@@ -127,15 +127,18 @@ def __init__(
127127 "IPEXAttention dose not support this modelType {} !" .format (dtype )
128128 )
129129
130- if self .grouped :
131- self .attn .load_parameter = partial (
132- chatglm_load_attn_params_grouped , self .attn
133- )
134- else :
135- self .attn .load_parameter = partial (load_attn_fused_qkv_params , self .attn )
136- self .attn .transpose_parameter = partial (
137- transpose_attn_fused_qkv_params , self .attn
138- )
130+ if not self .new_decoder_architecture :
131+ if self .grouped :
132+ self .attn .load_parameter = partial (
133+ chatglm_load_attn_params_grouped , self .attn
134+ )
135+ else :
136+ self .attn .load_parameter = partial (
137+ load_attn_fused_qkv_params , self .attn
138+ )
139+ self .attn .transpose_parameter = partial (
140+ transpose_attn_fused_qkv_params , self .attn
141+ )
139142
140143 self .mlp = (
141144 FalconMLP (config )
@@ -218,11 +221,17 @@ def build_ipex_transformer_config(
218221 )
219222
220223 def port_attn_parameter (self ):
221- self .attn .load_parameter (
222- self .module .self_attention .query_key_value ,
223- self .module .self_attention .dense ,
224- dtype = self .ipex_config .dtype ,
225- )
224+ if self .new_decoder_architecture :
225+ self .attn .load_parameter (
226+ qkv_proj = self .module .self_attention .query_key_value ,
227+ out_proj = self .module .self_attention .dense ,
228+ )
229+ else :
230+ self .attn .load_parameter (
231+ self .module .self_attention .query_key_value ,
232+ self .module .self_attention .dense ,
233+ dtype = self .ipex_config .dtype ,
234+ )
226235
227236 def port_mlp_parameter (self ):
228237 if self .new_decoder_architecture :
@@ -255,13 +264,14 @@ def port_norm_parameter(self):
255264
256265 def transpose_parameter (self ):
257266 if self .new_decoder_architecture :
267+ self .attn .transpose_parameter ()
258268 self .mlp .transpose_parameter ()
259-
260- if not self .grouped :
261- dtype = self .ipex_config .dtype
262- self .attn .transpose_parameter (dtype = dtype )
263269 else :
264- self .attn .transpose_parameter ()
270+ if not self .grouped :
271+ dtype = self .ipex_config .dtype
272+ self .attn .transpose_parameter (dtype = dtype )
273+ else :
274+ self .attn .transpose_parameter ()
265275
266276 def port_all_parameters_to_new_module (self ):
267277 super ().port_all_parameters_to_new_module ()
@@ -369,7 +379,7 @@ def forward(
369379 next_cache = None
370380 if use_cache :
371381 layer_past = outputs [0 ]
372- outputs = (output , layer_past ) + outputs
382+ outputs = (output , layer_past ) + outputs [ 1 :]
373383 else :
374384 outputs = (output ,) + outputs [1 :]
375385 return outputs
0 commit comments