Skip to content

Commit f571dc2

Browse files
authored
Update PT Flax equivalence tests in PT test file (#16280)
* update PT/Flax equivalence tests on PT side * overwrite check_outputs in BigBirdModelTest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 41bfc1e commit f571dc2

File tree

2 files changed

+133
-45
lines changed

2 files changed

+133
-45
lines changed

tests/big_bird/test_modeling_big_bird.py

+9
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,15 @@ def test_for_change_to_full_attn(self):
596596
config_and_inputs = self.model_tester.prepare_config_and_inputs()
597597
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
598598

599+
# overwrite from common in order to skip the check on `attentions`
600+
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
601+
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
602+
# an effort was done to return `attention_probs` (yet to be verified).
603+
if type(names) == str and names.startswith("attentions"):
604+
return
605+
else:
606+
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
607+
599608

600609
@require_torch
601610
@slow

tests/test_modeling_common.py

+124-45
Original file line numberDiff line numberDiff line change
@@ -1660,8 +1660,9 @@ def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_
16601660
# transformers does not have TF version yet
16611661
return
16621662

1663-
if self.has_attentions:
1664-
config.output_attentions = True
1663+
# Output all for aggressive testing
1664+
config.output_hidden_states = True
1665+
config.output_attentions = self.has_attentions
16651666

16661667
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
16671668
if k in inputs_dict:
@@ -1728,28 +1729,75 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
17281729
diff = np.abs((a - b)).max()
17291730
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
17301731

1732+
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
1733+
"""
1734+
Args:
1735+
model_class: The class of the model that is currently testing. For example, ..., etc.
1736+
Currently unused, but it could make debugging easier and faster.
1737+
1738+
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
1739+
Currently unused, but in the future, we could use this information to make the error message clearer
1740+
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
1741+
"""
1742+
if type(fx_outputs) in [tuple, list]:
1743+
self.assertEqual(type(fx_outputs), type(pt_outputs))
1744+
self.assertEqual(len(fx_outputs), len(pt_outputs))
1745+
if type(names) == tuple:
1746+
for fo, po, name in zip(fx_outputs, pt_outputs, names):
1747+
self.check_outputs(fo, po, model_class, names=name)
1748+
elif type(names) == str:
1749+
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
1750+
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
1751+
else:
1752+
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
1753+
elif isinstance(fx_outputs, jnp.ndarray):
1754+
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
1755+
1756+
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
1757+
fx_outputs = np.array(fx_outputs)
1758+
pt_outputs = pt_outputs.detach().to("cpu").numpy()
1759+
1760+
fx_nans = np.isnan(fx_outputs)
1761+
pt_nans = np.isnan(pt_outputs)
1762+
1763+
pt_outputs[fx_nans] = 0
1764+
fx_outputs[fx_nans] = 0
1765+
pt_outputs[pt_nans] = 0
1766+
fx_outputs[pt_nans] = 0
1767+
1768+
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
1769+
else:
1770+
raise ValueError(
1771+
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
1772+
)
1773+
17311774
@is_pt_flax_cross_test
17321775
def test_equivalence_pt_to_flax(self):
17331776
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
17341777

17351778
for model_class in self.all_model_classes:
17361779
with self.subTest(model_class.__name__):
1737-
1738-
# load PyTorch class
1739-
pt_model = model_class(config).eval()
1740-
# Flax models don't use the `use_cache` option and cache is not returned as a default.
1741-
# So we disable `use_cache` here for PyTorch model.
1742-
pt_model.config.use_cache = False
1743-
17441780
fx_model_class_name = "Flax" + model_class.__name__
17451781

17461782
if not hasattr(transformers, fx_model_class_name):
1783+
# no flax model exists for this class
17471784
return
17481785

1786+
# Output all for aggressive testing
1787+
config.output_hidden_states = True
1788+
config.output_attentions = self.has_attentions
1789+
17491790
fx_model_class = getattr(transformers, fx_model_class_name)
17501791

1792+
# load PyTorch class
1793+
pt_model = model_class(config).eval()
1794+
# Flax models don't use the `use_cache` option and cache is not returned as a default.
1795+
# So we disable `use_cache` here for PyTorch model.
1796+
pt_model.config.use_cache = False
1797+
17511798
# load Flax class
17521799
fx_model = fx_model_class(config, dtype=jnp.float32)
1800+
17531801
# make sure only flax inputs are forward that actually exist in function args
17541802
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
17551803

@@ -1759,89 +1807,120 @@ def test_equivalence_pt_to_flax(self):
17591807
# remove function args that don't exist in Flax
17601808
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
17611809

1810+
# send pytorch inputs to the correct device
1811+
pt_inputs = {
1812+
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
1813+
}
1814+
1815+
# convert inputs to Flax
1816+
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
1817+
17621818
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
17631819
fx_model.params = fx_state
17641820

1821+
# send pytorch model to the correct device
1822+
pt_model.to(torch_device)
1823+
17651824
with torch.no_grad():
1766-
pt_outputs = pt_model(**pt_inputs).to_tuple()
1825+
pt_outputs = pt_model(**pt_inputs)
1826+
fx_outputs = fx_model(**fx_inputs)
17671827

1768-
# convert inputs to Flax
1769-
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
1770-
fx_outputs = fx_model(**fx_inputs).to_tuple()
1771-
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
1772-
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
1773-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
1828+
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1829+
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1830+
1831+
self.assertEqual(fx_keys, pt_keys)
1832+
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
17741833

17751834
with tempfile.TemporaryDirectory() as tmpdirname:
17761835
pt_model.save_pretrained(tmpdirname)
17771836
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
17781837

1779-
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
1780-
self.assertEqual(
1781-
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
1782-
)
1783-
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
1784-
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
1838+
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
1839+
1840+
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
1841+
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1842+
1843+
self.assertEqual(fx_keys, pt_keys)
1844+
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
17851845

17861846
@is_pt_flax_cross_test
17871847
def test_equivalence_flax_to_pt(self):
17881848
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
17891849

17901850
for model_class in self.all_model_classes:
17911851
with self.subTest(model_class.__name__):
1792-
# load corresponding PyTorch class
1793-
pt_model = model_class(config).eval()
1794-
1795-
# So we disable `use_cache` here for PyTorch model.
1796-
pt_model.config.use_cache = False
1797-
17981852
fx_model_class_name = "Flax" + model_class.__name__
17991853

18001854
if not hasattr(transformers, fx_model_class_name):
18011855
# no flax model exists for this class
18021856
return
18031857

1858+
# Output all for aggressive testing
1859+
config.output_hidden_states = True
1860+
config.output_attentions = self.has_attentions
1861+
18041862
fx_model_class = getattr(transformers, fx_model_class_name)
18051863

1864+
# load PyTorch class
1865+
pt_model = model_class(config).eval()
1866+
# Flax models don't use the `use_cache` option and cache is not returned as a default.
1867+
# So we disable `use_cache` here for PyTorch model.
1868+
pt_model.config.use_cache = False
1869+
18061870
# load Flax class
18071871
fx_model = fx_model_class(config, dtype=jnp.float32)
1872+
18081873
# make sure only flax inputs are forward that actually exist in function args
18091874
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
18101875

1811-
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
1812-
1813-
# make sure weights are tied in PyTorch
1814-
pt_model.tie_weights()
1815-
18161876
# prepare inputs
18171877
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
18181878

18191879
# remove function args that don't exist in Flax
18201880
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
18211881

1822-
with torch.no_grad():
1823-
pt_outputs = pt_model(**pt_inputs).to_tuple()
1882+
# send pytorch inputs to the correct device
1883+
pt_inputs = {
1884+
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
1885+
}
18241886

1887+
# convert inputs to Flax
18251888
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
18261889

1827-
fx_outputs = fx_model(**fx_inputs).to_tuple()
1828-
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
1890+
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
1891+
1892+
# make sure weights are tied in PyTorch
1893+
pt_model.tie_weights()
1894+
1895+
# send pytorch model to the correct device
1896+
pt_model.to(torch_device)
18291897

1830-
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
1831-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
1898+
with torch.no_grad():
1899+
pt_outputs = pt_model(**pt_inputs)
1900+
fx_outputs = fx_model(**fx_inputs)
1901+
1902+
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1903+
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1904+
1905+
self.assertEqual(fx_keys, pt_keys)
1906+
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
18321907

18331908
with tempfile.TemporaryDirectory() as tmpdirname:
18341909
fx_model.save_pretrained(tmpdirname)
18351910
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
18361911

1912+
# send pytorch model to the correct device
1913+
pt_model_loaded.to(torch_device)
1914+
pt_model_loaded.eval()
1915+
18371916
with torch.no_grad():
1838-
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
1917+
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
18391918

1840-
self.assertEqual(
1841-
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
1842-
)
1843-
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
1844-
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
1919+
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1920+
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
1921+
1922+
self.assertEqual(fx_keys, pt_keys)
1923+
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
18451924

18461925
def test_inputs_embeds(self):
18471926
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

0 commit comments

Comments
 (0)