7
7
- Access to the raw data as an iterator
8
8
- Build data processing pipeline to convert the raw text strings into ``torch.Tensor`` that can be used to train the model
9
9
- Shuffle and iterate the data with `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__
10
+
11
+
12
+ Prerequisites
13
+ ~~~~~~~~~~~~~~~~
14
+
15
+ A recent 2.x version of the ``portalocker`` package needs to be installed prior to running the tutorial.
16
+ For example, in the Colab environment, this can be done by adding the following line at the top of the script:
17
+
18
+ .. code-block:: bash
19
+
20
+ !pip install -U portalocker>=2.0.0`
21
+
10
22
"""
11
23
12
24
16
28
#
17
29
# The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the ``AG_NEWS`` dataset iterators yield the raw data as a tuple of label and text.
18
30
#
19
- # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
31
+ # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
20
32
#
21
33
22
34
import torch
23
35
from torchtext .datasets import AG_NEWS
24
- train_iter = iter (AG_NEWS (split = 'train' ))
36
+
37
+ train_iter = iter (AG_NEWS (split = "train" ))
25
38
26
39
######################################################################
27
40
# ::
60
73
from torchtext .data .utils import get_tokenizer
61
74
from torchtext .vocab import build_vocab_from_iterator
62
75
63
- tokenizer = get_tokenizer ('basic_english' )
64
- train_iter = AG_NEWS (split = 'train' )
76
+ tokenizer = get_tokenizer ("basic_english" )
77
+ train_iter = AG_NEWS (split = "train" )
78
+
65
79
66
80
def yield_tokens (data_iter ):
67
81
for _ , text in data_iter :
68
82
yield tokenizer (text )
69
83
84
+
70
85
vocab = build_vocab_from_iterator (yield_tokens (train_iter ), specials = ["<unk>" ])
71
86
vocab .set_default_index (vocab ["<unk>" ])
72
87
@@ -96,7 +111,6 @@ def yield_tokens(data_iter):
96
111
#
97
112
98
113
99
-
100
114
######################################################################
101
115
# Generate data batch and iterator
102
116
# --------------------------------
@@ -111,22 +125,27 @@ def yield_tokens(data_iter):
111
125
112
126
113
127
from torch .utils .data import DataLoader
128
+
114
129
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
115
130
131
+
116
132
def collate_batch (batch ):
117
133
label_list , text_list , offsets = [], [], [0 ]
118
- for ( _label , _text ) in batch :
119
- label_list .append (label_pipeline (_label ))
120
- processed_text = torch .tensor (text_pipeline (_text ), dtype = torch .int64 )
121
- text_list .append (processed_text )
122
- offsets .append (processed_text .size (0 ))
134
+ for _label , _text in batch :
135
+ label_list .append (label_pipeline (_label ))
136
+ processed_text = torch .tensor (text_pipeline (_text ), dtype = torch .int64 )
137
+ text_list .append (processed_text )
138
+ offsets .append (processed_text .size (0 ))
123
139
label_list = torch .tensor (label_list , dtype = torch .int64 )
124
140
offsets = torch .tensor (offsets [:- 1 ]).cumsum (dim = 0 )
125
141
text_list = torch .cat (text_list )
126
142
return label_list .to (device ), text_list .to (device ), offsets .to (device )
127
143
128
- train_iter = AG_NEWS (split = 'train' )
129
- dataloader = DataLoader (train_iter , batch_size = 8 , shuffle = False , collate_fn = collate_batch )
144
+
145
+ train_iter = AG_NEWS (split = "train" )
146
+ dataloader = DataLoader (
147
+ train_iter , batch_size = 8 , shuffle = False , collate_fn = collate_batch
148
+ )
130
149
131
150
132
151
######################################################################
@@ -144,8 +163,8 @@ def collate_batch(batch):
144
163
145
164
from torch import nn
146
165
147
- class TextClassificationModel (nn .Module ):
148
166
167
+ class TextClassificationModel (nn .Module ):
149
168
def __init__ (self , vocab_size , embed_dim , num_class ):
150
169
super (TextClassificationModel , self ).__init__ ()
151
170
self .embedding = nn .EmbeddingBag (vocab_size , embed_dim , sparse = False )
@@ -179,7 +198,7 @@ def forward(self, text, offsets):
179
198
# We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels,
180
199
#
181
200
182
- train_iter = AG_NEWS (split = ' train' )
201
+ train_iter = AG_NEWS (split = " train" )
183
202
num_class = len (set ([label for (label , text ) in train_iter ]))
184
203
vocab_size = len (vocab )
185
204
emsize = 64
@@ -194,6 +213,7 @@ def forward(self, text, offsets):
194
213
195
214
import time
196
215
216
+
197
217
def train (dataloader ):
198
218
model .train ()
199
219
total_acc , total_count = 0 , 0
@@ -211,12 +231,16 @@ def train(dataloader):
211
231
total_count += label .size (0 )
212
232
if idx % log_interval == 0 and idx > 0 :
213
233
elapsed = time .time () - start_time
214
- print ('| epoch {:3d} | {:5d}/{:5d} batches '
215
- '| accuracy {:8.3f}' .format (epoch , idx , len (dataloader ),
216
- total_acc / total_count ))
234
+ print (
235
+ "| epoch {:3d} | {:5d}/{:5d} batches "
236
+ "| accuracy {:8.3f}" .format (
237
+ epoch , idx , len (dataloader ), total_acc / total_count
238
+ )
239
+ )
217
240
total_acc , total_count = 0 , 0
218
241
start_time = time .time ()
219
242
243
+
220
244
def evaluate (dataloader ):
221
245
model .eval ()
222
246
total_acc , total_count = 0 , 0
@@ -227,7 +251,7 @@ def evaluate(dataloader):
227
251
loss = criterion (predicted_label , label )
228
252
total_acc += (predicted_label .argmax (1 ) == label ).sum ().item ()
229
253
total_count += label .size (0 )
230
- return total_acc / total_count
254
+ return total_acc / total_count
231
255
232
256
233
257
######################################################################
@@ -253,10 +277,11 @@ def evaluate(dataloader):
253
277
254
278
from torch .utils .data .dataset import random_split
255
279
from torchtext .data .functional import to_map_style_dataset
280
+
256
281
# Hyperparameters
257
- EPOCHS = 10 # epoch
282
+ EPOCHS = 10 # epoch
258
283
LR = 5 # learning rate
259
- BATCH_SIZE = 64 # batch size for training
284
+ BATCH_SIZE = 64 # batch size for training
260
285
261
286
criterion = torch .nn .CrossEntropyLoss ()
262
287
optimizer = torch .optim .SGD (model .parameters (), lr = LR )
@@ -266,31 +291,36 @@ def evaluate(dataloader):
266
291
train_dataset = to_map_style_dataset (train_iter )
267
292
test_dataset = to_map_style_dataset (test_iter )
268
293
num_train = int (len (train_dataset ) * 0.95 )
269
- split_train_ , split_valid_ = \
270
- random_split (train_dataset , [num_train , len (train_dataset ) - num_train ])
271
-
272
- train_dataloader = DataLoader (split_train_ , batch_size = BATCH_SIZE ,
273
- shuffle = True , collate_fn = collate_batch )
274
- valid_dataloader = DataLoader (split_valid_ , batch_size = BATCH_SIZE ,
275
- shuffle = True , collate_fn = collate_batch )
276
- test_dataloader = DataLoader (test_dataset , batch_size = BATCH_SIZE ,
277
- shuffle = True , collate_fn = collate_batch )
294
+ split_train_ , split_valid_ = random_split (
295
+ train_dataset , [num_train , len (train_dataset ) - num_train ]
296
+ )
297
+
298
+ train_dataloader = DataLoader (
299
+ split_train_ , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
300
+ )
301
+ valid_dataloader = DataLoader (
302
+ split_valid_ , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
303
+ )
304
+ test_dataloader = DataLoader (
305
+ test_dataset , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
306
+ )
278
307
279
308
for epoch in range (1 , EPOCHS + 1 ):
280
309
epoch_start_time = time .time ()
281
310
train (train_dataloader )
282
311
accu_val = evaluate (valid_dataloader )
283
312
if total_accu is not None and total_accu > accu_val :
284
- scheduler .step ()
313
+ scheduler .step ()
285
314
else :
286
- total_accu = accu_val
287
- print ('-' * 59 )
288
- print ('| end of epoch {:3d} | time: {:5.2f}s | '
289
- 'valid accuracy {:8.3f} ' .format (epoch ,
290
- time .time () - epoch_start_time ,
291
- accu_val ))
292
- print ('-' * 59 )
293
-
315
+ total_accu = accu_val
316
+ print ("-" * 59 )
317
+ print (
318
+ "| end of epoch {:3d} | time: {:5.2f}s | "
319
+ "valid accuracy {:8.3f} " .format (
320
+ epoch , time .time () - epoch_start_time , accu_val
321
+ )
322
+ )
323
+ print ("-" * 59 )
294
324
295
325
296
326
######################################################################
@@ -299,15 +329,12 @@ def evaluate(dataloader):
299
329
#
300
330
301
331
302
-
303
332
######################################################################
304
333
# Checking the results of the test dataset…
305
334
306
- print (' Checking the results of test dataset.' )
335
+ print (" Checking the results of test dataset." )
307
336
accu_test = evaluate (test_dataloader )
308
- print ('test accuracy {:8.3f}' .format (accu_test ))
309
-
310
-
337
+ print ("test accuracy {:8.3f}" .format (accu_test ))
311
338
312
339
313
340
######################################################################
@@ -318,17 +345,16 @@ def evaluate(dataloader):
318
345
#
319
346
320
347
321
- ag_news_label = {1 : "World" ,
322
- 2 : "Sports" ,
323
- 3 : "Business" ,
324
- 4 : "Sci/Tec" }
348
+ ag_news_label = {1 : "World" , 2 : "Sports" , 3 : "Business" , 4 : "Sci/Tec" }
349
+
325
350
326
351
def predict (text , text_pipeline ):
327
352
with torch .no_grad ():
328
353
text = torch .tensor (text_pipeline (text ))
329
354
output = model (text , torch .tensor ([0 ]))
330
355
return output .argmax (1 ).item () + 1
331
356
357
+
332
358
ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
333
359
enduring the season’s worst weather conditions on Sunday at The \
334
360
Open on his way to a closing 75 at Royal Portrush, which \
@@ -343,4 +369,4 @@ def predict(text, text_pipeline):
343
369
344
370
model = model .to ("cpu" )
345
371
346
- print ("This is a %s news" % ag_news_label [predict (ex_text_str , text_pipeline )])
372
+ print ("This is a %s news" % ag_news_label [predict (ex_text_str , text_pipeline )])
0 commit comments