diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index cce52eefdb3..a3fc3ab16eb 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -36,12 +36,8 @@ # of the word (see the next paragraph for more details). The # ``nn.TransformerEncoder`` consists of multiple layers of # `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__. -# Along with the input sequence, a square attention mask is required because the -# self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend -# the earlier positions in the sequence. For the language modeling task, any -# tokens on the future positions should be masked. To produce a probability -# distribution over output words, the output of the ``nn.TransformerEncoder`` -# model is passed through a linear layer followed by a log-softmax function. +# To produce a probability distribution over output words, the output of +# the ``nn.TransformerEncoder`` model is passed through a linear layer. # import math @@ -51,7 +47,6 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.utils.data import dataset @@ -64,19 +59,19 @@ def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, self.pos_encoder = PositionalEncoding(d_model, dropout) encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, d_model) + self.embedding = nn.Embedding(ntoken, d_model) self.d_model = d_model - self.decoder = nn.Linear(d_model, ntoken) + self.linear = nn.Linear(d_model, ntoken) self.init_weights() def init_weights(self) -> None: initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) + self.embedding.weight.data.uniform_(-initrange, initrange) + self.linear.bias.data.zero_() + self.linear.weight.data.uniform_(-initrange, initrange) - def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: + def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor: """ Arguments: src: Tensor, shape ``[seq_len, batch_size]`` @@ -85,18 +80,13 @@ def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: Returns: output Tensor of shape ``[seq_len, batch_size, ntoken]`` """ - src = self.encoder(src) * math.sqrt(self.d_model) + src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) - output = self.decoder(output) + output = self.linear(output) return output -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``.""" - return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) - - ###################################################################### # ``PositionalEncoding`` module injects some information about the # relative or absolute position of the tokens in the sequence. The @@ -286,7 +276,6 @@ def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]: # to prevent gradients from exploding. # -import copy import time criterion = nn.CrossEntropyLoss() @@ -299,16 +288,13 @@ def train(model: nn.Module) -> None: total_loss = 0. log_interval = 200 start_time = time.time() - src_mask = generate_square_subsequent_mask(bptt).to(device) num_batches = len(train_data) // bptt for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): data, targets = get_batch(train_data, i) - seq_len = data.size(0) - if seq_len != bptt: # only on last batch - src_mask = src_mask[:seq_len, :seq_len] - output = model(data, src_mask) - loss = criterion(output.view(-1, ntokens), targets) + output = model(data) + output_flat = output.view(-1, ntokens) + loss = criterion(output_flat, targets) optimizer.zero_grad() loss.backward() @@ -330,14 +316,11 @@ def train(model: nn.Module) -> None: def evaluate(model: nn.Module, eval_data: Tensor) -> float: model.eval() # turn on evaluation mode total_loss = 0. - src_mask = generate_square_subsequent_mask(bptt).to(device) with torch.no_grad(): for i in range(0, eval_data.size(0) - 1, bptt): data, targets = get_batch(eval_data, i) seq_len = data.size(0) - if seq_len != bptt: - src_mask = src_mask[:seq_len, :seq_len] - output = model(data, src_mask) + output = model(data) output_flat = output.view(-1, ntokens) total_loss += seq_len * criterion(output_flat, targets).item() return total_loss / (len(eval_data) - 1)