58
58
Iterator ,
59
59
List ,
60
60
Optional ,
61
- Sequence ,
62
61
Set ,
63
62
Tuple ,
64
63
Union ,
@@ -128,7 +127,7 @@ class WrappedFunction:
128
127
"""
129
128
130
129
model : Callable [..., Any ]
131
- static_input_idxs : Sequence [int ]
130
+ static_input_idxs : List [int ]
132
131
id : FunctionID
133
132
constants : Tuple [torch .Tensor , ...]
134
133
@@ -787,6 +786,16 @@ def __init__(
787
786
set (wrapped_function .static_input_idxs ) | set (self .cudagraph_managed_idxs )
788
787
)
789
788
789
+ self .non_static_input_idx : LevelList [int ] = [
790
+ i for i in range (len (inputs )) if i not in self .static_input_idxs
791
+ ]
792
+
793
+ self .non_managed_static_input_idxs : LevelList [int ] = [
794
+ i
795
+ for i in wrapped_function .static_input_idxs
796
+ if i not in self .cudagraph_managed_idxs
797
+ ]
798
+
790
799
self .static_input_data_ptrs : InputList [Optional [int ]] = [
791
800
(
792
801
inputs [i ].data_ptr ()
@@ -924,6 +933,23 @@ def _copy_input(self, idx, dst, src):
924
933
# TODO - one jit kernel across multiple inputs
925
934
dst .copy_ (src )
926
935
936
+ def check_static_inputs_are_stable (self , new_inputs ):
937
+ # avoid checking managed tensor static points since we already checked those in check_invariants
938
+ if not torch ._C ._tensors_data_ptrs_at_indices_equal (
939
+ new_inputs , self .static_input_data_ptrs , self .non_managed_static_input_idxs
940
+ ):
941
+ # this should error
942
+ static_tensors = [new_inputs [i ] for i in self .non_managed_static_input_idxs ]
943
+ data_ptrs = [
944
+ self .static_input_data_ptrs [i ]
945
+ for i in self .non_managed_static_input_idxs
946
+ ]
947
+ for t , data_ptr in zip (static_tensors , data_ptrs ):
948
+ torch ._check (
949
+ t .data_ptr () == data_ptr ,
950
+ lambda : f"static input data pointer changed from { data_ptr } to { t .data_ptr ()} " ,
951
+ )
952
+
927
953
def run_first_inputs (self , new_inputs ):
928
954
if config .triton .fast_path_cudagraph_asserts :
929
955
self .debug_check_invariants_before_invocation ()
@@ -936,30 +962,23 @@ def run_first_inputs(self, new_inputs):
936
962
return outputs
937
963
938
964
def run (self , new_inputs ):
939
- if config .triton .fast_path_cudagraph_asserts :
940
- self .debug_check_invariants_before_invocation ()
965
+ self .check_static_inputs_are_stable (new_inputs )
941
966
942
- assert len (self .static_input_data_ptrs ) == len (new_inputs )
943
- # NB: this ranges over non-static inputs too
944
- for idx , data_ptr in enumerate (self .static_input_data_ptrs ):
945
- if idx in self .cudagraph_managed_idxs :
946
- continue
967
+ for idx in self .non_static_input_idx :
947
968
if not isinstance (new_inputs [idx ], torch .Tensor ):
948
- pass
949
- elif data_ptr is not None :
950
- # static input, e.g., parameter
951
- assert data_ptr == new_inputs [idx ].data_ptr ()
952
- else :
953
- # non-static input, need to copy it into CUDA graph
954
- dst = self .reconstructed_inputs [idx ]
955
- src = new_inputs [idx ]
956
- self ._copy_input (idx , dst , src )
969
+ continue
970
+
971
+ # non-static input, need to copy it into CUDA graph
972
+ self ._copy_input (idx , self .reconstructed_inputs [idx ], new_inputs [idx ])
957
973
958
974
new_inputs .clear ()
975
+
959
976
self .run_graph ()
960
977
961
978
outputs = self .reconstruct_outputs ()
962
- self .debug_check_invariants_after_invocation ()
979
+
980
+ if config .triton .fast_path_cudagraph_asserts :
981
+ self .debug_check_invariants_after_invocation ()
963
982
964
983
return outputs
965
984
@@ -1513,9 +1532,12 @@ def check_invariants(self, inputs: List[Tensor]) -> bool:
1513
1532
"""
1514
1533
1515
1534
# previously managed data pointers remain stable
1516
- for idx in self .cudagraph_managed_idxs :
1517
- if inputs [idx ].data_ptr () != self .static_input_data_ptrs [idx ]:
1518
- return False
1535
+ # this is on the hot path so moved to C++. equivalent to:
1536
+ # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs))
1537
+ if not torch ._C ._tensors_data_ptrs_at_indices_equal (
1538
+ inputs , self .static_input_data_ptrs , self .cudagraph_managed_idxs
1539
+ ):
1540
+ return False
1519
1541
1520
1542
if not self ._check_liveness (
1521
1543
self .expected_dead_indices_before_graph , self .path_weakrefs
@@ -1931,7 +1953,7 @@ def add_function(
1931
1953
self .ids_to_stack_traces [id ] = stack_traces
1932
1954
self .ids_to_funcs [id ] = WrappedFunction (
1933
1955
model ,
1934
- static_input_idxs ,
1956
+ list ( static_input_idxs ) ,
1935
1957
id ,
1936
1958
tuple (t for t in constants if isinstance (t , torch .Tensor ) and t .is_cuda ),
1937
1959
)
0 commit comments