Skip to content

vllm.v1.worker.gpu_input_batch

CachedRequestState dataclass

Source code in vllm/v1/worker/gpu_input_batch.py
@dataclass
class CachedRequestState:
    req_id: str
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec]
    sampling_params: SamplingParams | None
    generator: torch.Generator | None

    block_ids: tuple[list[int], ...]
    num_computed_tokens: int
    output_token_ids: list[int]

    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None

    xdrope_positions: torch.Tensor | None = None

    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None

    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

    def __post_init__(self):
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            self.prompt_token_ids, self.prompt_embeds
        )

        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

    @property
    def num_tokens(self) -> int:
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
                    "provided via prompt_embeds, and its ID is unknown."
                )
            return self.prompt_token_ids[idx]
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
            return self.output_token_ids[idx - self.num_prompt_tokens]
        return -1

block_ids instance-attribute

block_ids: tuple[list[int], ...]

generator instance-attribute

generator: Generator | None

lora_request class-attribute instance-attribute

lora_request: LoRARequest | None = None

mm_features instance-attribute

mrope_position_delta class-attribute instance-attribute

mrope_position_delta: int | None = None

mrope_positions class-attribute instance-attribute

mrope_positions: Tensor | None = None

num_computed_tokens instance-attribute

num_computed_tokens: int

num_tokens property

num_tokens: int

output_token_ids instance-attribute

output_token_ids: list[int]

pooling_params class-attribute instance-attribute

pooling_params: PoolingParams | None = None

pooling_states class-attribute instance-attribute

pooling_states: PoolingStates | None = None

prev_num_draft_len class-attribute instance-attribute

prev_num_draft_len: int = 0

prompt_embeds class-attribute instance-attribute

prompt_embeds: Tensor | None = None

prompt_token_ids instance-attribute

prompt_token_ids: list[int] | None

req_id instance-attribute

req_id: str

sampling_params instance-attribute

sampling_params: SamplingParams | None

xdrope_positions class-attribute instance-attribute

xdrope_positions: Tensor | None = None

__init__

__init__(
    req_id: str,
    prompt_token_ids: list[int] | None,
    mm_features: list[MultiModalFeatureSpec],
    sampling_params: SamplingParams | None,
    generator: Generator | None,
    block_ids: tuple[list[int], ...],
    num_computed_tokens: int,
    output_token_ids: list[int],
    mrope_positions: Tensor | None = None,
    mrope_position_delta: int | None = None,
    xdrope_positions: Tensor | None = None,
    lora_request: LoRARequest | None = None,
    prompt_embeds: Tensor | None = None,
    prev_num_draft_len: int = 0,
    pooling_params: PoolingParams | None = None,
    pooling_states: PoolingStates | None = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/worker/gpu_input_batch.py
def __post_init__(self):
    self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
        self.prompt_token_ids, self.prompt_embeds
    )

    if self.pooling_params is not None:
        self.pooling_states = PoolingStates()

get_token_id

get_token_id(idx: int) -> int
Source code in vllm/v1/worker/gpu_input_batch.py
def get_token_id(self, idx: int) -> int:
    if idx < self.num_prompt_tokens:
        if self.prompt_token_ids is None:
            raise ValueError(
                f"Tried to access token index {idx}, but that token was "
                "provided via prompt_embeds, and its ID is unknown."
            )
        return self.prompt_token_ids[idx]
    if idx - self.num_prompt_tokens < len(self.output_token_ids):
        return self.output_token_ids[idx - self.num_prompt_tokens]
    return -1

InputBatch

Source code in vllm/v1/worker/gpu_input_batch.py
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
class InputBatch:
    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
        kernel_block_sizes: list[int],
        logitsprocs: LogitsProcessors | None = None,
        logitsprocs_need_output_token_ids: bool = False,
        is_spec_decode: bool = False,
        is_pooling_model: bool = False,
        num_speculative_tokens: int = 0,
        cp_kv_cache_interleave_size: int = 1,
    ):
        self.is_pooling_model = is_pooling_model
        self.is_spec_decode = is_spec_decode
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_batched_tokens = max_num_batched_tokens
        self.device = device
        self.pin_memory = pin_memory
        self.vocab_size = vocab_size

        self._req_ids: list[str | None] = []
        self.req_id_to_index: dict[str, int] = {}

        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
            pin_memory=False,
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
        self.is_token_ids_tensor = torch.zeros(
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
        self.is_token_ids = self.is_token_ids_tensor.numpy()
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()

        # Block table.
        self.block_table = MultiGroupBlockTable(
            max_num_reqs=max_num_reqs,
            max_model_len=max_model_len,
            max_num_batched_tokens=max_num_batched_tokens,
            pin_memory=pin_memory,
            device=device,
            block_sizes=block_sizes,
            kernel_block_sizes=kernel_block_sizes,
            num_speculative_tokens=num_speculative_tokens,
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
        )

        # Sampling-related.
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()

        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
        self.top_p_reqs: set[str] = set()

        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
        self.top_k_reqs: set[str] = set()

        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
        self.frequency_penalties_reqs: set[str] = set()

        # Presence penalty related data structures
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
        self.presence_penalties_reqs: set[str] = set()

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
        self.repetition_penalties_reqs: set[str] = set()

        # Speculative decoding
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()

        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

        # req_index -> generator
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
        self.generators: dict[int, torch.Generator] = {}

        self.num_logprobs: dict[str, int] = {}

        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
        self.has_allowed_token_ids: set[str] = set()
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None

        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)

        self.req_output_token_ids: list[list[int] | None] = []

        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids

        # Store last speculative tokens for sampler.
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]

        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

        # for pooling models
        self.pooling_params: dict[str, PoolingParams] = {}
        self.pooling_states: dict[str, PoolingStates] = {}

        # Cached reference to the GPU tensor of previously sampled tokens
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
        self.sampled_token_ids_cpu: torch.Tensor | None = None
        self.async_copy_ready_event: torch.Event | None = None

    @property
    def req_ids(self) -> list[str]:
        # None elements should only be present transiently
        # while performing state updates to the batch.
        return cast(list[str], self._req_ids)

    def _register_add_request(self, request: "CachedRequestState") -> int:
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )

        return new_req_index

    def add_request(
        self,
        request: "CachedRequestState",
    ) -> int:
        req_index = self._register_add_request(request)

        req_id = request.req_id
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
            self.spec_token_ids.append([])
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
            self.spec_token_ids[req_index].clear()

        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            request.prompt_token_ids, request.prompt_embeds
        )
        self.num_prompt_tokens[req_index] = num_prompt_tokens
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
        if request.prompt_token_ids is not None:
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
        self.block_table.add_row(request.block_ids, req_index)

        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device=self.device,
                    )
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device="cpu",
                    )
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
                # False means we don't fill with -inf.
                self.allowed_token_ids_mask_cpu_tensor[req_index][
                    sampling_params.allowed_token_ids
                ] = False

            if sampling_params.bad_words_token_ids:
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
        elif pooling_params := request.pooling_params:
            pooling_states = request.pooling_states
            assert pooling_states is not None

            self.pooling_params[req_id] = pooling_params
            self.pooling_states[req_id] = pooling_states
            self.logits_processing_needs_token_ids[req_index] = (
                pooling_params.requires_token_ids
            )
        else:
            raise NotImplementedError("Unrecognized request type")

        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

        return req_index

    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

    def remove_request(self, req_id: str) -> int | None:
        """This method must always be followed by a call to condense().

        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """

        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None

        self.batch_update_builder.removed_append(req_index)
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
        self.spec_token_ids[req_index].clear()

        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            self.pooling_states.pop(req_id, None)
            return req_index

        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)

        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
            # False means we don't fill with -inf.
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
        self.bad_words_token_ids.pop(req_index, None)
        return req_index

    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )

        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

        self.block_table.swap_row(i1, i2)

        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )

        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

        if self.allowed_token_ids_mask_cpu_tensor is not None:
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )

    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
        """
        num_reqs = self.num_reqs

        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
        if num_reqs == 0:
            # The batched states are empty.
            self._req_ids.clear()
            self.req_output_token_ids.clear()
            self.spec_token_ids.clear()
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
        last_req_index = num_reqs + len(empty_req_indices) - 1
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
            if empty_index >= last_req_index:
                break

            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
            assert req_id is not None
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
            self.req_id_to_index[req_id] = empty_index

            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()

            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens
            ]
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
                last_req_index, :num_tokens
            ]
            if last_req_index in self.req_prompt_embeds:
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
            self.block_table.move_row(last_req_index, empty_index)

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index
            ]

            if self.is_pooling_model:
                last_req_index -= 1
                # Sampling state not used by pooling models.
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )

            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

            # TODO convert these to LogitsProcessors
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )

            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids

            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

        # Trim lists to the batch size.
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
        del self.spec_token_ids[num_reqs:]

    def refresh_metadata(self):
        """Apply any batch updates to sampling metadata."""

        if self.is_pooling_model:
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
        if not self.all_greedy:
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
        else:
            temperature = None
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )

        needs_prompt_token_ids = (
            not self.no_penalties
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )

        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

        allowed_token_ids_mask: torch.Tensor | None = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

        return SamplingMetadata(
            temperature=temperature,
            all_greedy=self.all_greedy,
            all_random=self.all_random,
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
            output_token_ids=output_token_ids,
            spec_token_ids=self.spec_token_ids,
            no_penalties=self.no_penalties,
            allowed_token_ids_mask=allowed_token_ids_mask,
            bad_words_token_ids=self.bad_words_token_ids,
            logitsprocs=self.logitsprocs,
        )

    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
        pooling_states = self.get_pooling_states()

        return PoolingMetadata(
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
            pooling_states=pooling_states,
        )

    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
            pin_memory=self.pin_memory,
        )
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
        for i in range(num_reqs):
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)

    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))

        active_lora_requests: set[LoRARequest] = set(
            self.lora_id_to_lora_request.values()
        )

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
        async_copy_ready_event: torch.Event,
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)

    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

    @property
    def no_penalties(self) -> bool:
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )

    @property
    def max_num_logprobs(self) -> int | None:
        return max(self.num_logprobs.values()) if self.num_logprobs else None

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0

_req_ids instance-attribute

_req_ids: list[str | None] = []

all_greedy property

all_greedy: bool

all_random property

all_random: bool

allowed_token_ids_mask instance-attribute

allowed_token_ids_mask: Tensor | None = None

allowed_token_ids_mask_cpu_tensor instance-attribute

allowed_token_ids_mask_cpu_tensor: Tensor | None = None

async_copy_ready_event instance-attribute

async_copy_ready_event: Event | None = None

bad_words_token_ids instance-attribute

bad_words_token_ids: dict[int, list[list[int]]] = {}

batch_update_builder instance-attribute

batch_update_builder = BatchUpdateBuilder()

block_table instance-attribute

block_table = MultiGroupBlockTable(
    max_num_reqs=max_num_reqs,
    max_model_len=max_model_len,
    max_num_batched_tokens=max_num_batched_tokens,
    pin_memory=pin_memory,
    device=device,
    block_sizes=block_sizes,
    kernel_block_sizes=kernel_block_sizes,
    num_speculative_tokens=num_speculative_tokens,
    cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)

device instance-attribute

device = device

frequency_penalties instance-attribute

frequency_penalties = empty(
    (max_num_reqs,), dtype=float, device=device
)

frequency_penalties_cpu instance-attribute

frequency_penalties_cpu = numpy()

frequency_penalties_cpu_tensor instance-attribute

frequency_penalties_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=float,
    device="cpu",
    pin_memory=pin_memory,
)

frequency_penalties_reqs instance-attribute

frequency_penalties_reqs: set[str] = set()

generators instance-attribute

generators: dict[int, Generator] = {}

greedy_reqs instance-attribute

greedy_reqs: set[str] = set()

has_allowed_token_ids instance-attribute

has_allowed_token_ids: set[str] = set()

in_progress_prompt_logprobs_cpu instance-attribute

in_progress_prompt_logprobs_cpu: dict[
    str, LogprobsTensors
] = {}

is_pooling_model instance-attribute

is_pooling_model = is_pooling_model

is_spec_decode instance-attribute

is_spec_decode = is_spec_decode

is_token_ids instance-attribute

is_token_ids = numpy()

is_token_ids_tensor instance-attribute

is_token_ids_tensor = zeros(
    (max_num_reqs, max_model_len),
    device="cpu",
    dtype=bool,
    pin_memory=False,
)

logits_processing_needs_token_ids instance-attribute

logits_processing_needs_token_ids = zeros(
    max_num_reqs, dtype=bool
)

logitsprocs instance-attribute

logitsprocs = logitsprocs or LogitsProcessors()

logitsprocs_need_output_token_ids instance-attribute

logitsprocs_need_output_token_ids = (
    logitsprocs_need_output_token_ids
)

lora_id_to_lora_request instance-attribute

lora_id_to_lora_request: dict[int, LoRARequest] = {}

lora_id_to_request_ids instance-attribute

lora_id_to_request_ids: dict[int, set[str]] = {}

max_model_len instance-attribute

max_model_len = max_model_len

max_num_batched_tokens instance-attribute

max_num_batched_tokens = max_num_batched_tokens

max_num_logprobs property

max_num_logprobs: int | None

max_num_reqs instance-attribute

max_num_reqs = max_num_reqs

no_allowed_token_ids property

no_allowed_token_ids: bool

no_penalties property

no_penalties: bool

no_top_k property

no_top_k: bool

no_top_p property

no_top_p: bool

num_accepted_tokens_cpu instance-attribute

num_accepted_tokens_cpu = numpy()

num_accepted_tokens_cpu_tensor instance-attribute

num_accepted_tokens_cpu_tensor = ones(
    (max_num_reqs,),
    dtype=int64,
    device="cpu",
    pin_memory=pin_memory,
)

num_computed_tokens_cpu instance-attribute

num_computed_tokens_cpu = numpy()

num_computed_tokens_cpu_tensor instance-attribute

num_computed_tokens_cpu_tensor = zeros(
    (max_num_reqs,),
    device="cpu",
    dtype=int32,
    pin_memory=pin_memory,
)

num_logprobs instance-attribute

num_logprobs: dict[str, int] = {}

num_prompt_tokens instance-attribute

num_prompt_tokens = zeros(max_num_reqs, dtype=int32)

num_reqs property

num_reqs: int

num_tokens_no_spec instance-attribute

num_tokens_no_spec = zeros(max_num_reqs, dtype=int32)

pin_memory instance-attribute

pin_memory = pin_memory

pooling_params instance-attribute

pooling_params: dict[str, PoolingParams] = {}

pooling_states instance-attribute

pooling_states: dict[str, PoolingStates] = {}

presence_penalties instance-attribute

presence_penalties = empty(
    (max_num_reqs,), dtype=float, device=device
)

presence_penalties_cpu instance-attribute

presence_penalties_cpu = numpy()

presence_penalties_cpu_tensor instance-attribute

presence_penalties_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=float,
    device="cpu",
    pin_memory=pin_memory,
)

presence_penalties_reqs instance-attribute

presence_penalties_reqs: set[str] = set()

prev_req_id_to_index instance-attribute

prev_req_id_to_index: dict[str, int] | None = None

prev_sampled_token_ids instance-attribute

prev_sampled_token_ids: Tensor | None = None

random_reqs instance-attribute

random_reqs: set[str] = set()

repetition_penalties instance-attribute

repetition_penalties = empty(
    (max_num_reqs,), dtype=float, device=device
)

repetition_penalties_cpu instance-attribute

repetition_penalties_cpu = numpy()

repetition_penalties_cpu_tensor instance-attribute

repetition_penalties_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=float,
    device="cpu",
    pin_memory=pin_memory,
)

repetition_penalties_reqs instance-attribute

repetition_penalties_reqs: set[str] = set()

req_id_to_index instance-attribute

req_id_to_index: dict[str, int] = {}

req_ids property

req_ids: list[str]

req_output_token_ids instance-attribute

req_output_token_ids: list[list[int] | None] = []

req_prompt_embeds instance-attribute

req_prompt_embeds: dict[int, Tensor] = {}

request_lora_mapping instance-attribute

request_lora_mapping = zeros((max_num_reqs,), dtype=int64)

sampled_token_ids_cpu instance-attribute

sampled_token_ids_cpu: Tensor | None = None

sampling_metadata instance-attribute

sampling_metadata = _make_sampling_metadata()

spec_token_ids instance-attribute

spec_token_ids: list[list[int]] = [
    [] for _ in (range(max_num_reqs))
]

temperature instance-attribute

temperature = empty(
    (max_num_reqs,), dtype=float32, device=device
)

temperature_cpu instance-attribute

temperature_cpu = numpy()

temperature_cpu_tensor instance-attribute

temperature_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=float32,
    device="cpu",
    pin_memory=pin_memory,
)

token_ids_cpu instance-attribute

token_ids_cpu = numpy()

token_ids_cpu_tensor instance-attribute

token_ids_cpu_tensor = zeros(
    (max_num_reqs, max_model_len),
    device="cpu",
    dtype=int32,
    pin_memory=False,
)

top_k instance-attribute

top_k = empty((max_num_reqs,), dtype=int32, device=device)

top_k_cpu instance-attribute

top_k_cpu = numpy()

top_k_cpu_tensor instance-attribute

top_k_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

top_k_reqs instance-attribute

top_k_reqs: set[str] = set()

top_p instance-attribute

top_p = empty((max_num_reqs,), dtype=float32, device=device)

top_p_cpu instance-attribute

top_p_cpu = numpy()

top_p_cpu_tensor instance-attribute

top_p_cpu_tensor = empty(
    (max_num_reqs,),
    dtype=float32,
    device="cpu",
    pin_memory=pin_memory,
)

top_p_reqs instance-attribute

top_p_reqs: set[str] = set()

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    max_num_reqs: int,
    max_model_len: int,
    max_num_batched_tokens: int,
    device: device,
    pin_memory: bool,
    vocab_size: int,
    block_sizes: list[int],
    kernel_block_sizes: list[int],
    logitsprocs: LogitsProcessors | None = None,
    logitsprocs_need_output_token_ids: bool = False,
    is_spec_decode: bool = False,
    is_pooling_model: bool = False,
    num_speculative_tokens: int = 0,
    cp_kv_cache_interleave_size: int = 1,
)
Source code in vllm/v1/worker/gpu_input_batch.py
def __init__(
    self,
    max_num_reqs: int,
    max_model_len: int,
    max_num_batched_tokens: int,
    device: torch.device,
    pin_memory: bool,
    vocab_size: int,
    block_sizes: list[int],  # The block_size of each kv cache group
    kernel_block_sizes: list[int],
    logitsprocs: LogitsProcessors | None = None,
    logitsprocs_need_output_token_ids: bool = False,
    is_spec_decode: bool = False,
    is_pooling_model: bool = False,
    num_speculative_tokens: int = 0,
    cp_kv_cache_interleave_size: int = 1,
):
    self.is_pooling_model = is_pooling_model
    self.is_spec_decode = is_spec_decode
    self.max_num_reqs = max_num_reqs
    self.max_model_len = max_model_len
    self.max_num_batched_tokens = max_num_batched_tokens
    self.device = device
    self.pin_memory = pin_memory
    self.vocab_size = vocab_size

    self._req_ids: list[str | None] = []
    self.req_id_to_index: dict[str, int] = {}

    # TODO(woosuk): This buffer could be too large if max_model_len is big.
    # Find a way to reduce the CPU memory usage.
    # This buffer is not directly transferred to the GPU, so it does not
    # need to be pinned.
    self.token_ids_cpu_tensor = torch.zeros(
        (max_num_reqs, max_model_len),
        device="cpu",
        dtype=torch.int32,
        pin_memory=False,
    )
    self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
    self.is_token_ids_tensor = torch.zeros(
        (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
    )
    self.is_token_ids = self.is_token_ids_tensor.numpy()
    # Store prompt embeddings per request to avoid OOM from large upfront
    # allocation if max_model_len is big.
    # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
    self.req_prompt_embeds: dict[int, torch.Tensor] = {}
    self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
    self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
    self.num_computed_tokens_cpu_tensor = torch.zeros(
        (max_num_reqs,),
        device="cpu",
        dtype=torch.int32,
        pin_memory=pin_memory,
    )
    self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()

    # Block table.
    self.block_table = MultiGroupBlockTable(
        max_num_reqs=max_num_reqs,
        max_model_len=max_model_len,
        max_num_batched_tokens=max_num_batched_tokens,
        pin_memory=pin_memory,
        device=device,
        block_sizes=block_sizes,
        kernel_block_sizes=kernel_block_sizes,
        num_speculative_tokens=num_speculative_tokens,
        cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
    )

    # Sampling-related.
    self.temperature = torch.empty(
        (max_num_reqs,), dtype=torch.float32, device=device
    )
    self.temperature_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
    )
    self.temperature_cpu = self.temperature_cpu_tensor.numpy()
    self.greedy_reqs: set[str] = set()
    self.random_reqs: set[str] = set()

    self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
    self.top_p_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
    )
    self.top_p_cpu = self.top_p_cpu_tensor.numpy()
    self.top_p_reqs: set[str] = set()

    self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
    self.top_k_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
    )
    self.top_k_cpu = self.top_k_cpu_tensor.numpy()
    self.top_k_reqs: set[str] = set()

    # Frequency penalty related data structures
    self.frequency_penalties = torch.empty(
        (max_num_reqs,), dtype=torch.float, device=device
    )
    self.frequency_penalties_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
    )
    self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
    self.frequency_penalties_reqs: set[str] = set()

    # Presence penalty related data structures
    self.presence_penalties = torch.empty(
        (max_num_reqs,), dtype=torch.float, device=device
    )
    self.presence_penalties_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
    )
    self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
    self.presence_penalties_reqs: set[str] = set()

    # Repetition penalty related data structures
    self.repetition_penalties = torch.empty(
        (max_num_reqs,), dtype=torch.float, device=device
    )
    self.repetition_penalties_cpu_tensor = torch.empty(
        (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
    )
    self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
    self.repetition_penalties_reqs: set[str] = set()

    # Speculative decoding
    self.num_accepted_tokens_cpu_tensor = torch.ones(
        (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
    )
    self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()

    # lora related
    self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
    self.lora_id_to_request_ids: dict[int, set[str]] = {}
    self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

    # req_index -> generator
    # NOTE(woosuk): The indices of the requests that do not have their own
    # generator should not be included in the dictionary.
    self.generators: dict[int, torch.Generator] = {}

    self.num_logprobs: dict[str, int] = {}

    # To accumulate prompt logprobs tensor chunks across prefill steps.
    self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

    # Internal representation of per-step batch state changes, used for
    # reordering persistent batch and generating logitsprocs batch state
    # updates. Should reset each step.
    self.batch_update_builder = BatchUpdateBuilder()

    # TODO convert this to LogitsProcessor
    self.has_allowed_token_ids: set[str] = set()
    # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
    # the value is False. Since we use masked_fill_ to set -inf.
    self.allowed_token_ids_mask: torch.Tensor | None = None
    self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None

    # req_index -> bad_words_token_ids
    self.bad_words_token_ids: dict[int, list[list[int]]] = {}

    self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)

    self.req_output_token_ids: list[list[int] | None] = []

    # Store provided logitsprocs. If none are provided, initialize empty
    # data structure
    self.logitsprocs = logitsprocs or LogitsProcessors()
    self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids

    # Store last speculative tokens for sampler.
    self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]

    # This is updated each time the batch constituents change.
    self.sampling_metadata = self._make_sampling_metadata()

    # for pooling models
    self.pooling_params: dict[str, PoolingParams] = {}
    self.pooling_states: dict[str, PoolingStates] = {}

    # Cached reference to the GPU tensor of previously sampled tokens
    self.prev_sampled_token_ids: torch.Tensor | None = None
    self.prev_req_id_to_index: dict[str, int] | None = None
    # These are used to update output_token_ids with real sampled
    # ids from prior step, if required by current sampling params
    # (e.g. penalties).
    self.sampled_token_ids_cpu: torch.Tensor | None = None
    self.async_copy_ready_event: torch.Event | None = None

_make_prompt_token_ids_tensor

_make_prompt_token_ids_tensor() -> Tensor
Source code in vllm/v1/worker/gpu_input_batch.py
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
    num_reqs = self.num_reqs
    max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
    prompt_token_ids_cpu_tensor = torch.empty(
        (self.num_reqs, max_prompt_len),
        device="cpu",
        dtype=torch.int64,
        pin_memory=self.pin_memory,
    )
    prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
    prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
    # Use the value of vocab_size as a pad since we don't have a
    # token_id of this value.
    for i in range(num_reqs):
        prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
    return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)

_make_sampling_metadata

_make_sampling_metadata() -> SamplingMetadata
Source code in vllm/v1/worker/gpu_input_batch.py
def _make_sampling_metadata(self) -> SamplingMetadata:
    num_reqs = self.num_reqs
    if not self.all_greedy:
        temperature = copy_slice(
            self.temperature_cpu_tensor, self.temperature, num_reqs
        )
    else:
        temperature = None
    if not self.no_top_p:
        copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
    if not self.no_top_k:
        copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

    if not self.no_penalties:
        # Since syncing these tensors is expensive only copy them
        # if necessary i.e. if there are requests which require
        # penalties to be applied during sampling.
        copy_slice(
            self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
        )
        copy_slice(
            self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
        )
        copy_slice(
            self.repetition_penalties_cpu_tensor,
            self.repetition_penalties,
            num_reqs,
        )

    needs_prompt_token_ids = (
        not self.no_penalties
        or self.logits_processing_needs_token_ids[:num_reqs].any()
    )
    # The prompt tokens are used only for applying penalties or
    # step pooling during the sampling/pooling process.
    # Hence copy these tensors only when there are requests which
    # need penalties/step_pooler to be applied.
    prompt_token_ids = (
        self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
    )

    # Only set output_token_ids if required by the current requests'
    # sampling parameters.
    needs_output_token_ids = (
        not self.no_penalties
        or bool(self.bad_words_token_ids)
        or self.logitsprocs_need_output_token_ids
    )
    output_token_ids = (
        cast(list[list[int]], self.req_output_token_ids)
        if needs_output_token_ids
        else []
    )

    allowed_token_ids_mask: torch.Tensor | None = None
    if not self.no_allowed_token_ids:
        assert self.allowed_token_ids_mask is not None
        copy_slice(
            self.allowed_token_ids_mask_cpu_tensor,
            self.allowed_token_ids_mask,
            num_reqs,
        )
        allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

    return SamplingMetadata(
        temperature=temperature,
        all_greedy=self.all_greedy,
        all_random=self.all_random,
        top_p=None if self.no_top_p else self.top_p[:num_reqs],
        top_k=None if self.no_top_k else self.top_k[:num_reqs],
        generators=self.generators,
        max_num_logprobs=self.max_num_logprobs,
        prompt_token_ids=prompt_token_ids,
        frequency_penalties=self.frequency_penalties[:num_reqs],
        presence_penalties=self.presence_penalties[:num_reqs],
        repetition_penalties=self.repetition_penalties[:num_reqs],
        output_token_ids=output_token_ids,
        spec_token_ids=self.spec_token_ids,
        no_penalties=self.no_penalties,
        allowed_token_ids_mask=allowed_token_ids_mask,
        bad_words_token_ids=self.bad_words_token_ids,
        logitsprocs=self.logitsprocs,
    )

_register_add_request

_register_add_request(request: CachedRequestState) -> int

Track add-request operations for logits processors. Not applicable to pooling models.

Source code in vllm/v1/worker/gpu_input_batch.py
def _register_add_request(self, request: "CachedRequestState") -> int:
    """Track add-request operations for logits processors.
    Not applicable to pooling models.
    """

    # Fill the next empty index if there is one.
    if (new_req_index := self.batch_update_builder.pop_removed()) is None:
        # Append to end otherwise.
        new_req_index = self.num_reqs

    assert new_req_index < self.max_num_reqs
    self.batch_update_builder.batch_changed = True
    if request.sampling_params:
        # Detailed added request metadata is only required for non-pooling
        # models, to support logitsprocs.
        self.batch_update_builder.added.append(
            (
                new_req_index,
                request.sampling_params,
                request.prompt_token_ids,
                request.output_token_ids,
            )
        )

    return new_req_index

add_request

add_request(request: CachedRequestState) -> int
Source code in vllm/v1/worker/gpu_input_batch.py
def add_request(
    self,
    request: "CachedRequestState",
) -> int:
    req_index = self._register_add_request(request)

    req_id = request.req_id
    if req_index == len(self._req_ids):
        self._req_ids.append(req_id)
        self.req_output_token_ids.append(request.output_token_ids)
        self.spec_token_ids.append([])
    else:
        self._req_ids[req_index] = req_id
        self.req_output_token_ids[req_index] = request.output_token_ids
        self.spec_token_ids[req_index].clear()

    self.req_id_to_index[req_id] = req_index

    # Copy the prompt token ids and output token ids.
    num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
        request.prompt_token_ids, request.prompt_embeds
    )
    self.num_prompt_tokens[req_index] = num_prompt_tokens
    start_idx = num_prompt_tokens
    end_idx = start_idx + len(request.output_token_ids)
    if request.prompt_token_ids is not None:
        self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
        self.is_token_ids[req_index, :num_prompt_tokens] = True
    else:
        self.is_token_ids[req_index, :num_prompt_tokens] = False
    if request.prompt_embeds is not None:
        self.req_prompt_embeds[req_index] = request.prompt_embeds
    self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
    self.is_token_ids[req_index, start_idx:end_idx] = True
    # Number of tokens without spec decode tokens.
    self.num_tokens_no_spec[req_index] = request.num_tokens

    self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
    self.block_table.add_row(request.block_ids, req_index)

    if sampling_params := request.sampling_params:
        if sampling_params.sampling_type == SamplingType.GREEDY:
            # Should avoid division by zero later when apply_temperature.
            self.temperature_cpu[req_index] = 0.0
            self.greedy_reqs.add(req_id)
        else:
            self.temperature_cpu[req_index] = sampling_params.temperature
            self.random_reqs.add(req_id)

        self.top_p_cpu[req_index] = sampling_params.top_p
        if sampling_params.top_p < 1:
            self.top_p_reqs.add(req_id)
        top_k = sampling_params.top_k
        if 0 < top_k < self.vocab_size:
            self.top_k_reqs.add(req_id)
        else:
            top_k = self.vocab_size
        self.top_k_cpu[req_index] = top_k
        self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
        if sampling_params.frequency_penalty != 0.0:
            self.frequency_penalties_reqs.add(req_id)
        self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
        if sampling_params.presence_penalty != 0.0:
            self.presence_penalties_reqs.add(req_id)
        self.repetition_penalties_cpu[req_index] = (
            sampling_params.repetition_penalty
        )
        if sampling_params.repetition_penalty != 1.0:
            self.repetition_penalties_reqs.add(req_id)

        # NOTE(woosuk): self.generators should not include the requests that
        # do not have their own generator.
        if request.generator is not None:
            self.generators[req_index] = request.generator

        if sampling_params.logprobs is not None:
            self.num_logprobs[req_id] = (
                self.vocab_size
                if sampling_params.logprobs == -1
                else sampling_params.logprobs
            )

        if sampling_params.allowed_token_ids:
            self.has_allowed_token_ids.add(req_id)
            if self.allowed_token_ids_mask_cpu_tensor is None:
                # Lazy allocation for this tensor, which can be large.
                # False means we don't fill with -inf.
                self.allowed_token_ids_mask = torch.zeros(
                    self.max_num_reqs,
                    self.vocab_size,
                    dtype=torch.bool,
                    device=self.device,
                )
                self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                    self.max_num_reqs,
                    self.vocab_size,
                    dtype=torch.bool,
                    device="cpu",
                )
            self.allowed_token_ids_mask_cpu_tensor[req_index] = True
            # False means we don't fill with -inf.
            self.allowed_token_ids_mask_cpu_tensor[req_index][
                sampling_params.allowed_token_ids
            ] = False

        if sampling_params.bad_words_token_ids:
            self.bad_words_token_ids[req_index] = (
                sampling_params.bad_words_token_ids
            )
    elif pooling_params := request.pooling_params:
        pooling_states = request.pooling_states
        assert pooling_states is not None

        self.pooling_params[req_id] = pooling_params
        self.pooling_states[req_id] = pooling_states
        self.logits_processing_needs_token_ids[req_index] = (
            pooling_params.requires_token_ids
        )
    else:
        raise NotImplementedError("Unrecognized request type")

    # Speculative decoding: by default 1 token is generated.
    self.num_accepted_tokens_cpu[req_index] = 1

    # Add request lora ID
    if request.lora_request:
        lora_id = request.lora_request.lora_int_id
        if lora_id not in self.lora_id_to_request_ids:
            self.lora_id_to_request_ids[lora_id] = set()

        self.request_lora_mapping[req_index] = lora_id
        self.lora_id_to_request_ids[lora_id].add(request.req_id)
        self.lora_id_to_lora_request[lora_id] = request.lora_request
    else:
        # No LoRA
        self.request_lora_mapping[req_index] = 0

    return req_index

condense

condense() -> None

Slide non-empty requests down into lower, empty indices.

Any consecutive empty indices at the very end of the list are not filled.

Returns:

Name Type Description
swaps None

list of (from,to) swap tuples for moved requests

empty_req_indices None

indices not filled by condensation

Source code in vllm/v1/worker/gpu_input_batch.py
def condense(self) -> None:
    """Slide non-empty requests down into lower, empty indices.

    Any consecutive empty indices at the very end of the list are not
    filled.

    Returns:
      swaps: list of (from,to) swap tuples for moved requests
      empty_req_indices: indices not filled by condensation
    """
    num_reqs = self.num_reqs

    if not (empty_req_indices := self.batch_update_builder.removed):
        # All removed requests were replaced by added requests, or else no
        # requests were removed at all. No condense() needed
        return
    if num_reqs == 0:
        # The batched states are empty.
        self._req_ids.clear()
        self.req_output_token_ids.clear()
        self.spec_token_ids.clear()
        return

    # NOTE(woosuk): This function assumes that the empty_req_indices
    # is sorted in descending order.
    last_req_index = num_reqs + len(empty_req_indices) - 1
    while empty_req_indices:
        # Find the largest non-empty index.
        while last_req_index in empty_req_indices:
            last_req_index -= 1

        # Find the smallest empty index.
        empty_index = self.batch_update_builder.peek_removed()
        assert empty_index is not None
        if empty_index >= last_req_index:
            break

        # Move active request down into empty request
        # index.
        self.batch_update_builder.pop_removed()
        req_id = self._req_ids[last_req_index]
        output_token_ids = self.req_output_token_ids[last_req_index]
        assert req_id is not None
        self._req_ids[empty_index] = req_id
        self._req_ids[last_req_index] = None
        self.req_output_token_ids[empty_index] = output_token_ids
        self.req_output_token_ids[last_req_index] = None
        self.req_id_to_index[req_id] = empty_index

        num_tokens = self.num_tokens_no_spec[last_req_index] + len(
            self.spec_token_ids[last_req_index]
        )

        (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
            self.spec_token_ids[empty_index],
            self.spec_token_ids[last_req_index],
        )
        self.spec_token_ids[last_req_index].clear()

        self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
            last_req_index, :num_tokens
        ]
        self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
            last_req_index, :num_tokens
        ]
        if last_req_index in self.req_prompt_embeds:
            self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                last_req_index
            )
        self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
            last_req_index
        ]
        self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
        self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
            last_req_index
        ]
        self.block_table.move_row(last_req_index, empty_index)

        self.request_lora_mapping[empty_index] = self.request_lora_mapping[
            last_req_index
        ]

        if self.is_pooling_model:
            last_req_index -= 1
            # Sampling state not used by pooling models.
            continue

        # Autoregressive models require detailed tracking of condense
        # operations to support logitsprocs
        self.batch_update_builder.moved.append(
            (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
        )

        self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
        self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
        self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
        self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
            last_req_index
        ]
        self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
            last_req_index
        ]
        self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
            last_req_index
        ]
        self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
            last_req_index
        ]
        generator = self.generators.pop(last_req_index, None)
        if generator is not None:
            self.generators[empty_index] = generator

        # TODO convert these to LogitsProcessors
        if self.allowed_token_ids_mask_cpu_tensor is not None:
            self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                self.allowed_token_ids_mask_cpu_tensor[last_req_index]
            )

        bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
        if bad_words_token_ids is not None:
            self.bad_words_token_ids[empty_index] = bad_words_token_ids

        # Decrement last_req_index since it is now empty.
        last_req_index -= 1

    # Trim lists to the batch size.
    del self._req_ids[num_reqs:]
    del self.req_output_token_ids[num_reqs:]
    del self.spec_token_ids[num_reqs:]

get_pooling_metadata

get_pooling_metadata() -> PoolingMetadata
Source code in vllm/v1/worker/gpu_input_batch.py
def get_pooling_metadata(self) -> PoolingMetadata:
    pooling_params = self.get_pooling_params()
    pooling_states = self.get_pooling_states()

    return PoolingMetadata(
        prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
        prompt_token_ids=self.sampling_metadata.prompt_token_ids,
        pooling_params=pooling_params,
        pooling_states=pooling_states,
    )

get_pooling_params

get_pooling_params() -> list[PoolingParams]
Source code in vllm/v1/worker/gpu_input_batch.py
def get_pooling_params(self) -> list[PoolingParams]:
    assert len(self.req_ids) == len(self.pooling_params)
    return [self.pooling_params[req_id] for req_id in self.req_ids]

get_pooling_states

get_pooling_states() -> list[PoolingStates]
Source code in vllm/v1/worker/gpu_input_batch.py
def get_pooling_states(self) -> list[PoolingStates]:
    assert len(self.req_ids) == len(self.pooling_states)
    return [self.pooling_states[req_id] for req_id in self.req_ids]

make_lora_inputs

make_lora_inputs(
    num_scheduled_tokens: ndarray,
    num_sampled_tokens: ndarray,
) -> tuple[
    tuple[int, ...], tuple[int, ...], set[LoRARequest]
]

Given the num_scheduled_tokens for each request in the batch, return datastructures used to activate the current LoRAs. Returns: 1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens) where, prompt_lora_mapping[i] is the LoRA id to use for the ith sampled token. 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) where, token_lora_mapping[i] is the LoRA id to use for ith token. 3. lora_requests: Set of relevant LoRA requests.

Source code in vllm/v1/worker/gpu_input_batch.py
def make_lora_inputs(
    self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
    """
    Given the num_scheduled_tokens for each request in the batch, return
    datastructures used to activate the current LoRAs.
    Returns:
        1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
           where, prompt_lora_mapping[i] is the LoRA id to use for the ith
           sampled token.
        2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
           where, token_lora_mapping[i] is the LoRA id to use for ith token.
        3. lora_requests: Set of relevant LoRA requests.
    """

    req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
    prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
    token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))

    active_lora_requests: set[LoRARequest] = set(
        self.lora_id_to_lora_request.values()
    )

    return prompt_lora_mapping, token_lora_mapping, active_lora_requests

refresh_metadata

refresh_metadata()

Apply any batch updates to sampling metadata.

Source code in vllm/v1/worker/gpu_input_batch.py
def refresh_metadata(self):
    """Apply any batch updates to sampling metadata."""

    if self.is_pooling_model:
        batch_changed = self.batch_update_builder.reset()
        if batch_changed:
            self.sampling_metadata = self._make_sampling_metadata()
        return

    # For non-pooling models - generate and apply logitsprocs update;
    # reset batch update tracking.
    # Update sampling metadata if batch state is changed.
    batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
    for logit_proc in self.logitsprocs.all:
        logit_proc.update_state(batch_update)
    if batch_update:
        self.sampling_metadata = self._make_sampling_metadata()

remove_request

remove_request(req_id: str) -> int | None

This method must always be followed by a call to condense().

Parameters:

Name Type Description Default
req_id str

request to remove

required

Returns:

Type Description
int | None

Removed request index, or None if req_id not recognized

Source code in vllm/v1/worker/gpu_input_batch.py
def remove_request(self, req_id: str) -> int | None:
    """This method must always be followed by a call to condense().

    Args:
      req_id: request to remove

    Returns:
      Removed request index, or `None` if `req_id` not recognized
    """

    req_index = self.req_id_to_index.pop(req_id, None)
    if req_index is None:
        return None

    self.batch_update_builder.removed_append(req_index)
    self._req_ids[req_index] = None
    self.req_output_token_ids[req_index] = None
    self.spec_token_ids[req_index].clear()

    # LoRA
    lora_id = self.request_lora_mapping[req_index]
    if lora_id != 0:
        lora_req_ids = self.lora_id_to_request_ids[lora_id]
        lora_req_ids.discard(req_id)
        if not lora_req_ids:
            del self.lora_id_to_request_ids[lora_id]
            del self.lora_id_to_lora_request[lora_id]
        self.request_lora_mapping[req_index] = 0

    if self.is_pooling_model:
        self.pooling_params.pop(req_id, None)
        self.pooling_states.pop(req_id, None)
        return req_index

    self.greedy_reqs.discard(req_id)
    self.random_reqs.discard(req_id)
    self.top_p_reqs.discard(req_id)
    self.top_k_reqs.discard(req_id)
    self.frequency_penalties_reqs.discard(req_id)
    self.presence_penalties_reqs.discard(req_id)
    self.repetition_penalties_reqs.discard(req_id)
    self.generators.pop(req_index, None)
    self.num_logprobs.pop(req_id, None)
    self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
    if self.prev_req_id_to_index is not None:
        self.prev_req_id_to_index.pop(req_id, None)

    self.has_allowed_token_ids.discard(req_id)
    if self.allowed_token_ids_mask_cpu_tensor is not None:
        # False means we don't fill with -inf.
        self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
    self.bad_words_token_ids.pop(req_index, None)
    return req_index

set_async_sampled_token_ids

set_async_sampled_token_ids(
    sampled_token_ids_cpu: Tensor,
    async_copy_ready_event: Event,
) -> None

In async scheduling case, store ref to sampled_token_ids_cpu tensor and corresponding copy-ready event. Used to repair output_token_ids prior to sampling, if needed by logits processors.

Source code in vllm/v1/worker/gpu_input_batch.py
def set_async_sampled_token_ids(
    self,
    sampled_token_ids_cpu: torch.Tensor,
    async_copy_ready_event: torch.Event,
) -> None:
    """
    In async scheduling case, store ref to sampled_token_ids_cpu
    tensor and corresponding copy-ready event. Used to repair
    output_token_ids prior to sampling, if needed by logits processors.
    """
    if self.sampling_metadata.output_token_ids:
        self.sampled_token_ids_cpu = sampled_token_ids_cpu
        self.async_copy_ready_event = async_copy_ready_event
    else:
        self.sampled_token_ids_cpu = None
        self.async_copy_ready_event = None

swap_states

swap_states(i1: int, i2: int) -> None
Source code in vllm/v1/worker/gpu_input_batch.py
def swap_states(self, i1: int, i2: int) -> None:
    old_id_i1 = self._req_ids[i1]
    old_id_i2 = self._req_ids[i2]
    self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
    self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
        self.req_output_token_ids[i2],
        self.req_output_token_ids[i1],
    )
    self.spec_token_ids[i1], self.spec_token_ids[i2] = (
        self.spec_token_ids[i2],
        self.spec_token_ids[i1],
    )
    assert old_id_i1 is not None and old_id_i2 is not None
    self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
        self.req_id_to_index[old_id_i2],
        self.req_id_to_index[old_id_i1],
    )
    self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
        self.num_tokens_no_spec[i2],
        self.num_tokens_no_spec[i1],
    )
    self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
        self.num_prompt_tokens[i2],
        self.num_prompt_tokens[i1],
    )
    self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
        self.num_computed_tokens_cpu[i2],
        self.num_computed_tokens_cpu[i1],
    )

    # NOTE: the following is unsafe
    # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
    #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
    # instead, we need to temporarily copy the data for one of the indices
    # TODO(lucas): optimize this by only copying valid indices
    tmp = self.token_ids_cpu[i1, ...].copy()
    self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
    self.token_ids_cpu[i2, ...] = tmp

    self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

    # Swap prompt embeddings if they exist
    embeds_i1 = self.req_prompt_embeds.get(i1)
    embeds_i2 = self.req_prompt_embeds.get(i2)
    if embeds_i1 is not None:
        self.req_prompt_embeds[i2] = embeds_i1
    else:
        self.req_prompt_embeds.pop(i2, None)
    if embeds_i2 is not None:
        self.req_prompt_embeds[i1] = embeds_i2
    else:
        self.req_prompt_embeds.pop(i1, None)

    self.block_table.swap_row(i1, i2)

    self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
        self.request_lora_mapping[i2],
        self.request_lora_mapping[i1],
    )

    if self.is_pooling_model:
        # Sampling and logits parameters don't apply to pooling models.
        return

    # For autoregressive models, track detailed request reordering info
    # to support logitsprocs.
    self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

    self.temperature_cpu[i1], self.temperature_cpu[i2] = (
        self.temperature_cpu[i2],
        self.temperature_cpu[i1],
    )
    self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
    self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
    self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
        self.frequency_penalties_cpu[i2],
        self.frequency_penalties_cpu[i1],
    )
    self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
        self.presence_penalties_cpu[i2],
        self.presence_penalties_cpu[i1],
    )
    self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
        self.repetition_penalties_cpu[i2],
        self.repetition_penalties_cpu[i1],
    )
    self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
        self.num_accepted_tokens_cpu[i2],
        self.num_accepted_tokens_cpu[i1],
    )

    swap_dict_values(self.generators, i1, i2)
    swap_dict_values(self.bad_words_token_ids, i1, i2)

    if self.allowed_token_ids_mask_cpu_tensor is not None:
        (
            self.allowed_token_ids_mask_cpu_tensor[i1],
            self.allowed_token_ids_mask_cpu_tensor[i2],
        ) = (
            self.allowed_token_ids_mask_cpu_tensor[i2],
            self.allowed_token_ids_mask_cpu_tensor[i1],
        )

update_async_output_token_ids

update_async_output_token_ids() -> None

In async scheduling case, update output_token_ids in sampling metadata from prior steps sampled token ids once they've finished copying to CPU. This is called right before they are needed by the logits processors.

Source code in vllm/v1/worker/gpu_input_batch.py
def update_async_output_token_ids(self) -> None:
    """
    In async scheduling case, update output_token_ids in sampling metadata
    from prior steps sampled token ids once they've finished copying to CPU.
    This is called right before they are needed by the logits processors.
    """
    output_token_ids = self.sampling_metadata.output_token_ids
    if self.sampled_token_ids_cpu is None or not output_token_ids:
        # Output token ids not needed or not async scheduling.
        return

    assert self.prev_req_id_to_index is not None
    sampled_token_ids = None
    for index, req_id in enumerate(self.req_ids):
        prev_index = self.prev_req_id_to_index.get(req_id)
        if prev_index is None:
            continue
        req_output_token_ids = output_token_ids[index]
        if not req_output_token_ids or req_output_token_ids[-1] != -1:
            # Final output id is not a placeholder, some tokens must have
            # been discarded after a kv-load failure.
            continue
        if sampled_token_ids is None:
            assert self.async_copy_ready_event is not None
            self.async_copy_ready_event.synchronize()
            sampled_token_ids = self.sampled_token_ids_cpu.tolist()
        # Replace placeholder token id(s) with actual sampled id(s).
        new_ids: list[int] = sampled_token_ids[prev_index]
        if not new_ids:
            continue
        num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
        # Also account for case where there may be a smaller number of
        # output placeholders (tokens can be discarded after a kv-load failure).
        first_placeholder = req_output_token_ids.index(-1)
        num_placeholders = len(req_output_token_ids) - first_placeholder
        num_to_replace = min(num_sampled_ids, num_placeholders)
        del new_ids[num_to_replace:]
        end_index = first_placeholder + num_to_replace
        req_output_token_ids[first_placeholder:end_index] = new_ids

update_async_spec_token_ids

update_async_spec_token_ids(
    draft_token_ids: list[list[int]],
) -> None

In async scheduling case, update spec_token_ids in sampling metadata with real draft token ids from prior step. This is called right before they are needed by the rejection sampler for penalty/bad_words computation.

Source code in vllm/v1/worker/gpu_input_batch.py
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
    """
    In async scheduling case, update spec_token_ids in sampling metadata with
    real draft token ids from prior step. This is called right before they are
    needed by the rejection sampler for penalty/bad_words computation.
    """
    if not draft_token_ids or not self.prev_req_id_to_index:
        return

    if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
        for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
            if spec_ids:
                prev_index = self.prev_req_id_to_index.get(req_id)
                if prev_index is not None:
                    draft_ids = draft_token_ids[prev_index]
                    if draft_ids:
                        del draft_ids[len(spec_ids) :]
                        spec_ids.clear()
                        spec_ids.extend(draft_ids)

update_req_spec_token_ids

update_req_spec_token_ids(
    request: CachedRequestState,
    scheduled_spec_tokens: dict[str, list[int]],
) -> None
Source code in vllm/v1/worker/gpu_input_batch.py
def update_req_spec_token_ids(
    self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
) -> None:
    req_id = request.req_id
    req_index = self.req_id_to_index[req_id]
    cur_spec_token_ids = self.spec_token_ids[req_index]
    # When speculative decoding is used with structured output,
    # the scheduler can drop draft tokens that do not
    # conform to the schema. This can result in
    # scheduler_output.scheduled_spec_decode_tokens being empty,
    # even when speculative decoding is enabled.
    cur_spec_token_ids.clear()
    spec_token_ids = scheduled_spec_tokens.get(req_id, ())
    num_spec_tokens = len(spec_token_ids)
    request.prev_num_draft_len = num_spec_tokens
    if not spec_token_ids:
        return

    # For async scheduling, token_ids_cpu assigned from
    # spec_token_ids are placeholders and will be overwritten in
    # _prepare_input_ids.
    start_index = self.num_tokens_no_spec[req_index]
    end_token_index = start_index + num_spec_tokens
    self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
    cur_spec_token_ids.extend(spec_token_ids)