Skip to content

Commit 981baad

Browse files
davidberard98pytorchmergebot
authored andcommitted
[JIT] Add autocasting to freezing pass & enable autocast pass by default (pytorch#74178)
Summary: Pull Request resolved: pytorch#74178 Autocasting + freezing should reduce model size in some scenarios, since half-precision constants should be smaller than full-precision constants. This also enables the jit autocast pass by default, so `torch._C._jit_set_autocast_mode(True)` doesn't need to be set in order to enable autocasting. Test Plan: Imported from OSS Reviewed By: zou3519, eellison Differential Revision: D34914245 Pulled By: davidberard98 fbshipit-source-id: 301f3669431feabbd695ebbdfc9c17bd1be3b565 (cherry picked from commit 0530cd3)
1 parent f5a9c36 commit 981baad

File tree

3 files changed

+66
-6
lines changed

3 files changed

+66
-6
lines changed

test/test_jit_autocast.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,55 @@ def forward(self, x, y):
659659
# isn't enabled
660660
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
661661

662+
@unittest.skipIf(not TEST_CUDA, "No cuda")
663+
def test_jit_freeze_autocast_basic(self):
664+
class TestModule(torch.nn.Module):
665+
def __init__(self):
666+
super(TestModule, self).__init__()
667+
668+
def forward(self, x, y):
669+
with torch.cuda.amp.autocast():
670+
return torch.mm(x, y)
671+
672+
x = torch.rand((3, 4), dtype=torch.float).cuda()
673+
y = torch.rand((4, 5), dtype=torch.float).cuda()
674+
675+
mod = TestModule().eval()
676+
677+
# sanity check
678+
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
679+
680+
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
681+
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
682+
683+
# make sure that the runtime pass doesn't duplicate autocast nodes
684+
frozen_mod(x, y)
685+
optimized_graph = frozen_mod.graph_for(x, y)
686+
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
687+
688+
@unittest.skipIf(not TEST_CUDA, "No cuda")
689+
def test_jit_freeze_autocast_constants(self):
690+
class TestModule(torch.nn.Module):
691+
def __init__(self):
692+
super(TestModule, self).__init__()
693+
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
694+
695+
def forward(self, y):
696+
with torch.cuda.amp.autocast():
697+
return torch.mm(self.x, y)
698+
699+
y = torch.rand((4, 5), dtype=torch.float).cuda()
700+
mod = TestModule().eval()
701+
702+
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
703+
# freezing should pre-cast the constant self.x to remove one autocast call
704+
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
705+
706+
# the runtime autocasting pass will re-insert the second autocast call,
707+
# but constant propagation will merge it with the constant that it's casting.
708+
frozen_mod(y)
709+
optimized_graph = frozen_mod.graph_for(y)
710+
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
662711

663712
if __name__ == "__main__":
664713
run_tests()

torch/csrc/jit/passes/autocast.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212

1313
#include <stack>
1414
#include <unordered_set>
15+
#include <vector>
1516

1617
namespace torch {
1718
namespace jit {
1819

1920
namespace {
2021

21-
// TODO: Turn on autocast by default. default turned off to avoid tests failures
22-
// as we prototype the support
23-
bool autocast_enabled = false;
22+
bool autocast_enabled = true;
2423

2524
struct AutocastContext {
2625
bool gpu_enabled = false;
@@ -149,17 +148,23 @@ void castTensorInputs(
149148
const auto graph = node->owningGraph();
150149

151150
std::unordered_set<Value*> casted_inputs;
151+
// need to also keep the inputs in order, otherwise tracing fails
152+
// sanity checks because casting ops are inserted in random order
153+
std::vector<Value*> casted_inputs_ordered;
152154
for (auto input : node->inputs()) {
153155
// TODO: update cast_op signature to take dynamic context flags
154156
auto input_tensor_type = input->type()->cast<TensorType>();
155157
if (input_tensor_type && input->node()->kind() != cast_op) {
156-
casted_inputs.insert(input);
158+
auto has_inserted = casted_inputs.insert(input);
159+
if (has_inserted.second) {
160+
casted_inputs_ordered.push_back(input);
161+
}
157162
}
158163
}
159164

160165
WithInsertPoint insert_point(node);
161166

162-
for (auto input : casted_inputs) {
167+
for (auto input : casted_inputs_ordered) {
163168
if (cast_op == aten::_autocast_to_full_precision) {
164169
const auto new_input = graph->insert(
165170
cast_op,
@@ -437,7 +442,9 @@ void handleBlock(Block* block, AutocastContext initial_state) {
437442

438443
// Banned in autocast, see binary_cross_entropy_banned()
439444
case aten::binary_cross_entropy:
440-
AT_ERROR("Unsafe to autocast");
445+
if (current_state()) {
446+
AT_ERROR("Unsafe to autocast");
447+
}
441448
}
442449

443450
// process sub-blocks, if any

torch/csrc/jit/passes/freeze_module.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <c10/util/irange.h>
66
#include <torch/csrc/jit/api/function_impl.h>
77
#include <torch/csrc/jit/ir/alias_analysis.h>
8+
#include <torch/csrc/jit/passes/autocast.h>
89
#include <torch/csrc/jit/passes/clear_profiling.h>
910
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
1011
#include <torch/csrc/jit/passes/inliner.h>
@@ -101,6 +102,9 @@ class AttributePropagator {
101102
ClearProfilingInformation(subgraph);
102103
};
103104
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
105+
#ifndef C10_MOBILE
106+
Autocast(subgraph);
107+
#endif
104108
runOptimization(
105109
subgraph,
106110
/* unroll_non_constant_loops? */ false,

0 commit comments

Comments
 (0)