@@ -1660,8 +1660,9 @@ def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_
1660
1660
# transformers does not have TF version yet
1661
1661
return
1662
1662
1663
- if self .has_attentions :
1664
- config .output_attentions = True
1663
+ # Output all for aggressive testing
1664
+ config .output_hidden_states = True
1665
+ config .output_attentions = self .has_attentions
1665
1666
1666
1667
for k in ["attention_mask" , "encoder_attention_mask" , "decoder_attention_mask" ]:
1667
1668
if k in inputs_dict :
@@ -1728,28 +1729,75 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
1728
1729
diff = np .abs ((a - b )).max ()
1729
1730
self .assertLessEqual (diff , tol , f"Difference between torch and flax is { diff } (>= { tol } )." )
1730
1731
1732
+ def check_outputs (self , fx_outputs , pt_outputs , model_class , names ):
1733
+ """
1734
+ Args:
1735
+ model_class: The class of the model that is currently testing. For example, ..., etc.
1736
+ Currently unused, but it could make debugging easier and faster.
1737
+
1738
+ names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
1739
+ Currently unused, but in the future, we could use this information to make the error message clearer
1740
+ by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
1741
+ """
1742
+ if type (fx_outputs ) in [tuple , list ]:
1743
+ self .assertEqual (type (fx_outputs ), type (pt_outputs ))
1744
+ self .assertEqual (len (fx_outputs ), len (pt_outputs ))
1745
+ if type (names ) == tuple :
1746
+ for fo , po , name in zip (fx_outputs , pt_outputs , names ):
1747
+ self .check_outputs (fo , po , model_class , names = name )
1748
+ elif type (names ) == str :
1749
+ for idx , (fo , po ) in enumerate (zip (fx_outputs , pt_outputs )):
1750
+ self .check_outputs (fo , po , model_class , names = f"{ names } _{ idx } " )
1751
+ else :
1752
+ raise ValueError (f"`names` should be a `tuple` or a string. Got { type (names )} instead." )
1753
+ elif isinstance (fx_outputs , jnp .ndarray ):
1754
+ self .assertTrue (isinstance (pt_outputs , torch .Tensor ))
1755
+
1756
+ # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
1757
+ fx_outputs = np .array (fx_outputs )
1758
+ pt_outputs = pt_outputs .detach ().to ("cpu" ).numpy ()
1759
+
1760
+ fx_nans = np .isnan (fx_outputs )
1761
+ pt_nans = np .isnan (pt_outputs )
1762
+
1763
+ pt_outputs [fx_nans ] = 0
1764
+ fx_outputs [fx_nans ] = 0
1765
+ pt_outputs [pt_nans ] = 0
1766
+ fx_outputs [pt_nans ] = 0
1767
+
1768
+ self .assert_almost_equals (fx_outputs , pt_outputs , 1e-5 )
1769
+ else :
1770
+ raise ValueError (
1771
+ f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got { type (fx_outputs )} instead."
1772
+ )
1773
+
1731
1774
@is_pt_flax_cross_test
1732
1775
def test_equivalence_pt_to_flax (self ):
1733
1776
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
1734
1777
1735
1778
for model_class in self .all_model_classes :
1736
1779
with self .subTest (model_class .__name__ ):
1737
-
1738
- # load PyTorch class
1739
- pt_model = model_class (config ).eval ()
1740
- # Flax models don't use the `use_cache` option and cache is not returned as a default.
1741
- # So we disable `use_cache` here for PyTorch model.
1742
- pt_model .config .use_cache = False
1743
-
1744
1780
fx_model_class_name = "Flax" + model_class .__name__
1745
1781
1746
1782
if not hasattr (transformers , fx_model_class_name ):
1783
+ # no flax model exists for this class
1747
1784
return
1748
1785
1786
+ # Output all for aggressive testing
1787
+ config .output_hidden_states = True
1788
+ config .output_attentions = self .has_attentions
1789
+
1749
1790
fx_model_class = getattr (transformers , fx_model_class_name )
1750
1791
1792
+ # load PyTorch class
1793
+ pt_model = model_class (config ).eval ()
1794
+ # Flax models don't use the `use_cache` option and cache is not returned as a default.
1795
+ # So we disable `use_cache` here for PyTorch model.
1796
+ pt_model .config .use_cache = False
1797
+
1751
1798
# load Flax class
1752
1799
fx_model = fx_model_class (config , dtype = jnp .float32 )
1800
+
1753
1801
# make sure only flax inputs are forward that actually exist in function args
1754
1802
fx_input_keys = inspect .signature (fx_model .__call__ ).parameters .keys ()
1755
1803
@@ -1759,89 +1807,120 @@ def test_equivalence_pt_to_flax(self):
1759
1807
# remove function args that don't exist in Flax
1760
1808
pt_inputs = {k : v for k , v in pt_inputs .items () if k in fx_input_keys }
1761
1809
1810
+ # send pytorch inputs to the correct device
1811
+ pt_inputs = {
1812
+ k : v .to (device = torch_device ) if isinstance (v , torch .Tensor ) else v for k , v in pt_inputs .items ()
1813
+ }
1814
+
1815
+ # convert inputs to Flax
1816
+ fx_inputs = {k : np .array (v ) for k , v in pt_inputs .items () if torch .is_tensor (v )}
1817
+
1762
1818
fx_state = convert_pytorch_state_dict_to_flax (pt_model .state_dict (), fx_model )
1763
1819
fx_model .params = fx_state
1764
1820
1821
+ # send pytorch model to the correct device
1822
+ pt_model .to (torch_device )
1823
+
1765
1824
with torch .no_grad ():
1766
- pt_outputs = pt_model (** pt_inputs ).to_tuple ()
1825
+ pt_outputs = pt_model (** pt_inputs )
1826
+ fx_outputs = fx_model (** fx_inputs )
1767
1827
1768
- # convert inputs to Flax
1769
- fx_inputs = {k : np .array (v ) for k , v in pt_inputs .items () if torch .is_tensor (v )}
1770
- fx_outputs = fx_model (** fx_inputs ).to_tuple ()
1771
- self .assertEqual (len (fx_outputs ), len (pt_outputs ), "Output lengths differ between Flax and PyTorch" )
1772
- for fx_output , pt_output in zip (fx_outputs , pt_outputs ):
1773
- self .assert_almost_equals (fx_output , pt_output .numpy (), 4e-2 )
1828
+ fx_keys = tuple ([k for k , v in fx_outputs .items () if v is not None ])
1829
+ pt_keys = tuple ([k for k , v in pt_outputs .items () if v is not None ])
1830
+
1831
+ self .assertEqual (fx_keys , pt_keys )
1832
+ self .check_outputs (fx_outputs .to_tuple (), pt_outputs .to_tuple (), model_class , names = fx_keys )
1774
1833
1775
1834
with tempfile .TemporaryDirectory () as tmpdirname :
1776
1835
pt_model .save_pretrained (tmpdirname )
1777
1836
fx_model_loaded = fx_model_class .from_pretrained (tmpdirname , from_pt = True )
1778
1837
1779
- fx_outputs_loaded = fx_model_loaded (** fx_inputs ).to_tuple ()
1780
- self .assertEqual (
1781
- len (fx_outputs_loaded ), len (pt_outputs ), "Output lengths differ between Flax and PyTorch"
1782
- )
1783
- for fx_output_loaded , pt_output in zip (fx_outputs_loaded , pt_outputs ):
1784
- self .assert_almost_equals (fx_output_loaded , pt_output .numpy (), 4e-2 )
1838
+ fx_outputs_loaded = fx_model_loaded (** fx_inputs )
1839
+
1840
+ fx_keys = tuple ([k for k , v in fx_outputs_loaded .items () if v is not None ])
1841
+ pt_keys = tuple ([k for k , v in pt_outputs .items () if v is not None ])
1842
+
1843
+ self .assertEqual (fx_keys , pt_keys )
1844
+ self .check_outputs (fx_outputs_loaded .to_tuple (), pt_outputs .to_tuple (), model_class , names = fx_keys )
1785
1845
1786
1846
@is_pt_flax_cross_test
1787
1847
def test_equivalence_flax_to_pt (self ):
1788
1848
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
1789
1849
1790
1850
for model_class in self .all_model_classes :
1791
1851
with self .subTest (model_class .__name__ ):
1792
- # load corresponding PyTorch class
1793
- pt_model = model_class (config ).eval ()
1794
-
1795
- # So we disable `use_cache` here for PyTorch model.
1796
- pt_model .config .use_cache = False
1797
-
1798
1852
fx_model_class_name = "Flax" + model_class .__name__
1799
1853
1800
1854
if not hasattr (transformers , fx_model_class_name ):
1801
1855
# no flax model exists for this class
1802
1856
return
1803
1857
1858
+ # Output all for aggressive testing
1859
+ config .output_hidden_states = True
1860
+ config .output_attentions = self .has_attentions
1861
+
1804
1862
fx_model_class = getattr (transformers , fx_model_class_name )
1805
1863
1864
+ # load PyTorch class
1865
+ pt_model = model_class (config ).eval ()
1866
+ # Flax models don't use the `use_cache` option and cache is not returned as a default.
1867
+ # So we disable `use_cache` here for PyTorch model.
1868
+ pt_model .config .use_cache = False
1869
+
1806
1870
# load Flax class
1807
1871
fx_model = fx_model_class (config , dtype = jnp .float32 )
1872
+
1808
1873
# make sure only flax inputs are forward that actually exist in function args
1809
1874
fx_input_keys = inspect .signature (fx_model .__call__ ).parameters .keys ()
1810
1875
1811
- pt_model = load_flax_weights_in_pytorch_model (pt_model , fx_model .params )
1812
-
1813
- # make sure weights are tied in PyTorch
1814
- pt_model .tie_weights ()
1815
-
1816
1876
# prepare inputs
1817
1877
pt_inputs = self ._prepare_for_class (inputs_dict , model_class )
1818
1878
1819
1879
# remove function args that don't exist in Flax
1820
1880
pt_inputs = {k : v for k , v in pt_inputs .items () if k in fx_input_keys }
1821
1881
1822
- with torch .no_grad ():
1823
- pt_outputs = pt_model (** pt_inputs ).to_tuple ()
1882
+ # send pytorch inputs to the correct device
1883
+ pt_inputs = {
1884
+ k : v .to (device = torch_device ) if isinstance (v , torch .Tensor ) else v for k , v in pt_inputs .items ()
1885
+ }
1824
1886
1887
+ # convert inputs to Flax
1825
1888
fx_inputs = {k : np .array (v ) for k , v in pt_inputs .items () if torch .is_tensor (v )}
1826
1889
1827
- fx_outputs = fx_model (** fx_inputs ).to_tuple ()
1828
- self .assertEqual (len (fx_outputs ), len (pt_outputs ), "Output lengths differ between Flax and PyTorch" )
1890
+ pt_model = load_flax_weights_in_pytorch_model (pt_model , fx_model .params )
1891
+
1892
+ # make sure weights are tied in PyTorch
1893
+ pt_model .tie_weights ()
1894
+
1895
+ # send pytorch model to the correct device
1896
+ pt_model .to (torch_device )
1829
1897
1830
- for fx_output , pt_output in zip (fx_outputs , pt_outputs ):
1831
- self .assert_almost_equals (fx_output , pt_output .numpy (), 4e-2 )
1898
+ with torch .no_grad ():
1899
+ pt_outputs = pt_model (** pt_inputs )
1900
+ fx_outputs = fx_model (** fx_inputs )
1901
+
1902
+ fx_keys = tuple ([k for k , v in fx_outputs .items () if v is not None ])
1903
+ pt_keys = tuple ([k for k , v in pt_outputs .items () if v is not None ])
1904
+
1905
+ self .assertEqual (fx_keys , pt_keys )
1906
+ self .check_outputs (fx_outputs .to_tuple (), pt_outputs .to_tuple (), model_class , names = fx_keys )
1832
1907
1833
1908
with tempfile .TemporaryDirectory () as tmpdirname :
1834
1909
fx_model .save_pretrained (tmpdirname )
1835
1910
pt_model_loaded = model_class .from_pretrained (tmpdirname , from_flax = True )
1836
1911
1912
+ # send pytorch model to the correct device
1913
+ pt_model_loaded .to (torch_device )
1914
+ pt_model_loaded .eval ()
1915
+
1837
1916
with torch .no_grad ():
1838
- pt_outputs_loaded = pt_model_loaded (** pt_inputs ). to_tuple ()
1917
+ pt_outputs_loaded = pt_model_loaded (** pt_inputs )
1839
1918
1840
- self . assertEqual (
1841
- len ( fx_outputs ), len ( pt_outputs_loaded ), "Output lengths differ between Flax and PyTorch"
1842
- )
1843
- for fx_output , pt_output in zip ( fx_outputs , pt_outputs_loaded ):
1844
- self .assert_almost_equals ( fx_output , pt_output . numpy (), 4e-2 )
1919
+ fx_keys = tuple ([ k for k , v in fx_outputs . items () if v is not None ])
1920
+ pt_keys = tuple ([ k for k , v in pt_outputs_loaded . items () if v is not None ])
1921
+
1922
+ self . assertEqual ( fx_keys , pt_keys )
1923
+ self .check_outputs ( fx_outputs . to_tuple (), pt_outputs_loaded . to_tuple (), model_class , names = fx_keys )
1845
1924
1846
1925
def test_inputs_embeds (self ):
1847
1926
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
0 commit comments