Skip to content

Commit a7dabfb

Browse files
authored
Fix TF s2s models (#9478)
* Fix Seq2Seq models for serving * Apply style * Fix lonfgormer * Fix mBart/Pegasus/Blenderbot * Apply style * Add a main intermediate layer * Apply style * Remove import * Apply tf.function to Longformer * Fix utils check_copy * Update S2S template * Fix BART + Blenderbot * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix MBart * Fix Marian * Fix Pegasus + template * Apply style * Fix common attributes test * Forgot to fix the LED test * Apply Patrick's comment on LED Decoder
1 parent 23e5a36 commit a7dabfb

20 files changed

+1024
-665
lines changed

src/transformers/modeling_tf_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def input_processing(func, config, input_ids, **kwargs):
322322
"""
323323
signature = dict(inspect.signature(func).parameters)
324324
signature.pop("kwargs", None)
325+
signature.pop("self", None)
325326
parameter_names = list(signature.keys())
326327
output = {}
327328
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
@@ -346,6 +347,8 @@ def input_processing(func, config, input_ids, **kwargs):
346347
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
347348
)
348349

350+
kwargs.pop("kwargs_call")
351+
349352
for k, v in kwargs.items():
350353
if isinstance(v, allowed_types) or v is None:
351354
output[k] = v
@@ -356,8 +359,8 @@ def input_processing(func, config, input_ids, **kwargs):
356359
for i, input in enumerate(input_ids):
357360
# EagerTensors don't allow to use the .name property so we check for a real Tensor
358361
if type(input) == tf.Tensor:
359-
# Tensor names have always the pattern name:device_id then we check only the
360-
# name and not the device id
362+
# Tensor names have always the pattern `name:id` then we check only the
363+
# `name` part
361364
tensor_name = input.name.split(":")[0]
362365

363366
if tensor_name in parameter_names:

src/transformers/models/bart/modeling_tf_bart.py

+114-52
Original file line numberDiff line numberDiff line change
@@ -411,29 +411,6 @@ def dummy_inputs(self):
411411
}
412412
return dummy_inputs
413413

414-
def get_input_embeddings(self):
415-
base_model = getattr(self, self.base_model_prefix, self)
416-
417-
return base_model.shared
418-
419-
def set_input_embeddings(self, value):
420-
base_model = getattr(self, self.base_model_prefix, self)
421-
422-
try:
423-
base_model.shared.weight = value
424-
except AttributeError:
425-
self(self.dummy_inputs)
426-
base_model.shared.weight = value
427-
428-
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
429-
430-
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
431-
pass
432-
433-
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
434-
base_model.encoder.set_embed_tokens(embed_tokens)
435-
base_model.decoder.set_embed_tokens(embed_tokens)
436-
437414
@tf.function(
438415
input_signature=[
439416
{
@@ -605,6 +582,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
605582
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
606583
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
607584

585+
def get_embed_tokens(self):
586+
return self.embed_tokens
587+
608588
def set_embed_tokens(self, embed_tokens):
609589
self.embed_tokens = embed_tokens
610590

@@ -744,6 +724,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
744724

745725
self.dropout = tf.keras.layers.Dropout(config.dropout)
746726

727+
def get_embed_tokens(self):
728+
return self.embed_tokens
729+
747730
def set_embed_tokens(self, embed_tokens):
748731
self.embed_tokens = embed_tokens
749732

@@ -871,13 +854,15 @@ def call(
871854
hidden_states = self.dropout(hidden_states, training=inputs["training"])
872855

873856
# decoder layers
874-
all_hidden_states = ()
875-
all_self_attns = ()
876-
present_key_values = ()
857+
all_hidden_states = () if inputs["output_hidden_states"] else None
858+
all_self_attns = () if inputs["output_attentions"] else None
859+
present_key_values = () if inputs["use_cache"] else None
860+
877861
for idx, decoder_layer in enumerate(self.layers):
878862
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
879863
if inputs["output_hidden_states"]:
880864
all_hidden_states += (hidden_states,)
865+
881866
dropout_probability = random.uniform(0, 1)
882867

883868
if inputs["training"] and (dropout_probability < self.layerdrop):
@@ -901,12 +886,12 @@ def call(
901886

902887
if inputs["output_hidden_states"]:
903888
all_hidden_states += (hidden_states,)
904-
else:
905-
all_hidden_states = None
906889

907-
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
890+
if inputs["output_attentions"]:
891+
all_self_attns = list(all_self_attns)
908892

909-
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
893+
if inputs["use_cache"]:
894+
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
910895

911896
if not inputs["return_dict"]:
912897
return hidden_states, present_key_values, all_hidden_states, all_self_attns
@@ -919,16 +904,14 @@ def call(
919904
)
920905

921906

922-
@add_start_docstrings(
923-
"The bare BART Model outputting raw hidden-states without any specific head on top.",
924-
BART_START_DOCSTRING,
925-
)
926907
@keras_serializable
927-
class TFBartModel(TFBartPretrainedModel):
928-
base_model_prefix = "model"
908+
class TFBartMainLayer(tf.keras.layers.Layer):
909+
config_class = BartConfig
929910

930-
def __init__(self, config: BartConfig, *inputs, **kwargs):
931-
super().__init__(config, *inputs, **kwargs)
911+
def __init__(self, config: BartConfig, **kwargs):
912+
super().__init__(**kwargs)
913+
914+
self.config = config
932915
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
933916

934917
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
@@ -942,19 +925,20 @@ def __init__(self, config: BartConfig, *inputs, **kwargs):
942925
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
943926
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
944927

945-
def get_encoder(self):
946-
return self.encoder
928+
def get_input_embeddings(self):
929+
return self.shared
947930

948-
def get_decoder(self):
949-
return self.decoder
931+
def set_input_embeddings(self, new_embeddings):
932+
self.shared.weight = new_embeddings
933+
self.shared.vocab_size = self.shared.weight.shape[0]
934+
# retrieve correct absolute scope for embed token wrapper
935+
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
936+
pass
937+
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
938+
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
939+
self.encoder.set_embed_tokens(embed_tokens)
940+
self.decoder.set_embed_tokens(embed_tokens)
950941

951-
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
952-
@add_code_sample_docstrings(
953-
tokenizer_class=_TOKENIZER_FOR_DOC,
954-
checkpoint="facebook/bart-large",
955-
output_type=TFSeq2SeqModelOutput,
956-
config_class=_CONFIG_FOR_DOC,
957-
)
958942
def call(
959943
self,
960944
input_ids=None,
@@ -1053,8 +1037,86 @@ def call(
10531037
encoder_attentions=inputs["encoder_outputs"].attentions,
10541038
)
10551039

1040+
1041+
@add_start_docstrings(
1042+
"The bare BART Model outputting raw hidden-states without any specific head on top.",
1043+
BART_START_DOCSTRING,
1044+
)
1045+
class TFBartModel(TFBartPretrainedModel):
1046+
def __init__(self, config: BartConfig, *inputs, **kwargs):
1047+
super().__init__(config, *inputs, **kwargs)
1048+
1049+
self.model = TFBartMainLayer(config, name="model")
1050+
1051+
def get_encoder(self):
1052+
return self.model.encoder
1053+
1054+
def get_decoder(self):
1055+
return self.model.decoder
1056+
1057+
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1058+
@add_code_sample_docstrings(
1059+
tokenizer_class=_TOKENIZER_FOR_DOC,
1060+
checkpoint="facebook/bart-large",
1061+
output_type=TFSeq2SeqModelOutput,
1062+
config_class=_CONFIG_FOR_DOC,
1063+
)
1064+
def call(
1065+
self,
1066+
input_ids=None,
1067+
attention_mask=None,
1068+
decoder_input_ids=None,
1069+
decoder_attention_mask=None,
1070+
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
1071+
past_key_values=None,
1072+
inputs_embeds=None,
1073+
decoder_inputs_embeds=None,
1074+
use_cache=None,
1075+
output_attentions=None,
1076+
output_hidden_states=None,
1077+
return_dict=None,
1078+
training=False,
1079+
**kwargs
1080+
):
1081+
inputs = input_processing(
1082+
func=self.call,
1083+
config=self.config,
1084+
input_ids=input_ids,
1085+
attention_mask=attention_mask,
1086+
decoder_input_ids=decoder_input_ids,
1087+
decoder_attention_mask=decoder_attention_mask,
1088+
encoder_outputs=encoder_outputs,
1089+
past_key_values=past_key_values,
1090+
inputs_embeds=inputs_embeds,
1091+
decoder_inputs_embeds=decoder_inputs_embeds,
1092+
use_cache=use_cache,
1093+
output_attentions=output_attentions,
1094+
output_hidden_states=output_hidden_states,
1095+
return_dict=return_dict,
1096+
training=training,
1097+
kwargs_call=kwargs,
1098+
)
1099+
1100+
outputs = self.model(
1101+
input_ids=inputs["input_ids"],
1102+
attention_mask=inputs["attention_mask"],
1103+
decoder_input_ids=inputs["decoder_input_ids"],
1104+
decoder_attention_mask=inputs["decoder_attention_mask"],
1105+
encoder_outputs=inputs["encoder_outputs"],
1106+
past_key_values=inputs["past_key_values"],
1107+
inputs_embeds=inputs["inputs_embeds"],
1108+
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
1109+
use_cache=inputs["use_cache"],
1110+
output_attentions=inputs["output_attentions"],
1111+
output_hidden_states=inputs["output_hidden_states"],
1112+
return_dict=inputs["return_dict"],
1113+
training=inputs["training"],
1114+
)
1115+
1116+
return outputs
1117+
10561118
def serving_output(self, output):
1057-
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
1119+
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
10581120
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
10591121
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
10601122
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
@@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
10831145

10841146
def __init__(self, config, *inputs, **kwargs):
10851147
super().__init__(config, *inputs, **kwargs)
1086-
self.model = TFBartModel(config, name="model")
1148+
self.model = TFBartMainLayer(config, name="model")
10871149
self.use_cache = config.use_cache
10881150
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
10891151
self.final_logits_bias = self.add_weight(
@@ -1199,7 +1261,7 @@ def call(
11991261
)
12001262

12011263
def serving_output(self, output):
1202-
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
1264+
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
12031265
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
12041266
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
12051267
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None

0 commit comments

Comments
 (0)