@@ -416,11 +416,11 @@ def test_diagonalize(self):
416
416
417
417
def test_pad_and_transpose_last_two_dims (self ):
418
418
hidden_states = self ._get_hidden_states ()
419
- self .assertTrue (hidden_states .shape , (1 , 8 , 4 ))
419
+ self .assertEqual (hidden_states .shape , (1 , 4 , 8 ))
420
420
padding = (0 , 0 , 0 , 1 )
421
421
422
422
padded_hidden_states = LongformerSelfAttention ._pad_and_transpose_last_two_dims (hidden_states , padding )
423
- self .assertTrue (padded_hidden_states .shape , (1 , 8 , 5 ))
423
+ self .assertEqual (padded_hidden_states .shape , (1 , 8 , 5 ))
424
424
425
425
expected_added_dim = torch .zeros ((5 ,), device = torch_device , dtype = torch .float32 )
426
426
self .assertTrue (torch .allclose (expected_added_dim , padded_hidden_states [0 , - 1 , :], atol = 1e-6 ))
@@ -445,7 +445,7 @@ def test_chunk(self):
445
445
446
446
self .assertTrue (torch .allclose (chunked_hidden_states [0 , :, 0 , 0 ], expected_slice_along_seq_length , atol = 1e-3 ))
447
447
self .assertTrue (torch .allclose (chunked_hidden_states [0 , 0 , :, 0 ], expected_slice_along_chunk , atol = 1e-3 ))
448
- self .assertTrue (chunked_hidden_states .shape , (1 , 3 , 4 , 4 ))
448
+ self .assertEqual (chunked_hidden_states .shape , (1 , 3 , 4 , 4 ))
449
449
450
450
def test_mask_invalid_locations (self ):
451
451
hidden_states = self ._get_hidden_states ()
@@ -493,7 +493,7 @@ def test_layer_local_attn(self):
493
493
is_global_attn = is_global_attn ,
494
494
)[0 ]
495
495
496
- self .assertTrue (output_hidden_states .shape , (1 , 4 , 8 ))
496
+ self .assertEqual (output_hidden_states .shape , (1 , 4 , 8 ))
497
497
self .assertTrue (
498
498
torch .allclose (
499
499
output_hidden_states [0 , 1 ],
@@ -531,7 +531,7 @@ def test_layer_global_attn(self):
531
531
is_global_attn = is_global_attn ,
532
532
)[0 ]
533
533
534
- self .assertTrue (output_hidden_states .shape , (2 , 4 , 8 ))
534
+ self .assertEqual (output_hidden_states .shape , (2 , 4 , 8 ))
535
535
536
536
self .assertTrue (
537
537
torch .allclose (
0 commit comments