@@ -159,7 +159,7 @@ Node* createQuantNode(Value* v, Graph* g) {
159159// Create Dequant node
160160Node* createDeQuantNode (Value* v, Graph* g) {
161161 Node* dequant =
162- g->create (at::Symbol::fromQualString (" aten::_dequantize_per_tensor " ));
162+ g->create (at::Symbol::fromQualString (" aten::dequantize " ));
163163 TORCH_INTERNAL_ASSERT (dequant != nullptr , " Failed to create dequant node" );
164164 dequant->output ()->setDebugName (v->debugName () + " .dequant" );
165165 return dequant;
@@ -401,7 +401,6 @@ Node* insertQuantDeQuantCall(
401401 bool insert_after = true ) {
402402 Graph* g = v->node ()->owningGraph ();
403403 Node* quant = createQuantNode (v, g);
404- Node* intrepr = createIntReprNode (v, g);
405404 Node* dequant = createDeQuantNode (v, g);
406405 Node* insert_point = insert_after ? v->node () : *g->nodes ().begin ();
407406 WithCurrentScope scope_guard (
@@ -417,30 +416,24 @@ Node* insertQuantDeQuantCall(
417416 Value* scale_val = g->insertConstant (scale);
418417 Value* zero_point_val = g->insertConstant (zero_point);
419418
420- // Insert quant/int_repr/ dequant nodes
419+ // Insert quant/dequant nodes
421420 if (insert_after) {
422421 quant->insertAfter (insert_point);
423422 } else {
424423 quant->insertBefore (insert_point);
425424 }
426425
427- intrepr->insertAfter (quant);
428- dequant->insertAfter (intrepr);
426+ dequant->insertAfter (quant);
429427
430- // Attach inputs to quantization pattern nodes
428+ // Attach inputs to quantize node
431429 quant->addInput (v);
432- intrepr->addInput (quant->output ());
433- dequant->addInput (intrepr->output ());
434-
435430 quant->addInput (scale_val);
436431 quant->addInput (zero_point_val);
437- dequant->addInput (scale_val);
438- dequant->addInput (zero_point_val);
439-
440432 Value* scalar_type_val = insertScalarType (quant, scalar_type.toScalarType ());
441433 TORCH_INTERNAL_ASSERT (scalar_type_val != nullptr );
442434 quant->addInput (scalar_type_val);
443- dequant->addInput (scalar_type_val);
435+
436+ dequant->addInput (quant->output ());
444437 return dequant;
445438}
446439
@@ -665,68 +658,60 @@ void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
665658void QuantFusion (std::shared_ptr<Graph>& graph) {
666659 const std::string quantized_linear_with_bias =
667660 R"(
668- graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, % r_scale, %r_zero_point, %r_dtype, %4):
661+ graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
669662 %w_quant_t = aten::t(%w_quant)
670663 %packed_params = quantized::linear_prepack(%w_quant_t, %b)
671664 %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
672665 return (%r))" ;
673666 const std::unordered_map<std::string, std::string> pattern_and_replacements =
674667 {// quantized::conv2d
675668 {R"(
676- graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
677- %a_intrepr = aten::int_repr(%a_quant)
678- %a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
679- %w_intrepr = aten::int_repr(%w_quant)
680- %w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
669+ graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
670+ %a_dequant = aten::dequantize(%a_quant)
671+ %w_dequant = aten::dequantize(%w_quant)
681672 %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
682673 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
683674 return (%r_quant))" ,
684675 R"(
685- graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, % r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
676+ graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
686677 %packed_params = quantized::conv_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
687- %r = quantized::conv2d(%a_quant, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
678+ %r_quant = quantized::conv2d(%a_quant, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
688679 %0 : int = prim::Constant[value=0]()
689680 %1 : int = prim::Constant[value=1]()
690681 %2 : int = prim::Constant[value=2]()
691682 %3 : int = prim::Constant[value=3]()
692683 %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
693- %r_perm = aten::permute(%r , %out_param)
684+ %r_perm = aten::permute(%r_quant , %out_param)
694685 return (%r_perm))" },
695686 // addmm -> quantized::linear
696687 {R"(
697- graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
698- %a_intrepr = aten::int_repr(%a_quant)
699- %a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
700- %w_intrepr = aten::int_repr(%w_quant)
701- %w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
688+ graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
689+ %a_dequant = aten::dequantize(%a_quant)
690+ %w_dequant = aten::dequantize(%w_quant)
702691 %r = aten::addmm(%b, %a_dequant, %w_dequant, %4, %4)
703692 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
704693 return (%r_quant))" ,
705694 quantized_linear_with_bias},
706695 // matmul(with bias) -> quantized::linear
707696 {R"(
708- graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4):
709- %a_intrepr = aten::int_repr(%a_quant)
710- %a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
711- %w_intrepr = aten::int_repr(%w_quant)
712- %w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
697+ graph(%a_quant, %w_quant, %b, %r_scale, %r_zero_point, %r_dtype, %4):
698+ %a_dequant = aten::dequantize(%a_quant)
699+ %w_dequant = aten::dequantize(%w_quant)
713700 %output = aten::matmul(%a_dequant, %w_dequant)
714701 %r = aten::add_(%output, %b, %4)
715702 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
716703 return (%r_quant))" ,
717704 quantized_linear_with_bias},
718705 // matmul(without bias) -> quantized::linear
719706 {R"(
720- graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype):
721- %a_intrepr = aten::int_repr(%a_quant)
722- %a_dequant = aten::_dequantize_per_tensor(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
723- %w_intrepr = aten::int_repr(%w_quant)
724- %w_dequant = aten::_dequantize_per_tensor(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
707+ graph(%a_quant, %w_quant, %r_scale, %r_zero_point, %r_dtype):
708+ %a_dequant = aten::dequantize(%a_quant)
709+ %w_dequant = aten::dequantize(%w_quant)
725710 %r = aten::matmul(%a_dequant, %w_dequant)
726711 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
727712 return (%r_quant))" ,
728713 R"(
729- graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, % r_scale, %r_zero_point, %r_dtype):
714+ graph(%a_quant, %w_quant, %r_scale, %r_zero_point, %r_dtype):
730715 %w_quant_t = aten::t(%w_quant)
731716 %bias: Tensor? = prim::Constant()
732717 %packed_params = quantized::linear_prepack(%w_quant_t, %bias)
0 commit comments