Skip to content

Commit 249e65b

Browse files
eee4017pytorchmergebot
authored andcommitted
Graph-Safe RNG State Exchange for Tensor Parallelism (pytorch#114068)
See pytorch#113541 The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality. cc @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli Pull Request resolved: pytorch#114068 Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
1 parent fe41ba4 commit 249e65b

File tree

15 files changed

+643
-138
lines changed

15 files changed

+643
-138
lines changed

aten/src/ATen/core/Generator.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,12 @@ at::Tensor Generator::get_state() const {
1313
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
1414
}
1515

16+
void Generator::graphsafe_set_state(const Generator& new_state) {
17+
this->impl_->graphsafe_set_state(new_state.getIntrusivePtr());
18+
}
19+
20+
Generator Generator::graphsafe_get_state() const {
21+
return Generator(this->impl_->graphsafe_get_state());
22+
}
23+
1624
} // namespace at

aten/src/ATen/core/Generator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ struct TORCH_API Generator {
107107

108108
at::Tensor get_state() const;
109109

110+
void graphsafe_set_state(const Generator& new_state);
111+
112+
Generator graphsafe_get_state() const;
113+
110114
std::mutex& mutex() {
111115
return impl_->mutex_;
112116
}

0 commit comments

Comments
 (0)