@@ -39,20 +39,7 @@ bool isQuantized(const BufHandle& qx) {
39
39
return qx.node ()->qscale () && qx.node ()->qzero ();
40
40
}
41
41
42
- BufHandle makeQBufHandleNCHW (
43
- const std::string& name,
44
- const std::vector<ExprHandle>& dims,
45
- Dtype dtype,
46
- const ExprPtr qscale,
47
- const ExprPtr qzero) {
48
- BufHandle ResultBuf (name, dims, dtype);
49
- ResultBuf.node ()->set_qscale (qscale);
50
- ResultBuf.node ()->set_qzero (qzero);
51
- ResultBuf.node ()->set_strides (make_contiguous_strides (dims));
52
- return ResultBuf;
53
- }
54
-
55
- BufHandle makeQBufHandleNHWC (
42
+ BufHandle makeQBufHandleChannelsLast (
56
43
const std::string& name,
57
44
const std::vector<ExprHandle>& dims,
58
45
Dtype dtype,
@@ -65,21 +52,21 @@ BufHandle makeQBufHandleNHWC(
65
52
return ResultBuf;
66
53
}
67
54
68
- BufHandle makeQBufHandleNHWC (
55
+ BufHandle makeQBufHandleChannelsLast (
69
56
const std::string& name,
70
57
const std::vector<ExprHandle>& dims,
71
58
Dtype dtype,
72
59
const double qscale,
73
60
const int64_t qzero) {
74
- return makeQBufHandleNHWC (
61
+ return makeQBufHandleChannelsLast (
75
62
name,
76
63
dims,
77
64
dtype,
78
65
DoubleImm::make (qscale).node (),
79
66
LongImm::make (qzero).node ());
80
67
}
81
68
82
- BufHandle makeQBufHandleNLC (
69
+ BufHandle makeQBufHandleContiguous (
83
70
const std::string& name,
84
71
const std::vector<ExprHandle>& dims,
85
72
Dtype dtype,
@@ -88,62 +75,37 @@ BufHandle makeQBufHandleNLC(
88
75
BufHandle ResultBuf (name, dims, dtype);
89
76
ResultBuf.node ()->set_qscale (qscale);
90
77
ResultBuf.node ()->set_qzero (qzero);
91
- ResultBuf.node ()->set_strides (make_channels_last_strides (dims));
78
+ ResultBuf.node ()->set_strides (make_contiguous_strides (dims));
92
79
return ResultBuf;
93
80
}
94
81
95
- BufHandle makeQBufHandleNLC (
82
+ BufHandle makeQBufHandleContiguous (
96
83
const std::string& name,
97
84
const std::vector<ExprHandle>& dims,
98
85
Dtype dtype,
99
86
const double qscale,
100
87
const int64_t qzero) {
101
- return makeQBufHandleNLC (
88
+ return makeQBufHandleContiguous (
102
89
name,
103
90
dims,
104
91
dtype,
105
92
DoubleImm::make (qscale).node (),
106
93
LongImm::make (qzero).node ());
107
94
}
108
95
109
- BufHandle makeQBufHandleNCHW (
110
- const std::string& name,
111
- const std::vector<ExprHandle>& dims,
112
- Dtype dtype,
113
- const double qscale,
114
- const int64_t qzero) {
115
- return makeQBufHandleNCHW (
116
- name,
117
- dims,
118
- dtype,
119
- DoubleImm::make (qscale).node (),
120
- LongImm::make (qzero).node ());
121
- }
122
-
123
- bool isNHWC (const BufHandle& buf) {
96
+ bool isChannelsLast (const BufHandle& buf) {
124
97
const auto & strides = buf.node ()->strides ();
125
98
const auto & dims = buf.node ()->dims ();
126
- if (strides.size () != 4 ) {
99
+ const auto rank = dims.size ();
100
+ if (rank < 3 ) {
127
101
return false ;
128
102
}
129
- auto dims1 = to<LongImm>(IRSimplifier::simplify (dims[1 ]))->value ();
130
- auto strides1 = to<LongImm>(IRSimplifier::simplify (strides[1 ]))->value ();
131
- auto strides3 = to<LongImm>(IRSimplifier::simplify (strides[3 ]))->value ();
103
+ auto dimsC = to<LongImm>(IRSimplifier::simplify (dims[1 ]))->value ();
104
+ auto stridesC = to<LongImm>(IRSimplifier::simplify (strides[1 ]))->value ();
105
+ auto stridesLast =
106
+ to<LongImm>(IRSimplifier::simplify (strides[rank - 1 ]))->value ();
132
107
133
- return ((strides3 == dims1) && (strides1 == 1 ));
134
- }
135
-
136
- bool isNLC (const BufHandle& buf) {
137
- const auto & strides = buf.node ()->strides ();
138
- const auto & dims = buf.node ()->dims ();
139
- if (strides.size () != 3 ) {
140
- return false ;
141
- }
142
- auto dims1 = to<LongImm>(IRSimplifier::simplify (dims[1 ]))->value ();
143
- auto strides1 = to<LongImm>(IRSimplifier::simplify (strides[1 ]))->value ();
144
- auto strides3 = to<LongImm>(IRSimplifier::simplify (strides[3 ]))->value ();
145
-
146
- return ((strides3 == dims1) && (strides1 == 1 ));
108
+ return ((stridesLast == dimsC) && (stridesC == 1 ));
147
109
}
148
110
149
111
ExprHandle quant (
@@ -273,15 +235,11 @@ Tensor computeQuantizePerTensorExternalCall(
273
235
throw malformed_input (" Expected quantized dtype" );
274
236
}(qdtype);
275
237
auto ResultBuf = [&]() {
276
- if (isNHWC (x)) {
277
- return makeQBufHandleNHWC (
278
- " quantize_per_tensor" , outputShape, dtype, qscale, qzero);
279
- }
280
- if (isNLC (x)) {
281
- return makeQBufHandleNLC (
238
+ if (isChannelsLast (x)) {
239
+ return makeQBufHandleChannelsLast (
282
240
" quantize_per_tensor" , outputShape, dtype, qscale, qzero);
283
241
}
284
- return makeQBufHandleNCHW (
242
+ return makeQBufHandleContiguous (
285
243
" quantize_per_tensor" , outputShape, dtype, qscale, qzero);
286
244
}();
287
245
StmtPtr s = ExternalCall::make (
@@ -376,7 +334,7 @@ Tensor computeQuantizedConv1d(
376
334
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
377
335
// Change to dtype based on outputType when dtype propagation implemented
378
336
const auto out_qdtype = immQDType (qx);
379
- auto ResultBuf = makeQBufHandleNLC (
337
+ auto ResultBuf = makeQBufHandleChannelsLast (
380
338
" quantized_conv1d" ,
381
339
outputShape,
382
340
Dtype (out_qdtype),
@@ -407,7 +365,7 @@ Tensor computeQuantizedConv2d(
407
365
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
408
366
// Change to dtype based on outputType when dtype propagation implemented
409
367
const auto out_qdtype = immQDType (qx);
410
- auto ResultBuf = makeQBufHandleNHWC (
368
+ auto ResultBuf = makeQBufHandleChannelsLast (
411
369
" quantized_conv2d" ,
412
370
outputShape,
413
371
Dtype (out_qdtype),
@@ -438,7 +396,7 @@ Tensor computeQuantizedConv2dRelu(
438
396
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
439
397
// Change to dtype based on outputType when dtype propagation implemented
440
398
const auto out_qdtype = immQDType (qx);
441
- auto ResultBuf = makeQBufHandleNHWC (
399
+ auto ResultBuf = makeQBufHandleChannelsLast (
442
400
" quantized_conv2d_relu" ,
443
401
outputShape,
444
402
Dtype (out_qdtype),
@@ -469,7 +427,7 @@ Tensor computeQuantizedLinear(
469
427
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
470
428
// Change to dtype based on outputType when dtype propagation implemented
471
429
const auto out_qdtype = immQDType (qx);
472
- auto ResultBuf = makeQBufHandleNCHW (
430
+ auto ResultBuf = makeQBufHandleContiguous (
473
431
" quantized_linear" ,
474
432
outputShape,
475
433
Dtype (out_qdtype),
@@ -500,7 +458,7 @@ Tensor computeQuantizedLinearRelu(
500
458
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
501
459
// Change to dtype based on outputType when dtype propagation implemented
502
460
const auto out_qdtype = immQDType (qx);
503
- auto ResultBuf = makeQBufHandleNCHW (
461
+ auto ResultBuf = makeQBufHandleContiguous (
504
462
" quantized_linear_relu" ,
505
463
outputShape,
506
464
Dtype (out_qdtype),
@@ -531,16 +489,16 @@ Tensor computeQuantizedAddExternalCall(
531
489
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
532
490
// Change to dtype based on outputType when dtype propagation implemented
533
491
const auto out_qdtype = immQDType (qa);
534
- const bool isQAChannelsLast = isNHWC (qa);
535
- const bool isQBChannelsLast = isNHWC (qb);
492
+ const bool isQAChannelsLast = isChannelsLast (qa);
493
+ const bool isQBChannelsLast = isChannelsLast (qb);
536
494
auto ResultBuf = (isQAChannelsLast || isQBChannelsLast)
537
- ? makeQBufHandleNHWC (
495
+ ? makeQBufHandleChannelsLast (
538
496
" quantized_add" ,
539
497
outputShape,
540
498
Dtype (out_qdtype),
541
499
out_qscale,
542
500
out_qzero)
543
- : makeQBufHandleNCHW (
501
+ : makeQBufHandleContiguous (
544
502
" quantized_add" ,
545
503
outputShape,
546
504
Dtype (out_qdtype),
@@ -574,7 +532,7 @@ Tensor computeQuantizedMul(
574
532
const auto out_qzero = c10::get<int64_t >(inputs[3 ]);
575
533
// Change to dtype based on outputType when dtype propagation implemented
576
534
const auto out_qdtype = immQDType (qa);
577
- auto ResultBuf = makeQBufHandleNCHW (
535
+ auto ResultBuf = makeQBufHandleContiguous (
578
536
" quantized_mul" , outputShape, Dtype (out_qdtype), out_qscale, out_qzero);
579
537
StmtPtr s = ExternalCall::make (
580
538
ResultBuf,
@@ -603,7 +561,7 @@ Tensor computeQuantizedMulScalar(
603
561
// Change to dtype based on outputType when dtype propagation implemented
604
562
const auto out_qdtype = immQDType (qa);
605
563
double scale1 = immQScale (qa);
606
- auto ResultBuf = makeQBufHandleNCHW (
564
+ auto ResultBuf = makeQBufHandleContiguous (
607
565
" quantized_mul_scalar" ,
608
566
outputShape,
609
567
Dtype (out_qdtype),
@@ -626,14 +584,14 @@ Tensor computeQuantizedRelu(
626
584
at::Device device) {
627
585
const BufHandle& qa = c10::get<BufHandle>(inputs[0 ]);
628
586
const auto out_qdtype = immQDType (qa);
629
- const bool isQAChannelsLast = isNHWC (qa);
630
- auto ResultBuf = isQAChannelsLast ? makeQBufHandleNHWC (
587
+ const bool isQAChannelsLast = isChannelsLast (qa);
588
+ auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast (
631
589
" quantized_relu" ,
632
590
outputShape,
633
591
Dtype (out_qdtype),
634
592
immQScale (qa),
635
593
immQZero (qa))
636
- : makeQBufHandleNCHW (
594
+ : makeQBufHandleContiguous (
637
595
" quantized_relu" ,
638
596
outputShape,
639
597
Dtype (out_qdtype),
@@ -674,7 +632,7 @@ Tensor computeQuantizedCat(
674
632
extra_args.emplace_back (argDim);
675
633
extra_args.emplace_back (out_qscale);
676
634
extra_args.emplace_back (out_qzero);
677
- auto ResultBuf = makeQBufHandleNCHW (
635
+ auto ResultBuf = makeQBufHandleContiguous (
678
636
" quantized_cat" ,
679
637
outputShape,
680
638
Dtype (immQDType (inputList[0 ])),
@@ -793,7 +751,7 @@ Tensor computeUpsampleNearest2dExternalCall(
793
751
794
752
BufHandle ResultBuf = [&]() {
795
753
if (isQuantized (x)) {
796
- return makeQBufHandleNHWC (
754
+ return makeQBufHandleChannelsLast (
797
755
" upsample_nearest2d" ,
798
756
outputShape,
799
757
Dtype (immQDType (x)),
@@ -829,7 +787,7 @@ Tensor computeQuantizedSigmoidExternalCall(
829
787
const double out_qscale = 1 .0f / 256 .0f ;
830
788
const int64_t out_qzero = (out_qdtype == ScalarType::QInt8) ? -128 : 0 ;
831
789
832
- auto ResultBuf = makeQBufHandleNHWC (
790
+ auto ResultBuf = makeQBufHandleChannelsLast (
833
791
" quantized_sigmoid" ,
834
792
outputShape,
835
793
Dtype (out_qdtype),
0 commit comments