Skip to content

Commit a9ec8f8

Browse files
committed
add quantizationin practice blogpost
1 parent 855f9d8 commit a9ec8f8

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
---
2+
layout: blog_detail
3+
title: 'Quantization in Practice'
4+
author: Suraj Subramanian
5+
featured-img: ''
6+
---
7+
8+
There are a few different ways to quantize your model with PyTorch. In this blog post, we'll take a look at how each technique looks like in practice. I will use a non-standard model that is not traceable, to paint an accurate picture of how much effort is really needed when quantizing your model.
9+
10+
<div class="text-center">
11+
<img src="/assets/images/quantization_gif.gif" width="60%">
12+
</div>
13+
14+
## What happens when you "quantize" a model?
15+
16+
Two things, generally - the model gets smaller and runs faster. This is because adding and multiplying 8-bit numbers is faster than 32-bit numbers. Loading a smaller model from memory reduces I/O, making models more energy efficient.
17+
18+
> If someone asks you what time it is, you don't respond "10:14:34:430705", but you might say "a quarter past 10".
19+
20+
Quantizing a model means reducing the numerical precision of its weights and/or activations a.k.a information compression. Quantization of deep networks is especially interesting because overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression.
21+
22+
At the heart of it all is a **mapping function**, a linear projection from floating-point to integer space: $Q(r) = round(r/S + Z)$
23+
24+
To reconvert to floating point space, the inverse function is given by $\tilde r = (Q(r) - Z) \cdot S$. $\tilde r \neq r$, and their difference constitutes the **quantization error**.
25+
26+
The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$
27+
where [$\alpha, \beta$] is the clipping range of the input, i.e. the boundaries of permissible inputs. [$\alpha_q, \beta_q$] is the range in quantized output space that it is mapped to. For 8-bit quantization, the output range $\beta_q - \alpha_q <= (2^8 - 1) $.
28+
29+
The process of choosing the appropriate input range is known as **calibration**; commonly used methods are MinMax and Entropy.
30+
31+
The zero-point $Z$ acts as a bias to ensure that a 0 in the input space maps perfectly to a 0 in the quantized space. $Z = -(\frac{\alpha}{S} - \alpha_q)$
32+
33+
34+
### Quantization Schemes
35+
$S, Z$ can be calculated and used for quantizing an entire tensor ("per-tensor"), or individually for each channel ("per-channel").
36+
37+
When [$\alpha, \beta$] are centered around 0, it is called **symmetric quantization**. The range is calculated as $-\alpha = \beta = max(|max(r)|,|min(r)|)$. This removes the need of a zero-point offset in the mapping function. Asymmetric or **affine** schemes simply assign the boundaries to the minimum and maximum observed values. Asymmetric schemes have a tighter clipping range (for non-negative ReLU activations, for instance) but require a non-zero offset.
38+
39+
40+
### PyTorch Classes
41+
`Observer` modules ([docs](https://PyTorch.org/docs/stable/torch.quantization.html?highlight=observer#observers), [code](https://github.com/PyTorch/PyTorch/blob/748d9d24940cd17938df963456c90fa1a13f3932/torch/ao/quantization/observer.py#L88)) collect statistics on the input values and calculate the qparams $S, Z$.
42+
43+
The `QConfig` ([code](https://github.com/PyTorch/PyTorch/blob/d6b15bfcbdaff8eb73fa750ee47cef4ccee1cd92/torch/ao/quantization/qconfig.py#L165)) NamedTuple specifies the observers and quantization schemes for the network's weights and activations. The default qconfig is at `torch.quantization.get_default_qconfig(backend)` where `backend='fbgemm'` for x86 CPU and `backend='qnnpack'` for ARM.
44+
45+
46+
## In PyTorch
47+
48+
PyTorch allows you a few different ways to quantize your model, depending on
49+
- if you prefer a manual, or a more automatic process (*Eager Mode* v/s *FX Graph Mode*)
50+
- if $S, Z$ for quantizing activations (layer outputs) are precomputed for all inputs, or calculated afresh with each input (*static* v/s *dynamic*),
51+
- if $S, Z$ are computed during, or after training (*quantization-aware training* v/s *post-training quantization*)
52+
53+
Each approach has its unique tradeoffs; for instance FX Graph Mode can automagically figure out the right quantization configurations, but only for models that are [symbolically traceable](https://PyTorch.org/docs/stable/fx.html#torch.fx.symbolic_trace). Dynamic quantization can offer better precision at the cost of additional overhead in each inference.
54+
55+
### Post-Training Dynamic Quantization
56+
The model's weights are pre-quantized before inference, but the activations are calibrated and quantized on the fly during inference. The simplest of all approaches, it has a one line API call in `torch.quantization.quantize_dynamic`
57+
58+
Because the calibrations are bespoke, clipping ranges can be tighter; dynamic quantization can therefore theoretically give higher accuracies, but calibrating and quantizing each layer's activations can add to the overhead.
59+
60+
For this reason, it is best suited for models where most of the execution time is spent in loading weights from memory (think very large models with billions of parameters).
61+
62+
```python
63+
import torch
64+
from torch import nn
65+
66+
# toy model
67+
m = nn.Sequential(
68+
nn.Conv1d(2, 64, (8,)),
69+
nn.ReLU(),
70+
nn.Linear(16,10),
71+
nn.LSTM(10, 10))
72+
73+
m.eval()
74+
75+
## EAGER MODE
76+
from torch.quantization import quantize_dynamic
77+
model_quantized = quantize_dynamic(
78+
model=m, qconfig_spec={nn.LSTM, nn.Linear}, dtype=torch.qint8, inplace=False
79+
)
80+
81+
## FX MODE
82+
from torch.quantization import quantize_fx
83+
qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} # An empty key denotes the default applied to all modules
84+
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict)
85+
model_quantized = quantize_fx.convert_fx(model_prepared)
86+
```
87+
88+
### Post-Training Static Quantization
89+
Similar to dynamic quantization, weights are pre-quantized. In dynamic, the activations are calibrated and quantized at inference time. In static, activations are precalibrated by passing sample inputs to the model.
90+
91+
This method has faster inference than dynamic quantization, but is less robust to accuracy drops from out-of-distribution inputs, i.e. if the model encounters data different from the sample inputs it has calibrated.
92+
93+
```python
94+
# Static quantization of a model consists of the following steps:
95+
96+
# Fuse modules
97+
# Insert Quant/DeQuant Stubs
98+
# Prepare the fused module (insert observers before and after)
99+
# Calibrate the prepared module (pass it representative data)
100+
# Convert the calibrated module (replace with quantized version)
101+
102+
import torch
103+
from torch import nn
104+
105+
m = nn.Sequential(
106+
nn.Conv1d(2,64,(8,)),
107+
nn.ReLU(),
108+
nn.Conv1d(64, 128, (1,)),
109+
nn.ReLU()
110+
)
111+
112+
## EAGER MODE
113+
"""Fuse"""
114+
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
115+
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair
116+
117+
"""Insert stubs"""
118+
m = nn.Sequential(torch.quantization.QuantStub(),
119+
*m,
120+
torch.quantization.DeQuantStub())
121+
122+
"""Prepare"""
123+
m.qconfig = torch.quantization.get_default_qconfig('fbgemm')
124+
torch.quantization.prepare(m, inplace=True)
125+
126+
"""Calibrate"""
127+
with torch.inference_mode():
128+
for _ in range(10):
129+
x = torch.rand(1,2,28)
130+
m(x)
131+
132+
"""Convert"""
133+
torch.quantization.convert(m, inplace=True)
134+
135+
"""Check"""
136+
print(m[1].weight().element_size()) # 1 instead of 4 for FP32
137+
138+
139+
## FX GRAPH
140+
from torch.quantization import quantize_fx
141+
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
142+
model_to_quantize.eval()
143+
# Prepare
144+
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
145+
# Calibrate
146+
with torch.inference_mode():
147+
for _ in range(10):
148+
x = torch.rand(1,2,28)
149+
model_prepared(x)
150+
# quantize
151+
model_quantized = quantize_fx.convert_fx(model_prepared)
152+
```
153+
154+
### Quantization-aware Training (QAT)
155+
The previous two methods are to quantize FP32 models *after* they have been trained. Although they work surprisingly well ([citation to some case study](google.com)), they're still subject to prediction errors arising from the drop in numerical precision.
156+
157+
<p align="center">
158+
<img src="/assets/images/ptq_vs_qat.png" alt="Fig. 6: Comparison of PTQ and QAT" width="100%">
159+
<br>
160+
Wu, Hao, et al. "Integer quantization for deep learning inference: Principles and empirical evaluation." arXiv preprint arXiv:2004.09602 (2020)
161+
</p>
162+
163+
164+
Figure 6(a) shows why. In PTQ, the FP32 model's parameters are chosen by optimizing on the FP32 loss, and then projected to INT8 space. Depending on where the new INT8 weights lie on the loss curve, model accuracies can significantly change.
165+
166+
In QAT, the FP32 parameters are chosen by also optimizing on the INT8 loss. This allows the model to identify a wider region in the loss function (Figure 6(b)), and identify FP32 parameters such that quantizing them does not significantly affect accuracy.
167+
168+
It's likely that you can still use QAT by "fine-tuning" it on a sample of the training dataset, but I did not try it on demucs (yet).
169+
170+
171+
## Quantizing "real-world" models
172+
173+
Traceable models can be easily quantized with FX Graph Mode, but it's possible the model you're using is not traceable end-to-end. Maybe it has loops or `if` statements on inputs (dynamic control flow), or relies on third-party libraries. The model I use in this example has [dynamic control flow and uses third-party libraries](https://github.com/facebookresearch/demucs/blob/v2/demucs/model.py). As a result, it cannot be symbolically traced directly. In this code walkthrough, I show how you can bypass this limitation by quantizing the child modules individually for FX Graph Mode, and how to patch Quant/DeQuant stubs in Eager Mode.
174+
175+
Download the [notebook](https://gist.github.com/suraj813/735357e56321237950a0348b50f2f3b4) or run it on [Colab](https://colab.research.google.com/gist/suraj813/735357e56321237950a0348b50f2f3b4/fx-and-eager-mode-quantization-example.ipynb) (note that Colab runtimes may differ significantly from local machines).
176+
177+
178+
## What's next - Define-by-Run Quantization
179+
At PyTorch Dev Day 2021, we got a sneak peek into the next cool feature in PyTorch quantization - [define-by-run](https://s3.amazonaws.com/assets.pytorch.org/ptdd2021/posters/C8.png). DBR attempts to improve usability by resolving the problem of model non-traceability.
180+
For example, this dynamic control flow block would not be traceable:
181+
182+
```python
183+
def forward(self, x):
184+
# ....
185+
if x.shape[0] == 1:
186+
assert x.shape[1] == 1
187+
# .....
188+
```
189+
190+
As you might have seen in the real-world example above, refactoring the model can require effort. An early prototype of DBR aims to eliminate this cost. DBR dynamically traces the program, captures the subgraphs having quantizable ops, and performs the quantization transforms only on these subgraphs. The rest of the program is executed as-is. Although the if-block above operates on an input variable `x`, it does not perform any quantizable operation. DBR would not require this to be traced, and it can be executed as-is.
191+
192+
DBR is an early prototype [code](https://github.com/PyTorch/PyTorch/tree/master/torch/ao/quantization/_dbr) but feel to play around and provide feedback via Github Issues.
193+
194+
195+
196+
## References
197+
[Quantization Docs](https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization)
198+
[Integer quantization for deep learning inference: Principles and empirical evaluation (arxiv)](https://arxiv.org/abs/2004.09602)
199+
[A Survey of Quantization Methods for Efficient Neural Network Inference (arxiv)](https://arxiv.org/pdf/2103.13630.pdf)
200+
arxiv

assets/images/ptq_vs_qat.png

74 KB
Loading

assets/images/quantization_gif.gif

530 KB
Loading

0 commit comments

Comments
 (0)