Skip to content

Commit 5fe629e

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Add PyObject preservation for UntypedStorage (pytorch#97470)
Part of pytorch#91395 Pull Request resolved: pytorch#97470 Approved by: https://github.com/ezyang
1 parent 488a430 commit 5fe629e

25 files changed

+1068
-246
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ libtorch_python_core_sources = [
888888
"torch/csrc/utils/python_dispatch.cpp",
889889
"torch/csrc/utils/python_symnode.cpp",
890890
"torch/csrc/utils/pybind.cpp",
891+
"torch/csrc/utils/pyobject_preservation.cpp",
891892
"torch/csrc/utils/structseq.cpp",
892893
"torch/csrc/utils/tensor_apply.cpp",
893894
"torch/csrc/utils/tensor_dtypes.cpp",

c10/core/RefcountedDeleter.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include <c10/core/RefcountedDeleter.h>
2+
3+
#include <mutex>
4+
5+
namespace c10 {
6+
7+
void refcounted_deleter(void* ctx_) {
8+
RefcountedDeleterContext& ctx =
9+
*reinterpret_cast<RefcountedDeleterContext*>(ctx_);
10+
ctx.refcount--;
11+
if (ctx.refcount == 0) {
12+
ctx.other_ctx = nullptr;
13+
delete &ctx;
14+
}
15+
}
16+
17+
std::mutex replace_data_ptr_mutex;
18+
19+
void maybeApplyRefcountedDeleter(c10::Storage storage) {
20+
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
21+
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
22+
23+
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
24+
// Data pointer is already shared
25+
return;
26+
}
27+
28+
void* data = data_ptr.get();
29+
void* other_ctx = data_ptr.get_context();
30+
c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
31+
c10::Device device = data_ptr.device();
32+
33+
// Release the context of the original DataPtr so that the data doesn't
34+
// get deleted when the original DataPtr is replaced
35+
data_ptr.release_context();
36+
37+
c10::RefcountedDeleterContext* refcount_ctx =
38+
new c10::RefcountedDeleterContext(other_ctx, other_deleter);
39+
40+
c10::DataPtr new_data_ptr(
41+
data,
42+
reinterpret_cast<void*>(refcount_ctx),
43+
&c10::refcounted_deleter,
44+
device);
45+
storage.set_data_ptr(std::move(new_data_ptr));
46+
}
47+
48+
c10::Storage newStorageImplFromRefcountedDataPtr(c10::Storage storage) {
49+
c10::maybeApplyRefcountedDeleter(storage);
50+
51+
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
52+
53+
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
54+
c10::DataPtr new_data_ptr(
55+
data_ptr.get(),
56+
data_ptr.get_context(),
57+
data_ptr.get_deleter(),
58+
data_ptr.device());
59+
60+
// NOTE: This refcount increment should always happen immediately after
61+
// `new_data_ptr` is created. No other lines of code should be added between
62+
// them in the future, unless there's a very good reason for it, because if
63+
// any errors are raised and `new_data_ptr` is deleted before the refcount is
64+
// incremented, the refcount will get decremented and end up being one less
65+
// than it should be.
66+
reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
67+
->refcount++;
68+
69+
c10::Allocator* allocator = c10::GetAllocator(storage_impl->device_type());
70+
c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
71+
c10::StorageImpl::use_byte_size_t(),
72+
storage_impl->nbytes(),
73+
allocator,
74+
/*resizable=*/storage_impl->resizable());
75+
new_storage.set_data_ptr(std::move(new_data_ptr));
76+
return new_storage;
77+
}
78+
79+
} // namespace c10

c10/core/RefcountedDeleter.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <c10/core/Storage.h>
4+
#include <c10/util/UniqueVoidPtr.h>
5+
6+
#include <atomic>
7+
#include <memory>
8+
9+
namespace c10 {
10+
11+
// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
12+
// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
13+
// this custom context and the `refcounted_deleter` function below to make the
14+
// DataPtr act like a non-unique DataPtr. This context object holds onto an
15+
// inner context and deleter function which handle the actual deletion of the
16+
// data when the refcount reaches 0.
17+
//
18+
// This shared DataPtr feature is only used when storages are shared between
19+
// multiple Python interpreters in MultiPy. Before storages had PyObject
20+
// preservation, interpreters could just share the same StorageImpl instance.
21+
// But now a StorageImpl can only be associated with one interpreter in order
22+
// to properly manage a zombie PyObject. So we share storages across Python
23+
// interpreters by creating a different StorageImpl instance for each one, but
24+
// they all point to the same data.
25+
struct C10_API RefcountedDeleterContext {
26+
RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
27+
: other_ctx(other_ctx, other_deleter), refcount(1) {}
28+
29+
std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
30+
std::atomic_int refcount;
31+
};
32+
33+
// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
34+
// a shared DataPtr.
35+
//
36+
// Warning: This should only be called on a pointer to
37+
// a RefcountedDeleterContext that was allocated on the heap with `new`,
38+
// because when the refcount reaches 0, the context is deleted with `delete`
39+
C10_API void refcounted_deleter(void* ctx_);
40+
41+
// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
42+
// a DataPtr that does, so it can be shared between multiple StorageImpls
43+
C10_API void maybeApplyRefcountedDeleter(c10::Storage storage);
44+
45+
// Create a new StorageImpl that points to the same data. If the original
46+
// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
47+
// with one that does
48+
C10_API c10::Storage newStorageImplFromRefcountedDataPtr(c10::Storage storage);
49+
50+
} // namespace c10

c10/core/SafePyObject.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct C10_API SafePyObject {
2929
SafePyObject& operator=(SafePyObject const&) = delete;
3030

3131
~SafePyObject() {
32-
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
32+
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
3333
}
3434

3535
c10::impl::PyInterpreter& pyinterpreter() const {

c10/core/Storage.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#include <c10/core/RefcountedDeleter.h>
12
#include <c10/core/Storage.h>
23

3-
namespace c10 {} // namespace c10
4+
namespace c10 {
5+
6+
bool isSharedStorageAlias(const Storage& storage0, const Storage& storage1) {
7+
c10::DeleterFnPtr deleter_expected = &c10::refcounted_deleter;
8+
c10::DeleterFnPtr deleter0 = storage0.data_ptr().get_deleter();
9+
c10::DeleterFnPtr deleter1 = storage1.data_ptr().get_deleter();
10+
11+
if ((deleter0 != deleter_expected) || (deleter1 != deleter_expected)) {
12+
return false;
13+
}
14+
15+
return storage0.data_ptr().get_context() == storage1.data_ptr().get_context();
16+
}
17+
18+
} // namespace c10

c10/core/Storage.h

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
#pragma once
22

33
#include <c10/core/StorageImpl.h>
4+
#include <c10/util/ExclusivelyOwned.h>
45

56
namespace c10 {
67

8+
struct Storage;
9+
10+
C10_API bool isSharedStorageAlias(
11+
const Storage& storage0,
12+
const Storage& storage1);
13+
714
struct C10_API Storage {
815
public:
916
struct use_byte_size_t {};
17+
struct unsafe_borrow_t {
18+
explicit unsafe_borrow_t() = default;
19+
};
1020

1121
Storage() = default;
1222
Storage(c10::intrusive_ptr<StorageImpl> ptr)
@@ -40,6 +50,14 @@ struct C10_API Storage {
4050
allocator,
4151
resizable)) {}
4252

53+
protected:
54+
explicit Storage(unsafe_borrow_t, const Storage& rhs)
55+
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
56+
rhs.storage_impl_.get())) {}
57+
58+
friend MaybeOwnedTraits<Storage>;
59+
60+
public:
4361
// Legacy constructor for partially initialized (dtype or memory) storages
4462
// that can be temporarily created with Caffe2 APIs. See the note on top of
4563
// TensorImpl.h for details.
@@ -144,7 +162,9 @@ struct C10_API Storage {
144162
}
145163

146164
bool is_alias_of(const Storage& other) const {
147-
return storage_impl_ == other.storage_impl_;
165+
return (
166+
storage_impl_ == other.storage_impl_ ||
167+
isSharedStorageAlias(*this, other));
148168
}
149169

150170
void UniqueStorageShareExternalPointer(
@@ -175,4 +195,67 @@ struct C10_API Storage {
175195
c10::intrusive_ptr<StorageImpl> storage_impl_;
176196
};
177197

198+
template <>
199+
struct MaybeOwnedTraits<c10::Storage> {
200+
using owned_type = c10::Storage;
201+
using borrow_type = c10::Storage;
202+
203+
static borrow_type createBorrow(const owned_type& from) {
204+
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
205+
}
206+
207+
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
208+
lhs.unsafeReleaseStorageImpl();
209+
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
210+
}
211+
212+
static void destroyBorrow(borrow_type& toDestroy) {
213+
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
214+
}
215+
216+
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
217+
return borrow;
218+
}
219+
220+
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
221+
return &borrow;
222+
}
223+
224+
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
225+
return true;
226+
}
227+
};
228+
229+
template <>
230+
struct ExclusivelyOwnedTraits<c10::Storage> {
231+
using repr_type = c10::Storage;
232+
using pointer_type = c10::Storage*;
233+
using const_pointer_type = const c10::Storage*;
234+
235+
static repr_type nullRepr() {
236+
return c10::Storage();
237+
}
238+
239+
template <class... Args>
240+
static repr_type createInPlace(Args&&... args) {
241+
return c10::Storage(std::forward<Args>(args)...);
242+
}
243+
244+
static repr_type moveToRepr(c10::Storage&& x) {
245+
return std::move(x);
246+
}
247+
248+
static c10::Storage take(c10::Storage& x) {
249+
return std::move(x);
250+
}
251+
252+
static pointer_type getImpl(repr_type& x) {
253+
return &x;
254+
}
255+
256+
static const_pointer_type getImpl(const repr_type& x) {
257+
return &x;
258+
}
259+
};
260+
178261
} // namespace c10

c10/core/StorageImpl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
205205
return received_cuda_;
206206
}
207207

208+
impl::PyObjectSlot* pyobj_slot() {
209+
return &pyobj_slot_;
210+
}
211+
212+
const impl::PyObjectSlot* pyobj_slot() const {
213+
return &pyobj_slot_;
214+
}
215+
208216
private:
209217
DataPtr data_ptr_;
210218
SymInt size_bytes_;

c10/core/TensorImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ void TensorImpl::_set_fw_grad(
7373
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
7474
}
7575

76-
TensorImpl::~TensorImpl() {
77-
pyobj_slot_.destroy_pyobj_if_needed();
78-
}
76+
TensorImpl::~TensorImpl() = default;
7977

8078
TensorImpl::TensorImpl(
8179
Storage&& storage,
@@ -582,7 +580,7 @@ void TensorImpl::release_resources() {
582580
if (storage_) {
583581
storage_ = {};
584582
}
585-
pyobj_slot_.destroy_pyobj_if_needed();
583+
pyobj_slot_.maybe_destroy_pyobj();
586584
}
587585

588586
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY

c10/core/impl/PyInterpreter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
1010
return "<unloaded interpreter>";
1111
}
1212

13-
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
13+
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
14+
} // do nothing
1415

1516
#define PANIC(m) \
1617
TORCH_INTERNAL_ASSERT( \

c10/core/impl/PyInterpreter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ struct C10_API PyInterpreterVTable {
127127
virtual std::string name() const = 0;
128128

129129
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
130-
// See NOTE [PyInterpreter::decref takes an `is_tensor` arg]
131-
virtual void decref(PyObject* pyobj, bool is_tensor) const = 0;
130+
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
131+
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
132132

133133
// Perform a detach by deferring to the __torch_dispatch__ implementation of
134134
// detach, which will also arrange for the PyObject to get copied in this

0 commit comments

Comments
 (0)