Skip to content

Commit 254128c

Browse files
masnesralpytorchmergebot
authored andcommitted
[inductor] Remove usage of device_interface from _inductor.runtime (pytorch#124592)
Differential Revision: [D56723770](https://our.internmc.facebook.com/intern/diff/D56723770) Co-authored-by: Sam Larsen <slarsen@meta.com> Pull Request resolved: pytorch#124592 Approved by: https://github.com/masnesral
1 parent 5f4c6d9 commit 254128c

File tree

8 files changed

+99
-78
lines changed

8 files changed

+99
-78
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch._dynamo.utils import same
1515
from torch._inductor import config
1616
from torch._inductor.compile_fx import compile_fx_inner
17+
from torch._inductor.runtime.hints import DeviceProperties
1718
from torch._inductor.utils import run_and_get_code
1819
from torch.fx.experimental.proxy_tensor import make_fx
1920
from torch.testing import FileCheck
@@ -405,7 +406,7 @@ def decorator(fn):
405406
],
406407
meta={
407408
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
408-
"device": 0,
409+
"device": DeviceProperties.create(torch.device("cuda")),
409410
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
410411
"constants": {},
411412
},

torch/_inductor/codecache.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,12 @@
4545
Optional,
4646
Set,
4747
Tuple,
48-
Type,
4948
TYPE_CHECKING,
5049
Union,
5150
)
5251

5352
import torch
54-
from torch._dynamo.device_interface import (
55-
get_interface_for_device,
56-
get_registered_device_interfaces,
57-
)
53+
from torch._dynamo.device_interface import get_registered_device_interfaces
5854
from torch._dynamo.utils import counters, dynamo_timed
5955
from torch._inductor import config, exc, metrics
6056
from torch._inductor.codegen.cuda import cuda_env
@@ -70,7 +66,6 @@
7066
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
7167

7268
if TYPE_CHECKING:
73-
from torch._dynamo.device_interface import DeviceInterface
7469
from torch._inductor.graph import GraphLowering
7570
from torch._inductor.ir import ChoiceCaller
7671

@@ -2823,14 +2818,9 @@ def _set_triton_ptxas_path() -> None:
28232818

28242819
def _worker_compile_triton(
28252820
load_kernel: Callable[[], Any],
2826-
cc: int,
2827-
device: torch.device,
2828-
device_interface: Type[DeviceInterface],
28292821
):
28302822
_set_triton_ptxas_path()
2831-
device_interface.Worker.set_device(device.index)
2832-
kernel = load_kernel()
2833-
kernel.precompile(warm_cache_only_with_cc=cc)
2823+
load_kernel().precompile(warm_cache_only=True)
28342824

28352825

28362826
class CodeCacheFuture:
@@ -2993,17 +2983,13 @@ def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
29932983

29942984
kernel = TritonCodeCache.load(kernel_name, source_code)
29952985
if config.compile_threads > 1:
2996-
device_interface = get_interface_for_device(device_str)
2997-
device = torch.device(device_str, device_interface.current_device())
2998-
cc = device_interface.get_compute_capability(device)
2999-
future = self.process_pool().submit(
3000-
_worker_compile_triton,
3001-
kernel._reload_in_subproc,
3002-
cc,
3003-
device,
3004-
device_interface,
2986+
return TritonFuture(
2987+
kernel,
2988+
self.process_pool().submit(
2989+
_worker_compile_triton,
2990+
kernel._reload_in_subproc,
2991+
),
30052992
)
3006-
return TritonFuture(kernel, future)
30072993
else:
30082994
kernel.precompile()
30092995
return kernel

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torch._dynamo.utils import preserve_rng_state
3535

3636
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
37-
from torch._inductor.runtime.hints import AutotuneHint
37+
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
3838
from torch._prims_common import is_integer_dtype
3939
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
4040
from torch.utils._sympy.value_ranges import ValueRanges
@@ -125,7 +125,7 @@ def gen_common_triton_imports():
125125
"""
126126
from torch._inductor.runtime import triton_helpers, triton_heuristics
127127
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128-
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
128+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
129129
"""
130130
)
131131
return imports.getvalue()
@@ -2833,8 +2833,7 @@ def codegen_kernel(self, name=None):
28332833
)
28342834
triton_meta = {
28352835
"signature": triton_meta_signature,
2836-
"device": V.graph.scheduler.current_device.index,
2837-
"device_type": V.graph.scheduler.current_device.type,
2836+
"device": DeviceProperties.create(V.graph.scheduler.current_device),
28382837
"constants": {},
28392838
}
28402839

torch/_inductor/codegen/triton_foreach.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sympy import Integer
77

88
from .. import metrics
9+
from ..runtime.hints import DeviceProperties
910
from ..scheduler import SchedulerNode
1011
from ..utils import ceildiv, Placeholder
1112
from ..virtualized import V
@@ -157,8 +158,7 @@ def jit_lines(self):
157158
_, _, signature = self.args.python_argdefs()
158159
triton_meta = {
159160
"signature": signature_to_meta(signature, size_dtype=size_dtype),
160-
"device": V.graph.scheduler.current_device.index,
161-
"device_type": V.graph.scheduler.current_device.type,
161+
"device": DeviceProperties.create(V.graph.scheduler.current_device),
162162
"constants": {},
163163
}
164164
triton_meta["configs"] = [config_of(signature)]

torch/_inductor/codegen/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .. import codecache, config, ir
4141
from ..ir import ReinterpretView
4242
from ..runtime import triton_heuristics
43+
from ..runtime.hints import DeviceProperties
4344
from ..utils import (
4445
cache_on_self,
4546
get_benchmark_name,
@@ -1130,8 +1131,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
11301131
size_dtype=index_dtype,
11311132
indices=non_constant_indices,
11321133
),
1133-
"device": V.graph.scheduler.current_device.index,
1134-
"device_type": V.graph.scheduler.current_device.type,
1134+
"device": DeviceProperties.create(V.graph.scheduler.current_device),
11351135
# Triton compiler includes equal_to_1 args into constants even
11361136
# when they are not constexpr. otherwise there may be a segfault
11371137
# during launching the Inductor-compiled Triton kernel.

torch/_inductor/runtime/hints.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import collections
2+
import typing
23
from dataclasses import fields
34
from enum import auto, Enum
5+
from typing import Optional
46

57

68
# NOTE: if these fail asserts submit a PR to increase them
@@ -89,3 +91,39 @@ class AutotuneHint(Enum):
8991
# which isn't valid python.
9092
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
9193
__repr__ = Enum.__str__
94+
95+
96+
class DeviceProperties(typing.NamedTuple):
97+
"""Copy device properties into a data structure not requiring torch to be imported"""
98+
99+
type: str # type: ignore[assignment]
100+
index: int # type: ignore[assignment]
101+
cc: int
102+
major: Optional[int] = None
103+
regs_per_multiprocessor: Optional[int] = None
104+
max_threads_per_multi_processor: Optional[int] = None
105+
multi_processor_count: Optional[int] = None
106+
107+
@classmethod
108+
def create(cls, device):
109+
import torch
110+
from torch._dynamo.device_interface import get_interface_for_device
111+
112+
device_type = device.type if torch.version.hip is None else "hip"
113+
device_interface = get_interface_for_device(device)
114+
if device_type == "cuda":
115+
props = device_interface.get_device_properties(device)
116+
return cls(
117+
type=device_type,
118+
index=device.index,
119+
cc=device_interface.get_compute_capability(device),
120+
major=props.major,
121+
regs_per_multiprocessor=props.regs_per_multiprocessor,
122+
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
123+
multi_processor_count=props.multi_processor_count,
124+
)
125+
return cls(
126+
type=device_type,
127+
index=device.index,
128+
cc=device_interface.get_compute_capability(device),
129+
)

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
import torch
1818

19-
from torch._dynamo.device_interface import DeviceGuard, get_interface_for_device
2019
from .coordinate_descent_tuner import CoordescTuner
2120

2221
from .hints import (
2322
_NUM_THREADS_PER_WARP,
2423
AutotuneHint,
24+
DeviceProperties,
2525
HeuristicType,
2626
ReductionHint,
2727
TileHint,
@@ -144,21 +144,19 @@ def __init__(
144144

145145
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
146146
self.fn = fn
147-
self.triton_meta = triton_meta
147+
self.device_props: DeviceProperties = triton_meta["device"]
148+
self.triton_meta = {
149+
**triton_meta,
150+
"device": self.device_props.index,
151+
"device_type": self.device_props.type,
152+
}
148153
self.inductor_meta = {} if inductor_meta is None else inductor_meta
149154
self.save_cache_hook = save_cache_hook
150155
self.mutated_arg_names = mutated_arg_names
151156
self.configs = configs
152157
self.heuristic_type = heuristic_type
153158
self.custom_kernel = custom_kernel
154159
self.cuda_kernel_saved = False
155-
156-
# Align the default design that default as cuda
157-
self.device_type = (
158-
triton_meta["device_type"] if "device_type" in triton_meta else "cuda"
159-
)
160-
self.device_interface = get_interface_for_device(self.device_type)
161-
162160
if log.isEnabledFor(logging.DEBUG):
163161
log.debug(
164162
"CachingAutotuner gets %d configs for %s",
@@ -186,7 +184,7 @@ def __init__(
186184
)
187185
self.filename = filename
188186

189-
def precompile(self, warm_cache_only_with_cc=None):
187+
def precompile(self, warm_cache_only=False):
190188
with self.lock:
191189
if self.launchers:
192190
return
@@ -197,7 +195,7 @@ def precompile(self, warm_cache_only_with_cc=None):
197195
for c in self.configs:
198196
try:
199197
compiled_binary, launcher = self._precompile_config(
200-
c, warm_cache_only_with_cc
198+
c, warm_cache_only
201199
)
202200
except OutOfResources as e:
203201
if len(self.configs) == 1:
@@ -215,19 +213,19 @@ def precompile(self, warm_cache_only_with_cc=None):
215213

216214
seen_configs = set(self.configs)
217215

218-
device_prop = self.device_interface.Worker.get_device_properties(
219-
self.triton_meta["device"]
220-
)
216+
device_prop = self.device_props
221217
if (
222218
self.inductor_meta.get("dynamic_scale_rblock", True)
223219
and self.heuristic_type == HeuristicType.REDUCTION
224220
and self.size_hints is not None
225-
# Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
226-
and not self.inductor_meta.get("is_hip")
227-
# Disable for Intel GPU as Triton is not ready to return n_regs for a compiled_binary.
228-
and self.device_type != "xpu"
221+
# Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
222+
and device_prop.type == "cuda"
223+
and device_prop.major
229224
and device_prop.major >= 8
230225
):
226+
assert device_prop.regs_per_multiprocessor
227+
assert device_prop.max_threads_per_multi_processor
228+
assert device_prop.multi_processor_count
231229
for triton_config, compiled_binary in zip(
232230
self.configs, compiled_binaries
233231
):
@@ -288,15 +286,21 @@ def precompile(self, warm_cache_only_with_cc=None):
288286
continue
289287
seen_configs.add(new_config)
290288
self.launchers.append(
291-
self._precompile_config(new_config, warm_cache_only_with_cc)[1]
289+
self._precompile_config(new_config, warm_cache_only)[1]
292290
)
293291
self.configs = None
294292

295-
def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]):
293+
def get_device_interface(self):
294+
# this code cannot run in compile workers, because it imports from torch
295+
from torch._dynamo.device_interface import get_interface_for_device
296+
297+
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
298+
299+
def _precompile_config(self, cfg: Config, warm_cache_only: bool):
296300
"""Ahead of time compile a given autotuner config."""
297301
compile_meta = copy.deepcopy(self.triton_meta)
298302
for k, v in cfg.kwargs.items():
299-
if torch.version.hip is not None:
303+
if self.device_props.type != "hip":
300304
if k == "matrix_instr_nonkdim":
301305
compile_meta["matrix_instr_nonkdim"] = v
302306
continue
@@ -310,22 +314,9 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
310314
"assert_indirect_indexing", True
311315
) and not self.inductor_meta.get("is_hip", False)
312316

313-
# Setting device_type="hip" required on ROCm to pass down to triton
314-
compile_meta["device_type"] = (
315-
self.device_type if torch.version.hip is None else "hip"
316-
)
317-
318-
if warm_cache_only_with_cc:
319-
cc = warm_cache_only_with_cc
320-
else:
321-
# Use device_type 'cuda' for both cuda and hip devices to retrieve
322-
# the compute capability.
323-
device_type = self.device_type if torch.version.hip is None else "cuda"
324-
device_id = compile_meta["device"]
325-
device = torch.device(device_type, device_id)
326-
cc = self.device_interface.get_compute_capability(device)
327-
328-
compile_meta["cc"] = cc
317+
# device type will be "hip" rather than "cuda" here
318+
compile_meta["device_type"] = self.device_props.type
319+
compile_meta["cc"] = self.device_props.cc
329320

330321
if ASTSource:
331322
compile_args = (
@@ -354,7 +345,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
354345
"num_stages": compile_meta["num_stages"],
355346
"debug": compile_meta["debug"],
356347
}
357-
if torch.version.hip is not None:
348+
if self.device_props.type != "hip":
358349
if "waves_per_eu" in compile_meta:
359350
options["waves_per_eu"] = compile_meta["waves_per_eu"]
360351
if "matrix_instr_nonkdim" in compile_meta:
@@ -369,16 +360,21 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
369360
compile_args = (self.fn,)
370361
compile_kwargs = compile_meta
371362

372-
if warm_cache_only_with_cc:
363+
if warm_cache_only:
373364
return (
374365
triton.compile(*compile_args, **compile_kwargs),
375366
None,
376367
)
377368

369+
# importing from torch is safe now that precompile has returned
370+
from torch._dynamo.device_interface import DeviceGuard
371+
372+
device_interface = self.get_device_interface()
373+
378374
# load binary to the correct device
379-
with DeviceGuard(self.device_interface, compile_meta["device"]): # type: ignore[attr-defined]
375+
with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined]
380376
# need to initialize context
381-
self.device_interface.synchronize(self.device_interface.current_device())
377+
device_interface.synchronize(device_interface.current_device())
382378

383379
try:
384380
binary = triton.compile(*compile_args, **compile_kwargs)
@@ -596,8 +592,9 @@ def bench(self, launcher, *args, grid, **kwargs):
596592
)
597593
return float("inf")
598594

599-
stream = self.device_interface.get_raw_stream( # type: ignore[call-arg]
600-
self.device_interface.current_device()
595+
device_interface = self.get_device_interface()
596+
stream = device_interface.get_raw_stream( # type: ignore[call-arg]
597+
device_interface.current_device()
601598
)
602599

603600
def kernel_call():
@@ -706,7 +703,7 @@ def save_cuda_kernel(self, grid, stream, launcher):
706703

707704
binary = (
708705
launcher.bin.asm["cubin"]
709-
if torch.version.hip is None
706+
if self.device_props.type != "hip"
710707
else launcher.bin.asm["hsaco"]
711708
)
712709
CudaKernelParamCache.set(key, params, binary)
@@ -736,7 +733,7 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
736733

737734
def benchmark_one_config(config):
738735
with self.lock:
739-
_, launcher = self._precompile_config(config, None)
736+
_, launcher = self._precompile_config(config, False)
740737
config2launcher[config] = launcher
741738

742739
out = self.bench(launcher, *cloned_args, **kwargs)

0 commit comments

Comments
 (0)