Skip to content

vllm.model_executor.layers.attention.static_sink_attention

logger module-attribute

logger = init_logger(__name__)

StaticSinkAttention

Bases: Attention, CustomOp

Attention with static sink tokens

Source code in vllm/model_executor/layers/attention/static_sink_attention.py
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
    """
    Attention with static sink tokens
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        sink_len: int,
        attn_backend: type[AttentionBackend] | None = None,
        cache_config: CacheConfig | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if attn_backend is not None:
            underlying_attn_backend = attn_backend
        else:
            underlying_attn_backend = get_attn_backend(
                head_size, dtype, kv_cache_dtype, block_size
            )
        attn_backend = create_static_sink_attention_backend(
            underlying_attn_backend,  # type: ignore[arg-type]
            sink_len=sink_len,
        )
        Attention.__init__(
            self=self,
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            **kwargs,
        )
        CustomOp.__init__(self)

        self.sink_len = sink_len
        self.block_size = block_size
        self.sink_populated = False
        self.sink_key = None
        self.sink_value = None

    def update_sink_kv(self, sink_key, sink_value) -> None:
        self.sink_key = sink_key
        self.sink_value = sink_value

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        assert self.sink_key is not None and self.sink_value is not None, (
            "sink_key and sink_value have not been prepared"
        )
        if not self.sink_populated:
            forward_context: ForwardContext = get_forward_context()
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

        return super().forward(query, key, value, output_shape)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        return self.forward_native(query, key, value, output_shape)

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def populate_sink_kv(self, self_kv_cache):
        sink_kv_slot_mapping = torch.arange(
            self.block_size,
            self.sink_len + self.block_size,
            device=torch.cuda.current_device(),
            dtype=torch.long,
        )
        triton_reshape_and_cache_flash_diffkv(
            self.sink_key,
            self.sink_value,
            self_kv_cache,
            sink_kv_slot_mapping,
            self.kv_cache_dtype,
            self._k_scale,
            self._v_scale,
        )
        # We only populate the sink_key and sink_value once
        self.sink_populated = True

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER

        return SinkFullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            head_size_v=self.head_size_v,
            sink_len=self.sink_len,
            dtype=self.kv_cache_torch_dtype,
        )

block_size instance-attribute

block_size = block_size

sink_key instance-attribute

sink_key = None

sink_len instance-attribute

sink_len = sink_len

sink_populated instance-attribute

sink_populated = False

sink_value instance-attribute

sink_value = None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if attn_backend is not None:
        underlying_attn_backend = attn_backend
    else:
        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )
    attn_backend = create_static_sink_attention_backend(
        underlying_attn_backend,  # type: ignore[arg-type]
        sink_len=sink_len,
    )
    Attention.__init__(
        self=self,
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        **kwargs,
    )
    CustomOp.__init__(self)

    self.sink_len = sink_len
    self.block_size = block_size
    self.sink_populated = False
    self.sink_key = None
    self.sink_value = None

forward

forward(*args, **kwargs)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward(self, *args, **kwargs):
    return self._forward_method(*args, **kwargs)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    return self.forward_native(query, key, value, output_shape)

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    assert self.sink_key is not None and self.sink_value is not None, (
        "sink_key and sink_value have not been prepared"
    )
    if not self.sink_populated:
        forward_context: ForwardContext = get_forward_context()
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

    return super().forward(query, key, value, output_shape)

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER

    return SinkFullAttentionSpec(
        block_size=block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        head_size_v=self.head_size_v,
        sink_len=self.sink_len,
        dtype=self.kv_cache_torch_dtype,
    )

populate_sink_kv

populate_sink_kv(self_kv_cache)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def populate_sink_kv(self, self_kv_cache):
    sink_kv_slot_mapping = torch.arange(
        self.block_size,
        self.sink_len + self.block_size,
        device=torch.cuda.current_device(),
        dtype=torch.long,
    )
    triton_reshape_and_cache_flash_diffkv(
        self.sink_key,
        self.sink_value,
        self_kv_cache,
        sink_kv_slot_mapping,
        self.kv_cache_dtype,
        self._k_scale,
        self._v_scale,
    )
    # We only populate the sink_key and sink_value once
    self.sink_populated = True

update_sink_kv

update_sink_kv(sink_key, sink_value) -> None
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def update_sink_kv(self, sink_key, sink_value) -> None:
    self.sink_key = sink_key
    self.sink_value = sink_value

create_static_sink_attention_backend cached

create_static_sink_attention_backend(
    underlying_attn_backend: type[AttentionBackend],
    sink_len: int = 0,
) -> type[AttentionBackend]
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
@functools.lru_cache
def create_static_sink_attention_backend(
    underlying_attn_backend: type[AttentionBackend],
    sink_len: int = 0,
) -> type[AttentionBackend]:
    prefix = "StaticSink_"
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class StaticSinkAttentionBuilder(underlying_builder):  # type: ignore
        def __init__(
            self,
            kv_cache_spec: AttentionSpec,
            layer_names: list[str],
            vllm_config: VllmConfig,
            device: torch.device,
        ):
            super().__init__(kv_cache_spec, layer_names, vllm_config, device)
            model_config = vllm_config.model_config
            scheduler_config = vllm_config.scheduler_config
            self.sink_len = sink_len
            self.block_size = vllm_config.cache_config.block_size
            self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
            self.max_num_blocks = cdiv(
                model_config.max_model_len, vllm_config.cache_config.block_size
            )
            self.block_table_with_sink = torch.zeros(
                (
                    scheduler_config.max_num_seqs,
                    self.max_num_blocks + self.num_sink_blocks,
                ),
                device=device,
                dtype=torch.int32,
            )
            self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
                1,
                self.num_sink_blocks + 1,
                device=device,
                dtype=torch.int32,
            )

        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            common_attn_metadata.seq_lens[:] = (
                common_attn_metadata.seq_lens + self.sink_len
            )
            common_attn_metadata.seq_lens[
                common_attn_metadata.seq_lens == self.sink_len
            ] = 0
            common_attn_metadata.max_seq_len = (
                common_attn_metadata.max_seq_len + self.sink_len
            )
            max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
            num_reqs = common_attn_metadata.num_reqs
            self.block_table_with_sink[
                :num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
            ] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
            common_attn_metadata.block_table_tensor = self.block_table_with_sink[
                :num_reqs
            ]

            return super().build(common_prefix_len, common_attn_metadata, fast_build)

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
        builder_cls=StaticSinkAttentionBuilder,
    )

    return attn_backend

maybe_populate_sink

maybe_populate_sink(
    self_kv_cache: Tensor, layer_name: str
) -> None
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def maybe_populate_sink(
    self_kv_cache: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    if self.sink_populated or self_kv_cache.numel() == 0:
        return
    self.populate_sink_kv(self_kv_cache)

maybe_populate_sink_fake

maybe_populate_sink_fake(
    self_kv_cache: Tensor, layer_name: str
) -> None
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def maybe_populate_sink_fake(
    self_kv_cache: torch.Tensor,
    layer_name: str,
) -> None:
    return