class Sampler:
def __init__(
self,
max_num_reqs: int,
vocab_size: int,
device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes(
prefill_token_ids, prefill_lens, prompt_lens
)
self.logit_bias_state.apply_staged_writes()
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits, idx_mapping, idx_mapping_np, pos
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
else:
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
# Apply penalties and temperature in place.
self.penalties_state.apply_penalties_and_temperature(
logits, idx_mapping, self.sampling_states.temperature.gpu
)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
)
return sampled, logits