Skip to content

Commit 9fa95f0

Browse files
noqqaqqSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Mention prerequisites for running tutorial basing on observations (#2461)
* Mention prerequisites for running tutorial basing on observations made with issue 1993 --------- Co-authored-by: noqqaqq <noqqaqq@users.noreply.github.com> Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 994bd83 commit 9fa95f0

File tree

1 file changed

+74
-48
lines changed

1 file changed

+74
-48
lines changed

beginner_source/text_sentiment_ngrams_tutorial.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
- Access to the raw data as an iterator
88
- Build data processing pipeline to convert the raw text strings into ``torch.Tensor`` that can be used to train the model
99
- Shuffle and iterate the data with `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__
10+
11+
12+
Prerequisites
13+
~~~~~~~~~~~~~~~~
14+
15+
A recent 2.x version of the ``portalocker`` package needs to be installed prior to running the tutorial.
16+
For example, in the Colab environment, this can be done by adding the following line at the top of the script:
17+
18+
.. code-block:: bash
19+
20+
!pip install -U portalocker>=2.0.0`
21+
1022
"""
1123

1224

@@ -16,12 +28,13 @@
1628
#
1729
# The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the ``AG_NEWS`` dataset iterators yield the raw data as a tuple of label and text.
1830
#
19-
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
31+
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
2032
#
2133

2234
import torch
2335
from torchtext.datasets import AG_NEWS
24-
train_iter = iter(AG_NEWS(split='train'))
36+
37+
train_iter = iter(AG_NEWS(split="train"))
2538

2639
######################################################################
2740
# ::
@@ -60,13 +73,15 @@
6073
from torchtext.data.utils import get_tokenizer
6174
from torchtext.vocab import build_vocab_from_iterator
6275

63-
tokenizer = get_tokenizer('basic_english')
64-
train_iter = AG_NEWS(split='train')
76+
tokenizer = get_tokenizer("basic_english")
77+
train_iter = AG_NEWS(split="train")
78+
6579

6680
def yield_tokens(data_iter):
6781
for _, text in data_iter:
6882
yield tokenizer(text)
6983

84+
7085
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
7186
vocab.set_default_index(vocab["<unk>"])
7287

@@ -96,7 +111,6 @@ def yield_tokens(data_iter):
96111
#
97112

98113

99-
100114
######################################################################
101115
# Generate data batch and iterator
102116
# --------------------------------
@@ -111,22 +125,27 @@ def yield_tokens(data_iter):
111125

112126

113127
from torch.utils.data import DataLoader
128+
114129
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115130

131+
116132
def collate_batch(batch):
117133
label_list, text_list, offsets = [], [], [0]
118-
for (_label, _text) in batch:
119-
label_list.append(label_pipeline(_label))
120-
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
121-
text_list.append(processed_text)
122-
offsets.append(processed_text.size(0))
134+
for _label, _text in batch:
135+
label_list.append(label_pipeline(_label))
136+
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
137+
text_list.append(processed_text)
138+
offsets.append(processed_text.size(0))
123139
label_list = torch.tensor(label_list, dtype=torch.int64)
124140
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
125141
text_list = torch.cat(text_list)
126142
return label_list.to(device), text_list.to(device), offsets.to(device)
127143

128-
train_iter = AG_NEWS(split='train')
129-
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
144+
145+
train_iter = AG_NEWS(split="train")
146+
dataloader = DataLoader(
147+
train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch
148+
)
130149

131150

132151
######################################################################
@@ -144,8 +163,8 @@ def collate_batch(batch):
144163

145164
from torch import nn
146165

147-
class TextClassificationModel(nn.Module):
148166

167+
class TextClassificationModel(nn.Module):
149168
def __init__(self, vocab_size, embed_dim, num_class):
150169
super(TextClassificationModel, self).__init__()
151170
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
@@ -179,7 +198,7 @@ def forward(self, text, offsets):
179198
# We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels,
180199
#
181200

182-
train_iter = AG_NEWS(split='train')
201+
train_iter = AG_NEWS(split="train")
183202
num_class = len(set([label for (label, text) in train_iter]))
184203
vocab_size = len(vocab)
185204
emsize = 64
@@ -194,6 +213,7 @@ def forward(self, text, offsets):
194213

195214
import time
196215

216+
197217
def train(dataloader):
198218
model.train()
199219
total_acc, total_count = 0, 0
@@ -211,12 +231,16 @@ def train(dataloader):
211231
total_count += label.size(0)
212232
if idx % log_interval == 0 and idx > 0:
213233
elapsed = time.time() - start_time
214-
print('| epoch {:3d} | {:5d}/{:5d} batches '
215-
'| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
216-
total_acc/total_count))
234+
print(
235+
"| epoch {:3d} | {:5d}/{:5d} batches "
236+
"| accuracy {:8.3f}".format(
237+
epoch, idx, len(dataloader), total_acc / total_count
238+
)
239+
)
217240
total_acc, total_count = 0, 0
218241
start_time = time.time()
219242

243+
220244
def evaluate(dataloader):
221245
model.eval()
222246
total_acc, total_count = 0, 0
@@ -227,7 +251,7 @@ def evaluate(dataloader):
227251
loss = criterion(predicted_label, label)
228252
total_acc += (predicted_label.argmax(1) == label).sum().item()
229253
total_count += label.size(0)
230-
return total_acc/total_count
254+
return total_acc / total_count
231255

232256

233257
######################################################################
@@ -253,10 +277,11 @@ def evaluate(dataloader):
253277

254278
from torch.utils.data.dataset import random_split
255279
from torchtext.data.functional import to_map_style_dataset
280+
256281
# Hyperparameters
257-
EPOCHS = 10 # epoch
282+
EPOCHS = 10 # epoch
258283
LR = 5 # learning rate
259-
BATCH_SIZE = 64 # batch size for training
284+
BATCH_SIZE = 64 # batch size for training
260285

261286
criterion = torch.nn.CrossEntropyLoss()
262287
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
@@ -266,31 +291,36 @@ def evaluate(dataloader):
266291
train_dataset = to_map_style_dataset(train_iter)
267292
test_dataset = to_map_style_dataset(test_iter)
268293
num_train = int(len(train_dataset) * 0.95)
269-
split_train_, split_valid_ = \
270-
random_split(train_dataset, [num_train, len(train_dataset) - num_train])
271-
272-
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
273-
shuffle=True, collate_fn=collate_batch)
274-
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
275-
shuffle=True, collate_fn=collate_batch)
276-
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
277-
shuffle=True, collate_fn=collate_batch)
294+
split_train_, split_valid_ = random_split(
295+
train_dataset, [num_train, len(train_dataset) - num_train]
296+
)
297+
298+
train_dataloader = DataLoader(
299+
split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
300+
)
301+
valid_dataloader = DataLoader(
302+
split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
303+
)
304+
test_dataloader = DataLoader(
305+
test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
306+
)
278307

279308
for epoch in range(1, EPOCHS + 1):
280309
epoch_start_time = time.time()
281310
train(train_dataloader)
282311
accu_val = evaluate(valid_dataloader)
283312
if total_accu is not None and total_accu > accu_val:
284-
scheduler.step()
313+
scheduler.step()
285314
else:
286-
total_accu = accu_val
287-
print('-' * 59)
288-
print('| end of epoch {:3d} | time: {:5.2f}s | '
289-
'valid accuracy {:8.3f} '.format(epoch,
290-
time.time() - epoch_start_time,
291-
accu_val))
292-
print('-' * 59)
293-
315+
total_accu = accu_val
316+
print("-" * 59)
317+
print(
318+
"| end of epoch {:3d} | time: {:5.2f}s | "
319+
"valid accuracy {:8.3f} ".format(
320+
epoch, time.time() - epoch_start_time, accu_val
321+
)
322+
)
323+
print("-" * 59)
294324

295325

296326
######################################################################
@@ -299,15 +329,12 @@ def evaluate(dataloader):
299329
#
300330

301331

302-
303332
######################################################################
304333
# Checking the results of the test dataset…
305334

306-
print('Checking the results of test dataset.')
335+
print("Checking the results of test dataset.")
307336
accu_test = evaluate(test_dataloader)
308-
print('test accuracy {:8.3f}'.format(accu_test))
309-
310-
337+
print("test accuracy {:8.3f}".format(accu_test))
311338

312339

313340
######################################################################
@@ -318,17 +345,16 @@ def evaluate(dataloader):
318345
#
319346

320347

321-
ag_news_label = {1: "World",
322-
2: "Sports",
323-
3: "Business",
324-
4: "Sci/Tec"}
348+
ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"}
349+
325350

326351
def predict(text, text_pipeline):
327352
with torch.no_grad():
328353
text = torch.tensor(text_pipeline(text))
329354
output = model(text, torch.tensor([0]))
330355
return output.argmax(1).item() + 1
331356

357+
332358
ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
333359
enduring the season’s worst weather conditions on Sunday at The \
334360
Open on his way to a closing 75 at Royal Portrush, which \
@@ -343,4 +369,4 @@ def predict(text, text_pipeline):
343369

344370
model = model.to("cpu")
345371

346-
print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])
372+
print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

0 commit comments

Comments
 (0)