Skip to content

vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe

TritonOrDeepGemmExperts

Bases: FallbackExperts

DeepGemm with fallback to Triton for low latency shapes.

Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
class TritonOrDeepGemmExperts(FallbackExperts):
    """DeepGemm with fallback to Triton for low latency shapes."""

    def __init__(self, quant_config: FusedMoEQuantConfig):
        super().__init__(
            experts=DeepGemmExperts(quant_config),
            fallback_experts=TritonExperts(quant_config),
        )

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: str,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        # Note: the deep gemm workspaces are strictly larger than the triton
        # workspaces so we can be pessimistic here and allocate for DeepGemm
        # even if we fall back to triton later, e.g. if expert maps are set.
        if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
            return self.experts.workspace_shapes(
                M,
                N,
                K,
                topk,
                global_num_experts,
                local_num_experts,
                expert_tokens_meta,
                activation,
            )
        else:
            return self.fallback_experts.workspace_shapes(
                M,
                N,
                K,
                topk,
                global_num_experts,
                local_num_experts,
                expert_tokens_meta,
                activation,
            )

    def _select_experts_impl(
        self,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
            return self.experts
        else:
            return self.fallback_experts

__init__

__init__(quant_config: FusedMoEQuantConfig)
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def __init__(self, quant_config: FusedMoEQuantConfig):
    super().__init__(
        experts=DeepGemmExperts(quant_config),
        fallback_experts=TritonExperts(quant_config),
    )

_select_experts_impl

_select_experts_impl(
    hidden_states: Tensor, w1: Tensor, w2: Tensor
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def _select_experts_impl(
    self,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
        return self.experts
    else:
        return self.fallback_experts

workspace_shapes

workspace_shapes(
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: ExpertTokensMetadata | None,
    activation: str,
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...]
]
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def workspace_shapes(
    self,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
    activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
    # Note: the deep gemm workspaces are strictly larger than the triton
    # workspaces so we can be pessimistic here and allocate for DeepGemm
    # even if we fall back to triton later, e.g. if expert maps are set.
    if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
        return self.experts.workspace_shapes(
            M,
            N,
            K,
            topk,
            global_num_experts,
            local_num_experts,
            expert_tokens_meta,
            activation,
        )
    else:
        return self.fallback_experts.workspace_shapes(
            M,
            N,
            K,
            topk,
            global_num_experts,
            local_num_experts,
            expert_tokens_meta,
            activation,
        )