class StructuredOutputsWorker:
def __init__(
self,
max_num_logits: int,
vocab_size: int,
):
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
# to save a unnecessary CPU-to-CPU copy.
self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
self.grammar_bitmask = UvaBufferPool(
(max_num_logits, cdiv(vocab_size, 32)), torch.int32
)
def apply_grammar_bitmask(
self,
logits: torch.Tensor,
input_batch: InputBatch,
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
) -> None:
if not grammar_req_ids:
return
# Construct bitmask -> logits mapping
mapping: list[int] = []
req_ids = input_batch.req_ids
cu_num_logits = input_batch.cu_num_logits_np.tolist()
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
for grammar_req_id in grammar_req_ids:
req_idx = req_id_to_idx[grammar_req_id]
logits_start_idx = cu_num_logits[req_idx]
logits_end_idx = cu_num_logits[req_idx + 1]
mapping.extend(range(logits_start_idx, logits_end_idx))
# Copy the mapping.
mapping_np = np.array(mapping, dtype=np.int32)
logits_indices = self.logits_indices.copy_to_uva(mapping_np)
# Copy the bitmask.
bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)
num_masks = bitmask.shape[0]
assert num_masks == len(mapping)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
logits_indices,
bitmask,
bitmask.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)