Skip to content

Score

Source https://github.com/vllm-project/vllm/tree/main/examples/pooling/score.

Cohere Rerank Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
Note that `pip install cohere` is needed to run this example.

run: vllm serve BAAI/bge-reranker-base
"""

import cohere
from cohere import Client, ClientV2

model = "BAAI/bge-reranker-base"

query = "What is the capital of France?"

documents = [
    "The capital of France is Paris",
    "Reranking is fun!",
    "vLLM is an open-source framework for fast AI serving",
]


def cohere_rerank(
    client: Client | ClientV2, model: str, query: str, documents: list[str]
) -> dict:
    return client.rerank(model=model, query=query, documents=documents)


def main():
    # cohere v1 client
    cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
    rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
    print("-" * 50)
    print("rerank_v1_result:\n", rerank_v1_result)
    print("-" * 50)

    # or the v2
    cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
    rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
    print("rerank_v2_result:\n", rerank_v2_result)
    print("-" * 50)


if __name__ == "__main__":
    main()

Convert Model To Seq Cls

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

"""
Script to convert Large Language Models (LLMs) to Sequence Classification models.
This is particularly useful for converting reranker models that use next-token
prediction to a sequence classification format for compatibility with standard
classification and rerank pipelines.

Usage examples:
- For BAAI/bge-reranker-v2-gemma:
  python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma \
    --classifier_from_tokens '["Yes"]' --method no_post_processing \
    --path ./bge-reranker-v2-gemma-seq-cls

- For mxbai-rerank-v2:
  python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 \
    --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax \
    --path ./mxbai-rerank-base-v2-seq-cls

- For Qwen3-Reranker:
  python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B \
    --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax \
    --path ./Qwen3-Reranker-0.6B-seq-cls

Note: For BAAI/bge-reranker-v2-gemma, "Yes" and "yes" are different tokens.
"""

import argparse
import json

import torch
import transformers


def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
    """
    This method extracts the difference between weights for 'true' and 'false' tokens
    from the language model head to create a single classification weight vector.

    Args:
        causal_lm: The original causal language model
        seq_cls_model: The target sequence classification model
        tokenizer: Model tokenizer
        tokens: List of two tokens representing [false_token, true_token]
        device: Target device (cpu/cuda)

    Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
    """
    assert len(tokens) == 2, (
        "Method requires exactly two tokens for binary classification"
    )

    # Get the language model head weights (vocabulary_size x hidden_size)
    lm_head_weights = causal_lm.lm_head.weight

    # Convert token strings to their corresponding token IDs
    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])

    # Compute the classification weight as the difference between true and false token weights
    # This follows the approach in: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
    score_weight = lm_head_weights[true_id].to(device).to(
        torch.float32
    ) - lm_head_weights[false_id].to(device).to(torch.float32)

    # Copy the computed weights to the sequence classification model
    with torch.no_grad():
        seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
        if seq_cls_model.score.bias is not None:
            seq_cls_model.score.bias.zero_()


def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
    """
    Directly use token weights from the language model head for classification.

    This method maps each classification label directly to a corresponding token
    in the vocabulary without additional transformation.

    Args:
        causal_lm: The original causal language model
        seq_cls_model: The target sequence classification model
        tokenizer: Model tokenizer
        tokens: List of tokens representing class labels
        device: Target device (cpu/cuda)
    """
    # Get the language model head weights (vocabulary_size x hidden_size)
    lm_head_weights = causal_lm.lm_head.weight

    # Convert all tokens to their corresponding token IDs
    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]

    # Extract weights for the specific tokens (num_tokens x hidden_size)
    score_weight = lm_head_weights[token_ids].to(device)

    # Copy the weights to the sequence classification model
    with torch.no_grad():
        seq_cls_model.score.weight.copy_(score_weight)
        if seq_cls_model.score.bias is not None:
            seq_cls_model.score.bias.zero_()


method_map = {
    function.__name__: function for function in [from_2_way_softmax, no_post_processing]
}


def converting(
    model_name, classifier_from_tokens, path, method, use_sep_token=False, device="cpu"
):
    """
    Main conversion function to transform a CausalLM model to SequenceClassification.

    Args:
        model_name: Name or path of the pretrained model
        classifier_from_tokens: List of tokens used for classification
        path: Output path to save the converted model
        method: Conversion method ('from_2_way_softmax' or 'no_post_processing')
        use_sep_token: Whether to use separating token in the sequence classification model
        device: Device to load the model on ('cpu' or 'cuda')
    """
    assert method in method_map, f"Unknown method: {method}"

    # Determine number of labels based on conversion method
    if method == "from_2_way_softmax":
        assert len(classifier_from_tokens) == 2
        num_labels = 1
    else:
        num_labels = len(classifier_from_tokens)

    # Load tokenizer and original causal language model
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device
    )

    # Load an empty sequence classification model with the same architecture
    seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
        device_map=device,
    )

    # Apply the selected conversion method to transfer weights
    method_map[method](
        causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
    )

    # Configure separating token settings
    # Note: `llm as reranker` defaults to not using separating token.
    seq_cls_model.config.use_sep_token = use_sep_token
    seq_cls_model.config.sep_token_id = tokenizer.sep_token_id

    # Save the converted model and tokenizer
    seq_cls_model.save_pretrained(path)
    tokenizer.save_pretrained(path)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Converting *ForCausalLM models to "
        "*ForSequenceClassification models."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="BAAI/bge-reranker-v2-gemma",
        help="HuggingFace model name or local path",
    )
    parser.add_argument(
        "--classifier_from_tokens",
        type=str,
        default='["Yes"]',
        help="JSON string of tokens used for classification labels",
    )
    parser.add_argument(
        "--method",
        type=str,
        default="no_post_processing",
        help="Conversion method to use",
    )
    parser.add_argument(
        "--use-pad-token",
        action="store_true",
        help="Enable padding token in the sequence classification model",
    )
    parser.add_argument(
        "--path",
        type=str,
        default="./bge-reranker-v2-gemma-seq-cls",
        help="Output directory to save the converted model",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    converting(
        model_name=args.model_name,
        classifier_from_tokens=json.loads(args.classifier_from_tokens),
        method=args.method,
        use_sep_token=args.use_sep_token,
        path=args.path,
    )

Qwen3 Reranker Offline

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?

Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.

A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py

For the converted model, initialization would simply be:
llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")

This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
"""

from pathlib import Path

from vllm import LLM

model_name = "Qwen/Qwen3-Reranker-0.6B"


def get_llm() -> LLM:
    """
    Initializes and returns the LLM model for Qwen3-Reranker.

    Returns:
        LLM: Configured vLLM instance for reranking tasks.

    Note:
        This function loads the ORIGINAL Qwen3-Reranker model with specific
        overrides to make it compatible with vLLM's score API.
    """
    return LLM(
        # Specify the original model from HuggingFace
        model=model_name,
        # Use pooling runner for score task
        runner="pooling",
        # HuggingFace model configuration overrides required for compatibility
        hf_overrides={
            # Manually route to sequence classification architecture
            # This tells vLLM to use Qwen3ForSequenceClassification instead of
            # the default Qwen3ForCausalLM
            "architectures": ["Qwen3ForSequenceClassification"],
            # Specify which token logits to extract from the language model head
            # The original reranker uses "no" and "yes" token logits for scoring
            "classifier_from_token": ["no", "yes"],
            # Enable special handling for original Qwen3-Reranker models
            # This flag triggers conversion logic that transforms the two token
            # vectors into a single classification vector
            "is_original_qwen3_reranker": True,
        },
    )


def main() -> None:
    # Load the Jinja template for formatting query-document pairs
    # The template ensures proper formatting for the reranker model
    template_home = Path(__file__).parent / "template"
    template_path = "qwen3_reranker.jinja"
    chat_template = (template_home / template_path).read_text()

    # Sample queries for testing the reranker
    queries = [
        "What is the capital of China?",
        "Explain gravity",
    ]

    # Corresponding documents to be scored against each query
    documents = [
        "The capital of China is Beijing.",
        "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
    ]

    # Initialize the LLM model with the original Qwen3-Reranker configuration
    llm = get_llm()

    # Compute relevance scores for each query-document pair
    # The score() method returns a relevance score for each pair
    # Higher scores indicate better relevance
    outputs = llm.score(queries, documents, chat_template=chat_template)

    # Extract and print the relevance scores from the outputs
    # Each output contains a score representing query-document relevance
    print("-" * 30)
    print("Relevance scores:", [output.outputs.score for output in outputs])
    print("-" * 30)


if __name__ == "__main__":
    main()

Qwen3 Reranker Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?

Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.

A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py

For the converted model, initialization would simply be:
    vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --runner pooling --chat-template examples/pooling/score/template/qwen3_reranker.jinja

This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
    vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
"""

import json

import requests

# URL of the vLLM server's score endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/score"

# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"}

# Example queries & documents
queries = [
    "What is the capital of China?",
    "Explain gravity",
]
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

# Request payload for the score API
data = {
    "model": "Qwen/Qwen3-Reranker-0.6B",
    "text_1": queries,
    "text_2": documents,
}


def main():
    """Main function to send a score request to the vLLM server.

    This function sends a POST request to the /score endpoint with
    the query and documents, then prints the relevance scores.
    """
    # Send POST request to the vLLM server's score endpoint
    response = requests.post(url, headers=headers, json=data)

    # Check if the request was successful
    if response.status_code == 200:
        print("Request successful!")
        # Pretty print the JSON response containing relevance scores
        # The response includes scores for each document's relevance to the query
        print(json.dumps(response.json(), indent=2))
    else:
        # Handle request failure
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)


if __name__ == "__main__":
    main()

Rerank API Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker

run: vllm serve BAAI/bge-reranker-base
"""

import json

import requests

url = "http://127.0.0.1:8000/rerank"

headers = {"accept": "application/json", "Content-Type": "application/json"}

data = {
    "model": "BAAI/bge-reranker-base",
    "query": "What is the capital of France?",
    "documents": [
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
        "Horses and cows are both animals",
    ],
}


def main():
    response = requests.post(url, headers=headers, json=data)

    # Check the response
    if response.status_code == 200:
        print("Request successful!")
        print(json.dumps(response.json(), indent=2))
    else:
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)


if __name__ == "__main__":
    main()

Score API Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Score API.

Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
"""

import argparse
import pprint

import requests


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
    return parser.parse_args()


def main(args):
    api_url = f"http://{args.host}:{args.port}/score"
    model_name = args.model

    text_1 = "What is the capital of Brazil?"
    text_2 = "The capital of Brazil is Brasilia."
    prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
    score_response = post_http_request(prompt=prompt, api_url=api_url)
    print("\nPrompt when text_1 and text_2 are both strings:")
    pprint.pprint(prompt)
    print("\nScore Response:")
    pprint.pprint(score_response.json())

    text_1 = "What is the capital of France?"
    text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
    prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
    score_response = post_http_request(prompt=prompt, api_url=api_url)
    print("\nPrompt when text_1 is string and text_2 is a list:")
    pprint.pprint(prompt)
    print("\nScore Response:")
    pprint.pprint(score_response.json())

    text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
    text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
    prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
    score_response = post_http_request(prompt=prompt, api_url=api_url)
    print("\nPrompt when text_1 and text_2 are both lists:")
    pprint.pprint(prompt)
    print("\nScore Response:")
    pprint.pprint(score_response.json())


if __name__ == "__main__":
    args = parse_args()
    main(args)

Template - Bge-Reranker-V2-Gemma

A: {{ (messages | selectattr("role", "eq", "query") | first).content }}
B: {{ (messages | selectattr("role", "eq", "document") | first).content }}
Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.

Template - Mxbai Rerank V2

<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
query: {{ (messages | selectattr("role", "eq", "query") | first).content }}
document: {{ (messages | selectattr("role", "eq", "document") | first).content }}
You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
Relevance:<|im_end|>
<|im_start|>assistant

Template - Nemotron-Rerank

question:{{ (messages | selectattr("role", "eq", "query") | first).content }} 

 passage:{{ (messages | selectattr("role", "eq", "document") | first).content }}

Template - Qwen3 Reranker

<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }}
<Query>: {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }}
<Document>: {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|>
<|im_start|>assistant
<think>

</think>

Template - Qwen3 Vl Reranker

<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {{
    messages
    | selectattr("role", "eq", "system")
    | map(attribute="content")
    | first
    | default("Given a search query, retrieve relevant candidates that answer the query.")
}}<Query>:{{
    messages
    | selectattr("role", "eq", "query")
    | map(attribute="content")
    | first
}}
<Document>:{{
    messages
    | selectattr("role", "eq", "document")
    | map(attribute="content")
    | first
}}<|im_end|>
<|im_start|>assistant

Using Template Offline

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from argparse import Namespace
from pathlib import Path
from typing import Any

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    """Parse command line arguments for the reranking example.

    This function sets up the argument parser with default values
    specific to reranking models, including the model name and
    runner type.
    """
    parser = FlexibleArgumentParser()
    # Add all EngineArgs command line arguments to the parser
    parser = EngineArgs.add_cli_args(parser)

    # Set default values specific to this reranking example
    # These defaults ensure the script works out-of-the-box for reranking tasks
    parser.set_defaults(
        model="nvidia/llama-nemotron-rerank-1b-v2",  # Default reranking model
        runner="pooling",  # Required for cross-encoder/reranking models
        trust_remote_code=True,  # Allow loading models with custom code
    )
    return parser.parse_args()


def get_chat_template(model: str) -> str:
    """Load the appropriate chat template for the specified model.

    Reranking models require specific prompt templates to format
    query-document pairs correctly. This function maps model names
    to their corresponding template files.
    """
    # Directory containing all chat template files
    template_home = Path(__file__).parent / "template"

    # Mapping from model names to their corresponding template files
    # Each reranking model has its own specific prompt format
    model_name_to_template_path_map = {
        "BAAI/bge-reranker-v2-gemma": "bge-reranker-v2-gemma.jinja",
        "Qwen/Qwen3-Reranker-0.6B": "qwen3_reranker.jinja",
        "Qwen/Qwen3-Reranker-4B": "qwen3_reranker.jinja",
        "Qwen/Qwen3-Reranker-8B": "qwen3_reranker.jinja",
        "tomaarsen/Qwen3-Reranker-0.6B-seq-cls": "qwen3_reranker.jinja",
        "tomaarsen/Qwen3-Reranker-4B-seq-cls": "qwen3_reranker.jinja",
        "tomaarsen/Qwen3-Reranker-8B-seq-cls": "qwen3_reranker.jinja",
        "mixedbread-ai/mxbai-rerank-base-v2": "mxbai_rerank_v2.jinja",
        "mixedbread-ai/mxbai-rerank-large-v2": "mxbai_rerank_v2.jinja",
        "nvidia/llama-nemotron-rerank-1b-v2": "nemotron-rerank.jinja",
    }

    # Get the template filename for the specified model
    template_path = model_name_to_template_path_map.get(model)

    if template_path is None:
        raise ValueError(f"This demo does not support model name: {model}.")

    # Read and return the template content
    return (template_home / template_path).read_text()


def get_hf_overrides(model: str) -> dict[str, Any]:
    """Convert Large Language Models (LLMs) to Sequence Classification models.

    note:
        Some reranking models require special configuration overrides to work
        correctly with vLLM's score API.
        Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_offline.py
        Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
    """

    model_name_to_hf_overrides_map = {
        "BAAI/bge-reranker-v2-gemma": {
            "architectures": ["GemmaForSequenceClassification"],
            "classifier_from_token": ["Yes"],
            "method": "no_post_processing",
        },
        "Qwen/Qwen3-Reranker-0.6B": {
            "architectures": ["Qwen3ForSequenceClassification"],
            "classifier_from_token": ["no", "yes"],
            "is_original_qwen3_reranker": True,
        },
        "Qwen/Qwen3-Reranker-4B": {
            "architectures": ["Qwen3ForSequenceClassification"],
            "classifier_from_token": ["no", "yes"],
            "is_original_qwen3_reranker": True,
        },
        "Qwen/Qwen3-Reranker-8B": {
            "architectures": ["Qwen3ForSequenceClassification"],
            "classifier_from_token": ["no", "yes"],
            "is_original_qwen3_reranker": True,
        },
        "tomaarsen/Qwen3-Reranker-0.6B-seq-cls": {},
        "tomaarsen/Qwen3-Reranker-4B-seq-cls": {},
        "tomaarsen/Qwen3-Reranker-8B-seq-cls": {},
        "mixedbread-ai/mxbai-rerank-base-v2": {
            "architectures": ["Qwen2ForSequenceClassification"],
            "classifier_from_token": ["0", "1"],
            "method": "from_2_way_softmax",
        },
        "mixedbread-ai/mxbai-rerank-large-v2": {
            "architectures": ["Qwen2ForSequenceClassification"],
            "classifier_from_token": ["0", "1"],
            "method": "from_2_way_softmax",
        },
        "nvidia/llama-nemotron-rerank-1b-v2": {},
    }

    hf_overrides = model_name_to_hf_overrides_map.get(model)

    if hf_overrides is None:
        raise ValueError(f"This demo does not support model name: {model}.")

    return hf_overrides


def main(args: Namespace):
    """Main execution function for the reranking example."""

    # Get the overrides for the specified model
    args.hf_overrides = get_hf_overrides(args.model)

    # Initialize the LLM with all provided arguments
    llm = LLM(**vars(args))

    # Example query for demonstration
    query = "how much protein should a female eat?"

    # Example documents to be reranked based on relevance to the query
    documents = [
        "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
        "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.",
        "Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
    ]

    # Load the appropriate chat template for the selected model
    # The template formats query-document pairs for the reranking model
    chat_template = get_chat_template(args.model)

    # Score documents based on relevance to the query
    # The score method returns relevance scores for each document
    outputs = llm.score(query, documents, chat_template=chat_template)

    # Display the relevance scores
    # Higher scores indicate more relevant documents
    print("-" * 30)
    print([output.outputs.score for output in outputs])
    print("-" * 30)


if __name__ == "__main__":
    args = parse_args()
    main(args)

Using Template Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using the rerank API with template.

This script demonstrates how to interact with a vLLM server running
a reranking model via the REST API.
Before running this script, start the vLLM server with one of the
supported reranking models using the commands below.

note:
    Some reranking models require special configuration overrides to work correctly
    with vLLM's score API.
    Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_online.py
    Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py

run:
    vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' --chat-template examples/pooling/score/template/bge-reranker-v2-gemma.jinja
    vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --chat-template examples/pooling/score/template/qwen3_reranker.jinja
    vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}' --chat-template examples/pooling/score/template/mxbai_rerank_v2.jinja
    vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
    vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
"""

import json

import requests

# URL of the vLLM server's rerank endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/rerank"

# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"}

# Example query & documents
query = "how much protein should a female eat?"
documents = [
    "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.",
    "Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]

# Request payload for the rerank API
data = {
    "model": "nvidia/llama-nemotron-rerank-1b-v2",  # Model to use for reranking
    "query": query,  # The query to score documents against
    "documents": documents,  # List of documents to be scored
}


def main():
    """Main function to send a rerank request to the vLLM server.

    This function sends a POST request to the /rerank endpoint with
    the query and documents, then prints the relevance scores.
    """
    # Send POST request to the vLLM server's rerank endpoint
    response = requests.post(url, headers=headers, json=data)

    # Check if the request was successful
    if response.status_code == 200:
        print("Request successful!")
        # Pretty print the JSON response containing relevance scores
        # The response includes scores for each document's relevance to the query
        print(json.dumps(response.json(), indent=2))
    else:
        # Handle request failure
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)


if __name__ == "__main__":
    main()

Vision Rerank API Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

"""
Example Python client for multimodal rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker

Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
e.g.
    vllm serve jinaai/jina-reranker-m0 --runner pooling

    vllm serve Qwen/Qwen3-VL-Reranker-2B \
        --runner pooling \
        --max-model-len 4096 \
        --hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
        --chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
"""

import argparse
import json

import requests

headers = {"accept": "application/json", "Content-Type": "application/json"}

query = "A woman playing with her dog on a beach at sunset."
documents = {
    "content": [
        {
            "type": "text",
            "text": (
                "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, "  # noqa: E501
                "as the dog offers its paw in a heartwarming display of companionship and trust."  # noqa: E501
            ),
        },
        {
            "type": "image_url",
            "image_url": {
                "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
            },
        },
    ]
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    return parser.parse_args()


def main(args):
    base_url = f"http://{args.host}:{args.port}"
    models_url = base_url + "/v1/models"
    rerank_url = base_url + "/rerank"

    response = requests.get(models_url, headers=headers)
    model = response.json()["data"][0]["id"]

    data = {
        "model": model,
        "query": query,
        "documents": documents,
    }
    response = requests.post(rerank_url, headers=headers, json=data)

    # Check the response
    if response.status_code == 200:
        print("Request successful!")
        print(json.dumps(response.json(), indent=2))
    else:
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)


if __name__ == "__main__":
    args = parse_args()
    main(args)

Vision Reranker Offline

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference with
vision language reranker models for multimodal scoring tasks.

Vision language rerankers score the relevance between a text query and
multimodal documents (text + images/videos).
"""

from argparse import Namespace
from collections.abc import Callable
from dataclasses import asdict
from pathlib import Path
from typing import NamedTuple

from vllm import LLM, EngineArgs
from vllm.entrypoints.score_utils import ScoreMultiModalParam
from vllm.utils.argparse_utils import FlexibleArgumentParser

TEMPLATE_HOME = Path(__file__).parent / "template"


class RerankModelData(NamedTuple):
    engine_args: EngineArgs
    chat_template: str | None = None


def run_jinavl_reranker(modality: str) -> RerankModelData:
    assert modality == "image"

    engine_args = EngineArgs(
        model="jinaai/jina-reranker-m0",
        runner="pooling",
        max_model_len=32768,
        trust_remote_code=True,
        mm_processor_kwargs={
            "min_pixels": 3136,
            "max_pixels": 602112,
        },
        limit_mm_per_prompt={modality: 1},
    )
    return RerankModelData(
        engine_args=engine_args,
    )


def run_qwen3_vl_reranker(modality: str) -> RerankModelData:
    engine_args = EngineArgs(
        model="Qwen/Qwen3-VL-Reranker-2B",
        runner="pooling",
        max_model_len=16384,
        limit_mm_per_prompt={modality: 1},
        # HuggingFace model configuration overrides required for compatibility
        hf_overrides={
            # Manually route to sequence classification architecture
            # This tells vLLM to use Qwen3VLForSequenceClassification instead of
            # the default Qwen3VLForConditionalGeneration
            "architectures": ["Qwen3VLForSequenceClassification"],
            # Specify which token logits to extract from the language model head
            # The original reranker uses "no" and "yes" token logits for scoring
            "classifier_from_token": ["no", "yes"],
            # Enable special handling for original Qwen3-Reranker models
            # This flag triggers conversion logic that transforms the two token
            # vectors into a single classification vector
            "is_original_qwen3_reranker": True,
        },
    )
    chat_template_path = "qwen3_vl_reranker.jinja"
    chat_template = (TEMPLATE_HOME / chat_template_path).read_text()
    return RerankModelData(
        engine_args=engine_args,
        chat_template=chat_template,
    )


model_example_map: dict[str, Callable[[str], RerankModelData]] = {
    "jinavl_reranker": run_jinavl_reranker,
    "qwen3_vl_reranker": run_qwen3_vl_reranker,
}


def parse_args():
    parser = FlexibleArgumentParser(
        description="Demo on using vLLM for offline inference with "
        "vision language reranker models for multimodal scoring tasks."
    )
    parser.add_argument(
        "--model-name",
        "-m",
        type=str,
        default="jinavl_reranker",
        choices=model_example_map.keys(),
        help="The name of the reranker model.",
    )
    parser.add_argument(
        "--modality",
        type=str,
        default="image",
        choices=["image", "video"],
        help="Modality of the multimodal input (image or video).",
    )
    return parser.parse_args()


def get_multi_modal_input(modality: str) -> tuple[str, ScoreMultiModalParam]:
    # Sample query for testing the reranker
    if modality == "image":
        query = "A woman playing with her dog on a beach at sunset."
        # Sample multimodal documents to be scored against the query
        # Each document contains an image URL that will be fetched and processed
        documents: ScoreMultiModalParam = {
            "content": [
                {
                    "type": "text",
                    "text": (
                        "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, "  # noqa: E501
                        "as the dog offers its paw in a heartwarming display of companionship and trust."  # noqa: E501
                    ),
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
                    },
                },
            ]
        }
    elif modality == "video":
        query = "A girl is drawing pictures on an ipad."
        # Sample video documents to be scored against the query
        documents: ScoreMultiModalParam = {
            "content": [
                {
                    "type": "text",
                    "text": "A girl is drawing a guitar on her ipad with Apple Pencil.",
                },
                {
                    "type": "video_url",
                    "video_url": {
                        "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4"
                    },
                },
            ]
        }
    else:
        raise ValueError(f"Unsupported modality: {modality}")
    return query, documents


def main(args: Namespace):
    # Run the selected reranker model
    modality = args.modality
    model_request = model_example_map[args.model_name](modality)
    engine_args = model_request.engine_args

    llm = LLM(**asdict(engine_args))

    query, documents = get_multi_modal_input(modality)
    outputs = llm.score(query, documents, chat_template=model_request.chat_template)

    print("-" * 50)
    print(f"Model: {engine_args.model}")
    print(f"Modality: {modality}")
    print(f"Query: {query}")
    print("Relevance scores:", [output.outputs.score for output in outputs])
    print("-" * 50)


if __name__ == "__main__":
    args = parse_args()
    main(args)

Vision Score API Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

"""
Example online usage of Score API.

Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
e.g.
    vllm serve jinaai/jina-reranker-m0 --runner pooling

    vllm serve Qwen/Qwen3-VL-Reranker-2B \
        --runner pooling \
        --max-model-len 4096 \
        --hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
        --chat-template examples/pooling/score/template/qwen3_vl_reranker.jinja
"""

import argparse
import json
import pprint

import requests

headers = {"accept": "application/json", "Content-Type": "application/json"}

text_1 = "slm markdown"
text_2 = {
    "content": [
        {
            "type": "image_url",
            "image_url": {
                "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
            },
        },
        {
            "type": "image_url",
            "image_url": {
                "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
            },
        },
    ]
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    return parser.parse_args()


def main(args):
    base_url = f"http://{args.host}:{args.port}"
    models_url = base_url + "/v1/models"
    score_url = base_url + "/score"

    response = requests.get(models_url, headers=headers)
    model = response.json()["data"][0]["id"]

    prompt = {"model": model, "text_1": text_1, "text_2": text_2}
    response = requests.post(score_url, headers=headers, json=prompt)
    print("\nPrompt when text_1 is string and text_2 is a image list:")
    pprint.pprint(prompt)
    print("\nScore Response:")
    print(json.dumps(response.json(), indent=2))


if __name__ == "__main__":
    args = parse_args()
    main(args)