diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py
index cce52eefdb3..a3fc3ab16eb 100644
--- a/beginner_source/transformer_tutorial.py
+++ b/beginner_source/transformer_tutorial.py
@@ -36,12 +36,8 @@
 # of the word (see the next paragraph for more details). The
 # ``nn.TransformerEncoder`` consists of multiple layers of
 # `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__.
-# Along with the input sequence, a square attention mask is required because the
-# self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend
-# the earlier positions in the sequence. For the language modeling task, any
-# tokens on the future positions should be masked. To produce a probability
-# distribution over output words, the output of the ``nn.TransformerEncoder``
-# model is passed through a linear layer followed by a log-softmax function.
+# To produce a probability distribution over output words, the output of 
+# the ``nn.TransformerEncoder`` model is passed through a linear layer.
 #
 
 import math
@@ -51,7 +47,6 @@
 
 import torch
 from torch import nn, Tensor
-import torch.nn.functional as F
 from torch.nn import TransformerEncoder, TransformerEncoderLayer
 from torch.utils.data import dataset
 
@@ -64,19 +59,19 @@ def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
         self.pos_encoder = PositionalEncoding(d_model, dropout)
         encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
         self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
-        self.encoder = nn.Embedding(ntoken, d_model)
+        self.embedding = nn.Embedding(ntoken, d_model)
         self.d_model = d_model
-        self.decoder = nn.Linear(d_model, ntoken)
+        self.linear = nn.Linear(d_model, ntoken)
 
         self.init_weights()
 
     def init_weights(self) -> None:
         initrange = 0.1
-        self.encoder.weight.data.uniform_(-initrange, initrange)
-        self.decoder.bias.data.zero_()
-        self.decoder.weight.data.uniform_(-initrange, initrange)
+        self.embedding.weight.data.uniform_(-initrange, initrange)
+        self.linear.bias.data.zero_()
+        self.linear.weight.data.uniform_(-initrange, initrange)
 
-    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
+    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
         """
         Arguments:
             src: Tensor, shape ``[seq_len, batch_size]``
@@ -85,18 +80,13 @@ def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
         Returns:
             output Tensor of shape ``[seq_len, batch_size, ntoken]``
         """
-        src = self.encoder(src) * math.sqrt(self.d_model)
+        src = self.embedding(src) * math.sqrt(self.d_model)
         src = self.pos_encoder(src)
         output = self.transformer_encoder(src, src_mask)
-        output = self.decoder(output)
+        output = self.linear(output)
         return output
 
 
-def generate_square_subsequent_mask(sz: int) -> Tensor:
-    """Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``."""
-    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
-
-
 ######################################################################
 # ``PositionalEncoding`` module injects some information about the
 # relative or absolute position of the tokens in the sequence. The
@@ -286,7 +276,6 @@ def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
 # to prevent gradients from exploding.
 #
 
-import copy
 import time
 
 criterion = nn.CrossEntropyLoss()
@@ -299,16 +288,13 @@ def train(model: nn.Module) -> None:
     total_loss = 0.
     log_interval = 200
     start_time = time.time()
-    src_mask = generate_square_subsequent_mask(bptt).to(device)
 
     num_batches = len(train_data) // bptt
     for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
         data, targets = get_batch(train_data, i)
-        seq_len = data.size(0)
-        if seq_len != bptt:  # only on last batch
-            src_mask = src_mask[:seq_len, :seq_len]
-        output = model(data, src_mask)
-        loss = criterion(output.view(-1, ntokens), targets)
+        output = model(data)
+        output_flat = output.view(-1, ntokens)
+        loss = criterion(output_flat, targets)
 
         optimizer.zero_grad()
         loss.backward()
@@ -330,14 +316,11 @@ def train(model: nn.Module) -> None:
 def evaluate(model: nn.Module, eval_data: Tensor) -> float:
     model.eval()  # turn on evaluation mode
     total_loss = 0.
-    src_mask = generate_square_subsequent_mask(bptt).to(device)
     with torch.no_grad():
         for i in range(0, eval_data.size(0) - 1, bptt):
             data, targets = get_batch(eval_data, i)
             seq_len = data.size(0)
-            if seq_len != bptt:
-                src_mask = src_mask[:seq_len, :seq_len]
-            output = model(data, src_mask)
+            output = model(data)
             output_flat = output.view(-1, ntokens)
             total_loss += seq_len * criterion(output_flat, targets).item()
     return total_loss / (len(eval_data) - 1)