Skip to content

vllm.model_executor.layers.rotary_embedding

Rotary Positional Embeddings.

Modules:

Name Description
base

Rotary Positional Embeddings Base Class.

common
deepseek_scaling_rope
dual_chunk_rope
dynamic_ntk_alpha_rope
dynamic_ntk_scaling_rope
ernie45_vl_rope
linear_scaling_rope
llama3_rope
llama4_vision_rope
mrope
ntk_scaling_rope
phi3_long_rope_scaled_rope
xdrope
yarn_scaling_rope

_ROPE_DICT module-attribute

_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}

__all__ module-attribute

__all__ = ['RotaryEmbedding']

RotaryEmbedding

Bases: RotaryEmbeddingBase

Source code in vllm/model_executor/layers/rotary_embedding/base.py
class RotaryEmbedding(RotaryEmbeddingBase):
    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
    ) -> None:
        super().__init__(
            head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
        )

    @staticmethod
    def forward_static(
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None,
        head_size: int,
        rotary_dim: int,
        cos_sin_cache: torch.Tensor,
        is_neox_style: bool,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """A PyTorch-native implementation of forward()."""
        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = cos_sin_cache.index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, head_size)
        query_rot = query[..., :rotary_dim]
        query_pass = query[..., rotary_dim:]
        query_rot = ApplyRotaryEmb.forward_static(
            query_rot,
            cos,
            sin,
            is_neox_style,
        )
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        # key may be None in some cases, e.g. cross-layer KV sharing
        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, head_size)
            key_rot = key[..., :rotary_dim]
            key_pass = key[..., rotary_dim:]
            key_rot = ApplyRotaryEmb.forward_static(
                key_rot,
                cos,
                sin,
                is_neox_style,
            )
            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """A PyTorch-native implementation of forward()."""
        return self.forward_static(
            positions,
            query,
            key,
            self.head_size,
            self.rotary_dim,
            self.cos_sin_cache,
            self.is_neox_style,
        )

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if self.use_flashinfer:
            torch.ops.vllm.flashinfer_rotary_embedding(
                positions,
                query,
                key,
                self.head_size,
                self.cos_sin_cache,
                self.is_neox_style,
            )
            return query, key

        from vllm import _custom_ops as ops

        self._match_cos_sin_cache_dtype(query)

        # ops.rotary_embedding() is an in-place operation
        # that updates the query and key tensors.
        ops.rotary_embedding(
            positions,
            query,
            key,
            self.head_size,
            self.cos_sin_cache,
            self.is_neox_style,
        )
        return query, key

    def forward_hip(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if self.is_rocm_triton_rotary_embed_enabled:
            self._match_cos_sin_cache_dtype(query)
            rocm_aiter_ops.triton_rotary_embed(
                positions,
                query,
                key,
                self.cos_sin_cache,
                self.head_size,
                self.rotary_dim,
                self.is_neox_style,
            )
            return query, key
        return self.forward_cuda(positions, query, key)

    def forward_xpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        from vllm._ipex_ops import ipex_ops as ops

        self._match_cos_sin_cache_dtype(query)
        # ops.rotary_embedding() is an in-place operation
        # that updates the query and key tensors.
        if key is None:
            # XPU kernel doesn't support key=None so fall back to native impl
            # TODO(sarckk): add support for optional key in
            # ipex.llm.functional.rotary_embedding_batched
            return self.forward_native(positions, query, key)
        else:
            ops.rotary_embedding(
                positions,
                query,
                key,
                self.head_size,
                self.cos_sin_cache,
                self.is_neox_style,
            )
        return query, key

    def forward_cpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        from vllm import _custom_ops as ops

        self._match_cos_sin_cache_dtype(query)

        # ops.rotary_embedding() is an in-place operation
        # that updates the query and key tensors.
        ops.rotary_embedding(
            positions,
            query,
            key,
            self.head_size,
            self.cos_sin_cache,
            self.is_neox_style,
        )
        return query, key

    def extra_repr(self) -> str:
        s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
        s += f", max_position_embeddings={self.max_position_embeddings}"
        s += f", base={self.base}, is_neox_style={self.is_neox_style}"
        return s

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
) -> None:
    super().__init__(
        head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
    )

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def extra_repr(self) -> str:
    s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
    s += f", max_position_embeddings={self.max_position_embeddings}"
    s += f", base={self.base}, is_neox_style={self.is_neox_style}"
    return s

forward_cpu

forward_cpu(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def forward_cpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    from vllm import _custom_ops as ops

    self._match_cos_sin_cache_dtype(query)

    # ops.rotary_embedding() is an in-place operation
    # that updates the query and key tensors.
    ops.rotary_embedding(
        positions,
        query,
        key,
        self.head_size,
        self.cos_sin_cache,
        self.is_neox_style,
    )
    return query, key

forward_cuda

forward_cuda(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def forward_cuda(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    if self.use_flashinfer:
        torch.ops.vllm.flashinfer_rotary_embedding(
            positions,
            query,
            key,
            self.head_size,
            self.cos_sin_cache,
            self.is_neox_style,
        )
        return query, key

    from vllm import _custom_ops as ops

    self._match_cos_sin_cache_dtype(query)

    # ops.rotary_embedding() is an in-place operation
    # that updates the query and key tensors.
    ops.rotary_embedding(
        positions,
        query,
        key,
        self.head_size,
        self.cos_sin_cache,
        self.is_neox_style,
    )
    return query, key

forward_hip

forward_hip(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def forward_hip(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    if self.is_rocm_triton_rotary_embed_enabled:
        self._match_cos_sin_cache_dtype(query)
        rocm_aiter_ops.triton_rotary_embed(
            positions,
            query,
            key,
            self.cos_sin_cache,
            self.head_size,
            self.rotary_dim,
            self.is_neox_style,
        )
        return query, key
    return self.forward_cuda(positions, query, key)

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]

A PyTorch-native implementation of forward().

Source code in vllm/model_executor/layers/rotary_embedding/base.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """A PyTorch-native implementation of forward()."""
    return self.forward_static(
        positions,
        query,
        key,
        self.head_size,
        self.rotary_dim,
        self.cos_sin_cache,
        self.is_neox_style,
    )

forward_static staticmethod

forward_static(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None,
    head_size: int,
    rotary_dim: int,
    cos_sin_cache: Tensor,
    is_neox_style: bool,
) -> tuple[Tensor, Tensor | None]

A PyTorch-native implementation of forward().

Source code in vllm/model_executor/layers/rotary_embedding/base.py
@staticmethod
def forward_static(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None,
    head_size: int,
    rotary_dim: int,
    cos_sin_cache: torch.Tensor,
    is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """A PyTorch-native implementation of forward()."""
    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = cos_sin_cache.index_select(0, positions)
    cos, sin = cos_sin.chunk(2, dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, head_size)
    query_rot = query[..., :rotary_dim]
    query_pass = query[..., rotary_dim:]
    query_rot = ApplyRotaryEmb.forward_static(
        query_rot,
        cos,
        sin,
        is_neox_style,
    )
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    # key may be None in some cases, e.g. cross-layer KV sharing
    if key is not None:
        key_shape = key.shape
        key = key.view(num_tokens, -1, head_size)
        key_rot = key[..., :rotary_dim]
        key_pass = key[..., rotary_dim:]
        key_rot = ApplyRotaryEmb.forward_static(
            key_rot,
            cos,
            sin,
            is_neox_style,
        )
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_xpu

forward_xpu(
    positions: Tensor,
    query: Tensor,
    key: Tensor | None = None,
) -> tuple[Tensor, Tensor | None]
Source code in vllm/model_executor/layers/rotary_embedding/base.py
def forward_xpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    from vllm._ipex_ops import ipex_ops as ops

    self._match_cos_sin_cache_dtype(query)
    # ops.rotary_embedding() is an in-place operation
    # that updates the query and key tensors.
    if key is None:
        # XPU kernel doesn't support key=None so fall back to native impl
        # TODO(sarckk): add support for optional key in
        # ipex.llm.functional.rotary_embedding_batched
        return self.forward_native(positions, query, key)
    else:
        ops.rotary_embedding(
            positions,
            query,
            key,
            self.head_size,
            self.cos_sin_cache,
            self.is_neox_style,
        )
    return query, key

get_rope

get_rope(
    head_size: int,
    max_position: int,
    is_neox_style: bool = True,
    rope_parameters: dict[str, Any] | None = None,
    dtype: dtype | None = None,
    dual_chunk_attention_config: dict[str, Any]
    | None = None,
) -> RotaryEmbedding
Source code in vllm/model_executor/layers/rotary_embedding/__init__.py
def get_rope(
    head_size: int,
    max_position: int,
    is_neox_style: bool = True,
    rope_parameters: dict[str, Any] | None = None,
    dtype: torch.dtype | None = None,
    dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding:
    if dtype is None:
        dtype = torch.get_default_dtype()
    if rope_parameters is not None:
        # Transforms every value that is a list into a tuple for caching calls
        rope_parameters_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in rope_parameters.items()
        }
        rope_parameters_args = tuple(rope_parameters_tuple.items())
    else:
        rope_parameters_args = None

    if dual_chunk_attention_config is not None:
        dual_chunk_attention_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in dual_chunk_attention_config.items()
            if k != "sparse_attention_config"
        }
        dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
    else:
        dual_chunk_attention_args = None

    rope_parameters = rope_parameters or {}
    base = rope_parameters.get("rope_theta", 10000)
    scaling_type = rope_parameters.get("rope_type", "default")
    partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)

    if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
        raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
    rotary_dim = int(head_size * partial_rotary_factor)

    key = (
        head_size,
        rotary_dim,
        max_position,
        is_neox_style,
        rope_parameters_args,
        dual_chunk_attention_args,
        dtype,
    )
    if key in _ROPE_DICT:
        return _ROPE_DICT[key]

    if dual_chunk_attention_config is not None:
        extra_kwargs = {
            k: v
            for k, v in dual_chunk_attention_config.items()
            if k in ("chunk_size", "local_size")
        }
        rotary_emb = DualChunkRotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            base,
            is_neox_style,
            dtype,
            **extra_kwargs,
        )
    elif scaling_type == "default":
        if "mrope_section" in rope_parameters:
            rotary_emb = MRotaryEmbedding(
                head_size,
                rotary_dim,
                max_position,
                base,
                is_neox_style,
                dtype,
                mrope_section=rope_parameters["mrope_section"],
                mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
            )
        else:
            rotary_emb = RotaryEmbedding(
                head_size,
                rotary_dim,
                max_position,
                base,
                is_neox_style,
                dtype,
            )
    elif scaling_type == "llama3":
        scaling_factor = rope_parameters["factor"]
        low_freq_factor = rope_parameters["low_freq_factor"]
        high_freq_factor = rope_parameters["high_freq_factor"]
        original_max_position = rope_parameters["original_max_position_embeddings"]
        rotary_emb = Llama3RotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            base,
            is_neox_style,
            dtype,
            scaling_factor,
            low_freq_factor,
            high_freq_factor,
            original_max_position,
        )
    elif scaling_type == "mllama4":
        rotary_emb = Llama4VisionRotaryEmbedding(
            head_size, rotary_dim, max_position, base, is_neox_style, dtype
        )
    elif scaling_type == "linear":
        scaling_factor = rope_parameters["factor"]
        rotary_emb = LinearScalingRotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            base,
            is_neox_style,
            scaling_factor,
            dtype,
        )
    elif scaling_type == "ntk":
        scaling_factor = rope_parameters["factor"]
        mixed_b = rope_parameters.get("mixed_b")
        rotary_emb = NTKScalingRotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            base,
            is_neox_style,
            scaling_factor,
            dtype,
            mixed_b,
        )
    elif scaling_type == "dynamic":
        if "alpha" in rope_parameters:
            scaling_alpha = rope_parameters["alpha"]
            rotary_emb = DynamicNTKAlphaRotaryEmbedding(
                head_size,
                rotary_dim,
                max_position,
                base,
                is_neox_style,
                scaling_alpha,
                dtype,
            )
        elif "factor" in rope_parameters:
            scaling_factor = rope_parameters["factor"]
            rotary_emb = DynamicNTKScalingRotaryEmbedding(
                head_size,
                rotary_dim,
                max_position,
                base,
                is_neox_style,
                scaling_factor,
                dtype,
            )
        else:
            raise ValueError(
                "Dynamic rope scaling must contain either 'alpha' or 'factor' field"
            )
    elif scaling_type == "xdrope":
        scaling_alpha = rope_parameters["alpha"]
        rotary_emb = XDRotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            base,
            is_neox_style,
            scaling_alpha,
            dtype,
            xdrope_section=rope_parameters["xdrope_section"],
        )
    elif scaling_type == "yarn":
        scaling_factor = rope_parameters["factor"]
        original_max_position = rope_parameters["original_max_position_embeddings"]
        extra_kwargs = {
            k: v
            for k, v in rope_parameters.items()
            if k
            in (
                "extrapolation_factor",
                "attn_factor",
                "beta_fast",
                "beta_slow",
                "apply_yarn_scaling",
                "truncate",
            )
        }
        if "mrope_section" in rope_parameters:
            extra_kwargs.pop("apply_yarn_scaling", None)
            rotary_emb = MRotaryEmbedding(
                head_size,
                rotary_dim,
                original_max_position,
                base,
                is_neox_style,
                dtype,
                mrope_section=rope_parameters["mrope_section"],
                mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
                scaling_factor=scaling_factor,
                **extra_kwargs,
            )
        else:
            rotary_emb = YaRNScalingRotaryEmbedding(
                head_size,
                rotary_dim,
                original_max_position,
                base,
                is_neox_style,
                scaling_factor,
                dtype,
                **extra_kwargs,
            )
    elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
        scaling_factor = rope_parameters["factor"]
        original_max_position = rope_parameters["original_max_position_embeddings"]
        # assert max_position == original_max_position * scaling_factor
        extra_kwargs = {
            k: v
            for k, v in rope_parameters.items()
            if k
            in (
                "extrapolation_factor",
                "attn_factor",
                "beta_fast",
                "beta_slow",
                "mscale",
                "mscale_all_dim",
            )
        }
        rotary_emb = DeepseekScalingRotaryEmbedding(
            head_size,
            rotary_dim,
            original_max_position,
            base,
            is_neox_style,
            scaling_factor,
            dtype,
            **extra_kwargs,
        )
    elif scaling_type == "longrope":
        short_factor = rope_parameters["short_factor"]
        long_factor = rope_parameters["long_factor"]
        original_max_position = rope_parameters["original_max_position_embeddings"]
        extra_kwargs = {
            k: v
            for k, v in rope_parameters.items()
            if k in ("short_mscale", "long_mscale")
        }
        rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
            head_size,
            rotary_dim,
            max_position,
            original_max_position,
            base,
            is_neox_style,
            dtype,
            short_factor,
            long_factor,
            **extra_kwargs,
        )
    else:
        raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
    _ROPE_DICT[key] = rotary_emb
    return rotary_emb