@@ -296,7 +296,7 @@ def test_mbart_fast_forward(self):
296
296
lm_model = BartForConditionalGeneration (config ).to (torch_device )
297
297
context = torch .Tensor ([[71 , 82 , 18 , 33 , 46 , 91 , 2 ], [68 , 34 , 26 , 58 , 30 , 2 , 1 ]]).long ().to (torch_device )
298
298
summary = torch .Tensor ([[82 , 71 , 82 , 18 , 2 ], [58 , 68 , 2 , 1 , 1 ]]).long ().to (torch_device )
299
- loss , logits , enc_features = lm_model (input_ids = context , decoder_input_ids = summary , lm_labels = summary )
299
+ loss , logits , enc_features = lm_model (input_ids = context , decoder_input_ids = summary , labels = summary )
300
300
expected_shape = (* summary .shape , config .vocab_size )
301
301
self .assertEqual (logits .shape , expected_shape )
302
302
@@ -361,7 +361,7 @@ def test_lm_forward(self):
361
361
lm_labels = ids_tensor ([batch_size , input_ids .shape [1 ]], self .vocab_size ).to (torch_device )
362
362
lm_model = BartForConditionalGeneration (config )
363
363
lm_model .to (torch_device )
364
- loss , logits , enc_features = lm_model (input_ids = input_ids , lm_labels = lm_labels )
364
+ loss , logits , enc_features = lm_model (input_ids = input_ids , labels = lm_labels )
365
365
expected_shape = (batch_size , input_ids .shape [1 ], config .vocab_size )
366
366
self .assertEqual (logits .shape , expected_shape )
367
367
self .assertIsInstance (loss .item (), float )
@@ -381,7 +381,7 @@ def test_lm_uneven_forward(self):
381
381
lm_model = BartForConditionalGeneration (config ).to (torch_device )
382
382
context = torch .Tensor ([[71 , 82 , 18 , 33 , 46 , 91 , 2 ], [68 , 34 , 26 , 58 , 30 , 2 , 1 ]]).long ().to (torch_device )
383
383
summary = torch .Tensor ([[82 , 71 , 82 , 18 , 2 ], [58 , 68 , 2 , 1 , 1 ]]).long ().to (torch_device )
384
- loss , logits , enc_features = lm_model (input_ids = context , decoder_input_ids = summary , lm_labels = summary )
384
+ loss , logits , enc_features = lm_model (input_ids = context , decoder_input_ids = summary , labels = summary )
385
385
expected_shape = (* summary .shape , config .vocab_size )
386
386
self .assertEqual (logits .shape , expected_shape )
387
387
0 commit comments