Skip to content

Commit 5c56986

Browse files
pritamdamaniafacebook-github-bot
authored andcommitted
Attach autograd edges only for tensors requiring grad. (pytorch#30904)
Summary: Pull Request resolved: pytorch#30904 When we sent tensors over RPC, on the server side we would call addRecvRpcBackward which would call `set_history` on all tensors. This was incorrect and set the `requires_grad` flag on tensors that didn't actually need grad. To fix this, we only attach autograd edges to tensors that need grads. ghstack-source-id: 95113672 ghstack-source-id: 95113999 Test Plan: waitforbuildbot Differential Revision: D18828561 fbshipit-source-id: d8942b76e9e4c567f8f1821f125c00d275ea0f90
1 parent 62b1072 commit 5c56986

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

test/dist_autograd_test.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -665,14 +665,11 @@ def _test_rpc_complex_args(self, exec_mode):
665665
dist_autograd._current_context()._send_functions().values()
666666
)[0].next_functions
667667
idx = 0
668-
for i in range(num_tensors):
669-
if i % 2 == 0:
670-
self.assertEqual(
671-
"torch::autograd::AccumulateGrad", next_funcs[i][0].name()
672-
)
673-
self.assertEqual(tensors[i], next_funcs[i][0].variable)
674-
else:
675-
self.assertIsNone(next_funcs[i][0])
668+
for i in range(len(next_funcs)):
669+
self.assertEqual(
670+
"torch::autograd::AccumulateGrad", next_funcs[i][0].name()
671+
)
672+
self.assertEqual(tensors[i], next_funcs[i][0].variable)
676673

677674
# Verify that the worker id has been recorded in the context
678675
ctx = dist_autograd._current_context()
@@ -1370,6 +1367,67 @@ def test_clean_context_during_backward(self):
13701367
rpc.shutdown()
13711368
sys.exit(0)
13721369

1370+
def _call_remote_embedding(embedding_rref, input, offsets, per_sample_weights):
1371+
embedding = embedding_rref.local_value()
1372+
return embedding(input, offsets, per_sample_weights)
1373+
1374+
def _get_grad(embedding_rref, context_id):
1375+
embedding = embedding_rref.local_value()
1376+
grad_map = dist_autograd.get_gradients(context_id)
1377+
# Can't send sparse tensors over RPC: https://github.com/pytorch/pytorch/issues/30807
1378+
return grad_map[embedding.weight].to_dense()
1379+
1380+
@dist_init
1381+
def test_embedding_bag_with_no_grad_tensors(self):
1382+
dst = self._next_rank()
1383+
remote_embedding = rpc.remote("worker{}".format(dst),
1384+
torch.nn.EmbeddingBag, args=(16, 16),
1385+
kwargs={'mode': 'sum', 'sparse': True})
1386+
local_embedding = torch.nn.EmbeddingBag(16, 16, mode='sum', sparse=True)
1387+
1388+
input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1389+
# requires_grad = True to record send/recv functions
1390+
per_sample_weights = torch.rand((8), requires_grad=True)
1391+
offsets = torch.LongTensor([0, 4])
1392+
1393+
local_res = local_embedding(input, offsets, per_sample_weights)
1394+
local_res.sum().backward()
1395+
local_grad = local_embedding.weight.grad
1396+
1397+
with dist_autograd.context() as context_id:
1398+
res = rpc.rpc_sync("worker{}".format(dst),
1399+
DistAutogradTest._call_remote_embedding,
1400+
args=(remote_embedding, input, offsets, per_sample_weights))
1401+
1402+
dist_autograd.backward([res.sum()])
1403+
1404+
remote_grad = rpc.rpc_sync("worker{}".format(dst),
1405+
DistAutogradTest._get_grad,
1406+
args=(remote_embedding, context_id))
1407+
1408+
self.assertEqual(local_grad.to_dense(), remote_grad)
1409+
1410+
def _mixed_requires_grad(t1, t2):
1411+
if t2.requires_grad:
1412+
return t1 - t2
1413+
else:
1414+
return t1 * t2
1415+
1416+
@dist_init
1417+
def test_mixed_requires_grad(self):
1418+
for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
1419+
t1 = torch.rand((3, 3), requires_grad=True)
1420+
t2 = torch.rand((3, 3), requires_grad=False)
1421+
with dist_autograd.context() as context_id:
1422+
ret = self._exec_func(exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2)
1423+
self.assertEqual(t1 * t2, ret)
1424+
dist_autograd.backward([ret.sum()])
1425+
self.assertTrue(t1.requires_grad)
1426+
self.assertFalse(t2.requires_grad)
1427+
grads = dist_autograd.get_gradients(context_id)
1428+
self.assertIn(t1, grads)
1429+
self.assertNotIn(t2, grads)
1430+
self.assertEqual(t2, grads[t1])
13731431

13741432
if __name__ == '__main__':
13751433
unittest.main()

torch/csrc/distributed/autograd/utils.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ void addSendRpcBackward(
2323
const AutogradMetadata& autogradMetadata,
2424
std::vector<torch::Tensor>& tensors,
2525
const rpc::worker_id_t dst) {
26+
// Attach autograd information only for tensors requiring grad.
27+
std::vector<torch::Tensor> tensors_with_grad;
28+
std::copy_if(
29+
tensors.begin(),
30+
tensors.end(),
31+
std::back_inserter(tensors_with_grad),
32+
[](const torch::Tensor& t) { return t.requires_grad(); });
33+
2634
// Attach the appropriate autograd edges.
2735
auto grad_fn = std::make_shared<SendRpcBackward>();
28-
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
36+
grad_fn->set_next_edges(
37+
torch::autograd::collect_next_edges(tensors_with_grad));
2938

3039
// Add the appropriate input metadata for the grad_fn.
31-
for (const auto& tensor : tensors) {
40+
for (const auto& tensor : tensors_with_grad) {
3241
grad_fn->add_input_metadata(tensor);
3342
}
3443

@@ -52,7 +61,9 @@ ContextPtr addRecvRpcBackward(
5261
auto grad_fn = std::make_shared<RecvRpcBackward>(
5362
autogradMetadata, autogradContext, fromWorkerId);
5463
for (auto& tensor : tensors) {
55-
torch::autograd::set_history(tensor, grad_fn);
64+
if (tensor.requires_grad()) {
65+
torch::autograd::set_history(tensor, grad_fn);
66+
}
5667
}
5768

5869
// Now update the autograd context with the necessary information.

0 commit comments

Comments
 (0)