forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dist_autograd.cpp
108 lines (86 loc) · 3.52 KB
/
test_dist_autograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <memory>
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/torch.h>
namespace torch {
namespace distributed {
namespace autograd {
class DistAutogradTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
autogradContainer_ = &DistAutogradContainer::init(0);
}
void TearDown() override {
autogradContainer_->releaseContext(
autogradContainer_->currentContext()->contextId());
}
static DistAutogradContainer* autogradContainer_;
};
DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
auto options = at::TensorOptions().requires_grad(true);
auto in1 = torch::ones({3, 3}, options);
auto in2 = torch::ones({3, 3}, options);
autogradContainer_->newContext();
auto autogradContext = autogradContainer_->currentContext();
// Attach the send autograd function to tensors.
std::vector<torch::Tensor> tensors = {in1, in2};
rpc::worker_id_t worker_id = 1;
addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
autogradContext->addKnownWorkerId(worker_id);
auto send_function = autogradContext->sendFunctions()[1];
// ensure that the worker_ids are recorded
auto knownWorkerIds = autogradContext->getKnownWorkerIds();
ASSERT_TRUE(knownWorkerIds.find(worker_id) != knownWorkerIds.end());
ASSERT_EQ(knownWorkerIds.size(), 1);
// This should fail since the SendRpcBackward function shouldn't receive any
// inputs grad.
EXPECT_THROW(send_function->apply({in1, in2}), c10::Error);
// This should fail since the SendRpcBackward function encounters an undefined
// grad.
send_function->setGrads({in1, torch::autograd::Variable()});
EXPECT_THROW(send_function->apply({}), c10::Error);
}
TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
autogradContainer_->newContext();
auto contextId = autogradContainer_->currentContext()->contextId();
auto& engine = DistEngine::getInstance();
ASSERT_EQ(0, engine.numBackwardPasses());
// Build autograd graph
auto x = torch::randn({2, 2}, torch::requires_grad());
auto y = torch::randn({2, 2}, torch::requires_grad());
auto z = (x * x + y * y).sum();
ASSERT_NE(nullptr, z.grad_fn());
// Execute engine.
engine.execute(contextId, {z}, /* retainGraph */ false);
// Validate appropriate cleanup.
ASSERT_EQ(0, engine.numBackwardPasses());
}
TEST_F(DistAutogradTest, TestInitializedContextCleanupSendFunction) {
autogradContainer_->newContext();
auto context = autogradContainer_->currentContext();
auto& engine = DistEngine::getInstance();
ASSERT_EQ(0, engine.numBackwardPasses());
// Attach send function.
auto options = at::TensorOptions().requires_grad(true);
auto t = torch::ones({1}, options);
auto tensors = std::vector<torch::Tensor>{t};
addSendRpcBackward(
context, AutogradMetadata(context->contextId(), 0), tensors);
auto sendFunction = context->retrieveSendFunction(0);
sendFunction->setGrads({t});
// Execute engine.
engine
.executeSendFunctionAsync(context, sendFunction, /*retrainGraph*/ false)
->wait();
// Validate appropriate cleanup.
ASSERT_EQ(0, engine.numBackwardPasses());
}
} // namespace autograd
} // namespace distributed
} // namespace torch