Skip to content

Commit b768db0

Browse files
suofacebook-github-bot
authored andcommitted
Allow DCE to clean up some mutable ops (pytorch#14601)
Summary: This PR makes DCE a little smarter in the presence of mutable ops. Previously mutable ops could never be cleaned up, now they can be cleaned up if we can prove there are no live uses of any alias sets that the op writes to. This behavior is optional; if you pass DCE a block instead of a graph, it will do the same thing as before. Also changed `InlineAutographSubgraph` to use the common subgraph utils. Tested on traced ResNet, and it gets rid of the dead code. Pull Request resolved: pytorch#14601 Differential Revision: D13309118 Pulled By: suo fbshipit-source-id: dac2791e7d2ecf219ae717a2759b83c1e927f254
1 parent 9783ce3 commit b768db0

20 files changed

+537
-140
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
graph() {
2+
%0 : int = prim::Constant[value=1]()
3+
%1 : int[] = prim::Constant[value=[0, -1]]()
4+
%2 : int = prim::Constant[value=0]()
5+
%3 : int = prim::Constant[value=6]()
6+
%4 : int = prim::Constant[value=2]()
7+
%5 : int = prim::Constant[value=3]()
8+
%6 : int[] = prim::ListConstruct(%4, %5)
9+
%a.1 : Tensor = aten::rand(%6, %3, %2, %1)
10+
%8 : int[] = prim::ListConstruct(%4, %5)
11+
%9 : Tensor = aten::rand(%8, %3, %2, %1)
12+
%a : Tensor = aten::add_(%a.1, %9, %0)
13+
return (%a);
14+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
graph() {
2+
%0 : int = prim::Constant[value=1]()
3+
%1 : int[] = prim::Constant[value=[0, -1]]()
4+
%2 : int = prim::Constant[value=0]()
5+
%3 : int = prim::Constant[value=6]()
6+
%4 : int = prim::Constant[value=2]()
7+
%5 : int = prim::Constant[value=3]()
8+
%6 : int[] = prim::ListConstruct(%4, %5)
9+
%a.1 : Tensor = aten::rand(%6, %3, %2, %1)
10+
%8 : int[] = prim::ListConstruct(%4, %5)
11+
%9 : Tensor = aten::rand(%8, %3, %2, %1)
12+
%a.2 : Tensor = aten::add_(%a.1, %9, %0)
13+
%11 : int[] = prim::ListConstruct(%4, %5)
14+
%b.1 : Tensor = aten::rand(%11, %3, %2, %1)
15+
%13 : int[] = prim::ListConstruct(%4, %5)
16+
%14 : Tensor = aten::zeros(%13, %3, %2, %1)
17+
%15 : Tensor = aten::gt(%a.2, %14)
18+
%16 : bool = prim::TensorToBool(%15)
19+
%b : Tensor = prim::If(%16)
20+
block0() {
21+
%18 : int[] = prim::ListConstruct(%4, %5)
22+
%19 : Tensor = aten::rand(%18, %3, %2, %1)
23+
%b.2 : Tensor = aten::add_(%b.1, %19, %0)
24+
-> (%b.2)
25+
}
26+
block1() {
27+
-> (%b.1)
28+
}
29+
return (%b);
30+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
graph(%a.1 : Tensor) {
2+
%1 : int = prim::Constant[value=1]()
3+
%2 : int[] = prim::Constant[value=[0, -1]]()
4+
%3 : int = prim::Constant[value=0]()
5+
%4 : int = prim::Constant[value=6]()
6+
%5 : int = prim::Constant[value=2]()
7+
%6 : int = prim::Constant[value=3]()
8+
%7 : int[] = prim::ListConstruct(%5, %6)
9+
%8 : Tensor = aten::rand(%7, %4, %3, %2)
10+
%a : Tensor = aten::add_(%a.1, %8, %1)
11+
return ();
12+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
graph(%a : Tensor) {
2+
%1 : int = prim::Constant[value=1]()
3+
%2 : int[] = prim::Constant[value=[0, -1]]()
4+
%3 : int = prim::Constant[value=6]()
5+
%4 : int = prim::Constant[value=0]()
6+
%5 : int = prim::Constant[value=2]()
7+
%6 : int = prim::Constant[value=3]()
8+
%l : Tensor[] = prim::ListConstruct()
9+
%8 : Tensor[] = aten::append(%l, %a)
10+
%c.1 : Tensor = aten::select(%l, %4)
11+
%10 : int[] = prim::ListConstruct(%5, %6)
12+
%b : Tensor = aten::rand(%10, %3, %4, %2)
13+
%12 : int[] = prim::ListConstruct(%5, %6)
14+
%13 : Tensor = aten::rand(%12, %3, %4, %2)
15+
%c : Tensor = aten::add_(%c.1, %13, %1)
16+
return (%b);
17+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
graph(%a : Tensor) {
2+
%1 : int[] = prim::Constant[value=[0, -1]]()
3+
%2 : int = prim::Constant[value=6]()
4+
%i.1 : int = prim::Constant[value=0]()
5+
%4 : int = prim::Constant[value=2]()
6+
%5 : int = prim::Constant[value=3]()
7+
%6 : int = prim::Constant[value=9223372036854775807]()
8+
%7 : int = prim::Constant[value=1]()
9+
%l : Tensor[] = prim::ListConstruct()
10+
%9 : Tensor[] = aten::append(%l, %a)
11+
%10 : int[] = prim::ListConstruct(%4, %5)
12+
%b : Tensor = aten::rand(%10, %2, %i.1, %1)
13+
%12 : bool = aten::lt(%i.1, %7)
14+
%i : int = prim::Loop(%6, %12, %i.1)
15+
block0(%14 : int, %15 : int) {
16+
%c.1 : Tensor = aten::select(%l, %i.1)
17+
%17 : int[] = prim::ListConstruct(%4, %5)
18+
%18 : Tensor = aten::rand(%17, %2, %i.1, %1)
19+
%c : Tensor = aten::add_(%c.1, %18, %7)
20+
%i.2 : int = aten::add(%15, %7)
21+
%21 : bool = aten::lt(%i.2, %7)
22+
-> (%21, %i.2)
23+
}
24+
return (%b);
25+
}

test/test_jit.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8881,6 +8881,68 @@ def fn(x, y):
88818881

88828882
self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
88838883

8884+
def test_mutable_dce(self):
8885+
@torch.jit.script
8886+
def foo():
8887+
a = torch.rand(2, 3)
8888+
a += torch.rand(2, 3)
8889+
b = torch.rand(2, 3)
8890+
b += torch.rand(2, 3)
8891+
# b should be cleaned up but not a
8892+
return a
8893+
8894+
self.assertExpectedGraph(foo.graph)
8895+
8896+
def test_mutable_dce_block(self):
8897+
@torch.jit.script
8898+
def foo():
8899+
a = torch.rand(2, 3)
8900+
a += torch.rand(2, 3)
8901+
b = torch.rand(2, 3)
8902+
if bool(a > torch.zeros(2, 3)):
8903+
b += torch.rand(2, 3)
8904+
a += torch.rand(2, 3)
8905+
# a should be cleaned up but not b
8906+
return b
8907+
8908+
self.assertExpectedGraph(foo.graph)
8909+
8910+
def test_mutable_dce_graph_input(self):
8911+
@torch.jit.script
8912+
def foo(a):
8913+
a += torch.rand(2, 3)
8914+
# shouldn't clean up `a` even though it's not used in the output
8915+
8916+
self.assertExpectedGraph(foo.graph)
8917+
8918+
def test_mutable_dce_list(self):
8919+
@torch.jit.script
8920+
def foo(a):
8921+
l = []
8922+
l.append(a)
8923+
c = l[0]
8924+
b = torch.rand(2, 3)
8925+
c += torch.rand(2, 3)
8926+
return b
8927+
8928+
self.assertExpectedGraph(foo.graph)
8929+
8930+
def test_mutable_dce_loop(self):
8931+
@torch.jit.script
8932+
def foo(a):
8933+
l = []
8934+
l.append(a)
8935+
i = 0
8936+
b = torch.rand(2, 3)
8937+
while i < 1:
8938+
dead = torch.rand(2, 3)
8939+
c = l[0]
8940+
c += torch.rand(2, 3)
8941+
i += 1
8942+
return b
8943+
8944+
self.assertExpectedGraph(foo.graph)
8945+
88848946

88858947
class MnistNet(nn.Module):
88868948
def __init__(self):

torch/csrc/jit/export.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_typ
9696

9797
void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
9898
validateBlock(graph->block(), operator_export_type);
99-
EliminateDeadCode(graph);
99+
EliminateDeadCode(graph->block());
100100
}
101101

102102
class EncoderBase {

torch/csrc/jit/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void initJITBindings(PyObject *module) {
9696
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
9797
.def("_jit_pass_fuse", FuseGraph)
9898
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
99-
return EliminateDeadCode(g); // overload resolution
99+
return EliminateDeadCode(g->block()); // overload resolution
100100
})
101101
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
102102
return EliminateCommonSubexpression(g); // overload resolution

torch/csrc/jit/ir.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -862,18 +862,18 @@ Value* Node::insertOutput(size_t i) {
862862
return outputs_.at(i);
863863
}
864864

865-
bool Node::isBefore(const Node * n) const {
866-
if (this == n) {
867-
return false;
868-
}
869-
return !isAfter(n);
870-
}
865+
bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
866+
if (this->owningBlock() == n->owningBlock()) {
867+
if (moveSide == MoveSide::BEFORE) {
868+
return this->topo_position_ < n->topo_position_;
869+
}
871870

872-
bool Node::isAfter(const Node * n) const {
873-
JIT_ASSERT(this->owningGraph() == n->owningGraph());
871+
if (moveSide == MoveSide::AFTER) {
872+
return this->topo_position_ > n->topo_position_;
873+
}
874874

875-
if (this->owningBlock() == n->owningBlock()) {
876-
return this->topo_position_ > n->topo_position_;
875+
JIT_ASSERT(this == n);
876+
return false;
877877
}
878878

879879
// These nodes don't share a common block. Traverse the blockchains upward
@@ -887,7 +887,7 @@ bool Node::isAfter(const Node * n) const {
887887
JIT_ASSERT(rhs->owningBlock());
888888

889889
if (lhs->owningBlock() == rhs->owningBlock()) {
890-
return lhs->isAfter(rhs);
890+
return lhs->isBeforeOrAfter(rhs, moveSide);
891891
}
892892
rhs = rhs->owningBlock()->owningNode();
893893
}
@@ -896,6 +896,15 @@ bool Node::isAfter(const Node * n) const {
896896
}
897897
// should never reach here, since both nodes are ultimately in the same graph
898898
JIT_ASSERT(false);
899+
900+
}
901+
902+
bool Node::isBefore(const Node * n) const {
903+
return isBeforeOrAfter(n, MoveSide::BEFORE);
904+
}
905+
906+
bool Node::isAfter(const Node * n) const {
907+
return isBeforeOrAfter(n, MoveSide::AFTER);
899908
}
900909

901910
Node* Node::insertBefore(Node * n) {

torch/csrc/jit/ir.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ struct Node : public Attributes<Node> {
599599
enum class MoveSide { BEFORE, AFTER };
600600
bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
601601
void move(Node* movePoint, MoveSide moveSide);
602+
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
602603

603604
std::pair<Value*, const Argument&> findInput(Symbol name);
604605
void findSchema() const;
@@ -808,6 +809,12 @@ friend struct Block;
808809
const auto & block = *block_;
809810
return block.nodes();
810811
}
812+
Node * param_node() {
813+
return block_->param_node();
814+
}
815+
const Node * param_node() const {
816+
return block_->param_node();
817+
}
811818
Node * return_node() {
812819
return block_->return_node();
813820
}

0 commit comments

Comments
 (0)