Skip to content

Commit 76e2ffc

Browse files
Mikhail Zolotukhinfacebook-github-bot
authored andcommitted
Remove 'recurse' parameter from Inline. (pytorch#26487)
Summary: Pull Request resolved: pytorch#26487 The way it is implemented currently is bad because while we're inlining to a graph G, we are also mutating all the graphs that are being inlined. The problem is that the graphs we're inlining are usually the original graphs of functions, so we're silently changing them behind the scenes, and we don't have a way to recover 'unoptimized' graphs afterwards. Test Plan: Imported from OSS Differential Revision: D17485748 Pulled By: ZolotukhinM fbshipit-source-id: 6094ef56077240e9379d4c53680867df1b6e79ef
1 parent a65db65 commit 76e2ffc

File tree

4 files changed

+9
-28
lines changed

4 files changed

+9
-28
lines changed

test/cpp/jit/test_inliner.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,14 @@ struct InlinerGuard {
3939

4040
void testInliner() {
4141
{
42-
// Test that the recursive inlining works
4342
// disable automatic inlining so we can test it manually
4443
InlinerGuard guard(/*shouldInline=*/false);
4544

4645
CompilationUnit cu(testSource);
4746
auto& fn = cu.get_function("foo3");
4847

4948
auto g = fn.graph();
50-
Inline(*g, /*recurse=*/true);
51-
FileCheck().check_count("prim::Print", 3)->run(*g);
52-
}
53-
{
54-
// disable automatic inlining so we can test it manually
55-
InlinerGuard guard(/*shouldInline=*/false);
56-
57-
CompilationUnit cu(testSource);
58-
auto& fn = cu.get_function("foo3");
59-
60-
auto g = fn.graph();
61-
Inline(*g, /*recurse=*/false);
49+
Inline(*g);
6250
FileCheck()
6351
.check("three")
6452
->check("two")

torch/csrc/jit/passes/inliner.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace prim {
99
using namespace ::c10::prim;
1010
}
1111

12-
void inlineCalls(Block* block, bool recurse) {
12+
void inlineCalls(Block* block) {
1313
for (auto it = block->nodes().begin(), end = block->nodes().end();
1414
it != end;) {
1515
Node* cur = *it++;
@@ -20,32 +20,26 @@ void inlineCalls(Block* block, bool recurse) {
2020
auto fun_type =
2121
function_constant->output()->type()->expect<FunctionType>();
2222
cur->removeInput(0);
23-
if (recurse) {
24-
Inline(*fun_type->function()->graph(), recurse);
25-
}
2623
inlineCallTo(cur, *fun_type->function()->graph());
2724
} break;
2825
case prim::CallMethod: {
2926
const std::string& name = cur->s(attr::name);
3027
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
3128
auto function = class_type->getMethod(name);
32-
if (recurse) {
33-
Inline(*function->graph(), recurse);
34-
}
3529
inlineCallTo(cur, *function->graph());
3630
}
3731
} break;
3832
default: {
3933
for (auto b : cur->blocks()) {
40-
inlineCalls(b, recurse);
34+
inlineCalls(b);
4135
}
4236
} break;
4337
}
4438
}
4539
}
4640

47-
void Inline(Graph& graph, bool recurse) {
48-
inlineCalls(graph.block(), recurse);
41+
void Inline(Graph& graph) {
42+
inlineCalls(graph.block());
4943
}
5044

5145
} // namespace jit

torch/csrc/jit/passes/inliner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
namespace torch {
66
namespace jit {
77

8-
// Inline function and method calls. If `recurse` is true, inline all nested
9-
// calls as well, resulting in a completely flattened graph.
10-
TORCH_API void Inline(Graph& graph, bool recurse = false);
8+
// Inline function and method calls.
9+
TORCH_API void Inline(Graph& graph);
1110

1211
} // namespace jit
1312
} // namespace torch

torch/onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def _split_tensor_list_constants(g, block):
8686

8787

8888
def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False, fixed_batch_size=False):
89-
# Inline everyting (recursively)
90-
torch._C._jit_pass_inline(graph, True)
89+
# Inline everyting
90+
torch._C._jit_pass_inline(graph)
9191

9292
# Remove fork/wait nodes
9393
torch._C._jit_pass_inline_fork_wait(graph)

0 commit comments

Comments
 (0)