def prepare_nvfp4_moe_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype = get_marlin_input_dtype(prefix="")
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
GROUP_SIZE = 16
E = layer.num_experts
K = layer.hidden_size
N = layer.intermediate_size_per_partition
device = w13.device
param_dtype = layer.params_dtype
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
assert weight.shape == (E, size_n, size_k // 2)
for i in range(E):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
# WEIGHT SCALES
# Permute scales
def premute_scales(
scales: torch.Tensor, g_scales: torch.Tensor, name: str
) -> tuple[torch.Tensor, torch.Tensor]:
scales = scales.to(param_dtype)
g_scales = g_scales.to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
for i in range(E):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale,
size_k=size_k,
size_n=size_n,
group_size=GROUP_SIZE,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
g_scales = nvfp4_marlin_process_global_scale(g_scales)
return scales, g_scales
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2")
return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2