88import torch .nn ._intrinsic .qat as nniqat
99import torch .nn .quantized as nnq
1010import torch .nn .quantized .dynamic as nnqd
11- from .QConfig import default_dynamic_qconfig
11+ from .QConfig import default_dynamic_qconfig , float16_dynamic_qconfig
1212import torch .nn .qat as nnqat
1313
1414
@@ -256,20 +256,47 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING):
256256 convert (model , mapping )
257257 return model
258258
259- DEFAULT_QCONFIG_DICT = {
260- nn .Linear : default_dynamic_qconfig ,
261- nn .LSTM : default_dynamic_qconfig ,
262- }
259+ def quantize_dynamic (model , qconfig_dict = None , dtype = torch .qint8 , mapping = DEFAULT_DYNAMIC_MODULE_MAPPING ):
260+ r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
261+
262+ Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
263+
264+ For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
265+ by default is performed for layers with large weights size - i.e. Linear and RNN variants.
263266
264- def quantize_dynamic ( model , qconfig_dict = DEFAULT_QCONFIG_DICT , mapping = DEFAULT_DYNAMIC_MODULE_MAPPING , dtype = torch . qint8 ):
265- r"""Converts a float model to dynamic quantized model .
267+ Fine grained control is possible with ` qconfig_dict` and ` mapping` that act similarly to `quantize()`.
268+ If `qconfig_dict` is provided, the `dtype` argument is ignored .
266269
267- Perform dynamic training and output a quantized model.
270+ Args:
271+ module: input model
272+ qconfig_dict: dictionary that maps from name or type of submodule to quantization
273+ configuration, qconfig applies to all submodules of a given
274+ module unless qconfig for the submodules are specified (when the
275+ submodule already has qconfig attribute). Entries in the dictionary
276+ need to be QConfigDynamic instances.
277+ mapping: maps type of a submodule to a type of corresponding dynamically quantized version
278+ with which the submodule needs to be replaced
268279 """
280+ if qconfig_dict is None :
281+ if dtype == torch .qint8 :
282+ qconfig_dict = {
283+ nn .Linear : default_dynamic_qconfig ,
284+ nn .LSTM : default_dynamic_qconfig ,
285+ }
286+ elif dtype == torch .float16 :
287+ qconfig_dict = {
288+ # TODO: uncomment when float16 Linear support is added
289+ # nn.Linear : default_dynamic_qconfig,
290+ nn .LSTM : float16_dynamic_qconfig ,
291+ }
292+ else :
293+ raise ValueError (
294+ "Don't know how to quantize with default settings for {}. Provide full qconfig please" .format (dtype ))
295+
269296 model = copy .deepcopy (model )
270297 model .eval ()
271298 propagate_qconfig (model , qconfig_dict )
272- convert (model , mapping , dtype )
299+ convert (model , mapping )
273300 return model
274301
275302def prepare_qat (model ):
@@ -295,7 +322,7 @@ def quantize_qat(model, run_fn, run_args):
295322 convert (model )
296323 return model
297324
298- def convert (module , mapping = DEFAULT_MODULE_MAPPING , dtype = torch . qint8 ):
325+ def convert (module , mapping = DEFAULT_MODULE_MAPPING ):
299326 r"""Converts the float module with observers(where we can get quantization
300327 parameters) to a quantized module.
301328 Args:
@@ -312,13 +339,13 @@ def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
312339
313340 for name , mod in module .named_children ():
314341 if type (mod ) not in SWAPPABLE_MODULES :
315- convert (mod , mapping , dtype )
316- reassign [name ] = swap_module (mod , mapping , dtype )
342+ convert (mod , mapping )
343+ reassign [name ] = swap_module (mod , mapping )
317344
318345 for key , value in reassign .items ():
319346 module ._modules [key ] = value
320347
321- def swap_module (mod , mapping , dtype = torch . qint8 ):
348+ def swap_module (mod , mapping ):
322349 r"""Swaps the module if it has a quantized counterpart and it has an
323350 `observer` attached.
324351
@@ -332,14 +359,7 @@ def swap_module(mod, mapping, dtype=torch.qint8):
332359 new_mod = mod
333360 if hasattr (mod , 'qconfig' ) and mod .qconfig is not None :
334361 if type (mod ) in mapping :
335- supported_scalar_types = [torch .qint8 , torch .float16 ]
336- if dtype not in supported_scalar_types :
337- raise RuntimeError ('Unsupported dtype: {}' .format (dtype ))
338- if dtype == torch .qint8 :
339- new_mod = mapping [type (mod )].from_float (mod )
340- elif dtype == torch .float16 :
341- # We want to support float16 dynamic quantization
342- new_mod = mapping [type (mod )].from_float (mod , dtype )
362+ new_mod = mapping [type (mod )].from_float (mod )
343363 return new_mod
344364
345365def get_observer_dict (mod , target_dict , prefix = "" ):
0 commit comments