Skip to content

Commit 4fa3794

Browse files
authored
Shared param (#1664)
* enable shared parameter support for ipex.optimize * auto format flake8 * manually fix format * fix deepspeed prepack * only use module_cls instead of use module in ParameterWrapper * use module_para/optimizer_para instead of current_para/low_precision_cache in ParamWrapper * auto format flake8 * manually fix flake8 error * add TODO for mkl linear investigation * fix setattr * use parameter/master_parameter instead of module_parameter/optimizer_parameter * Revert "fix setattr" This reverts commit 177f3ff8dc8747519d8f5e4cff0374016a724dd9. * Revert "manually fix flake8 error" This reverts commit 50aecff126aeacac5ebcad9263cabd0a4cbc842a. * Revert "auto format flake8" This reverts commit 2284fb8e4da9481e64bc9bd5326a8a330275f9de. * Revert "manually fix format" This reverts commit f560177e52f5a3b60dcced74f8a947c6ceba0aaa. * clean up codes * clean up code and fix * clean up code
1 parent b912405 commit 4fa3794

File tree

12 files changed

+1168
-1043
lines changed

12 files changed

+1168
-1043
lines changed

intel_extension_for_pytorch/frontend.py

Lines changed: 56 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -50,80 +50,51 @@ def _copy_model_and_optimizer(model, optimizer):
5050
return new_model, optimizer
5151
else:
5252
new_optimizer = copy.deepcopy(optimizer)
53-
new_optimizer.state.clear()
5453
dic_param = {}
54+
dic_param_for_master_case = {}
5555
for k, value in zip(model.parameters(), new_model.parameters()):
5656
dic_param[k] = value
57+
if hasattr(optimizer, "params_attr"):
58+
params_attr = getattr(optimizer, "params_attr")
59+
param_key_pair = {}
60+
if len(params_attr) != 0:
61+
new_params_attr = copy.deepcopy(params_attr)
62+
for (k1, v1), (k2, v2) in zip(
63+
params_attr.items(), new_params_attr.items()
64+
):
65+
if v1.master_parameter is None:
66+
v2.parameter = dic_param[v1.parameter]
67+
else:
68+
dic_param_for_master_case[k1] = k2
69+
param_key_pair[k1] = k2
70+
if len(dic_param_for_master_case) != 0:
71+
dic_param = dic_param_for_master_case
72+
for k, v in param_key_pair.items():
73+
new_params_attr[dic_param[k]] = new_params_attr.pop(v)
74+
setattr(new_optimizer, "params_attr", new_params_attr)
5775

76+
new_optimizer.state.clear()
5877
# deep copy param_groups
5978
for group1, group2 in zip(optimizer.param_groups, new_optimizer.param_groups):
6079
for i, p in enumerate(group1["params"]):
61-
# for the p not in the dic_param case, the new optimizer state will be updated
62-
# in _deep_copy_params_attr because the param here in optimizer state is the master
63-
# parameter of the model, which has ever optimized by ipex.optimize
6480
if p in dic_param:
6581
new_model_param = dic_param[p]
6682
group2["params"][i] = new_model_param
6783
new_optimizer.state[new_model_param] = copy.deepcopy(
6884
optimizer.state[p]
6985
)
7086

71-
# deep copy params_attr for reentrancy of ipex.optimize
72-
def _deep_copy_params_attr(old_module, new_module):
87+
def _attach_master_weight_split_attr(old_module, new_module):
7388
if hasattr(old_module, "master_weight_split"):
7489
setattr(
7590
new_module, "master_weight_split", old_module.master_weight_split
7691
)
77-
master_weight_split = getattr(new_module, "master_weight_split")
78-
79-
for name, param in old_module.named_parameters():
80-
if master_weight_split:
81-
attr_name = name + "_trail"
82-
if param in optimizer.params_attr:
83-
new_optimizer.params_attr[
84-
getattr(new_module, name)
85-
] = optimizer.params_attr[param]
86-
new_optimizer.params_attr[getattr(new_module, name)][
87-
"trail"
88-
] = getattr(new_module, attr_name)
89-
else:
90-
attr_name = "master_" + name
91-
old_master_param = getattr(old_module, attr_name)
92-
new_master_param = getattr(new_module, attr_name)
93-
if old_master_param in optimizer.params_attr:
94-
new_optimizer.params_attr[
95-
new_master_param
96-
] = optimizer.params_attr[old_master_param]
97-
if (
98-
"bf16_param"
99-
in new_optimizer.params_attr[new_master_param]
100-
):
101-
new_optimizer.params_attr[new_master_param][
102-
"bf16_param"
103-
] = getattr(new_module, name)
104-
if (
105-
"fp16_param"
106-
in new_optimizer.params_attr[new_master_param]
107-
):
108-
new_optimizer.params_attr[new_master_param][
109-
"fp16_param"
110-
] = getattr(new_module, name)
111-
112-
# deep copy new optimizer state for master parameter
113-
new_optimizer.state[new_master_param] = copy.deepcopy(
114-
optimizer.state[old_master_param]
115-
)
116-
11792
for (_, old_child), (_, new_child) in zip(
11893
old_module.named_children(), new_module.named_children()
11994
):
120-
_deep_copy_params_attr(old_child, new_child)
121-
122-
if hasattr(optimizer, "params_attr"):
123-
params_attr = {}
124-
setattr(new_optimizer, "params_attr", params_attr)
125-
_deep_copy_params_attr(model, new_model)
95+
_attach_master_weight_split_attr(old_child, new_child)
12696

97+
_attach_master_weight_split_attr(model, new_model)
12798
return new_model, new_optimizer
12899

129100

@@ -587,31 +558,30 @@ def xpu_check_channel_last():
587558
utils._weight_prepack.record_input_shape_for_prepack(
588559
optimized_model, sample_input
589560
)
590-
561+
params_attr = {}
591562
if not model.training:
592563
if opt_properties.conv_bn_folding:
593564
try:
594-
optimized_model = optimization.fuse(optimized_model, inplace=inplace)
565+
optimized_model = optimization.fuse(optimized_model, inplace=True)
595566
except: # noqa E722
596567
warnings.warn(
597568
"Conv BatchNorm folding failed during the optimize process."
598569
)
599570
if opt_properties.linear_bn_folding:
600571
try:
601-
optimized_model = linear_bn_fuse(optimized_model, inplace=inplace)
572+
optimized_model = linear_bn_fuse(optimized_model, inplace=True)
602573
except BaseException:
603574
warnings.warn(
604575
"Linear BatchNorm folding failed during the optimize process."
605576
)
606577
if opt_properties.replace_dropout_with_identity:
607578
utils._model_convert.replace_dropout_with_identity(optimized_model)
608-
if dtype == torch.bfloat16:
609-
optimized_model = utils._model_convert.convert_module_data_type(
610-
optimized_model, torch.bfloat16
611-
)
612-
if dtype == torch.half:
613-
optimized_model = utils._model_convert.convert_module_data_type(
614-
optimized_model, torch.half
579+
if dtype in (
580+
torch.bfloat16,
581+
torch.float16,
582+
):
583+
params_attr, optimized_model = utils._model_convert.convert_model_data_type(
584+
optimized_model, dtype
615585
)
616586

617587
if opt_properties.optimize_lstm:
@@ -654,37 +624,28 @@ def xpu_check_channel_last():
654624
+ " will use non-fused master weight update for bf16 training on XPU."
655625
)
656626

657-
# convert optimizer for training case.
658-
params_attr = {}
659-
if hasattr(optimized_optimizer, "params_attr"):
660-
params_attr = optimized_optimizer.params_attr
661-
if dtype == torch.bfloat16 and model.training:
662-
(
663-
optimized_model,
664-
optimized_optimizer,
665-
params_attr,
666-
) = utils._weight_cast.weight_dtype_convert_with_ipex(
667-
optimized_model,
668-
optimized_optimizer,
669-
params_attr,
670-
opt_properties.split_master_weight_for_bf16,
671-
convert_dtype=torch.bfloat16,
672-
)
673-
if dtype == torch.half and model.training:
674-
assert (
675-
device_type != "xpu"
676-
), "For now, XPU device does not support model training with half precision."
677-
(
678-
optimized_model,
679-
optimized_optimizer,
680-
params_attr,
681-
) = utils._weight_cast.weight_dtype_convert_with_ipex(
682-
optimized_model,
683-
optimized_optimizer,
684-
params_attr,
685-
False,
686-
convert_dtype=torch.half,
687-
)
627+
if model.training:
628+
if hasattr(optimized_optimizer, "params_attr"):
629+
params_attr = optimized_optimizer.params_attr
630+
if dtype == torch.float16:
631+
assert (
632+
device_type != "xpu"
633+
), "For now, XPU device does not support model training with half precision."
634+
opt_properties.split_master_weight_for_bf16 = False
635+
if dtype in (torch.bfloat16, torch.float16):
636+
# convert optimizer for training case.
637+
(
638+
optimized_model,
639+
optimized_optimizer,
640+
params_attr,
641+
) = utils._weight_cast.weight_dtype_convert_with_ipex(
642+
optimized_model,
643+
optimized_optimizer,
644+
params_attr,
645+
opt_properties.split_master_weight_for_bf16,
646+
dtype,
647+
)
648+
688649
# Since TorchDynamo cannot handle custom operations yet, for the case of inference graph mode,
689650
# the weights prepacking here is temporarily cancelled, and it will be completed on the graph.
690651
if opt_properties.weights_prepack:
@@ -704,7 +665,7 @@ def xpu_check_channel_last():
704665
optimized_optimizer,
705666
params_attr,
706667
) = utils._weight_prepack.weight_prepack_with_ipex(
707-
optimized_model, optimized_optimizer, params_attr, inplace, "cpu"
668+
optimized_model, optimized_optimizer, params_attr, "cpu"
708669
)
709670
torch._dynamo.allow_in_graph(utils._weight_prepack._IPEXConv2d)
710671
torch._dynamo.allow_in_graph(utils._weight_prepack._IPEXConvTranspose2d)
@@ -719,7 +680,7 @@ def xpu_check_channel_last():
719680
optimized_optimizer,
720681
params_attr,
721682
) = utils._weight_prepack.weight_prepack_with_ipex(
722-
optimized_model, optimized_optimizer, params_attr, inplace, "xpu"
683+
optimized_model, optimized_optimizer, params_attr, "xpu"
723684
)
724685

725686
if opt_properties.graph_mode:

intel_extension_for_pytorch/nn/utils/_model_convert.py

Lines changed: 41 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22
import copy
3-
import warnings
4-
import types
5-
63
from torch.nn.utils.rnn import PackedSequence
7-
4+
from ._parameter_wrapper import get_shared_parameter_status
5+
import contextlib
6+
import types
87

98
class _LSTM(torch.nn.LSTM):
109
# This is a solution to swap the lstm module with the ipex counterpart
@@ -127,91 +126,44 @@ def replace_dropout_with_identity(model):
127126
replace_dropout_with_identity(child)
128127

129128

130-
def _save_to_state_dict(self, destination, prefix, keep_vars):
131-
# convert weights(bias) of module to float while saving check point
132-
param_dict = {}
133-
for name, para in self.named_parameters():
134-
if not hasattr(self, name):
135-
continue
136-
param_dict.update({name: para})
137-
temp_param = torch.nn.Parameter(
138-
para.to(torch.float), requires_grad=para.requires_grad
139-
)
140-
setattr(self, name, temp_param)
141-
super(type(self), self)._save_to_state_dict(destination, prefix, keep_vars)
142-
for p in param_dict:
143-
origin_param = param_dict[p]
144-
setattr(self, p, origin_param)
145-
146-
147-
def convert_module_data_type(module, dtype):
148-
# convert weights(bias) of module to dtype to reduce dtype reorder
129+
def convert_model_data_type(model, dtype):
130+
# convert weights(bias) of model to dtype to reduce dtype reorder
149131
assert dtype in [
150132
torch.bfloat16,
151133
torch.float16,
152-
], "module convert only support bf16 and fp16"
153-
module_convert_list_bf16 = [
154-
torch.nn.Conv2d,
155-
torch.nn.Conv3d,
156-
torch.nn.ConvTranspose2d,
157-
torch.nn.ConvTranspose3d,
158-
torch.nn.Linear,
159-
torch.nn.Embedding,
160-
torch.nn.LSTM,
161-
]
162-
163-
module_convert_list_fp16 = [
164-
torch.nn.Conv1d,
165-
torch.nn.Conv2d,
166-
torch.nn.Conv3d,
167-
torch.nn.Linear,
168-
]
169-
170-
module_convert_lists = {
171-
torch.bfloat16: module_convert_list_bf16,
172-
torch.float16: module_convert_list_fp16,
173-
}
174-
175-
for module_cls in module_convert_lists[dtype]:
176-
if isinstance(module, module_cls):
177-
setattr(
178-
module,
179-
"_save_to_state_dict",
180-
types.MethodType(_save_to_state_dict, module),
181-
)
182-
if module_cls is torch.nn.LSTM:
183-
for name, param in module.named_parameters():
184-
ori_data = getattr(getattr(module, name), "data")
185-
ori_data_dtype = ori_data.dtype
186-
if (
187-
ori_data_dtype == torch.float
188-
or ori_data_dtype == torch.bfloat16
189-
):
190-
casted_data = ori_data.detach().clone().to(dtype)
191-
setattr(getattr(module, name), "data", casted_data)
192-
else:
193-
warnings.warn(
194-
f"WARNING: Can't convert model's parameters dtyep from {ori_data_dtype} to {dtype}"
195-
)
196-
break
197-
else:
198-
ori_data_dtype = module.weight.dtype
199-
# Assume weight and bias have same dtype, only need check weight dtype here.
200-
if (
201-
ori_data_dtype == torch.float
202-
or ori_data_dtype == torch.bfloat16
203-
or ori_data_dtype == torch.half
204-
):
205-
weight_data = module.weight.detach().clone().to(dtype)
206-
module.weight.data = weight_data
207-
if hasattr(module, "bias") and module.bias is not None:
208-
bias_data = module.bias.detach().clone().to(dtype)
209-
module.bias.data = bias_data
210-
else:
211-
warnings.warn(
212-
f"WARNING: Can't convert model's parameters dtype from {ori_data_dtype} to {dtype}"
213-
)
214-
break
215-
for child in module.children():
216-
convert_module_data_type(child, dtype)
217-
return module
134+
], "model convert only support bf16 and fp16"
135+
136+
params_attr = {}
137+
get_shared_parameter_status(model, params_attr)
138+
139+
for _, param in model.named_parameters():
140+
if param is None:
141+
continue
142+
if params_attr[param].can_cast_inference(dtype):
143+
params_attr[param].cast_for_inference(dtype)
144+
145+
def patch_state_dict():
146+
def cast_back_state_dict(
147+
self, *args, destination=None, prefix="", keep_vars=False
148+
):
149+
with torch.no_grad(), contextlib.ExitStack() as stack:
150+
for v in params_attr.values():
151+
stack.enter_context(v.inference_cast_save())
152+
out = self._original_state_dict(
153+
*args,
154+
destination=destination,
155+
prefix=prefix,
156+
keep_vars=keep_vars
157+
)
158+
return out
159+
160+
if not hasattr(model, "_original_state_dict"):
161+
setattr(model, "_original_state_dict", model.state_dict)
162+
setattr(
163+
model,
164+
"state_dict",
165+
types.MethodType(cast_back_state_dict, model),
166+
)
167+
168+
patch_state_dict()
169+
return params_attr, model

0 commit comments

Comments
 (0)