class SamplingStates:
def __init__(self, max_num_reqs: int, vocab_size: int):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
self.top_k.copy_to_uva()
self.top_p.np.fill(1.0)
self.top_p.copy_to_uva()
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(NO_LOGPROBS)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs
def apply_staged_writes(self) -> None:
self.temperature.copy_to_uva()
self.top_p.copy_to_uva()
self.top_k.copy_to_uva()
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def do_min_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.min_p.np[idx_mapping_np] != 0.0)
def do_top_k(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
def do_top_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_p.np[idx_mapping_np] != 1.0)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))