Skip to content

vllm.model_executor.layers.pooler.tokwise

Poolers that produce an output for each token in the sequence.

Modules:

Name Description
heads
methods
poolers

TokenPoolerHeadOutputItem module-attribute

TokenPoolerHeadOutputItem: TypeAlias = Tensor | None

TokenPoolerOutput module-attribute

TokenPoolerOutput: TypeAlias = list[Tensor | None]

TokenPoolingMethodOutputItem module-attribute

TokenPoolingMethodOutputItem: TypeAlias = Tensor | None

__all__ module-attribute

__all__ = [
    "TokenPoolerHead",
    "TokenPoolerHeadOutputItem",
    "TokenClassifierPoolerHead",
    "TokenEmbeddingPoolerHead",
    "TokenPoolingMethod",
    "TokenPoolingMethodOutputItem",
    "AllPool",
    "StepPool",
    "get_tok_pooling_method",
    "TokenPooler",
    "TokenPoolerOutput",
    "pooler_for_token_classify",
    "pooler_for_token_embed",
]

AllPool

Bases: TokenPoolingMethod

Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
class AllPool(TokenPoolingMethod):
    def __init__(self):
        super().__init__()

        vllm_config = get_current_vllm_config()
        scheduler_config = vllm_config.scheduler_config

        self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> list[TokenPoolingMethodOutputItem]:
        pooling_cursor = pooling_metadata.get_pooling_cursor()
        hidden_states_all = hidden_states.split(
            pooling_cursor.num_scheduled_tokens_cpu.tolist()
        )
        hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]

        if not self.enable_chunked_prefill:
            return hidden_states_lst

        pooling_states = pooling_metadata.pooling_states

        # If chunked_prefill is enabled
        # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
        for p, hs_chunk in zip(pooling_states, hidden_states_lst):
            p.hidden_states_cache.append(hs_chunk)

        # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
        output_list = list[TokenPoolingMethodOutputItem]()
        for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
            if finished:
                hidden_states_cache = p.hidden_states_cache
                if len(hidden_states_cache) == 1:
                    output_list.append(hidden_states_cache[0])
                else:
                    output_list.append(torch.concat(hidden_states_cache, dim=0))
                p.clean()
            else:
                output_list.append(None)

        return output_list

enable_chunked_prefill instance-attribute

enable_chunked_prefill = enable_chunked_prefill

__init__

__init__()
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def __init__(self):
    super().__init__()

    vllm_config = get_current_vllm_config()
    scheduler_config = vllm_config.scheduler_config

    self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> list[TokenPoolingMethodOutputItem]
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
    pooling_cursor = pooling_metadata.get_pooling_cursor()
    hidden_states_all = hidden_states.split(
        pooling_cursor.num_scheduled_tokens_cpu.tolist()
    )
    hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]

    if not self.enable_chunked_prefill:
        return hidden_states_lst

    pooling_states = pooling_metadata.pooling_states

    # If chunked_prefill is enabled
    # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
    for p, hs_chunk in zip(pooling_states, hidden_states_lst):
        p.hidden_states_cache.append(hs_chunk)

    # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
    output_list = list[TokenPoolingMethodOutputItem]()
    for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
        if finished:
            hidden_states_cache = p.hidden_states_cache
            if len(hidden_states_cache) == 1:
                output_list.append(hidden_states_cache[0])
            else:
                output_list.append(torch.concat(hidden_states_cache, dim=0))
            p.clean()
        else:
            output_list.append(None)

    return output_list

StepPool

Bases: AllPool

Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
class StepPool(AllPool):
    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> list[TokenPoolingMethodOutputItem]:
        pooled_data_lst = super().forward(hidden_states, pooling_metadata)
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()
        pooling_params = pooling_metadata.pooling_params

        pooled_data = list[torch.Tensor | None]()
        for data, token_id, pooling_param in zip(
            pooled_data_lst, prompt_token_ids, pooling_params
        ):
            # for unfinished chunked prefill
            if data is None:
                pass
            else:
                step_tag_id = pooling_param.step_tag_id
                returned_token_ids = pooling_param.returned_token_ids

                if returned_token_ids is not None and len(returned_token_ids) > 0:
                    data = data[:, returned_token_ids]

                if step_tag_id is not None:
                    data = data[token_id == step_tag_id]

            pooled_data.append(data)

        return pooled_data

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> list[TokenPoolingMethodOutputItem]
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
    pooled_data_lst = super().forward(hidden_states, pooling_metadata)
    prompt_token_ids = pooling_metadata.get_prompt_token_ids()
    pooling_params = pooling_metadata.pooling_params

    pooled_data = list[torch.Tensor | None]()
    for data, token_id, pooling_param in zip(
        pooled_data_lst, prompt_token_ids, pooling_params
    ):
        # for unfinished chunked prefill
        if data is None:
            pass
        else:
            step_tag_id = pooling_param.step_tag_id
            returned_token_ids = pooling_param.returned_token_ids

            if returned_token_ids is not None and len(returned_token_ids) > 0:
                data = data[:, returned_token_ids]

            if step_tag_id is not None:
                data = data[token_id == step_tag_id]

        pooled_data.append(data)

    return pooled_data

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return PoolingParamsUpdate(requires_token_ids=True)

TokenClassifierPoolerHead

Bases: TokenPoolerHead

Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
class TokenClassifierPoolerHead(TokenPoolerHead):
    def __init__(
        self,
        classifier: ClassifierFn | None = None,
        logit_bias: float | None = None,
        head_dtype: torch.dtype | str | None = None,
        activation: ActivationFn | None = None,
    ) -> None:
        super().__init__()

        self.classifier = classifier
        self.logit_bias = logit_bias
        self.head_dtype = head_dtype
        self.activation = activation

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_classify"}

    def forward_chunk(
        self,
        pooled_data: TokenPoolingMethodOutputItem,
        pooling_param: PoolingParams,
    ) -> TokenPoolerHeadOutputItem:
        # for unfinished chunked prefill
        if pooled_data is None:
            return None

        if self.head_dtype is not None:
            pooled_data = pooled_data.to(self.head_dtype)
        # hidden_states shape: [n_token, hidden_size]

        if self.classifier is not None:
            scores = self.classifier(pooled_data)
        else:
            scores = pooled_data
        # scores shape: [n_token, num_labels]

        if self.logit_bias is not None:
            scores -= self.logit_bias

        if self.activation is not None and pooling_param.use_activation:
            scores = self.activation(scores)

        # scores shape: [n_token, num_labels]
        return scores

activation instance-attribute

activation = activation

classifier instance-attribute

classifier = classifier

head_dtype instance-attribute

head_dtype = head_dtype

logit_bias instance-attribute

logit_bias = logit_bias

__init__

__init__(
    classifier: ClassifierFn | None = None,
    logit_bias: float | None = None,
    head_dtype: dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def __init__(
    self,
    classifier: ClassifierFn | None = None,
    logit_bias: float | None = None,
    head_dtype: torch.dtype | str | None = None,
    activation: ActivationFn | None = None,
) -> None:
    super().__init__()

    self.classifier = classifier
    self.logit_bias = logit_bias
    self.head_dtype = head_dtype
    self.activation = activation

forward_chunk

forward_chunk(
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def forward_chunk(
    self,
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
    # for unfinished chunked prefill
    if pooled_data is None:
        return None

    if self.head_dtype is not None:
        pooled_data = pooled_data.to(self.head_dtype)
    # hidden_states shape: [n_token, hidden_size]

    if self.classifier is not None:
        scores = self.classifier(pooled_data)
    else:
        scores = pooled_data
    # scores shape: [n_token, num_labels]

    if self.logit_bias is not None:
        scores -= self.logit_bias

    if self.activation is not None and pooling_param.use_activation:
        scores = self.activation(scores)

    # scores shape: [n_token, num_labels]
    return scores

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"token_classify"}

TokenEmbeddingPoolerHead

Bases: TokenPoolerHead

Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
class TokenEmbeddingPoolerHead(TokenPoolerHead):
    def __init__(
        self,
        head_dtype: torch.dtype | str | None = None,
        projector: ProjectorFn | None = None,
        activation: ActivationFn | None = None,
    ) -> None:
        super().__init__()

        self.head_dtype = head_dtype
        self.projector = projector
        self.activation = activation

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed"}

    def forward_chunk(
        self,
        pooled_data: TokenPoolingMethodOutputItem,
        pooling_param: PoolingParams,
    ) -> TokenPoolerHeadOutputItem:
        # for unfinished chunked prefill
        if pooled_data is None:
            return None

        if self.head_dtype is not None:
            pooled_data = pooled_data.to(self.head_dtype)
        # pooled_data shape: [n_tokens, hidden_dimension]

        # Apply ST projector
        if self.projector is not None:
            pooled_data = self.projector(pooled_data)
        # pooled_data shape: [n_tokens, embedding_dimension]

        # for matryoshka representation
        pooled_data = pooled_data[..., : pooling_param.dimensions]

        # for normalize
        if self.activation is not None and pooling_param.use_activation:
            pooled_data = self.activation(pooled_data)

        # pooled_data shape: [n_tokens, embedding_dimension]
        return pooled_data

activation instance-attribute

activation = activation

head_dtype instance-attribute

head_dtype = head_dtype

projector instance-attribute

projector = projector

__init__

__init__(
    head_dtype: dtype | str | None = None,
    projector: ProjectorFn | None = None,
    activation: ActivationFn | None = None,
) -> None
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def __init__(
    self,
    head_dtype: torch.dtype | str | None = None,
    projector: ProjectorFn | None = None,
    activation: ActivationFn | None = None,
) -> None:
    super().__init__()

    self.head_dtype = head_dtype
    self.projector = projector
    self.activation = activation

forward_chunk

forward_chunk(
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def forward_chunk(
    self,
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
    # for unfinished chunked prefill
    if pooled_data is None:
        return None

    if self.head_dtype is not None:
        pooled_data = pooled_data.to(self.head_dtype)
    # pooled_data shape: [n_tokens, hidden_dimension]

    # Apply ST projector
    if self.projector is not None:
        pooled_data = self.projector(pooled_data)
    # pooled_data shape: [n_tokens, embedding_dimension]

    # for matryoshka representation
    pooled_data = pooled_data[..., : pooling_param.dimensions]

    # for normalize
    if self.activation is not None and pooling_param.use_activation:
        pooled_data = self.activation(pooled_data)

    # pooled_data shape: [n_tokens, embedding_dimension]
    return pooled_data

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"token_embed"}

TokenPooler

Bases: Pooler

A layer that pools specific information from hidden states.

This layer does the following: 1. Extracts specific tokens or aggregates data based on pooling method. 2. Postprocesses the output based on pooling head. 3. Returns structured results as PoolerOutput.

Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
class TokenPooler(Pooler):
    """
    A layer that pools specific information from hidden states.

    This layer does the following:
    1. Extracts specific tokens or aggregates data based on pooling method.
    2. Postprocesses the output based on pooling head.
    3. Returns structured results as `PoolerOutput`.
    """

    def __init__(
        self,
        pooling: TokenPoolingMethod | TokenPoolingFn,
        head: TokenPoolerHead | TokenPoolingHeadFn,
    ) -> None:
        super().__init__()

        self.pooling = pooling
        self.head = head

    def get_supported_tasks(self) -> Set[PoolingTask]:
        tasks = set(POOLING_TASKS)

        if isinstance(self.pooling, TokenPoolingMethod):
            tasks &= self.pooling.get_supported_tasks()
        if isinstance(self.head, TokenPoolerHead):
            tasks &= self.head.get_supported_tasks()

        return tasks

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        updates = PoolingParamsUpdate()

        if isinstance(self.pooling, TokenPoolingMethod):
            updates |= self.pooling.get_pooling_updates(task)

        return updates

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return pooled_data

head instance-attribute

head = head

pooling instance-attribute

pooling = pooling

__init__

__init__(
    pooling: TokenPoolingMethod | TokenPoolingFn,
    head: TokenPoolerHead | TokenPoolingHeadFn,
) -> None
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def __init__(
    self,
    pooling: TokenPoolingMethod | TokenPoolingFn,
    head: TokenPoolerHead | TokenPoolingHeadFn,
) -> None:
    super().__init__()

    self.pooling = pooling
    self.head = head

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> TokenPoolerOutput
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
    pooled_data = self.pooling(hidden_states, pooling_metadata)
    pooled_data = self.head(pooled_data, pooling_metadata)
    return pooled_data

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    updates = PoolingParamsUpdate()

    if isinstance(self.pooling, TokenPoolingMethod):
        updates |= self.pooling.get_pooling_updates(task)

    return updates

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    tasks = set(POOLING_TASKS)

    if isinstance(self.pooling, TokenPoolingMethod):
        tasks &= self.pooling.get_supported_tasks()
    if isinstance(self.head, TokenPoolerHead):
        tasks &= self.head.get_supported_tasks()

    return tasks

TokenPoolerHead

Bases: Module, ABC

Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
class TokenPoolerHead(nn.Module, ABC):
    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        raise NotImplementedError

    @abstractmethod
    def forward_chunk(
        self,
        pooled_data: TokenPoolingMethodOutputItem,
        pooling_param: PoolingParams,
    ) -> TokenPoolerHeadOutputItem:
        raise NotImplementedError

    def forward(
        self,
        pooled_data: list[TokenPoolingMethodOutputItem],
        pooling_metadata: PoolingMetadata,
    ) -> list[TokenPoolerHeadOutputItem]:
        pooling_params = pooling_metadata.pooling_params
        assert len(pooled_data) == len(pooling_params)

        return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]

forward

forward(
    pooled_data: list[TokenPoolingMethodOutputItem],
    pooling_metadata: PoolingMetadata,
) -> list[TokenPoolerHeadOutputItem]
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
def forward(
    self,
    pooled_data: list[TokenPoolingMethodOutputItem],
    pooling_metadata: PoolingMetadata,
) -> list[TokenPoolerHeadOutputItem]:
    pooling_params = pooling_metadata.pooling_params
    assert len(pooled_data) == len(pooling_params)

    return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]

forward_chunk abstractmethod

forward_chunk(
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
@abstractmethod
def forward_chunk(
    self,
    pooled_data: TokenPoolingMethodOutputItem,
    pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
    raise NotImplementedError

get_supported_tasks abstractmethod

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/tokwise/heads.py
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
    raise NotImplementedError

TokenPoolingMethod

Bases: Module, ABC

Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
class TokenPoolingMethod(nn.Module, ABC):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"token_embed", "token_classify"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> list[TokenPoolingMethodOutputItem]:
        raise NotImplementedError

forward abstractmethod

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> list[TokenPoolingMethodOutputItem]
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
@abstractmethod
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
    raise NotImplementedError

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return PoolingParamsUpdate()

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"token_embed", "token_classify"}

get_tok_pooling_method

get_tok_pooling_method(
    pooling_type: TokenPoolingType | str,
)
Source code in vllm/model_executor/layers/pooler/tokwise/methods.py
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
    if pooling_type == "ALL":
        return AllPool()
    if pooling_type == "STEP":
        return StepPool()

    raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")

pooler_for_token_classify

pooler_for_token_classify(
    pooler_config: PoolerConfig,
    *,
    pooling: TokenPoolingMethod
    | TokenPoolingFn
    | None = None,
    classifier: ClassifierFn | None = None,
    act_fn: PoolerActivation | str | None = None,
)
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def pooler_for_token_classify(
    pooler_config: PoolerConfig,
    *,
    pooling: TokenPoolingMethod | TokenPoolingFn | None = None,
    classifier: ClassifierFn | None = None,
    act_fn: PoolerActivation | str | None = None,
):
    if pooling is None:
        pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())

    vllm_config = get_current_vllm_config()
    model_config = vllm_config.model_config
    head = TokenClassifierPoolerHead(
        head_dtype=model_config.head_dtype,
        classifier=classifier,
        logit_bias=model_config.pooler_config.logit_bias,
        activation=resolve_classifier_act_fn(
            model_config, static_num_labels=False, act_fn=act_fn
        ),
    )

    return TokenPooler(pooling=pooling, head=head)

pooler_for_token_embed

pooler_for_token_embed(pooler_config: PoolerConfig)
Source code in vllm/model_executor/layers/pooler/tokwise/poolers.py
def pooler_for_token_embed(pooler_config: PoolerConfig):
    pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())

    vllm_config = get_current_vllm_config()
    model_config = vllm_config.model_config
    head = TokenEmbeddingPoolerHead(
        head_dtype=model_config.head_dtype,
        projector=_load_st_projector(model_config),
        activation=PoolerNormalize(),
    )

    return TokenPooler(pooling=pooling, head=head)