Skip to content

Commit 7259fd8

Browse files
baodiigeqinling
andauthored
use seqlen k only for cases with block table (#5836) (#5839)
Signed-off-by: baodii <di.bao@intel.com> Co-authored-by: Ge Qinling <qinling.ge@intel.com>
1 parent ba6aee9 commit 7259fd8

File tree

3 files changed

+29
-32
lines changed

3 files changed

+29
-32
lines changed

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,19 +358,15 @@ class fmha_forward_t {
358358
mem_desc_Oi.init(
359359
args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});
360360

361-
// get current location for kv
362-
kv_offset_y = 0;
363-
for (int32_t i = 0; i <= static_cast<int32_t>(batch_id) - 1; ++i) {
364-
kv_offset_y = kv_offset_y + args.cu_seqlen_k[i];
365-
}
366-
367361
// for local attention
368362
if constexpr (kIsLocal) {
369363
if constexpr (kIsCausal) {
370364
args.w_right = 0;
371365
}
372366
int32_t startF = item.get_group(1) * kBr;
373-
uint32_t real_T = args.cu_seqlen_k[batch_id];
367+
uint32_t real_T = args.block_tables == nullptr
368+
? args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]
369+
: args.cu_seqlen_k[batch_id];
374370
uint32_t real_F =
375371
args.cu_seqlen_q[batch_id + 1] - args.cu_seqlen_q[batch_id];
376372
uint32_t seq_diff = real_T - real_F;
@@ -458,9 +454,9 @@ class fmha_forward_t {
458454
remain_T = remain_T < args.block_size ? remain_T : args.block_size;
459455
end_x = start_x + remain_T;
460456
} else {
461-
start_x = startT + kv_offset_y;
457+
start_x = startT + args.cu_seqlen_k[batch_id];
462458
end_x = start_x + kBc;
463-
int32_t limit_x = kv_offset_y + args.cu_seqlen_k[batch_id];
459+
int32_t limit_x = args.cu_seqlen_k[batch_id + 1];
464460
end_x = end_x < limit_x ? end_x : limit_x;
465461
}
466462
int32_t start_acc = head_id * args.uNkv / args.uN * args.uH;
@@ -701,7 +697,9 @@ class fmha_forward_t {
701697
}
702698
uint32_t real_T = args.uT;
703699
if constexpr (kVarlen) {
704-
real_T = args.cu_seqlen_k[ctx.batch_id];
700+
real_T = args.block_tables == nullptr
701+
? args.cu_seqlen_k[ctx.batch_id + 1] - args.cu_seqlen_k[ctx.batch_id]
702+
: args.cu_seqlen_k[ctx.batch_id];
705703
}
706704
uint32_t remainT = std::max(int(real_T) - int(sg_startT), 0);
707705
if constexpr (kIsLocal) {
@@ -1066,7 +1064,9 @@ class fmha_forward_t {
10661064
int32_t actual_seqlen_k = 0;
10671065
int32_t seqlen_diff = 0;
10681066
if constexpr (kVarlen) {
1069-
actual_seqlen_k = args.cu_seqlen_k[batch_id];
1067+
actual_seqlen_k = args.block_tables == nullptr
1068+
? args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]
1069+
: args.cu_seqlen_k[batch_id];
10701070
seqlen_diff = actual_seqlen_k - actual_seqlen_q;
10711071
}
10721072

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward_v3.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,19 +363,15 @@ class fmha_forward_v3_t {
363363
mem_desc_Oi[i].init(
364364
args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});
365365
}
366-
// get current kv location
367-
kv_offset_y = 0;
368-
for (int32_t i = 0; i <= static_cast<int>(batch_id) - 1; ++i) {
369-
kv_offset_y += args.cu_seqlen_k[i];
370-
}
371-
372366
// for local attention
373367
if constexpr (kIsLocal) {
374368
if constexpr (kIsCausal) {
375369
args.w_right = 0;
376370
}
377371
int32_t startF = item.get_group(1) * kBr;
378-
uint32_t real_T = args.cu_seqlen_k[batch_id];
372+
uint32_t real_T = args.block_tables == nullptr
373+
? args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]
374+
: args.cu_seqlen_k[batch_id];
379375
uint32_t real_F =
380376
args.cu_seqlen_q[batch_id + 1] - args.cu_seqlen_q[batch_id];
381377
uint32_t seq_diff = real_T - real_F;
@@ -473,9 +469,9 @@ class fmha_forward_v3_t {
473469
remain_T = remain_T < args.block_size ? remain_T : args.block_size;
474470
end_x = start_x + remain_T;
475471
} else {
476-
start_x = startT + kv_offset_y;
472+
start_x = startT + args.cu_seqlen_k[batch_id];
477473
end_x = start_x + kBc;
478-
int32_t limit_x = kv_offset_y + args.cu_seqlen_k[batch_id];
474+
int32_t limit_x = args.cu_seqlen_k[batch_id + 1];
479475
end_x = end_x < limit_x ? end_x : limit_x;
480476
}
481477
int32_t start_acc = head_id_kv * args.uH;
@@ -664,7 +660,9 @@ class fmha_forward_v3_t {
664660
}
665661
uint32_t real_T = args.uT;
666662
if constexpr (kVarlen) {
667-
real_T = args.cu_seqlen_k[ctx.batch_id];
663+
real_T = args.block_tables == nullptr
664+
? args.cu_seqlen_k[ctx.batch_id + 1] - args.cu_seqlen_k[ctx.batch_id]
665+
: args.cu_seqlen_k[ctx.batch_id];
668666
}
669667
uint32_t remainT = std::max(int(real_T) - int(sg_startT), 0);
670668
if constexpr (kIsLocal) {
@@ -920,7 +918,9 @@ class fmha_forward_v3_t {
920918
int32_t actual_seqlen_k = 0;
921919
int32_t seqlen_diff = 0;
922920
if constexpr (kVarlen) {
923-
actual_seqlen_k = args.cu_seqlen_k[batch_id];
921+
actual_seqlen_k = args.block_tables == nullptr
922+
? args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]
923+
: args.cu_seqlen_k[batch_id];
924924
seqlen_diff = actual_seqlen_k - actual_seqlen_q;
925925
}
926926

tests/gpu/examples/test_varlen_fwd.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def varlen_fwd_reference(
113113
seqlen_q_ = seqlen_q.clone()
114114
seqlen_q_[:batch_size] = seqlen_q[1:]
115115
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
116-
# seqlen_k_ = seqlen_k.clone()
117-
# seqlen_k_[:batch_size] = seqlen_k[1:]
118-
# seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]
116+
seqlen_k_ = seqlen_k.clone()
117+
seqlen_k_[:batch_size] = seqlen_k[1:]
118+
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]
119119

120120
pad_q = torch.zeros(
121121
[batch_size, max_seqlen_q, num_head, head_size],
@@ -263,8 +263,6 @@ def test_varlen_fwd(
263263
cu_seqlen = (
264264
torch.cat([torch.tensor([0]), cu_seqlen], dim=0).to(torch.int32).to("xpu")
265265
)
266-
seqlen_list = seqlen_list.to("xpu")
267-
print(f"seqlen_list: {seqlen_list} cu_seqlen: {cu_seqlen}")
268266

269267
query = torch.randn(
270268
[cu_seqlen[-1], num_heads_query, head_dim], dtype=dtype, device="xpu"
@@ -294,7 +292,7 @@ def test_varlen_fwd(
294292
value,
295293
out,
296294
cu_seqlen,
297-
seqlen_list,
295+
cu_seqlen,
298296
None,
299297
None,
300298
alibi_slopes,
@@ -316,7 +314,7 @@ def test_varlen_fwd(
316314
value,
317315
out_ref,
318316
cu_seqlen,
319-
seqlen_list,
317+
cu_seqlen,
320318
max_seqlen,
321319
max_seqlen,
322320
alibi_slopes,
@@ -336,7 +334,7 @@ def test_varlen_fwd(
336334
value,
337335
out,
338336
cu_seqlen,
339-
seqlen_list,
337+
cu_seqlen,
340338
alibi_slopes,
341339
max_seqlen,
342340
max_seqlen,
@@ -376,7 +374,6 @@ def test_varlen_attention_softcap(
376374
cu_seqlen = (
377375
torch.cat([torch.tensor([0]), cu_seqlen], dim=0).to(torch.int32).to("xpu")
378376
)
379-
seqlen_list = seqlen_list.to("xpu")
380377

381378
query = torch.randn(
382379
[cu_seqlen[-1], num_heads_query, head_dim], dtype=dtype, device="xpu"
@@ -395,7 +392,7 @@ def test_varlen_attention_softcap(
395392
value,
396393
out,
397394
cu_seqlen,
398-
seqlen_list,
395+
cu_seqlen,
399396
None,
400397
max_seqlen,
401398
max_seqlen,

0 commit comments

Comments
 (0)