forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextra_state.cpp
165 lines (150 loc) · 5.04 KB
/
extra_state.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/utils/python_compat.h>
#if IS_PYTHON_3_12_PLUS
#define _PyCode_GetExtra PyUnstable_Code_GetExtra
#define _PyCode_SetExtra PyUnstable_Code_SetExtra
#endif
Py_ssize_t extra_index = -1;
CacheEntry* ExtraState::get_first_entry() {
if (this->cache_entry_list.empty()) {
return nullptr;
}
return &this->cache_entry_list.front();
}
void ExtraState::move_to_front(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.splice(
this->cache_entry_list.begin(),
this->cache_entry_list,
cache_entry->_owner_loc);
}
void ExtraState::invalidate(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.erase(cache_entry->_owner_loc);
}
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
if (extra_state == nullptr || extra_state == SKIP_CODE) {
return nullptr;
}
return extra_state->get_first_entry();
}
FrameState* extract_frame_state(ExtraState* extra_state) {
if (extra_state == nullptr || extra_state == SKIP_CODE) {
return nullptr;
}
return (FrameState*)extra_state->frame_state.ptr();
}
ExtraState* get_extra_state(PyCodeObject* code) {
ExtraState* extra = nullptr;
_PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
return extra;
}
void destroy_extra_state(void* obj) {
ExtraState* extra = (ExtraState*)obj;
if (extra != nullptr && extra != SKIP_CODE) {
delete extra;
}
}
void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
ExtraState* old_extra_state = get_extra_state(code);
CHECK(
old_extra_state == nullptr || old_extra_state == SKIP_CODE ||
old_extra_state != extra_state);
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
}
ExtraState* init_and_set_extra_state(PyCodeObject* code) {
// Invariant - Extra state should not have been set before, therefore it
// should be nullptr.
CHECK(get_extra_state(code) == nullptr);
ExtraState* extra_state = new ExtraState();
NULL_CHECK(extra_state);
set_extra_state(code, extra_state);
return extra_state;
}
PyObject* lookup(
ExtraState* extra_state,
PyObject* f_locals,
const PyObject* backend) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
// Check backend. Py_False means run only mode.
bool valid = backend == Py_False || cache_entry.backend == backend;
if (valid) {
try {
// TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
// True by default
if (cache_entry.root_mgr != nullptr) {
valid = torch::dynamo::run_root_guard_manager(
cache_entry.root_mgr, f_locals);
} else {
valid = cache_entry.check_fn(locals).cast<bool>();
}
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return nullptr;
}
}
if (valid) {
found = &cache_entry;
break;
}
++index;
}
if (found) {
extra_state->move_to_front(found);
return found->code.ptr();
}
return py::none().ptr();
}
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code,
PyObject* backend) {
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
// Set check_fn references to extra_state and CacheEntry
// Warning: lifetime is controlled by C++!
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
check_fn.attr("cache_entry") =
py::cast(*new_iter, py::return_value_policy::reference);
check_fn.attr("extra_state") =
py::cast(extra_state, py::return_value_policy::reference);
return &*new_iter;
}
py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
throw py::type_error("expected a code object!");
}
PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
ExtraState* extra = get_extra_state(code);
py::list result;
if (extra && extra != SKIP_CODE) {
for (CacheEntry& e : extra->cache_entry_list) {
result.append(py::cast(e, py::return_value_policy::reference));
}
}
return result;
}