File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed
intel_extension_for_pytorch/transformers/models/xpu/fusions Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -471,18 +471,23 @@ def flash_attn_varlen_func(
471471 is_causal ,
472472 block_table ,
473473 alibi_slopes = None ,
474+ k_scale : float = 1.0 ,
475+ v_scale : float = 1.0 ,
474476 ):
475477 head_dim = query .size (- 1 )
476478 pad_query = query
477479 pad_k_cache = k_cache
478480 pad_v_cache = v_cache
481+ block_table_s = block_table
479482 pad_output = output
480483 if head_dim % 64 != 0 :
481484 pad_size = 64 - head_dim % 64
482485 pad_query = torch .nn .functional .pad (query , (0 , pad_size ))
483- pad_k_cache = torch .nn .functional .pad (k_cache , (0 , pad_size ))
484- pad_v_cache = torch .nn .functional .pad (v_cache , (0 , pad_size ))
486+ block_valid , block_table_s = block_table .unique (return_inverse = True )
487+ pad_k_cache = torch .nn .functional .pad (k_cache [block_valid ], (0 , pad_size ))
488+ pad_v_cache = torch .nn .functional .pad (v_cache [block_valid ], (0 , pad_size ))
485489 pad_output = torch .nn .functional .pad (output , (0 , pad_size ))
490+ block_table_s = block_table_s .to (torch .int32 )
486491 torch .ops .torch_ipex .chunked_prefill (
487492 pad_query ,
488493 pad_k_cache ,
@@ -491,7 +496,7 @@ def flash_attn_varlen_func(
491496 cu_seqlens_q ,
492497 cu_seqlens_kv ,
493498 None ,
494- block_table ,
499+ block_table_s ,
495500 alibi_slopes ,
496501 max_seqlen_q ,
497502 max_seqlen_kv ,
You can’t perform that action at this time.
0 commit comments