@@ -88,7 +88,7 @@ def __init__(
88
88
self .bos_token_id = vocab_size - 1
89
89
self .eos_token_id = vocab_size - 1
90
90
91
- def prepare_config_and_inputs (self ):
91
+ def prepare_config_and_inputs (self , gradient_checkpointing = False ):
92
92
input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
93
93
94
94
input_mask = None
@@ -127,6 +127,7 @@ def prepare_config_and_inputs(self):
127
127
bos_token_id = self .bos_token_id ,
128
128
eos_token_id = self .eos_token_id ,
129
129
return_dict = True ,
130
+ gradient_checkpointing = gradient_checkpointing ,
130
131
)
131
132
132
133
head_mask = ids_tensor ([self .num_hidden_layers , self .num_attention_heads ], 2 )
@@ -269,6 +270,15 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas
269
270
self .parent .assertEqual (result .loss .shape , ())
270
271
self .parent .assertEqual (result .logits .shape , (self .batch_size , self .seq_length , self .vocab_size ))
271
272
273
+ def create_and_check_forward_and_backwards (self , config , input_ids , input_mask , head_mask , token_type_ids , * args ):
274
+ model = GPT2LMHeadModel (config )
275
+ model .to (torch_device )
276
+
277
+ result = model (input_ids , token_type_ids = token_type_ids , labels = input_ids )
278
+ self .parent .assertEqual (result .loss .shape , ())
279
+ self .parent .assertEqual (result .logits .shape , (self .batch_size , self .seq_length , self .vocab_size ))
280
+ result .loss .backward ()
281
+
272
282
def create_and_check_double_lm_head_model (
273
283
self , config , input_ids , input_mask , head_mask , token_type_ids , mc_token_ids , * args
274
284
):
@@ -355,6 +365,10 @@ def test_gpt2_double_lm_head_model(self):
355
365
config_and_inputs = self .model_tester .prepare_config_and_inputs ()
356
366
self .model_tester .create_and_check_double_lm_head_model (* config_and_inputs )
357
367
368
+ def test_gpt2_gradient_checkpointing (self ):
369
+ config_and_inputs = self .model_tester .prepare_config_and_inputs (gradient_checkpointing = True )
370
+ self .model_tester .create_and_check_forward_and_backwards (* config_and_inputs )
371
+
358
372
@slow
359
373
def test_model_from_pretrained (self ):
360
374
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST [:1 ]:
@@ -366,33 +380,34 @@ def test_model_from_pretrained(self):
366
380
class GPT2ModelLanguageGenerationTest (unittest .TestCase ):
367
381
@slow
368
382
def test_lm_generate_gpt2 (self ):
369
- model = GPT2LMHeadModel .from_pretrained ("gpt2" )
370
- model .to (torch_device )
371
- input_ids = torch .tensor ([[464 , 3290 ]], dtype = torch .long , device = torch_device ) # The dog
372
- expected_output_ids = [
373
- 464 ,
374
- 3290 ,
375
- 373 ,
376
- 1043 ,
377
- 287 ,
378
- 257 ,
379
- 2214 ,
380
- 1474 ,
381
- 262 ,
382
- 16246 ,
383
- 286 ,
384
- 2688 ,
385
- 290 ,
386
- 2688 ,
387
- 27262 ,
388
- 13 ,
389
- 198 ,
390
- 198 ,
391
- 464 ,
392
- 3290 ,
393
- ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
394
- output_ids = model .generate (input_ids , do_sample = False )
395
- self .assertListEqual (output_ids [0 ].tolist (), expected_output_ids )
383
+ for checkpointing in [True , False ]:
384
+ model = GPT2LMHeadModel .from_pretrained ("gpt2" , gradient_checkpointing = checkpointing )
385
+ model .to (torch_device )
386
+ input_ids = torch .tensor ([[464 , 3290 ]], dtype = torch .long , device = torch_device ) # The dog
387
+ expected_output_ids = [
388
+ 464 ,
389
+ 3290 ,
390
+ 373 ,
391
+ 1043 ,
392
+ 287 ,
393
+ 257 ,
394
+ 2214 ,
395
+ 1474 ,
396
+ 262 ,
397
+ 16246 ,
398
+ 286 ,
399
+ 2688 ,
400
+ 290 ,
401
+ 2688 ,
402
+ 27262 ,
403
+ 13 ,
404
+ 198 ,
405
+ 198 ,
406
+ 464 ,
407
+ 3290 ,
408
+ ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
409
+ output_ids = model .generate (input_ids , do_sample = False )
410
+ self .assertListEqual (output_ids [0 ].tolist (), expected_output_ids )
396
411
397
412
@slow
398
413
def test_lm_generate_distilgpt2 (self ):
0 commit comments