Skip to content

vllm.model_executor.layers.pooler.seqwise

Poolers that produce an output aggregating all tokens in the sequence.

Modules:

Name Description
heads
methods
poolers

SequencePoolerHeadOutput module-attribute

SequencePoolerHeadOutput: TypeAlias = Tensor | list[Tensor]

SequencePoolerOutput module-attribute

SequencePoolerOutput: TypeAlias = Tensor | list[Tensor]

SequencePoolingFn module-attribute

SequencePoolingHeadFn module-attribute

SequencePoolingMethodOutput module-attribute

SequencePoolingMethodOutput: TypeAlias = (
    Tensor | list[Tensor]
)

__all__ module-attribute

__all__ = [
    "SequencePoolerHead",
    "SequencePoolerHeadOutput",
    "ClassifierPoolerHead",
    "EmbeddingPoolerHead",
    "SequencePoolingMethod",
    "SequencePoolingMethodOutput",
    "CLSPool",
    "LastPool",
    "MeanPool",
    "get_seq_pooling_method",
    "SequencePooler",
    "SequencePoolingFn",
    "SequencePoolingHeadFn",
    "SequencePoolerOutput",
    "pooler_for_classify",
    "pooler_for_embed",
]

CLSPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class CLSPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        assert not pooling_cursor.is_partial_prefill(), (
            "partial prefill not supported with CLS pooling"
        )

        return hidden_states[pooling_cursor.first_token_indices_gpu]

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    assert not pooling_cursor.is_partial_prefill(), (
        "partial prefill not supported with CLS pooling"
    )

    return hidden_states[pooling_cursor.first_token_indices_gpu]

ClassifierPoolerHead

Bases: SequencePoolerHead

Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
class ClassifierPoolerHead(SequencePoolerHead):
    def __init__(
        self,
        classifier: ClassifierFn | None = None,
        logit_bias: float | None = None,
        head_dtype: torch.dtype | str | None = None,
        activation: ActivationFn | None = None,
    ) -> None:
        super().__init__()

        self.classifier = classifier
        self.logit_bias = logit_bias
        self.head_dtype = head_dtype
        self.activation = activation

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"classify", "score"}

    def forward(
        self,
        pooled_data: SequencePoolingMethodOutput,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolerHeadOutput:
        pooling_params = pooling_metadata.pooling_params
        assert len(pooled_data) == len(pooling_params)

        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_size]

        if self.head_dtype is not None:
            pooled_data = pooled_data.to(self.head_dtype)

        if self.classifier is not None:
            pooled_data = self.classifier(pooled_data)
        # pooled_data shape: [batchsize, num_labels]

        if self.logit_bias is not None:
            pooled_data -= self.logit_bias

        if self.activation is not None:
            flags = [p.use_activation for p in pooling_params]
            if len(set(flags)) == 1:
                pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
            else:
                pooled_data = [
                    self.activation(vecs) if f else vecs
                    for vecs, f in zip(pooled_data, flags)
                ]

        # pooled_data shape: [batchsize, num_labels]
        return pooled_data

activation instance-attribute

activation = activation

classifier instance-attribute

classifier = classifier

head_dtype instance-attribute

head_dtype = head_dtype

logit_bias instance-attribute

logit_bias = logit_bias

__init__

__init__(
    classifier: ClassifierFn | None = None,
    logit_bias: float | None = None,
    head_dtype: dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def __init__(
    self,
    classifier: ClassifierFn | None = None,
    logit_bias: float | None = None,
    head_dtype: torch.dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None:
    super().__init__()

    self.classifier = classifier
    self.logit_bias = logit_bias
    self.head_dtype = head_dtype
    self.activation = activation

forward

forward(
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def forward(
    self,
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
    pooling_params = pooling_metadata.pooling_params
    assert len(pooled_data) == len(pooling_params)

    if isinstance(pooled_data, list):
        pooled_data = torch.stack(pooled_data)
    # pooled_data shape: [batchsize, hidden_size]

    if self.head_dtype is not None:
        pooled_data = pooled_data.to(self.head_dtype)

    if self.classifier is not None:
        pooled_data = self.classifier(pooled_data)
    # pooled_data shape: [batchsize, num_labels]

    if self.logit_bias is not None:
        pooled_data -= self.logit_bias

    if self.activation is not None:
        flags = [p.use_activation for p in pooling_params]
        if len(set(flags)) == 1:
            pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

    # pooled_data shape: [batchsize, num_labels]
    return pooled_data

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"classify", "score"}

EmbeddingPoolerHead

Bases: SequencePoolerHead

Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
class EmbeddingPoolerHead(SequencePoolerHead):
    def __init__(
        self,
        projector: ProjectorFn | None = None,
        head_dtype: torch.dtype | str | None = None,
        activation: ActivationFn | None = None,
    ) -> None:
        super().__init__()

        self.projector = projector
        self.head_dtype = head_dtype
        self.activation = activation

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"embed"}

    def forward(
        self,
        pooled_data: SequencePoolingMethodOutput,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolerHeadOutput:
        pooling_params = pooling_metadata.pooling_params
        assert len(pooled_data) == len(pooling_params)

        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
        # pooled_data shape: [batchsize, hidden_dimension]

        if self.head_dtype is not None:
            pooled_data = pooled_data.to(self.head_dtype)

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [batchsize, embedding_dimension]

        # for matryoshka representation
        dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
        if any(d is not None for d in dimensions_list):
            # change the output dimension
            assert len(pooled_data) == len(dimensions_list)
            if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
                # if all dimensions are the same
                d = dimensions_list[0]
                pooled_data = pooled_data[..., :d]
            else:
                pooled_data = [
                    vecs if d is None else vecs[..., :d]
                    for vecs, d in zip(pooled_data, dimensions_list)
                ]

        # for normalize
        if self.activation is not None:
            flags = [p.use_activation for p in pooling_params]
            if len(set(flags)) == 1:
                if flags[0]:
                    pooled_data = self.activation(pooled_data)
            else:
                pooled_data = [
                    self.activation(vecs) if f else vecs
                    for vecs, f in zip(pooled_data, flags)
                ]

        # pooled_data shape: [batchsize, embedding_dimension]
        return pooled_data

activation instance-attribute

activation = activation

head_dtype instance-attribute

head_dtype = head_dtype

projector instance-attribute

projector = projector

__init__

__init__(
    projector: ProjectorFn | None = None,
    head_dtype: dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def __init__(
    self,
    projector: ProjectorFn | None = None,
    head_dtype: torch.dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None:
    super().__init__()

    self.projector = projector
    self.head_dtype = head_dtype
    self.activation = activation

forward

forward(
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def forward(
    self,
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
    pooling_params = pooling_metadata.pooling_params
    assert len(pooled_data) == len(pooling_params)

    if isinstance(pooled_data, list):
        pooled_data = torch.stack(pooled_data)
    # pooled_data shape: [batchsize, hidden_dimension]

    if self.head_dtype is not None:
        pooled_data = pooled_data.to(self.head_dtype)

    # Apply ST projector
    if self.projector is not None:
        pooled_data = self.projector(pooled_data)
    # pooled_data shape: [batchsize, embedding_dimension]

    # for matryoshka representation
    dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
    if any(d is not None for d in dimensions_list):
        # change the output dimension
        assert len(pooled_data) == len(dimensions_list)
        if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
            # if all dimensions are the same
            d = dimensions_list[0]
            pooled_data = pooled_data[..., :d]
        else:
            pooled_data = [
                vecs if d is None else vecs[..., :d]
                for vecs, d in zip(pooled_data, dimensions_list)
            ]

    # for normalize
    if self.activation is not None:
        flags = [p.use_activation for p in pooling_params]
        if len(set(flags)) == 1:
            if flags[0]:
                pooled_data = self.activation(pooled_data)
        else:
            pooled_data = [
                self.activation(vecs) if f else vecs
                for vecs, f in zip(pooled_data, flags)
            ]

    # pooled_data shape: [batchsize, embedding_dimension]
    return pooled_data

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"embed"}

LastPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class LastPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        return hidden_states[pooling_cursor.last_token_indices_gpu]

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    return hidden_states[pooling_cursor.last_token_indices_gpu]

MeanPool

Bases: SequencePoolingMethod

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class MeanPool(SequencePoolingMethod):
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        assert not pooling_cursor.is_partial_prefill(), (
            "partial prefill not supported with MEAN pooling"
        )

        prompt_lens = pooling_cursor.prompt_lens_cpu.to(
            hidden_states.device, non_blocking=True
        )

        # Use float32 for torch.cumsum in MeanPool,
        # otherwise precision will be lost significantly.
        cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

        start_indices = pooling_cursor.first_token_indices_gpu
        end_indices = pooling_cursor.last_token_indices_gpu

        return (
            cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
        ) / prompt_lens.unsqueeze(1)

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    assert not pooling_cursor.is_partial_prefill(), (
        "partial prefill not supported with MEAN pooling"
    )

    prompt_lens = pooling_cursor.prompt_lens_cpu.to(
        hidden_states.device, non_blocking=True
    )

    # Use float32 for torch.cumsum in MeanPool,
    # otherwise precision will be lost significantly.
    cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)

    start_indices = pooling_cursor.first_token_indices_gpu
    end_indices = pooling_cursor.last_token_indices_gpu

    return (
        cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
    ) / prompt_lens.unsqueeze(1)

SequencePooler

Bases: Pooler

A layer that pools specific information from hidden states.

This layer does the following: 1. Extracts specific tokens or aggregates data based on pooling method. 2. Postprocesses the output based on pooling head. 3. Returns structured results as PoolerOutput.

Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
class SequencePooler(Pooler):
    """
    A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Postprocesses the output based on pooling head.
    3. Returns structured results as `PoolerOutput`.
    """

    def __init__(
        self,
        pooling: SequencePoolingMethod | SequencePoolingFn,
        head: SequencePoolerHead | SequencePoolingHeadFn,
    ) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        tasks = set(POOLING_TASKS)

        if isinstance(self.pooling, SequencePoolingMethod):
            tasks &= self.pooling.get_supported_tasks()
        if isinstance(self.head, SequencePoolerHead):
            tasks &= self.head.get_supported_tasks()

        return tasks

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        updates = PoolingParamsUpdate()

        if isinstance(self.pooling, SequencePoolingMethod):
            updates |= self.pooling.get_pooling_updates(task)

        return updates

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return pooled_data

head instance-attribute

head = head

pooling instance-attribute

pooling = pooling

__init__

Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def __init__(
    self,
    pooling: SequencePoolingMethod | SequencePoolingFn,
    head: SequencePoolerHead | SequencePoolingHeadFn,
) -> None:
    super().__init__()

    self.pooling = pooling
    self.head = head

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolerOutput
Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerOutput:
    pooled_data = self.pooling(hidden_states, pooling_metadata)
    pooled_data = self.head(pooled_data, pooling_metadata)
    return pooled_data

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    updates = PoolingParamsUpdate()

    if isinstance(self.pooling, SequencePoolingMethod):
        updates |= self.pooling.get_pooling_updates(task)

    return updates

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    tasks = set(POOLING_TASKS)

    if isinstance(self.pooling, SequencePoolingMethod):
        tasks &= self.pooling.get_supported_tasks()
    if isinstance(self.head, SequencePoolerHead):
        tasks &= self.head.get_supported_tasks()

    return tasks

SequencePoolerHead

Bases: Module, ABC

Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
class SequencePoolerHead(nn.Module, ABC):
    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        pooled_data: SequencePoolingMethodOutput,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolerHeadOutput:
        raise NotImplementedError

forward abstractmethod

forward(
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
@abstractmethod
def forward(
    self,
    pooled_data: SequencePoolingMethodOutput,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
    raise NotImplementedError

get_supported_tasks abstractmethod

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/heads.py
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
    raise NotImplementedError

SequencePoolingMethod

Bases: Module, ABC

Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
class SequencePoolingMethod(nn.Module, ABC):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify", "embed", "classify", "score"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> SequencePoolingMethodOutput:
        raise NotImplementedError

forward abstractmethod

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> SequencePoolingMethodOutput
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
@abstractmethod
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
    raise NotImplementedError

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return PoolingParamsUpdate()

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"token_embed", "token_classify", "embed", "classify", "score"}

get_seq_pooling_method

get_seq_pooling_method(
    pooling_type: SequencePoolingType | str,
)
Source code in vllm/model_executor/layers/pooler/seqwise/methods.py
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
    if pooling_type == "CLS":
        return CLSPool()
    if pooling_type == "LAST":
        return LastPool()
    if pooling_type == "MEAN":
        return MeanPool()

    raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")

pooler_for_classify

pooler_for_classify(
    pooler_config: PoolerConfig,
    *,
    pooling: SequencePoolingMethod
    | SequencePoolingFn
    | None = None,
    classifier: ClassifierFn | None = None,
    act_fn: PoolerActivation | str | None = None,
)
Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def pooler_for_classify(
    pooler_config: PoolerConfig,
    *,
    pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
    classifier: ClassifierFn | None = None,
    act_fn: PoolerActivation | str | None = None,
):
    if pooling is None:
        pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())

    vllm_config = get_current_vllm_config()
    model_config = vllm_config.model_config
    head = ClassifierPoolerHead(
        head_dtype=model_config.head_dtype,
        classifier=classifier,
        logit_bias=model_config.pooler_config.logit_bias,
        activation=resolve_classifier_act_fn(
            model_config, static_num_labels=True, act_fn=act_fn
        ),
    )

    return SequencePooler(pooling=pooling, head=head)

pooler_for_embed

pooler_for_embed(pooler_config: PoolerConfig)
Source code in vllm/model_executor/layers/pooler/seqwise/poolers.py
def pooler_for_embed(pooler_config: PoolerConfig):
    pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())

    vllm_config = get_current_vllm_config()
    model_config = vllm_config.model_config
    head = EmbeddingPoolerHead(
        head_dtype=model_config.head_dtype,
        projector=_load_st_projector(model_config),
        activation=PoolerNormalize(),
    )

    return SequencePooler(pooling=pooling, head=head)