Skip to content

Commit 67cd98f

Browse files
IvanKobzarevpytorchmergebot
authored andcommitted
[tensorexpr] Fix isNLC segfault (pytorch#72786)
Summary: Pull Request resolved: pytorch#72786 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D34204523 Pulled By: IvanKobzarev fbshipit-source-id: 9a0f2ce0a1921e261932029c3ebd842330fdf528 (cherry picked from commit b832606)
1 parent d2c0c0b commit 67cd98f

File tree

2 files changed

+67
-77
lines changed

2 files changed

+67
-77
lines changed

test/cpp/tensorexpr/test_quantization.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,38 @@ TEST_F(Quantization, QuantDequantUInt8) {
9090
CHECK_EQ(check, 1);
9191
}
9292

93+
TEST_F(Quantization, QuantDequantUInt8_NLC) {
94+
const auto graph_string = R"IR(
95+
graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)):
96+
%2 : int = prim::Constant[value=13]()
97+
%3 : int = prim::Constant[value=122]()
98+
%4 : float = prim::Constant[value=0.1]()
99+
%q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
100+
%6 : Float(1, 2, 2) = aten::dequantize(%q.1)
101+
return (%6))IR";
102+
auto graph = std::make_shared<Graph>();
103+
parseIR(graph_string, &*graph);
104+
105+
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
106+
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
107+
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
108+
auto y_expected = at::dequantize(q);
109+
TensorExprKernel k(graph);
110+
std::vector<at::Tensor> inputs = {x};
111+
StmtPtr s = k.getCodeGenStmt();
112+
113+
std::vector<IValue> stack = fmap<IValue>(inputs);
114+
k.run(stack);
115+
auto y = stack[0].toTensor();
116+
bool check = at::allclose(y_expected, y);
117+
if (!check) {
118+
std::cout << "x:\n" << x << std::endl;
119+
std::cout << "y_expected:\n" << y_expected << std::endl;
120+
std::cout << "y:\n" << y << std::endl;
121+
}
122+
CHECK_EQ(check, 1);
123+
}
124+
93125
at::Tensor quantized_add(
94126
at::Tensor x1,
95127
at::Tensor x2,

torch/csrc/jit/tensorexpr/operators/quantization.cpp

Lines changed: 35 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,7 @@ bool isQuantized(const BufHandle& qx) {
3939
return qx.node()->qscale() && qx.node()->qzero();
4040
}
4141

42-
BufHandle makeQBufHandleNCHW(
43-
const std::string& name,
44-
const std::vector<ExprHandle>& dims,
45-
Dtype dtype,
46-
const ExprPtr qscale,
47-
const ExprPtr qzero) {
48-
BufHandle ResultBuf(name, dims, dtype);
49-
ResultBuf.node()->set_qscale(qscale);
50-
ResultBuf.node()->set_qzero(qzero);
51-
ResultBuf.node()->set_strides(make_contiguous_strides(dims));
52-
return ResultBuf;
53-
}
54-
55-
BufHandle makeQBufHandleNHWC(
42+
BufHandle makeQBufHandleChannelsLast(
5643
const std::string& name,
5744
const std::vector<ExprHandle>& dims,
5845
Dtype dtype,
@@ -65,21 +52,21 @@ BufHandle makeQBufHandleNHWC(
6552
return ResultBuf;
6653
}
6754

68-
BufHandle makeQBufHandleNHWC(
55+
BufHandle makeQBufHandleChannelsLast(
6956
const std::string& name,
7057
const std::vector<ExprHandle>& dims,
7158
Dtype dtype,
7259
const double qscale,
7360
const int64_t qzero) {
74-
return makeQBufHandleNHWC(
61+
return makeQBufHandleChannelsLast(
7562
name,
7663
dims,
7764
dtype,
7865
DoubleImm::make(qscale).node(),
7966
LongImm::make(qzero).node());
8067
}
8168

82-
BufHandle makeQBufHandleNLC(
69+
BufHandle makeQBufHandleContiguous(
8370
const std::string& name,
8471
const std::vector<ExprHandle>& dims,
8572
Dtype dtype,
@@ -88,62 +75,37 @@ BufHandle makeQBufHandleNLC(
8875
BufHandle ResultBuf(name, dims, dtype);
8976
ResultBuf.node()->set_qscale(qscale);
9077
ResultBuf.node()->set_qzero(qzero);
91-
ResultBuf.node()->set_strides(make_channels_last_strides(dims));
78+
ResultBuf.node()->set_strides(make_contiguous_strides(dims));
9279
return ResultBuf;
9380
}
9481

95-
BufHandle makeQBufHandleNLC(
82+
BufHandle makeQBufHandleContiguous(
9683
const std::string& name,
9784
const std::vector<ExprHandle>& dims,
9885
Dtype dtype,
9986
const double qscale,
10087
const int64_t qzero) {
101-
return makeQBufHandleNLC(
88+
return makeQBufHandleContiguous(
10289
name,
10390
dims,
10491
dtype,
10592
DoubleImm::make(qscale).node(),
10693
LongImm::make(qzero).node());
10794
}
10895

109-
BufHandle makeQBufHandleNCHW(
110-
const std::string& name,
111-
const std::vector<ExprHandle>& dims,
112-
Dtype dtype,
113-
const double qscale,
114-
const int64_t qzero) {
115-
return makeQBufHandleNCHW(
116-
name,
117-
dims,
118-
dtype,
119-
DoubleImm::make(qscale).node(),
120-
LongImm::make(qzero).node());
121-
}
122-
123-
bool isNHWC(const BufHandle& buf) {
96+
bool isChannelsLast(const BufHandle& buf) {
12497
const auto& strides = buf.node()->strides();
12598
const auto& dims = buf.node()->dims();
126-
if (strides.size() != 4) {
99+
const auto rank = dims.size();
100+
if (rank < 3) {
127101
return false;
128102
}
129-
auto dims1 = to<LongImm>(IRSimplifier::simplify(dims[1]))->value();
130-
auto strides1 = to<LongImm>(IRSimplifier::simplify(strides[1]))->value();
131-
auto strides3 = to<LongImm>(IRSimplifier::simplify(strides[3]))->value();
103+
auto dimsC = to<LongImm>(IRSimplifier::simplify(dims[1]))->value();
104+
auto stridesC = to<LongImm>(IRSimplifier::simplify(strides[1]))->value();
105+
auto stridesLast =
106+
to<LongImm>(IRSimplifier::simplify(strides[rank - 1]))->value();
132107

133-
return ((strides3 == dims1) && (strides1 == 1));
134-
}
135-
136-
bool isNLC(const BufHandle& buf) {
137-
const auto& strides = buf.node()->strides();
138-
const auto& dims = buf.node()->dims();
139-
if (strides.size() != 3) {
140-
return false;
141-
}
142-
auto dims1 = to<LongImm>(IRSimplifier::simplify(dims[1]))->value();
143-
auto strides1 = to<LongImm>(IRSimplifier::simplify(strides[1]))->value();
144-
auto strides3 = to<LongImm>(IRSimplifier::simplify(strides[3]))->value();
145-
146-
return ((strides3 == dims1) && (strides1 == 1));
108+
return ((stridesLast == dimsC) && (stridesC == 1));
147109
}
148110

149111
ExprHandle quant(
@@ -273,15 +235,11 @@ Tensor computeQuantizePerTensorExternalCall(
273235
throw malformed_input("Expected quantized dtype");
274236
}(qdtype);
275237
auto ResultBuf = [&]() {
276-
if (isNHWC(x)) {
277-
return makeQBufHandleNHWC(
278-
"quantize_per_tensor", outputShape, dtype, qscale, qzero);
279-
}
280-
if (isNLC(x)) {
281-
return makeQBufHandleNLC(
238+
if (isChannelsLast(x)) {
239+
return makeQBufHandleChannelsLast(
282240
"quantize_per_tensor", outputShape, dtype, qscale, qzero);
283241
}
284-
return makeQBufHandleNCHW(
242+
return makeQBufHandleContiguous(
285243
"quantize_per_tensor", outputShape, dtype, qscale, qzero);
286244
}();
287245
StmtPtr s = ExternalCall::make(
@@ -376,7 +334,7 @@ Tensor computeQuantizedConv1d(
376334
const auto out_qzero = c10::get<int64_t>(inputs[3]);
377335
// Change to dtype based on outputType when dtype propagation implemented
378336
const auto out_qdtype = immQDType(qx);
379-
auto ResultBuf = makeQBufHandleNLC(
337+
auto ResultBuf = makeQBufHandleChannelsLast(
380338
"quantized_conv1d",
381339
outputShape,
382340
Dtype(out_qdtype),
@@ -407,7 +365,7 @@ Tensor computeQuantizedConv2d(
407365
const auto out_qzero = c10::get<int64_t>(inputs[3]);
408366
// Change to dtype based on outputType when dtype propagation implemented
409367
const auto out_qdtype = immQDType(qx);
410-
auto ResultBuf = makeQBufHandleNHWC(
368+
auto ResultBuf = makeQBufHandleChannelsLast(
411369
"quantized_conv2d",
412370
outputShape,
413371
Dtype(out_qdtype),
@@ -438,7 +396,7 @@ Tensor computeQuantizedConv2dRelu(
438396
const auto out_qzero = c10::get<int64_t>(inputs[3]);
439397
// Change to dtype based on outputType when dtype propagation implemented
440398
const auto out_qdtype = immQDType(qx);
441-
auto ResultBuf = makeQBufHandleNHWC(
399+
auto ResultBuf = makeQBufHandleChannelsLast(
442400
"quantized_conv2d_relu",
443401
outputShape,
444402
Dtype(out_qdtype),
@@ -469,7 +427,7 @@ Tensor computeQuantizedLinear(
469427
const auto out_qzero = c10::get<int64_t>(inputs[3]);
470428
// Change to dtype based on outputType when dtype propagation implemented
471429
const auto out_qdtype = immQDType(qx);
472-
auto ResultBuf = makeQBufHandleNCHW(
430+
auto ResultBuf = makeQBufHandleContiguous(
473431
"quantized_linear",
474432
outputShape,
475433
Dtype(out_qdtype),
@@ -500,7 +458,7 @@ Tensor computeQuantizedLinearRelu(
500458
const auto out_qzero = c10::get<int64_t>(inputs[3]);
501459
// Change to dtype based on outputType when dtype propagation implemented
502460
const auto out_qdtype = immQDType(qx);
503-
auto ResultBuf = makeQBufHandleNCHW(
461+
auto ResultBuf = makeQBufHandleContiguous(
504462
"quantized_linear_relu",
505463
outputShape,
506464
Dtype(out_qdtype),
@@ -531,16 +489,16 @@ Tensor computeQuantizedAddExternalCall(
531489
const auto out_qzero = c10::get<int64_t>(inputs[3]);
532490
// Change to dtype based on outputType when dtype propagation implemented
533491
const auto out_qdtype = immQDType(qa);
534-
const bool isQAChannelsLast = isNHWC(qa);
535-
const bool isQBChannelsLast = isNHWC(qb);
492+
const bool isQAChannelsLast = isChannelsLast(qa);
493+
const bool isQBChannelsLast = isChannelsLast(qb);
536494
auto ResultBuf = (isQAChannelsLast || isQBChannelsLast)
537-
? makeQBufHandleNHWC(
495+
? makeQBufHandleChannelsLast(
538496
"quantized_add",
539497
outputShape,
540498
Dtype(out_qdtype),
541499
out_qscale,
542500
out_qzero)
543-
: makeQBufHandleNCHW(
501+
: makeQBufHandleContiguous(
544502
"quantized_add",
545503
outputShape,
546504
Dtype(out_qdtype),
@@ -574,7 +532,7 @@ Tensor computeQuantizedMul(
574532
const auto out_qzero = c10::get<int64_t>(inputs[3]);
575533
// Change to dtype based on outputType when dtype propagation implemented
576534
const auto out_qdtype = immQDType(qa);
577-
auto ResultBuf = makeQBufHandleNCHW(
535+
auto ResultBuf = makeQBufHandleContiguous(
578536
"quantized_mul", outputShape, Dtype(out_qdtype), out_qscale, out_qzero);
579537
StmtPtr s = ExternalCall::make(
580538
ResultBuf,
@@ -603,7 +561,7 @@ Tensor computeQuantizedMulScalar(
603561
// Change to dtype based on outputType when dtype propagation implemented
604562
const auto out_qdtype = immQDType(qa);
605563
double scale1 = immQScale(qa);
606-
auto ResultBuf = makeQBufHandleNCHW(
564+
auto ResultBuf = makeQBufHandleContiguous(
607565
"quantized_mul_scalar",
608566
outputShape,
609567
Dtype(out_qdtype),
@@ -626,14 +584,14 @@ Tensor computeQuantizedRelu(
626584
at::Device device) {
627585
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
628586
const auto out_qdtype = immQDType(qa);
629-
const bool isQAChannelsLast = isNHWC(qa);
630-
auto ResultBuf = isQAChannelsLast ? makeQBufHandleNHWC(
587+
const bool isQAChannelsLast = isChannelsLast(qa);
588+
auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast(
631589
"quantized_relu",
632590
outputShape,
633591
Dtype(out_qdtype),
634592
immQScale(qa),
635593
immQZero(qa))
636-
: makeQBufHandleNCHW(
594+
: makeQBufHandleContiguous(
637595
"quantized_relu",
638596
outputShape,
639597
Dtype(out_qdtype),
@@ -674,7 +632,7 @@ Tensor computeQuantizedCat(
674632
extra_args.emplace_back(argDim);
675633
extra_args.emplace_back(out_qscale);
676634
extra_args.emplace_back(out_qzero);
677-
auto ResultBuf = makeQBufHandleNCHW(
635+
auto ResultBuf = makeQBufHandleContiguous(
678636
"quantized_cat",
679637
outputShape,
680638
Dtype(immQDType(inputList[0])),
@@ -793,7 +751,7 @@ Tensor computeUpsampleNearest2dExternalCall(
793751

794752
BufHandle ResultBuf = [&]() {
795753
if (isQuantized(x)) {
796-
return makeQBufHandleNHWC(
754+
return makeQBufHandleChannelsLast(
797755
"upsample_nearest2d",
798756
outputShape,
799757
Dtype(immQDType(x)),
@@ -829,7 +787,7 @@ Tensor computeQuantizedSigmoidExternalCall(
829787
const double out_qscale = 1.0f / 256.0f;
830788
const int64_t out_qzero = (out_qdtype == ScalarType::QInt8) ? -128 : 0;
831789

832-
auto ResultBuf = makeQBufHandleNHWC(
790+
auto ResultBuf = makeQBufHandleChannelsLast(
833791
"quantized_sigmoid",
834792
outputShape,
835793
Dtype(out_qdtype),

0 commit comments

Comments
 (0)