Skip to content

Commit 4bb497b

Browse files
pbelevichfacebook-github-bot
authored andcommitted
MultiheadAttention fixes
Summary: Pull Request resolved: pytorch#30666 Test Plan: Imported from OSS Differential Revision: D18864094 Pulled By: pbelevich fbshipit-source-id: f7a634b2c7f526282bf918d47b9cc82aa0c0af1d
1 parent 8b6d769 commit 4bb497b

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

test/test_nn.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3196,8 +3196,7 @@ def test_pad_scalar_error(self):
31963196
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
31973197
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
31983198

3199-
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
3200-
"Scipy v1.0 and/or numpy not found")
3199+
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
32013200
def test_multihead_attention(self):
32023201
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None):
32033202
""" Numpy-based reference implementation of scaled dot attention
@@ -3209,7 +3208,7 @@ def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None)
32093208
/ np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head)
32103209
)
32113210
b1, b2, s1, s2 = QKT.shape
3212-
if unseen_mask is not None or src_lengths is not None:
3211+
if unseen_mask is not None or key_padding_mask is not None:
32133212
# assert s1 == s2
32143213
for i in range(b1):
32153214
for j in range(b2):
@@ -3301,9 +3300,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33013300
saved_v_tensor = None
33023301
if saved_kv:
33033302
saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head)
3304-
saved_k_tensor = torch.from_numpy(saved_k)
3303+
saved_k_tensor = torch.from_numpy(saved_k).to(torch.get_default_dtype())
33053304
saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head)
3306-
saved_v_tensor = torch.from_numpy(saved_v)
3305+
saved_v_tensor = torch.from_numpy(saved_v).to(torch.get_default_dtype())
33073306

33083307
key_padding_mask = None
33093308
key_padding_mask_tensor = None
@@ -3312,8 +3311,8 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33123311
key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1)
33133312
key_padding_mask_tensor = torch.from_numpy(key_padding_mask)
33143313

3315-
decoder_state = np.random.rand(batch_sz, d_model).astype(np.float64)
3316-
K = np.random.rand(*dims).astype(np.float64)
3314+
decoder_state = np.random.rand(batch_sz, d_model)
3315+
K = np.random.rand(*dims)
33173316
V = K
33183317
Q = np.expand_dims(decoder_state, 1)
33193318
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
@@ -3322,8 +3321,8 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33223321
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
33233322
attn_mask_tensor = attn_mask_tensor.double()
33243323

3325-
decoder_state_tensor = torch.from_numpy(decoder_state).double()
3326-
source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1)
3324+
decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype())
3325+
source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1)
33273326

33283327
multihead_attn_module = MultiheadAttention(d_model, nheads,
33293328
add_bias_kv=add_bias_kv,
@@ -3337,7 +3336,6 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33373336
bias_k = None
33383337
bias_v = None
33393338

3340-
_batch_size = decoder_state_tensor.shape[0]
33413339
_Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1)
33423340
_V = source_hid_tensor
33433341
_K = source_hid_tensor
@@ -3397,7 +3395,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
33973395
else:
33983396
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)
33993397

3400-
if saved_k is not None:
3398+
if saved_v is not None:
34013399
V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head])
34023400
else:
34033401
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)

0 commit comments

Comments
 (0)