From 2509680d081939c3f84688dc13fd03612e90bada Mon Sep 17 00:00:00 2001 From: BJ Hargrave <hargrave@us.ibm.com> Date: Mon, 5 Jun 2023 10:34:25 -0400 Subject: [PATCH] Remove improper src_mask from encoder tutorial The tutorial is using a transformer encoder and the mask used was for masking a decoder which is not part of the tutorial. The mask is removed. Some variable names are changed to better reflect the purpose of the variable. Also, some unused imports are removed. Fixes https://github.com/pytorch/tutorials/issues/1877 Signed-off-by: BJ Hargrave <hargrave@us.ibm.com> --- beginner_source/transformer_tutorial.py | 45 ++++++++----------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index cce52eefdb3..a3fc3ab16eb 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -36,12 +36,8 @@ # of the word (see the next paragraph for more details). The # ``nn.TransformerEncoder`` consists of multiple layers of # `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__. -# Along with the input sequence, a square attention mask is required because the -# self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend -# the earlier positions in the sequence. For the language modeling task, any -# tokens on the future positions should be masked. To produce a probability -# distribution over output words, the output of the ``nn.TransformerEncoder`` -# model is passed through a linear layer followed by a log-softmax function. +# To produce a probability distribution over output words, the output of +# the ``nn.TransformerEncoder`` model is passed through a linear layer. # import math @@ -51,7 +47,6 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.utils.data import dataset @@ -64,19 +59,19 @@ def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, self.pos_encoder = PositionalEncoding(d_model, dropout) encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, d_model) + self.embedding = nn.Embedding(ntoken, d_model) self.d_model = d_model - self.decoder = nn.Linear(d_model, ntoken) + self.linear = nn.Linear(d_model, ntoken) self.init_weights() def init_weights(self) -> None: initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) + self.embedding.weight.data.uniform_(-initrange, initrange) + self.linear.bias.data.zero_() + self.linear.weight.data.uniform_(-initrange, initrange) - def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: + def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor: """ Arguments: src: Tensor, shape ``[seq_len, batch_size]`` @@ -85,18 +80,13 @@ def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: Returns: output Tensor of shape ``[seq_len, batch_size, ntoken]`` """ - src = self.encoder(src) * math.sqrt(self.d_model) + src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) - output = self.decoder(output) + output = self.linear(output) return output -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``.""" - return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) - - ###################################################################### # ``PositionalEncoding`` module injects some information about the # relative or absolute position of the tokens in the sequence. The @@ -286,7 +276,6 @@ def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]: # to prevent gradients from exploding. # -import copy import time criterion = nn.CrossEntropyLoss() @@ -299,16 +288,13 @@ def train(model: nn.Module) -> None: total_loss = 0. log_interval = 200 start_time = time.time() - src_mask = generate_square_subsequent_mask(bptt).to(device) num_batches = len(train_data) // bptt for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): data, targets = get_batch(train_data, i) - seq_len = data.size(0) - if seq_len != bptt: # only on last batch - src_mask = src_mask[:seq_len, :seq_len] - output = model(data, src_mask) - loss = criterion(output.view(-1, ntokens), targets) + output = model(data) + output_flat = output.view(-1, ntokens) + loss = criterion(output_flat, targets) optimizer.zero_grad() loss.backward() @@ -330,14 +316,11 @@ def train(model: nn.Module) -> None: def evaluate(model: nn.Module, eval_data: Tensor) -> float: model.eval() # turn on evaluation mode total_loss = 0. - src_mask = generate_square_subsequent_mask(bptt).to(device) with torch.no_grad(): for i in range(0, eval_data.size(0) - 1, bptt): data, targets = get_batch(eval_data, i) seq_len = data.size(0) - if seq_len != bptt: - src_mask = src_mask[:seq_len, :seq_len] - output = model(data, src_mask) + output = model(data) output_flat = output.view(-1, ntokens) total_loss += seq_len * criterion(output_flat, targets).item() return total_loss / (len(eval_data) - 1)