Skip to content

vllm.model_executor.layers.attention.mm_encoder_attention

logger module-attribute

logger = init_logger(__name__)

MMEncoderAttention

Bases: CustomOp

Multi-headed attention without any cache, used for multimodal encoder.

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
    """Multi-headed attention without any cache, used for multimodal encoder."""

    # --8<-- [end:mm_encoder_attn]

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float | None = None,
        num_kv_heads: int | None = None,
        prefix: str = "",
        multimodal_config: MultiModalConfig | None = None,
    ) -> None:
        """
        Args:
            num_heads: number of attention heads per partition.
            head_size: hidden_size per attention head.
            scale: scale factor.
            num_kv_heads: number of kv heads.
            prefix: This has no effect, it is only here to make it easier to
                    swap between Attention and MultiHeadAttention
            multimodal_config: configs for multi-modal.
        """
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.layer_name = prefix

        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
            f"divisible by num_kv_heads ({self.num_kv_heads})"
        )
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()

        # Try to get vision attention backend from multimodal_config.
        attn_backend_override = None
        if multimodal_config is not None:
            attn_backend_override = multimodal_config.mm_encoder_attn_backend

        # Get device-specific vision attention backend.
        self.attn_backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
            attn_backend_override=attn_backend_override,
        )

        self.is_flash_attn_backend = self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }

        self._fa_version = (
            get_flash_attn_version() if self.is_flash_attn_backend else None
        )

        logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

    @classmethod
    def enabled(cls) -> bool:
        return True

    def maybe_reshape_qkv_to_4d(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        bsz: int,
        q_len: int,
        kv_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reshape query, key, value to 4D tensors:
        (batch_size, seq_len, num_heads, head_size)
        """
        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        return query, key, value

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_torch_sdpa_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_fa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_flash_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
            fa_version=self._fa_version,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        if self.is_flash_attn_backend:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for CUDA: "
                f"{self.attn_backend}."
            )

    def forward_cpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_xpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        assert self.is_flash_attn_backend, (
            "XPU only supports FLASH_ATTN for vision attention."
        )
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

_fa_version instance-attribute

_fa_version = (
    get_flash_attn_version()
    if is_flash_attn_backend
    else None
)

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_size,
    dtype=dtype,
    attn_backend_override=attn_backend_override,
)

head_size instance-attribute

head_size = head_size

is_flash_attn_backend instance-attribute

is_flash_attn_backend = attn_backend in {
    FLASH_ATTN,
    ROCM_AITER_FA,
}

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = scale

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
    multimodal_config: MultiModalConfig | None = None,
) -> None

Parameters:

Name Type Description Default
num_heads int

number of attention heads per partition.

required
head_size int

hidden_size per attention head.

required
scale float | None

scale factor.

None
num_kv_heads int | None

number of kv heads.

None
prefix str

This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention

''
multimodal_config MultiModalConfig | None

configs for multi-modal.

None
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
    multimodal_config: MultiModalConfig | None = None,
) -> None:
    """
    Args:
        num_heads: number of attention heads per partition.
        head_size: hidden_size per attention head.
        scale: scale factor.
        num_kv_heads: number of kv heads.
        prefix: This has no effect, it is only here to make it easier to
                swap between Attention and MultiHeadAttention
        multimodal_config: configs for multi-modal.
    """
    super().__init__()

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = scale
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.layer_name = prefix

    assert self.num_heads % self.num_kv_heads == 0, (
        f"num_heads ({self.num_heads}) is not "
        f"divisible by num_kv_heads ({self.num_kv_heads})"
    )
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()

    # Try to get vision attention backend from multimodal_config.
    attn_backend_override = None
    if multimodal_config is not None:
        attn_backend_override = multimodal_config.mm_encoder_attn_backend

    # Get device-specific vision attention backend.
    self.attn_backend = get_vit_attn_backend(
        head_size=head_size,
        dtype=dtype,
        attn_backend_override=attn_backend_override,
    )

    self.is_flash_attn_backend = self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }

    self._fa_version = (
        get_flash_attn_version() if self.is_flash_attn_backend else None
    )

    logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

_forward_fa

_forward_fa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_fa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_flash_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
        fa_version=self._fa_version,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_torch_sdpa_wrapper(
        q=query,
        k=key,
        v=value,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

enabled classmethod

enabled() -> bool
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@classmethod
def enabled(cls) -> bool:
    return True

forward_cpu

forward_cpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    if self.is_flash_attn_backend:
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
    elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
        return self._forward_sdpa(query, key, value, cu_seqlens)
    else:
        raise ValueError(
            f"Unsupported multi-modal encoder attention backend for CUDA: "
            f"{self.attn_backend}."
        )

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_xpu

forward_xpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_xpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    assert self.is_flash_attn_backend, (
        "XPU only supports FLASH_ATTN for vision attention."
    )
    return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

maybe_reshape_qkv_to_4d

maybe_reshape_qkv_to_4d(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[Tensor, Tensor, Tensor]

Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def maybe_reshape_qkv_to_4d(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Reshape query, key, value to 4D tensors:
    (batch_size, seq_len, num_heads, head_size)
    """
    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    return query, key, value