Skip to content

Commit d2fe9ba

Browse files
authored
support load_state_dict after ipex.optimize (#1326)
* support load_state_dict after ipex.optimize * remove un-used import * fix inference test for NotEqual * add inplace test && remove unneccessary check * improve comments
1 parent a6fafec commit d2fe9ba

14 files changed

+517
-100
lines changed

csrc/jit/cpu/kernels/ContextConvTranspose.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct ContextConvTranspose final {
1414
// at_weight will share same memory with weight_packed_
1515
// at_weight is used for autograd and optimizer update
1616
at::Tensor at_weight_;
17-
c10::optional<at::Tensor> bias_;
17+
c10::optional<at::Tensor> at_bias_;
1818
// paddings_, strided_, dilation_, output_padding_ here are expanded and
1919
// might different with those stored on ConvTransposeOpContext.
2020
// For example, aten deconv2d can accept padding = 2, but onednn deconv2d need
@@ -48,7 +48,7 @@ struct ContextConvTranspose final {
4848
: original_desc_(std::move(original_desc)),
4949
weight_packed_(std::move(weight_packed)),
5050
at_weight_(std::move(at_weight)),
51-
bias_(std::move(bias)),
51+
at_bias_(std::move(bias)),
5252
padding_(padding),
5353
output_padding_(output_padding),
5454
stride_(stride),

csrc/jit/cpu/kernels/ContextLinear.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct ContextLinear final {
1313
// at_weight will share same memory with weight_packed_
1414
// at_weight is used for autograd and optimizer update
1515
at::Tensor at_weight_;
16-
c10::optional<at::Tensor> bias_;
16+
c10::optional<at::Tensor> at_bias_;
1717

1818
ContextLinear() = delete;
1919

@@ -25,7 +25,7 @@ struct ContextLinear final {
2525
: original_desc_(std::move(original_desc)),
2626
weight_packed_(std::move(weight_packed)),
2727
at_weight_(std::move(at_weight)),
28-
bias_(std::move(bias)) {}
28+
at_bias_(std::move(bias)) {}
2929

3030
ContextLinear(ContextLinear&&) = default;
3131
ContextLinear& operator=(ContextLinear&&) = default;

csrc/jit/cpu/kernels/ContextLinearMKL.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ namespace cpu {
99
namespace detail {
1010
struct ContextLinearMKL final {
1111
std::vector<int64_t> sgemm_sizes_ = {0, 0, 0};
12-
at::Tensor mkl_weight_;
13-
at::Tensor ori_weight_;
14-
c10::optional<at::Tensor> bias_;
12+
at::Tensor at_weight_; // packed at weight
13+
at::Tensor ori_weight_; // non-packed at weight
14+
c10::optional<at::Tensor> at_bias_;
1515

1616
ContextLinearMKL() = delete;
1717

@@ -21,9 +21,9 @@ struct ContextLinearMKL final {
2121
at::Tensor&& ori_weight,
2222
c10::optional<at::Tensor>&& bias)
2323
: sgemm_sizes_(std::move(sgemm_sizes)),
24-
mkl_weight_(std::move(mkl_weight)),
24+
at_weight_(std::move(mkl_weight)),
2525
ori_weight_(std::move(ori_weight)),
26-
bias_(std::move(bias)) {}
26+
at_bias_(std::move(bias)) {}
2727

2828
ContextLinearMKL(ContextLinearMKL&&) = default;
2929
ContextLinearMKL& operator=(ContextLinearMKL&&) = default;

csrc/jit/cpu/kernels/ConvTransposePacked.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ at::Tensor run(
352352
check_shape_forward(
353353
input_.sizes(),
354354
context.origin_weight_dims_,
355-
context.bias_,
355+
context.at_bias_,
356356
context.padding_,
357357
context.stride_,
358358
context.dilation_,
@@ -361,7 +361,7 @@ at::Tensor run(
361361
return conv_transpose_kernel_impl(
362362
input_,
363363
context.weight_packed_,
364-
context.bias_,
364+
context.at_bias_,
365365
context.stride_,
366366
context.padding_,
367367
context.output_padding_,
@@ -397,7 +397,7 @@ at::Tensor& run(
397397
check_shape_forward(
398398
input_.sizes(),
399399
context.origin_weight_dims_,
400-
context.bias_,
400+
context.at_bias_,
401401
context.padding_,
402402
context.stride_,
403403
context.dilation_,
@@ -406,7 +406,7 @@ at::Tensor& run(
406406
conv_transpose_out_kernel_impl(
407407
input_,
408408
context.weight_packed_,
409-
context.bias_,
409+
context.at_bias_,
410410
accumu,
411411
context.stride_,
412412
context.padding_,

csrc/jit/cpu/kernels/LinearMKLPacked.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ at::Tensor run(ContextLinearMKL& context, const at::Tensor& input) {
6060
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
6161
auto input_ = input.contiguous();
6262
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
63-
at::borrow_from_optional_tensor(context.bias_);
63+
at::borrow_from_optional_tensor(context.at_bias_);
6464
const at::Tensor& bias = *bias_maybe_owned;
6565
int64_t input_batch = (int64_t)(input_.numel() / K);
6666

@@ -71,7 +71,7 @@ at::Tensor run(ContextLinearMKL& context, const at::Tensor& input) {
7171
if (input_batch != context.sgemm_sizes_[0])
7272
return mkl_sgemm_kernel(input_, context.ori_weight_, bias);
7373
return mkl_prepack_sgemm_kernel(
74-
input_, context.mkl_weight_, bias, context.sgemm_sizes_[2]);
74+
input_, context.at_weight_, bias, context.sgemm_sizes_[2]);
7575
}
7676

7777
at::Tensor& run(
@@ -84,14 +84,14 @@ at::Tensor& run(
8484
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
8585
auto input_ = input.contiguous();
8686
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
87-
at::borrow_from_optional_tensor(context.bias_);
87+
at::borrow_from_optional_tensor(context.at_bias_);
8888
const at::Tensor& bias = *bias_maybe_owned;
8989
int64_t input_batch = (int64_t)(input_.numel() / K);
9090
if (input_batch != context.sgemm_sizes_[0]) {
9191
mkl_sgemm_kernel_output(input_, context.ori_weight_, bias, accumu);
9292
} else {
9393
mkl_prepack_sgemm_kernel_output(
94-
input_, context.mkl_weight_, bias, context.sgemm_sizes_[2], accumu);
94+
input_, context.at_weight_, bias, context.sgemm_sizes_[2], accumu);
9595
}
9696
return accumu;
9797
}

csrc/jit/cpu/kernels/LinearPacked.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ at::Tensor run(
218218
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
219219
auto input_ = input.contiguous();
220220
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
221-
at::borrow_from_optional_tensor(context.bias_);
221+
at::borrow_from_optional_tensor(context.at_bias_);
222222
const at::Tensor& bias = *bias_maybe_owned;
223223
return linear_kernel(input_, context.weight_packed_, bias, attr);
224224
}
@@ -233,7 +233,7 @@ at::Tensor& run(
233233
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
234234
auto input_ = input.contiguous();
235235
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
236-
at::borrow_from_optional_tensor(context.bias_);
236+
at::borrow_from_optional_tensor(context.at_bias_);
237237
const at::Tensor& bias = *bias_maybe_owned;
238238
linear_kernel_output(input_, context.weight_packed_, bias, accumu, attr);
239239
return accumu;
@@ -250,8 +250,8 @@ void run_core(
250250
TORCH_CHECK(
251251
input.size(input.dim() - 1) == context.weight_packed_.get_dims()[1],
252252
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
253-
if (context.bias_) {
254-
auto mkl_bias = itensor_view_from_dense(*context.bias_);
253+
if (context.at_bias_) {
254+
auto mkl_bias = itensor_view_from_dense(*context.at_bias_);
255255
ideep::inner_product_forward::prepare(
256256
param,
257257
mkldnn_input,
@@ -280,7 +280,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> run_backward(
280280
context.at_weight_,
281281
output_mask,
282282
context.weight_packed_,
283-
context.bias_);
283+
context.at_bias_);
284284
}
285285

286286
at::Tensor pack(ContextLinear& context, const at::Tensor& tensor) {

csrc/jit/cpu/kernels/OpContext.cpp

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@
88
namespace torch_ipex {
99
namespace cpu {
1010

11+
template <typename T1, typename T2>
12+
void load_from_ctx_template(T1* self, c10::intrusive_ptr<T2> other) {
13+
auto& other_ctx_ = other->get_context();
14+
auto loaded_weight = other_ctx_.at_weight_;
15+
auto loaded_bias = other_ctx_.at_bias_;
16+
self->get_context().at_weight_.copy_(loaded_weight);
17+
if (loaded_bias.has_value()) {
18+
self->get_context().at_bias_.value().copy_(loaded_bias.value());
19+
}
20+
return;
21+
}
22+
23+
template <>
24+
void load_from_ctx_template<IpexLinearMKLOpContext, MKLOpContext>(
25+
IpexLinearMKLOpContext* self,
26+
c10::intrusive_ptr<MKLOpContext> other) {
27+
auto& other_ctx_ = other->get_context();
28+
auto loaded_weight = other_ctx_.at_weight_;
29+
auto loaded_bias = other_ctx_.at_bias_;
30+
self->get_context().at_weight_.copy_(loaded_weight);
31+
if (loaded_bias.has_value()) {
32+
self->get_context().at_bias_.value().copy_(loaded_bias.value());
33+
}
34+
self->get_context().ori_weight_.copy_(other->get_context().ori_weight_);
35+
return;
36+
}
1137
c10::intrusive_ptr<ConvolutionOpContext> IpexConvolutionOpContext::
1238
create_context(
1339
at::Tensor&& weight,
@@ -99,6 +125,11 @@ at::Tensor IpexConvolutionOpContext::get_data_handle() {
99125
return ptr;
100126
}
101127

128+
void IpexConvolutionOpContext::load_from_ctx(
129+
c10::intrusive_ptr<ConvolutionOpContext> other) {
130+
load_from_ctx_template(this, other);
131+
}
132+
102133
c10::intrusive_ptr<LinearOpContext> IpexLinearOpContext::create_context(
103134
at::Tensor&& weight,
104135
c10::optional<at::Tensor>&& bias,
@@ -153,6 +184,11 @@ at::Tensor IpexLinearOpContext::to_public(const at::Tensor& tensor) {
153184
return torch_ipex::cpu::detail::linear::unpack(op_context_, tensor);
154185
}
155186

187+
void IpexLinearOpContext::load_from_ctx(
188+
c10::intrusive_ptr<LinearOpContext> other) {
189+
load_from_ctx_template(this, other);
190+
}
191+
156192
c10::intrusive_ptr<ConvTransposeOpContext> IpexConvTransposeOpContext::
157193
create_context(
158194
at::Tensor&& weight,
@@ -194,7 +230,7 @@ c10::intrusive_ptr<MKLOpContext> IpexLinearMKLOpContext::create_context(
194230
}
195231

196232
at::Tensor IpexLinearMKLOpContext::get_at_packed_weight() {
197-
return op_context_.mkl_weight_;
233+
return op_context_.at_weight_;
198234
}
199235

200236
at::Tensor IpexLinearMKLOpContext::get_data_handle() {
@@ -221,7 +257,7 @@ at::Tensor IpexLinearMKLOpContext::to_public(const at::Tensor& tensor) {
221257
return op_context_.ori_weight_.clone();
222258
}
223259

224-
detail::ContextLinearMKL& IpexLinearMKLOpContext::get_mkl_context() {
260+
detail::ContextLinearMKL& IpexLinearMKLOpContext::get_context() {
225261
return op_context_;
226262
}
227263

@@ -233,6 +269,11 @@ int64_t IpexLinearMKLOpContext::get_in_features() {
233269
return op_context_.sgemm_sizes_[1];
234270
}
235271

272+
void IpexLinearMKLOpContext::load_from_ctx(
273+
c10::intrusive_ptr<MKLOpContext> other) {
274+
load_from_ctx_template(this, other);
275+
}
276+
236277
at::Tensor IpexConvTransposeOpContext::run(
237278
const at::Tensor& input,
238279
const ideep::attr_t& attr) {
@@ -288,5 +329,10 @@ detail::ContextConvTranspose& IpexConvTransposeOpContext::get_context() {
288329
return op_context_;
289330
}
290331

332+
void IpexConvTransposeOpContext::load_from_ctx(
333+
c10::intrusive_ptr<ConvTransposeOpContext> other) {
334+
load_from_ctx_template(this, other);
335+
}
336+
291337
} // namespace cpu
292338
} // namespace torch_ipex

0 commit comments

Comments
 (0)