Skip to content

Commit a539beb

Browse files
authored
Fix Falcon-40b accuracy issue. (#5351)(#5359)
* Fix Falcon 40b accuracy issue. * Fix typos. * Remove useless code. * Format.
1 parent 378954b commit a539beb

File tree

1 file changed

+30
-20
lines changed
  • intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules

1 file changed

+30
-20
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/falcon.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,18 @@ def __init__(
127127
"IPEXAttention dose not support this modelType {} !".format(dtype)
128128
)
129129

130-
if self.grouped:
131-
self.attn.load_parameter = partial(
132-
chatglm_load_attn_params_grouped, self.attn
133-
)
134-
else:
135-
self.attn.load_parameter = partial(load_attn_fused_qkv_params, self.attn)
136-
self.attn.transpose_parameter = partial(
137-
transpose_attn_fused_qkv_params, self.attn
138-
)
130+
if not self.new_decoder_architecture:
131+
if self.grouped:
132+
self.attn.load_parameter = partial(
133+
chatglm_load_attn_params_grouped, self.attn
134+
)
135+
else:
136+
self.attn.load_parameter = partial(
137+
load_attn_fused_qkv_params, self.attn
138+
)
139+
self.attn.transpose_parameter = partial(
140+
transpose_attn_fused_qkv_params, self.attn
141+
)
139142

140143
self.mlp = (
141144
FalconMLP(config)
@@ -218,11 +221,17 @@ def build_ipex_transformer_config(
218221
)
219222

220223
def port_attn_parameter(self):
221-
self.attn.load_parameter(
222-
self.module.self_attention.query_key_value,
223-
self.module.self_attention.dense,
224-
dtype=self.ipex_config.dtype,
225-
)
224+
if self.new_decoder_architecture:
225+
self.attn.load_parameter(
226+
qkv_proj=self.module.self_attention.query_key_value,
227+
out_proj=self.module.self_attention.dense,
228+
)
229+
else:
230+
self.attn.load_parameter(
231+
self.module.self_attention.query_key_value,
232+
self.module.self_attention.dense,
233+
dtype=self.ipex_config.dtype,
234+
)
226235

227236
def port_mlp_parameter(self):
228237
if self.new_decoder_architecture:
@@ -255,13 +264,14 @@ def port_norm_parameter(self):
255264

256265
def transpose_parameter(self):
257266
if self.new_decoder_architecture:
267+
self.attn.transpose_parameter()
258268
self.mlp.transpose_parameter()
259-
260-
if not self.grouped:
261-
dtype = self.ipex_config.dtype
262-
self.attn.transpose_parameter(dtype=dtype)
263269
else:
264-
self.attn.transpose_parameter()
270+
if not self.grouped:
271+
dtype = self.ipex_config.dtype
272+
self.attn.transpose_parameter(dtype=dtype)
273+
else:
274+
self.attn.transpose_parameter()
265275

266276
def port_all_parameters_to_new_module(self):
267277
super().port_all_parameters_to_new_module()
@@ -369,7 +379,7 @@ def forward(
369379
next_cache = None
370380
if use_cache:
371381
layer_past = outputs[0]
372-
outputs = (output, layer_past) + outputs
382+
outputs = (output, layer_past) + outputs[1:]
373383
else:
374384
outputs = (output,) + outputs[1:]
375385
return outputs

0 commit comments

Comments
 (0)