Skip to content

torch.compile with Multimodal Encoders

torch.compile can now be applied to multimodal encoders and miscellaneous nn modules in vLLM, including vision-language models like LLaMA 4, Qwen-VL, and similar encoder-based architectures.

This document covers the basics of how the torch.compile integration works for multimodal encoders in vLLM, as well as how to apply the decorator to new models to improve performance.

Note

For general information about torch.compile integration in vLLM, see the torch.compile design document.

Overview

We have recently enabled the @supports_torch_compile decorator to work for multiple nn module components within a model type; this enables turning compile on for multimodal encoders, bringing performance improvements to additional components of the stack.

When applied to the vision block of Qwen2_5_vl we observe ~4.5% e2e perf improvements with some increase in compilation time

This feature is off by default, but can be enabled by setting compile_mm_encoder: true in the compilation config when models have the @supports_torch_compile decorator.

How Compilation Works for Multimodal Components

APIs for Enablement

To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings:

  1. The @supports_torch_compile decorator should include enable_if=should_torch_compile_mm_vit. This will gate the compilation behind our compile_mm_encoder configuration

  2. with set_model_tag("<component_name>", is_encoder=True) context manager should be used around the nn.Module's instantiation. Since torch.compile relies on caching artifacts to reduce start time, we must properly propagate the <component_name> information to the cache in order to avoid collisions with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). is_encoder=True is also needed for encoder components (see Compile Range Integration).

  3. with set_forward_context context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed for torch.compile integration.

CompilationConfig

With the exception of compile_mm_encoder: true, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend this for more configuration in the future.

Applying torch.compile to a New Multimodal Model/Component

To apply supports_torch_compile to a new general nn.Module, we advise following the same steps in debug_vllm_compile; this includes:

  1. Applying supports_torch_compile on initially small modules (such as basic MLP layers), then raising to more general modules until one reaches a good performance tradeoff

  2. Leveraging tlparse to identify and eliminate the source of recompiles and graph breaks

  3. Using dynamic_arg_dims and proper dynamic_shapes_config to handle dynamism.

Common pitfalls

VllmBackend Feature Support

Compile ranges

The torch.compile integration will try to rely on max_batch_size to infer compilation ranges for dynamic shapes; however, for modules used in the encoder, this shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on is_encoder=True in the set_model_tag to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).

Note

We may seek to tighten this range for better performance in the future

Cudagraphs

We have not yet explored compilation for multimodal encoders with CUDAGraph integration; behavior is currently unspecified.

Troubleshooting

Graph Breaks in Vision Encoders

Some vision encoder operations may cause graph breaks. To identify them:

TORCH_LOGS="+dynamo" vllm serve <MODEL>

Common causes of graph breaks in multimodal models:

  • Dynamic image sizes: Use dynamic_shapes_config to handle variable resolutions
  • Untraceable operations: Some operations (such as to_list) may not be supported by Dynamo
  • Conditional processing: Data-dependent branching based on image properties

Compilation Errors

If compilation fails for a multimodal model:

  1. Disable and test: First verify the model works without compilation:

    VLLM_TORCH_COMPILE_LEVEL=0 vllm serve <model> --compilation-config='{"compile_mm_encoder":"false"}'
    

  2. Check logs: Enable debug logging to see compilation details:

    VLLM_LOGGING_LEVEL=DEBUG vllm serve <model> --compilation-config='{"compile_mm_encoder":"true"}'
    

  3. Report issues: If you find a bug, open an issue on GitHub

See Also