Skip to content

Commit bc077ba

Browse files
authored
fix crash in tgi, and fix oom if head_size is not 64 aligned (#5317)(#5343)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent eaa5562 commit bc077ba

File tree

1 file changed

+8
-3
lines changed
  • intel_extension_for_pytorch/transformers/models/xpu/fusions

1 file changed

+8
-3
lines changed

intel_extension_for_pytorch/transformers/models/xpu/fusions/mha_fusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,18 +471,23 @@ def flash_attn_varlen_func(
471471
is_causal,
472472
block_table,
473473
alibi_slopes=None,
474+
k_scale: float = 1.0,
475+
v_scale: float = 1.0,
474476
):
475477
head_dim = query.size(-1)
476478
pad_query = query
477479
pad_k_cache = k_cache
478480
pad_v_cache = v_cache
481+
block_table_s = block_table
479482
pad_output = output
480483
if head_dim % 64 != 0:
481484
pad_size = 64 - head_dim % 64
482485
pad_query = torch.nn.functional.pad(query, (0, pad_size))
483-
pad_k_cache = torch.nn.functional.pad(k_cache, (0, pad_size))
484-
pad_v_cache = torch.nn.functional.pad(v_cache, (0, pad_size))
486+
block_valid, block_table_s = block_table.unique(return_inverse=True)
487+
pad_k_cache = torch.nn.functional.pad(k_cache[block_valid], (0, pad_size))
488+
pad_v_cache = torch.nn.functional.pad(v_cache[block_valid], (0, pad_size))
485489
pad_output = torch.nn.functional.pad(output, (0, pad_size))
490+
block_table_s = block_table_s.to(torch.int32)
486491
torch.ops.torch_ipex.chunked_prefill(
487492
pad_query,
488493
pad_k_cache,
@@ -491,7 +496,7 @@ def flash_attn_varlen_func(
491496
cu_seqlens_q,
492497
cu_seqlens_kv,
493498
None,
494-
block_table,
499+
block_table_s,
495500
alibi_slopes,
496501
max_seqlen_q,
497502
max_seqlen_kv,

0 commit comments

Comments
 (0)