@triton.jit
def kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
PHYSICAL_BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
if filter_by_query_len:
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if cur_batch_query_len > 1:
return
else:
cur_batch_in_all_start_index = seq_idx
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
0, num_queries_per_kv_padded
)
query_offset = (
cur_batch_in_all_start_index * query_stride_0
+ query_head_idx[:, None] * query_stride_1
)
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
head_mask = head_mask & (query_head_idx < num_query_heads)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
# Q : (num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.where(float("-inf") < M, 1.0, 0.0)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(
alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0
)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
# iterate through tiles
for j in range(0, num_blocks):
start_n = j * BLOCK_SIZE
# Calculate the logical location within a non-standard physical block,
# such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking.
# Supports non-contiguous mapping
# from logical blocks to physical blocks
abs_token_idx = start_n + offs_n
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE
# Vectorized loading of physical block IDs
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
# 5D addressing logic of K
k_offset = (
p_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ (offs_d[:, None] // x) * stride_k_cache_2
+ internal_offsets[None, :] * stride_k_cache_3
+ (offs_d[:, None] % x) * stride_k_cache_4
)
# 4D addressing logic of V (Slot is innermost)
v_offset = (
p_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ internal_offsets[:, None] * stride_v_cache_3
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0,
eviction_policy="evict_last",
)
if K_load.dtype.is_fp8():
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :],
other=0.0,
eviction_policy="evict_last",
)
if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
seq_mask = seq_offset[None, :] < boundary
# First calculate the dot, then apply the mask.
qk = scale * tl.dot(Q, K)
S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf"))
context_len = seq_len - 1
if SLIDING_WINDOW > 0:
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (num_queries_per_kv, BLOCK_SIZE,)
p = tl.exp(S - m_j[:, None])
p = tl.where(m_j[:, None] == float("-inf"), 0.0, p)
# l_j : (num_queries_per_kv,)
l_j = tl.sum(p, axis=1)
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)
alpha = tl.where(float("-inf") == M, 0.0, alpha)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(p.to(V.dtype), V)
# epilogue
acc = acc / (L[:, None] + 1e-10)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (
cur_batch_in_all_start_index * output_stride_0
+ query_head_idx * output_stride_1
)
tl.store(
output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
acc,
mask=dim_mask[None, :] & head_mask[:, None],
)