16
16
17
17
import torch
18
18
19
- from torch ._dynamo .device_interface import DeviceGuard , get_interface_for_device
20
19
from .coordinate_descent_tuner import CoordescTuner
21
20
22
21
from .hints import (
23
22
_NUM_THREADS_PER_WARP ,
24
23
AutotuneHint ,
24
+ DeviceProperties ,
25
25
HeuristicType ,
26
26
ReductionHint ,
27
27
TileHint ,
@@ -144,21 +144,19 @@ def __init__(
144
144
145
145
assert len (configs ) > 0 , "Non-empty TritonConfig list required for compiling"
146
146
self .fn = fn
147
- self .triton_meta = triton_meta
147
+ self .device_props : DeviceProperties = triton_meta ["device" ]
148
+ self .triton_meta = {
149
+ ** triton_meta ,
150
+ "device" : self .device_props .index ,
151
+ "device_type" : self .device_props .type ,
152
+ }
148
153
self .inductor_meta = {} if inductor_meta is None else inductor_meta
149
154
self .save_cache_hook = save_cache_hook
150
155
self .mutated_arg_names = mutated_arg_names
151
156
self .configs = configs
152
157
self .heuristic_type = heuristic_type
153
158
self .custom_kernel = custom_kernel
154
159
self .cuda_kernel_saved = False
155
-
156
- # Align the default design that default as cuda
157
- self .device_type = (
158
- triton_meta ["device_type" ] if "device_type" in triton_meta else "cuda"
159
- )
160
- self .device_interface = get_interface_for_device (self .device_type )
161
-
162
160
if log .isEnabledFor (logging .DEBUG ):
163
161
log .debug (
164
162
"CachingAutotuner gets %d configs for %s" ,
@@ -186,7 +184,7 @@ def __init__(
186
184
)
187
185
self .filename = filename
188
186
189
- def precompile (self , warm_cache_only_with_cc = None ):
187
+ def precompile (self , warm_cache_only = False ):
190
188
with self .lock :
191
189
if self .launchers :
192
190
return
@@ -197,7 +195,7 @@ def precompile(self, warm_cache_only_with_cc=None):
197
195
for c in self .configs :
198
196
try :
199
197
compiled_binary , launcher = self ._precompile_config (
200
- c , warm_cache_only_with_cc
198
+ c , warm_cache_only
201
199
)
202
200
except OutOfResources as e :
203
201
if len (self .configs ) == 1 :
@@ -215,19 +213,19 @@ def precompile(self, warm_cache_only_with_cc=None):
215
213
216
214
seen_configs = set (self .configs )
217
215
218
- device_prop = self .device_interface .Worker .get_device_properties (
219
- self .triton_meta ["device" ]
220
- )
216
+ device_prop = self .device_props
221
217
if (
222
218
self .inductor_meta .get ("dynamic_scale_rblock" , True )
223
219
and self .heuristic_type == HeuristicType .REDUCTION
224
220
and self .size_hints is not None
225
- # Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
226
- and not self .inductor_meta .get ("is_hip" )
227
- # Disable for Intel GPU as Triton is not ready to return n_regs for a compiled_binary.
228
- and self .device_type != "xpu"
221
+ # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
222
+ and device_prop .type == "cuda"
223
+ and device_prop .major
229
224
and device_prop .major >= 8
230
225
):
226
+ assert device_prop .regs_per_multiprocessor
227
+ assert device_prop .max_threads_per_multi_processor
228
+ assert device_prop .multi_processor_count
231
229
for triton_config , compiled_binary in zip (
232
230
self .configs , compiled_binaries
233
231
):
@@ -288,15 +286,21 @@ def precompile(self, warm_cache_only_with_cc=None):
288
286
continue
289
287
seen_configs .add (new_config )
290
288
self .launchers .append (
291
- self ._precompile_config (new_config , warm_cache_only_with_cc )[1 ]
289
+ self ._precompile_config (new_config , warm_cache_only )[1 ]
292
290
)
293
291
self .configs = None
294
292
295
- def _precompile_config (self , cfg : Config , warm_cache_only_with_cc : Optional [int ]):
293
+ def get_device_interface (self ):
294
+ # this code cannot run in compile workers, because it imports from torch
295
+ from torch ._dynamo .device_interface import get_interface_for_device
296
+
297
+ return get_interface_for_device (self .device_props .type .replace ("hip" , "cuda" ))
298
+
299
+ def _precompile_config (self , cfg : Config , warm_cache_only : bool ):
296
300
"""Ahead of time compile a given autotuner config."""
297
301
compile_meta = copy .deepcopy (self .triton_meta )
298
302
for k , v in cfg .kwargs .items ():
299
- if torch . version . hip is not None :
303
+ if self . device_props . type != "hip" :
300
304
if k == "matrix_instr_nonkdim" :
301
305
compile_meta ["matrix_instr_nonkdim" ] = v
302
306
continue
@@ -310,22 +314,9 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
310
314
"assert_indirect_indexing" , True
311
315
) and not self .inductor_meta .get ("is_hip" , False )
312
316
313
- # Setting device_type="hip" required on ROCm to pass down to triton
314
- compile_meta ["device_type" ] = (
315
- self .device_type if torch .version .hip is None else "hip"
316
- )
317
-
318
- if warm_cache_only_with_cc :
319
- cc = warm_cache_only_with_cc
320
- else :
321
- # Use device_type 'cuda' for both cuda and hip devices to retrieve
322
- # the compute capability.
323
- device_type = self .device_type if torch .version .hip is None else "cuda"
324
- device_id = compile_meta ["device" ]
325
- device = torch .device (device_type , device_id )
326
- cc = self .device_interface .get_compute_capability (device )
327
-
328
- compile_meta ["cc" ] = cc
317
+ # device type will be "hip" rather than "cuda" here
318
+ compile_meta ["device_type" ] = self .device_props .type
319
+ compile_meta ["cc" ] = self .device_props .cc
329
320
330
321
if ASTSource :
331
322
compile_args = (
@@ -354,7 +345,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
354
345
"num_stages" : compile_meta ["num_stages" ],
355
346
"debug" : compile_meta ["debug" ],
356
347
}
357
- if torch . version . hip is not None :
348
+ if self . device_props . type != "hip" :
358
349
if "waves_per_eu" in compile_meta :
359
350
options ["waves_per_eu" ] = compile_meta ["waves_per_eu" ]
360
351
if "matrix_instr_nonkdim" in compile_meta :
@@ -369,16 +360,21 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
369
360
compile_args = (self .fn ,)
370
361
compile_kwargs = compile_meta
371
362
372
- if warm_cache_only_with_cc :
363
+ if warm_cache_only :
373
364
return (
374
365
triton .compile (* compile_args , ** compile_kwargs ),
375
366
None ,
376
367
)
377
368
369
+ # importing from torch is safe now that precompile has returned
370
+ from torch ._dynamo .device_interface import DeviceGuard
371
+
372
+ device_interface = self .get_device_interface ()
373
+
378
374
# load binary to the correct device
379
- with DeviceGuard (self . device_interface , compile_meta ["device" ]): # type: ignore[attr-defined]
375
+ with DeviceGuard (device_interface , compile_meta ["device" ]): # type: ignore[attr-defined]
380
376
# need to initialize context
381
- self . device_interface .synchronize (self . device_interface .current_device ())
377
+ device_interface .synchronize (device_interface .current_device ())
382
378
383
379
try :
384
380
binary = triton .compile (* compile_args , ** compile_kwargs )
@@ -596,8 +592,9 @@ def bench(self, launcher, *args, grid, **kwargs):
596
592
)
597
593
return float ("inf" )
598
594
599
- stream = self .device_interface .get_raw_stream ( # type: ignore[call-arg]
600
- self .device_interface .current_device ()
595
+ device_interface = self .get_device_interface ()
596
+ stream = device_interface .get_raw_stream ( # type: ignore[call-arg]
597
+ device_interface .current_device ()
601
598
)
602
599
603
600
def kernel_call ():
@@ -706,7 +703,7 @@ def save_cuda_kernel(self, grid, stream, launcher):
706
703
707
704
binary = (
708
705
launcher .bin .asm ["cubin" ]
709
- if torch . version . hip is None
706
+ if self . device_props . type != "hip"
710
707
else launcher .bin .asm ["hsaco" ]
711
708
)
712
709
CudaKernelParamCache .set (key , params , binary )
@@ -736,7 +733,7 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
736
733
737
734
def benchmark_one_config (config ):
738
735
with self .lock :
739
- _ , launcher = self ._precompile_config (config , None )
736
+ _ , launcher = self ._precompile_config (config , False )
740
737
config2launcher [config ] = launcher
741
738
742
739
out = self .bench (launcher , * cloned_args , ** kwargs )
0 commit comments