Skip to content

Commit 9e9a1fb

Browse files
TevenLeScaopatrickvonplatensgugger
authored
Adding gradient checkpointing to GPT2 (#7446)
* GPT2 gradient checkpointing * find_unused_parameters removed if checkpointing * find_unused_parameters removed if checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Added a test for generation with checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 52e8392 commit 9e9a1fb

File tree

4 files changed

+79
-40
lines changed

4 files changed

+79
-40
lines changed

src/transformers/configuration_gpt2.py

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class GPT2Config(PretrainedConfig):
103103
:class:`~transformers.GPT2DoubleHeadsModel` and :class:`~transformers.TFGPT2DoubleHeadsModel`.
104104
105105
The dropout ratio to be used after the projection and activation.
106+
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
107+
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
106108
107109
Example::
108110
@@ -142,6 +144,7 @@ def __init__(
142144
summary_first_dropout=0.1,
143145
bos_token_id=50256,
144146
eos_token_id=50256,
147+
gradient_checkpointing=False,
145148
**kwargs
146149
):
147150
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -164,6 +167,7 @@ def __init__(
164167
self.summary_activation = summary_activation
165168
self.summary_first_dropout = summary_first_dropout
166169
self.summary_proj_to_labels = summary_proj_to_labels
170+
self.gradient_checkpointing = gradient_checkpointing
167171

168172
self.bos_token_id = bos_token_id
169173
self.eos_token_id = eos_token_id

src/transformers/modeling_gpt2.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
"""PyTorch OpenAI GPT-2 model."""
1717

18-
1918
import os
2019
import warnings
2120
from dataclasses import dataclass
@@ -624,16 +623,35 @@ def forward(
624623
if output_hidden_states:
625624
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
626625

627-
outputs = block(
628-
hidden_states,
629-
layer_past=layer_past,
630-
attention_mask=attention_mask,
631-
head_mask=head_mask[i],
632-
encoder_hidden_states=encoder_hidden_states,
633-
encoder_attention_mask=encoder_attention_mask,
634-
use_cache=use_cache,
635-
output_attentions=output_attentions,
636-
)
626+
if getattr(self.config, "gradient_checkpointing", False):
627+
628+
def create_custom_forward(module):
629+
def custom_forward(*inputs):
630+
# checkpointing only works with tuple returns, not with lists
631+
return tuple(output for output in module(*inputs, use_cache, output_attentions))
632+
633+
return custom_forward
634+
635+
outputs = torch.utils.checkpoint.checkpoint(
636+
create_custom_forward(block),
637+
hidden_states,
638+
layer_past,
639+
attention_mask,
640+
head_mask[i],
641+
encoder_hidden_states,
642+
encoder_attention_mask,
643+
)
644+
else:
645+
outputs = block(
646+
hidden_states,
647+
layer_past=layer_past,
648+
attention_mask=attention_mask,
649+
head_mask=head_mask[i],
650+
encoder_hidden_states=encoder_hidden_states,
651+
encoder_attention_mask=encoder_attention_mask,
652+
use_cache=use_cache,
653+
output_attentions=output_attentions,
654+
)
637655

638656
hidden_states, present = outputs[:2]
639657
if use_cache is True:

src/transformers/trainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,10 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
679679
model,
680680
device_ids=[self.args.local_rank],
681681
output_device=self.args.local_rank,
682-
find_unused_parameters=True,
682+
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
683683
)
684+
# find_unused_parameters breaks checkpointing as per
685+
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
684686

685687
if self.tb_writer is not None:
686688
self.tb_writer.add_text("args", self.args.to_json_string())

tests/test_modeling_gpt2.py

+43-28
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.bos_token_id = vocab_size - 1
8989
self.eos_token_id = vocab_size - 1
9090

91-
def prepare_config_and_inputs(self):
91+
def prepare_config_and_inputs(self, gradient_checkpointing=False):
9292
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
9393

9494
input_mask = None
@@ -127,6 +127,7 @@ def prepare_config_and_inputs(self):
127127
bos_token_id=self.bos_token_id,
128128
eos_token_id=self.eos_token_id,
129129
return_dict=True,
130+
gradient_checkpointing=gradient_checkpointing,
130131
)
131132

132133
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@@ -269,6 +270,15 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas
269270
self.parent.assertEqual(result.loss.shape, ())
270271
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
271272

273+
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
274+
model = GPT2LMHeadModel(config)
275+
model.to(torch_device)
276+
277+
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
278+
self.parent.assertEqual(result.loss.shape, ())
279+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
280+
result.loss.backward()
281+
272282
def create_and_check_double_lm_head_model(
273283
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
274284
):
@@ -355,6 +365,10 @@ def test_gpt2_double_lm_head_model(self):
355365
config_and_inputs = self.model_tester.prepare_config_and_inputs()
356366
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
357367

368+
def test_gpt2_gradient_checkpointing(self):
369+
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
370+
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
371+
358372
@slow
359373
def test_model_from_pretrained(self):
360374
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -366,33 +380,34 @@ def test_model_from_pretrained(self):
366380
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
367381
@slow
368382
def test_lm_generate_gpt2(self):
369-
model = GPT2LMHeadModel.from_pretrained("gpt2")
370-
model.to(torch_device)
371-
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
372-
expected_output_ids = [
373-
464,
374-
3290,
375-
373,
376-
1043,
377-
287,
378-
257,
379-
2214,
380-
1474,
381-
262,
382-
16246,
383-
286,
384-
2688,
385-
290,
386-
2688,
387-
27262,
388-
13,
389-
198,
390-
198,
391-
464,
392-
3290,
393-
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
394-
output_ids = model.generate(input_ids, do_sample=False)
395-
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
383+
for checkpointing in [True, False]:
384+
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
385+
model.to(torch_device)
386+
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
387+
expected_output_ids = [
388+
464,
389+
3290,
390+
373,
391+
1043,
392+
287,
393+
257,
394+
2214,
395+
1474,
396+
262,
397+
16246,
398+
286,
399+
2688,
400+
290,
401+
2688,
402+
27262,
403+
13,
404+
198,
405+
198,
406+
464,
407+
3290,
408+
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
409+
output_ids = model.generate(input_ids, do_sample=False)
410+
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
396411

397412
@slow
398413
def test_lm_generate_distilgpt2(self):

0 commit comments

Comments
 (0)