Skip to content

vllm.v1.worker.gpu.sample.penalties

PenaltiesState

Source code in vllm/v1/worker/gpu/sample/penalties.py
class PenaltiesState:
    def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
        self.max_num_reqs = max_num_reqs
        self.vocab_size = vocab_size
        self.device = device

        self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
        self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
        self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)

        # Initialize repetition penalty manually because 0 is an invalid value for it.
        self.repetition_penalty.np.fill(1.0)
        self.repetition_penalty.copy_to_uva()

        # Statistics for penalties.
        self.prompt_bin_mask = torch.zeros(
            self.max_num_reqs,
            cdiv(self.vocab_size, 32),
            dtype=torch.int32,
            device=self.device,
        )
        # TODO(woosuk): This tensor is rarely used but can be very large, taking up
        # GBs of GPU memory. Optimize the memory usage.
        self.output_bin_counts = torch.zeros(
            self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
        )

        self._penalties_reqs: list[int] = []

    def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
        self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
        self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
        self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
        if use_penalty(sampling_params):
            self._penalties_reqs.append(req_idx)

    def apply_staged_writes(
        self,
        prefill_token_ids: torch.Tensor,
        prefill_lens: np.ndarray,
        prompt_lens: np.ndarray,
    ) -> None:
        # TODO(woosuk): Optimize this.
        for req_idx in self._penalties_reqs:
            bincount(
                prefill_token_ids[req_idx],
                int(prefill_lens[req_idx]),
                int(prompt_lens[req_idx]),
                self.prompt_bin_mask[req_idx],
                self.output_bin_counts[req_idx],
            )
        self._penalties_reqs.clear()

        self.repetition_penalty.copy_to_uva()
        self.frequency_penalty.copy_to_uva()
        self.presence_penalty.copy_to_uva()

    def apply_penalties_and_temperature(
        self,
        logits: torch.Tensor,
        idx_mapping: torch.Tensor,
        temperature: torch.Tensor,
    ) -> None:
        apply_penalties_and_temperature(
            logits,
            idx_mapping,
            temperature,
            self.repetition_penalty.gpu,
            self.frequency_penalty.gpu,
            self.presence_penalty.gpu,
            self.prompt_bin_mask,
            self.output_bin_counts,
        )

_penalties_reqs instance-attribute

_penalties_reqs: list[int] = []

device instance-attribute

device = device

frequency_penalty instance-attribute

frequency_penalty = UvaBackedTensor(
    max_num_reqs, dtype=float32
)

max_num_reqs instance-attribute

max_num_reqs = max_num_reqs

output_bin_counts instance-attribute

output_bin_counts = zeros(
    max_num_reqs, vocab_size, dtype=int32, device=device
)

presence_penalty instance-attribute

presence_penalty = UvaBackedTensor(
    max_num_reqs, dtype=float32
)

prompt_bin_mask instance-attribute

prompt_bin_mask = zeros(
    max_num_reqs,
    cdiv(vocab_size, 32),
    dtype=int32,
    device=device,
)

repetition_penalty instance-attribute

repetition_penalty = UvaBackedTensor(
    max_num_reqs, dtype=float32
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    max_num_reqs: int, vocab_size: int, device: device
)
Source code in vllm/v1/worker/gpu/sample/penalties.py
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
    self.max_num_reqs = max_num_reqs
    self.vocab_size = vocab_size
    self.device = device

    self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
    self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
    self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)

    # Initialize repetition penalty manually because 0 is an invalid value for it.
    self.repetition_penalty.np.fill(1.0)
    self.repetition_penalty.copy_to_uva()

    # Statistics for penalties.
    self.prompt_bin_mask = torch.zeros(
        self.max_num_reqs,
        cdiv(self.vocab_size, 32),
        dtype=torch.int32,
        device=self.device,
    )
    # TODO(woosuk): This tensor is rarely used but can be very large, taking up
    # GBs of GPU memory. Optimize the memory usage.
    self.output_bin_counts = torch.zeros(
        self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
    )

    self._penalties_reqs: list[int] = []

add_request

add_request(
    req_idx: int, sampling_params: SamplingParams
) -> None
Source code in vllm/v1/worker/gpu/sample/penalties.py
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
    self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
    self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
    self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
    if use_penalty(sampling_params):
        self._penalties_reqs.append(req_idx)

apply_penalties_and_temperature

apply_penalties_and_temperature(
    logits: Tensor, idx_mapping: Tensor, temperature: Tensor
) -> None
Source code in vllm/v1/worker/gpu/sample/penalties.py
def apply_penalties_and_temperature(
    self,
    logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    temperature: torch.Tensor,
) -> None:
    apply_penalties_and_temperature(
        logits,
        idx_mapping,
        temperature,
        self.repetition_penalty.gpu,
        self.frequency_penalty.gpu,
        self.presence_penalty.gpu,
        self.prompt_bin_mask,
        self.output_bin_counts,
    )

apply_staged_writes

apply_staged_writes(
    prefill_token_ids: Tensor,
    prefill_lens: ndarray,
    prompt_lens: ndarray,
) -> None
Source code in vllm/v1/worker/gpu/sample/penalties.py
def apply_staged_writes(
    self,
    prefill_token_ids: torch.Tensor,
    prefill_lens: np.ndarray,
    prompt_lens: np.ndarray,
) -> None:
    # TODO(woosuk): Optimize this.
    for req_idx in self._penalties_reqs:
        bincount(
            prefill_token_ids[req_idx],
            int(prefill_lens[req_idx]),
            int(prompt_lens[req_idx]),
            self.prompt_bin_mask[req_idx],
            self.output_bin_counts[req_idx],
        )
    self._penalties_reqs.clear()

    self.repetition_penalty.copy_to_uva()
    self.frequency_penalty.copy_to_uva()
    self.presence_penalty.copy_to_uva()

_bincount_kernel

_bincount_kernel(
    prefill_token_ids_ptr,
    prefill_len,
    prompt_len,
    prompt_bin_mask_ptr,
    output_bin_counts_ptr,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/worker/gpu/sample/penalties.py
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
    prefill_token_ids_ptr,
    prefill_len,
    prompt_len,
    prompt_bin_mask_ptr,
    output_bin_counts_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    block_idx = tl.program_id(0)
    if block_idx * BLOCK_SIZE >= prefill_len:
        return

    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    if block_idx * BLOCK_SIZE < prompt_len:
        mask = block < prompt_len
        prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
        idx = prefill_tokens // 32
        bit_idx = prefill_tokens % 32
        bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
        tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
    if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
        mask = block < prefill_len
        mask &= block >= prompt_len
        prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
        tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)

_penalties_and_temperature_kernel

_penalties_and_temperature_kernel(
    logits_ptr,
    logits_stride,
    idx_mapping_ptr,
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
    temperature_ptr,
    prompt_bin_mask_ptr,
    prompt_bin_mask_stride,
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/worker/gpu/sample/penalties.py
@triton.jit
def _penalties_and_temperature_kernel(
    logits_ptr,
    logits_stride,
    idx_mapping_ptr,
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
    temperature_ptr,
    prompt_bin_mask_ptr,
    prompt_bin_mask_stride,
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
    rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
    freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
    pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
    temperature = tl.load(temperature_ptr + req_state_idx)
    temperature = tl.where(temperature == 0.0, 1.0, temperature)

    use_rep_penalty = rep_penalty != 1.0
    use_freq_penalty = freq_penalty != 0.0
    use_pres_penalty = pres_penalty != 0.0
    use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
    use_temperature = temperature != 1.0
    if not (use_penalty or use_temperature):
        # Early return to avoid loading logits.
        return

    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size
    logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
    logits = logits.to(tl.float32)

    if use_penalty:
        output_bin_counts = tl.load(
            output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
            mask=mask,
        )
        output_bin_mask = output_bin_counts > 0

        # Apply repetition penalties.
        if use_rep_penalty:
            packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
            packed_mask = tl.load(
                prompt_bin_mask_ptr
                + req_state_idx * prompt_bin_mask_stride
                + packed_block,
                mask=packed_block < tl.cdiv(vocab_size, 32),
            )
            prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
            prompt_bin_mask = prompt_bin_mask.to(tl.int1)
            prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)

            # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
            scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
            # If logits are positive, divide by penalty, otherwise multiply by penalty.
            logits *= tl.where(logits > 0, 1.0 / scale, scale)

        # Apply frequency penalties.
        logits -= freq_penalty * output_bin_counts
        # Apply presence penalties.
        logits -= pres_penalty * output_bin_mask

    # Apply temperature.
    logits = logits / temperature

    # Store back to logits.
    tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)

apply_penalties_and_temperature

apply_penalties_and_temperature(
    logits: Tensor,
    idx_mapping: Tensor,
    temperature: Tensor,
    repetition_penalty: Tensor,
    frequency_penalty: Tensor,
    presence_penalty: Tensor,
    prompt_bin_mask: Tensor,
    output_bin_counts: Tensor,
) -> None
Source code in vllm/v1/worker/gpu/sample/penalties.py
def apply_penalties_and_temperature(
    logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    temperature: torch.Tensor,
    repetition_penalty: torch.Tensor,
    frequency_penalty: torch.Tensor,
    presence_penalty: torch.Tensor,
    prompt_bin_mask: torch.Tensor,
    output_bin_counts: torch.Tensor,
) -> None:
    num_reqs, vocab_size = logits.shape
    BLOCK_SIZE = 8192
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
    _penalties_and_temperature_kernel[(num_reqs, num_blocks)](
        logits,
        logits.stride(0),
        idx_mapping,
        repetition_penalty,
        frequency_penalty,
        presence_penalty,
        temperature,
        prompt_bin_mask,
        prompt_bin_mask.stride(0),
        output_bin_counts,
        output_bin_counts.stride(0),
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
    )

bincount

bincount(
    prefill_token_ids: Tensor,
    prefill_len: int,
    prompt_len: int,
    prompt_bin_mask: Tensor,
    output_bin_counts: Tensor,
) -> None
Source code in vllm/v1/worker/gpu/sample/penalties.py
def bincount(
    prefill_token_ids: torch.Tensor,
    prefill_len: int,
    prompt_len: int,
    prompt_bin_mask: torch.Tensor,
    output_bin_counts: torch.Tensor,
) -> None:
    prompt_bin_mask.zero_()
    output_bin_counts.zero_()
    BLOCK_SIZE = 1024
    num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
    _bincount_kernel[(num_blocks,)](
        prefill_token_ids,
        prefill_len,
        prompt_len,
        prompt_bin_mask,
        output_bin_counts,
        BLOCK_SIZE=BLOCK_SIZE,
    )

use_penalty

use_penalty(sampling_params: SamplingParams) -> bool
Source code in vllm/v1/worker/gpu/sample/penalties.py
def use_penalty(sampling_params: SamplingParams) -> bool:
    return (
        sampling_params.repetition_penalty != 1.0
        or sampling_params.frequency_penalty != 0.0
        or sampling_params.presence_penalty != 0.0
    )