Skip to content

Commit 0650b47

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into issues
2 parents c9dea0f + a5376f7 commit 0650b47

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed
Loading

beginner_source/transformer_tutorial.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Language Modeling with ``nn.Transformer`` and torchtext
33
===============================================================
44
5-
This is a tutorial on training a sequence-to-sequence model that uses the
5+
This is a tutorial on training a model to predict the next word in a sequence using the
66
`nn.Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ module.
77
88
The PyTorch 1.2 release includes a standard transformer module based on the
@@ -29,7 +29,9 @@
2929

3030
######################################################################
3131
# In this tutorial, we train a ``nn.TransformerEncoder`` model on a
32-
# language modeling task. The language modeling task is to assign a
32+
# language modeling task. Please note that this tutorial does not cover
33+
# the training of `nn.TransformerDecoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html#torch.nn.TransformerDecoder>`__, as depicted in
34+
# the right half of the diagram above. The language modeling task is to assign a
3335
# probability for the likelihood of a given word (or a sequence of words)
3436
# to follow a sequence of words. A sequence of tokens are passed to the embedding
3537
# layer first, followed by a positional encoding layer to account for the order
@@ -130,6 +132,7 @@ def forward(self, x: Tensor) -> Tensor:
130132
# .. code-block:: bash
131133
#
132134
# %%bash
135+
# pip install portalocker
133136
# pip install torchdata
134137
#
135138
# The vocab object is built based on the train dataset and is used to numericalize

intermediate_source/FSDP_adavnced_tutorial.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ highlight different available features in FSDP that are helpful for training
7575
large scale model above 3B parameters. Also, we cover specific features for
7676
Transformer based models. The code for this tutorial is available in `Pytorch
7777
Examples
78-
<https://github.com/HamidShojanazeri/examples/tree/FSDP_example/FSDP/>`__.
78+
<https://github.com/HamidShojanazeri/examples/tree/FSDP_example/distributed/FSDP/>`__.
7979

8080

8181
*Setup*
@@ -97,13 +97,13 @@ Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv
9797
`wikihowSep.cs <https://ucsb.app.box.com/s/7yq601ijl1lzvlfu4rjdbbxforzd2oag>`__,
9898
and place them in the `data` folder. We will use the wikihow dataset from
9999
`summarization_dataset
100-
<https://github.com/HamidShojanazeri/examples/blob/FSDP_example/FSDP/summarization_dataset.py>`__.
100+
<https://github.com/HamidShojanazeri/examples/blob/FSDP_example/distributed/FSDP/summarization_dataset.py>`__.
101101

102102
Next, we add the following code snippets to a Python script “T5_training.py”.
103103

104104
.. note::
105105
The full source code for this tutorial is available in `PyTorch examples
106-
<https://github.com/HamidShojanazeri/examples/tree/FSDP_example/FSDP>`__.
106+
<https://github.com/HamidShojanazeri/examples/tree/FSDP_example/distributed/FSDP>`__.
107107

108108
1.3 Import necessary packages:
109109

intermediate_source/seq2seq_translation_tutorial.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
440440
self.max_length = max_length
441441

442442
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
443-
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
444-
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
443+
self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
444+
self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
445+
self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size))
446+
torch.nn.init.xavier_uniform_(self.alignment_vector)
445447
self.dropout = nn.Dropout(self.dropout_p)
446-
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
448+
self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size)
447449
self.out = nn.Linear(self.hidden_size, self.output_size)
448450

449451
def forward(self, input, hidden, encoder_outputs):
450-
embedded = self.embedding(input).view(1, 1, -1)
452+
embedded = self.embedding(input).view(1, -1)
451453
embedded = self.dropout(embedded)
452454

453-
attn_weights = F.softmax(
454-
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
455-
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
456-
encoder_outputs.unsqueeze(0))
457-
458-
output = torch.cat((embedded[0], attn_applied[0]), 1)
459-
output = self.attn_combine(output).unsqueeze(0)
455+
transformed_hidden = self.fc_hidden(hidden[0])
456+
expanded_hidden_state = transformed_hidden.expand(self.max_length, -1)
457+
alignment_scores = torch.tanh(expanded_hidden_state +
458+
self.fc_encoder(encoder_outputs))
459+
alignment_scores = self.alignment_vector.mm(alignment_scores.T)
460+
attn_weights = F.softmax(alignment_scores, dim=1)
461+
context_vector = attn_weights.mm(encoder_outputs)
460462

461-
output = F.relu(output)
463+
output = torch.cat((embedded, context_vector), 1).unsqueeze(0)
462464
output, hidden = self.gru(output, hidden)
463465

464466
output = F.log_softmax(self.out(output[0]), dim=1)
@@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
761763
#
762764

763765
hidden_size = 256
764-
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
765-
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
766+
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
767+
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
766768

767-
trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
769+
trainIters(encoder, attn_decoder, 75000, print_every=5000)
768770

769771
######################################################################
770772
#
771773

772-
evaluateRandomly(encoder1, attn_decoder1)
774+
evaluateRandomly(encoder, attn_decoder)
773775

774776

775777
######################################################################
@@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
787789
#
788790

789791
output_words, attentions = evaluate(
790-
encoder1, attn_decoder1, "je suis trop froid .")
792+
encoder, attn_decoder, "je suis trop froid .")
791793
plt.matshow(attentions.numpy())
792794

793795

@@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):
817819

818820
def evaluateAndShowAttention(input_sentence):
819821
output_words, attentions = evaluate(
820-
encoder1, attn_decoder1, input_sentence)
822+
encoder, attn_decoder, input_sentence)
821823
print('input =', input_sentence)
822824
print('output =', ' '.join(output_words))
823825
showAttention(input_sentence, output_words, attentions)

0 commit comments

Comments
 (0)