@@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
440
440
self .max_length = max_length
441
441
442
442
self .embedding = nn .Embedding (self .output_size , self .hidden_size )
443
- self .attn = nn .Linear (self .hidden_size * 2 , self .max_length )
444
- self .attn_combine = nn .Linear (self .hidden_size * 2 , self .hidden_size )
443
+ self .fc_hidden = nn .Linear (self .hidden_size , self .hidden_size , bias = False )
444
+ self .fc_encoder = nn .Linear (self .hidden_size , self .hidden_size , bias = False )
445
+ self .alignment_vector = nn .Parameter (torch .Tensor (1 , hidden_size ))
446
+ torch .nn .init .xavier_uniform_ (self .alignment_vector )
445
447
self .dropout = nn .Dropout (self .dropout_p )
446
- self .gru = nn .GRU (self .hidden_size , self .hidden_size )
448
+ self .gru = nn .GRU (self .hidden_size * 2 , self .hidden_size )
447
449
self .out = nn .Linear (self .hidden_size , self .output_size )
448
450
449
451
def forward (self , input , hidden , encoder_outputs ):
450
- embedded = self .embedding (input ).view (1 , 1 , - 1 )
452
+ embedded = self .embedding (input ).view (1 , - 1 )
451
453
embedded = self .dropout (embedded )
452
454
453
- attn_weights = F . softmax (
454
- self . attn ( torch . cat (( embedded [ 0 ], hidden [ 0 ]), 1 )), dim = 1 )
455
- attn_applied = torch .bmm ( attn_weights . unsqueeze ( 0 ),
456
- encoder_outputs . unsqueeze ( 0 ))
457
-
458
- output = torch . cat (( embedded [ 0 ], attn_applied [ 0 ]), 1 )
459
- output = self . attn_combine ( output ). unsqueeze ( 0 )
455
+ transformed_hidden = self . fc_hidden ( hidden [ 0 ])
456
+ expanded_hidden_state = transformed_hidden . expand ( self . max_length , - 1 )
457
+ alignment_scores = torch .tanh ( expanded_hidden_state +
458
+ self . fc_encoder ( encoder_outputs ))
459
+ alignment_scores = self . alignment_vector . mm ( alignment_scores . T )
460
+ attn_weights = F . softmax ( alignment_scores , dim = 1 )
461
+ context_vector = attn_weights . mm ( encoder_outputs )
460
462
461
- output = F . relu ( output )
463
+ output = torch . cat (( embedded , context_vector ), 1 ). unsqueeze ( 0 )
462
464
output , hidden = self .gru (output , hidden )
463
465
464
466
output = F .log_softmax (self .out (output [0 ]), dim = 1 )
@@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
761
763
#
762
764
763
765
hidden_size = 256
764
- encoder1 = EncoderRNN (input_lang .n_words , hidden_size ).to (device )
765
- attn_decoder1 = AttnDecoderRNN (hidden_size , output_lang .n_words , dropout_p = 0.1 ).to (device )
766
+ encoder = EncoderRNN (input_lang .n_words , hidden_size ).to (device )
767
+ attn_decoder = AttnDecoderRNN (hidden_size , output_lang .n_words , dropout_p = 0.1 ).to (device )
766
768
767
- trainIters (encoder1 , attn_decoder1 , 75000 , print_every = 5000 )
769
+ trainIters (encoder , attn_decoder , 75000 , print_every = 5000 )
768
770
769
771
######################################################################
770
772
#
771
773
772
- evaluateRandomly (encoder1 , attn_decoder1 )
774
+ evaluateRandomly (encoder , attn_decoder )
773
775
774
776
775
777
######################################################################
@@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
787
789
#
788
790
789
791
output_words , attentions = evaluate (
790
- encoder1 , attn_decoder1 , "je suis trop froid ." )
792
+ encoder , attn_decoder , "je suis trop froid ." )
791
793
plt .matshow (attentions .numpy ())
792
794
793
795
@@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):
817
819
818
820
def evaluateAndShowAttention (input_sentence ):
819
821
output_words , attentions = evaluate (
820
- encoder1 , attn_decoder1 , input_sentence )
822
+ encoder , attn_decoder , input_sentence )
821
823
print ('input =' , input_sentence )
822
824
print ('output =' , ' ' .join (output_words ))
823
825
showAttention (input_sentence , output_words , attentions )
0 commit comments