Skip to content

vllm.compilation.rocm_aiter_fusion

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

logger module-attribute

logger = init_logger(__name__)

AiterFusedAddRMSFp8GroupQuantPattern

Bases: AiterRMSNormQuantPattern

This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops into a aiter rms_norm_with_add_group_fp8_quant op.

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
    """
    This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
    into a aiter rms_norm_with_add_group_fp8_quant op.
    """

    FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()

    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        match_aiter_quant: bool = True,
        symmetric: bool = True,
    ) -> None:
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )

        super().__init__(epsilon, key, match_aiter_quant)

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)

            return result, residual_out, scale

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            at = self.FUSED_OP(
                x=input,
                residual=residual,
                weight=weight,
                variance_epsilon=self.epsilon,
                group_size=128,
            )

            # result, scale, residual
            return at[0], at[1], at[2]

        pm.register_replacement(
            pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
        )

FUSED_OP class-attribute instance-attribute

FUSED_OP = get_rmsnorm_group_add_fused_quant_op()

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape,
    match_aiter_quant: bool = True,
    symmetric: bool = True,
) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape,
    match_aiter_quant: bool = True,
    symmetric: bool = True,
) -> None:
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )

    super().__init__(epsilon, key, match_aiter_quant)

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
        result, scale = self.quant_matcher(result_rms)

        return result, residual_out, scale

    def replacement(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        at = self.FUSED_OP(
            x=input,
            residual=residual,
            weight=weight,
            variance_epsilon=self.epsilon,
            group_size=128,
        )

        # result, scale, residual
        return at[0], at[1], at[2]

    pm.register_replacement(
        pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
    )

AiterFusedAddRMSNormDynamicQuantPattern

Bases: AiterRMSNormQuantPattern

AITER RMSNorm Fused Add + Dynamic Quantization pattern.

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
    """AITER RMSNorm Fused Add + Dynamic Quantization pattern."""

    FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()

    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        match_aiter_quant: bool = True,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric: bool = True,
    ) -> None:
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )

        super().__init__(epsilon, key, match_aiter_quant)

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)

            return result, residual_out, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            result = self.FUSED_OP(
                x=input,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
                quant_dtype=self.quant_dtype,
            )

            return result[0], result[1], result[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

FUSED_OP class-attribute instance-attribute

FUSED_OP = get_rmsnorm_fused_add_dynamic_quant_op()

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    match_aiter_quant: bool = True,
    group_shape: GroupShape = PER_TOKEN,
    symmetric: bool = True,
) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    match_aiter_quant: bool = True,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric: bool = True,
) -> None:
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )

    super().__init__(epsilon, key, match_aiter_quant)

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
        result, scale = self.quant_matcher(result_rms)

        return result, residual_out, scale

    def replacement(
        input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        result = self.FUSED_OP(
            x=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
            quant_dtype=self.quant_dtype,
        )

        return result[0], result[1], result[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

AiterRMSFp8GroupQuantPattern

Bases: AiterRMSNormQuantPattern

This pattern fuses aiter rms_norm & group fp8 quant custom ops into an aiter rms_norm_group_fp8_quant op.

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
    """
    This pattern fuses aiter rms_norm & group fp8 quant custom
    ops into an aiter rms_norm_group_fp8_quant op.
    """

    FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()

    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        match_aiter_quant: bool = True,
        symmetric: bool = True,
    ) -> None:
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )

        super().__init__(epsilon, key, match_aiter_quant)

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            at = self.FUSED_OP(
                x=input,
                weight=weight,
                variance_epsilon=self.epsilon,
                group_size=128,
            )

            return at[0], at[1]

        pm.register_replacement(
            pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
        )

FUSED_OP class-attribute instance-attribute

FUSED_OP = get_rmsnorm_group_fused_quant_op()

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape,
    match_aiter_quant: bool = True,
    symmetric: bool = True,
) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape,
    match_aiter_quant: bool = True,
    symmetric: bool = True,
) -> None:
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )

    super().__init__(epsilon, key, match_aiter_quant)

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        result_rms = self.rmsnorm_matcher(input, weight)
        result, scale = self.quant_matcher(result_rms)
        return result, scale

    def replacement(
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        at = self.FUSED_OP(
            x=input,
            weight=weight,
            variance_epsilon=self.epsilon,
            group_size=128,
        )

        return at[0], at[1]

    pm.register_replacement(
        pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
    )

AiterRMSNormDynamicQuantPattern

Bases: AiterRMSNormQuantPattern

AITER RMSNorm + Dynamic Quantization pattern.

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
    """AITER RMSNorm + Dynamic Quantization pattern."""

    FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()

    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        match_aiter_quant: bool = True,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric: bool = True,
    ) -> None:
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )

        super().__init__(epsilon, key, match_aiter_quant)

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            result = self.FUSED_OP(
                x=input,
                weight=weight,
                epsilon=self.epsilon,
                quant_dtype=self.quant_dtype,
            )

            return result[0], result[1]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

FUSED_OP class-attribute instance-attribute

FUSED_OP = get_rmsnorm_fused_dynamic_quant_op()

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    match_aiter_quant: bool = True,
    group_shape: GroupShape = PER_TOKEN,
    symmetric: bool = True,
) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    match_aiter_quant: bool = True,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric: bool = True,
) -> None:
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )

    super().__init__(epsilon, key, match_aiter_quant)

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        result_rms = self.rmsnorm_matcher(input, weight)
        result, scale = self.quant_matcher(result_rms)
        return result, scale

    def replacement(
        input: torch.Tensor,
        weight: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        result = self.FUSED_OP(
            x=input,
            weight=weight,
            epsilon=self.epsilon,
            quant_dtype=self.quant_dtype,
        )

        return result[0], result[1]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

AiterRMSNormQuantPattern

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterRMSNormQuantPattern:
    def __init__(
        self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
    ):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon, match_rocm_aiter=True)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
        )
        self.quant_matcher = MatcherQuantFP8(
            key.quant,
            match_rocm_aiter=match_aiter_quant,
        )

epsilon instance-attribute

epsilon = epsilon

quant_dtype instance-attribute

quant_dtype = dtype

quant_matcher instance-attribute

quant_matcher = MatcherQuantFP8(
    quant, match_rocm_aiter=match_aiter_quant
)

rmsnorm_matcher instance-attribute

rmsnorm_matcher = (
    MatcherRMSNorm(epsilon, match_rocm_aiter=True)
    if not fused_add
    else MatcherFusedAddRMSNorm(
        epsilon, match_rocm_aiter=True
    )
)

__init__

__init__(
    epsilon: float,
    key: FusedRMSQuantKey,
    match_aiter_quant: bool = True,
)
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(
    self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
):
    self.epsilon = epsilon
    self.quant_dtype = key.quant.dtype

    self.rmsnorm_matcher = (
        MatcherRMSNorm(epsilon, match_rocm_aiter=True)
        if not key.fused_add
        else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
    )
    self.quant_matcher = MatcherQuantFP8(
        key.quant,
        match_rocm_aiter=match_aiter_quant,
    )

AiterSiluMulFp8GroupQuantPattern

Bases: ActivationQuantPattern

This pattern fuses aiter silu_and_mul & group fp8 quant custom ops into an aiter silu_and_mul_group_fp8_quant op.

Source code in vllm/compilation/rocm_aiter_fusion.py
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
    """
    This pattern fuses aiter silu_and_mul & group fp8 quant custom
    ops into an aiter silu_and_mul_group_fp8_quant op.
    """

    FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()

    def __init__(self, quant_op: OpOverload) -> None:
        self.silu_and_mul_matcher = MatcherSiluAndMul()
        self.quant_op = quant_op

    def get_inputs(self) -> list[torch.Tensor]:
        return [
            self.silu_and_mul_matcher.inputs()[0],
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            at1 = self.silu_and_mul_matcher(input)
            at2 = self.quant_op(at1, 128)
            return at2[0], at2[1]

        def replacement(
            input: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
            return at[0], at[1]

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

FUSED_SILU_MUL_QUANT_OP class-attribute instance-attribute

FUSED_SILU_MUL_QUANT_OP = (
    get_act_mul_fused_fp8_group_quant_op()
)

quant_op instance-attribute

quant_op = quant_op

silu_and_mul_matcher instance-attribute

silu_and_mul_matcher = MatcherSiluAndMul()

__init__

__init__(quant_op: OpOverload) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def __init__(self, quant_op: OpOverload) -> None:
    self.silu_and_mul_matcher = MatcherSiluAndMul()
    self.quant_op = quant_op

get_inputs

get_inputs() -> list[Tensor]
Source code in vllm/compilation/rocm_aiter_fusion.py
def get_inputs(self) -> list[torch.Tensor]:
    return [
        self.silu_and_mul_matcher.inputs()[0],
    ]

register

register(pm_pass: PatternMatcherPass) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
def register(self, pm_pass: PatternMatcherPass) -> None:
    def pattern(
        input: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        at1 = self.silu_and_mul_matcher(input)
        at2 = self.quant_op(at1, 128)
        return at2[0], at2[1]

    def replacement(
        input: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
        return at[0], at[1]

    pm.register_replacement(
        pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
    )

RocmAiterRMSNormFusionPass

Bases: VllmPatternMatcherPass

This pass fuses aiter rms_norm & vllm/aiter quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.

Source code in vllm/compilation/rocm_aiter_fusion.py
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses aiter rms_norm & vllm/aiter quant custom ops
    into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
        )

        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
        for epsilon in [1e-5, 1e-6]:
            #  Fuse aiter rms_norm + aiter dynamic group fp8 quant
            AiterRMSFp8GroupQuantPattern(
                epsilon, FP8_DTYPE, GroupShape(1, 128)
            ).register(self.patterns)

            # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
            AiterFusedAddRMSFp8GroupQuantPattern(
                epsilon, FP8_DTYPE, GroupShape(1, 128)
            ).register(self.patterns)

            for match_aiter_quant in [True, False]:
                # Fuse aiter rms_norm + (aiter / vllm built-in)
                # dynamic per-token fp8 quant
                AiterRMSNormDynamicQuantPattern(
                    epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
                ).register(self.patterns)

                # Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
                # dynamic per-token fp8 quant
                AiterFusedAddRMSNormDynamicQuantPattern(
                    epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
                ).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> str:
        fusion_patterns = [
            AiterRMSNormDynamicQuantPattern,
            AiterFusedAddRMSNormDynamicQuantPattern,
            AiterRMSFp8GroupQuantPattern,
            AiterFusedAddRMSFp8GroupQuantPattern,
        ]
        return self.hash_source(self, *fusion_patterns)

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
)

__call__

__call__(graph: Graph) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
    )

    # Make sure fused add patterns are before simple rms norm,
    # as the latter is a subset of the former in torch ops
    for epsilon in [1e-5, 1e-6]:
        #  Fuse aiter rms_norm + aiter dynamic group fp8 quant
        AiterRMSFp8GroupQuantPattern(
            epsilon, FP8_DTYPE, GroupShape(1, 128)
        ).register(self.patterns)

        # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
        AiterFusedAddRMSFp8GroupQuantPattern(
            epsilon, FP8_DTYPE, GroupShape(1, 128)
        ).register(self.patterns)

        for match_aiter_quant in [True, False]:
            # Fuse aiter rms_norm + (aiter / vllm built-in)
            # dynamic per-token fp8 quant
            AiterRMSNormDynamicQuantPattern(
                epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
            ).register(self.patterns)

            # Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
            # dynamic per-token fp8 quant
            AiterFusedAddRMSNormDynamicQuantPattern(
                epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
            ).register(self.patterns)

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> str
Source code in vllm/compilation/rocm_aiter_fusion.py
def uuid(self) -> str:
    fusion_patterns = [
        AiterRMSNormDynamicQuantPattern,
        AiterFusedAddRMSNormDynamicQuantPattern,
        AiterRMSFp8GroupQuantPattern,
        AiterFusedAddRMSFp8GroupQuantPattern,
    ]
    return self.hash_source(self, *fusion_patterns)

RocmAiterSiluMulFp8GroupQuantFusionPass

Bases: VllmPatternMatcherPass

This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them.

Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980

Source code in vllm/compilation/rocm_aiter_fusion.py
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

    AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
    TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default

    QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]

    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
        )

        for quant_op in self.QUANT_OPS:
            AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: torch.fx.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> str:
        fusion_patterns = [
            ActivationQuantPattern,
            AiterSiluMulFp8GroupQuantPattern,
        ]
        return VllmInductorPass.hash_source(self, *fusion_patterns)

AITER_GROUP_FP8_QUANT_OP class-attribute instance-attribute

AITER_GROUP_FP8_QUANT_OP = get_group_quant_op()

QUANT_OPS class-attribute instance-attribute

TRITON_GROUP_FP8_QUANT_OP class-attribute instance-attribute

TRITON_GROUP_FP8_QUANT_OP = default

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)

__call__

__call__(graph: Graph) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig) -> None
Source code in vllm/compilation/rocm_aiter_fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
    )

    for quant_op in self.QUANT_OPS:
        AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> str
Source code in vllm/compilation/rocm_aiter_fusion.py
def uuid(self) -> str:
    fusion_patterns = [
        ActivationQuantPattern,
        AiterSiluMulFp8GroupQuantPattern,
    ]
    return VllmInductorPass.hash_source(self, *fusion_patterns)