Skip to content

Commit b545f9f

Browse files
authored
fix acc issue when bs>1 and with static cache (#5327)(#5342)
* fix acc issue when bs>2 and with static cache * fix int4 fusedgemm onednn path bs>1 issue
1 parent f44a91d commit b545f9f

File tree

2 files changed

+17
-59
lines changed

2 files changed

+17
-59
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/XPUAttentionInt4.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
dequant_gemm_block,
1010
)
1111
from .model_utils import xpu_gemm_use_xetla
12-
from .CacheUtils import CacheFormat
1312

1413

1514
class IPEXAttentionInt4(IPEXAttention):
@@ -194,21 +193,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
194193
mq = query.shape[-1]
195194
mk = key.shape[-1]
196195
if IPEXAttention.cache_type == "static":
197-
if (
198-
IPEXAttention.cache_format == CacheFormat.FBNH
199-
and not self.beam_search_first_iter(hidden_states.shape[1])
200-
):
201-
query = qkv_out[:, :, :mq].transpose(0, 1)
202-
key.copy_(
203-
qkv_out[:, :, mq : mq + mk].transpose(0, 1)
204-
).contiguous()
205-
value.copy_(
206-
qkv_out[:, :, mq + mk :].transpose(0, 1)
207-
).contiguous()
208-
else:
209-
query = qkv_out[:, :, :mq]
210-
key.copy_(qkv_out[:, :, mq : mq + mk]).contiguous()
211-
value.copy_(qkv_out[:, :, mq + mk :]).contiguous()
196+
query = qkv_out[:, :, :mq]
197+
key.copy_(qkv_out[:, :, mq : mq + mk]).contiguous()
198+
value.copy_(qkv_out[:, :, mq + mk :]).contiguous()
212199
else:
213200
query = qkv_out[:, :, :mq]
214201
key = qkv_out[:, :, mq : mq + mk]
@@ -310,6 +297,9 @@ def out_proj_compute(self, attn_output, residual=None):
310297
if residual is not None:
311298
attn_output += residual
312299
return attn_output
300+
# ensure onednn kernel input is contiguous
301+
if not attn_output.is_contiguous():
302+
attn_output = attn_output.contiguous()
313303
if residual is None:
314304
if self.out_proj.bias is not None:
315305
attn_output = torch.ops.torch_ipex.mm_bias_int4(
@@ -399,6 +389,9 @@ def cat_qkv(self):
399389
pass
400390

401391
def compute_qkv_gemm(self, hidden_states, query, key, value):
392+
# ensure onednn kernel input is contiguous
393+
if not hidden_states.is_contiguous():
394+
hidden_states = hidden_states.contiguous()
402395
if (
403396
self.q_proj_quant.qweight is None
404397
and self.qkv_proj_quant.qweight is not None
@@ -418,32 +411,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
418411
intermediate_shape = (bs, seqlen, -1, num_group + 2, self.head_dim)
419412
qkv_out = qkv_out.view(*intermediate_shape)
420413
if IPEXAttention.cache_type == "static":
421-
if (
422-
IPEXAttention.cache_format == CacheFormat.FBNH
423-
and not self.beam_search_first_iter(hidden_states.shape[1])
424-
):
425-
query = (
426-
qkv_out[:, :, :, :-2]
427-
.reshape(bs, seqlen, -1)
428-
.transpose(0, 1)
429-
.contiguous()
430-
)
431-
key.copy_(
432-
qkv_out[:, :, :, [-2]]
433-
.reshape(bs, seqlen, -1)
434-
.transpose(0, 1)
435-
)
436-
value.copy_(
437-
qkv_out[:, :, :, [-1]]
438-
.reshape(bs, seqlen, -1)
439-
.transpose(0, 1)
440-
)
441-
else:
442-
query = (
443-
qkv_out[:, :, :, :-2].reshape(bs, seqlen, -1).contiguous()
444-
)
445-
key.copy_(qkv_out[:, :, :, [-2]].reshape(bs, seqlen, -1))
446-
value.copy_(qkv_out[:, :, :, [-1]].reshape(bs, seqlen, -1))
414+
query = qkv_out[:, :, :, :-2].reshape(bs, seqlen, -1).contiguous()
415+
key.copy_(qkv_out[:, :, :, [-2]].reshape(bs, seqlen, -1))
416+
value.copy_(qkv_out[:, :, :, [-1]].reshape(bs, seqlen, -1))
447417
elif IPEXAttention.cache_type == "dynamic":
448418
query = qkv_out[:, :, :, :-2].reshape(bs, seqlen, -1).contiguous()
449419
key = qkv_out[:, :, :, [-2]].reshape(bs, seqlen, -1).contiguous()
@@ -455,17 +425,9 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
455425
# Statice Cache needs to store the key and value in the applied space.
456426
# Dynamic Cache will cat the new key and value, so that does not need inplace operation.
457427
if IPEXAttention.cache_type == "static":
458-
if (
459-
IPEXAttention.cache_format == CacheFormat.FBNH
460-
and not self.beam_search_first_iter(hidden_states.shape[1])
461-
):
462-
query = qkv_out[:, :, :mq].transpose(0, 1).contiguous()
463-
key.copy_(qkv_out[:, :, mq : mq + mk].transpose(0, 1))
464-
value.copy_(qkv_out[:, :, mq + mk :].transpose(0, 1))
465-
else:
466-
query = qkv_out[:, :, :mq].contiguous()
467-
key.copy_(qkv_out[:, :, mq : mq + mk])
468-
value.copy_(qkv_out[:, :, mq + mk :])
428+
query = qkv_out[:, :, :mq].contiguous()
429+
key.copy_(qkv_out[:, :, mq : mq + mk])
430+
value.copy_(qkv_out[:, :, mq + mk :])
469431
elif IPEXAttention.cache_type == "dynamic":
470432
query = qkv_out[:, :, :mq].contiguous()
471433
key = qkv_out[:, :, mq : mq + mk].contiguous()
@@ -500,12 +462,6 @@ def compute_qkv_gemm(self, hidden_states, query, key, value):
500462
self.v_proj_quant.g_idx,
501463
)
502464
if IPEXAttention.cache_type == "static":
503-
if self.beam_idx is None or self.beam_search_next_token(
504-
hidden_states.size(1)
505-
):
506-
key.copy_(key_out.transpose(0, 1))
507-
value.copy_(value_out.transpose(0, 1))
508-
return query.transpose(0, 1).contiguous(), key, value
509465
key.copy_(key_out)
510466
value.copy_(value_out)
511467
return query, key, value

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/XPUAttentionfp16.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,8 @@ def forward(
463463
key, value = past_key_value.get_kv_slice_for_qkv(
464464
self.layer_idx, cache_position=cache_position
465465
)
466+
# StaticCache format is FBNH, hidden_states(input) need align this format
467+
hidden_states = hidden_states.permute(1, 0, 2)
466468

467469
else:
468470
query = torch.empty(

0 commit comments

Comments
 (0)