Skip to content

Commit 6a974be

Browse files
ezyangpytorchmergebot
authored andcommitted
Change flash attention outputs to be SymInt instead of int (pytorch#110533)
Fixes pytorch#110322 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#110533 Approved by: https://github.com/albanD
1 parent f1d8113 commit 6a974be

File tree

10 files changed

+65
-15
lines changed

10 files changed

+65
-15
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14349,14 +14349,14 @@
1434914349
variants: function
1435014350
tags: nondeterministic_seeded
1435114351

14352-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14352+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1435314353
dispatch:
1435414354
CPU: _scaled_dot_product_flash_attention_cpu
1435514355
CUDA: _scaled_dot_product_flash_attention_cuda
1435614356
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
1435714357
tags: nondeterministic_seeded
1435814358

14359-
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
14359+
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1436014360
device_check: NoCheck
1436114361
variants: function
1436214362
dispatch:
@@ -14375,13 +14375,13 @@
1437514375
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
1437614376
tags: nondeterministic_seeded
1437714377

14378-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14378+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt? max_q, SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1437914379
variants: function
1438014380
dispatch:
1438114381
CUDA: _flash_attention_forward
1438214382
tags: nondeterministic_seeded
1438314383

14384-
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14384+
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1438514385
device_check: NoCheck
1438614386
variants: function
1438714387
dispatch:

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ std::tuple<
220220
Tensor,
221221
Tensor,
222222
Tensor,
223-
int64_t,
224-
int64_t,
223+
c10::SymInt,
224+
c10::SymInt,
225225
Tensor,
226226
Tensor,
227227
Tensor>

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ std::tuple<
744744
at::Tensor,
745745
at::Tensor,
746746
at::Tensor,
747-
int64_t,
748-
int64_t,
747+
c10::SymInt,
748+
c10::SymInt,
749749
at::Tensor,
750750
at::Tensor,
751751
at::Tensor>

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
668668
}
669669
return std::make_tuple(std::move(proj), std::move(qkt));
670670
}
671-
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
671+
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
672672
const Tensor& query,
673673
const Tensor& key,
674674
const Tensor& value,

test/inductor/test_cuda_repro.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
import torch
77
import torch._dynamo.config as dynamo_config
8+
import torch.backends.cuda
9+
import torch.nn.functional as F
810
from torch import nn
911
from torch._dynamo.debug_utils import same_two_models
1012
from torch._dynamo.testing import rand_strided
1113
from torch._dynamo.utils import same
1214
from torch._inductor import config
1315
from torch._inductor.compile_fx import compile_fx_inner
1416
from torch.fx.experimental.proxy_tensor import make_fx
17+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
1518
from torch.testing._internal.common_utils import (
1619
DeterministicGuard,
1720
freeze_rng_state,
@@ -982,6 +985,51 @@ def fn(x, y, z):
982985

983986
self.assertEqual(ref, res)
984987

988+
@unittest.skipIf(
989+
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
990+
)
991+
def test_flash_attention_dynamic(self):
992+
class Model(nn.Module):
993+
def __init__(self, *args, **kwargs) -> None:
994+
super().__init__(*args, **kwargs)
995+
996+
self.q = nn.Linear(1024, 1024)
997+
self.k = nn.Linear(1024, 1024)
998+
self.v = nn.Linear(1024, 1024)
999+
1000+
def forward(self, x):
1001+
batch_size, seq_len, _ = x.size()
1002+
1003+
queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1004+
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1005+
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1006+
1007+
attn = F.scaled_dot_product_attention(
1008+
queries,
1009+
keys,
1010+
values,
1011+
)
1012+
1013+
return attn
1014+
1015+
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
1016+
1017+
model = Model().cuda().half()
1018+
model = torch.compile(model, backend=cnts, dynamic=True)
1019+
1020+
with torch.backends.cuda.sdp_kernel(
1021+
enable_flash=True, enable_math=False, enable_mem_efficient=False
1022+
):
1023+
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
1024+
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
1025+
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)
1026+
1027+
out1 = model(input1)
1028+
out2 = model(input2)
1029+
out3 = model(input3)
1030+
1031+
self.assertEqual(cnts.frame_count, 1)
1032+
9851033
@config.patch({"triton.cudagraphs": True})
9861034
def test_index_put_no_fallback_cudagraph(self):
9871035
def fn(x, y, z):

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,9 +2764,9 @@
27642764
output_differentiability: [True, False, False, False]
27652765
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
27662766

2767-
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2767+
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
27682768
output_differentiability: [True, False, False, False, False, False, False, False, False]
2769-
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
2769+
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
27702770

27712771
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
27722772
# output_differentiability: [True, False, False, False, False, False, False, False]

torch/_C/return_types.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from typing import (
1616
Union,
1717
)
1818

19-
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
19+
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor, SymInt
2020
from torch.types import (
2121
_bool,
2222
_device,

torch/_inductor/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3984,6 +3984,8 @@ def generate_output(output, indices):
39843984
)
39853985
elif isinstance(output, int):
39863986
return output
3987+
elif isinstance(output, torch.SymInt):
3988+
return output.node.expr
39873989
else:
39883990
assert (
39893991
output is None

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
228228
at::Tensor* ret3_tensor = new at::Tensor(std::move(r3));
229229
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
230230
}
231-
*ret4 = r4;
232-
*ret5 = r5;
231+
*ret4 = r4.expect_int();
232+
*ret5 = r5.expect_int();
233233
at::Tensor* ret6_tensor = new at::Tensor(std::move(r6));
234234
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
235235
at::Tensor* ret7_tensor = new at::Tensor(std::move(r7));

torchgen/api/python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
11291129
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
11301130
"::std::vector<at::Tensor>",
11311131
# Needed for flash attention forw/backward
1132-
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t,int64_t,at::Tensor,at::Tensor,at::Tensor>",
1132+
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
11331133
"at::Scalar",
11341134
"bool",
11351135
"int64_t",

0 commit comments

Comments
 (0)