Skip to content

Commit cfd2f36

Browse files
committed
Fix padding issues in seq2seq translation tutorial
- Add dedicated PAD_token (index 2) for proper padding - Use pack_padded_sequence in encoder to handle variable-length sequences - Ensure encoder hidden state represents actual content, not padding - Add ignore_index=PAD_token to loss function to exclude padding from gradients - Update all embedding layers with padding_idx parameter - Add comprehensive documentation explaining padding handling best practices Fixes issues where: 1. GRU final hidden state could be from PAD tokens 2. Loss was computed on PAD tokens affecting training
1 parent f99e9e8 commit cfd2f36

File tree

1 file changed

+42
-9
lines changed

1 file changed

+42
-9
lines changed

intermediate_source/seq2seq_translation_tutorial.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,15 @@
150150

151151
SOS_token = 0
152152
EOS_token = 1
153+
PAD_token = 2
153154

154155
class Lang:
155156
def __init__(self, name):
156157
self.name = name
157158
self.word2index = {}
158159
self.word2count = {}
159-
self.index2word = {0: "SOS", 1: "EOS"}
160-
self.n_words = 2 # Count SOS and EOS
160+
self.index2word = {0: "SOS", 1: "EOS", 2: "PAD"}
161+
self.n_words = 3 # Count SOS, EOS, and PAD
161162

162163
def addSentence(self, sentence):
163164
for word in sentence.split(' '):
@@ -335,13 +336,23 @@ def __init__(self, input_size, hidden_size, dropout_p=0.1):
335336
super(EncoderRNN, self).__init__()
336337
self.hidden_size = hidden_size
337338

338-
self.embedding = nn.Embedding(input_size, hidden_size)
339+
self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD_token)
339340
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
340341
self.dropout = nn.Dropout(dropout_p)
341342

342343
def forward(self, input):
344+
# Compute actual lengths (excluding padding)
345+
lengths = (input != PAD_token).sum(dim=1).cpu()
346+
343347
embedded = self.dropout(self.embedding(input))
344-
output, hidden = self.gru(embedded)
348+
349+
# Pack padded sequences
350+
packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)
351+
output, hidden = self.gru(packed)
352+
353+
# Unpack sequences
354+
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
355+
345356
return output, hidden
346357

347358
######################################################################
@@ -375,7 +386,7 @@ def forward(self, input):
375386
class DecoderRNN(nn.Module):
376387
def __init__(self, hidden_size, output_size):
377388
super(DecoderRNN, self).__init__()
378-
self.embedding = nn.Embedding(output_size, hidden_size)
389+
self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token)
379390
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
380391
self.out = nn.Linear(hidden_size, output_size)
381392

@@ -480,7 +491,7 @@ def forward(self, query, keys):
480491
class AttnDecoderRNN(nn.Module):
481492
def __init__(self, hidden_size, output_size, dropout_p=0.1):
482493
super(AttnDecoderRNN, self).__init__()
483-
self.embedding = nn.Embedding(output_size, hidden_size)
494+
self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token)
484495
self.attention = BahdanauAttention(hidden_size)
485496
self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
486497
self.out = nn.Linear(hidden_size, output_size)
@@ -563,8 +574,8 @@ def get_dataloader(batch_size):
563574
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
564575

565576
n = len(pairs)
566-
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
567-
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
577+
input_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32)
578+
target_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32)
568579

569580
for idx, (inp, tgt) in enumerate(pairs):
570581
inp_ids = indexesFromSentence(input_lang, inp)
@@ -583,6 +594,28 @@ def get_dataloader(batch_size):
583594

584595

585596
######################################################################
597+
# .. note::
598+
# When working with batched sequences of variable lengths, proper padding
599+
# handling is crucial:
600+
#
601+
# 1. **Padding Token**: We use a dedicated ``PAD_token`` (index 2) to pad
602+
# shorter sequences to the batch's maximum length. This is better than
603+
# using 0 (SOS token) as padding.
604+
#
605+
# 2. **Encoder Padding**: The encoder uses ``pack_padded_sequence`` and
606+
# ``pad_packed_sequence`` to handle variable-length sequences efficiently.
607+
# This ensures the GRU's final hidden state represents the actual sentence
608+
# content, not padding tokens.
609+
#
610+
# 3. **Loss Masking**: The loss function uses ``ignore_index=PAD_token`` to
611+
# exclude padding tokens from the loss computation. This prevents the model
612+
# from learning to predict padding and ensures gradients only flow from
613+
# actual target tokens.
614+
#
615+
# 4. **Embedding Padding**: All embedding layers use ``padding_idx=PAD_token``
616+
# to ensure padding tokens have zero embeddings that don't get updated
617+
# during training.
618+
#
586619
# Training the Model
587620
# ------------------
588621
#
@@ -678,7 +711,7 @@ def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
678711

679712
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
680713
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
681-
criterion = nn.NLLLoss()
714+
criterion = nn.NLLLoss(ignore_index=PAD_token)
682715

683716
for epoch in range(1, n_epochs + 1):
684717
loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)

0 commit comments

Comments
 (0)