def apply_bad_words_with_drafts(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
num_draft_tokens: list[int],
) -> None:
start_idx = 0
remaining = len(bad_words_token_ids)
for i, n in enumerate(num_draft_tokens):
if (bad_words_ids := bad_words_token_ids.get(i)) is not None:
for draft_idx in range(start_idx, start_idx + n):
_apply_bad_words_single_batch(
logits[draft_idx],
bad_words_ids,
past_tokens_ids[draft_idx],
)
remaining -= 1
if not remaining:
break
start_idx += n