class MRopeState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.device = device
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor(
(max_num_reqs, 3 * max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
def init_prefill_mrope_positions(
self,
req_idx: int,
mrope_model: SupportsMRoPE,
prefill_token_ids: list[int],
mm_features: list,
) -> None:
prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions(
prefill_token_ids,
mm_features,
)
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()
self.prefill_mrope_positions.stage_write(
req_idx, i * self.max_model_len, pos
)
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
def apply_staged_writes(self) -> None:
self.prefill_mrope_positions.apply_write()
self.prefill_mrope_delta.copy_to_uva()
def prepare_mrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
mrope_positions: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)](
mrope_positions,
mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
self.prefill_mrope_positions.gpu.stride(0),
self.max_model_len,
self.prefill_mrope_delta.gpu,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
)