@@ -411,29 +411,6 @@ def dummy_inputs(self):
411
411
}
412
412
return dummy_inputs
413
413
414
- def get_input_embeddings (self ):
415
- base_model = getattr (self , self .base_model_prefix , self )
416
-
417
- return base_model .shared
418
-
419
- def set_input_embeddings (self , value ):
420
- base_model = getattr (self , self .base_model_prefix , self )
421
-
422
- try :
423
- base_model .shared .weight = value
424
- except AttributeError :
425
- self (self .dummy_inputs )
426
- base_model .shared .weight = value
427
-
428
- base_model .shared .vocab_size = shape_list (base_model .shared .weight )[0 ]
429
-
430
- with tf .compat .v1 .variable_scope ("model.shared" ) as shared_abs_scope_name :
431
- pass
432
-
433
- embed_tokens = TFWrappedEmbeddings (base_model .shared , abs_scope_name = shared_abs_scope_name )
434
- base_model .encoder .set_embed_tokens (embed_tokens )
435
- base_model .decoder .set_embed_tokens (embed_tokens )
436
-
437
414
@tf .function (
438
415
input_signature = [
439
416
{
@@ -605,6 +582,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
605
582
self .layers = [TFBartEncoderLayer (config , name = f"layers.{ i } " ) for i in range (config .encoder_layers )]
606
583
self .layernorm_embedding = tf .keras .layers .LayerNormalization (epsilon = 1e-5 , name = "layernorm_embedding" )
607
584
585
+ def get_embed_tokens (self ):
586
+ return self .embed_tokens
587
+
608
588
def set_embed_tokens (self , embed_tokens ):
609
589
self .embed_tokens = embed_tokens
610
590
@@ -744,6 +724,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
744
724
745
725
self .dropout = tf .keras .layers .Dropout (config .dropout )
746
726
727
+ def get_embed_tokens (self ):
728
+ return self .embed_tokens
729
+
747
730
def set_embed_tokens (self , embed_tokens ):
748
731
self .embed_tokens = embed_tokens
749
732
@@ -871,13 +854,15 @@ def call(
871
854
hidden_states = self .dropout (hidden_states , training = inputs ["training" ])
872
855
873
856
# decoder layers
874
- all_hidden_states = ()
875
- all_self_attns = ()
876
- present_key_values = ()
857
+ all_hidden_states = () if inputs ["output_hidden_states" ] else None
858
+ all_self_attns = () if inputs ["output_attentions" ] else None
859
+ present_key_values = () if inputs ["use_cache" ] else None
860
+
877
861
for idx , decoder_layer in enumerate (self .layers ):
878
862
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
879
863
if inputs ["output_hidden_states" ]:
880
864
all_hidden_states += (hidden_states ,)
865
+
881
866
dropout_probability = random .uniform (0 , 1 )
882
867
883
868
if inputs ["training" ] and (dropout_probability < self .layerdrop ):
@@ -901,12 +886,12 @@ def call(
901
886
902
887
if inputs ["output_hidden_states" ]:
903
888
all_hidden_states += (hidden_states ,)
904
- else :
905
- all_hidden_states = None
906
889
907
- all_self_attns = list (all_self_attns ) if inputs ["output_attentions" ] else None
890
+ if inputs ["output_attentions" ]:
891
+ all_self_attns = list (all_self_attns )
908
892
909
- present_key_values = (encoder_hidden_states , present_key_values ) if inputs ["use_cache" ] else None
893
+ if inputs ["use_cache" ]:
894
+ present_key_values = (inputs ["encoder_hidden_states" ], present_key_values )
910
895
911
896
if not inputs ["return_dict" ]:
912
897
return hidden_states , present_key_values , all_hidden_states , all_self_attns
@@ -919,16 +904,14 @@ def call(
919
904
)
920
905
921
906
922
- @add_start_docstrings (
923
- "The bare BART Model outputting raw hidden-states without any specific head on top." ,
924
- BART_START_DOCSTRING ,
925
- )
926
907
@keras_serializable
927
- class TFBartModel ( TFBartPretrainedModel ):
928
- base_model_prefix = "model"
908
+ class TFBartMainLayer ( tf . keras . layers . Layer ):
909
+ config_class = BartConfig
929
910
930
- def __init__ (self , config : BartConfig , * inputs , ** kwargs ):
931
- super ().__init__ (config , * inputs , ** kwargs )
911
+ def __init__ (self , config : BartConfig , ** kwargs ):
912
+ super ().__init__ (** kwargs )
913
+
914
+ self .config = config
932
915
self .shared = TFSharedEmbeddings (config .vocab_size , config .d_model , config .pad_token_id , name = "model.shared" )
933
916
934
917
with tf .compat .v1 .variable_scope ("model.shared" ) as shared_abs_scope_name :
@@ -942,19 +925,20 @@ def __init__(self, config: BartConfig, *inputs, **kwargs):
942
925
self .encoder = TFBartEncoder (config , embed_tokens , name = "encoder" )
943
926
self .decoder = TFBartDecoder (config , embed_tokens , name = "decoder" )
944
927
945
- def get_encoder (self ):
946
- return self .encoder
928
+ def get_input_embeddings (self ):
929
+ return self .shared
947
930
948
- def get_decoder (self ):
949
- return self .decoder
931
+ def set_input_embeddings (self , new_embeddings ):
932
+ self .shared .weight = new_embeddings
933
+ self .shared .vocab_size = self .shared .weight .shape [0 ]
934
+ # retrieve correct absolute scope for embed token wrapper
935
+ with tf .compat .v1 .variable_scope ("model.shared" ) as shared_abs_scope_name :
936
+ pass
937
+ # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
938
+ embed_tokens = TFWrappedEmbeddings (self .shared , abs_scope_name = shared_abs_scope_name )
939
+ self .encoder .set_embed_tokens (embed_tokens )
940
+ self .decoder .set_embed_tokens (embed_tokens )
950
941
951
- @add_start_docstrings_to_model_forward (BART_INPUTS_DOCSTRING .format ("batch_size, sequence_length" ))
952
- @add_code_sample_docstrings (
953
- tokenizer_class = _TOKENIZER_FOR_DOC ,
954
- checkpoint = "facebook/bart-large" ,
955
- output_type = TFSeq2SeqModelOutput ,
956
- config_class = _CONFIG_FOR_DOC ,
957
- )
958
942
def call (
959
943
self ,
960
944
input_ids = None ,
@@ -1053,8 +1037,86 @@ def call(
1053
1037
encoder_attentions = inputs ["encoder_outputs" ].attentions ,
1054
1038
)
1055
1039
1040
+
1041
+ @add_start_docstrings (
1042
+ "The bare BART Model outputting raw hidden-states without any specific head on top." ,
1043
+ BART_START_DOCSTRING ,
1044
+ )
1045
+ class TFBartModel (TFBartPretrainedModel ):
1046
+ def __init__ (self , config : BartConfig , * inputs , ** kwargs ):
1047
+ super ().__init__ (config , * inputs , ** kwargs )
1048
+
1049
+ self .model = TFBartMainLayer (config , name = "model" )
1050
+
1051
+ def get_encoder (self ):
1052
+ return self .model .encoder
1053
+
1054
+ def get_decoder (self ):
1055
+ return self .model .decoder
1056
+
1057
+ @add_start_docstrings_to_model_forward (BART_INPUTS_DOCSTRING .format ("batch_size, sequence_length" ))
1058
+ @add_code_sample_docstrings (
1059
+ tokenizer_class = _TOKENIZER_FOR_DOC ,
1060
+ checkpoint = "facebook/bart-large" ,
1061
+ output_type = TFSeq2SeqModelOutput ,
1062
+ config_class = _CONFIG_FOR_DOC ,
1063
+ )
1064
+ def call (
1065
+ self ,
1066
+ input_ids = None ,
1067
+ attention_mask = None ,
1068
+ decoder_input_ids = None ,
1069
+ decoder_attention_mask = None ,
1070
+ encoder_outputs : Optional [Union [Tuple , TFBaseModelOutput ]] = None ,
1071
+ past_key_values = None ,
1072
+ inputs_embeds = None ,
1073
+ decoder_inputs_embeds = None ,
1074
+ use_cache = None ,
1075
+ output_attentions = None ,
1076
+ output_hidden_states = None ,
1077
+ return_dict = None ,
1078
+ training = False ,
1079
+ ** kwargs
1080
+ ):
1081
+ inputs = input_processing (
1082
+ func = self .call ,
1083
+ config = self .config ,
1084
+ input_ids = input_ids ,
1085
+ attention_mask = attention_mask ,
1086
+ decoder_input_ids = decoder_input_ids ,
1087
+ decoder_attention_mask = decoder_attention_mask ,
1088
+ encoder_outputs = encoder_outputs ,
1089
+ past_key_values = past_key_values ,
1090
+ inputs_embeds = inputs_embeds ,
1091
+ decoder_inputs_embeds = decoder_inputs_embeds ,
1092
+ use_cache = use_cache ,
1093
+ output_attentions = output_attentions ,
1094
+ output_hidden_states = output_hidden_states ,
1095
+ return_dict = return_dict ,
1096
+ training = training ,
1097
+ kwargs_call = kwargs ,
1098
+ )
1099
+
1100
+ outputs = self .model (
1101
+ input_ids = inputs ["input_ids" ],
1102
+ attention_mask = inputs ["attention_mask" ],
1103
+ decoder_input_ids = inputs ["decoder_input_ids" ],
1104
+ decoder_attention_mask = inputs ["decoder_attention_mask" ],
1105
+ encoder_outputs = inputs ["encoder_outputs" ],
1106
+ past_key_values = inputs ["past_key_values" ],
1107
+ inputs_embeds = inputs ["inputs_embeds" ],
1108
+ decoder_inputs_embeds = inputs ["decoder_inputs_embeds" ],
1109
+ use_cache = inputs ["use_cache" ],
1110
+ output_attentions = inputs ["output_attentions" ],
1111
+ output_hidden_states = inputs ["output_hidden_states" ],
1112
+ return_dict = inputs ["return_dict" ],
1113
+ training = inputs ["training" ],
1114
+ )
1115
+
1116
+ return outputs
1117
+
1056
1118
def serving_output (self , output ):
1057
- pkv = ( tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None ,)
1119
+ pkv = tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None
1058
1120
dec_hs = tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .output_hidden_states else None
1059
1121
dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .output_attentions else None
1060
1122
enc_hs = tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .output_hidden_states else None
@@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
1083
1145
1084
1146
def __init__ (self , config , * inputs , ** kwargs ):
1085
1147
super ().__init__ (config , * inputs , ** kwargs )
1086
- self .model = TFBartModel (config , name = "model" )
1148
+ self .model = TFBartMainLayer (config , name = "model" )
1087
1149
self .use_cache = config .use_cache
1088
1150
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
1089
1151
self .final_logits_bias = self .add_weight (
@@ -1199,7 +1261,7 @@ def call(
1199
1261
)
1200
1262
1201
1263
def serving_output (self , output ):
1202
- pkv = ( tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None ,)
1264
+ pkv = tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None
1203
1265
dec_hs = tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .output_hidden_states else None
1204
1266
dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .output_attentions else None
1205
1267
enc_hs = tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .output_hidden_states else None
0 commit comments