class StagedWriteTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
device: torch.device,
max_concurrency: int = 2,
uva_instead_of_gpu: bool = False,
):
supported_dtypes = [torch.int32, torch.int64, torch.float32]
if dtype not in supported_dtypes:
raise ValueError(
f"Unsupported dtype {dtype}: should be one of {supported_dtypes}"
)
self.num_rows = size if isinstance(size, int) else size[0]
self.dtype = dtype
self.max_concurrency = max_concurrency
if not uva_instead_of_gpu:
# Create a GPU tensor (default)
self.gpu = torch.zeros(size, dtype=dtype, device=device)
else:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self._uva_buf = UvaBuffer(size, dtype)
self.gpu = self._uva_buf.uva
self._staged_write_indices: list[int] = []
self._staged_write_starts: list[int] = []
self._staged_write_contents: list[int | float] = []
self._staged_write_cu_lens: list[int] = []
self.write_indices = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
self.write_starts = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
init_size = next_power_of_2(self.num_rows)
self.write_contents = UvaBufferPool(
init_size, dtype=dtype, max_concurrency=max_concurrency
)
self.write_cu_lens = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
def stage_write(
self,
index: int,
start: int,
x: Iterable[int] | Iterable[float],
) -> None:
assert index >= 0
assert start >= 0
if not x:
return
self._staged_write_indices.append(index)
self._staged_write_starts.append(start)
self._staged_write_contents.extend(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def stage_write_elem(self, index: int, x: int) -> None:
assert index >= 0
self._staged_write_indices.append(index)
self._staged_write_starts.append(0)
self._staged_write_contents.append(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def apply_write(self) -> None:
n = len(self._staged_write_indices)
if n == 0:
return
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
# Special handling for write_contents
diff_len = len(self._staged_write_contents)
assert isinstance(self.write_contents.size, int)
if diff_len > self.write_contents.size:
# Re-allocate a larger buffer for the write_contents
new_size = next_power_of_2(diff_len)
self.write_contents = UvaBufferPool(
new_size, dtype=self.dtype, max_concurrency=self.max_concurrency
)
# NOTE(woosuk): Since the previous write_contents buffer is released,
# we perform a synchronization here to ensure that all data transfers
# involving the old buffer have finished before allocating a new one.
# This prevents potential race conditions. The slight overhead is
# negligible because the reallocations are infrequent in practice.
torch.cuda.synchronize()
contents_uva = self.write_contents.copy_to_uva(self._staged_write_contents)
# Write diffs to the GPU buffer
_apply_write_kernel[(n,)](
self.gpu,
self.gpu.stride(0),
indices_uva,
starts_uva,
contents_uva,
cu_lens_uva,
BLOCK_SIZE=1024,
)
# Clear the staged writes
self.clear_staged_writes()
def clear_staged_writes(self) -> None:
self._staged_write_indices.clear()
self._staged_write_starts.clear()
self._staged_write_contents.clear()
self._staged_write_cu_lens.clear()