Skip to content

Commit 52b69fb

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Remove _dequantize_per_channel in the pattern (pytorch#26680)
Summary: Pull Request resolved: pytorch#26680 This was introduced before under the assumption that we'll have a qconv_per_tensor_affine and a qconv_per_channel_affine, but turns out we don't have these, so we'll remove thse functions. Test Plan: python test/test_jit.py 'TestJit.test_quant_fusion' Imported from OSS Differential Revision: D17542607 fbshipit-source-id: b90ce5738170f0922bdc2eb1c4dbecd930f68a48
1 parent cf272d4 commit 52b69fb

File tree

2 files changed

+52
-98
lines changed

2 files changed

+52
-98
lines changed

test/test_jit.py

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,12 +1169,10 @@ def get_forward(m):
11691169
.check("return") \
11701170
.run(str(get_forward(m).graph))
11711171
FileCheck().check("aten::quantize_per_tensor") \
1172-
.check_next("aten::int_repr") \
1173-
.check_next("aten::_dequantize_per_tensor") \
1172+
.check_next("aten::dequantize") \
11741173
.check("aten::conv2d") \
11751174
.check("aten::quantize_per_tensor") \
1176-
.check_next("aten::int_repr") \
1177-
.check_next("aten::_dequantize_per_tensor") \
1175+
.check_next("aten::dequantize") \
11781176
.check("return") \
11791177
.run(str(m._c._get_module('conv')._get_method('conv2d_forward').graph))
11801178

@@ -1185,68 +1183,48 @@ def test_quant_fusion(self):
11851183
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
11861184
%r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f):
11871185
%a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1188-
# CHECK-NOT: aten::int_repr
1189-
%a_intrepr = aten::int_repr(%a_quant)
1190-
# CHECK-NOT: aten::_dequantize_per_tensor
1191-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
1186+
# CHECK-NOT: aten::dequantize
1187+
%a_dequant = aten::dequantize(%a_quant)
11921188
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
1193-
# CHECK-NOT: aten::int_repr
1194-
%w_intrepr = aten::int_repr(%w_quant)
1195-
# CHECK-NOT: aten::_dequantize_per_tensor
1196-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
1189+
# CHECK-NOT: aten::dequantize
1190+
%w_dequant = aten::dequantize(%w_quant)
11971191
# CHECK: quantized::conv_prepack
11981192
# CHECK: quantized::conv2d
11991193
# CHECK-NOT: aten::conv2d
12001194
%r = aten::conv2d(%a_dequant, %w_dequant, %b, %c, %d, %e, %f)
12011195
# CHECK-NOT: aten::quantize_per_tensor
12021196
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
1203-
# CHECK: aten::int_repr
1204-
%r_intrepr = aten::int_repr(%r_quant)
1205-
# CHECK: aten::_dequantize_per_tensor
1206-
%r_dequant = aten::_dequantize_per_tensor(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1197+
# CHECK: aten::dequantize
1198+
%r_dequant = aten::dequantize(%r_quant)
12071199
return (%r_dequant)""",
12081200
# addmm -> quantized::linear
12091201
"""
1210-
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
1211-
%r_scale, %r_zero_point, %r_dtype, %4):
1202+
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
12121203
%a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1213-
# CHECK-NOT: aten::int_repr
1214-
%a_intrepr = aten::int_repr(%a_quant)
1215-
# CHECK-NOT: aten::_dequantize_per_tensor
1216-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
1204+
# CHECK-NOT: aten::dequantize
1205+
%a_dequant = aten::dequantize(%a_quant)
12171206
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
1218-
# CHECK-NOT: aten::int_repr
1219-
%w_intrepr = aten::int_repr(%w_quant)
1220-
# CHECK-NOT: aten::_dequantize_per_tensor
1221-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
1207+
# CHECK-NOT: aten::dequantize
1208+
%w_dequant = aten::dequantize(%w_quant)
12221209
# CHECK: aten::t
12231210
# CHECK: quantized::linear_prepack
12241211
# CHECK: quantized::linear
12251212
# CHECK-NOT: aten::addmm
12261213
%r = aten::addmm(%b, %a_dequant, %w_dequant, %4, %4)
12271214
# CHECK-NOT: aten::quantize_per_tensor
12281215
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
1229-
# CHECK: aten::int_repr
1230-
%r_intrepr = aten::int_repr(%r_quant)
1231-
# CHECK: aten::_dequantize_per_tensor
1232-
%r_dequant = aten::_dequantize_per_tensor(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1216+
# CHECK: aten::dequantize
1217+
%r_dequant = aten::dequantize(%r_quant)
12331218
return (%r_dequant)""",
12341219
# matmul(with bias) -> quantized::linear
12351220
"""
1236-
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
1237-
%r_scale, %r_zero_point, %r_dtype, %4):
1221+
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
12381222
%a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1239-
# CHECK-NOT: aten::int_repr
1240-
%a_intrepr = aten::int_repr(%a_quant)
1241-
# CHECK-NOT: aten::_dequantize_per_tensor
1242-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
1223+
# CHECK-NOT: aten::dequantize
1224+
%a_dequant = aten::dequantize(%a_quant)
12431225
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
1244-
# CHECK-NOT: aten::int_repr
1245-
%w_intrepr = aten::int_repr(%w_quant)
1246-
# CHECK-NOT: aten::_dequantize_per_tensor
1247-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
1248-
# CHECK-NOT: aten::int_repr
1249-
# CHECK-NOT: aten::_dequantize_per_tensor
1226+
# CHECK-NOT: aten::dequantize
1227+
%w_dequant = aten::dequantize(%w_quant)
12501228
# CHECK: aten::t
12511229
# CHECK: quantized::linear_prepack
12521230
# CHECK: quantized::linear
@@ -1255,25 +1233,18 @@ def test_quant_fusion(self):
12551233
%r = aten::add_(%output, %b, %4)
12561234
# CHECK-NOT: aten::quantize_per_tensor
12571235
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
1258-
# CHECK: aten::int_repr
1259-
%r_intrepr = aten::int_repr(%r_quant)
1260-
# CHECK: aten::_dequantize_per_tensor
1261-
%r_dequant = aten::_dequantize_per_tensor(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1236+
# CHECK: aten::dequantize
1237+
%r_dequant = aten::dequantize(%r_quant)
12621238
return (%r_dequant)""",
12631239
# matmul(without bias) -> quantized::linear
12641240
"""
1265-
graph(%a, %w, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
1266-
%r_scale, %r_zero_point, %r_dtype):
1241+
graph(%a, %w, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):
12671242
%a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1268-
# CHECK-NOT: aten::int_repr
1269-
%a_intrepr = aten::int_repr(%a_quant)
1270-
# CHECK-NOT: aten::_dequantize_per_tensor
1271-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
1243+
# CHECK-NOT: aten::dequantize
1244+
%a_dequant = aten::dequantize(%a_quant)
12721245
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
1273-
# CHECK-NOT: aten::int_repr
1274-
%w_intrepr = aten::int_repr(%w_quant)
1275-
# CHECK-NOT: aten::_dequantize_per_tensor
1276-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
1246+
# CHECK-NOT: aten::dequantize
1247+
%w_dequant = aten::dequantize(%w_quant)
12771248
# CHECK: aten::t
12781249
# CHECK: prim::Constant()
12791250
# CHECK: quantized::linear_prepack
@@ -1282,10 +1253,8 @@ def test_quant_fusion(self):
12821253
%r = aten::matmul(%a_dequant, %w_dequant)
12831254
# CHECK-NOT: aten::quantize_per_tensor
12841255
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
1285-
# CHECK: aten::int_repr
1286-
%r_intrepr = aten::int_repr(%r_quant)
1287-
# CHECK: aten::_dequantize_per_tensor
1288-
%r_dequant = aten::_dequantize_per_tensor(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1256+
# CHECK: aten::dequantize
1257+
%r_dequant = aten::dequantize(%r_quant)
12891258
return (%r_dequant)"""
12901259
]
12911260
for input_str in input_strs:

torch/csrc/jit/passes/quantization.cpp

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Node* createQuantNode(Value* v, Graph* g) {
159159
// Create Dequant node
160160
Node* createDeQuantNode(Value* v, Graph* g) {
161161
Node* dequant =
162-
g->create(at::Symbol::fromQualString("aten::_dequantize_per_tensor"));
162+
g->create(at::Symbol::fromQualString("aten::dequantize"));
163163
TORCH_INTERNAL_ASSERT(dequant != nullptr, "Failed to create dequant node");
164164
dequant->output()->setDebugName(v->debugName() + ".dequant");
165165
return dequant;
@@ -401,7 +401,6 @@ Node* insertQuantDeQuantCall(
401401
bool insert_after = true) {
402402
Graph* g = v->node()->owningGraph();
403403
Node* quant = createQuantNode(v, g);
404-
Node* intrepr = createIntReprNode(v, g);
405404
Node* dequant = createDeQuantNode(v, g);
406405
Node* insert_point = insert_after ? v->node() : *g->nodes().begin();
407406
WithCurrentScope scope_guard(
@@ -417,30 +416,24 @@ Node* insertQuantDeQuantCall(
417416
Value* scale_val = g->insertConstant(scale);
418417
Value* zero_point_val = g->insertConstant(zero_point);
419418

420-
// Insert quant/int_repr/dequant nodes
419+
// Insert quant/dequant nodes
421420
if (insert_after) {
422421
quant->insertAfter(insert_point);
423422
} else {
424423
quant->insertBefore(insert_point);
425424
}
426425

427-
intrepr->insertAfter(quant);
428-
dequant->insertAfter(intrepr);
426+
dequant->insertAfter(quant);
429427

430-
// Attach inputs to quantization pattern nodes
428+
// Attach inputs to quantize node
431429
quant->addInput(v);
432-
intrepr->addInput(quant->output());
433-
dequant->addInput(intrepr->output());
434-
435430
quant->addInput(scale_val);
436431
quant->addInput(zero_point_val);
437-
dequant->addInput(scale_val);
438-
dequant->addInput(zero_point_val);
439-
440432
Value* scalar_type_val = insertScalarType(quant, scalar_type.toScalarType());
441433
TORCH_INTERNAL_ASSERT(scalar_type_val != nullptr);
442434
quant->addInput(scalar_type_val);
443-
dequant->addInput(scalar_type_val);
435+
436+
dequant->addInput(quant->output());
444437
return dequant;
445438
}
446439

@@ -665,68 +658,60 @@ void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
665658
void QuantFusion(std::shared_ptr<Graph>& graph) {
666659
const std::string quantized_linear_with_bias =
667660
R"(
668-
graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
661+
graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
669662
%w_quant_t = aten::t(%w_quant)
670663
%packed_params = quantized::linear_prepack(%w_quant_t, %b)
671664
%r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
672665
return (%r))";
673666
const std::unordered_map<std::string, std::string> pattern_and_replacements =
674667
{// quantized::conv2d
675668
{R"(
676-
graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
677-
%a_intrepr = aten::int_repr(%a_quant)
678-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
679-
%w_intrepr = aten::int_repr(%w_quant)
680-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
669+
graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
670+
%a_dequant = aten::dequantize(%a_quant)
671+
%w_dequant = aten::dequantize(%w_quant)
681672
%r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
682673
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
683674
return (%r_quant))",
684675
R"(
685-
graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
676+
graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
686677
%packed_params = quantized::conv_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
687-
%r = quantized::conv2d(%a_quant, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
678+
%r_quant = quantized::conv2d(%a_quant, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
688679
%0 : int = prim::Constant[value=0]()
689680
%1 : int = prim::Constant[value=1]()
690681
%2 : int = prim::Constant[value=2]()
691682
%3 : int = prim::Constant[value=3]()
692683
%out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
693-
%r_perm = aten::permute(%r, %out_param)
684+
%r_perm = aten::permute(%r_quant, %out_param)
694685
return (%r_perm))"},
695686
// addmm -> quantized::linear
696687
{R"(
697-
graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
698-
%a_intrepr = aten::int_repr(%a_quant)
699-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
700-
%w_intrepr = aten::int_repr(%w_quant)
701-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
688+
graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
689+
%a_dequant = aten::dequantize(%a_quant)
690+
%w_dequant = aten::dequantize(%w_quant)
702691
%r = aten::addmm(%b, %a_dequant, %w_dequant, %4, %4)
703692
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
704693
return (%r_quant))",
705694
quantized_linear_with_bias},
706695
// matmul(with bias) -> quantized::linear
707696
{R"(
708-
graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
709-
%a_intrepr = aten::int_repr(%a_quant)
710-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
711-
%w_intrepr = aten::int_repr(%w_quant)
712-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
697+
graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
698+
%a_dequant = aten::dequantize(%a_quant)
699+
%w_dequant = aten::dequantize(%w_quant)
713700
%output = aten::matmul(%a_dequant, %w_dequant)
714701
%r = aten::add_(%output, %b, %4)
715702
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
716703
return (%r_quant))",
717704
quantized_linear_with_bias},
718705
// matmul(without bias) -> quantized::linear
719706
{R"(
720-
graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):
721-
%a_intrepr = aten::int_repr(%a_quant)
722-
%a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
723-
%w_intrepr = aten::int_repr(%w_quant)
724-
%w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
707+
graph(%a_quant, %w_quant, %r_scale, %r_zero_point, %r_dtype):
708+
%a_dequant = aten::dequantize(%a_quant)
709+
%w_dequant = aten::dequantize(%w_quant)
725710
%r = aten::matmul(%a_dequant, %w_dequant)
726711
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
727712
return (%r_quant))",
728713
R"(
729-
graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):
714+
graph(%a_quant, %w_quant, %r_scale, %r_zero_point, %r_dtype):
730715
%w_quant_t = aten::t(%w_quant)
731716
%bias: Tensor? = prim::Constant()
732717
%packed_params = quantized::linear_prepack(%w_quant_t, %bias)

0 commit comments

Comments
 (0)