Skip to content

Commit a8add2b

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
Support matching Args for SubgraphMatcher (#85456)
Subgraph matcher now handles the matching of non-Node arguments. Here are the 4 cases - pn is Node, gn is Node: this go through the regular _match_node() function - pn is Noed, gn is not a Node: this is a match if only pn is a placeholder op - pn is not Node, gn is Node: this is a no match case - pn is not a Node, gn is not a Node: this will go through the argument comparison. With this change ``` def target(x): return foo(x, 3) def pattern(x, y): return foo(x, y) ``` is a match Pull Request resolved: #85456 Approved by: https://github.com/jerryzh168
1 parent db40fbd commit a8add2b

File tree

2 files changed

+92
-4
lines changed

2 files changed

+92
-4
lines changed

test/test_fx_passes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,49 @@ def pattern(a, b, c):
621621
TestCase(False, True, 0),
622622
]
623623

624+
class QuantizationFp8Pattern:
625+
@classmethod
626+
def setup(cls):
627+
cls.quantization = torch.library.Library("fp8_quantization", "DEF")
628+
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
629+
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
630+
631+
@classmethod
632+
def tearDown(cls):
633+
del cls.quantization
634+
635+
@staticmethod
636+
def forward(self, arg0_1, arg1_1):
637+
qt = torch.ops.fp8_quantization
638+
_scale_0 = self._scale_0
639+
quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
640+
dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
641+
_scale_1 = self._scale_0
642+
quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
643+
dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
644+
add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
645+
_scale_2 = self._scale_0
646+
quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
647+
dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
648+
return dequantize_per_tensor_affine_fp8_2
649+
650+
@staticmethod
651+
def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
652+
qt = torch.ops.fp8_quantization
653+
a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
654+
b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
655+
output = torch.ops.aten.add.Tensor(a, b)
656+
657+
qt.dequantize_per_tensor_affine_fp8
658+
659+
output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
660+
return output
661+
662+
test_cases = [
663+
# match_output, match_placeholder, num_matches
664+
TestCase(False, False, 1),
665+
]
666+
624667
@instantiate_parametrized_tests
625668
class TestFXMatcherUtils(JitTestCase):
626669

@@ -639,8 +682,14 @@ class TestFXMatcherUtils(JitTestCase):
639682
MultipleOutputsIdenticalAnchor,
640683
MultipleOutputsHorizontalPattern,
641684
MultiOutputWithWithInvalidMatches,
685+
QuantizationFp8Pattern,
642686
])
643687
def test_subgraph_matcher(self, test_model):
688+
689+
setup = getattr(test_model, "setup", None)
690+
if callable(setup):
691+
setup()
692+
644693
traced = symbolic_trace(test_model.forward)
645694
pattern_traced = symbolic_trace(test_model.pattern)
646695

@@ -662,6 +711,10 @@ def test_subgraph_matcher(self, test_model):
662711
continue
663712
assert node in match.nodes_map
664713

714+
tearDown = getattr(test_model, "tearDown", None)
715+
if callable(setup):
716+
tearDown()
717+
665718

666719
if __name__ == "__main__":
667720
run_tests()

torch/fx/passes/utils/matcher_utils.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from torch.fx.graph import Graph
55
from torch.fx.node import Node
66
from torch.fx._compatibility import compatibility
7-
from typing import Dict, List, Set
7+
import torch.utils._pytree as pytree
8+
from typing import Dict, List, Set, Any
89

910
__all__ = ['SubgraphMatcher', 'InternalMatch']
1011

@@ -124,7 +125,27 @@ def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[Inte
124125
nodes_matched.add(gn)
125126
return non_overlapping_matches
126127

128+
def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
129+
assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
130+
131+
if isinstance(pn, Node) and not isinstance(gn, Node):
132+
if pn.op == "placeholder":
133+
# Check if we've already matched these nodes in the current
134+
# traversal
135+
if pn in match.nodes_map:
136+
return match.nodes_map[pn] == gn
137+
138+
match.nodes_map[pn] = gn
139+
return True
140+
else:
141+
return False
142+
elif not isinstance(pn, Node) and isinstance(gn, Node):
143+
return False
144+
else:
145+
return type(gn) == type(pn) and gn == pn
146+
127147
def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
148+
assert isinstance(pn, Node) and isinstance(gn, Node), "pn and gn must be Node"
128149

129150
# Check if we've already matched these nodes in the current
130151
# traversal
@@ -146,9 +167,23 @@ def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
146167

147168
# Recursively traverse upwards to check if `pn` is a true
148169
# match for `gn`
149-
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes) and
150-
all(self._match_nodes(pn_, gn_, match) for pn_, gn_
151-
in zip(pn.all_input_nodes, gn.all_input_nodes)))
170+
match_found = True
171+
172+
pn_flatten_args, _ = pytree.tree_flatten(pn.args)
173+
gn_flatten_args, _ = pytree.tree_flatten(gn.args)
174+
175+
if len(pn_flatten_args) == len(gn_flatten_args):
176+
for pn_, gn_ in zip(pn_flatten_args, gn_flatten_args):
177+
if isinstance(gn_, Node) and isinstance(pn_, Node):
178+
matched = self._match_nodes(pn_, gn_, match)
179+
else:
180+
matched = self._match_args(pn_, gn_, match)
181+
182+
if not matched:
183+
match_found = False
184+
break
185+
else:
186+
match_found = False
152187

153188
if not match_found:
154189
match.nodes_map.pop(pn)

0 commit comments

Comments
 (0)