Skip to content

vllm.model_executor.layers.quantization.fp8

ACTIVATION_SCHEMES module-attribute

ACTIVATION_SCHEMES = ['static', 'dynamic']

logger module-attribute

logger = init_logger(__name__)

CopyNumelCounter

Bases: TorchDispatchMode

Tracks total number of elements modified with copy_. Useful for keeping track of weight loading where underlying weights can be arbitrarily transformed (such as with narrow) before calling copy.

Source code in vllm/model_executor/layers/quantization/fp8.py
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out

copied_numel instance-attribute

copied_numel = 0

__init__

__init__()
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self):
    super().__init__()
    self.copied_numel = 0

__torch_dispatch__

__torch_dispatch__(func, types, args=(), kwargs=None)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    out = func(*args, **kwargs)
    if func == torch.ops.aten.copy_.default:
        self.copied_numel += args[0].numel()
    return out

Fp8Config

Bases: QuantizationConfig

Config class for FP8.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8Config(QuantizationConfig):
    """Config class for FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        activation_scheme: str = "dynamic",
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
    ) -> None:
        super().__init__()

        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
        self.activation_scheme = activation_scheme
        self.ignored_layers = ignored_layers or []
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now."
                )
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
            if activation_scheme != "dynamic":
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
        self.weight_block_size = weight_block_size

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 75

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = "fp8" in quant_method
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size,
        )

        if isinstance(layer, LinearBase):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
        if isinstance(layer, LinearBase):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedLinearMethod()
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
        elif isinstance(layer, FusedMoE):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
            return moe_quant_method
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

    def get_cache_scale(self, name: str) -> str | None:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
        return None

activation_scheme instance-attribute

activation_scheme = activation_scheme

ignored_layers instance-attribute

ignored_layers = ignored_layers or []

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

weight_block_size instance-attribute

weight_block_size = weight_block_size

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: list[str] | None = None,
    weight_block_size: list[int] | None = None,
) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: list[str] | None = None,
    weight_block_size: list[int] | None = None,
) -> None:
    super().__init__()

    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

    if activation_scheme not in ACTIVATION_SCHEMES:
        raise ValueError(f"Unsupported activation scheme {activation_scheme}")
    self.activation_scheme = activation_scheme
    self.ignored_layers = ignored_layers or []
    if weight_block_size is not None:
        if not is_checkpoint_fp8_serialized:
            raise ValueError(
                "The block-wise quantization only supports fp8-serialized "
                "checkpoint for now."
            )
        if len(weight_block_size) != 2:
            raise ValueError(
                "The quantization block size of weight must have 2 "
                f"dimensions, but got {len(weight_block_size)} dimensions"
            )
        if activation_scheme != "dynamic":
            raise ValueError(
                "The block-wise quantization only supports "
                "dynamic activation scheme for now, but got "
                f"{activation_scheme} activation scheme."
            )
    self.weight_block_size = weight_block_size

apply_vllm_mapper

apply_vllm_mapper(hf_to_vllm_mapper: WeightsMapper)
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
    if self.ignored_layers is not None:
        self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)

from_config classmethod

from_config(config: dict[str, Any]) -> Fp8Config
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
    quant_method = cls.get_from_keys(config, ["quant_method"])
    is_checkpoint_fp8_serialized = "fp8" in quant_method
    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
    ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
    weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
    if not ignored_layers:
        ignored_layers = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None
        )
    return cls(
        is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
        activation_scheme=activation_scheme,
        ignored_layers=ignored_layers,
        weight_block_size=weight_block_size,
    )

get_cache_scale

get_cache_scale(name: str) -> str | None

Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent param name expected by vLLM

:param name: param name :return: matching param name for KV cache scale in vLLM

Source code in vllm/model_executor/layers/quantization/fp8.py
def get_cache_scale(self, name: str) -> str | None:
    """
    Check whether the param name matches the format for k/v cache scales
    in compressed-tensors. If this is the case, return its equivalent
    param name expected by vLLM

    :param name: param name
    :return: matching param name for KV cache scale in vLLM
    """
    if name.endswith(".output_scale") and ".k_proj" in name:
        return name.replace(".k_proj.output_scale", ".attn.k_scale")
    if name.endswith(".output_scale") and ".v_proj" in name:
        return name.replace(".v_proj.output_scale", ".attn.v_scale")
    if name.endswith(".output_scale") and ".q_proj" in name:
        return name.replace(".q_proj.output_scale", ".attn.q_scale")
    if name.endswith("self_attn.prob_output_scale"):
        return name.replace(".prob_output_scale", ".attn.prob_scale")
    # If no matches, return None
    return None

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_min_capability(cls) -> int:
    return 75

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "fp8"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    if current_platform.is_xpu():
        return self.get_xpu_quant_method(layer, prefix)
    if isinstance(layer, LinearBase):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedLinearMethod()
        quant_method = Fp8LinearMethod(self)
        quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
        return quant_method
    elif isinstance(layer, FusedMoE):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedFusedMoEMethod(layer.moe_config)
        if self.is_checkpoint_fp8_serialized:
            moe_quant_method = Fp8MoEMethod(self, layer)
        else:
            moe_quant_method = Fp8OnlineMoEMethod(self, layer)
        return moe_quant_method
    elif isinstance(layer, Attention):
        return Fp8KVCacheMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

get_xpu_quant_method

get_xpu_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_xpu_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    from vllm.model_executor.layers.quantization.ipex_quant import (
        XPUFp8LinearMethod,
        XPUFp8MoEMethod,
    )

    fp8_config = Fp8Config(
        is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
        activation_scheme=self.activation_scheme,
        ignored_layers=self.ignored_layers,
        weight_block_size=self.weight_block_size,
    )

    if isinstance(layer, LinearBase):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedLinearMethod()
        return XPUFp8LinearMethod(fp8_config)
    elif isinstance(layer, FusedMoE):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedFusedMoEMethod(layer.moe_config)

        return XPUFp8MoEMethod(fp8_config, layer)
    elif isinstance(layer, Attention):
        return Fp8KVCacheMethod(self)
    return None

Fp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        super().__init__(quant_config)

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    super().__init__(quant_config)

Fp8LinearMethod

Bases: LinearMethodBase

Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
        self.out_dtype = torch.get_default_dtype()

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.marlin_input_dtype = None
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False
        if vllm_is_batch_invariant():
            self.use_marlin = False

        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
        self.use_deep_gemm = is_deep_gemm_supported()

        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
        self.act_q_static = self.quant_config.activation_scheme == "static"
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
        else:
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR

        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
                act_quant_group_shape=self.act_q_group_shape,
            )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        maybe_create_device_identity()

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        if self.block_quant:
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )

        # WEIGHT
        if self.quant_config.is_checkpoint_fp8_serialized:
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
        else:

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

            # For non-serialized checkpoints, use original dtype
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=patched_weight_loader,
            )
        layer.register_parameter("weight", weight)

        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            if not self.block_quant:
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
                layer.register_parameter("weight_scale", scale)
            else:
                assert not self.act_q_static
                assert self.weight_block_size is not None
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)

            # INPUT ACTIVATION SCALE
            if self.act_q_static:
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
                set_weight_attrs(scale, {"scale_type": "input_scale"})
                layer.register_parameter("input_scale", scale)
            else:
                layer.register_parameter("input_scale", None)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        size_k_first = True
        input_scale = None
        # TODO(rob): refactor block quant into separate class.
        if self.block_quant:
            assert not self.act_q_static
            size_k_first = False

            weight, weight_scale_inv = process_fp8_weight_block_strategy(
                layer.weight, layer.weight_scale_inv
            )

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)

        # If checkpoint not serialized fp8, quantize the weights.
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()

            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
        else:
            layer.input_scale = None

        if self.use_marlin:
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.
            del layer.input_scale
            return

        if self.block_quant:
            maybe_post_process_fp8_weight_block(layer)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
        if vllm_is_batch_invariant():
            if self.block_quant:
                assert self.weight_block_size is not None
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale_inv,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
            else:
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)

        if self.use_marlin:
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=weight_scale,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                input_dtype=self.marlin_input_dtype,
                bias=bias,
            )

        if self.block_quant:
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )

act_q_group_shape instance-attribute

act_q_group_shape = GroupShape(1, weight_block_size[0])

act_q_static instance-attribute

act_q_static = activation_scheme == 'static'

block_quant instance-attribute

block_quant = weight_block_size is not None

cutlass_block_fp8_supported instance-attribute

cutlass_block_fp8_supported = cutlass_block_fp8_supported()

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=act_q_static,
    act_quant_group_shape=act_q_group_shape,
)

marlin_input_dtype instance-attribute

marlin_input_dtype = None

out_dtype instance-attribute

out_dtype = get_default_dtype()

quant_config instance-attribute

quant_config = quant_config

use_aiter_and_is_supported instance-attribute

use_aiter_and_is_supported = is_linear_fp8_enabled()

use_deep_gemm instance-attribute

use_deep_gemm = is_deep_gemm_supported()

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

w8a8_block_fp8_linear instance-attribute

w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
    weight_group_shape=GroupShape(*(weight_block_size)),
    act_quant_group_shape=act_q_group_shape,
    cutlass_block_fp8_supported=cutlass_block_fp8_supported,
    use_aiter_and_is_supported=use_aiter_and_is_supported,
)

weight_block_size instance-attribute

weight_block_size = weight_block_size

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    self.quant_config = quant_config
    self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
    self.out_dtype = torch.get_default_dtype()

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.marlin_input_dtype = None
    self.use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False
    if vllm_is_batch_invariant():
        self.use_marlin = False

    self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
    self.use_deep_gemm = is_deep_gemm_supported()

    self.weight_block_size = self.quant_config.weight_block_size
    self.block_quant = self.weight_block_size is not None
    self.act_q_static = self.quant_config.activation_scheme == "static"
    if self.weight_block_size:
        self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
    else:
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

    if self.block_quant:
        assert not self.act_q_static
        assert self.weight_block_size is not None
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(*self.weight_block_size),
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
            use_aiter_and_is_supported=self.use_aiter_and_is_supported,
        )
    else:
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape,
        )

apply

apply(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
    # we will use BF16 dequant when DeepGEMM is not supported.
    if vllm_is_batch_invariant():
        if self.block_quant:
            assert self.weight_block_size is not None
            return self.w8a8_block_fp8_linear.apply(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )
        else:
            # per-tensor/channel: dequant to BF16 and run GEMM
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)
            if weight_scale.numel() == 1:
                # Per-tensor: simple scalar multiplication
                weight_bf16 = weight_fp8 * weight_scale
            else:
                # Multiple scales (fused modules like QKV)
                # Try to infer correct broadcasting
                # weight is [K, N], scale could be [num_logical_weights]
                # Need to figure out how to broadcast - for now just try
                # direct multiplication
                if (
                    weight_scale.dim() == 1
                    and weight_scale.shape[0] == weight_fp8.shape[0]
                ):
                    # Per-row scaling
                    weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                else:
                    # Fallback
                    weight_bf16 = weight_fp8 * weight_scale
            return torch.nn.functional.linear(x, weight_bf16.t(), bias)

    if self.use_marlin:
        if self.block_quant:
            weight_scale = layer.weight_scale_inv
        else:
            weight_scale = layer.weight_scale

        return apply_fp8_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=weight_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            input_dtype=self.marlin_input_dtype,
            bias=bias,
        )

    if self.block_quant:
        assert self.weight_block_size is not None

        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale_inv,
            input_scale=layer.input_scale,
            bias=bias,
        )

    return self.fp8_linear.apply(
        input=x,
        weight=layer.weight,
        weight_scale=layer.weight_scale,
        out_dtype=self.out_dtype,
        input_scale=layer.input_scale,
        bias=bias,
    )

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    maybe_create_device_identity()

    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    if self.block_quant:
        assert self.weight_block_size is not None
        layer.weight_block_size = self.weight_block_size
        validate_fp8_block_shape(
            layer,
            input_size,
            output_size,
            input_size_per_partition,
            output_partition_sizes,
            self.weight_block_size,
        )

    # WEIGHT
    if self.quant_config.is_checkpoint_fp8_serialized:
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
    else:

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

            return res

        # For non-serialized checkpoints, use original dtype
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
    layer.register_parameter("weight", weight)

    # If checkpoint is serialized fp8, load them.
    # Otherwise, wait until process_weights_after_loading.
    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
            )
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            layer.register_parameter("weight_scale", scale)
        else:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
            )
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)

        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
        else:
            layer.register_parameter("input_scale", None)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    size_k_first = True
    input_scale = None
    # TODO(rob): refactor block quant into separate class.
    if self.block_quant:
        assert not self.act_q_static
        size_k_first = False

        weight, weight_scale_inv = process_fp8_weight_block_strategy(
            layer.weight, layer.weight_scale_inv
        )

        # Update layer with new values
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)

    # If checkpoint not serialized fp8, quantize the weights.
    else:
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
            weight = qweight.t()

        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
        # shards in a fused module
        else:
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = (
                    process_fp8_weight_tensor_strategy(
                        weight,
                        weight_scale,
                        layer.logical_widths,
                        getattr(layer, "input_scale", None),
                    )
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

    if input_scale is not None:
        replace_parameter(layer, "input_scale", input_scale)
    else:
        layer.input_scale = None

    if self.use_marlin:
        prepare_fp8_layer_for_marlin(
            layer, size_k_first, input_dtype=self.marlin_input_dtype
        )
        # Activations not quantized for marlin.
        del layer.input_scale
        return

    if self.block_quant:
        maybe_post_process_fp8_weight_block(layer)

Fp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.quant_config = quant_config
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant: bool = self.weight_block_size is not None
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=self.block_quant,
            tp_size=layer.moe_parallel_config.tp_size,
            with_lora_support=self.moe.is_lora_enabled,
        )

        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
                )
        dynamic_per_token = (
            not self.block_quant and self.quant_config.activation_scheme != "static"
        )
        if dynamic_per_token and self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_TRTLLM,
            Fp8MoeBackend.FLASHINFER_CUTLASS,
        ]:
            raise NotImplementedError(
                "FlashInfer FP8 MoE backend does not support dynamic per token "
                "activation quantization."
            )

        self.kernel: mk.FusedMoEModularKernel | None = None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

        if self.block_quant:
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.weight_block_size[0],
                self.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size_per_partition % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
                # Required by row parallel
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}."
                )

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        if not self.block_quant:
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
        else:
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
            )
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
            )
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)

        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, extra_weight_attrs)

            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

        else:
            layer.w13_input_scale = None
            layer.w2_input_scale = None

    def _setup_kernel(
        self,
        layer: Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
    ) -> None:
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )

        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)

        # Setup modular kernel for TP case.
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
            )

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
            )
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
            assert w13_input_scale is not None and w2_input_scale is not None
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )

        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )

    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
        if self.fp8_backend in [
            Fp8MoeBackend.AITER,
            Fp8MoeBackend.MARLIN,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
            return None
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        return super().maybe_make_prepare_finalize(routing_tables)

    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> FusedMoEPermuteExpertsUnpermute:
        from vllm.model_executor.layers.fused_moe import (
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
            TritonExperts,
            TritonOrDeepGemmExperts,
        )

        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )

        assert self.moe_quant_config is not None

        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
            assert max_num_tokens_per_rank is not None

            experts_impl = (
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
            )
            logger.debug(
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
            return experts_impl(
                max_num_tokens=max_num_tokens_per_rank,
                num_dispatchers=prepare_finalize.num_dispatchers(),
                quant_config=self.moe_quant_config,
            )
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            # Select GEMM experts with block-scale when weights are block-quantized
            experts = select_cutlass_fp8_gemm_impl(
                self.moe,
                self.moe_quant_config,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
        elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
            return TritonOrDeepGemmExperts(self.moe_quant_config)
        else:
            assert self.fp8_backend == Fp8MoeBackend.TRITON
            logger.debug(
                "TritonExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
            return TritonExperts(self.moe_quant_config)

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> FusedMoEQuantConfig | None:
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=self.weight_block_size,
        )

    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

    def apply(
        self,
        layer: FusedMoE,
        router: FusedMoERouter,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            # TODO(rob): convert this to MK.
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
            )

            if self.block_quant:
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401

                e_score_correction_bias = (
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
                    else None
                )
                routing_method_type = layer.routing_method_type
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.weight_block_size,
                    routing_method_type=routing_method_type,
                    routed_scaling=layer.routed_scaling_factor,
                )
            else:
                result = apply_fi_trtllm_fp8_per_tensor_moe(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
                )

        topk_weights, topk_ids = router.select_experts(
            hidden_states=x,
            router_logits=router_logits,
        )

        assert self.kernel is not None
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

        return result

allow_inplace property

allow_inplace: bool

block_quant instance-attribute

block_quant: bool = weight_block_size is not None

fp8_backend instance-attribute

fp8_backend = select_fp8_moe_backend(
    block_quant=block_quant,
    tp_size=tp_size,
    with_lora_support=is_lora_enabled,
)

kernel instance-attribute

kernel: FusedMoEModularKernel | None = None

quant_config instance-attribute

quant_config = quant_config

supports_eplb property

supports_eplb: bool

weight_block_size instance-attribute

weight_block_size = weight_block_size

weight_scale_name instance-attribute

weight_scale_name = (
    "weight_scale_inv" if block_quant else "weight_scale"
)

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(layer.moe_config)
    self.quant_config = quant_config
    self.weight_block_size = self.quant_config.weight_block_size
    self.block_quant: bool = self.weight_block_size is not None
    self.weight_scale_name = (
        "weight_scale_inv" if self.block_quant else "weight_scale"
    )
    self.fp8_backend = select_fp8_moe_backend(
        block_quant=self.block_quant,
        tp_size=layer.moe_parallel_config.tp_size,
        with_lora_support=self.moe.is_lora_enabled,
    )

    if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
        if self.block_quant and self.weight_block_size != [128, 128]:
            raise NotImplementedError(
                "FlashInfer CUTLASS FP8 MoE backend only supports block "
                "size [128, 128]."
            )
        if layer.activation != "silu":
            raise NotImplementedError(
                "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                "activation function, but got {layer.activation}."
            )
    dynamic_per_token = (
        not self.block_quant and self.quant_config.activation_scheme != "static"
    )
    if dynamic_per_token and self.fp8_backend in [
        Fp8MoeBackend.FLASHINFER_TRTLLM,
        Fp8MoeBackend.FLASHINFER_CUTLASS,
    ]:
        raise NotImplementedError(
            "FlashInfer FP8 MoE backend does not support dynamic per token "
            "activation quantization."
        )

    self.kernel: mk.FusedMoEModularKernel | None = None

_setup_kernel

_setup_kernel(
    layer: Module,
    w13: Tensor,
    w2: Tensor,
    w13_scale: Tensor,
    w2_scale: Tensor,
    w13_input_scale: Tensor | None,
    w2_input_scale: Tensor | None,
) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def _setup_kernel(
    self,
    layer: Module,
    w13: torch.Tensor,
    w2: torch.Tensor,
    w13_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    w13_input_scale: torch.Tensor | None,
    w2_input_scale: torch.Tensor | None,
) -> None:
    # Shuffle weights to runtime format.
    w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
        fp8_backend=self.fp8_backend,
        layer=layer,
        w13=w13,
        w2=w2,
        w13_scale=w13_scale,
        w2_scale=w2_scale,
        w13_input_scale=w13_input_scale,
        w2_input_scale=w2_input_scale,
    )

    # Replace parameters with updated versions. Note that this helper
    # function ensures the replacement is compatible with RL weight reloads.
    replace_parameter(layer, "w13_weight", w13)
    replace_parameter(layer, "w2_weight", w2)
    replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
    replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)

    # Setup modular kernel for TP case.
    self.moe_quant_config = self.get_fused_moe_quant_config(layer)
    if self.moe_quant_config:
        self.kernel, self.use_inplace = make_fp8_moe_kernel(
            layer=layer,
            moe_quant_config=self.moe_quant_config,
            moe_config=self.moe,
            fp8_backend=self.fp8_backend,
        )

apply

apply(
    layer: FusedMoE,
    router: FusedMoERouter,
    x: Tensor,
    router_logits: Tensor,
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(
    self,
    layer: FusedMoE,
    router: FusedMoERouter,
    x: torch.Tensor,
    router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
        # TODO(rob): convert this to MK.
        if layer.enable_eplb:
            raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )

        if self.block_quant:
            import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401

            e_score_correction_bias = (
                layer.e_score_correction_bias.to(x.dtype)
                if layer.e_score_correction_bias is not None
                else None
            )
            routing_method_type = layer.routing_method_type
            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                routing_logits=router_logits.to(torch.float32)
                if routing_method_type == RoutingMethodType.DeepSeekV3
                else router_logits,
                routing_bias=e_score_correction_bias,
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.weight_block_size,
                routing_method_type=routing_method_type,
                routed_scaling=layer.routed_scaling_factor,
            )
        else:
            result = apply_fi_trtllm_fp8_per_tensor_moe(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
            )

    topk_weights, topk_ids = router.select_experts(
        hidden_states=x,
        router_logits=router_logits,
    )

    assert self.kernel is not None
    result = self.kernel(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights,
        topk_ids,
        inplace=self.use_inplace,
        activation=layer.activation,
        global_num_experts=layer.global_num_experts,
        expert_map=layer.expert_map,
        apply_router_weight_on_input=layer.apply_router_weight_on_input,
    )

    return result

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    assert self.quant_config.is_checkpoint_fp8_serialized
    params_dtype = torch.float8_e4m3fn

    if self.block_quant:
        assert self.weight_block_size is not None
        layer.weight_block_size = self.weight_block_size
        tp_size = get_tensor_model_parallel_world_size()
        block_n, block_k = (
            self.weight_block_size[0],
            self.weight_block_size[1],
        )
        # NOTE: To ensure proper alignment of the block-wise quantization
        # scales, the output_size of the weights for both the gate and up
        # layers must be divisible by block_n.
        # Required by column parallel or enabling merged weights
        if intermediate_size_per_partition % block_n != 0:
            raise ValueError(
                f"The output_size of gate's and up's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_n = {block_n}."
            )
        if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
            # Required by row parallel
            raise ValueError(
                f"The input_size of down's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_k = {block_k}."
            )

    # WEIGHTS
    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    if not self.block_quant:
        # For per-tensor quant, the scales are per expert and weight.
        w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
        w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
    else:
        # For block quant, the scales are per block (typically 128x128).
        w13_scale_data = torch.ones(
            num_experts,
            2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
            (hidden_size + block_k - 1) // block_k,
            dtype=torch.float32,
        )
        w2_scale_data = torch.ones(
            num_experts,
            (hidden_size + block_n - 1) // block_n,
            (intermediate_size_per_partition + block_k - 1) // block_k,
            dtype=torch.float32,
        )
    w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
    w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
    # Note: name is weight_scale for tensor, weight_scale_inv for block.
    layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
    layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)

    # Add the quantization method used (per tensor/grouped/channel)
    # to ensure the weight scales are loaded in properly
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        if self.block_quant
        else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
    )
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    if self.quant_config.activation_scheme == "static":
        assert not self.block_quant
        w13_input_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

    else:
        layer.w13_input_scale = None
        layer.w2_input_scale = None

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> FusedMoEQuantConfig | None
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_fused_moe_quant_config(
    self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
    # TRTLLM does not use Modular Kernel.
    if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
        return None

    w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
    w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
    a1_scale = layer.w13_input_scale
    a2_scale = layer.w2_input_scale

    return make_fp8_moe_quant_config(
        fp8_backend=self.fp8_backend,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
        block_shape=self.weight_block_size,
    )

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    routing_tables: tuple[Tensor, Tensor, Tensor]
    | None = None,
) -> FusedMoEPrepareAndFinalize | None
Source code in vllm/model_executor/layers/quantization/fp8.py
def maybe_make_prepare_finalize(
    self,
    routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
    if self.fp8_backend in [
        Fp8MoeBackend.AITER,
        Fp8MoeBackend.MARLIN,
        Fp8MoeBackend.FLASHINFER_TRTLLM,
    ]:
        return None
    elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            self.moe,
            use_deepseek_fp8_block_scale=self.block_quant,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize
    return super().maybe_make_prepare_finalize(routing_tables)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    # Allow for accessing weights and scales in standard way.
    w13 = layer.w13_weight
    w2 = layer.w2_weight
    w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
    w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
    w13_input_scale = layer.w13_input_scale
    w2_input_scale = layer.w2_input_scale

    # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
    if current_platform.is_fp8_fnuz():
        w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
            w13,
            w13_scale,
            w13_input_scale,
        )
        w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
            w2,
            w2_scale,
            w2_input_scale,
        )

    # Per tensor kernels require single activation scale. Use the max.
    if self.quant_config.activation_scheme == "static":
        assert not self.block_quant
        assert w13_input_scale is not None and w2_input_scale is not None
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

    # Per tensor kernels require single weight scale for w13 per expert, but
    # on disk there is a scale for w1 and w3. Use the max to requantize.
    if not self.block_quant:
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13, w13_scale, shard_size, layer.local_num_experts
        )

    # Shuffle weights to runtime format and setup kernel.
    self._setup_kernel(
        layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
    )

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/fp8.py
def select_gemm_impl(
    self,
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
    from vllm.model_executor.layers.fused_moe import (
        BatchedDeepGemmExperts,
        BatchedTritonExperts,
        TritonExperts,
        TritonOrDeepGemmExperts,
    )

    if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
        raise NotImplementedError(
            "Marlin and ROCm AITER are not supported with all2all yet."
        )

    assert self.moe_quant_config is not None

    if (
        prepare_finalize.activation_format
        == FusedMoEActivationFormat.BatchedExperts
    ):
        max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
        assert max_num_tokens_per_rank is not None

        experts_impl = (
            BatchedDeepGemmExperts
            if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
            else BatchedTritonExperts
        )
        logger.debug(
            "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
            experts_impl.__name__,
            self.__class__.__name__,
            max_num_tokens_per_rank,
            self.weight_block_size,
            False,
        )
        return experts_impl(
            max_num_tokens=max_num_tokens_per_rank,
            num_dispatchers=prepare_finalize.num_dispatchers(),
            quant_config=self.moe_quant_config,
        )
    elif self.moe.is_lora_enabled:
        return TritonExperts(quant_config=self.moe_quant_config)
    elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
        # Select GEMM experts with block-scale when weights are block-quantized
        experts = select_cutlass_fp8_gemm_impl(
            self.moe,
            self.moe_quant_config,
            use_deepseek_fp8_block_scale=self.block_quant,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
    elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
        logger.debug(
            "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
            self.__class__.__name__,
            self.weight_block_size,
            False,
        )
        return TritonOrDeepGemmExperts(self.moe_quant_config)
    else:
        assert self.fp8_backend == Fp8MoeBackend.TRITON
        logger.debug(
            "TritonExperts(%s): block_size=%s, per_act_token=%s",
            self.__class__.__name__,
            self.weight_block_size,
            False,
        )
        return TritonExperts(self.moe_quant_config)

Fp8OnlineMoEMethod

Bases: Fp8MoEMethod

MoE method for online FP8 quantization. Supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale

        for expert in range(layer.local_num_experts):
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
            )
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
            )

        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
        )

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(quant_config, layer)
    assert not quant_config.is_checkpoint_fp8_serialized
    assert quant_config.activation_scheme == "dynamic"
    assert quant_config.weight_block_size is None

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    # We are doing online quantization, patch the weight loaded
    # to call `process_weights_after_loading` in a streaming fashion
    # as soon as the last weight chunk is loaded.
    weight_loader = extra_weight_attrs["weight_loader"]
    # create a new holder to prevent modifying behavior of any other
    # objects which might depend on the old one
    new_extra_weight_attrs = extra_weight_attrs

    def patched_weight_loader(param, loaded_weight, *args, **kwargs):
        # add a counter to track how many elements we have updated
        if not hasattr(layer, "_loaded_numel"):
            layer._loaded_numel = 0

        # load the current weight chunk
        copy_numel_counter = CopyNumelCounter()
        with copy_numel_counter:
            res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
        layer._loaded_numel += copy_numel_counter.copied_numel

        # if we have loaded all of the elements, call
        # process_weights_after_loading
        target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
        if layer._loaded_numel == target_loaded_numel:
            self.process_weights_after_loading(layer)

            # Delete the bookkeeping
            del layer._loaded_numel
            # Prevent the usual `process_weights_after_loading` call
            # from doing anything
            layer._already_called_process_weights_after_loading = True

        return res

    new_extra_weight_attrs["weight_loader"] = patched_weight_loader
    extra_weight_attrs = new_extra_weight_attrs

    # WEIGHTS
    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    # Allocate 2 scales for w1 and w3 respectively.
    # They will be combined to a single scale after weight loading.
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, dtype=torch.float32), requires_grad=False
    )
    w2_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, dtype=torch.float32), requires_grad=False
    )
    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    layer.w13_input_scale = None
    layer.w2_input_scale = None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    # If checkpoint is fp16, quantize in place.
    fp8_dtype = current_platform.fp8_dtype()
    w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
    w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
    w13_scale = layer.w13_weight_scale
    w2_scale = layer.w2_weight_scale

    for expert in range(layer.local_num_experts):
        w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
            layer.w13_weight[expert, :, :]
        )
        w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
            layer.w2_weight[expert, :, :]
        )

    # Shuffle weights to runtime format and setup kernel.
    self._setup_kernel(
        layer,
        w13,
        w2,
        w13_scale,
        w2_scale,
        layer.w13_input_scale,
        layer.w2_input_scale,
    )