Speed up models with torch.compile
Learn how to use PyTorch's torch.compile with Replicate to speed up your models.
Table of contents
This guide covers how to implement torch.compile
in your PyTorch models to improve inference performance and what Replicate does
to help cache your compiled artifacts between model runs.
What is torch.compile?
torch.compile
is a feature that analyzes your PyTorch code as it runs and compiles it into optimized CUDA kernels. It works by:
- Instrumentation: When you call
torch.compile(f)
, it decorates your function for analysis but doesn’t trigger computation yet. - Tracing: The first call to the decorated function traces the code execution and compiles it into optimized kernels. which are cached on disk.
- Optimization: Subsequent calls run the compiled code, which can be significantly faster than the original Python code.
Here’s a simple example:
def f(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
# Decorate f for instrumentation
opt_f = torch.compile(f)
# First call traces and compiles (slower)
print(opt_f(torch.randn(10, 10), torch.randn(10, 10)))
# Subsequent calls run optimized code (faster)
for x in range(100):
print(opt_f(torch.randn(10, 10), torch.randn(10, 10)))
How does it work with Replicate?
- When a model container starts, it looks for cached
torch.compile
artifacts. - If found, Torch reuses them instead of re-compiling from scratch.
- When containers shut down gracefully, the cache is updated if needed.
- Cache files are keyed on model version and stored close to GPU nodes.
Prerequisites
Before you start using torch.compile
in your model, make sure you have:
- An existing cog model to work with. If you don’t have one yet, check out our guide to pushing your own model first.
- PyTorch 2.0 or later installed.
- A modern GPU (A100 or newer) for best results.
- A model that benefits from compilation (not all models see improvements).
When to use torch.compile
Models that commonly benefit include models with
- Complex computational graphs such as LLMs and diffusion models
- Repeated computational patterns
- High enough traffic to serve multiple predictions before an instance shuts off to amortize the increased boot up time across multiple predictions.
torch.compile
models are typically slower to setup than non-compiled models, even with caching.
Please make sure to measure your overall performance so that you understand the interplay slower setups and faster inferences.
Performance expectations
Performance varies significantly between models. We have on our testing seen examples of
- Cold boots: 2x faster when using cached artifacts.
- Inference times: 20% and up faster compared to non-
torch.compiled
models.
Calling torch.compile
does not trigger computation, but signals to torch to compile and cache the function when first called.
To have consistent inference speed, we suggest making an inference call from your setup()
function to trigger compilation
before your models starts accepting incoming calls.
Additional information
To learn how to use torch.compile
, check out the official PyTorch torch.compile tutorial.