Skip to content

Commit a239241

Browse files
Some tests misusing assertTrue for comparisons fix (#16771)
* Fix issue avoid-misusing-assert-true found at https://codereview.doctor * fix tests * fix tf Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent d3bd9ac commit a239241

11 files changed

+30
-26
lines changed

src/transformers/models/wav2vec2/tokenization_wav2vec2.py

+4
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def convert_tokens_to_string(
299299
if output_word_offsets:
300300
word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
301301

302+
# don't output chars if not set to True
303+
if not output_char_offsets:
304+
char_offsets = None
305+
302306
# join to string
303307
join_char = " " if spaces_between_special_tokens else ""
304308
string = join_char.join(processed_chars).strip()

tests/longformer/test_modeling_longformer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,11 @@ def test_diagonalize(self):
416416

417417
def test_pad_and_transpose_last_two_dims(self):
418418
hidden_states = self._get_hidden_states()
419-
self.assertTrue(hidden_states.shape, (1, 8, 4))
419+
self.assertEqual(hidden_states.shape, (1, 4, 8))
420420
padding = (0, 0, 0, 1)
421421

422422
padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding)
423-
self.assertTrue(padded_hidden_states.shape, (1, 8, 5))
423+
self.assertEqual(padded_hidden_states.shape, (1, 8, 5))
424424

425425
expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32)
426426
self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6))
@@ -445,7 +445,7 @@ def test_chunk(self):
445445

446446
self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3))
447447
self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3))
448-
self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4))
448+
self.assertEqual(chunked_hidden_states.shape, (1, 3, 4, 4))
449449

450450
def test_mask_invalid_locations(self):
451451
hidden_states = self._get_hidden_states()
@@ -493,7 +493,7 @@ def test_layer_local_attn(self):
493493
is_global_attn=is_global_attn,
494494
)[0]
495495

496-
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
496+
self.assertEqual(output_hidden_states.shape, (1, 4, 8))
497497
self.assertTrue(
498498
torch.allclose(
499499
output_hidden_states[0, 1],
@@ -531,7 +531,7 @@ def test_layer_global_attn(self):
531531
is_global_attn=is_global_attn,
532532
)[0]
533533

534-
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
534+
self.assertEqual(output_hidden_states.shape, (2, 4, 8))
535535

536536
self.assertTrue(
537537
torch.allclose(

tests/longformer/test_modeling_tf_longformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def test_diagonalize(self):
413413

414414
def test_pad_and_transpose_last_two_dims(self):
415415
hidden_states = self._get_hidden_states()
416-
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
416+
self.assertEqual(shape_list(hidden_states), [1, 4, 8])
417417

418418
# pad along seq length dim
419419
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
@@ -486,7 +486,7 @@ def test_layer_local_attn(self):
486486
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
487487
)
488488

489-
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
489+
self.assertEqual(output_hidden_states.shape, (1, 4, 8))
490490
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
491491

492492
def test_layer_global_attn(self):
@@ -523,7 +523,7 @@ def test_layer_global_attn(self):
523523
]
524524
)[0]
525525

526-
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
526+
self.assertEqual(output_hidden_states.shape, (2, 4, 8))
527527
expected_slice_0 = tf.convert_to_tensor(
528528
[-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32
529529
)

tests/test_sequence_feature_extraction_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _inputs_are_equal(input_1, input_2):
185185

186186
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
187187
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
188-
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
188+
self.assertEqual(input_9.shape[:2], (batch_size, expected_mult_pad_length))
189189

190190
if feature_size > 1:
191191
self.assertTrue(input_9.shape[2] == feature_size)

tests/trainer/test_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def test_predict(self):
809809
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True)
810810
preds = trainer.predict(trainer.eval_dataset).predictions
811811
x = trainer.eval_dataset.x
812-
self.assertTrue(len(preds), 2)
812+
self.assertEqual(len(preds), 2)
813813
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
814814
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
815815

@@ -819,7 +819,7 @@ def test_predict(self):
819819
preds = outputs.predictions
820820
labels = outputs.label_ids
821821
x = trainer.eval_dataset.x
822-
self.assertTrue(len(preds), 2)
822+
self.assertEqual(len(preds), 2)
823823
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
824824
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
825825
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))

tests/trainer/test_trainer_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def test_distributed_tensor_gatherer(self):
9797
gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]])
9898
result = gatherer.finalize()
9999
self.assertTrue(isinstance(result, list))
100-
self.assertTrue(len(result), 2)
100+
self.assertEqual(len(result), 2)
101101
self.assertTrue(isinstance(result[1], list))
102-
self.assertTrue(len(result[1]), 2)
102+
self.assertEqual(len(result[1]), 2)
103103
self.assertTrue(np.array_equal(result[0], predictions))
104104
self.assertTrue(np.array_equal(result[1][0], predictions))
105105
self.assertTrue(np.array_equal(result[1][1], predictions))

tests/wav2vec2/test_modeling_flax_wav2vec2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def test_sample_negatives(self):
386386

387387
# make sure that full vectors are sampled and not values of vectors
388388
# => this means that `unique()` yields a single value for `hidden_size` dim
389-
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
389+
self.assertEqual(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
390390

391391
def test_sample_negatives_with_attn_mask(self):
392392
batch_size = 2
@@ -428,7 +428,7 @@ def test_sample_negatives_with_attn_mask(self):
428428

429429
# make sure that full vectors are sampled and not just slices of vectors
430430
# => this means that `unique()` yields a single value for `hidden_size` dim
431-
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
431+
self.assertEqual(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
432432

433433

434434
@require_flax

tests/wav2vec2/test_modeling_wav2vec2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def test_sample_negatives(self):
10611061
self.assertTrue(((negative - features) == 0).sum() == 0.0)
10621062

10631063
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
1064-
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
1064+
self.assertEqual(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
10651065

10661066
def test_sample_negatives_with_mask(self):
10671067
batch_size = 2
@@ -1098,7 +1098,7 @@ def test_sample_negatives_with_mask(self):
10981098
self.assertTrue(((negative - features) == 0).sum() == 0.0)
10991099

11001100
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
1101-
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
1101+
self.assertEqual(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
11021102

11031103

11041104
@require_torch

tests/wav2vec2/test_tokenization_wav2vec2.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _input_values_are_equal(input_values_1, input_values_2):
202202
input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values
203203

204204
self.assertTrue(_input_values_are_equal(input_values_1, input_values_4))
205-
self.assertTrue(input_values_5.shape, (3, 1600))
205+
self.assertEqual(input_values_5.shape, (3, 1600))
206206
# padding should be 0.0
207207
self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3)
208208

@@ -213,8 +213,8 @@ def _input_values_are_equal(input_values_1, input_values_2):
213213
).input_values
214214

215215
self.assertTrue(_input_values_are_equal(input_values_1, input_values_6))
216-
self.assertTrue(input_values_7.shape, (3, 1500))
217-
self.assertTrue(input_values_8.shape, (3, 2500))
216+
self.assertEqual(input_values_7.shape, (3, 1500))
217+
self.assertEqual(input_values_8.shape, (3, 2500))
218218
# padding should be 0.0
219219
self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3)
220220
self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3)
@@ -489,21 +489,21 @@ def test_offsets(self):
489489

490490
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
491491
# check Wav2Vec2CTCTokenizerOutput keys for char
492-
self.assertTrue(len(outputs_char.keys()), 2)
492+
self.assertEqual(len(outputs_char.keys()), 2)
493493
self.assertTrue("text" in outputs_char)
494494
self.assertTrue("char_offsets" in outputs_char)
495495
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
496496

497497
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
498498
# check Wav2Vec2CTCTokenizerOutput keys for word
499-
self.assertTrue(len(outputs_word.keys()), 2)
499+
self.assertEqual(len(outputs_word.keys()), 2)
500500
self.assertTrue("text" in outputs_word)
501501
self.assertTrue("word_offsets" in outputs_word)
502502
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
503503

504504
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
505505
# check Wav2Vec2CTCTokenizerOutput keys for both
506-
self.assertTrue(len(outputs.keys()), 3)
506+
self.assertEqual(len(outputs.keys()), 3)
507507
self.assertTrue("text" in outputs)
508508
self.assertTrue("char_offsets" in outputs)
509509
self.assertTrue("word_offsets" in outputs)

tests/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_offsets(self):
265265

266266
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
267267
# check Wav2Vec2CTCTokenizerOutput keys for char
268-
self.assertTrue(len(outputs.keys()), 2)
268+
self.assertEqual(len(outputs.keys()), 2)
269269
self.assertTrue("text" in outputs)
270270
self.assertTrue("char_offsets" in outputs)
271271
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))

tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def test_offsets_integration_fast(self):
368368

369369
outputs = processor.decode(logits, output_word_offsets=True)
370370
# check Wav2Vec2CTCTokenizerOutput keys for word
371-
self.assertTrue(len(outputs.keys()), 2)
371+
self.assertEqual(len(outputs.keys()), 4)
372372
self.assertTrue("text" in outputs)
373373
self.assertTrue("word_offsets" in outputs)
374374
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
@@ -385,7 +385,7 @@ def test_offsets_integration_fast_batch(self):
385385
outputs = processor.batch_decode(logits, output_word_offsets=True)
386386

387387
# check Wav2Vec2CTCTokenizerOutput keys for word
388-
self.assertTrue(len(outputs.keys()), 2)
388+
self.assertEqual(len(outputs.keys()), 4)
389389
self.assertTrue("text" in outputs)
390390
self.assertTrue("word_offsets" in outputs)
391391
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))

0 commit comments

Comments
 (0)