Skip to content

Commit e6d23a4

Browse files
authored
Improve test_pt_tf_model_equivalence on PT side (#16731)
* Update test_pt_tf_model_equivalence on PT side Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 3dd57b1 commit e6d23a4

File tree

5 files changed

+255
-602
lines changed

5 files changed

+255
-602
lines changed

tests/clip/test_modeling_clip.py

-144
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from transformers.testing_utils import (
2929
is_flax_available,
3030
is_pt_flax_cross_test,
31-
is_pt_tf_cross_test,
3231
require_torch,
3332
require_vision,
3433
slow,
@@ -602,149 +601,6 @@ def test_load_vision_text_config(self):
602601
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
603602
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
604603

605-
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
606-
@is_pt_tf_cross_test
607-
def test_pt_tf_model_equivalence(self):
608-
import numpy as np
609-
import tensorflow as tf
610-
611-
import transformers
612-
613-
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
614-
615-
for model_class in self.all_model_classes:
616-
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
617-
618-
if not hasattr(transformers, tf_model_class_name):
619-
# transformers does not have TF version yet
620-
return
621-
622-
tf_model_class = getattr(transformers, tf_model_class_name)
623-
624-
config.output_hidden_states = True
625-
626-
tf_model = tf_model_class(config)
627-
pt_model = model_class(config)
628-
629-
# make sure only tf inputs are forward that actually exist in function args
630-
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
631-
632-
# remove all head masks
633-
tf_input_keys.discard("head_mask")
634-
tf_input_keys.discard("cross_attn_head_mask")
635-
tf_input_keys.discard("decoder_head_mask")
636-
637-
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
638-
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
639-
640-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
641-
pt_model.eval()
642-
tf_inputs_dict = {}
643-
for key, tensor in pt_inputs.items():
644-
# skip key that does not exist in tf
645-
if type(tensor) == bool:
646-
tf_inputs_dict[key] = tensor
647-
elif key == "input_values":
648-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
649-
elif key == "pixel_values":
650-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
651-
else:
652-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
653-
654-
# Check we can load pt model in tf and vice-versa with model => model functions
655-
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
656-
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
657-
658-
# need to rename encoder-decoder "inputs" for PyTorch
659-
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
660-
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
661-
662-
with torch.no_grad():
663-
pto = pt_model(**pt_inputs)
664-
tfo = tf_model(tf_inputs_dict, training=False)
665-
666-
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
667-
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
668-
669-
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
670-
continue
671-
672-
tf_out = tf_output.numpy()
673-
pt_out = pt_output.cpu().numpy()
674-
675-
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
676-
677-
if len(tf_out.shape) > 0:
678-
679-
tf_nans = np.copy(np.isnan(tf_out))
680-
pt_nans = np.copy(np.isnan(pt_out))
681-
682-
pt_out[tf_nans] = 0
683-
tf_out[tf_nans] = 0
684-
pt_out[pt_nans] = 0
685-
tf_out[pt_nans] = 0
686-
687-
max_diff = np.amax(np.abs(tf_out - pt_out))
688-
self.assertLessEqual(max_diff, 4e-2)
689-
690-
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
691-
with tempfile.TemporaryDirectory() as tmpdirname:
692-
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
693-
torch.save(pt_model.state_dict(), pt_checkpoint_path)
694-
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
695-
696-
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
697-
tf_model.save_weights(tf_checkpoint_path)
698-
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
699-
pt_model = pt_model.to(torch_device)
700-
701-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
702-
pt_model.eval()
703-
tf_inputs_dict = {}
704-
for key, tensor in pt_inputs.items():
705-
# skip key that does not exist in tf
706-
if type(tensor) == bool:
707-
tensor = np.array(tensor, dtype=bool)
708-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
709-
elif key == "input_values":
710-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
711-
elif key == "pixel_values":
712-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
713-
else:
714-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
715-
716-
# need to rename encoder-decoder "inputs" for PyTorch
717-
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
718-
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
719-
720-
with torch.no_grad():
721-
pto = pt_model(**pt_inputs)
722-
723-
tfo = tf_model(tf_inputs_dict)
724-
725-
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
726-
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
727-
728-
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
729-
continue
730-
731-
tf_out = tf_output.numpy()
732-
pt_out = pt_output.cpu().numpy()
733-
734-
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
735-
736-
if len(tf_out.shape) > 0:
737-
tf_nans = np.copy(np.isnan(tf_out))
738-
pt_nans = np.copy(np.isnan(pt_out))
739-
740-
pt_out[tf_nans] = 0
741-
tf_out[tf_nans] = 0
742-
pt_out[pt_nans] = 0
743-
tf_out[pt_nans] = 0
744-
745-
max_diff = np.amax(np.abs(tf_out - pt_out))
746-
self.assertLessEqual(max_diff, 4e-2)
747-
748604
# overwrite from common since FlaxCLIPModel returns nested output
749605
# which is not supported in the common test
750606
@is_pt_flax_cross_test

tests/lxmert/test_modeling_lxmert.py

+27-119
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616

1717
import copy
18-
import os
19-
import tempfile
2018
import unittest
2119

2220
import numpy as np
2321

24-
import transformers
2522
from transformers import LxmertConfig, is_tf_available, is_torch_available
2623
from transformers.models.auto import get_values
27-
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device
24+
from transformers.testing_utils import require_torch, slow, torch_device
2825

2926
from ..test_configuration_common import ConfigTester
3027
from ..test_modeling_common import ModelTesterMixin, ids_tensor
@@ -527,6 +524,8 @@ def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
527524

528525
if return_obj_labels:
529526
inputs_dict["obj_labels"] = obj_labels
527+
else:
528+
config.task_obj_predict = False
530529

531530
return config, inputs_dict
532531

@@ -740,121 +739,30 @@ def test_retain_grad_hidden_states_attentions(self):
740739
self.assertIsNotNone(hidden_states_vision.grad)
741740
self.assertIsNotNone(attentions_vision.grad)
742741

743-
@is_pt_tf_cross_test
744-
def test_pt_tf_model_equivalence(self):
745-
for model_class in self.all_model_classes:
746-
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
747-
return_obj_labels="PreTraining" in model_class.__name__
748-
)
749-
750-
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
751-
752-
if not hasattr(transformers, tf_model_class_name):
753-
# transformers does not have TF version yet
754-
return
755-
756-
tf_model_class = getattr(transformers, tf_model_class_name)
757-
758-
config.output_hidden_states = True
759-
config.task_obj_predict = False
760-
761-
pt_model = model_class(config)
762-
tf_model = tf_model_class(config)
763-
764-
# Check we can load pt model in tf and vice-versa with model => model functions
765-
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
766-
767-
def recursive_numpy_convert(iterable):
768-
return_dict = {}
769-
for key, value in iterable.items():
770-
if type(value) == bool:
771-
return_dict[key] = value
772-
if isinstance(value, dict):
773-
return_dict[key] = recursive_numpy_convert(value)
774-
else:
775-
if isinstance(value, (list, tuple)):
776-
return_dict[key] = (
777-
tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value
778-
)
779-
else:
780-
return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
781-
return return_dict
782-
783-
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
784-
785-
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
786-
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
787-
788-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
789-
pt_model.eval()
790-
791-
# Delete obj labels as we want to compute the hidden states and not the loss
792-
793-
if "obj_labels" in inputs_dict:
794-
del inputs_dict["obj_labels"]
795-
796-
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
797-
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
798-
799-
with torch.no_grad():
800-
pto = pt_model(**pt_inputs)
801-
tfo = tf_model(tf_inputs_dict, training=False)
802-
tf_hidden_states = tfo[0].numpy()
803-
pt_hidden_states = pto[0].cpu().numpy()
804-
805-
tf_nans = np.copy(np.isnan(tf_hidden_states))
806-
pt_nans = np.copy(np.isnan(pt_hidden_states))
807-
808-
pt_hidden_states[tf_nans] = 0
809-
tf_hidden_states[tf_nans] = 0
810-
pt_hidden_states[pt_nans] = 0
811-
tf_hidden_states[pt_nans] = 0
812-
813-
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
814-
# Debug info (remove when fixed)
815-
if max_diff >= 2e-2:
816-
print("===")
817-
print(model_class)
818-
print(config)
819-
print(inputs_dict)
820-
print(pt_inputs)
821-
self.assertLessEqual(max_diff, 6e-2)
822-
823-
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
824-
with tempfile.TemporaryDirectory() as tmpdirname:
825-
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
826-
torch.save(pt_model.state_dict(), pt_checkpoint_path)
827-
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
828-
829-
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
830-
tf_model.save_weights(tf_checkpoint_path)
831-
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
832-
833-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
834-
pt_model.eval()
835-
836-
for key, value in pt_inputs.items():
837-
if key in ("visual_feats", "visual_pos"):
838-
pt_inputs[key] = value.to(torch.float32)
839-
else:
840-
pt_inputs[key] = value.to(torch.long)
841-
842-
with torch.no_grad():
843-
pto = pt_model(**pt_inputs)
844-
845-
tfo = tf_model(tf_inputs_dict)
846-
tfo = tfo[0].numpy()
847-
pto = pto[0].cpu().numpy()
848-
tf_nans = np.copy(np.isnan(tfo))
849-
pt_nans = np.copy(np.isnan(pto))
850-
851-
pto[tf_nans] = 0
852-
tfo[tf_nans] = 0
853-
pto[pt_nans] = 0
854-
tfo[pt_nans] = 0
855-
856-
max_diff = np.amax(np.abs(tfo - pto))
857-
self.assertLessEqual(max_diff, 6e-2)
742+
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
743+
744+
tf_inputs_dict = {}
745+
for key, value in pt_inputs_dict.items():
746+
# skip key that does not exist in tf
747+
if isinstance(value, dict):
748+
tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
749+
elif isinstance(value, (list, tuple)):
750+
tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
751+
elif type(value) == bool:
752+
tf_inputs_dict[key] = value
753+
elif key == "input_values":
754+
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
755+
elif key == "pixel_values":
756+
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
757+
elif key == "input_features":
758+
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
759+
# other general float inputs
760+
elif value.is_floating_point():
761+
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
762+
else:
763+
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
764+
765+
return tf_inputs_dict
858766

859767

860768
@require_torch

0 commit comments

Comments
 (0)