|
15 | 15 |
|
16 | 16 |
|
17 | 17 | import copy
|
18 |
| -import os |
19 |
| -import tempfile |
20 | 18 | import unittest
|
21 | 19 |
|
22 | 20 | import numpy as np
|
23 | 21 |
|
24 |
| -import transformers |
25 | 22 | from transformers import LxmertConfig, is_tf_available, is_torch_available
|
26 | 23 | from transformers.models.auto import get_values
|
27 |
| -from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device |
| 24 | +from transformers.testing_utils import require_torch, slow, torch_device |
28 | 25 |
|
29 | 26 | from ..test_configuration_common import ConfigTester
|
30 | 27 | from ..test_modeling_common import ModelTesterMixin, ids_tensor
|
@@ -527,6 +524,8 @@ def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
|
527 | 524 |
|
528 | 525 | if return_obj_labels:
|
529 | 526 | inputs_dict["obj_labels"] = obj_labels
|
| 527 | + else: |
| 528 | + config.task_obj_predict = False |
530 | 529 |
|
531 | 530 | return config, inputs_dict
|
532 | 531 |
|
@@ -740,121 +739,30 @@ def test_retain_grad_hidden_states_attentions(self):
|
740 | 739 | self.assertIsNotNone(hidden_states_vision.grad)
|
741 | 740 | self.assertIsNotNone(attentions_vision.grad)
|
742 | 741 |
|
743 |
| - @is_pt_tf_cross_test |
744 |
| - def test_pt_tf_model_equivalence(self): |
745 |
| - for model_class in self.all_model_classes: |
746 |
| - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( |
747 |
| - return_obj_labels="PreTraining" in model_class.__name__ |
748 |
| - ) |
749 |
| - |
750 |
| - tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning |
751 |
| - |
752 |
| - if not hasattr(transformers, tf_model_class_name): |
753 |
| - # transformers does not have TF version yet |
754 |
| - return |
755 |
| - |
756 |
| - tf_model_class = getattr(transformers, tf_model_class_name) |
757 |
| - |
758 |
| - config.output_hidden_states = True |
759 |
| - config.task_obj_predict = False |
760 |
| - |
761 |
| - pt_model = model_class(config) |
762 |
| - tf_model = tf_model_class(config) |
763 |
| - |
764 |
| - # Check we can load pt model in tf and vice-versa with model => model functions |
765 |
| - pt_inputs = self._prepare_for_class(inputs_dict, model_class) |
766 |
| - |
767 |
| - def recursive_numpy_convert(iterable): |
768 |
| - return_dict = {} |
769 |
| - for key, value in iterable.items(): |
770 |
| - if type(value) == bool: |
771 |
| - return_dict[key] = value |
772 |
| - if isinstance(value, dict): |
773 |
| - return_dict[key] = recursive_numpy_convert(value) |
774 |
| - else: |
775 |
| - if isinstance(value, (list, tuple)): |
776 |
| - return_dict[key] = ( |
777 |
| - tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value |
778 |
| - ) |
779 |
| - else: |
780 |
| - return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32) |
781 |
| - return return_dict |
782 |
| - |
783 |
| - tf_inputs_dict = recursive_numpy_convert(pt_inputs) |
784 |
| - |
785 |
| - tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) |
786 |
| - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device) |
787 |
| - |
788 |
| - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences |
789 |
| - pt_model.eval() |
790 |
| - |
791 |
| - # Delete obj labels as we want to compute the hidden states and not the loss |
792 |
| - |
793 |
| - if "obj_labels" in inputs_dict: |
794 |
| - del inputs_dict["obj_labels"] |
795 |
| - |
796 |
| - pt_inputs = self._prepare_for_class(inputs_dict, model_class) |
797 |
| - tf_inputs_dict = recursive_numpy_convert(pt_inputs) |
798 |
| - |
799 |
| - with torch.no_grad(): |
800 |
| - pto = pt_model(**pt_inputs) |
801 |
| - tfo = tf_model(tf_inputs_dict, training=False) |
802 |
| - tf_hidden_states = tfo[0].numpy() |
803 |
| - pt_hidden_states = pto[0].cpu().numpy() |
804 |
| - |
805 |
| - tf_nans = np.copy(np.isnan(tf_hidden_states)) |
806 |
| - pt_nans = np.copy(np.isnan(pt_hidden_states)) |
807 |
| - |
808 |
| - pt_hidden_states[tf_nans] = 0 |
809 |
| - tf_hidden_states[tf_nans] = 0 |
810 |
| - pt_hidden_states[pt_nans] = 0 |
811 |
| - tf_hidden_states[pt_nans] = 0 |
812 |
| - |
813 |
| - max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) |
814 |
| - # Debug info (remove when fixed) |
815 |
| - if max_diff >= 2e-2: |
816 |
| - print("===") |
817 |
| - print(model_class) |
818 |
| - print(config) |
819 |
| - print(inputs_dict) |
820 |
| - print(pt_inputs) |
821 |
| - self.assertLessEqual(max_diff, 6e-2) |
822 |
| - |
823 |
| - # Check we can load pt model in tf and vice-versa with checkpoint => model functions |
824 |
| - with tempfile.TemporaryDirectory() as tmpdirname: |
825 |
| - pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") |
826 |
| - torch.save(pt_model.state_dict(), pt_checkpoint_path) |
827 |
| - tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) |
828 |
| - |
829 |
| - tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") |
830 |
| - tf_model.save_weights(tf_checkpoint_path) |
831 |
| - pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) |
832 |
| - |
833 |
| - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences |
834 |
| - pt_model.eval() |
835 |
| - |
836 |
| - for key, value in pt_inputs.items(): |
837 |
| - if key in ("visual_feats", "visual_pos"): |
838 |
| - pt_inputs[key] = value.to(torch.float32) |
839 |
| - else: |
840 |
| - pt_inputs[key] = value.to(torch.long) |
841 |
| - |
842 |
| - with torch.no_grad(): |
843 |
| - pto = pt_model(**pt_inputs) |
844 |
| - |
845 |
| - tfo = tf_model(tf_inputs_dict) |
846 |
| - tfo = tfo[0].numpy() |
847 |
| - pto = pto[0].cpu().numpy() |
848 |
| - tf_nans = np.copy(np.isnan(tfo)) |
849 |
| - pt_nans = np.copy(np.isnan(pto)) |
850 |
| - |
851 |
| - pto[tf_nans] = 0 |
852 |
| - tfo[tf_nans] = 0 |
853 |
| - pto[pt_nans] = 0 |
854 |
| - tfo[pt_nans] = 0 |
855 |
| - |
856 |
| - max_diff = np.amax(np.abs(tfo - pto)) |
857 |
| - self.assertLessEqual(max_diff, 6e-2) |
| 742 | + def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict): |
| 743 | + |
| 744 | + tf_inputs_dict = {} |
| 745 | + for key, value in pt_inputs_dict.items(): |
| 746 | + # skip key that does not exist in tf |
| 747 | + if isinstance(value, dict): |
| 748 | + tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value) |
| 749 | + elif isinstance(value, (list, tuple)): |
| 750 | + tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value) |
| 751 | + elif type(value) == bool: |
| 752 | + tf_inputs_dict[key] = value |
| 753 | + elif key == "input_values": |
| 754 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 755 | + elif key == "pixel_values": |
| 756 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 757 | + elif key == "input_features": |
| 758 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 759 | + # other general float inputs |
| 760 | + elif value.is_floating_point(): |
| 761 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32) |
| 762 | + else: |
| 763 | + tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32) |
| 764 | + |
| 765 | + return tf_inputs_dict |
858 | 766 |
|
859 | 767 |
|
860 | 768 | @require_torch
|
|
0 commit comments