Skip to content

Commit c20cf97

Browse files
mlazospytorchmergebot
authored andcommitted
Move some cudagraphs checks into C++ (pytorch#122251)
Based off of pytorch#111094 This + cpp guards improves TIMM geomean optimizer performance by about 20% Pull Request resolved: pytorch#122251 Approved by: https://github.com/eellison
1 parent be5863d commit c20cf97

File tree

4 files changed

+77
-23
lines changed

4 files changed

+77
-23
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,23 @@ def foo(mod, inp):
12301230
node = self.get_manager().current_node
12311231
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
12321232

1233+
def test_unstable_ptr(self):
1234+
import torch
1235+
1236+
@torch.compile(mode="reduce-overhead")
1237+
def foo(m, inp):
1238+
return m(inp)
1239+
1240+
def f():
1241+
l = []
1242+
m = torch.nn.Linear(20, 20).cuda()
1243+
for _ in range(4):
1244+
inp = torch.rand([20, 20], device="cuda")
1245+
foo(m, inp)
1246+
m.weight.data = torch.rand([20, 20], device="cuda")
1247+
1248+
self.assertRaises(RuntimeError, f)
1249+
12331250
@requires_multigpu()
12341251
def test_manager_per_device(self):
12351252
def test():

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,7 @@ def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_
16951695
def _set_cached_tensors_enabled(enabled: _bool) -> None: ...
16961696
def _add_cached_tensor(t: Tensor) -> None: ...
16971697
def _remove_cached_tensor(t: Tensor) -> None: ...
1698+
def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ...
16981699
def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
16991700
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
17001701
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...

torch/_inductor/cudagraph_trees.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
Iterator,
5959
List,
6060
Optional,
61-
Sequence,
6261
Set,
6362
Tuple,
6463
Union,
@@ -128,7 +127,7 @@ class WrappedFunction:
128127
"""
129128

130129
model: Callable[..., Any]
131-
static_input_idxs: Sequence[int]
130+
static_input_idxs: List[int]
132131
id: FunctionID
133132
constants: Tuple[torch.Tensor, ...]
134133

@@ -787,6 +786,16 @@ def __init__(
787786
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
788787
)
789788

789+
self.non_static_input_idx: LevelList[int] = [
790+
i for i in range(len(inputs)) if i not in self.static_input_idxs
791+
]
792+
793+
self.non_managed_static_input_idxs: LevelList[int] = [
794+
i
795+
for i in wrapped_function.static_input_idxs
796+
if i not in self.cudagraph_managed_idxs
797+
]
798+
790799
self.static_input_data_ptrs: InputList[Optional[int]] = [
791800
(
792801
inputs[i].data_ptr()
@@ -924,6 +933,23 @@ def _copy_input(self, idx, dst, src):
924933
# TODO - one jit kernel across multiple inputs
925934
dst.copy_(src)
926935

936+
def check_static_inputs_are_stable(self, new_inputs):
937+
# avoid checking managed tensor static points since we already checked those in check_invariants
938+
if not torch._C._tensors_data_ptrs_at_indices_equal(
939+
new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs
940+
):
941+
# this should error
942+
static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs]
943+
data_ptrs = [
944+
self.static_input_data_ptrs[i]
945+
for i in self.non_managed_static_input_idxs
946+
]
947+
for t, data_ptr in zip(static_tensors, data_ptrs):
948+
torch._check(
949+
t.data_ptr() == data_ptr,
950+
lambda: f"static input data pointer changed from {data_ptr} to {t.data_ptr()}",
951+
)
952+
927953
def run_first_inputs(self, new_inputs):
928954
if config.triton.fast_path_cudagraph_asserts:
929955
self.debug_check_invariants_before_invocation()
@@ -936,30 +962,23 @@ def run_first_inputs(self, new_inputs):
936962
return outputs
937963

938964
def run(self, new_inputs):
939-
if config.triton.fast_path_cudagraph_asserts:
940-
self.debug_check_invariants_before_invocation()
965+
self.check_static_inputs_are_stable(new_inputs)
941966

942-
assert len(self.static_input_data_ptrs) == len(new_inputs)
943-
# NB: this ranges over non-static inputs too
944-
for idx, data_ptr in enumerate(self.static_input_data_ptrs):
945-
if idx in self.cudagraph_managed_idxs:
946-
continue
967+
for idx in self.non_static_input_idx:
947968
if not isinstance(new_inputs[idx], torch.Tensor):
948-
pass
949-
elif data_ptr is not None:
950-
# static input, e.g., parameter
951-
assert data_ptr == new_inputs[idx].data_ptr()
952-
else:
953-
# non-static input, need to copy it into CUDA graph
954-
dst = self.reconstructed_inputs[idx]
955-
src = new_inputs[idx]
956-
self._copy_input(idx, dst, src)
969+
continue
970+
971+
# non-static input, need to copy it into CUDA graph
972+
self._copy_input(idx, self.reconstructed_inputs[idx], new_inputs[idx])
957973

958974
new_inputs.clear()
975+
959976
self.run_graph()
960977

961978
outputs = self.reconstruct_outputs()
962-
self.debug_check_invariants_after_invocation()
979+
980+
if config.triton.fast_path_cudagraph_asserts:
981+
self.debug_check_invariants_after_invocation()
963982

964983
return outputs
965984

@@ -1513,9 +1532,12 @@ def check_invariants(self, inputs: List[Tensor]) -> bool:
15131532
"""
15141533

15151534
# previously managed data pointers remain stable
1516-
for idx in self.cudagraph_managed_idxs:
1517-
if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]:
1518-
return False
1535+
# this is on the hot path so moved to C++. equivalent to:
1536+
# return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs))
1537+
if not torch._C._tensors_data_ptrs_at_indices_equal(
1538+
inputs, self.static_input_data_ptrs, self.cudagraph_managed_idxs
1539+
):
1540+
return False
15191541

15201542
if not self._check_liveness(
15211543
self.expected_dead_indices_before_graph, self.path_weakrefs
@@ -1931,7 +1953,7 @@ def add_function(
19311953
self.ids_to_stack_traces[id] = stack_traces
19321954
self.ids_to_funcs[id] = WrappedFunction(
19331955
model,
1934-
static_input_idxs,
1956+
list(static_input_idxs),
19351957
id,
19361958
tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda),
19371959
)

torch/csrc/cuda/Module.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,20 @@ static void registerCudaPluggableAllocator(PyObject* module) {
11911191
return c10::raw::weak_intrusive_ptr::use_count(storage_impl);
11921192
});
11931193

1194+
m.def(
1195+
"_tensors_data_ptrs_at_indices_equal",
1196+
[](py::list& tensors, py::list& data_ptrs, py::list& indices) {
1197+
for (size_t i = 0, end = indices.size(); i < end; ++i) {
1198+
auto index = indices[i].cast<int64_t>();
1199+
auto t = tensors[index].cast<at::Tensor>();
1200+
auto data_ptr = data_ptrs[index].cast<int64_t>();
1201+
if (reinterpret_cast<int64_t>(t.data_ptr()) != data_ptr) {
1202+
return false;
1203+
}
1204+
}
1205+
return true;
1206+
});
1207+
11941208
m.def(
11951209
"_construct_CUDA_Tensor_From_Storage_And_Metadata",
11961210
[](py::dict& metadata, c10::Storage s) {

0 commit comments

Comments
 (0)