@@ -354,6 +354,69 @@ def test_pt_tf_model_equivalence(self):
354
354
max_diff = np .amax (np .abs (tfo - pto ))
355
355
self .assertLessEqual (max_diff , 4e-2 )
356
356
357
+ def test_train_pipeline_custom_model (self ):
358
+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
359
+ tf_main_layer_classes = set (
360
+ module_member
361
+ for model_class in self .all_model_classes
362
+ for module in (import_module (model_class .__module__ ),)
363
+ for module_member_name in dir (module )
364
+ if module_member_name .endswith ("MainLayer" )
365
+ for module_member in (getattr (module , module_member_name ),)
366
+ if isinstance (module_member , type )
367
+ and tf .keras .layers .Layer in module_member .__bases__
368
+ and getattr (module_member , "_keras_serializable" , False )
369
+ )
370
+
371
+ for main_layer_class in tf_main_layer_classes :
372
+ # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
373
+ if "T5" in main_layer_class .__name__ :
374
+ # Take the same values than in TFT5ModelTester for this shared layer
375
+ shared = TFSharedEmbeddings (self .model_tester .vocab_size , self .model_tester .hidden_size , name = "shared" )
376
+ config .use_cache = False
377
+ main_layer = main_layer_class (config , embed_tokens = shared )
378
+ del inputs_dict ["use_cache" ]
379
+ else :
380
+ main_layer = main_layer_class (config )
381
+
382
+ symbolic_inputs = {
383
+ name : tf .keras .Input (tensor .shape [1 :], dtype = tensor .dtype ) for name , tensor in inputs_dict .items ()
384
+ }
385
+
386
+ if hasattr (self .model_tester , "num_labels" ):
387
+ num_labels = self .model_tester .num_labels
388
+ else :
389
+ num_labels = 2
390
+
391
+ X = tf .data .Dataset .from_tensor_slices (
392
+ (inputs_dict , np .random .randint (0 , num_labels , (self .model_tester .batch_size , 1 )))
393
+ ).batch (1 )
394
+
395
+ hidden_states = main_layer (symbolic_inputs )[0 ]
396
+ outputs = tf .keras .layers .Dense (num_labels , activation = "softmax" , name = "outputs" )(hidden_states )
397
+ model = tf .keras .models .Model (inputs = symbolic_inputs , outputs = [outputs ])
398
+
399
+ model .compile (loss = "binary_crossentropy" , optimizer = "adam" , metrics = ["acc" ])
400
+ model .fit (X , epochs = 1 )
401
+
402
+ with tempfile .TemporaryDirectory () as tmpdirname :
403
+ filepath = os .path .join (tmpdirname , "keras_model.h5" )
404
+ model .save (filepath )
405
+ if "T5" in main_layer_class .__name__ :
406
+ model = tf .keras .models .load_model (
407
+ filepath ,
408
+ custom_objects = {
409
+ main_layer_class .__name__ : main_layer_class ,
410
+ "TFSharedEmbeddings" : TFSharedEmbeddings ,
411
+ },
412
+ )
413
+ else :
414
+ model = tf .keras .models .load_model (
415
+ filepath , custom_objects = {main_layer_class .__name__ : main_layer_class }
416
+ )
417
+ assert isinstance (model , tf .keras .Model )
418
+ model (inputs_dict )
419
+
357
420
def test_compile_tf_model (self ):
358
421
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
359
422
0 commit comments