@@ -50,80 +50,51 @@ def _copy_model_and_optimizer(model, optimizer):
50
50
return new_model , optimizer
51
51
else :
52
52
new_optimizer = copy .deepcopy (optimizer )
53
- new_optimizer .state .clear ()
54
53
dic_param = {}
54
+ dic_param_for_master_case = {}
55
55
for k , value in zip (model .parameters (), new_model .parameters ()):
56
56
dic_param [k ] = value
57
+ if hasattr (optimizer , "params_attr" ):
58
+ params_attr = getattr (optimizer , "params_attr" )
59
+ param_key_pair = {}
60
+ if len (params_attr ) != 0 :
61
+ new_params_attr = copy .deepcopy (params_attr )
62
+ for (k1 , v1 ), (k2 , v2 ) in zip (
63
+ params_attr .items (), new_params_attr .items ()
64
+ ):
65
+ if v1 .master_parameter is None :
66
+ v2 .parameter = dic_param [v1 .parameter ]
67
+ else :
68
+ dic_param_for_master_case [k1 ] = k2
69
+ param_key_pair [k1 ] = k2
70
+ if len (dic_param_for_master_case ) != 0 :
71
+ dic_param = dic_param_for_master_case
72
+ for k , v in param_key_pair .items ():
73
+ new_params_attr [dic_param [k ]] = new_params_attr .pop (v )
74
+ setattr (new_optimizer , "params_attr" , new_params_attr )
57
75
76
+ new_optimizer .state .clear ()
58
77
# deep copy param_groups
59
78
for group1 , group2 in zip (optimizer .param_groups , new_optimizer .param_groups ):
60
79
for i , p in enumerate (group1 ["params" ]):
61
- # for the p not in the dic_param case, the new optimizer state will be updated
62
- # in _deep_copy_params_attr because the param here in optimizer state is the master
63
- # parameter of the model, which has ever optimized by ipex.optimize
64
80
if p in dic_param :
65
81
new_model_param = dic_param [p ]
66
82
group2 ["params" ][i ] = new_model_param
67
83
new_optimizer .state [new_model_param ] = copy .deepcopy (
68
84
optimizer .state [p ]
69
85
)
70
86
71
- # deep copy params_attr for reentrancy of ipex.optimize
72
- def _deep_copy_params_attr (old_module , new_module ):
87
+ def _attach_master_weight_split_attr (old_module , new_module ):
73
88
if hasattr (old_module , "master_weight_split" ):
74
89
setattr (
75
90
new_module , "master_weight_split" , old_module .master_weight_split
76
91
)
77
- master_weight_split = getattr (new_module , "master_weight_split" )
78
-
79
- for name , param in old_module .named_parameters ():
80
- if master_weight_split :
81
- attr_name = name + "_trail"
82
- if param in optimizer .params_attr :
83
- new_optimizer .params_attr [
84
- getattr (new_module , name )
85
- ] = optimizer .params_attr [param ]
86
- new_optimizer .params_attr [getattr (new_module , name )][
87
- "trail"
88
- ] = getattr (new_module , attr_name )
89
- else :
90
- attr_name = "master_" + name
91
- old_master_param = getattr (old_module , attr_name )
92
- new_master_param = getattr (new_module , attr_name )
93
- if old_master_param in optimizer .params_attr :
94
- new_optimizer .params_attr [
95
- new_master_param
96
- ] = optimizer .params_attr [old_master_param ]
97
- if (
98
- "bf16_param"
99
- in new_optimizer .params_attr [new_master_param ]
100
- ):
101
- new_optimizer .params_attr [new_master_param ][
102
- "bf16_param"
103
- ] = getattr (new_module , name )
104
- if (
105
- "fp16_param"
106
- in new_optimizer .params_attr [new_master_param ]
107
- ):
108
- new_optimizer .params_attr [new_master_param ][
109
- "fp16_param"
110
- ] = getattr (new_module , name )
111
-
112
- # deep copy new optimizer state for master parameter
113
- new_optimizer .state [new_master_param ] = copy .deepcopy (
114
- optimizer .state [old_master_param ]
115
- )
116
-
117
92
for (_ , old_child ), (_ , new_child ) in zip (
118
93
old_module .named_children (), new_module .named_children ()
119
94
):
120
- _deep_copy_params_attr (old_child , new_child )
121
-
122
- if hasattr (optimizer , "params_attr" ):
123
- params_attr = {}
124
- setattr (new_optimizer , "params_attr" , params_attr )
125
- _deep_copy_params_attr (model , new_model )
95
+ _attach_master_weight_split_attr (old_child , new_child )
126
96
97
+ _attach_master_weight_split_attr (model , new_model )
127
98
return new_model , new_optimizer
128
99
129
100
@@ -587,31 +558,30 @@ def xpu_check_channel_last():
587
558
utils ._weight_prepack .record_input_shape_for_prepack (
588
559
optimized_model , sample_input
589
560
)
590
-
561
+ params_attr = {}
591
562
if not model .training :
592
563
if opt_properties .conv_bn_folding :
593
564
try :
594
- optimized_model = optimization .fuse (optimized_model , inplace = inplace )
565
+ optimized_model = optimization .fuse (optimized_model , inplace = True )
595
566
except : # noqa E722
596
567
warnings .warn (
597
568
"Conv BatchNorm folding failed during the optimize process."
598
569
)
599
570
if opt_properties .linear_bn_folding :
600
571
try :
601
- optimized_model = linear_bn_fuse (optimized_model , inplace = inplace )
572
+ optimized_model = linear_bn_fuse (optimized_model , inplace = True )
602
573
except BaseException :
603
574
warnings .warn (
604
575
"Linear BatchNorm folding failed during the optimize process."
605
576
)
606
577
if opt_properties .replace_dropout_with_identity :
607
578
utils ._model_convert .replace_dropout_with_identity (optimized_model )
608
- if dtype == torch .bfloat16 :
609
- optimized_model = utils ._model_convert .convert_module_data_type (
610
- optimized_model , torch .bfloat16
611
- )
612
- if dtype == torch .half :
613
- optimized_model = utils ._model_convert .convert_module_data_type (
614
- optimized_model , torch .half
579
+ if dtype in (
580
+ torch .bfloat16 ,
581
+ torch .float16 ,
582
+ ):
583
+ params_attr , optimized_model = utils ._model_convert .convert_model_data_type (
584
+ optimized_model , dtype
615
585
)
616
586
617
587
if opt_properties .optimize_lstm :
@@ -654,37 +624,28 @@ def xpu_check_channel_last():
654
624
+ " will use non-fused master weight update for bf16 training on XPU."
655
625
)
656
626
657
- # convert optimizer for training case.
658
- params_attr = {}
659
- if hasattr (optimized_optimizer , "params_attr" ):
660
- params_attr = optimized_optimizer .params_attr
661
- if dtype == torch .bfloat16 and model .training :
662
- (
663
- optimized_model ,
664
- optimized_optimizer ,
665
- params_attr ,
666
- ) = utils ._weight_cast .weight_dtype_convert_with_ipex (
667
- optimized_model ,
668
- optimized_optimizer ,
669
- params_attr ,
670
- opt_properties .split_master_weight_for_bf16 ,
671
- convert_dtype = torch .bfloat16 ,
672
- )
673
- if dtype == torch .half and model .training :
674
- assert (
675
- device_type != "xpu"
676
- ), "For now, XPU device does not support model training with half precision."
677
- (
678
- optimized_model ,
679
- optimized_optimizer ,
680
- params_attr ,
681
- ) = utils ._weight_cast .weight_dtype_convert_with_ipex (
682
- optimized_model ,
683
- optimized_optimizer ,
684
- params_attr ,
685
- False ,
686
- convert_dtype = torch .half ,
687
- )
627
+ if model .training :
628
+ if hasattr (optimized_optimizer , "params_attr" ):
629
+ params_attr = optimized_optimizer .params_attr
630
+ if dtype == torch .float16 :
631
+ assert (
632
+ device_type != "xpu"
633
+ ), "For now, XPU device does not support model training with half precision."
634
+ opt_properties .split_master_weight_for_bf16 = False
635
+ if dtype in (torch .bfloat16 , torch .float16 ):
636
+ # convert optimizer for training case.
637
+ (
638
+ optimized_model ,
639
+ optimized_optimizer ,
640
+ params_attr ,
641
+ ) = utils ._weight_cast .weight_dtype_convert_with_ipex (
642
+ optimized_model ,
643
+ optimized_optimizer ,
644
+ params_attr ,
645
+ opt_properties .split_master_weight_for_bf16 ,
646
+ dtype ,
647
+ )
648
+
688
649
# Since TorchDynamo cannot handle custom operations yet, for the case of inference graph mode,
689
650
# the weights prepacking here is temporarily cancelled, and it will be completed on the graph.
690
651
if opt_properties .weights_prepack :
@@ -704,7 +665,7 @@ def xpu_check_channel_last():
704
665
optimized_optimizer ,
705
666
params_attr ,
706
667
) = utils ._weight_prepack .weight_prepack_with_ipex (
707
- optimized_model , optimized_optimizer , params_attr , inplace , "cpu"
668
+ optimized_model , optimized_optimizer , params_attr , "cpu"
708
669
)
709
670
torch ._dynamo .allow_in_graph (utils ._weight_prepack ._IPEXConv2d )
710
671
torch ._dynamo .allow_in_graph (utils ._weight_prepack ._IPEXConvTranspose2d )
@@ -719,7 +680,7 @@ def xpu_check_channel_last():
719
680
optimized_optimizer ,
720
681
params_attr ,
721
682
) = utils ._weight_prepack .weight_prepack_with_ipex (
722
- optimized_model , optimized_optimizer , params_attr , inplace , "xpu"
683
+ optimized_model , optimized_optimizer , params_attr , "xpu"
723
684
)
724
685
725
686
if opt_properties .graph_mode :
0 commit comments