Skip to content

Commit 35f6a69

Browse files
voznesenskympytorchmergebot
authored andcommitted
Python Dispatcher integration with C++ dispatcher (pytorch#84826)
Signed-off-by: Edward Z. Yang <ezyangfb.com> From @ezyang's original PR: There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients: We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch. I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See pybind/pybind11#483 (comment) for how this works; the technique is generally useful. I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: pytorch#84826 Approved by: https://github.com/ezyang
1 parent 44c30c5 commit 35f6a69

33 files changed

+686
-160
lines changed

aten/src/ATen/ThreadLocalState.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ThreadLocalState::ThreadLocalState()
1414
debug_info_(c10::ThreadLocalDebugInfo::current()),
1515
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
1616
autograd_tls_(c10::AutogradState::get_tls_state()),
17+
python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
1718
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
1819
rf_tls_ = at::get_record_function_tls_();
1920

@@ -41,6 +42,8 @@ void ThreadLocalState::setThreadLocalState(
4142

4243
at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
4344

45+
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
46+
4447
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
4548

4649
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);

aten/src/ATen/ThreadLocalState.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/FuncTorchTLS.h>
1111
#include <ATen/PythonTorchFunctionTLS.h>
1212
#include <ATen/record_function.h>
13+
#include <c10/core/impl/PythonDispatcherTLS.h>
1314
#include <c10/core/impl/TorchDispatchModeTLS.h>
1415

1516
namespace at {
@@ -57,6 +58,9 @@ class TORCH_API ThreadLocalState {
5758
// TLS for enable_torch_dispatch_mode
5859
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
5960

61+
// TLS for enable_python_dispatcher
62+
SafePyHandle python_dispatcher_state_;
63+
6064
// TLS for __torch_function__ (mode and disable_torch_function)
6165
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
6266

aten/src/ATen/core/PythonFallbackKernel.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <c10/core/impl/TorchDispatchModeTLS.h>
2+
#include <c10/core/impl/PythonDispatcherTLS.h>
23
#include <ATen/core/PythonFallbackKernel.h>
34
#include <c10/core/SafePyObject.h>
45

@@ -87,6 +88,12 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
8788
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
8889
}
8990

91+
void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
92+
auto state = c10::impl::PythonDispatcherTLS::get_state();
93+
TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
94+
state.pyinterpreter()->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
95+
}
96+
9097
void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
9198
// It is ok for the tls to be already set here.
9299
// It means that there are multiple calls into the dispatcher not originating from python code.
@@ -134,6 +141,10 @@ TORCH_LIBRARY_IMPL(_, Python, m) {
134141
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
135142
}
136143

144+
TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
145+
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
146+
}
147+
137148
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
138149
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
139150
}

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ class TORCH_API Dispatcher final {
168168
// See Note [Plumbing Keys Through The Dispatcher]
169169
void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
170170

171+
bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
172+
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
173+
if (dispatch_ix < 0) return false;
174+
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
175+
}
176+
171177

172178
// ------------------------------------------------------------------------
173179
//
@@ -333,6 +339,10 @@ class TORCH_API OperatorHandle {
333339
return operatorDef_->op.hasKernelForDispatchKey(k);
334340
}
335341

342+
bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
343+
return operatorDef_->op.hasKernelForAnyDispatchKey(k);
344+
}
345+
336346
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
337347
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
338348
}
@@ -635,11 +645,18 @@ inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, Dispat
635645
// We still compute this as we're obligated to pass it on to the internal
636646
// kernel, if it is a boxed fallback
637647
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
638-
const auto& kernel = entry.kernelForDispatchKey(dk);
648+
const auto& kernel = ([&]() {
649+
if (op.hasKernelForDispatchKey(dk)) {
650+
return entry.kernelForDispatchKey(dk);
651+
} else {
652+
auto idx = getDispatchTableIndexForDispatchKey(dk);
653+
TORCH_INTERNAL_ASSERT(idx >= 0);
654+
return backendFallbackKernels_[idx].kernel;
655+
}
656+
})();
639657
kernel.callBoxed(op, dispatchKeySet, stack);
640658
}
641659

642-
643660
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
644661
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
645662
const auto& entry = op.operatorDef_->op;

c10/core/DispatchKey.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ const char* toString(DispatchKey t) {
172172
case DispatchKey::TESTING_ONLY_GenericMode:
173173
return "TESTING_ONLY_GenericMode";
174174

175+
case DispatchKey::PythonDispatcher:
176+
return "PythonDispatcher";
177+
175178
// Aliases
176179

177180
case DispatchKey::Autograd:
@@ -283,6 +286,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
283286
{"TESTING_ONLY_GenericWrapper",
284287
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
285288
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
289+
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
286290

287291
{"CPU", c10::DispatchKey::CPU},
288292
{"CUDA", c10::DispatchKey::CUDA},

c10/core/DispatchKey.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ enum class DispatchKey : uint16_t {
401401
// for a usage example
402402
TESTING_ONLY_GenericMode,
403403

404+
// This is a bypass that allows you to skip running the C++ dispatcher
405+
// entirely
406+
PythonDispatcher,
407+
404408
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
405409
EndOfFunctionalityKeys, // End of functionality keys.
406410

c10/core/DispatchKeySet.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ class DispatchKeySet final {
172172
(1ULL
173173
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
174174
1)) -
175-
1) {}
175+
1) {
176+
*this = add(DispatchKey::PythonDispatcher);
177+
}
176178

177179
// Public version of DispatchKeySet(uint64_t) API; external users
178180
// must be explicit when they do this!

c10/core/SafePyObject.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,9 @@ PyObject* SafePyObject::ptr(const c10::impl::PyInterpreter* interpreter) const {
88
return data_;
99
}
1010

11+
PyObject* SafePyHandle::ptr(const c10::impl::PyInterpreter* interpreter) const {
12+
TORCH_INTERNAL_ASSERT(interpreter == pyinterpreter_);
13+
return data_;
14+
}
15+
1116
} // namespace c10

c10/core/SafePyObject.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,29 @@ struct C10_API SafePyObject {
4242
c10::impl::PyInterpreter* pyinterpreter_;
4343
};
4444

45+
// Like SafePyObject, but non-owning. Good for references to global PyObjects
46+
// that will be leaked on interpreter exit. You get a copy constructor/assign
47+
// this way.
48+
struct C10_API SafePyHandle {
49+
SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
50+
SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
51+
: data_(data), pyinterpreter_(pyinterpreter) {}
52+
53+
c10::impl::PyInterpreter& pyinterpreter() const {
54+
return *pyinterpreter_;
55+
}
56+
PyObject* ptr(const c10::impl::PyInterpreter*) const;
57+
void reset() {
58+
data_ = nullptr;
59+
pyinterpreter_ = nullptr;
60+
}
61+
operator bool() {
62+
return data_;
63+
}
64+
65+
private:
66+
PyObject* data_;
67+
c10::impl::PyInterpreter* pyinterpreter_;
68+
};
69+
4570
} // namespace c10

c10/core/impl/PyInterpreter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
2727
PANIC(dispatch);
2828
}
2929

30+
void python_dispatcher(
31+
const c10::OperatorHandle& op,
32+
c10::DispatchKeySet,
33+
torch::jit::Stack* stack) const override {
34+
PANIC(python_dispatcher);
35+
}
36+
3037
bool is_contiguous(const TensorImpl* self) const override {
3138
PANIC(is_contiguous);
3239
}

0 commit comments

Comments
 (0)