Skip to content

Commit 4360486

Browse files
jjsjann123facebook-github-bot
authored andcommitted
pass strict_fuser_check for recursive fusion (pytorch#47221)
Summary: We forgot to pass `strict_fuser_check` recursively to nested GraphFuser. Pull Request resolved: pytorch#47221 Reviewed By: zhangguanheng66 Differential Revision: D25060095 Pulled By: Krovatkin fbshipit-source-id: 31fe79c3bc080b637fce9aacc562d60708223321
1 parent ea1e78a commit 4360486

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torch/csrc/jit/passes/graph_fuser.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ struct GraphFuser {
151151
AliasDb* aliasDb,
152152
Block* block,
153153
FusionCallback callback,
154-
Symbol kind)
154+
Symbol kind,
155+
bool strict_fuser_check = false)
155156
: block_(block),
156157
aliasDb_(aliasDb),
157158
callback_(std::move(callback)),
158159
kind_(kind),
159-
strict_fuser_check_(false) {}
160+
strict_fuser_check_(strict_fuser_check) {}
160161

161162
void setInputArgLimit(size_t limit) {
162163
subgraph_arg_limit_ = limit;
@@ -1169,7 +1170,8 @@ struct GraphFuser {
11691170

11701171
for (Node* node : block_->nodes()) {
11711172
for (Block* sub_block : node->blocks()) {
1172-
GraphFuser(aliasDb_, sub_block, callback_, kind_).run();
1173+
GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_)
1174+
.run();
11731175
}
11741176
}
11751177
}

0 commit comments

Comments
 (0)