-
Notifications
You must be signed in to change notification settings - Fork 271
/
Copy pathint8_quantization_dynamic.py
39 lines (30 loc) · 1.44 KB
/
int8_quantization_dynamic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
###################################################### # noqa F401
##### Example Model ##### # noqa F401
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
######################### # noqa F401
qconfig_mapping = ipex.quantization.default_dynamic_qconfig_mapping
# Alternatively, define your own qconfig:
# from torch.ao.quantization import PerChannelMinMaxObserver, PlaceholderObserver, QConfig, QConfigMapping
# qconfig = QConfig(
# activation = PlaceholderObserver.with_args(dtype=torch.float, is_dynamic=True),
# weight = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
# qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_model = prepare(model, qconfig_mapping, example_inputs=data)
converted_model = convert(prepared_model)
with torch.no_grad():
traced_model = torch.jit.trace(
converted_model, (data,), check_trace=False, strict=False
)
traced_model = torch.jit.freeze(traced_model)
traced_model.save("dynamic_quantized_model.pt")
print("Saved model to: dynamic_quantized_model.pt")