@@ -3196,8 +3196,7 @@ def test_pad_scalar_error(self):
31963196 self .assertRaises (AssertionError , lambda : F .pad (inputs , (1 , 1 )))
31973197 self .assertRaises (AssertionError , lambda : F .pad (inputs , (1 ,)))
31983198
3199- @unittest .skipIf ((not TEST_NUMPY ) or (not TEST_SCIPY ) or (scipy .__version__ < '1.0.0' ),
3200- "Scipy v1.0 and/or numpy not found" )
3199+ @unittest .skipIf (not TEST_NUMPY , "numpy not found" )
32013200 def test_multihead_attention (self ):
32023201 def _scaled_dot_attn_ref (Q , K , V , dims , unseen_mask = None , key_padding_mask = None ):
32033202 """ Numpy-based reference implementation of scaled dot attention
@@ -3209,7 +3208,7 @@ def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None)
32093208 / np .sqrt (dims [3 ], dtype = np .float32 ), # divide by sqrt(d_head)
32103209 )
32113210 b1 , b2 , s1 , s2 = QKT .shape
3212- if unseen_mask is not None or src_lengths is not None :
3211+ if unseen_mask is not None or key_padding_mask is not None :
32133212 # assert s1 == s2
32143213 for i in range (b1 ):
32153214 for j in range (b2 ):
@@ -3301,9 +3300,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33013300 saved_v_tensor = None
33023301 if saved_kv :
33033302 saved_k = np .random .rand (batch_sz * nheads , seq_len , d_head )
3304- saved_k_tensor = torch .from_numpy (saved_k )
3303+ saved_k_tensor = torch .from_numpy (saved_k ). to ( torch . get_default_dtype ())
33053304 saved_v = np .random .rand (batch_sz * nheads , seq_len , d_head )
3306- saved_v_tensor = torch .from_numpy (saved_v )
3305+ saved_v_tensor = torch .from_numpy (saved_v ). to ( torch . get_default_dtype ())
33073306
33083307 key_padding_mask = None
33093308 key_padding_mask_tensor = None
@@ -3312,8 +3311,8 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33123311 key_padding_mask = (np .repeat (seq_mask , batch_sz , axis = 0 ) == 1 )
33133312 key_padding_mask_tensor = torch .from_numpy (key_padding_mask )
33143313
3315- decoder_state = np .random .rand (batch_sz , d_model ). astype ( np . float64 )
3316- K = np .random .rand (* dims ). astype ( np . float64 )
3314+ decoder_state = np .random .rand (batch_sz , d_model )
3315+ K = np .random .rand (* dims )
33173316 V = K
33183317 Q = np .expand_dims (decoder_state , 1 )
33193318 attn_mask = np .random .randint (0 , 2 , size = (1 , seq_len ))
@@ -3322,8 +3321,8 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33223321 attn_mask_tensor .masked_fill_ (attn_mask_tensor > 0 , float ('0.0' ))
33233322 attn_mask_tensor = attn_mask_tensor .double ()
33243323
3325- decoder_state_tensor = torch .from_numpy (decoder_state ).double ( )
3326- source_hid_tensor = torch .from_numpy (K ).double ( ).transpose (0 , 1 )
3324+ decoder_state_tensor = torch .from_numpy (decoder_state ).to ( torch . get_default_dtype () )
3325+ source_hid_tensor = torch .from_numpy (K ).to ( torch . get_default_dtype () ).transpose (0 , 1 )
33273326
33283327 multihead_attn_module = MultiheadAttention (d_model , nheads ,
33293328 add_bias_kv = add_bias_kv ,
@@ -3337,7 +3336,6 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33373336 bias_k = None
33383337 bias_v = None
33393338
3340- _batch_size = decoder_state_tensor .shape [0 ]
33413339 _Q = decoder_state_tensor .unsqueeze (1 ).transpose (0 , 1 )
33423340 _V = source_hid_tensor
33433341 _K = source_hid_tensor
@@ -3397,7 +3395,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33973395 else :
33983396 K_split = _split_heads_ref (K_fc , dims , nheads , d_head )
33993397
3400- if saved_k is not None :
3398+ if saved_v is not None :
34013399 V_split = np .reshape (saved_v , [dims [0 ], nheads , dims [1 ], d_head ])
34023400 else :
34033401 V_split = _split_heads_ref (V_fc , dims , nheads , d_head )
0 commit comments