@@ -603,6 +603,15 @@ def _init_weights(self, module):
603
603
module .bias .data .zero_ ()
604
604
module .weight .data .fill_ (1.0 )
605
605
606
+ def update_keys_to_ignore (self , config , del_keys_to_ignore ):
607
+ """Remove some keys from ignore list"""
608
+ if not config .tie_word_embeddings :
609
+ # must make a new list, or the class variable gets modified!
610
+ self ._keys_to_ignore_on_save = [k for k in self ._keys_to_ignore_on_save if k not in del_keys_to_ignore ]
611
+ self ._keys_to_ignore_on_load_missing = [
612
+ k for k in self ._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
613
+ ]
614
+
606
615
607
616
ROBERTA_START_DOCSTRING = r"""
608
617
@@ -864,7 +873,8 @@ def forward(
864
873
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """ , ROBERTA_START_DOCSTRING
865
874
)
866
875
class RobertaForCausalLM (RobertaPreTrainedModel ):
867
- _keys_to_ignore_on_load_missing = [r"position_ids" , r"lm_head.decoder.bias" ]
876
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight" , r"lm_head.decoder.bias" ]
877
+ _keys_to_ignore_on_load_missing = [r"position_ids" , r"lm_head.decoder.weight" , r"lm_head.decoder.bias" ]
868
878
_keys_to_ignore_on_load_unexpected = [r"pooler" ]
869
879
870
880
def __init__ (self , config ):
@@ -876,6 +886,9 @@ def __init__(self, config):
876
886
self .roberta = RobertaModel (config , add_pooling_layer = False )
877
887
self .lm_head = RobertaLMHead (config )
878
888
889
+ # The LM head weights require special treatment only when they are tied with the word embeddings
890
+ self .update_keys_to_ignore (config , ["lm_head.decoder.weight" ])
891
+
879
892
self .init_weights ()
880
893
881
894
def get_output_embeddings (self ):
@@ -1010,7 +1023,8 @@ def _reorder_cache(self, past, beam_idx):
1010
1023
1011
1024
@add_start_docstrings ("""RoBERTa Model with a `language modeling` head on top. """ , ROBERTA_START_DOCSTRING )
1012
1025
class RobertaForMaskedLM (RobertaPreTrainedModel ):
1013
- _keys_to_ignore_on_load_missing = [r"position_ids" , r"lm_head.decoder.bias" ]
1026
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight" , r"lm_head.decoder.bias" ]
1027
+ _keys_to_ignore_on_load_missing = [r"position_ids" , r"lm_head.decoder.weight" , r"lm_head.decoder.bias" ]
1014
1028
_keys_to_ignore_on_load_unexpected = [r"pooler" ]
1015
1029
1016
1030
def __init__ (self , config ):
@@ -1025,6 +1039,9 @@ def __init__(self, config):
1025
1039
self .roberta = RobertaModel (config , add_pooling_layer = False )
1026
1040
self .lm_head = RobertaLMHead (config )
1027
1041
1042
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1043
+ self .update_keys_to_ignore (config , ["lm_head.decoder.weight" ])
1044
+
1028
1045
self .init_weights ()
1029
1046
1030
1047
def get_output_embeddings (self ):
0 commit comments