Skip to content

Commit 128a65e

Browse files
Dmytro Dzhulgakovfacebook-github-bot
authored andcommitted
Use noop observer to pass dtype for dynamic quantization (pytorch#26709)
Summary: Pull Request resolved: pytorch#26709 Polishes implementation from pytorch#25975. Primarily, we use NoopObserver to communicate that weights need to be quantized to float16. The very top-level API (quantize_dynamic) stays the same with `dtype` argument but the implementation follows the common flow. One can argue that dynamic fp16 quantization doesn't really fit into the 'observer' mechanism. It's in fact not ideal, but it's better to have the same flow than branching on both dtype and qconfig. Test Plan: Imported from OSS Differential Revision: D17544103 Pulled By: dzhulgakov fbshipit-source-id: 6af3f18c35929a1a53ea734079c005f656e4925f
1 parent ae0732c commit 128a65e

File tree

7 files changed

+109
-76
lines changed

7 files changed

+109
-76
lines changed

test/test_quantization.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -485,20 +485,8 @@ def test_quantized_rnn(self):
485485

486486
ref = copy.deepcopy(cell)
487487

488-
qconfig_dynamic_dict = {
489-
torch.nn.LSTM: default_dynamic_qconfig,
490-
}
491-
default_dynamic_module_mapping = {
492-
torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
493-
}
494-
model_int8 = quantize_dynamic(
495-
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
496-
dtype=torch.qint8
497-
)
498-
model_fp16 = quantize_dynamic(
499-
model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
500-
dtype=torch.float16
501-
)
488+
model_int8 = quantize_dynamic(model=model, dtype=torch.qint8)
489+
model_fp16 = quantize_dynamic(model=model, dtype=torch.float16)
502490
cell_int8 = model_int8.lstm
503491
cell_fp16 = model_fp16.lstm
504492

torch/nn/quantized/dynamic/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def from_float(cls, mod):
5252
"""
5353
assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
5454
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
55-
if mod.qconfig is not None and mod.qconfig.weight() is not None:
55+
if mod.qconfig is not None and mod.qconfig.weight is not None:
5656
weight_observer = mod.qconfig.weight()
5757
else:
5858
# We have the circular import issues if we import the qconfig in the beginning of this file:

torch/nn/quantized/dynamic/modules/rnn.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -201,39 +201,35 @@ def __setstate__(self, state):
201201
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
202202

203203
@classmethod
204-
def from_float(cls, mod, dtype=torch.qint8):
204+
def from_float(cls, mod):
205205
assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
206206
assert hasattr(
207207
mod, 'qconfig'), 'Input float module must have qconfig defined'
208208

209+
if mod.qconfig is not None and mod.qconfig.weight is not None:
210+
weight_observer = mod.qconfig.weight()
211+
else:
212+
# We have the circular import issues if we import the qconfig in the beginning of this file:
213+
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
214+
# import until we need it.
215+
from torch.quantization.QConfig import default_dynamic_qconfig
216+
weight_observer = default_dynamic_qconfig.weight()
217+
218+
dtype = weight_observer.dtype
209219
supported_scalar_types = [torch.qint8, torch.float16]
210220
if dtype not in supported_scalar_types:
211-
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
212-
213-
# When dtype = torch.float16, we don't need weight_observer
214-
if dtype == torch.qint8:
215-
if mod.qconfig is not None and mod.qconfig.weight() is not None:
216-
weight_observer = mod.qconfig.weight()
217-
else:
218-
# We have the circular import issues if we import the qconfig in the beginning of this file:
219-
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
220-
# import until we need it.
221-
from torch.quantization.QConfig import default_dynamic_qconfig
222-
weight_observer = default_dynamic_qconfig.weight()
223-
assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
221+
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
224222

225223
if mod.mode == 'LSTM':
226224
qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
227225
mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
226+
else:
227+
raise NotImplementedError('Only LSTM is supported for QuantizedRNN for now')
228228

229229
num_directions = 2 if mod.bidirectional else 1
230230

231231
assert mod.bias
232232

233-
# TODO: support more than just LSTM
234-
if qRNNBase.mode != 'LSTM':
235-
raise RuntimeError('Only LSTM is supported for QuantizedRNN')
236-
237233
qRNNBase._all_weight_names = []
238234
qRNNBase._all_weight_values = []
239235
for layer in range(qRNNBase.num_layers):
@@ -372,5 +368,5 @@ def forward(self, input, hx=None):
372368
return self.forward_tensor(input, hx)
373369

374370
@classmethod
375-
def from_float(cls, mod, dtype=torch.qint8):
376-
return super(LSTM, cls).from_float(mod, dtype)
371+
def from_float(cls, mod):
372+
return super(LSTM, cls).from_float(mod)

torch/quantization/QConfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __new__(cls, weight):
5858
return super(QConfigDynamic, cls).__new__(cls, weight)
5959

6060
default_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
61+
float16_dynamic_qconfig = QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))
6162

6263
default_qat_qconfig = QConfig(activation=default_fake_quant,
6364
weight=default_weight_fake_quant)

torch/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def default_eval_fn(model, calib_data):
2626
'Observer', 'WeightObserver', 'observer', 'default_observer',
2727
'default_weight_observer',
2828
# QConfig
29-
'QConfig', 'default_qconfig', 'default_dynamic_qconfig',
29+
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
3030
# QAT utilities
3131
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
3232
# module transformations

torch/quantization/observer.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,39 @@ def _with_args(cls_or_self, **kwargs):
3939
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
4040

4141

42-
class ObserverBase(ABC, nn.Module):
43-
r"""Observer base Module
44-
Any concrete observer implementation should derive from this class.
42+
class Observer(ABC, nn.Module):
43+
r"""
44+
Observer base Module. Any observer implementation should derive from this class.
4545
4646
Concrete observers should follow the same API. In forward, they will update
4747
the statistics of the observed Tensor. And they should provide a
4848
`calculate_qparams` function that computes the quantization parameters given
4949
the collected statistics.
5050
"""
51+
def __init__(self, dtype):
52+
super(Observer, self).__init__()
53+
self.dtype = dtype
54+
55+
@abstractmethod
56+
def forward(self, x):
57+
pass
58+
59+
@abstractmethod
60+
def calculate_qparams(self, **kwargs):
61+
pass
62+
63+
with_args = classmethod(_with_args)
64+
65+
66+
class _ObserverBase(Observer):
67+
r"""
68+
Common base for all qint/quint8 observers
69+
"""
5170

5271
def __init__(
5372
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
5473
):
55-
super(ObserverBase, self).__init__()
56-
self.dtype = dtype
74+
super(_ObserverBase, self).__init__(dtype=dtype)
5775
self.qscheme = qscheme
5876
self.reduce_range = reduce_range
5977

@@ -71,14 +89,6 @@ def __init__(
7189
torch.quint8,
7290
), "Default Observer only works for qint8 and quint8 data type"
7391

74-
@abstractmethod
75-
def forward(self, x):
76-
pass
77-
78-
@abstractmethod
79-
def calculate_qparams(self, **kwargs):
80-
pass
81-
8292
def _calculate_per_channel_qparams(self, min_vals, max_vals):
8393
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
8494
"""
@@ -158,10 +168,8 @@ def _calculate_qparams(self, min_val, max_val):
158168

159169
return torch.tensor([scale]), torch.tensor([zero_point])
160170

161-
with_args = classmethod(_with_args)
162-
163171

164-
class MinMaxObserver(ObserverBase):
172+
class MinMaxObserver(_ObserverBase):
165173
r"""Default Observer Module
166174
A default implementation of the observer module, only works for
167175
`per_tensor_affine` quantization scheme. The module will record the
@@ -216,7 +224,7 @@ def extra_repr(self):
216224
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
217225

218226

219-
class PerChannelMinMaxObserver(ObserverBase):
227+
class PerChannelMinMaxObserver(_ObserverBase):
220228
r"""Per Channel Observer Module
221229
The module will record the running average of max and min value for each
222230
channel of the observed Tensor and calculate_qparams will calculate
@@ -266,7 +274,7 @@ def extra_repr(self):
266274

267275

268276

269-
class HistogramObserver(ObserverBase):
277+
class HistogramObserver(_ObserverBase):
270278
r"""
271279
The module records the running histogram of tensor values along with
272280
min/max values. calculate_qparams will calculate scale and zero_point
@@ -521,7 +529,7 @@ def calculate_qparams(self):
521529
return self._calculate_qparams(new_min.item(), new_max.item())
522530

523531

524-
class RecordingObserver(ObserverBase):
532+
class RecordingObserver(_ObserverBase):
525533
r"""
526534
The module is mainly for debug and records the tensor values during runtime
527535
"""
@@ -544,6 +552,26 @@ def get_tensor_value(self):
544552
return self.tensor_val
545553

546554

555+
class NoopObserver(Observer):
556+
r"""
557+
Observer that doesn't do anything and just passes its configuration to the
558+
quantized module's ``.from_float()`.
559+
560+
Primarily used for quantization to float16 which doesn't require determining
561+
ranges.
562+
"""
563+
def __init__(self, dtype=torch.float16):
564+
if dtype != torch.float16:
565+
raise ValueError("Only float16 quantization can be used without calibration process")
566+
super(NoopObserver, self).__init__(dtype=dtype)
567+
568+
def forward(self, x):
569+
return x
570+
571+
def calculate_qparams(self):
572+
raise Exception("calculate_qparams should not be called for NoopObserver")
573+
574+
547575
# Restrict activations to be in the range (0,127)
548576
default_observer = MinMaxObserver.with_args(reduce_range=True)
549577
default_debug_observer = RecordingObserver

torch/quantization/quantize.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn._intrinsic.qat as nniqat
99
import torch.nn.quantized as nnq
1010
import torch.nn.quantized.dynamic as nnqd
11-
from .QConfig import default_dynamic_qconfig
11+
from .QConfig import default_dynamic_qconfig, float16_dynamic_qconfig
1212
import torch.nn.qat as nnqat
1313

1414

@@ -256,20 +256,47 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING):
256256
convert(model, mapping)
257257
return model
258258

259-
DEFAULT_QCONFIG_DICT = {
260-
nn.Linear : default_dynamic_qconfig,
261-
nn.LSTM : default_dynamic_qconfig,
262-
}
259+
def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING):
260+
r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
261+
262+
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
263+
264+
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
265+
by default is performed for layers with large weights size - i.e. Linear and RNN variants.
263266
264-
def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8):
265-
r"""Converts a float model to dynamic quantized model.
267+
Fine grained control is possible with `qconfig_dict` and `mapping` that act similarly to `quantize()`.
268+
If `qconfig_dict` is provided, the `dtype` argument is ignored.
266269
267-
Perform dynamic training and output a quantized model.
270+
Args:
271+
module: input model
272+
qconfig_dict: dictionary that maps from name or type of submodule to quantization
273+
configuration, qconfig applies to all submodules of a given
274+
module unless qconfig for the submodules are specified (when the
275+
submodule already has qconfig attribute). Entries in the dictionary
276+
need to be QConfigDynamic instances.
277+
mapping: maps type of a submodule to a type of corresponding dynamically quantized version
278+
with which the submodule needs to be replaced
268279
"""
280+
if qconfig_dict is None:
281+
if dtype == torch.qint8:
282+
qconfig_dict = {
283+
nn.Linear : default_dynamic_qconfig,
284+
nn.LSTM : default_dynamic_qconfig,
285+
}
286+
elif dtype == torch.float16:
287+
qconfig_dict = {
288+
# TODO: uncomment when float16 Linear support is added
289+
# nn.Linear : default_dynamic_qconfig,
290+
nn.LSTM : float16_dynamic_qconfig,
291+
}
292+
else:
293+
raise ValueError(
294+
"Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
295+
269296
model = copy.deepcopy(model)
270297
model.eval()
271298
propagate_qconfig(model, qconfig_dict)
272-
convert(model, mapping, dtype)
299+
convert(model, mapping)
273300
return model
274301

275302
def prepare_qat(model):
@@ -295,7 +322,7 @@ def quantize_qat(model, run_fn, run_args):
295322
convert(model)
296323
return model
297324

298-
def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
325+
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
299326
r"""Converts the float module with observers(where we can get quantization
300327
parameters) to a quantized module.
301328
Args:
@@ -312,13 +339,13 @@ def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
312339

313340
for name, mod in module.named_children():
314341
if type(mod) not in SWAPPABLE_MODULES:
315-
convert(mod, mapping, dtype)
316-
reassign[name] = swap_module(mod, mapping, dtype)
342+
convert(mod, mapping)
343+
reassign[name] = swap_module(mod, mapping)
317344

318345
for key, value in reassign.items():
319346
module._modules[key] = value
320347

321-
def swap_module(mod, mapping, dtype=torch.qint8):
348+
def swap_module(mod, mapping):
322349
r"""Swaps the module if it has a quantized counterpart and it has an
323350
`observer` attached.
324351
@@ -332,14 +359,7 @@ def swap_module(mod, mapping, dtype=torch.qint8):
332359
new_mod = mod
333360
if hasattr(mod, 'qconfig') and mod.qconfig is not None:
334361
if type(mod) in mapping:
335-
supported_scalar_types = [torch.qint8, torch.float16]
336-
if dtype not in supported_scalar_types:
337-
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
338-
if dtype == torch.qint8:
339-
new_mod = mapping[type(mod)].from_float(mod)
340-
elif dtype == torch.float16:
341-
# We want to support float16 dynamic quantization
342-
new_mod = mapping[type(mod)].from_float(mod, dtype)
362+
new_mod = mapping[type(mod)].from_float(mod)
343363
return new_mod
344364

345365
def get_observer_dict(mod, target_dict, prefix=""):

0 commit comments

Comments
 (0)