36
36
# of the word (see the next paragraph for more details). The
37
37
# ``nn.TransformerEncoder`` consists of multiple layers of
38
38
# `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__.
39
- # Along with the input sequence, a square attention mask is required because the
40
- # self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend
41
- # the earlier positions in the sequence. For the language modeling task, any
42
- # tokens on the future positions should be masked. To produce a probability
43
- # distribution over output words, the output of the ``nn.TransformerEncoder``
44
- # model is passed through a linear layer followed by a log-softmax function.
39
+ # To produce a probability distribution over output words, the output of
40
+ # the ``nn.TransformerEncoder`` model is passed through a linear layer.
45
41
#
46
42
47
43
import math
51
47
52
48
import torch
53
49
from torch import nn , Tensor
54
- import torch .nn .functional as F
55
50
from torch .nn import TransformerEncoder , TransformerEncoderLayer
56
51
from torch .utils .data import dataset
57
52
@@ -64,19 +59,19 @@ def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
64
59
self .pos_encoder = PositionalEncoding (d_model , dropout )
65
60
encoder_layers = TransformerEncoderLayer (d_model , nhead , d_hid , dropout )
66
61
self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
67
- self .encoder = nn .Embedding (ntoken , d_model )
62
+ self .embedding = nn .Embedding (ntoken , d_model )
68
63
self .d_model = d_model
69
- self .decoder = nn .Linear (d_model , ntoken )
64
+ self .linear = nn .Linear (d_model , ntoken )
70
65
71
66
self .init_weights ()
72
67
73
68
def init_weights (self ) -> None :
74
69
initrange = 0.1
75
- self .encoder .weight .data .uniform_ (- initrange , initrange )
76
- self .decoder .bias .data .zero_ ()
77
- self .decoder .weight .data .uniform_ (- initrange , initrange )
70
+ self .embedding .weight .data .uniform_ (- initrange , initrange )
71
+ self .linear .bias .data .zero_ ()
72
+ self .linear .weight .data .uniform_ (- initrange , initrange )
78
73
79
- def forward (self , src : Tensor , src_mask : Tensor ) -> Tensor :
74
+ def forward (self , src : Tensor , src_mask : Tensor = None ) -> Tensor :
80
75
"""
81
76
Arguments:
82
77
src: Tensor, shape ``[seq_len, batch_size]``
@@ -85,18 +80,13 @@ def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
85
80
Returns:
86
81
output Tensor of shape ``[seq_len, batch_size, ntoken]``
87
82
"""
88
- src = self .encoder (src ) * math .sqrt (self .d_model )
83
+ src = self .embedding (src ) * math .sqrt (self .d_model )
89
84
src = self .pos_encoder (src )
90
85
output = self .transformer_encoder (src , src_mask )
91
- output = self .decoder (output )
86
+ output = self .linear (output )
92
87
return output
93
88
94
89
95
- def generate_square_subsequent_mask (sz : int ) -> Tensor :
96
- """Generates an upper-triangular matrix of ``-inf``, with zeros on ``diag``."""
97
- return torch .triu (torch .ones (sz , sz ) * float ('-inf' ), diagonal = 1 )
98
-
99
-
100
90
######################################################################
101
91
# ``PositionalEncoding`` module injects some information about the
102
92
# relative or absolute position of the tokens in the sequence. The
@@ -286,7 +276,6 @@ def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
286
276
# to prevent gradients from exploding.
287
277
#
288
278
289
- import copy
290
279
import time
291
280
292
281
criterion = nn .CrossEntropyLoss ()
@@ -299,16 +288,13 @@ def train(model: nn.Module) -> None:
299
288
total_loss = 0.
300
289
log_interval = 200
301
290
start_time = time .time ()
302
- src_mask = generate_square_subsequent_mask (bptt ).to (device )
303
291
304
292
num_batches = len (train_data ) // bptt
305
293
for batch , i in enumerate (range (0 , train_data .size (0 ) - 1 , bptt )):
306
294
data , targets = get_batch (train_data , i )
307
- seq_len = data .size (0 )
308
- if seq_len != bptt : # only on last batch
309
- src_mask = src_mask [:seq_len , :seq_len ]
310
- output = model (data , src_mask )
311
- loss = criterion (output .view (- 1 , ntokens ), targets )
295
+ output = model (data )
296
+ output_flat = output .view (- 1 , ntokens )
297
+ loss = criterion (output_flat , targets )
312
298
313
299
optimizer .zero_grad ()
314
300
loss .backward ()
@@ -330,14 +316,11 @@ def train(model: nn.Module) -> None:
330
316
def evaluate (model : nn .Module , eval_data : Tensor ) -> float :
331
317
model .eval () # turn on evaluation mode
332
318
total_loss = 0.
333
- src_mask = generate_square_subsequent_mask (bptt ).to (device )
334
319
with torch .no_grad ():
335
320
for i in range (0 , eval_data .size (0 ) - 1 , bptt ):
336
321
data , targets = get_batch (eval_data , i )
337
322
seq_len = data .size (0 )
338
- if seq_len != bptt :
339
- src_mask = src_mask [:seq_len , :seq_len ]
340
- output = model (data , src_mask )
323
+ output = model (data )
341
324
output_flat = output .view (- 1 , ntokens )
342
325
total_loss += seq_len * criterion (output_flat , targets ).item ()
343
326
return total_loss / (len (eval_data ) - 1 )
0 commit comments