Skip to content

Commit 3600b1a

Browse files
author
Seth Weidman
committed
First working draft of dynamic quantization tutorial
1 parent a300b1d commit 3600b1a

File tree

4 files changed

+307
-0
lines changed

4 files changed

+307
-0
lines changed

Makefile

+8
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ download:
8181
wget -N https://s3.amazonaws.com/pytorch-tutorial-assets/lenet_mnist_model.pth -P $(DATADIR)
8282
cp $(DATADIR)/lenet_mnist_model.pth ./beginner_source/data/lenet_mnist_model.pth
8383

84+
# Download model for advanced_source/dynamic_quantization_tutorial.py
85+
wget -N https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth -P $(DATADIR)
86+
cp $(DATADIR)/lenet_mnist_model.pth ./advanced_source/data/word_language_model_quantize.pth
87+
88+
# Download data for advanced_source/dynamic_quantization_tutorial.py
89+
wget -N https://s3.amazonaws.com/pytorch-tutorial-assets/wikitext.zip -P $(DATADIR)
90+
cp $(DATADIR)/wikitext.zip ./advanced_source/data/wikitext.zip
91+
8492
docs:
8593
make download
8694
make html

_static/img/quant_asym.png

8.18 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""
2+
Dynamic Quantization on an LSTM Word Language Model
3+
===================================================
4+
5+
**Author**: `James Reed <https://github.com/jamesr66a>`_
6+
7+
**Edited by**: `Seth Weidman <https://github.com/SethHWeidman/>`_
8+
9+
Introduction
10+
------------
11+
12+
Quantization involves converting the weights and activations of your model from float
13+
to int, which can result in smaller model size and faster inference with only a small
14+
hit to accuracy.
15+
16+
In this tutorial, we'll apply the easiest form of quantization - _dynamic quantization_ -
17+
to an LSTM-based next word-prediction model, closely following the
18+
`word language model <https://github.com/pytorch/examples/tree/master/word_language_model>`_
19+
from the PyTorch examples.
20+
"""
21+
22+
# imports
23+
import os
24+
from io import open
25+
import time
26+
27+
import torch
28+
import torch.nn as nn
29+
import torch.nn.functional as F
30+
31+
######################################################################
32+
# 1. Define the model
33+
# -------------------
34+
#
35+
# Here we define the LSTM model architecture, following the
36+
# `model <https://github.com/pytorch/examples/blob/master/word_language_model/model.py>`_
37+
# from the word language model example.
38+
39+
class LSTMModel(nn.Module):
40+
"""Container module with an encoder, a recurrent module, and a decoder."""
41+
42+
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
43+
super(LSTMModel, self).__init__()
44+
self.drop = nn.Dropout(dropout)
45+
self.encoder = nn.Embedding(ntoken, ninp)
46+
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
47+
self.decoder = nn.Linear(nhid, ntoken)
48+
49+
self.init_weights()
50+
51+
self.nhid = nhid
52+
self.nlayers = nlayers
53+
54+
def init_weights(self):
55+
initrange = 0.1
56+
self.encoder.weight.data.uniform_(-initrange, initrange)
57+
self.decoder.bias.data.zero_()
58+
self.decoder.weight.data.uniform_(-initrange, initrange)
59+
60+
def forward(self, input, hidden):
61+
emb = self.drop(self.encoder(input))
62+
output, hidden = self.rnn(emb, hidden)
63+
output = self.drop(output)
64+
decoded = self.decoder(output)
65+
return decoded, hidden
66+
67+
def init_hidden(self, bsz):
68+
weight = next(self.parameters())
69+
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
70+
weight.new_zeros(self.nlayers, bsz, self.nhid))
71+
72+
######################################################################
73+
# 2. Load in the text data
74+
# ------------------------
75+
#
76+
# Next, we load the
77+
# `Wikitext-2 dataset <https://www.google.com/search?q=wikitext+2+data>`_ into a `Corpus`,
78+
# again following the
79+
# `preprocessing <https://github.com/pytorch/examples/blob/master/word_language_model/data.py>`_
80+
# from the word language model example.
81+
82+
class Dictionary(object):
83+
def __init__(self):
84+
self.word2idx = {}
85+
self.idx2word = []
86+
87+
def add_word(self, word):
88+
if word not in self.word2idx:
89+
self.idx2word.append(word)
90+
self.word2idx[word] = len(self.idx2word) - 1
91+
return self.word2idx[word]
92+
93+
def __len__(self):
94+
return len(self.idx2word)
95+
96+
97+
class Corpus(object):
98+
def __init__(self, path):
99+
self.dictionary = Dictionary()
100+
self.train = self.tokenize(os.path.join(path, 'train.txt'))
101+
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
102+
self.test = self.tokenize(os.path.join(path, 'test.txt'))
103+
104+
def tokenize(self, path):
105+
"""Tokenizes a text file."""
106+
assert os.path.exists(path)
107+
# Add words to the dictionary
108+
with open(path, 'r', encoding="utf8") as f:
109+
for line in f:
110+
words = line.split() + ['<eos>']
111+
for word in words:
112+
self.dictionary.add_word(word)
113+
114+
# Tokenize file content
115+
with open(path, 'r', encoding="utf8") as f:
116+
idss = []
117+
for line in f:
118+
words = line.split() + ['<eos>']
119+
ids = []
120+
for word in words:
121+
ids.append(self.dictionary.word2idx[word])
122+
idss.append(torch.tensor(ids).type(torch.int64))
123+
ids = torch.cat(idss)
124+
125+
return ids
126+
127+
model_data_filepath = 'data/'
128+
129+
corpus = Corpus(model_data_filepath + 'wikitext-2')
130+
131+
######################################################################
132+
# 3. Load the pre-trained model
133+
# -----------------------------
134+
#
135+
# This is a tutorial on dynamic quantization, a quantization technique
136+
# that is applied after a model has been trained. Therefore, we'll simply load some
137+
# pre-trained weights into this model architecture; these weights were obtained
138+
# by training for five epochs using the default settings in the word language model
139+
# example.
140+
141+
ntokens = len(corpus.dictionary)
142+
143+
model = LSTMModel(
144+
ntoken = ntokens,
145+
ninp = 512,
146+
nhid = 256,
147+
nlayers = 5,
148+
)
149+
150+
model.load_state_dict(
151+
torch.load(
152+
model_data_filepath + 'word_language_model_quantize.pth',
153+
map_location=torch.device('cpu')
154+
)
155+
)
156+
157+
model.eval()
158+
print(model)
159+
160+
######################################################################
161+
# Now let's generate some text to ensure that the pre-trained model is working
162+
# properly - similarly to before, we follow
163+
# `here <https://github.com/pytorch/examples/blob/master/word_language_model/generate.py>`_
164+
165+
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
166+
hidden = model.init_hidden(1)
167+
temperature = 1.0
168+
num_words = 1000
169+
170+
with open(model_data_filepath + 'out.txt', 'w') as outf:
171+
with torch.no_grad(): # no tracking history
172+
for i in range(num_words):
173+
output, hidden = model(input_, hidden)
174+
word_weights = output.squeeze().div(temperature).exp().cpu()
175+
word_idx = torch.multinomial(word_weights, 1)[0]
176+
input_.fill_(word_idx)
177+
178+
word = corpus.dictionary.idx2word[word_idx]
179+
180+
outf.write(word + ('\n' if i % 20 == 19 else ' '))
181+
182+
if i % 100 == 0:
183+
print('| Generated {}/{} words'.format(i, 1000))
184+
185+
with open(model_data_filepath + 'out.txt', 'r') as outf:
186+
all_output = outf.read()
187+
print(all_output)
188+
189+
######################################################################
190+
# It's no GPT-2, but it looks like the model has started to learn the structure of
191+
# language!
192+
#
193+
# We're almost ready to demonstrate dynamic quantization. We just need to define a few more
194+
# helper functions:
195+
196+
bptt = 25
197+
criterion = nn.CrossEntropyLoss()
198+
eval_batch_size = 1
199+
200+
# create test data set
201+
def batchify(data, bsz):
202+
# Work out how cleanly we can divide the dataset into bsz parts.
203+
nbatch = data.size(0) // bsz
204+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
205+
data = data.narrow(0, 0, nbatch * bsz)
206+
# Evenly divide the data across the bsz batches.
207+
return data.view(bsz, -1).t().contiguous()
208+
209+
test_data = batchify(corpus.test, eval_batch_size)
210+
211+
# Evaluation functions
212+
def get_batch(source, i):
213+
seq_len = min(bptt, len(source) - 1 - i)
214+
data = source[i:i+seq_len]
215+
target = source[i+1:i+1+seq_len].view(-1)
216+
return data, target
217+
218+
def repackage_hidden(h):
219+
"""Wraps hidden states in new Tensors, to detach them from their history."""
220+
221+
if isinstance(h, torch.Tensor):
222+
return h.detach()
223+
else:
224+
return tuple(repackage_hidden(v) for v in h)
225+
226+
def evaluate(model_, data_source):
227+
# Turn on evaluation mode which disables dropout.
228+
model_.eval()
229+
total_loss = 0.
230+
hidden = model_.init_hidden(eval_batch_size)
231+
with torch.no_grad():
232+
for i in range(0, data_source.size(0) - 1, bptt):
233+
data, targets = get_batch(data_source, i)
234+
output, hidden = model_(data, hidden)
235+
hidden = repackage_hidden(hidden)
236+
output_flat = output.view(-1, ntokens)
237+
total_loss += len(data) * criterion(output_flat, targets).item()
238+
return total_loss / (len(data_source) - 1)
239+
240+
######################################################################
241+
# 4. Test dynamic quantization
242+
# ----------------------------
243+
#
244+
# Finally, we can call `torch.quantization.quantize_dynamic` on the model!
245+
246+
import torch.quantization
247+
248+
quantized_model = torch.quantization.quantize_dynamic(
249+
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
250+
)
251+
print(quantized_model)
252+
253+
######################################################################
254+
# Note that we specify that we want to quantize `nn.LSTM` and `nn.Linear` modules, with a
255+
# bit width of `int8`.
256+
#
257+
# How has this benefitted us? First, we see a significant reduction in model size:
258+
259+
def print_size_of_model(model):
260+
torch.save(model.state_dict(), "temp.p")
261+
print('Size (MB):', os.path.getsize("temp.p")/1e6)
262+
os.remove('temp.p')
263+
264+
print_size_of_model(model)
265+
print_size_of_model(quantized_model)
266+
267+
######################################################################
268+
# Second, we see faster inference time, with no difference in evaluation loss:
269+
270+
def time_model_evaluation(model, test_data):
271+
s = time.time()
272+
loss = evaluate(model, test_data)
273+
elapsed = time.time() - s
274+
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
275+
276+
time_model_evaluation(model, test_data)
277+
time_model_evaluation(quantized_model, test_data)
278+
279+
######################################################################
280+
# Conclusion
281+
# ----------
282+
#
283+
# Dynamic quantization can be an easy way to reduce model size while only
284+
# having a limited effect on accuracy.
285+
#
286+
# Thanks for reading! As always, we welcome any feedback, so please create an issue
287+
# `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.

index.rst

+12
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ Extending PyTorch
223223

224224
<div style='clear:both'></div>
225225

226+
Quantization
227+
----------------------
228+
229+
.. customgalleryitem::
230+
:tooltip: Perform dynamic quantization on a pre-trained PyTorch model
231+
:description: :doc:`/advanced/dynamic_quantization_tutorial`
232+
:figure: _static/img/quant_asym.png
233+
234+
.. raw:: html
235+
236+
<div style='clear:both'></div>
237+
226238
PyTorch in Other Languages
227239
--------------------------
228240

0 commit comments

Comments
 (0)