Skip to content

vllm.v1.attention.backends.flash_attn_diffkv

Attention layer with FlashAttention.

logger module-attribute

logger = init_logger(__name__)

FlashAttentionDiffKVBackend

Bases: FlashAttentionBackend

Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
    # Default to 128 for this backend
    head_size_v: int = 128

    @classmethod
    def set_head_size_v(cls, head_size_v: int) -> None:
        cls.head_size_v = head_size_v

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_DIFFKV"

    @staticmethod
    def get_impl_cls() -> type["FlashAttentionImpl"]:
        return FlashAttentionDiffKVImpl

    # Do not modify the interface of get_kv_cache_shape,
    # but consider head_size_v when returning result.
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (
            num_blocks,
            block_size,
            num_kv_heads,
            head_size + FlashAttentionDiffKVBackend.head_size_v,
        )

    @staticmethod
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets
        # us from `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, block_size,
            # num_kv_heads, head_size + head_size_v)
            return (1, 0, 2, 3, 4)
        elif cache_layout == "NHD":
            stride_order = (0, 1, 2, 3)
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, num_kv_heads, num_layers,
            # block_size, head_size + head_size_v)
            return (1, 3, 0, 2, 4)
        elif cache_layout == "HND":
            stride_order = (0, 2, 1, 3)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

head_size_v class-attribute instance-attribute

head_size_v: int = 128

get_impl_cls staticmethod

get_impl_cls() -> type[FlashAttentionImpl]
Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
    return FlashAttentionDiffKVImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (
        num_blocks,
        block_size,
        num_kv_heads,
        head_size + FlashAttentionDiffKVBackend.head_size_v,
    )

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    # `stride_order` indicates the permutation that gets
    # us from `get_kv_cache_shape` to the actual memory layout we want.
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD" and include_num_layers_dimension:
        # (num_blocks, num_layers, block_size,
        # num_kv_heads, head_size + head_size_v)
        return (1, 0, 2, 3, 4)
    elif cache_layout == "NHD":
        stride_order = (0, 1, 2, 3)
    elif cache_layout == "HND" and include_num_layers_dimension:
        # (num_blocks, num_kv_heads, num_layers,
        # block_size, head_size + head_size_v)
        return (1, 3, 0, 2, 4)
    elif cache_layout == "HND":
        stride_order = (0, 2, 1, 3)
    else:
        raise ValueError(f"Unknown cache layout format {cache_layout}.")
    return stride_order

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
@staticmethod
def get_name() -> str:
    return "FLASH_ATTN_DIFFKV"

set_head_size_v classmethod

set_head_size_v(head_size_v: int) -> None
Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
@classmethod
def set_head_size_v(cls, head_size_v: int) -> None:
    cls.head_size_v = head_size_v

FlashAttentionDiffKVImpl

Bases: FlashAttentionImpl

Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size_v]
            kv_cache: shape =
                [num_blocks, block_size, num_kv_heads, head_size + head_size_v]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size_v]
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
        """
        assert output is not None, "Output tensor must be provided."
        assert self.vllm_flash_attn_version is not None, (
            "FlashAttention version not detected."
        )

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for FlashAttentionImpl"
            )

        if attn_metadata is None:
            # Profiling run.
            return output.fill_(0)

        attn_type = self.attn_type

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens

        # Handle encoder attention differently - no KV cache needed
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                output[:num_actual_tokens],
                attn_metadata,
                layer,
            )

        # For decoder and cross-attention, use KV cache as before
        # Different head_size for K and V
        key_cache = kv_cache[..., : self.head_size]
        value_cache = kv_cache[..., self.head_size :]

        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (
            self.kv_sharing_target_layer_name is None
            and key is not None
            and value is not None
        ):
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.

            # kv_cache update for different head_size K and V
            triton_reshape_and_cache_flash_diffkv(
                key,
                value,
                kv_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        if self.kv_cache_dtype.startswith("fp8"):
            # queries are quantized in the attention layer
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype
            )
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)

        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata

            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

            if self.dcp_world_size > 1:
                self._forward_with_dcp(
                    query[:num_actual_tokens],
                    key[:num_actual_tokens],
                    value[:num_actual_tokens],
                    key_cache,
                    value_cache,
                    output[:num_actual_tokens],
                    attn_metadata,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                )
                return output
            else:
                sliding_window_size = (
                    list(self.sliding_window)
                    if self.sliding_window is not None
                    else None
                )
                flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
                    causal=attn_metadata.causal,
                    alibi_slopes=self.alibi_slopes,
                    window_size=sliding_window_size,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
                    s_aux=self.sinks,
                )
                return output

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
            max_kv_len=attn_metadata.max_seq_len,
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
            block_table=attn_metadata.block_table,
            common_prefix_len=attn_metadata.common_prefix_len,
            max_num_splits=attn_metadata.max_num_splits,
            fa_version=self.vllm_flash_attn_version,
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
            s_aux=self.sinks,
        )
        return output

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor

Forward pass with FlashAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size_v]

required
kv_cache Tensor

shape = [num_blocks, block_size, num_kv_heads, head_size + head_size_v]

required
attn_metadata FlashAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values

Source code in vllm/v1/attention/backends/flash_attn_diffkv.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass with FlashAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size_v]
        kv_cache: shape =
            [num_blocks, block_size, num_kv_heads, head_size + head_size_v]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size_v]
    NOTE: FP8 quantization, flash-attn expect the size of
          {q,k,v}_descale to be (num_sequences, num_kv_heads).
          We use torch's .expand() to avoid duplicating values
    """
    assert output is not None, "Output tensor must be provided."
    assert self.vllm_flash_attn_version is not None, (
        "FlashAttention version not detected."
    )

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for FlashAttentionImpl"
        )

    if attn_metadata is None:
        # Profiling run.
        return output.fill_(0)

    attn_type = self.attn_type

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens

    # Handle encoder attention differently - no KV cache needed
    if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
        # For encoder attention,
        # we use direct Q, K, V tensors without caching
        return self._forward_encoder_attention(
            query[:num_actual_tokens],
            key[:num_actual_tokens],
            value[:num_actual_tokens],
            output[:num_actual_tokens],
            attn_metadata,
            layer,
        )

    # For decoder and cross-attention, use KV cache as before
    # Different head_size for K and V
    key_cache = kv_cache[..., : self.head_size]
    value_cache = kv_cache[..., self.head_size :]

    # key and value may be None in the case of cross attention. They are
    # calculated once based on the output from the encoder and then cached
    # in KV cache.
    if (
        self.kv_sharing_target_layer_name is None
        and key is not None
        and value is not None
    ):
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.

        # kv_cache update for different head_size K and V
        triton_reshape_and_cache_flash_diffkv(
            key,
            value,
            kv_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    if self.kv_cache_dtype.startswith("fp8"):
        # queries are quantized in the attention layer
        dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
            self.kv_cache_dtype
        )
        key_cache = key_cache.view(dtype)
        value_cache = value_cache.view(dtype)

    if not attn_metadata.use_cascade:
        cu_seqlens_q = attn_metadata.query_start_loc
        seqused_k = attn_metadata.seq_lens
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_seq_len
        block_table = attn_metadata.block_table
        scheduler_metadata = attn_metadata.scheduler_metadata

        descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

        if self.dcp_world_size > 1:
            self._forward_with_dcp(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                key_cache,
                value_cache,
                output[:num_actual_tokens],
                attn_metadata,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )
            return output
        else:
            sliding_window_size = (
                list(self.sliding_window)
                if self.sliding_window is not None
                else None
            )
            flash_attn_varlen_func(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=seqused_k,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=attn_metadata.causal,
                alibi_slopes=self.alibi_slopes,
                window_size=sliding_window_size,
                block_table=block_table,
                softcap=self.logits_soft_cap,
                scheduler_metadata=scheduler_metadata,
                fa_version=self.vllm_flash_attn_version,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
                num_splits=attn_metadata.max_num_splits,
                s_aux=self.sinks,
            )
            return output

    # Cascade attention (rare case).
    cascade_attention(
        output[:num_actual_tokens],
        query[:num_actual_tokens],
        key_cache,
        value_cache,
        cu_query_lens=attn_metadata.query_start_loc,
        max_query_len=attn_metadata.max_query_len,
        cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
        prefix_kv_lens=attn_metadata.prefix_kv_lens,
        suffix_kv_lens=attn_metadata.suffix_kv_lens,
        max_kv_len=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        alibi_slopes=self.alibi_slopes,
        sliding_window=self.sliding_window,
        logits_soft_cap=self.logits_soft_cap,
        block_table=attn_metadata.block_table,
        common_prefix_len=attn_metadata.common_prefix_len,
        max_num_splits=attn_metadata.max_num_splits,
        fa_version=self.vllm_flash_attn_version,
        prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
        suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
        q_descale=layer._q_scale,
        k_descale=layer._k_scale,
        v_descale=layer._v_scale,
        s_aux=self.sinks,
    )
    return output