Skip to content

vllm.tokenizers.grok2

Tokenizer for Grok-2 .tok.json format.

CONTROL_TOKEN_TEXTS module-attribute

CONTROL_TOKEN_TEXTS = [
    f"<|control{i}|>" for i in (range(1, 705))
]

DEFAULT_CHAT_TEMPLATE module-attribute

DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"

DEFAULT_CONTROL_TOKENS module-attribute

DEFAULT_CONTROL_TOKENS = {
    "pad": PAD,
    "sep": SEP,
    "eos": EOS,
}

DEFAULT_SPECIAL_TOKENS module-attribute

DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]

EOS module-attribute

EOS = '<|eos|>'

PAD module-attribute

PAD = '<|pad|>'

PAT_STR_B module-attribute

PAT_STR_B = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"

RESERVED_TOKEN_TEXTS module-attribute

RESERVED_TOKEN_TEXTS = [
    f"<|reserved_{i}|>" for i in (range(3, 128))
]

SEP module-attribute

SEP = '<|separator|>'

logger module-attribute

logger = init_logger(__name__)

Grok2Tokenizer

Bases: TokenizerLike

Source code in vllm/tokenizers/grok2.py
class Grok2Tokenizer(TokenizerLike):
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "Grok2Tokenizer":
        if args:
            logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")

        path = Path(path_or_repo_id)
        if path.is_file():
            vocab_file = path
            model_path = path.parent
            repo_id = None
        elif path.is_dir():
            vocab_file = path / "tokenizer.tok.json"
            model_path = path
            repo_id = None
        else:
            vocab_file = Path(
                hf_hub_download(
                    repo_id=str(path_or_repo_id),
                    filename="tokenizer.tok.json",
                    revision=revision,
                    cache_dir=download_dir,
                )
            )
            model_path = vocab_file.parent
            repo_id = str(path_or_repo_id)

        if not vocab_file.is_file():
            raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")

        config = _maybe_load_tokenizer_config(
            model_path,
            repo_id=repo_id,
            revision=revision,
            download_dir=download_dir,
        )

        return cls(
            vocab_file=vocab_file,
            name_or_path=str(path_or_repo_id),
            truncation_side=kwargs.get("truncation_side", "left"),
            chat_template=config.get("chat_template"),
            init_kwargs=config,
        )

    def __init__(
        self,
        *,
        vocab_file: Path,
        name_or_path: str,
        truncation_side: str,
        chat_template: str | None,
        init_kwargs: dict[str, Any] | None = None,
    ) -> None:
        super().__init__()
        self.name_or_path = name_or_path
        self._truncation_side = truncation_side
        self.init_kwargs = init_kwargs or {}
        self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE

        self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)

        self._token_to_id: dict[str, int] = {}
        self._id_to_token: dict[int, str] = {}
        for token, token_id in self._tokenizer._mergeable_ranks.items():
            token_str = token.decode("utf-8", errors="replace")
            self._token_to_id[token_str] = token_id
            self._id_to_token[token_id] = token_str

        for token, token_id in self._special_tokens.items():
            self._token_to_id[token] = token_id
            self._id_to_token[token_id] = token

        bos_token_id = self._special_tokens.get(SEP)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(PAD)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(EOS)
        if bos_token_id is None:
            bos_token_id = 0
        self._bos_token_id = bos_token_id

        self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
        self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
        self._unk_token_id = self._pad_token_id

    def num_special_tokens_to_add(self) -> int:
        return 0

    @property
    def all_special_tokens(self) -> list[str]:
        return list(self._special_tokens.keys())

    @property
    def all_special_ids(self) -> list[int]:
        return list(self._special_tokens.values())

    @property
    def bos_token_id(self) -> int:
        return self._bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self._eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self._pad_token_id

    @property
    def is_fast(self) -> bool:
        return False

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.n_vocab

    @property
    def max_token_id(self) -> int:
        return self._tokenizer.n_vocab - 1

    @property
    def truncation_side(self) -> str:
        return self._truncation_side

    def get_vocab(self) -> dict[str, int]:
        return dict(self._token_to_id)

    def get_added_vocab(self) -> dict[str, int]:
        return dict(self._special_tokens)

    def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
        if max_length is None or len(tokens) <= max_length:
            return tokens
        if self.truncation_side == "left":
            return tokens[-max_length:]
        return tokens[:max_length]

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
    ) -> list[int]:
        del add_special_tokens
        tokens = self._tokenizer.encode(text)
        if truncation:
            tokens = self._maybe_truncate(tokens, max_length)
        return tokens

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        if isinstance(ids, int):
            ids = [ids]
        if skip_special_tokens:
            ids = [
                token_id
                for token_id in ids
                if token_id not in self._special_tokens.values()
            ]
        return self._tokenizer.decode(ids)

    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        if isinstance(tokens, str):
            return self._token_to_id.get(tokens, self._unk_token_id)
        return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]

    def convert_ids_to_tokens(
        self, ids: list[int], skip_special_tokens: bool = False
    ) -> list[str]:
        tokens = []
        for token_id in ids:
            if skip_special_tokens and token_id in self._special_tokens.values():
                continue
            tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
        return tokens

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        token_ids = self.convert_tokens_to_ids(tokens)
        return self.decode(token_ids, skip_special_tokens=False)

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
    ) -> BatchEncoding:
        if text_pair is not None:
            raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")

        if isinstance(text, list):
            input_ids_batch: list[list[int]] = [
                self.encode(
                    item,
                    truncation=truncation,
                    max_length=max_length,
                    add_special_tokens=add_special_tokens,
                )
                for item in text
            ]
            attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
            return BatchEncoding(
                {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
            )

        input_ids = self.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )
        attention_mask = [1] * len(input_ids)
        return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})

    def get_chat_template(
        self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
    ) -> str | None:
        del tools
        return chat_template or self._chat_template

    def apply_chat_template(
        self,
        messages: list[ChatCompletionMessageParam],
        tools: list[dict[str, Any]] | None = None,
        chat_template: str | None = None,
        tokenize: bool = False,
        **kwargs,
    ) -> str | list[int]:
        template = self.get_chat_template(chat_template, tools=tools)
        if template is None:
            raise ValueError(
                "No chat template available. Provide `chat_template` explicitly."
            )
        prompt = hf_chat_utils.apply_chat_template(
            conversation=messages,
            chat_template=template,
            tools=tools,
            **kwargs,
        )
        if tokenize:
            return self.encode(prompt, add_special_tokens=False)
        return prompt

_bos_token_id instance-attribute

_bos_token_id = bos_token_id

_chat_template instance-attribute

_chat_template = chat_template or DEFAULT_CHAT_TEMPLATE

_eos_token_id instance-attribute

_eos_token_id = get(EOS, _bos_token_id)

_id_to_token instance-attribute

_id_to_token: dict[int, str] = {}

_pad_token_id instance-attribute

_pad_token_id = get(PAD, _eos_token_id)

_token_to_id instance-attribute

_token_to_id: dict[str, int] = {}

_truncation_side instance-attribute

_truncation_side = truncation_side

_unk_token_id instance-attribute

_unk_token_id = _pad_token_id

all_special_ids property

all_special_ids: list[int]

all_special_tokens property

all_special_tokens: list[str]

bos_token_id property

bos_token_id: int

eos_token_id property

eos_token_id: int

init_kwargs instance-attribute

init_kwargs = init_kwargs or {}

is_fast property

is_fast: bool

max_token_id property

max_token_id: int

name_or_path instance-attribute

name_or_path = name_or_path

pad_token_id property

pad_token_id: int

truncation_side property

truncation_side: str

vocab_size property

vocab_size: int

__call__

__call__(
    text: str | list[str],
    text_pair: str | None = None,
    add_special_tokens: bool = True,
    truncation: bool = False,
    max_length: int | None = None,
) -> BatchEncoding
Source code in vllm/tokenizers/grok2.py
def __call__(
    self,
    text: str | list[str],
    text_pair: str | None = None,
    add_special_tokens: bool = True,
    truncation: bool = False,
    max_length: int | None = None,
) -> BatchEncoding:
    if text_pair is not None:
        raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")

    if isinstance(text, list):
        input_ids_batch: list[list[int]] = [
            self.encode(
                item,
                truncation=truncation,
                max_length=max_length,
                add_special_tokens=add_special_tokens,
            )
            for item in text
        ]
        attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
        return BatchEncoding(
            {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
        )

    input_ids = self.encode(
        text,
        truncation=truncation,
        max_length=max_length,
        add_special_tokens=add_special_tokens,
    )
    attention_mask = [1] * len(input_ids)
    return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})

__init__

__init__(
    *,
    vocab_file: Path,
    name_or_path: str,
    truncation_side: str,
    chat_template: str | None,
    init_kwargs: dict[str, Any] | None = None,
) -> None
Source code in vllm/tokenizers/grok2.py
def __init__(
    self,
    *,
    vocab_file: Path,
    name_or_path: str,
    truncation_side: str,
    chat_template: str | None,
    init_kwargs: dict[str, Any] | None = None,
) -> None:
    super().__init__()
    self.name_or_path = name_or_path
    self._truncation_side = truncation_side
    self.init_kwargs = init_kwargs or {}
    self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE

    self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)

    self._token_to_id: dict[str, int] = {}
    self._id_to_token: dict[int, str] = {}
    for token, token_id in self._tokenizer._mergeable_ranks.items():
        token_str = token.decode("utf-8", errors="replace")
        self._token_to_id[token_str] = token_id
        self._id_to_token[token_id] = token_str

    for token, token_id in self._special_tokens.items():
        self._token_to_id[token] = token_id
        self._id_to_token[token_id] = token

    bos_token_id = self._special_tokens.get(SEP)
    if bos_token_id is None:
        bos_token_id = self._special_tokens.get(PAD)
    if bos_token_id is None:
        bos_token_id = self._special_tokens.get(EOS)
    if bos_token_id is None:
        bos_token_id = 0
    self._bos_token_id = bos_token_id

    self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
    self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
    self._unk_token_id = self._pad_token_id

_maybe_truncate

_maybe_truncate(
    tokens: list[int], max_length: int | None
) -> list[int]
Source code in vllm/tokenizers/grok2.py
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
    if max_length is None or len(tokens) <= max_length:
        return tokens
    if self.truncation_side == "left":
        return tokens[-max_length:]
    return tokens[:max_length]

apply_chat_template

apply_chat_template(
    messages: list[ChatCompletionMessageParam],
    tools: list[dict[str, Any]] | None = None,
    chat_template: str | None = None,
    tokenize: bool = False,
    **kwargs,
) -> str | list[int]
Source code in vllm/tokenizers/grok2.py
def apply_chat_template(
    self,
    messages: list[ChatCompletionMessageParam],
    tools: list[dict[str, Any]] | None = None,
    chat_template: str | None = None,
    tokenize: bool = False,
    **kwargs,
) -> str | list[int]:
    template = self.get_chat_template(chat_template, tools=tools)
    if template is None:
        raise ValueError(
            "No chat template available. Provide `chat_template` explicitly."
        )
    prompt = hf_chat_utils.apply_chat_template(
        conversation=messages,
        chat_template=template,
        tools=tools,
        **kwargs,
    )
    if tokenize:
        return self.encode(prompt, add_special_tokens=False)
    return prompt

convert_ids_to_tokens

convert_ids_to_tokens(
    ids: list[int], skip_special_tokens: bool = False
) -> list[str]
Source code in vllm/tokenizers/grok2.py
def convert_ids_to_tokens(
    self, ids: list[int], skip_special_tokens: bool = False
) -> list[str]:
    tokens = []
    for token_id in ids:
        if skip_special_tokens and token_id in self._special_tokens.values():
            continue
        tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
    return tokens

convert_tokens_to_ids

convert_tokens_to_ids(tokens: str) -> int
convert_tokens_to_ids(tokens: list[str]) -> list[int]
convert_tokens_to_ids(
    tokens: str | list[str],
) -> int | list[int]
Source code in vllm/tokenizers/grok2.py
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
    if isinstance(tokens, str):
        return self._token_to_id.get(tokens, self._unk_token_id)
    return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]

convert_tokens_to_string

convert_tokens_to_string(tokens: list[str]) -> str
Source code in vllm/tokenizers/grok2.py
def convert_tokens_to_string(self, tokens: list[str]) -> str:
    token_ids = self.convert_tokens_to_ids(tokens)
    return self.decode(token_ids, skip_special_tokens=False)

decode

decode(
    ids: list[int] | int, skip_special_tokens: bool = False
) -> str
Source code in vllm/tokenizers/grok2.py
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
    if isinstance(ids, int):
        ids = [ids]
    if skip_special_tokens:
        ids = [
            token_id
            for token_id in ids
            if token_id not in self._special_tokens.values()
        ]
    return self._tokenizer.decode(ids)

encode

encode(
    text: str,
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool = True,
) -> list[int]
Source code in vllm/tokenizers/grok2.py
def encode(
    self,
    text: str,
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool = True,
) -> list[int]:
    del add_special_tokens
    tokens = self._tokenizer.encode(text)
    if truncation:
        tokens = self._maybe_truncate(tokens, max_length)
    return tokens

from_pretrained classmethod

from_pretrained(
    path_or_repo_id: str | Path,
    *args,
    trust_remote_code: bool = False,
    revision: str | None = None,
    download_dir: str | None = None,
    **kwargs,
) -> Grok2Tokenizer
Source code in vllm/tokenizers/grok2.py
@classmethod
def from_pretrained(
    cls,
    path_or_repo_id: str | Path,
    *args,
    trust_remote_code: bool = False,
    revision: str | None = None,
    download_dir: str | None = None,
    **kwargs,
) -> "Grok2Tokenizer":
    if args:
        logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")

    path = Path(path_or_repo_id)
    if path.is_file():
        vocab_file = path
        model_path = path.parent
        repo_id = None
    elif path.is_dir():
        vocab_file = path / "tokenizer.tok.json"
        model_path = path
        repo_id = None
    else:
        vocab_file = Path(
            hf_hub_download(
                repo_id=str(path_or_repo_id),
                filename="tokenizer.tok.json",
                revision=revision,
                cache_dir=download_dir,
            )
        )
        model_path = vocab_file.parent
        repo_id = str(path_or_repo_id)

    if not vocab_file.is_file():
        raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")

    config = _maybe_load_tokenizer_config(
        model_path,
        repo_id=repo_id,
        revision=revision,
        download_dir=download_dir,
    )

    return cls(
        vocab_file=vocab_file,
        name_or_path=str(path_or_repo_id),
        truncation_side=kwargs.get("truncation_side", "left"),
        chat_template=config.get("chat_template"),
        init_kwargs=config,
    )

get_added_vocab

get_added_vocab() -> dict[str, int]
Source code in vllm/tokenizers/grok2.py
def get_added_vocab(self) -> dict[str, int]:
    return dict(self._special_tokens)

get_chat_template

get_chat_template(
    chat_template: str | None,
    tools: list[dict[str, Any]] | None = None,
) -> str | None
Source code in vllm/tokenizers/grok2.py
def get_chat_template(
    self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
) -> str | None:
    del tools
    return chat_template or self._chat_template

get_vocab

get_vocab() -> dict[str, int]
Source code in vllm/tokenizers/grok2.py
def get_vocab(self) -> dict[str, int]:
    return dict(self._token_to_id)

num_special_tokens_to_add

num_special_tokens_to_add() -> int
Source code in vllm/tokenizers/grok2.py
def num_special_tokens_to_add(self) -> int:
    return 0

_load_tiktoken_encoding

_load_tiktoken_encoding(
    vocab_file: Path,
) -> tuple[Any, dict[str, int]]
Source code in vllm/tokenizers/grok2.py
def _load_tiktoken_encoding(
    vocab_file: Path,
) -> tuple[Any, dict[str, int]]:
    try:
        import tiktoken
    except ImportError as exc:
        raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc

    with vocab_file.open("rb") as f:
        xtok_dict = json.load(f)

    mergeable_ranks = {
        bytes(item["bytes"]): item["token"]
        for item in xtok_dict.get("regular_tokens", [])
    }
    special_tokens = {
        bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"]
        for item in xtok_dict.get("special_tokens", [])
    }

    if xtok_dict.get("word_split") == "V1":
        pat_str = PAT_STR_B
    else:
        raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}")

    pat_str = xtok_dict.get("pat_str", pat_str)

    kwargs = {
        "name": str(vocab_file),
        "pat_str": pat_str,
        "mergeable_ranks": mergeable_ranks,
        "special_tokens": special_tokens,
    }

    if "vocab_size" in xtok_dict:
        kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]

    tokenizer = tiktoken.Encoding(**kwargs)

    default_allowed_special: set[str] | None = None
    if "default_allowed_special" in xtok_dict:
        default_allowed_special = {
            bytes(bytes_list).decode("utf-8", errors="replace")
            for bytes_list in xtok_dict["default_allowed_special"]
        }

    tokenizer._default_allowed_special = default_allowed_special or set()
    tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS

    def encode_patched(
        self,
        text: str,
        *,
        allowed_special: Literal["all"] | Set[str] = set(),
        disallowed_special: Literal["all"] | Collection[str] = "all",
    ) -> list[int]:
        del disallowed_special
        if isinstance(allowed_special, set):
            allowed_special |= self._default_allowed_special
        return tiktoken.Encoding.encode(
            self,
            text,
            allowed_special=allowed_special,
            disallowed_special=(),
        )

    tokenizer.encode = functools.partial(encode_patched, tokenizer)
    tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
    tokenizer._default_allowed_special |= set(
        CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
    )

    return tokenizer, special_tokens

_maybe_load_tokenizer_config

_maybe_load_tokenizer_config(
    model_path: Path,
    *,
    repo_id: str | None,
    revision: str | None,
    download_dir: str | None,
) -> dict[str, Any]
Source code in vllm/tokenizers/grok2.py
def _maybe_load_tokenizer_config(
    model_path: Path,
    *,
    repo_id: str | None,
    revision: str | None,
    download_dir: str | None,
) -> dict[str, Any]:
    config_path = model_path / "tokenizer_config.json"
    if config_path.is_file():
        with config_path.open("r", encoding="utf-8") as f:
            return json.load(f)

    if repo_id is None:
        return {}

    try:
        config_file = hf_hub_download(
            repo_id=repo_id,
            filename="tokenizer_config.json",
            revision=revision,
            cache_dir=download_dir,
        )
    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError):
        # If the repo, revision, or file does not exist, fall back silently.
        return {}
    except HfHubHTTPError as exc:
        logger.warning(
            "Failed to download tokenizer_config.json from %s. "
            "This may be due to a network or authentication issue. "
            "The default chat template will be used. Error: %s",
            repo_id,
            exc,
        )
        return {}

    try:
        with Path(config_file).open("r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError as exc:
        logger.warning(
            "Failed to parse tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}
    except OSError as exc:
        logger.warning(
            "Failed to open tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}