99 dequant_gemm_block ,
1010)
1111from .model_utils import xpu_gemm_use_xetla
12- from .CacheUtils import CacheFormat
1312
1413
1514class IPEXAttentionInt4 (IPEXAttention ):
@@ -194,21 +193,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
194193 mq = query .shape [- 1 ]
195194 mk = key .shape [- 1 ]
196195 if IPEXAttention .cache_type == "static" :
197- if (
198- IPEXAttention .cache_format == CacheFormat .FBNH
199- and not self .beam_search_first_iter (hidden_states .shape [1 ])
200- ):
201- query = qkv_out [:, :, :mq ].transpose (0 , 1 )
202- key .copy_ (
203- qkv_out [:, :, mq : mq + mk ].transpose (0 , 1 )
204- ).contiguous ()
205- value .copy_ (
206- qkv_out [:, :, mq + mk :].transpose (0 , 1 )
207- ).contiguous ()
208- else :
209- query = qkv_out [:, :, :mq ]
210- key .copy_ (qkv_out [:, :, mq : mq + mk ]).contiguous ()
211- value .copy_ (qkv_out [:, :, mq + mk :]).contiguous ()
196+ query = qkv_out [:, :, :mq ]
197+ key .copy_ (qkv_out [:, :, mq : mq + mk ]).contiguous ()
198+ value .copy_ (qkv_out [:, :, mq + mk :]).contiguous ()
212199 else :
213200 query = qkv_out [:, :, :mq ]
214201 key = qkv_out [:, :, mq : mq + mk ]
@@ -310,6 +297,9 @@ def out_proj_compute(self, attn_output, residual=None):
310297 if residual is not None :
311298 attn_output += residual
312299 return attn_output
300+ # ensure onednn kernel input is contiguous
301+ if not attn_output .is_contiguous ():
302+ attn_output = attn_output .contiguous ()
313303 if residual is None :
314304 if self .out_proj .bias is not None :
315305 attn_output = torch .ops .torch_ipex .mm_bias_int4 (
@@ -399,6 +389,9 @@ def cat_qkv(self):
399389 pass
400390
401391 def compute_qkv_gemm (self , hidden_states , query , key , value ):
392+ # ensure onednn kernel input is contiguous
393+ if not hidden_states .is_contiguous ():
394+ hidden_states = hidden_states .contiguous ()
402395 if (
403396 self .q_proj_quant .qweight is None
404397 and self .qkv_proj_quant .qweight is not None
@@ -418,32 +411,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
418411 intermediate_shape = (bs , seqlen , - 1 , num_group + 2 , self .head_dim )
419412 qkv_out = qkv_out .view (* intermediate_shape )
420413 if IPEXAttention .cache_type == "static" :
421- if (
422- IPEXAttention .cache_format == CacheFormat .FBNH
423- and not self .beam_search_first_iter (hidden_states .shape [1 ])
424- ):
425- query = (
426- qkv_out [:, :, :, :- 2 ]
427- .reshape (bs , seqlen , - 1 )
428- .transpose (0 , 1 )
429- .contiguous ()
430- )
431- key .copy_ (
432- qkv_out [:, :, :, [- 2 ]]
433- .reshape (bs , seqlen , - 1 )
434- .transpose (0 , 1 )
435- )
436- value .copy_ (
437- qkv_out [:, :, :, [- 1 ]]
438- .reshape (bs , seqlen , - 1 )
439- .transpose (0 , 1 )
440- )
441- else :
442- query = (
443- qkv_out [:, :, :, :- 2 ].reshape (bs , seqlen , - 1 ).contiguous ()
444- )
445- key .copy_ (qkv_out [:, :, :, [- 2 ]].reshape (bs , seqlen , - 1 ))
446- value .copy_ (qkv_out [:, :, :, [- 1 ]].reshape (bs , seqlen , - 1 ))
414+ query = qkv_out [:, :, :, :- 2 ].reshape (bs , seqlen , - 1 ).contiguous ()
415+ key .copy_ (qkv_out [:, :, :, [- 2 ]].reshape (bs , seqlen , - 1 ))
416+ value .copy_ (qkv_out [:, :, :, [- 1 ]].reshape (bs , seqlen , - 1 ))
447417 elif IPEXAttention .cache_type == "dynamic" :
448418 query = qkv_out [:, :, :, :- 2 ].reshape (bs , seqlen , - 1 ).contiguous ()
449419 key = qkv_out [:, :, :, [- 2 ]].reshape (bs , seqlen , - 1 ).contiguous ()
@@ -455,17 +425,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
455425 # Statice Cache needs to store the key and value in the applied space.
456426 # Dynamic Cache will cat the new key and value, so that does not need inplace operation.
457427 if IPEXAttention .cache_type == "static" :
458- if (
459- IPEXAttention .cache_format == CacheFormat .FBNH
460- and not self .beam_search_first_iter (hidden_states .shape [1 ])
461- ):
462- query = qkv_out [:, :, :mq ].transpose (0 , 1 ).contiguous ()
463- key .copy_ (qkv_out [:, :, mq : mq + mk ].transpose (0 , 1 ))
464- value .copy_ (qkv_out [:, :, mq + mk :].transpose (0 , 1 ))
465- else :
466- query = qkv_out [:, :, :mq ].contiguous ()
467- key .copy_ (qkv_out [:, :, mq : mq + mk ])
468- value .copy_ (qkv_out [:, :, mq + mk :])
428+ query = qkv_out [:, :, :mq ].contiguous ()
429+ key .copy_ (qkv_out [:, :, mq : mq + mk ])
430+ value .copy_ (qkv_out [:, :, mq + mk :])
469431 elif IPEXAttention .cache_type == "dynamic" :
470432 query = qkv_out [:, :, :mq ].contiguous ()
471433 key = qkv_out [:, :, mq : mq + mk ].contiguous ()
@@ -500,12 +462,6 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
500462 self .v_proj_quant .g_idx ,
501463 )
502464 if IPEXAttention .cache_type == "static" :
503- if self .beam_idx is None or self .beam_search_next_token (
504- hidden_states .size (1 )
505- ):
506- key .copy_ (key_out .transpose (0 , 1 ))
507- value .copy_ (value_out .transpose (0 , 1 ))
508- return query .transpose (0 , 1 ).contiguous (), key , value
509465 key .copy_ (key_out )
510466 value .copy_ (value_out )
511467 return query , key , value
0 commit comments