Skip to content

Commit a5376f7

Browse files
QasimKhan5xSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Fix Attention in seq2seq_translation_tutorial AttnDecoderRNN (#2452)
* replace old decoder diagram with new one * remove 1 from encoder1 and decoder1 * fix attention in AttnDecoderRNN * Fix formatting going over max character count --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 203f567 commit a5376f7

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed
Loading

intermediate_source/seq2seq_translation_tutorial.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
440440
self.max_length = max_length
441441

442442
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
443-
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
444-
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
443+
self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
444+
self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
445+
self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size))
446+
torch.nn.init.xavier_uniform_(self.alignment_vector)
445447
self.dropout = nn.Dropout(self.dropout_p)
446-
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
448+
self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size)
447449
self.out = nn.Linear(self.hidden_size, self.output_size)
448450

449451
def forward(self, input, hidden, encoder_outputs):
450-
embedded = self.embedding(input).view(1, 1, -1)
452+
embedded = self.embedding(input).view(1, -1)
451453
embedded = self.dropout(embedded)
452454

453-
attn_weights = F.softmax(
454-
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
455-
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
456-
encoder_outputs.unsqueeze(0))
457-
458-
output = torch.cat((embedded[0], attn_applied[0]), 1)
459-
output = self.attn_combine(output).unsqueeze(0)
455+
transformed_hidden = self.fc_hidden(hidden[0])
456+
expanded_hidden_state = transformed_hidden.expand(self.max_length, -1)
457+
alignment_scores = torch.tanh(expanded_hidden_state +
458+
self.fc_encoder(encoder_outputs))
459+
alignment_scores = self.alignment_vector.mm(alignment_scores.T)
460+
attn_weights = F.softmax(alignment_scores, dim=1)
461+
context_vector = attn_weights.mm(encoder_outputs)
460462

461-
output = F.relu(output)
463+
output = torch.cat((embedded, context_vector), 1).unsqueeze(0)
462464
output, hidden = self.gru(output, hidden)
463465

464466
output = F.log_softmax(self.out(output[0]), dim=1)
@@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
761763
#
762764

763765
hidden_size = 256
764-
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
765-
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
766+
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
767+
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
766768

767-
trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
769+
trainIters(encoder, attn_decoder, 75000, print_every=5000)
768770

769771
######################################################################
770772
#
771773

772-
evaluateRandomly(encoder1, attn_decoder1)
774+
evaluateRandomly(encoder, attn_decoder)
773775

774776

775777
######################################################################
@@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
787789
#
788790

789791
output_words, attentions = evaluate(
790-
encoder1, attn_decoder1, "je suis trop froid .")
792+
encoder, attn_decoder, "je suis trop froid .")
791793
plt.matshow(attentions.numpy())
792794

793795

@@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):
817819

818820
def evaluateAndShowAttention(input_sentence):
819821
output_words, attentions = evaluate(
820-
encoder1, attn_decoder1, input_sentence)
822+
encoder, attn_decoder, input_sentence)
821823
print('input =', input_sentence)
822824
print('output =', ' '.join(output_words))
823825
showAttention(input_sentence, output_words, attentions)

0 commit comments

Comments
 (0)