Skip to content

Commit 450504c

Browse files
pbelevichfacebook-github-bot
authored andcommitted
C++ API parity: at::Tensor::set_data
Summary: Pull Request resolved: pytorch#26647 Test Plan: Imported from OSS Differential Revision: D17542604 Pulled By: pbelevich fbshipit-source-id: 37d5d67ebdb9348b5561d983f9bd26d310210983
1 parent 2cf1183 commit 450504c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

test/cpp/api/tensor.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,18 @@ TEST(TensorTest, DetachInplace) {
489489
ASSERT_THROWS_WITH(x.detach_(), message);
490490
ASSERT_THROWS_WITH(y.detach_(), message);
491491
}
492+
493+
TEST(TensorTest, SetData) {
494+
auto x = torch::randn({5});
495+
auto y = torch::randn({5});
496+
ASSERT_FALSE(torch::equal(x, y));
497+
ASSERT_NE(x.data_ptr<float>(), y.data_ptr<float>());
498+
499+
x.set_data(y);
500+
ASSERT_TRUE(torch::equal(x, y));
501+
ASSERT_EQ(x.data_ptr<float>(), y.data_ptr<float>());
502+
503+
x = at::tensor({5});
504+
y = at::tensor({5});
505+
ASSERT_THROWS_WITH(x.set_data(y), "set_data is not implemented for Tensor");
506+
}

0 commit comments

Comments
 (0)