forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunction_substitution.cpp
197 lines (182 loc) · 7.31 KB
/
function_substitution.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/naming.h>
namespace torch {
namespace jit {
namespace {
const std::string kTopModuleVariableName = "";
std::string TidyClassNameFromTorchScript(
const std::optional<c10::QualifiedName>& class_name) {
if (!class_name) {
return "UNKNOWN_CLASS";
}
std::string out = "";
for (const auto& atom : class_name->atoms()) {
bool is_internal_torch_atom = (atom == "__torch__");
bool is_mangle_atom = (atom.find("__torch_mangle") != std::string::npos);
if (!is_internal_torch_atom && !is_mangle_atom) {
if (!out.empty()) {
out += ".";
}
out += atom;
}
}
return out;
}
std::string GetCallNodeVariableName(const Node* call_node) {
TORCH_INTERNAL_ASSERT(
call_node->kind() == prim::CallFunction ||
call_node->kind() == prim::CallMethod);
auto module_node = call_node->input(0)->node();
if (!module_node->hasAttribute(attr::name)) {
return "";
}
std::string module_name = module_node->s(attr::name);
if (module_node->inputs().empty()) {
return module_name;
}
// If module is from container, attr::name in module node only carries
// index info. Need to check parent node (container) for variable name.
auto parent_module_value = module_node->input(0);
while (parent_module_value) {
auto parent_module_type = parent_module_value->type()->cast<ClassType>();
if (parent_module_type &&
parent_module_type->name() ==
"__torch__.torch.nn.modules.container.ModuleList") {
auto parent_module_node = parent_module_value->node();
module_name = parent_module_node->s(attr::name) + "." + module_name;
parent_module_value = !parent_module_node->inputs().empty()
? parent_module_node->input(0)
: nullptr;
} else {
break;
}
}
return module_name;
}
ScopePtr ForwardCallScope(Graph& graph, Node* call_node) {
TORCH_INTERNAL_ASSERT(call_node->kind() == prim::CallMethod);
const std::string& method_name = call_node->s(attr::name);
if (method_name == "forward") {
const auto type = call_node->input(0)->type()->expect<c10::NamedType>();
const std::string class_name = TidyClassNameFromTorchScript(type->name());
const std::string variable_name = GetCallNodeVariableName(call_node);
const std::string scope_name =
onnx::ONNXScopeName::createFullScopeName(class_name, variable_name);
return graph.current_scope()->push(Symbol::scope(scope_name));
}
return graph.current_scope();
}
void functionCallSubstitution(Block* block) {
auto graph = block->owningGraph();
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
switch (cur->kind()) {
case prim::CallFunction: {
TORCH_INTERNAL_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
if ((fun_type->function()->qualname().qualifiedName().find(
"torch.nn.functional") != std::string::npos) &&
(fun_type->function()->qualname().qualifiedName().find(
"interpolate") != std::string::npos)) {
// Remove input[0] and the node that feeds into it
auto input_node_0 = cur->input(0)->node();
cur->removeInput(0);
if (!input_node_0->hasUses()) {
input_node_0->destroy();
}
Node* interpolate_node = block->owningGraph()->create(
Symbol::fromQualString("aten::__interpolate"),
{cur->inputs()},
cur->outputs().size());
interpolate_node->output()->copyMetadata(cur->output());
interpolate_node->insertAfter(cur);
interpolate_node->copyMetadata(cur);
cur->replaceAllUsesWith(interpolate_node);
cur->removeAllInputs();
cur->destroy();
GRAPH_UPDATE(
"ONNX function call substitution function: '",
fun_type->function()->name(),
"' to aten::__interpolate");
GRAPH_UPDATE(
"Function in ONNX function call substitution body: ",
toGraphFunction(*fun_type->function()).optimized_graph());
} else {
// Remove input[0] and the node that feeds into it
auto input_node_0 = cur->input(0)->node();
cur->removeInput(0);
if (!input_node_0->hasUses()) {
input_node_0->destroy();
}
auto& graphFunction = toGraphFunction(*fun_type->function());
functionCallSubstitution(graphFunction.graph()->block());
inlineCallTo(cur, &graphFunction, false);
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
Function& function = class_type->getMethod(name);
ScopePtr call_scope = ForwardCallScope(*graph, cur);
WithCurrentScope scope_guard(*graph, call_scope);
GRAPH_DEBUG(
"Setting scope guard for forward call: ",
graph->current_scope()->namesFromRoot());
if (auto graphFunction = tryToGraphFunction(function)) {
GRAPH_DEBUG(
"Inner graph for method call ",
name,
": ",
*graphFunction->graph());
WithCurrentScope inner_graph_scope_guard(
*graphFunction->graph(), call_scope);
functionCallSubstitution(graphFunction->graph()->block());
inlineCallTo(cur, graphFunction, false);
}
}
} break;
default: {
if (!graph->current_scope()->isBlank()) {
cur->setScope(graph->current_scope());
}
for (auto b : cur->blocks()) {
functionCallSubstitution(b);
}
} break;
}
GRAPH_DEBUG(
"Graph current scope after node process: ",
graph->current_scope()->namesFromRoot());
}
}
ScopePtr ONNXGraphTopLevelScope(Graph& graph) {
if (graph.inputs().empty()) {
return graph.current_scope();
}
if (auto top_module_type = graph.inputs().at(0)->type()->cast<ClassType>()) {
auto scope_name = ::torch::jit::onnx::ONNXScopeName::createFullScopeName(
TidyClassNameFromTorchScript(top_module_type->name()),
kTopModuleVariableName);
return graph.current_scope()->push(Symbol::scope(scope_name));
}
return graph.current_scope();
}
} // namespace
// This pass is to be used for ONNX conversion only. The ONNX converter depends
// on a number of deprecated aten operators. These operators are removed from IR
// and replaced by the compiled python function code. However, in-order to
// maintain the behavior for ONNX conversion, we replace these function calls
// with the aten symbolic which can still be used by the ONNX converter.
void ONNXFunctionCallSubstitution(Graph& graph) {
GRAPH_DUMP("Before function call substitution calls: ", &graph);
WithCurrentScope top_level_scope_guard(graph, ONNXGraphTopLevelScope(graph));
functionCallSubstitution(graph.block());
GRAPH_DUMP("After function call substitution calls: ", &graph);
}
} // namespace jit
} // namespace torch