Skip to content

vllm.v1.worker.gpu.sample.logit_bias

MAX_NUM_ALLOWED_TOKEN_IDS module-attribute

MAX_NUM_ALLOWED_TOKEN_IDS = 1024

MAX_NUM_LOGIT_BIAS_TOKENS module-attribute

MAX_NUM_LOGIT_BIAS_TOKENS = 1024

MAX_NUM_STOP_TOKEN_IDS module-attribute

MAX_NUM_STOP_TOKEN_IDS = 128

LogitBiasState

Source code in vllm/v1/worker/gpu/sample/logit_bias.py
class LogitBiasState:
    def __init__(
        self,
        max_num_reqs: int,
        device: torch.device,
    ):
        self.max_num_reqs = max_num_reqs

        # Allowed token IDs.
        self.num_allowed_token_ids = UvaBackedTensor(
            self.max_num_reqs, dtype=torch.int32
        )
        self.allowed_token_ids = StagedWriteTensor(
            (self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
            dtype=torch.int32,
            device=device,
        )
        # Logit bias.
        self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
        self.logit_bias_token_ids = StagedWriteTensor(
            (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
            dtype=torch.int32,
            device=device,
        )
        self.logit_bias = StagedWriteTensor(
            (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
            dtype=torch.float32,
            device=device,
        )
        # Min tokens.
        self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
        self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
        self.stop_token_ids = StagedWriteTensor(
            (self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
            dtype=torch.int32,
            device=device,
        )

    def add_request(
        self,
        req_idx: int,
        prompt_len: int,
        sampling_params: SamplingParams,
    ) -> None:
        # Allowed token IDs.
        allowed_token_ids = sampling_params.allowed_token_ids
        if allowed_token_ids:
            num_allowed_token_ids = len(allowed_token_ids)
            if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
                raise ValueError(
                    f"Too many allowed token IDs: {num_allowed_token_ids}. "
                    f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
                )
            self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
            self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
        else:
            self.num_allowed_token_ids.np[req_idx] = 0

        # Logit bias.
        logit_bias = sampling_params.logit_bias
        if logit_bias:
            num_logit_bias = len(logit_bias)
            if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
                raise ValueError(
                    f"Too many logit bias tokens: {num_logit_bias}. "
                    f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
                )
            self.num_logit_bias.np[req_idx] = num_logit_bias
            self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
            self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
        else:
            self.num_logit_bias.np[req_idx] = 0

        # Min tokens.
        min_tokens = sampling_params.min_tokens
        min_len = prompt_len + min_tokens
        self.min_lens.np[req_idx] = min_len
        stop_token_ids = sampling_params.all_stop_token_ids
        if stop_token_ids:
            num_stop_token_ids = len(stop_token_ids)
            if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
                raise ValueError(
                    f"Too many stop tokens: {num_stop_token_ids}. "
                    f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
                )
            self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
            self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
        else:
            self.num_stop_token_ids.np[req_idx] = 0

    def apply_staged_writes(self) -> None:
        self.num_allowed_token_ids.copy_to_uva()
        self.allowed_token_ids.apply_write()

        self.num_logit_bias.copy_to_uva()
        self.logit_bias_token_ids.apply_write()
        self.logit_bias.apply_write()

        self.min_lens.copy_to_uva()
        self.num_stop_token_ids.copy_to_uva()
        self.stop_token_ids.apply_write()

    def apply_logit_bias(
        self,
        logits: torch.Tensor,
        idx_mapping: torch.Tensor,
        pos: torch.Tensor,
    ) -> None:
        apply_logit_bias(
            logits,
            idx_mapping,
            pos,
            self.num_allowed_token_ids.gpu,
            self.allowed_token_ids.gpu,
            self.num_logit_bias.gpu,
            self.logit_bias_token_ids.gpu,
            self.logit_bias.gpu,
            self.min_lens.gpu,
            self.num_stop_token_ids.gpu,
            self.stop_token_ids.gpu,
        )

allowed_token_ids instance-attribute

allowed_token_ids = StagedWriteTensor(
    (max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
    dtype=int32,
    device=device,
)

logit_bias instance-attribute

logit_bias = StagedWriteTensor(
    (max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
    dtype=float32,
    device=device,
)

logit_bias_token_ids instance-attribute

logit_bias_token_ids = StagedWriteTensor(
    (max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
    dtype=int32,
    device=device,
)

max_num_reqs instance-attribute

max_num_reqs = max_num_reqs

min_lens instance-attribute

min_lens = UvaBackedTensor(max_num_reqs, dtype=int32)

num_allowed_token_ids instance-attribute

num_allowed_token_ids = UvaBackedTensor(
    max_num_reqs, dtype=int32
)

num_logit_bias instance-attribute

num_logit_bias = UvaBackedTensor(max_num_reqs, dtype=int32)

num_stop_token_ids instance-attribute

num_stop_token_ids = UvaBackedTensor(
    max_num_reqs, dtype=int32
)

stop_token_ids instance-attribute

stop_token_ids = StagedWriteTensor(
    (max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
    dtype=int32,
    device=device,
)

__init__

__init__(max_num_reqs: int, device: device)
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
def __init__(
    self,
    max_num_reqs: int,
    device: torch.device,
):
    self.max_num_reqs = max_num_reqs

    # Allowed token IDs.
    self.num_allowed_token_ids = UvaBackedTensor(
        self.max_num_reqs, dtype=torch.int32
    )
    self.allowed_token_ids = StagedWriteTensor(
        (self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
        dtype=torch.int32,
        device=device,
    )
    # Logit bias.
    self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
    self.logit_bias_token_ids = StagedWriteTensor(
        (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
        dtype=torch.int32,
        device=device,
    )
    self.logit_bias = StagedWriteTensor(
        (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
        dtype=torch.float32,
        device=device,
    )
    # Min tokens.
    self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
    self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
    self.stop_token_ids = StagedWriteTensor(
        (self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
        dtype=torch.int32,
        device=device,
    )

add_request

add_request(
    req_idx: int,
    prompt_len: int,
    sampling_params: SamplingParams,
) -> None
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
def add_request(
    self,
    req_idx: int,
    prompt_len: int,
    sampling_params: SamplingParams,
) -> None:
    # Allowed token IDs.
    allowed_token_ids = sampling_params.allowed_token_ids
    if allowed_token_ids:
        num_allowed_token_ids = len(allowed_token_ids)
        if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
            raise ValueError(
                f"Too many allowed token IDs: {num_allowed_token_ids}. "
                f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
            )
        self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
        self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
    else:
        self.num_allowed_token_ids.np[req_idx] = 0

    # Logit bias.
    logit_bias = sampling_params.logit_bias
    if logit_bias:
        num_logit_bias = len(logit_bias)
        if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
            raise ValueError(
                f"Too many logit bias tokens: {num_logit_bias}. "
                f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
            )
        self.num_logit_bias.np[req_idx] = num_logit_bias
        self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
        self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
    else:
        self.num_logit_bias.np[req_idx] = 0

    # Min tokens.
    min_tokens = sampling_params.min_tokens
    min_len = prompt_len + min_tokens
    self.min_lens.np[req_idx] = min_len
    stop_token_ids = sampling_params.all_stop_token_ids
    if stop_token_ids:
        num_stop_token_ids = len(stop_token_ids)
        if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
            raise ValueError(
                f"Too many stop tokens: {num_stop_token_ids}. "
                f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
            )
        self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
        self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
    else:
        self.num_stop_token_ids.np[req_idx] = 0

apply_logit_bias

apply_logit_bias(
    logits: Tensor, idx_mapping: Tensor, pos: Tensor
) -> None
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
def apply_logit_bias(
    self,
    logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    pos: torch.Tensor,
) -> None:
    apply_logit_bias(
        logits,
        idx_mapping,
        pos,
        self.num_allowed_token_ids.gpu,
        self.allowed_token_ids.gpu,
        self.num_logit_bias.gpu,
        self.logit_bias_token_ids.gpu,
        self.logit_bias.gpu,
        self.min_lens.gpu,
        self.num_stop_token_ids.gpu,
        self.stop_token_ids.gpu,
    )

apply_staged_writes

apply_staged_writes() -> None
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
def apply_staged_writes(self) -> None:
    self.num_allowed_token_ids.copy_to_uva()
    self.allowed_token_ids.apply_write()

    self.num_logit_bias.copy_to_uva()
    self.logit_bias_token_ids.apply_write()
    self.logit_bias.apply_write()

    self.min_lens.copy_to_uva()
    self.num_stop_token_ids.copy_to_uva()
    self.stop_token_ids.apply_write()

_bias_kernel

_bias_kernel(
    logits_ptr,
    logits_stride,
    vocab_size,
    idx_mapping_ptr,
    num_allowed_token_ids_ptr,
    allowed_token_ids_ptr,
    allowed_token_ids_stride,
    num_logit_bias_ptr,
    bias_token_ids_ptr,
    bias_token_ids_stride,
    bias_ptr,
    bias_stride,
    pos_ptr,
    min_lens_ptr,
    num_stop_token_ids_ptr,
    stop_token_ids_ptr,
    stop_token_ids_stride,
    BLOCK_SIZE: constexpr,
    LOGITS_BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
@triton.jit
def _bias_kernel(
    logits_ptr,
    logits_stride,
    vocab_size,
    idx_mapping_ptr,
    # Allowed token IDs.
    num_allowed_token_ids_ptr,
    allowed_token_ids_ptr,
    allowed_token_ids_stride,
    # Logit bias.
    num_logit_bias_ptr,
    bias_token_ids_ptr,
    bias_token_ids_stride,
    bias_ptr,
    bias_stride,
    # Min tokens.
    pos_ptr,
    min_lens_ptr,
    num_stop_token_ids_ptr,
    stop_token_ids_ptr,
    stop_token_ids_stride,
    BLOCK_SIZE: tl.constexpr,
    LOGITS_BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    block = tl.arange(0, BLOCK_SIZE)

    # Allowed token IDs.
    num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
    if num_allowed_token_ids > 0:
        block = tl.arange(0, BLOCK_SIZE)
        mask = block < num_allowed_token_ids

        # Save logits for allowed token IDs.
        allowed_token_ids = tl.load(
            allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
            mask=mask,
        )
        logits = tl.load(
            logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
        )

        # Set logits to -inf for all tokens.
        for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
            offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
            tl.store(
                logits_ptr + batch_idx * logits_stride + offset,
                -float("inf"),
                mask=offset < vocab_size,
            )

        # Restore logits for allowed token IDs.
        tl.store(
            logits_ptr + batch_idx * logits_stride + allowed_token_ids,
            logits,
            mask=mask,
        )

    # Logit bias.
    num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
    if num_logit_bias > 0:
        mask = block < num_logit_bias
        token_ids = tl.load(
            bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
            mask=mask,
        )
        bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
        logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
        logits += bias
        tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)

    # Apply min tokens.
    num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
    pos = tl.load(pos_ptr + batch_idx)
    min_len = tl.load(min_lens_ptr + req_state_idx)
    if num_stop_token_ids > 0 and pos < min_len:
        mask = block < num_stop_token_ids
        stop_token_ids = tl.load(
            stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
            mask=mask,
        )
        tl.store(
            logits_ptr + batch_idx * logits_stride + stop_token_ids,
            -float("inf"),
            mask=mask,
        )

apply_logit_bias

apply_logit_bias(
    logits: Tensor,
    idx_mapping: Tensor,
    pos: Tensor,
    num_allowed_token_ids: Tensor,
    allowed_token_ids: Tensor,
    num_logit_bias: Tensor,
    logit_bias_token_ids: Tensor,
    logit_bias: Tensor,
    min_lens: Tensor,
    num_stop_token_ids: Tensor,
    stop_token_ids: Tensor,
) -> None
Source code in vllm/v1/worker/gpu/sample/logit_bias.py
def apply_logit_bias(
    logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    pos: torch.Tensor,
    num_allowed_token_ids: torch.Tensor,
    allowed_token_ids: torch.Tensor,
    num_logit_bias: torch.Tensor,
    logit_bias_token_ids: torch.Tensor,
    logit_bias: torch.Tensor,
    min_lens: torch.Tensor,
    num_stop_token_ids: torch.Tensor,
    stop_token_ids: torch.Tensor,
) -> None:
    num_reqs, vocab_size = logits.shape
    BLOCK_SIZE = triton.next_power_of_2(
        max(
            allowed_token_ids.shape[-1],
            logit_bias_token_ids.shape[-1],
            stop_token_ids.shape[-1],
        )
    )
    LOGITS_BLOCK_SIZE = 8192
    _bias_kernel[(num_reqs,)](
        logits,
        logits.stride(0),
        vocab_size,
        idx_mapping,
        num_allowed_token_ids,
        allowed_token_ids,
        allowed_token_ids.stride(0),
        num_logit_bias,
        logit_bias_token_ids,
        logit_bias_token_ids.stride(0),
        logit_bias,
        logit_bias.stride(0),
        pos,
        min_lens,
        num_stop_token_ids,
        stop_token_ids,
        stop_token_ids.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
        LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
    )