Skip to content

vllm.model_executor.layers.pooler.seqwise.methods

SequencePoolingMethodOutput module-attribute

SequencePoolingMethodOutput: TypeAlias = (
    Tensor | list[Tensor]
)

CLSPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class CLSPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        assert not pooling_cursor.is_partial_prefill(), (
            "partial prefill not supported with CLS pooling"
        )

        return hidden_states[pooling_cursor.first_token_indices_gpu]

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    assert not pooling_cursor.is_partial_prefill(), (
        "partial prefill not supported with CLS pooling"
    )

    return hidden_states[pooling_cursor.first_token_indices_gpu]

LastPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class LastPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        return hidden_states[pooling_cursor.last_token_indices_gpu]

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    return hidden_states[pooling_cursor.last_token_indices_gpu]

MeanPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class MeanPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        assert not pooling_cursor.is_partial_prefill(), (
            "partial prefill not supported with MEAN pooling"
        )

        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )

        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu

        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    assert not pooling_cursor.is_partial_prefill(), (
        "partial prefill not supported with MEAN pooling"
    )

    prompt_lens = pooling_cursor.prompt_lens_cpu.to(
        hidden_states.device, non_blocking=True
    )

    # Use float32 for torch.cumsum in MeanPool,
    # otherwise precision will be lost significantly.
    cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

    start_indices = pooling_cursor.first_token_indices_gpu
    end_indices = pooling_cursor.last_token_indices_gpu

    return (
        cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
    ) / prompt_lens.unsqueeze(1)

SequencePoolingMethod

Bases: Module, ABC

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class SequencePoolingMethod(nn.Module, ABC):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify", "embed", "classify", "score"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        raise NotImplementedError

forward abstractmethod

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
@abstractmethod
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    raise NotImplementedError

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return PoolingParamsUpdate()

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"token_embed", "token_classify", "embed", "classify", "score"}

get_seq_pooling_method

get_seq_pooling_method(
    pooling_type: SequencePoolingType | str,
)
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
    if pooling_type == "CLS":
        return CLSPool()
    if pooling_type == "LAST":
        return LastPool()
    if pooling_type == "MEAN":
        return MeanPool()

    raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")