150150
151151SOS_token = 0
152152EOS_token = 1
153+ PAD_token = 2
153154
154155class Lang :
155156 def __init__ (self , name ):
156157 self .name = name
157158 self .word2index = {}
158159 self .word2count = {}
159- self .index2word = {0 : "SOS" , 1 : "EOS" }
160- self .n_words = 2 # Count SOS and EOS
160+ self .index2word = {0 : "SOS" , 1 : "EOS" , 2 : "PAD" }
161+ self .n_words = 3 # Count SOS, EOS, and PAD
161162
162163 def addSentence (self , sentence ):
163164 for word in sentence .split (' ' ):
@@ -335,13 +336,23 @@ def __init__(self, input_size, hidden_size, dropout_p=0.1):
335336 super (EncoderRNN , self ).__init__ ()
336337 self .hidden_size = hidden_size
337338
338- self .embedding = nn .Embedding (input_size , hidden_size )
339+ self .embedding = nn .Embedding (input_size , hidden_size , padding_idx = PAD_token )
339340 self .gru = nn .GRU (hidden_size , hidden_size , batch_first = True )
340341 self .dropout = nn .Dropout (dropout_p )
341342
342343 def forward (self , input ):
344+ # Compute actual lengths (excluding padding)
345+ lengths = (input != PAD_token ).sum (dim = 1 ).cpu ()
346+
343347 embedded = self .dropout (self .embedding (input ))
344- output , hidden = self .gru (embedded )
348+
349+ # Pack padded sequences
350+ packed = nn .utils .rnn .pack_padded_sequence (embedded , lengths , batch_first = True , enforce_sorted = False )
351+ output , hidden = self .gru (packed )
352+
353+ # Unpack sequences
354+ output , _ = nn .utils .rnn .pad_packed_sequence (output , batch_first = True )
355+
345356 return output , hidden
346357
347358######################################################################
@@ -375,7 +386,7 @@ def forward(self, input):
375386class DecoderRNN (nn .Module ):
376387 def __init__ (self , hidden_size , output_size ):
377388 super (DecoderRNN , self ).__init__ ()
378- self .embedding = nn .Embedding (output_size , hidden_size )
389+ self .embedding = nn .Embedding (output_size , hidden_size , padding_idx = PAD_token )
379390 self .gru = nn .GRU (hidden_size , hidden_size , batch_first = True )
380391 self .out = nn .Linear (hidden_size , output_size )
381392
@@ -480,7 +491,7 @@ def forward(self, query, keys):
480491class AttnDecoderRNN (nn .Module ):
481492 def __init__ (self , hidden_size , output_size , dropout_p = 0.1 ):
482493 super (AttnDecoderRNN , self ).__init__ ()
483- self .embedding = nn .Embedding (output_size , hidden_size )
494+ self .embedding = nn .Embedding (output_size , hidden_size , padding_idx = PAD_token )
484495 self .attention = BahdanauAttention (hidden_size )
485496 self .gru = nn .GRU (2 * hidden_size , hidden_size , batch_first = True )
486497 self .out = nn .Linear (hidden_size , output_size )
@@ -563,8 +574,8 @@ def get_dataloader(batch_size):
563574 input_lang , output_lang , pairs = prepareData ('eng' , 'fra' , True )
564575
565576 n = len (pairs )
566- input_ids = np .zeros ((n , MAX_LENGTH ), dtype = np .int32 )
567- target_ids = np .zeros ((n , MAX_LENGTH ), dtype = np .int32 )
577+ input_ids = np .full ((n , MAX_LENGTH ), PAD_token , dtype = np .int32 )
578+ target_ids = np .full ((n , MAX_LENGTH ), PAD_token , dtype = np .int32 )
568579
569580 for idx , (inp , tgt ) in enumerate (pairs ):
570581 inp_ids = indexesFromSentence (input_lang , inp )
@@ -583,6 +594,28 @@ def get_dataloader(batch_size):
583594
584595
585596######################################################################
597+ # .. note::
598+ # When working with batched sequences of variable lengths, proper padding
599+ # handling is crucial:
600+ #
601+ # 1. **Padding Token**: We use a dedicated ``PAD_token`` (index 2) to pad
602+ # shorter sequences to the batch's maximum length. This is better than
603+ # using 0 (SOS token) as padding.
604+ #
605+ # 2. **Encoder Padding**: The encoder uses ``pack_padded_sequence`` and
606+ # ``pad_packed_sequence`` to handle variable-length sequences efficiently.
607+ # This ensures the GRU's final hidden state represents the actual sentence
608+ # content, not padding tokens.
609+ #
610+ # 3. **Loss Masking**: The loss function uses ``ignore_index=PAD_token`` to
611+ # exclude padding tokens from the loss computation. This prevents the model
612+ # from learning to predict padding and ensures gradients only flow from
613+ # actual target tokens.
614+ #
615+ # 4. **Embedding Padding**: All embedding layers use ``padding_idx=PAD_token``
616+ # to ensure padding tokens have zero embeddings that don't get updated
617+ # during training.
618+ #
586619# Training the Model
587620# ------------------
588621#
@@ -678,7 +711,7 @@ def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
678711
679712 encoder_optimizer = optim .Adam (encoder .parameters (), lr = learning_rate )
680713 decoder_optimizer = optim .Adam (decoder .parameters (), lr = learning_rate )
681- criterion = nn .NLLLoss ()
714+ criterion = nn .NLLLoss (ignore_index = PAD_token )
682715
683716 for epoch in range (1 , n_epochs + 1 ):
684717 loss = train_epoch (train_dataloader , encoder , decoder , encoder_optimizer , decoder_optimizer , criterion )
0 commit comments