Skip to content

Commit 75c666b

Browse files
authored
Aggressive PT/TF equivalence test on PT side (#16250)
* Aggressive PT/TF equivalence test on PT side * Ugly fix for `TFTapasForQuestionAnswering` * apply review suggestions Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent d481b64 commit 75c666b

File tree

1 file changed

+220
-77
lines changed

1 file changed

+220
-77
lines changed

tests/test_modeling_common.py

+220-77
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,193 @@ def test_pt_tf_model_equivalence(self):
14631463

14641464
import transformers
14651465

1466+
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
1467+
1468+
tf_inputs_dict = {}
1469+
for key, tensor in pt_inputs_dict.items():
1470+
# skip key that does not exist in tf
1471+
if type(tensor) == bool:
1472+
tf_inputs_dict[key] = tensor
1473+
elif key == "input_values":
1474+
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1475+
elif key == "pixel_values":
1476+
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1477+
elif key == "input_features":
1478+
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1479+
# To deal with the edge cases from `TFTapasForQuestionAnswering`.
1480+
# PyTorch can deal with type casting automatically, but TensorFlow is more strict!
1481+
# TODO: find a clean/better way to deal with these extra keys that are not common.
1482+
elif key in ["float_answer", "numeric_values", "numeric_values_scale"]:
1483+
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1484+
else:
1485+
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
1486+
1487+
return tf_inputs_dict
1488+
1489+
def check_outputs(tf_outputs, pt_outputs, model_class, names):
1490+
"""
1491+
Args:
1492+
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
1493+
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
1494+
debugging easier and faster.
1495+
1496+
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
1497+
Currently unused, but in the future, we could use this information to make the error message clearer
1498+
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
1499+
"""
1500+
1501+
# Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
1502+
if names == "past_key_values":
1503+
return
1504+
1505+
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
1506+
if type(tf_outputs) in [tuple, list]:
1507+
self.assertEqual(type(tf_outputs), type(pt_outputs))
1508+
self.assertEqual(len(tf_outputs), len(pt_outputs))
1509+
if type(names) == tuple:
1510+
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
1511+
check_outputs(tf_output, pt_output, model_class, names=name)
1512+
elif type(names) == str:
1513+
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
1514+
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
1515+
else:
1516+
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
1517+
elif isinstance(tf_outputs, tf.Tensor):
1518+
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
1519+
1520+
tf_outputs = tf_outputs.numpy()
1521+
pt_outputs = pt_outputs.detach().to("cpu").numpy()
1522+
1523+
tf_nans = np.isnan(tf_outputs)
1524+
pt_nans = np.isnan(pt_outputs)
1525+
1526+
pt_outputs[tf_nans] = 0
1527+
tf_outputs[tf_nans] = 0
1528+
pt_outputs[pt_nans] = 0
1529+
tf_outputs[pt_nans] = 0
1530+
1531+
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
1532+
self.assertLessEqual(max_diff, 1e-5)
1533+
else:
1534+
raise ValueError(
1535+
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
1536+
)
1537+
1538+
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels):
1539+
1540+
# send pytorch model to the correct device
1541+
pt_model.to(torch_device)
1542+
1543+
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
1544+
pt_model.eval()
1545+
1546+
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
1547+
tf_inputs_dict_maybe_with_labels = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict_maybe_with_labels)
1548+
1549+
# send pytorch inputs to the correct device
1550+
pt_inputs_dict = {
1551+
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
1552+
}
1553+
pt_inputs_dict_maybe_with_labels = {
1554+
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
1555+
for k, v in pt_inputs_dict_maybe_with_labels.items()
1556+
}
1557+
1558+
# Original test: check without `labels`
1559+
with torch.no_grad():
1560+
pt_outputs = pt_model(**pt_inputs_dict)
1561+
tf_outputs = tf_model(tf_inputs_dict)
1562+
1563+
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
1564+
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1565+
1566+
self.assertEqual(tf_keys, pt_keys)
1567+
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
1568+
1569+
# check the case where `labels` is passed
1570+
has_labels = any(
1571+
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
1572+
)
1573+
if has_labels:
1574+
1575+
with torch.no_grad():
1576+
pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels)
1577+
tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels)
1578+
1579+
# Some models' output class don't have `loss` attribute despite `labels` is used.
1580+
# TODO: identify which models
1581+
tf_loss = getattr(tf_outputs, "loss", None)
1582+
pt_loss = getattr(pt_outputs, "loss", None)
1583+
1584+
# Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
1585+
# - FlaubertWithLMHeadModel
1586+
# - FunnelForPreTraining
1587+
# - ElectraForPreTraining
1588+
# - XLMWithLMHeadModel
1589+
# TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
1590+
if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)):
1591+
if model_class.__name__ not in [
1592+
"FlaubertWithLMHeadModel",
1593+
"FunnelForPreTraining",
1594+
"ElectraForPreTraining",
1595+
"XLMWithLMHeadModel",
1596+
"TransfoXLLMHeadModel",
1597+
]:
1598+
self.assertEqual(tf_loss is None, pt_loss is None)
1599+
1600+
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
1601+
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1602+
1603+
# TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
1604+
# (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
1605+
if tf_keys != pt_keys:
1606+
if model_class.__name__ not in [
1607+
"FlaubertWithLMHeadModel",
1608+
"FunnelForPreTraining",
1609+
"ElectraForPreTraining",
1610+
"XLMWithLMHeadModel",
1611+
"TransfoXLLMHeadModel",
1612+
]:
1613+
self.assertEqual(tf_keys, pt_keys)
1614+
1615+
# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
1616+
# some remaining attributes in the outputs.
1617+
# TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
1618+
# compute the 1st `index` where `tf_keys` and `pt_keys` is different
1619+
index = 0
1620+
for _ in range(min(len(tf_keys), len(pt_keys))):
1621+
if tf_keys[index] == pt_keys[index]:
1622+
index += 1
1623+
else:
1624+
break
1625+
if tf_keys[:index] != pt_keys[:index]:
1626+
self.assertEqual(tf_keys, pt_keys)
1627+
1628+
# Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
1629+
# both`labels` and `next_sentence_label`.
1630+
if tf_loss is not None and pt_loss is not None:
1631+
1632+
# check anything else than `loss`
1633+
keys = tuple([k for k in tf_keys])
1634+
check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index])
1635+
1636+
# check `loss`
1637+
1638+
# tf models returned loss is usually a tensor rather than a scalar.
1639+
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
1640+
# Change it here to a scalar to match PyTorch models' loss
1641+
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
1642+
pt_loss = pt_loss.detach().to("cpu").numpy()
1643+
1644+
tf_nans = np.isnan(tf_loss)
1645+
pt_nans = np.isnan(pt_loss)
1646+
# the 2 losses need to be both nan or both not nan
1647+
self.assertEqual(tf_nans, pt_nans)
1648+
1649+
if not tf_nans:
1650+
max_diff = np.amax(np.abs(tf_loss - pt_loss))
1651+
self.assertLessEqual(max_diff, 1e-5)
1652+
14661653
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
14671654

14681655
for model_class in self.all_model_classes:
@@ -1472,9 +1659,30 @@ def test_pt_tf_model_equivalence(self):
14721659
# transformers does not have TF version yet
14731660
return
14741661

1475-
tf_model_class = getattr(transformers, tf_model_class_name)
1662+
if self.has_attentions:
1663+
config.output_attentions = True
14761664

1477-
config.output_hidden_states = True
1665+
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
1666+
if k in inputs_dict:
1667+
attention_mask = inputs_dict[k]
1668+
# make sure no all 0s attention masks - to avoid failure at this moment.
1669+
# TODO: remove this line once the TODO below is implemented.
1670+
attention_mask = torch.ones_like(attention_mask, dtype=torch.int32)
1671+
# Here we make the first sequence with all 0s as attention mask.
1672+
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
1673+
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
1674+
# TODO: enable this block once the large negative values thing is cleaned up.
1675+
# (see https://github.com/huggingface/transformers/issues/14859)
1676+
# attention_mask = torch.cat(
1677+
# [
1678+
# torch.zeros_like(attention_mask[:1], dtype=torch.int32),
1679+
# attention_mask[1:].type(dtype=torch.int32)
1680+
# ],
1681+
# dim=0
1682+
# )
1683+
inputs_dict[k] = attention_mask
1684+
1685+
tf_model_class = getattr(transformers, tf_model_class_name)
14781686

14791687
tf_model = tf_model_class(config)
14801688
pt_model = model_class(config)
@@ -1487,49 +1695,20 @@ def test_pt_tf_model_equivalence(self):
14871695
tf_input_keys.discard("cross_attn_head_mask")
14881696
tf_input_keys.discard("decoder_head_mask")
14891697

1490-
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
1491-
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
1698+
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
1699+
pt_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
14921700

1493-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
1494-
pt_model.eval()
1495-
tf_inputs_dict = {}
1496-
for key, tensor in pt_inputs.items():
1497-
# skip key that does not exist in tf
1498-
if type(tensor) == bool:
1499-
tf_inputs_dict[key] = tensor
1500-
elif key == "input_values":
1501-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1502-
elif key == "pixel_values":
1503-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1504-
elif key == "input_features":
1505-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1506-
else:
1507-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
1701+
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
1702+
pt_inputs_dict_maybe_with_labels = {
1703+
k: v for k, v in pt_inputs_dict_maybe_with_labels.items() if k in tf_input_keys
1704+
}
15081705

15091706
# Check we can load pt model in tf and vice-versa with model => model functions
1707+
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
15101708
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
1511-
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
1709+
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
15121710

1513-
# Make sure PyTorch tensors are on same device as model
1514-
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
1515-
1516-
with torch.no_grad():
1517-
pto = pt_model(**pt_inputs)
1518-
tfo = tf_model(tf_inputs_dict, training=False)
1519-
1520-
tf_hidden_states = tfo[0].numpy()
1521-
pt_hidden_states = pto[0].cpu().numpy()
1522-
1523-
tf_nans = np.copy(np.isnan(tf_hidden_states))
1524-
pt_nans = np.copy(np.isnan(pt_hidden_states))
1525-
1526-
pt_hidden_states[tf_nans] = 0
1527-
tf_hidden_states[tf_nans] = 0
1528-
pt_hidden_states[pt_nans] = 0
1529-
tf_hidden_states[pt_nans] = 0
1530-
1531-
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
1532-
self.assertLessEqual(max_diff, 4e-2)
1711+
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
15331712

15341713
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
15351714
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1542,43 +1721,7 @@ def test_pt_tf_model_equivalence(self):
15421721
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
15431722
pt_model = pt_model.to(torch_device)
15441723

1545-
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
1546-
pt_model.eval()
1547-
tf_inputs_dict = {}
1548-
for key, tensor in pt_inputs.items():
1549-
# skip key that does not exist in tf
1550-
if type(tensor) == bool:
1551-
tensor = np.array(tensor, dtype=bool)
1552-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
1553-
elif key == "input_values":
1554-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1555-
elif key == "pixel_values":
1556-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1557-
elif key == "input_features":
1558-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
1559-
else:
1560-
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
1561-
1562-
# need to rename encoder-decoder "inputs" for PyTorch
1563-
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
1564-
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
1565-
1566-
with torch.no_grad():
1567-
pto = pt_model(**pt_inputs)
1568-
1569-
tfo = tf_model(tf_inputs_dict)
1570-
tfo = tfo[0].numpy()
1571-
pto = pto[0].cpu().numpy()
1572-
tf_nans = np.copy(np.isnan(tfo))
1573-
pt_nans = np.copy(np.isnan(pto))
1574-
1575-
pto[tf_nans] = 0
1576-
tfo[tf_nans] = 0
1577-
pto[pt_nans] = 0
1578-
tfo[pt_nans] = 0
1579-
1580-
max_diff = np.amax(np.abs(tfo - pto))
1581-
self.assertLessEqual(max_diff, 4e-2)
1724+
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
15821725

15831726
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
15841727
diff = np.abs((a - b)).max()

0 commit comments

Comments
 (0)