Skip to content

Commit 2d1d921

Browse files
stas00LysandreJik
andauthored
[roberta] fix lm_head.decoder.weight ignore_key handling (#12446)
* fix lm_head.decoder.weight ignore_key handling * fix the mutable class variable * Update src/transformers/models/roberta/modeling_roberta.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * replicate the comment * make deterministic Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
1 parent 7f0027d commit 2d1d921

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

src/transformers/modeling_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
445445
# (and avoid unnecessary warnings).
446446
_keys_to_ignore_on_load_unexpected = None
447447
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
448-
# trained, but which are deterministic)
448+
# trained, but which are deterministic, or tied variables)
449449
_keys_to_ignore_on_save = None
450450

451451
is_parallelizable = False

src/transformers/models/roberta/modeling_roberta.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,15 @@ def _init_weights(self, module):
603603
module.bias.data.zero_()
604604
module.weight.data.fill_(1.0)
605605

606+
def update_keys_to_ignore(self, config, del_keys_to_ignore):
607+
"""Remove some keys from ignore list"""
608+
if not config.tie_word_embeddings:
609+
# must make a new list, or the class variable gets modified!
610+
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
611+
self._keys_to_ignore_on_load_missing = [
612+
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
613+
]
614+
606615

607616
ROBERTA_START_DOCSTRING = r"""
608617
@@ -864,7 +873,8 @@ def forward(
864873
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
865874
)
866875
class RobertaForCausalLM(RobertaPreTrainedModel):
867-
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
876+
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
877+
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
868878
_keys_to_ignore_on_load_unexpected = [r"pooler"]
869879

870880
def __init__(self, config):
@@ -876,6 +886,9 @@ def __init__(self, config):
876886
self.roberta = RobertaModel(config, add_pooling_layer=False)
877887
self.lm_head = RobertaLMHead(config)
878888

889+
# The LM head weights require special treatment only when they are tied with the word embeddings
890+
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
891+
879892
self.init_weights()
880893

881894
def get_output_embeddings(self):
@@ -1010,7 +1023,8 @@ def _reorder_cache(self, past, beam_idx):
10101023

10111024
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
10121025
class RobertaForMaskedLM(RobertaPreTrainedModel):
1013-
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
1026+
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1027+
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
10141028
_keys_to_ignore_on_load_unexpected = [r"pooler"]
10151029

10161030
def __init__(self, config):
@@ -1025,6 +1039,9 @@ def __init__(self, config):
10251039
self.roberta = RobertaModel(config, add_pooling_layer=False)
10261040
self.lm_head = RobertaLMHead(config)
10271041

1042+
# The LM head weights require special treatment only when they are tied with the word embeddings
1043+
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1044+
10281045
self.init_weights()
10291046

10301047
def get_output_embeddings(self):

tests/test_modeling_common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_save_load(self):
164164
max_diff = np.amax(np.abs(out_1 - out_2))
165165
self.assertLessEqual(max_diff, 1e-5)
166166

167-
def test_save_load__keys_to_ignore_on_save(self):
167+
def test_save_load_keys_to_ignore_on_save(self):
168168
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
169169

170170
for model_class in self.all_model_classes:
@@ -175,15 +175,15 @@ def test_save_load__keys_to_ignore_on_save(self):
175175

176176
# check the keys are in the original state_dict
177177
for k in _keys_to_ignore_on_save:
178-
self.assertIn(k, model.state_dict())
178+
self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))
179179

180180
# check that certain keys didn't get saved with the model
181181
with tempfile.TemporaryDirectory() as tmpdirname:
182182
model.save_pretrained(tmpdirname)
183183
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
184184
state_dict_saved = torch.load(output_model_file)
185185
for k in _keys_to_ignore_on_save:
186-
self.assertNotIn(k, state_dict_saved)
186+
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
187187

188188
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
189189
load_result = model.load_state_dict(state_dict_saved, strict=False)

tests/test_modeling_roberta.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616

1717
import unittest
18+
from copy import deepcopy
1819

1920
from transformers import is_torch_available
20-
from transformers.testing_utils import require_torch, slow, torch_device
21+
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
2122

2223
from .test_configuration_common import ConfigTester
2324
from .test_generation_utils import GenerationTesterMixin
@@ -43,6 +44,8 @@
4344
create_position_ids_from_input_ids,
4445
)
4546

47+
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
48+
4649

4750
class RobertaModelTester:
4851
def __init__(
@@ -475,7 +478,7 @@ def test_create_position_ids_from_inputs_embeds(self):
475478

476479

477480
@require_torch
478-
class RobertaModelIntegrationTest(unittest.TestCase):
481+
class RobertaModelIntegrationTest(TestCasePlus):
479482
@slow
480483
def test_inference_masked_lm(self):
481484
model = RobertaForMaskedLM.from_pretrained("roberta-base")
@@ -527,3 +530,23 @@ def test_inference_classification_head(self):
527530
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
528531

529532
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
533+
534+
# XXX: this might be a candidate for common tests if we have many of those
535+
def test_lm_head_ignore_keys(self):
536+
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
537+
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
538+
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
539+
config_tied = deepcopy(config)
540+
config_tied.tie_word_embeddings = True
541+
config_untied = deepcopy(config)
542+
config_untied.tie_word_embeddings = False
543+
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
544+
model = cls(config_tied)
545+
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
546+
547+
# the keys should be different when embeddings aren't tied
548+
model = cls(config_untied)
549+
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
550+
551+
# test that saving works with updated ignore keys - just testing that it doesn't fail
552+
model.save_pretrained(self.get_auto_remove_tmp_dir())

0 commit comments

Comments
 (0)