class PenaltiesState:
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.device = device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
self.repetition_penalty.copy_to_uva()
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
self._penalties_reqs.append(req_idx)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
prefill_token_ids[req_idx],
int(prefill_lens[req_idx]),
int(prompt_lens[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties_and_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
apply_penalties_and_temperature(
logits,
idx_mapping,
temperature,
self.repetition_penalty.gpu,
self.frequency_penalty.gpu,
self.presence_penalty.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
)