Skip to content

Commit 7a5aca8

Browse files
committed
improve readability, add detail
1 parent 1e0dff6 commit 7a5aca8

File tree

1 file changed

+125
-32
lines changed

1 file changed

+125
-32
lines changed

_posts/2022-1-19-quantization-in-practice.md

Lines changed: 125 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,130 @@ There are a few different ways to quantize your model with PyTorch. In this blog
1515

1616
> If someone asks you what time it is, you don't respond "10:14:34:430705", but you might say "a quarter past 10".
1717
18-
Quantization has roots in information compression; in deep networks it refers to reducing the numerical precision of its weights and/or activations. Overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression. When you quantize a model, two things generally happen - the model gets smaller and runs with better efficiency. Processing 8-bit numbers is faster than 32-bit numbers, and a smaller model has lower memory footprint and power consumption.
18+
Quantization has roots in information compression; in deep networks it refers to reducing the numerical precision of its weights and/or activations.
19+
20+
Overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression [[1](https://arxiv.org/pdf/2103.13630.pdf)]. When you quantize a model, two things generally happen - the model gets smaller and runs with better efficiency. Hardware vendors explicitly allow for faster processing of 8-bit data (than 32-bits). A smaller model has lower memory footprint and power consumption [[2](https://arxiv.org/pdf/1806.08342.pdf)], crucial for deployment at the edge.
1921

2022
At the heart of it all is a **mapping function**, a linear projection from floating-point to integer space: $Q(r) = round(r/S + Z)$
2123

22-
To reconvert to floating point space, the inverse function is given by $\tilde r = (Q(r) - Z) \cdot S$. $\tilde r \neq r$, and their difference constitutes the **quantization error**.
24+
To reconvert to floating point space, the inverse function is given by $\tilde r = (Q(r) - Z) \cdot S$. $\tilde r \neq r$, and their difference constitutes the *quantization error*.
2325

2426
The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$
2527
where [$\alpha, \beta$] is the clipping range of the input, i.e. the boundaries of permissible inputs. [$\alpha_q, \beta_q$] is the range in quantized output space that it is mapped to. For 8-bit quantization, the output range $\beta_q - \alpha_q <= (2^8 - 1) $.
2628

27-
The process of choosing the appropriate input range is known as **calibration**; commonly used methods are MinMax and Entropy.
28-
2929
The zero-point $Z$ acts as a bias to ensure that a 0 in the input space maps perfectly to a 0 in the quantized space. $Z = -(\frac{\alpha}{S} - \alpha_q)$
3030

31-
32-
### Quantization Schemes
3331
$S, Z$ can be calculated and used for quantizing an entire tensor ("per-tensor"), or individually for each channel ("per-channel").
3432

35-
When [$\alpha, \beta$] are centered around 0, it is called **symmetric quantization**. The range is calculated as
36-
$-\alpha = \beta = max(|max(r)|,|min(r)|)$. This removes the need of a zero-point offset in the mapping function. Asymmetric or **affine** schemes simply assign the boundaries to the minimum and maximum observed values. Asymmetric schemes have a tighter clipping range (for non-negative ReLU activations, for instance) but require a non-zero offset.
3733

34+
### Calibration
35+
The process of choosing the input range is known as **calibration**. The simplest technique (also the default in PyTorch) is to record the running mininmum and maximum values and assign them to $\alpha$ and $\beta$. In PyTorch, `Observer` modules ([docs](https://PyTorch.org/docs/stable/torch.quantization.html?highlight=observer#observers), [code](https://github.com/PyTorch/PyTorch/blob/748d9d24940cd17938df963456c90fa1a13f3932/torch/ao/quantization/observer.py#L88)) collect statistics on the input values and calculate the qparams $S, Z$.
36+
37+
```python
38+
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
39+
C, L = 3, 4
40+
normal = torch.distributions.normal.Normal(0,1)
41+
inputs = [normal.sample((C, L)), normal.sample((C, L))]
42+
print(inputs)
43+
44+
# >>>>>
45+
# [tensor([[-0.0590, 1.1674, 0.7119, -1.1270],
46+
# [-1.3974, 0.5077, -0.5601, 0.0683],
47+
# [-0.0929, 0.9473, 0.7159, -0.4574]]]),
48+
49+
# tensor([[-0.0236, -0.7599, 1.0290, 0.8914],
50+
# [-1.1727, -1.2556, -0.2271, 0.9568],
51+
# [-0.2500, 1.4579, 1.4707, 0.4043]])]
52+
53+
observers = [MinMaxObserver(), MovingAverageMinMaxObserver(), HistogramObserver()]
54+
for obs in observers:
55+
for x in inputs: obs(x)
56+
print(obs.__class__.__name__, obs.calculate_qparams())
57+
58+
# >>>>>
59+
# MinMaxObserver (tensor([0.0112]), tensor([124], dtype=torch.int32))
60+
# MovingAverageMinMaxObserver (tensor([0.0101]), tensor([139], dtype=torch.int32))
61+
# HistogramObserver (tensor([0.0100]), tensor([106], dtype=torch.int32))
62+
```
63+
64+
### Affine and Symmetric Quantization Schemes
65+
Affine or asymmetric quantization schemes assign the input range to the min and max observed values. Affine schemes offer tighter clipping ranges and are useful for quantizing non-negative activations (you don't need the input range to contain negative values if your input tensors are never negative). The range is calculated as
66+
$\alpha = min(r), \beta = max(r)$.
67+
68+
Symmetric quantization schemes center the input range around 0, eliminating the need to calculate a zero-point offset. The range is calculated as
69+
$-\alpha = \beta = max(|max(r)|,|min(r)|)$.
3870

39-
### PyTorch Classes
40-
`Observer` modules ([docs](https://PyTorch.org/docs/stable/torch.quantization.html?highlight=observer#observers), [code](https://github.com/PyTorch/PyTorch/blob/748d9d24940cd17938df963456c90fa1a13f3932/torch/ao/quantization/observer.py#L88)) collect statistics on the input values and calculate the qparams $S, Z$.
71+
```python
72+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
73+
obs = MovingAverageMinMaxObserver(qscheme=qscheme)
74+
for x in inputs: obs(x)
75+
print(f"Qscheme: {qscheme} | {obs.calculate_qparams()}")
76+
77+
# >>>>>
78+
# Qscheme: torch.per_tensor_affine | (tensor([0.0101]), tensor([139], dtype=torch.int32))
79+
# Qscheme: torch.per_tensor_symmetric | (tensor([0.0109]), tensor([128]))
80+
```
81+
82+
### Per-Tensor and Per-Channel Quantization Schemes
83+
Quantization parameters can be calculated for the layer's entire weight tensor as a whole, or separately for each channel
84+
85+
```python
86+
from torch.quantization.observer import MovingAveragePerChannelMinMaxObserver
87+
obs = MovingAveragePerChannelMinMaxObserver(ch_axis=0) # calculate qparams for all `C` channels separately
88+
for x in inputs: obs(x)
89+
print(obs.calculate_qparams())
90+
91+
# >>>>>
92+
# (tensor([0.0090, 0.0075, 0.0055]), tensor([125, 187, 82], dtype=torch.int32))
93+
```
94+
Per-channel quantization provides better accuracies in convolutional networks. Per-tensor performs poorly due to high variance in conv weights from batchnorm scaling. [[2](https://arxiv.org/pdf/1806.08342.pdf)]
4195

42-
The `QConfig` ([code](https://github.com/PyTorch/PyTorch/blob/d6b15bfcbdaff8eb73fa750ee47cef4ccee1cd92/torch/ao/quantization/qconfig.py#L165)) NamedTuple specifies the observers and quantization schemes for the network's weights and activations. The default qconfig is at `torch.quantization.get_default_qconfig(backend)` where `backend='fbgemm'` for x86 CPU and `backend='qnnpack'` for ARM.
96+
### QConfig
4397

98+
The `QConfig` ([code](https://github.com/PyTorch/PyTorch/blob/d6b15bfcbdaff8eb73fa750ee47cef4ccee1cd92/torch/ao/quantization/qconfig.py#L165), [docs](https://pytorch.org/docs/stable/torch.quantization.html?highlight=qconfig#torch.quantization.QConfig)) NamedTuple stores the Observers and the quantization schemes used to quantize activations and weights.
4499

45-
## In PyTorch
100+
Be sure to pass the Observer class (not the instance), or a callable that can return Observer instances. Use `with_args()` to override the default arguments.
101+
102+
```python
103+
my_qconfig = torch.quantization.QConfig(
104+
activation=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_affine),
105+
weight=MovingAveragePerChannelMinMaxObserver.with_args(qscheme=torch.qint8)
106+
)
107+
# >>>>>
108+
# QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, qscheme=torch.per_tensor_affine){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, qscheme=torch.qint8){})
109+
```
110+
111+
### Backend
112+
Currently, quantized operators run on x86 machines via the [FBGEMM backend](https://github.com/pytorch/FBGEMM), or use [QNNPACK](https://github.com/pytorch/QNNPACK) primitives on ARM machines. Backend support for server GPUs (via TensorRT and cuDNN) is coming soon. Learn more about extending quantization to custom backends: [RFC-0019](https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md).
113+
114+
```python
115+
backend = 'fbgemm' if x86 else 'qnnpack'
116+
qconfig = torch.quantization.get_default_qconfig(backend)
117+
torch.backends.quantized.engine = backend
118+
```
119+
120+
121+
## Techniques in PyTorch
46122

47123
PyTorch allows you a few different ways to quantize your model on the CPU, depending on
48-
- if you prefer a manual, or a more automatic process (*Eager Mode* v/s *FX Graph Mode*)
124+
- if you prefer a flexible but manual, or a restricted automagic process (*Eager Mode* v/s *FX Graph Mode*)
49125
- if $S, Z$ for quantizing activations (layer outputs) are precomputed for all inputs, or calculated afresh with each input (*static* v/s *dynamic*),
50126
- if $S, Z$ are computed during, or after training (*quantization-aware training* v/s *post-training quantization*)
51127

52-
Each approach has its unique tradeoffs; for instance FX Graph Mode can automagically figure out the right quantization configurations, but only for models that are [symbolically traceable](https://PyTorch.org/docs/stable/fx.html#torch.fx.symbolic_trace). Dynamic quantization can offer better precision at the cost of additional overhead in each inference.
128+
FX Graph Mode automatically fuses eligible modules, inserts Quant/DeQuant stubs, calibrates the model and throws out a quantized module - all in two method calls - but only for networks that are [symbolic traceable](https://PyTorch.org/docs/stable/fx.html#torch.fx.symbolic_trace). The examples below contain the calls using Eager Mode and FX Graph Mode for comparison.
129+
130+
In DNNs, eligible candidates for quantization are the FP32 weights (layer parameters) and activations (layer outputs). Quantizing weights reduces the model size. Quantized activations typically result in faster inference.
53131

54-
### Post-Training Dynamic Quantization
55-
The model's weights are pre-quantized before inference, but the activations are calibrated and quantized on the fly during inference. The simplest of all approaches, it has a one line API call in `torch.quantization.quantize_dynamic`
132+
As an example, the 50-layer ResNet network has ~26 million weight parameters and computes ~16 million activations in the forward pass.
56133

57-
Because the calibrations are bespoke, clipping ranges can be tighter; dynamic quantization can therefore theoretically give higher accuracies, but calibrating and quantizing each layer's activations can add to the overhead.
134+
### Post-Training Dynamic/Weight-only Quantization
135+
Here the model's weights are pre-quantized; the activations are quantized on-the-fly ("dynamic") during inference. The simplest of all approaches, it has a one line API call in `torch.quantization.quantize_dynamic`.
58136

59-
For this reason, it is best suited for models where most of the execution time is spent in loading weights from memory (think very large models with billions of parameters).
137+
**(+)** Can result in higher accuracies since the clipping range is exactly calibrated for each input [1].
138+
139+
**(+)** Dynamic quantization is preferred for models like LSTMs and Transformers where writing/retrieving the model's weights from memory dominate bandwidths [4].
140+
141+
**(-)** Calibrating and quantizing the activations at each layer during runtime can add to the compute overhead.
60142

61143
```python
62144
import torch
@@ -85,9 +167,12 @@ model_quantized = quantize_fx.convert_fx(model_prepared)
85167
```
86168

87169
### Post-Training Static Quantization
88-
Similar to dynamic quantization, weights are pre-quantized. In dynamic, the activations are calibrated and quantized at inference time. In static, activations are precalibrated by passing sample inputs to the model.
170+
Model weights are pre-quantized; activations are pre-calibrated by observers using validation data and stay in quantized precision between operations. About 100 mini-batches of representative data are sufficient to calibrate the observers [2]. The examples below use random data in calibration for convenience - using that in your application will result in bad qparams.
171+
172+
**(+)** Static quantization has faster inference than dynamic quantization because it eliminates the float<->int conversion costs between layers.
173+
174+
**(-)** Static quantized models may need regular re-calibration to stay robust against distribution-drift.
89175

90-
This method has faster inference than dynamic quantization, but is less robust to accuracy drops from out-of-distribution inputs, i.e. if the model encounters data different from the sample inputs it has calibrated.
91176

92177
```python
93178
# Static quantization of a model consists of the following steps:
@@ -101,6 +186,8 @@ This method has faster inference than dynamic quantization, but is less robust t
101186
import torch
102187
from torch import nn
103188

189+
backend = "fbgemm" # running on a x86 CPU. Use "qnnpack" if running on ARM.
190+
104191
m = nn.Sequential(
105192
nn.Conv1d(2,64,(8,)),
106193
nn.ReLU(),
@@ -109,7 +196,9 @@ m = nn.Sequential(
109196
)
110197

111198
## EAGER MODE
112-
"""Fuse"""
199+
"""Fuse
200+
- Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules
201+
"""
113202
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
114203
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair
115204

@@ -119,29 +208,31 @@ m = nn.Sequential(torch.quantization.QuantStub(),
119208
torch.quantization.DeQuantStub())
120209

121210
"""Prepare"""
122-
m.qconfig = torch.quantization.get_default_qconfig('fbgemm')
211+
m.qconfig = torch.quantization.get_default_qconfig(backend)
123212
torch.quantization.prepare(m, inplace=True)
124213

125-
"""Calibrate"""
214+
"""Calibrate
215+
- This example uses random data for convenience. Use representative (validation) data instead.
216+
"""
126217
with torch.inference_mode():
127218
for _ in range(10):
128-
x = torch.rand(1,2,28)
219+
x = torch.rand(1,2,28)
129220
m(x)
130221

131222
"""Convert"""
132223
torch.quantization.convert(m, inplace=True)
133224

134225
"""Check"""
135-
print(m[1].weight().element_size()) # 1 instead of 4 for FP32
226+
print(m[1].weight().element_size()) # 1 byte instead of 4 bytes for FP32
136227

137228

138229
## FX GRAPH
139230
from torch.quantization import quantize_fx
140-
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
141-
model_to_quantize.eval()
231+
m.eval()
232+
qconfig_dict = {"": torch.quantization.get_default_qconfig(backend)}
142233
# Prepare
143234
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
144-
# Calibrate
235+
# Calibrate - Use representative (validation) data.
145236
with torch.inference_mode():
146237
for _ in range(10):
147238
x = torch.rand(1,2,28)
@@ -171,7 +262,7 @@ It's likely that you can still use QAT by "fine-tuning" it on a sample of the tr
171262

172263
**Download the [notebook](https://gist.github.com/suraj813/735357e56321237950a0348b50f2f3b4) or run it on [Colab](https://colab.research.google.com/gist/suraj813/735357e56321237950a0348b50f2f3b4/fx-and-eager-mode-quantization-example.ipynb) (note that Colab runtimes may differ significantly from local machines).**
173264

174-
Traceable models can be easily quantized with FX Graph Mode, but it's possible the model you're using is not traceable end-to-end. Maybe it has loops or `if` statements on inputs (dynamic control flow), or relies on third-party libraries. The model I use in this example has [dynamic control flow and uses third-party libraries](https://github.com/facebookresearch/demucs/blob/v2/demucs/model.py). As a result, it cannot be symbolically traced directly. In this code walkthrough, I show how you can bypass this limitation by quantizing the child modules individually for FX Graph Mode, and how to patch Quant/DeQuant stubs in Eager Mode.
265+
Traceable models can be easily quantized with FX Graph Mode, but it's possible the model you're using is not traceable end-to-end. Maybe it has loops or `if` statements on inputs (dynamic control flow), or relies on third-party libraries. The model we use in this example has [dynamic control flow and uses third-party libraries](https://github.com/facebookresearch/demucs/blob/v2/demucs/model.py). As a result, it cannot be symbolically traced directly. In this code walkthrough, I show how you can bypass this limitation by quantizing the child modules individually for FX Graph Mode, and how to patch Quant/DeQuant stubs in Eager Mode.
175266

176267

177268

@@ -194,6 +285,8 @@ DBR is an early prototype [code](https://github.com/PyTorch/PyTorch/tree/master/
194285

195286

196287
## References
197-
- [Quantization Docs](https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization)
198-
- [Integer quantization for deep learning inference: Principles and empirical evaluation (arxiv)](https://arxiv.org/abs/2004.09602)
199-
- [A Survey of Quantization Methods for Efficient Neural Network Inference (arxiv)](https://arxiv.org/pdf/2103.13630.pdf)
288+
1. [A Survey of Quantization Methods for Efficient Neural Network Inference (arxiv)](https://arxiv.org/pdf/2103.13630.pdf)
289+
2. [Quantizing Deep Convolutional Networks for Efficient Inference (arxiv)](https://arxiv.org/pdf/1806.08342.pdf)
290+
2. [Integer quantization for deep learning inference: Principles and empirical evaluation (arxiv)](https://arxiv.org/abs/2004.09602)
291+
3. [PyTorch Quantization Docs](https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization)
292+
4. [Introduction to Quantization in PyTorch](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/)

0 commit comments

Comments
 (0)