Skip to content

vllm.compilation.inductor_pass

P module-attribute

P = ParamSpec('P')

R module-attribute

R = TypeVar('R')

__all__ module-attribute

__all__ = ['CustomGraphPass']

_pass_context module-attribute

_pass_context = None

CallableInductorPass

Bases: InductorPass

This class is a wrapper for a callable that automatically provides an implementation of the UUID.

Source code in vllm/compilation/inductor_pass.py
class CallableInductorPass(InductorPass):
    """
    This class is a wrapper for a callable that automatically provides an
    implementation of the UUID.
    """

    def __init__(
        self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
    ) -> None:
        self.callable = callable
        self._uuid = self.hash_source(callable) if uuid is None else uuid

    def __call__(self, graph: torch.fx.Graph) -> None:
        self.callable(graph)

    def uuid(self) -> Any:
        return self._uuid

_uuid instance-attribute

_uuid = hash_source(callable) if uuid is None else uuid

callable instance-attribute

callable = callable

__call__

__call__(graph: Graph) -> None
Source code in vllm/compilation/inductor_pass.py
def __call__(self, graph: torch.fx.Graph) -> None:
    self.callable(graph)

__init__

__init__(
    callable: Callable[[Graph], None],
    uuid: Any | None = None,
) -> None
Source code in vllm/compilation/inductor_pass.py
def __init__(
    self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
    self.callable = callable
    self._uuid = self.hash_source(callable) if uuid is None else uuid

uuid

uuid() -> Any
Source code in vllm/compilation/inductor_pass.py
def uuid(self) -> Any:
    return self._uuid

CustomGraphPass

Bases: ABC

This class replaces CustomGraphPass from torch==2.6 when using torch<2.6. It conforms to the 2.6 interface but also supports pickling, as that's what the inductor code cache uses to determine the cache key before 2.6. (in 2.6 and above, uuid() is used.)

Subclasses can just "pretend" that uuid is used.

Source code in vllm/compilation/torch25_custom_graph_pass.py
class Torch25CustomGraphPass(ABC):  # noqa (redefinition)
    """
    This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
    It conforms to the 2.6 interface but also supports pickling, as that's what
    the inductor code cache uses to determine the cache key before 2.6.
    (in 2.6 and above, uuid() is used.)

    Subclasses can just "pretend" that uuid is used.
    """

    @abstractmethod
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
        """
        Implementation of the custom pass.
        """

    @abstractmethod
    def uuid(self) -> Any | None:
        """
        Return an ID to uniquely identify your custom pass implementation.
        Return None to skip inductor code caching entirely.
        """

    def __getstate__(self) -> Any | None:
        """
        Pickling is used instead of uuid() in torch<2.6. Just return uuid()
         to enable subclasses to only have to implement uuid.
        """
        return self.uuid()

    def __setstate__(self, state: Any) -> NoReturn:
        raise ValueError(
            "Cannot unpickle CustomGraphPass because pickling"
            " is used for cache key uuid. Use torch>=2.6 with"
            " native uuid support for custom passes."
        )

__call__ abstractmethod

__call__(graph: Graph) -> None

Implementation of the custom pass.

Source code in vllm/compilation/torch25_custom_graph_pass.py
@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
    """
    Implementation of the custom pass.
    """

__getstate__

__getstate__() -> Any | None

Pickling is used instead of uuid() in torch<2.6. Just return uuid() to enable subclasses to only have to implement uuid.

Source code in vllm/compilation/torch25_custom_graph_pass.py
def __getstate__(self) -> Any | None:
    """
    Pickling is used instead of uuid() in torch<2.6. Just return uuid()
     to enable subclasses to only have to implement uuid.
    """
    return self.uuid()

__setstate__

__setstate__(state: Any) -> NoReturn
Source code in vllm/compilation/torch25_custom_graph_pass.py
def __setstate__(self, state: Any) -> NoReturn:
    raise ValueError(
        "Cannot unpickle CustomGraphPass because pickling"
        " is used for cache key uuid. Use torch>=2.6 with"
        " native uuid support for custom passes."
    )

uuid abstractmethod

uuid() -> Any | None

Return an ID to uniquely identify your custom pass implementation. Return None to skip inductor code caching entirely.

Source code in vllm/compilation/torch25_custom_graph_pass.py
@abstractmethod
def uuid(self) -> Any | None:
    """
    Return an ID to uniquely identify your custom pass implementation.
    Return None to skip inductor code caching entirely.
    """

InductorPass

Bases: Torch25CustomGraphPass

A custom graph pass that uses a hash of its source as the UUID. This is defined as a convenience and should work in most cases.

Source code in vllm/compilation/inductor_pass.py
class InductorPass(CustomGraphPass):  # type: ignore[misc]
    """
    A custom graph pass that uses a hash of its source as the UUID.
    This is defined as a convenience and should work in most cases.
    """

    def uuid(self) -> str:
        """
        Provide a unique identifier for the pass, used in Inductor code cache.
        This should depend on the pass implementation, so that changes to the
        pass result in recompilation.
        By default, the object source is hashed.
        """
        return InductorPass.hash_source(self)

    @staticmethod
    def hash_source(*srcs: str | Any) -> str:
        """
        Utility method to hash the sources of functions or objects.
        :param srcs: strings or objects to add to the hash.
        Objects and functions have their source inspected.
        :return:
        """
        hasher = hashlib.sha256()
        for src in srcs:
            if isinstance(src, str):
                src_str = src
            elif isinstance(src, (types.FunctionType, type)):
                src_str = inspect.getsource(src)
            else:
                # object instance
                src_str = inspect.getsource(src.__class__)
            hasher.update(src_str.encode("utf-8"))
        return hasher.hexdigest()

    @staticmethod
    def hash_dict(dict_: dict[Any, Any]) -> str:
        """
        Utility method to hash a dictionary, can alternatively be used for uuid.
        :return: A sha256 hash of the json rep of the dictionary.
        """
        encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
        return hashlib.sha256(encoded).hexdigest()

    def is_applicable_for_range(self, compile_range: Range) -> bool:
        return True

hash_dict staticmethod

hash_dict(dict_: dict[Any, Any]) -> str

Utility method to hash a dictionary, can alternatively be used for uuid. :return: A sha256 hash of the json rep of the dictionary.

Source code in vllm/compilation/inductor_pass.py
@staticmethod
def hash_dict(dict_: dict[Any, Any]) -> str:
    """
    Utility method to hash a dictionary, can alternatively be used for uuid.
    :return: A sha256 hash of the json rep of the dictionary.
    """
    encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
    return hashlib.sha256(encoded).hexdigest()

hash_source staticmethod

hash_source(*srcs: str | Any) -> str

Utility method to hash the sources of functions or objects. :param srcs: strings or objects to add to the hash. Objects and functions have their source inspected. :return:

Source code in vllm/compilation/inductor_pass.py
@staticmethod
def hash_source(*srcs: str | Any) -> str:
    """
    Utility method to hash the sources of functions or objects.
    :param srcs: strings or objects to add to the hash.
    Objects and functions have their source inspected.
    :return:
    """
    hasher = hashlib.sha256()
    for src in srcs:
        if isinstance(src, str):
            src_str = src
        elif isinstance(src, (types.FunctionType, type)):
            src_str = inspect.getsource(src)
        else:
            # object instance
            src_str = inspect.getsource(src.__class__)
        hasher.update(src_str.encode("utf-8"))
    return hasher.hexdigest()

is_applicable_for_range

is_applicable_for_range(compile_range: Range) -> bool
Source code in vllm/compilation/inductor_pass.py
def is_applicable_for_range(self, compile_range: Range) -> bool:
    return True

uuid

uuid() -> str

Provide a unique identifier for the pass, used in Inductor code cache. This should depend on the pass implementation, so that changes to the pass result in recompilation. By default, the object source is hashed.

Source code in vllm/compilation/inductor_pass.py
def uuid(self) -> str:
    """
    Provide a unique identifier for the pass, used in Inductor code cache.
    This should depend on the pass implementation, so that changes to the
    pass result in recompilation.
    By default, the object source is hashed.
    """
    return InductorPass.hash_source(self)

PassContext

Source code in vllm/compilation/inductor_pass.py
class PassContext:
    def __init__(self, compile_range: Range):
        self.compile_range: Range = compile_range

compile_range instance-attribute

compile_range: Range = compile_range

__init__

__init__(compile_range: Range)
Source code in vllm/compilation/inductor_pass.py
def __init__(self, compile_range: Range):
    self.compile_range: Range = compile_range

enable_fake_mode

enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]

Applies a FakeTensorMode context. This is useful when you don't want to create or run things with real tensors.

Source code in vllm/compilation/inductor_pass.py
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
    """
    Applies a FakeTensorMode context. This is useful when you don't want to
    create or run things with real tensors.
    """

    @functools.wraps(fn)
    def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
        with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
            result = fn(*args, **kwargs)

        return result

    return fn_new

get_pass_context

get_pass_context() -> PassContext

Get the current pass context.

Source code in vllm/compilation/inductor_pass.py
def get_pass_context() -> PassContext:
    """Get the current pass context."""
    assert _pass_context is not None
    return _pass_context

pass_context

pass_context(
    compile_range: Range,
) -> Generator[None, None, None]

A context manager that stores the current pass context, usually it is a list of sizes to specialize.

Source code in vllm/compilation/inductor_pass.py
@contextmanager
def pass_context(compile_range: Range) -> Generator[None, None, None]:
    """A context manager that stores the current pass context,
    usually it is a list of sizes to specialize.
    """
    global _pass_context
    prev_context = _pass_context
    _pass_context = PassContext(compile_range)
    try:
        yield
    finally:
        _pass_context = prev_context