Skip to content

Commit f1cb62c

Browse files
authored
Remove improper src_mask from encoder tutorial (#2423)
Fixes #1877 The tutorial is using a transformer encoder and the mask used was for masking a decoder which is not part of the tutorial. The mask is removed. Some variable names are changed to better reflect the purpose of the variable. Also, some unused imports are removed. Signed-off-by: BJ Hargrave <hargrave@us.ibm.com>
1 parent 9e00157 commit f1cb62c

File tree

1 file changed

+14
-31
lines changed

1 file changed

+14
-31
lines changed

beginner_source/transformer_tutorial.py

+14-31
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,8 @@
3636
# of the word (see the next paragraph for more details). The
3737
# ``nn.TransformerEncoder`` consists of multiple layers of
3838
# `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__.
39-
# Along with the input sequence, a square attention mask is required because the
40-
# self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend
41-
# the earlier positions in the sequence. For the language modeling task, any
42-
# tokens on the future positions should be masked. To produce a probability
43-
# distribution over output words, the output of the ``nn.TransformerEncoder``
44-
# model is passed through a linear layer followed by a log-softmax function.
39+
# To produce a probability distribution over output words, the output of
40+
# the ``nn.TransformerEncoder`` model is passed through a linear layer.
4541
#
4642

4743
import math
@@ -51,7 +47,6 @@
5147

5248
import torch
5349
from torch import nn, Tensor
54-
import torch.nn.functional as F
5550
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5651
from torch.utils.data import dataset
5752

@@ -64,19 +59,19 @@ def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
6459
self.pos_encoder = PositionalEncoding(d_model, dropout)
6560
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
6661
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
67-
self.encoder = nn.Embedding(ntoken, d_model)
62+
self.embedding = nn.Embedding(ntoken, d_model)
6863
self.d_model = d_model
69-
self.decoder = nn.Linear(d_model, ntoken)
64+
self.linear = nn.Linear(d_model, ntoken)
7065

7166
self.init_weights()
7267

7368
def init_weights(self) -> None:
7469
initrange = 0.1
75-
self.encoder.weight.data.uniform_(-initrange, initrange)
76-
self.decoder.bias.data.zero_()
77-
self.decoder.weight.data.uniform_(-initrange, initrange)
70+
self.embedding.weight.data.uniform_(-initrange, initrange)
71+
self.linear.bias.data.zero_()
72+
self.linear.weight.data.uniform_(-initrange, initrange)
7873

79-
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
74+
def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
8075
"""
8176
Arguments:
8277
src: Tensor, shape ``[seq_len, batch_size]``
@@ -85,18 +80,13 @@ def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
8580
Returns:
8681
output Tensor of shape ``[seq_len, batch_size, ntoken]``
8782
"""
88-
src = self.encoder(src) * math.sqrt(self.d_model)
83+
src = self.embedding(src) * math.sqrt(self.d_model)
8984
src = self.pos_encoder(src)
9085
output = self.transformer_encoder(src, src_mask)
91-
output = self.decoder(output)
86+
output = self.linear(output)
9287
return output
9388

9489

95-
def generate_square_subsequent_mask(sz: int) -> Tensor:
96-
"""Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``."""
97-
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
98-
99-
10090
######################################################################
10191
# ``PositionalEncoding`` module injects some information about the
10292
# relative or absolute position of the tokens in the sequence. The
@@ -286,7 +276,6 @@ def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
286276
# to prevent gradients from exploding.
287277
#
288278

289-
import copy
290279
import time
291280

292281
criterion = nn.CrossEntropyLoss()
@@ -299,16 +288,13 @@ def train(model: nn.Module) -> None:
299288
total_loss = 0.
300289
log_interval = 200
301290
start_time = time.time()
302-
src_mask = generate_square_subsequent_mask(bptt).to(device)
303291

304292
num_batches = len(train_data) // bptt
305293
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
306294
data, targets = get_batch(train_data, i)
307-
seq_len = data.size(0)
308-
if seq_len != bptt: # only on last batch
309-
src_mask = src_mask[:seq_len, :seq_len]
310-
output = model(data, src_mask)
311-
loss = criterion(output.view(-1, ntokens), targets)
295+
output = model(data)
296+
output_flat = output.view(-1, ntokens)
297+
loss = criterion(output_flat, targets)
312298

313299
optimizer.zero_grad()
314300
loss.backward()
@@ -330,14 +316,11 @@ def train(model: nn.Module) -> None:
330316
def evaluate(model: nn.Module, eval_data: Tensor) -> float:
331317
model.eval() # turn on evaluation mode
332318
total_loss = 0.
333-
src_mask = generate_square_subsequent_mask(bptt).to(device)
334319
with torch.no_grad():
335320
for i in range(0, eval_data.size(0) - 1, bptt):
336321
data, targets = get_batch(eval_data, i)
337322
seq_len = data.size(0)
338-
if seq_len != bptt:
339-
src_mask = src_mask[:seq_len, :seq_len]
340-
output = model(data, src_mask)
323+
output = model(data)
341324
output_flat = output.view(-1, ntokens)
342325
total_loss += seq_len * criterion(output_flat, targets).item()
343326
return total_loss / (len(eval_data) - 1)

0 commit comments

Comments
 (0)