@@ -1463,6 +1463,193 @@ def test_pt_tf_model_equivalence(self):
1463
1463
1464
1464
import transformers
1465
1465
1466
+ def prepare_tf_inputs_from_pt_inputs (pt_inputs_dict ):
1467
+
1468
+ tf_inputs_dict = {}
1469
+ for key , tensor in pt_inputs_dict .items ():
1470
+ # skip key that does not exist in tf
1471
+ if type (tensor ) == bool :
1472
+ tf_inputs_dict [key ] = tensor
1473
+ elif key == "input_values" :
1474
+ tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1475
+ elif key == "pixel_values" :
1476
+ tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1477
+ elif key == "input_features" :
1478
+ tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1479
+ # To deal with the edge cases from `TFTapasForQuestionAnswering`.
1480
+ # PyTorch can deal with type casting automatically, but TensorFlow is more strict!
1481
+ # TODO: find a clean/better way to deal with these extra keys that are not common.
1482
+ elif key in ["float_answer" , "numeric_values" , "numeric_values_scale" ]:
1483
+ tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1484
+ else :
1485
+ tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .int32 )
1486
+
1487
+ return tf_inputs_dict
1488
+
1489
+ def check_outputs (tf_outputs , pt_outputs , model_class , names ):
1490
+ """
1491
+ Args:
1492
+ model_class: The class of the model that is currently testing. For example, `TFBertModel`,
1493
+ TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
1494
+ debugging easier and faster.
1495
+
1496
+ names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
1497
+ Currently unused, but in the future, we could use this information to make the error message clearer
1498
+ by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
1499
+ """
1500
+
1501
+ # Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
1502
+ if names == "past_key_values" :
1503
+ return
1504
+
1505
+ # Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
1506
+ if type (tf_outputs ) in [tuple , list ]:
1507
+ self .assertEqual (type (tf_outputs ), type (pt_outputs ))
1508
+ self .assertEqual (len (tf_outputs ), len (pt_outputs ))
1509
+ if type (names ) == tuple :
1510
+ for tf_output , pt_output , name in zip (tf_outputs , pt_outputs , names ):
1511
+ check_outputs (tf_output , pt_output , model_class , names = name )
1512
+ elif type (names ) == str :
1513
+ for idx , (tf_output , pt_output ) in enumerate (zip (tf_outputs , pt_outputs )):
1514
+ check_outputs (tf_output , pt_output , model_class , names = f"{ names } _{ idx } " )
1515
+ else :
1516
+ raise ValueError (f"`names` should be a `tuple` or a string. Got { type (names )} instead." )
1517
+ elif isinstance (tf_outputs , tf .Tensor ):
1518
+ self .assertTrue (isinstance (pt_outputs , torch .Tensor ))
1519
+
1520
+ tf_outputs = tf_outputs .numpy ()
1521
+ pt_outputs = pt_outputs .detach ().to ("cpu" ).numpy ()
1522
+
1523
+ tf_nans = np .isnan (tf_outputs )
1524
+ pt_nans = np .isnan (pt_outputs )
1525
+
1526
+ pt_outputs [tf_nans ] = 0
1527
+ tf_outputs [tf_nans ] = 0
1528
+ pt_outputs [pt_nans ] = 0
1529
+ tf_outputs [pt_nans ] = 0
1530
+
1531
+ max_diff = np .amax (np .abs (tf_outputs - pt_outputs ))
1532
+ self .assertLessEqual (max_diff , 1e-5 )
1533
+ else :
1534
+ raise ValueError (
1535
+ f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got { type (tf_outputs )} instead."
1536
+ )
1537
+
1538
+ def check_pt_tf_models (tf_model , pt_model , pt_inputs_dict , pt_inputs_dict_maybe_with_labels ):
1539
+
1540
+ # send pytorch model to the correct device
1541
+ pt_model .to (torch_device )
1542
+
1543
+ # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
1544
+ pt_model .eval ()
1545
+
1546
+ tf_inputs_dict = prepare_tf_inputs_from_pt_inputs (pt_inputs_dict )
1547
+ tf_inputs_dict_maybe_with_labels = prepare_tf_inputs_from_pt_inputs (pt_inputs_dict_maybe_with_labels )
1548
+
1549
+ # send pytorch inputs to the correct device
1550
+ pt_inputs_dict = {
1551
+ k : v .to (device = torch_device ) if isinstance (v , torch .Tensor ) else v for k , v in pt_inputs_dict .items ()
1552
+ }
1553
+ pt_inputs_dict_maybe_with_labels = {
1554
+ k : v .to (device = torch_device ) if isinstance (v , torch .Tensor ) else v
1555
+ for k , v in pt_inputs_dict_maybe_with_labels .items ()
1556
+ }
1557
+
1558
+ # Original test: check without `labels`
1559
+ with torch .no_grad ():
1560
+ pt_outputs = pt_model (** pt_inputs_dict )
1561
+ tf_outputs = tf_model (tf_inputs_dict )
1562
+
1563
+ tf_keys = tuple ([k for k , v in tf_outputs .items () if v is not None ])
1564
+ pt_keys = tuple ([k for k , v in pt_outputs .items () if v is not None ])
1565
+
1566
+ self .assertEqual (tf_keys , pt_keys )
1567
+ check_outputs (tf_outputs .to_tuple (), pt_outputs .to_tuple (), model_class , names = tf_keys )
1568
+
1569
+ # check the case where `labels` is passed
1570
+ has_labels = any (
1571
+ x in tf_inputs_dict_maybe_with_labels for x in ["labels" , "next_sentence_label" , "start_positions" ]
1572
+ )
1573
+ if has_labels :
1574
+
1575
+ with torch .no_grad ():
1576
+ pt_outputs = pt_model (** pt_inputs_dict_maybe_with_labels )
1577
+ tf_outputs = tf_model (tf_inputs_dict_maybe_with_labels )
1578
+
1579
+ # Some models' output class don't have `loss` attribute despite `labels` is used.
1580
+ # TODO: identify which models
1581
+ tf_loss = getattr (tf_outputs , "loss" , None )
1582
+ pt_loss = getattr (pt_outputs , "loss" , None )
1583
+
1584
+ # Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
1585
+ # - FlaubertWithLMHeadModel
1586
+ # - FunnelForPreTraining
1587
+ # - ElectraForPreTraining
1588
+ # - XLMWithLMHeadModel
1589
+ # TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
1590
+ if not ((tf_loss is None and pt_loss is None ) or (tf_loss is not None and pt_loss is not None )):
1591
+ if model_class .__name__ not in [
1592
+ "FlaubertWithLMHeadModel" ,
1593
+ "FunnelForPreTraining" ,
1594
+ "ElectraForPreTraining" ,
1595
+ "XLMWithLMHeadModel" ,
1596
+ "TransfoXLLMHeadModel" ,
1597
+ ]:
1598
+ self .assertEqual (tf_loss is None , pt_loss is None )
1599
+
1600
+ tf_keys = tuple ([k for k , v in tf_outputs .items () if v is not None ])
1601
+ pt_keys = tuple ([k for k , v in pt_outputs .items () if v is not None ])
1602
+
1603
+ # TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
1604
+ # (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
1605
+ if tf_keys != pt_keys :
1606
+ if model_class .__name__ not in [
1607
+ "FlaubertWithLMHeadModel" ,
1608
+ "FunnelForPreTraining" ,
1609
+ "ElectraForPreTraining" ,
1610
+ "XLMWithLMHeadModel" ,
1611
+ "TransfoXLLMHeadModel" ,
1612
+ ]:
1613
+ self .assertEqual (tf_keys , pt_keys )
1614
+
1615
+ # Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
1616
+ # some remaining attributes in the outputs.
1617
+ # TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
1618
+ # compute the 1st `index` where `tf_keys` and `pt_keys` is different
1619
+ index = 0
1620
+ for _ in range (min (len (tf_keys ), len (pt_keys ))):
1621
+ if tf_keys [index ] == pt_keys [index ]:
1622
+ index += 1
1623
+ else :
1624
+ break
1625
+ if tf_keys [:index ] != pt_keys [:index ]:
1626
+ self .assertEqual (tf_keys , pt_keys )
1627
+
1628
+ # Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
1629
+ # both`labels` and `next_sentence_label`.
1630
+ if tf_loss is not None and pt_loss is not None :
1631
+
1632
+ # check anything else than `loss`
1633
+ keys = tuple ([k for k in tf_keys ])
1634
+ check_outputs (tf_outputs [1 :index ], pt_outputs [1 :index ], model_class , names = keys [1 :index ])
1635
+
1636
+ # check `loss`
1637
+
1638
+ # tf models returned loss is usually a tensor rather than a scalar.
1639
+ # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
1640
+ # Change it here to a scalar to match PyTorch models' loss
1641
+ tf_loss = tf .math .reduce_mean (tf_loss ).numpy ()
1642
+ pt_loss = pt_loss .detach ().to ("cpu" ).numpy ()
1643
+
1644
+ tf_nans = np .isnan (tf_loss )
1645
+ pt_nans = np .isnan (pt_loss )
1646
+ # the 2 losses need to be both nan or both not nan
1647
+ self .assertEqual (tf_nans , pt_nans )
1648
+
1649
+ if not tf_nans :
1650
+ max_diff = np .amax (np .abs (tf_loss - pt_loss ))
1651
+ self .assertLessEqual (max_diff , 1e-5 )
1652
+
1466
1653
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
1467
1654
1468
1655
for model_class in self .all_model_classes :
@@ -1472,9 +1659,30 @@ def test_pt_tf_model_equivalence(self):
1472
1659
# transformers does not have TF version yet
1473
1660
return
1474
1661
1475
- tf_model_class = getattr (transformers , tf_model_class_name )
1662
+ if self .has_attentions :
1663
+ config .output_attentions = True
1476
1664
1477
- config .output_hidden_states = True
1665
+ for k in ["attention_mask" , "encoder_attention_mask" , "decoder_attention_mask" ]:
1666
+ if k in inputs_dict :
1667
+ attention_mask = inputs_dict [k ]
1668
+ # make sure no all 0s attention masks - to avoid failure at this moment.
1669
+ # TODO: remove this line once the TODO below is implemented.
1670
+ attention_mask = torch .ones_like (attention_mask , dtype = torch .int32 )
1671
+ # Here we make the first sequence with all 0s as attention mask.
1672
+ # Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
1673
+ # values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
1674
+ # TODO: enable this block once the large negative values thing is cleaned up.
1675
+ # (see https://github.com/huggingface/transformers/issues/14859)
1676
+ # attention_mask = torch.cat(
1677
+ # [
1678
+ # torch.zeros_like(attention_mask[:1], dtype=torch.int32),
1679
+ # attention_mask[1:].type(dtype=torch.int32)
1680
+ # ],
1681
+ # dim=0
1682
+ # )
1683
+ inputs_dict [k ] = attention_mask
1684
+
1685
+ tf_model_class = getattr (transformers , tf_model_class_name )
1478
1686
1479
1687
tf_model = tf_model_class (config )
1480
1688
pt_model = model_class (config )
@@ -1487,49 +1695,20 @@ def test_pt_tf_model_equivalence(self):
1487
1695
tf_input_keys .discard ("cross_attn_head_mask" )
1488
1696
tf_input_keys .discard ("decoder_head_mask" )
1489
1697
1490
- pt_inputs = self ._prepare_for_class (inputs_dict , model_class )
1491
- pt_inputs = { k : v for k , v in pt_inputs . items () if k in tf_input_keys }
1698
+ pt_inputs_dict = self ._prepare_for_class (inputs_dict , model_class )
1699
+ pt_inputs_dict_maybe_with_labels = self . _prepare_for_class ( inputs_dict , model_class , return_labels = True )
1492
1700
1493
- # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
1494
- pt_model .eval ()
1495
- tf_inputs_dict = {}
1496
- for key , tensor in pt_inputs .items ():
1497
- # skip key that does not exist in tf
1498
- if type (tensor ) == bool :
1499
- tf_inputs_dict [key ] = tensor
1500
- elif key == "input_values" :
1501
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1502
- elif key == "pixel_values" :
1503
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1504
- elif key == "input_features" :
1505
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1506
- else :
1507
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .int32 )
1701
+ pt_inputs_dict = {k : v for k , v in pt_inputs_dict .items () if k in tf_input_keys }
1702
+ pt_inputs_dict_maybe_with_labels = {
1703
+ k : v for k , v in pt_inputs_dict_maybe_with_labels .items () if k in tf_input_keys
1704
+ }
1508
1705
1509
1706
# Check we can load pt model in tf and vice-versa with model => model functions
1707
+ tf_inputs_dict = prepare_tf_inputs_from_pt_inputs (pt_inputs_dict )
1510
1708
tf_model = transformers .load_pytorch_model_in_tf2_model (tf_model , pt_model , tf_inputs = tf_inputs_dict )
1511
- pt_model = transformers .load_tf2_model_in_pytorch_model (pt_model , tf_model ). to ( torch_device )
1709
+ pt_model = transformers .load_tf2_model_in_pytorch_model (pt_model , tf_model )
1512
1710
1513
- # Make sure PyTorch tensors are on same device as model
1514
- pt_inputs = {k : v .to (torch_device ) if torch .is_tensor (v ) else v for k , v in pt_inputs .items ()}
1515
-
1516
- with torch .no_grad ():
1517
- pto = pt_model (** pt_inputs )
1518
- tfo = tf_model (tf_inputs_dict , training = False )
1519
-
1520
- tf_hidden_states = tfo [0 ].numpy ()
1521
- pt_hidden_states = pto [0 ].cpu ().numpy ()
1522
-
1523
- tf_nans = np .copy (np .isnan (tf_hidden_states ))
1524
- pt_nans = np .copy (np .isnan (pt_hidden_states ))
1525
-
1526
- pt_hidden_states [tf_nans ] = 0
1527
- tf_hidden_states [tf_nans ] = 0
1528
- pt_hidden_states [pt_nans ] = 0
1529
- tf_hidden_states [pt_nans ] = 0
1530
-
1531
- max_diff = np .amax (np .abs (tf_hidden_states - pt_hidden_states ))
1532
- self .assertLessEqual (max_diff , 4e-2 )
1711
+ check_pt_tf_models (tf_model , pt_model , pt_inputs_dict , pt_inputs_dict_maybe_with_labels )
1533
1712
1534
1713
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
1535
1714
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -1542,43 +1721,7 @@ def test_pt_tf_model_equivalence(self):
1542
1721
pt_model = transformers .load_tf2_checkpoint_in_pytorch_model (pt_model , tf_checkpoint_path )
1543
1722
pt_model = pt_model .to (torch_device )
1544
1723
1545
- # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
1546
- pt_model .eval ()
1547
- tf_inputs_dict = {}
1548
- for key , tensor in pt_inputs .items ():
1549
- # skip key that does not exist in tf
1550
- if type (tensor ) == bool :
1551
- tensor = np .array (tensor , dtype = bool )
1552
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor , dtype = tf .int32 )
1553
- elif key == "input_values" :
1554
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1555
- elif key == "pixel_values" :
1556
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1557
- elif key == "input_features" :
1558
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .float32 )
1559
- else :
1560
- tf_inputs_dict [key ] = tf .convert_to_tensor (tensor .cpu ().numpy (), dtype = tf .int32 )
1561
-
1562
- # need to rename encoder-decoder "inputs" for PyTorch
1563
- # if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
1564
- # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
1565
-
1566
- with torch .no_grad ():
1567
- pto = pt_model (** pt_inputs )
1568
-
1569
- tfo = tf_model (tf_inputs_dict )
1570
- tfo = tfo [0 ].numpy ()
1571
- pto = pto [0 ].cpu ().numpy ()
1572
- tf_nans = np .copy (np .isnan (tfo ))
1573
- pt_nans = np .copy (np .isnan (pto ))
1574
-
1575
- pto [tf_nans ] = 0
1576
- tfo [tf_nans ] = 0
1577
- pto [pt_nans ] = 0
1578
- tfo [pt_nans ] = 0
1579
-
1580
- max_diff = np .amax (np .abs (tfo - pto ))
1581
- self .assertLessEqual (max_diff , 4e-2 )
1724
+ check_pt_tf_models (tf_model , pt_model , pt_inputs_dict , pt_inputs_dict_maybe_with_labels )
1582
1725
1583
1726
def assert_almost_equals (self , a : np .ndarray , b : np .ndarray , tol : float ):
1584
1727
diff = np .abs ((a - b )).max ()
0 commit comments