diff --git a/_static/img/seq-seq-images/attention-decoder-network.png b/_static/img/seq-seq-images/attention-decoder-network.png index 243f87c6e97..d31d42a5af1 100755 Binary files a/_static/img/seq-seq-images/attention-decoder-network.png and b/_static/img/seq-seq-images/attention-decoder-network.png differ diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py index ea583821f85..c2b0b722e5b 100644 --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT self.max_length = max_length self.embedding = nn.Embedding(self.output_size, self.hidden_size) - self.attn = nn.Linear(self.hidden_size * 2, self.max_length) - self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) + self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size)) + torch.nn.init.xavier_uniform_(self.alignment_vector) self.dropout = nn.Dropout(self.dropout_p) - self.gru = nn.GRU(self.hidden_size, self.hidden_size) + self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size) self.out = nn.Linear(self.hidden_size, self.output_size) def forward(self, input, hidden, encoder_outputs): - embedded = self.embedding(input).view(1, 1, -1) + embedded = self.embedding(input).view(1, -1) embedded = self.dropout(embedded) - attn_weights = F.softmax( - self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) - attn_applied = torch.bmm(attn_weights.unsqueeze(0), - encoder_outputs.unsqueeze(0)) - - output = torch.cat((embedded[0], attn_applied[0]), 1) - output = self.attn_combine(output).unsqueeze(0) + transformed_hidden = self.fc_hidden(hidden[0]) + expanded_hidden_state = transformed_hidden.expand(self.max_length, -1) + alignment_scores = torch.tanh(expanded_hidden_state + + self.fc_encoder(encoder_outputs)) + alignment_scores = self.alignment_vector.mm(alignment_scores.T) + attn_weights = F.softmax(alignment_scores, dim=1) + context_vector = attn_weights.mm(encoder_outputs) - output = F.relu(output) + output = torch.cat((embedded, context_vector), 1).unsqueeze(0) output, hidden = self.gru(output, hidden) output = F.log_softmax(self.out(output[0]), dim=1) @@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10): # hidden_size = 256 -encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device) -attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device) +encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) +attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device) -trainIters(encoder1, attn_decoder1, 75000, print_every=5000) +trainIters(encoder, attn_decoder, 75000, print_every=5000) ###################################################################### # -evaluateRandomly(encoder1, attn_decoder1) +evaluateRandomly(encoder, attn_decoder) ###################################################################### @@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10): # output_words, attentions = evaluate( - encoder1, attn_decoder1, "je suis trop froid .") + encoder, attn_decoder, "je suis trop froid .") plt.matshow(attentions.numpy()) @@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions): def evaluateAndShowAttention(input_sentence): output_words, attentions = evaluate( - encoder1, attn_decoder1, input_sentence) + encoder, attn_decoder, input_sentence) print('input =', input_sentence) print('output =', ' '.join(output_words)) showAttention(input_sentence, output_words, attentions)