My first ever blog on Deep learning stuff. This topic may not be super beginner friendly but I wanted to just start somewhere so here we go.
Model optimization is a crucial part of ML Engineering because time is money and who wants to waste time on training models right?
In PyTorch 2.0 something new was introduced called torch.compile
and it was something game changing. It feels one of those tools that seems like magic when you use it.
So whats the big deal?
To keep things simple and straight forward lets take an analogy of a Coffee Shop
The standard PyTorch way( we call this the eager mode) is having a sets of routines which are well meant but are inefficient.
This is like a barrista taking an order, walking to the grinder, grinding the beans, then walking back to the counter to check the order slip again. Then they walk back to the machine, tamp the coffee, walk back… you get the picture.
Each individual action (walk, grind, tamp) is a separate step. In PyTorch terms, each step is a separate kernel launch on the GPU, and the “walking back and forth” is the Python overhead. The GPU often ends up just waiting for Python to give it the next small instruction.
Enter torch.compile
Your entire coffee shop workflow is redesigned by torch.compile, who serves as a master consultant. This compiler is Just-In-Time (JIT). It observes the barista preparing a single coffee, comprehends the entire process (grind, tamp, pull shot, steam milk, pour), and then revises the rulebook.
On the graph, it says, “Avoid walking back and forth. Establish a single station where you can do all these tasks with one continuous movement.” This is operator fusion. So, instead of five separate tasks with a lot of downtime, you get one continuous and highly optimized process. The result? More coffee, quicker. More model training, quicker.
How to use it?
We will be using AlexNet for a simple demonstration
We will be now using the normal traditional pytorch code for evaluating our model’s speed
import torch # Importing Libraries ..
import torchvision.models as models
import time
device="cuda" if torch.cuda.is_available() else "cpu" # Using NVIDIA GPU for compiling
print(f"Using Device :{device}")
weights=models.AlexNet_Weights.DEFAULT
model=models.alexnet(weights=weights).to(device) #Loading AlexNet weights and model
dummy_input=torch.randn(32,3,224,224,device=device) # Example Input with 32 batch size and RGB channels wuth 224x224 dimension of image
model.eval()
with torch.no_grad():
_ = model(dummy_input)
torch.cuda.synchronize() # Wait for all GPU work to finish
start_time = time.time()
with torch.no_grad():
for _ in range(100):
_ = model(dummy_input)
torch.cuda.synchronize() # Wait for all GPU work to finish
end_time = time.time()
baseline_time = end_time - start_time
print(f"Baseline time for 100 batches: {baseline_time:.4f} seconds")
To compile the model using torch.compile.. You just need to
compiled_model = torch.compile(model)
and that’s it. Yes
with torch.no_grad():
_ = compiled_model(dummy_input)
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
for _ in range(100):
_ = compiled_model(dummy_input)
torch.cuda.synchronize()
end_time = time.time()
compiled_time = end_time - start_time
print(f"Compiled time for 100 batches: {compiled_time:.4f} seconds")
print(f"\nSpeedup: {baseline_time / compiled_time:.2f}x")
After running this, you should see a noticeable speedup, often in the range of 30-50% or more for a model like AlexNet. That’s a significant performance boost for adding a single function call.
Why this works?
Under the hood, torch.compile uses the following things :
1)TorchDynamo
The job of Dynamo is to acquire your code as a computational graph. Its brilliance lies in its safety. Think of it this way. Suppose you have a print(x.shape) statement somewhere in the forward of your model for debugging purposes. A previous compiler, that was a bit more restrictive, would simply crash as soon it encountered a print statement for which it did not know how to compile. TorchDynamo is smarter. It simply recognizes that construct and creates a “graph break”. At this point, it stops the compilation, allows the print statement to be run by the standard Python interpreter, and then it resumes compiling the code immediately after. This is why it works so well for the overwhelming majority of real-world models, which are full of such pythonic oddities.
2)TorchInductor
Many of the guts is done in this phase. The TorchInductor takes the clean graphs produced from Dynamo and applies aggressive optimizations. The most important of these optimizations is operator fusion. For instance, a “standard” sequence in a vision model might be a convolution (conv) operator, followed by a batchNorm (bn) operator, followed by a relu (relu) activation. If you run that in eager mode, it’s three calls to the GPU. TorchInductor can automatically fuse those three calls into one execution, one optimized kernel. This optimization reduces overhead but can also reduce memory, since the intermediate results between the layers we’re fusing do not always have to be written out to global memory.
3)Backends
Once TorchInductor has its optimized plan, it still needs to produce machine code. There are various backends for various hardware. For NVIDIA GPUs, it generates extremely fast CUDA code via Triton. For CPUs, it uses C++/OpenMP. This means torch.compile can be an all-in-one tool to accelerate your code on a variety of hardware types.
Parameters and Customization options for torch.compile
The default options are great, but torch.compile
allows you to turn a few knobs for additional control, specifically through the mode
argument.
Default
: A good compromise between compilation time and runtime speed. Great for most applications.
reduce-overhead
: For really small models. It avoids wasting time compiling when the model runs so fast that compilation time itself could become the slowest part of the process.
max-autotune
: This is for when you’re ready to go to production. It takes a lot longer to compile the model, as it will run many different kernels and configurations to identify the fastest model for your input shapes and your hardware.
optimized_model = torch.compile(model, mode="max-autotune")
Reasons to use it
a) You get speed for free
b) You dont have to deal with low level optimization headaches
c) Its a single line code