Skip to content

Commit f008efa

Browse files
voznesenskympytorchmergebot
authored andcommitted
Reconstruct streams via global registration, temporary impl to unblock FSDP (pytorch#117386)
This is a placeholder implementation for reconstructing streams via global storage to unblock FSDP, pending proper stream support design This PR does a few things: 1) fixes registration for devices with indices. We were only supporting "cuda", we now support "cuda:k" interfaces where k is # of gpu 2) Changes the stream objects in dynamo to take devices as device types, instead of strings, and updates the string based device APIs to gracefully take device types. 3) Introduces a reconstruct-by-global (using existing cleanup hook structures) to streams as a placeholder impl for now Pull Request resolved: pytorch#117386 Approved by: https://github.com/jansel
1 parent ef3217d commit f008efa

File tree

5 files changed

+59
-14
lines changed

5 files changed

+59
-14
lines changed

test/dynamo/test_ctx_manager.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,30 @@ def fn(x):
177177
self.assertEqual(cnts.frame_count, 1)
178178
self.assertEqual(cnts.op_count, 9)
179179

180+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
181+
def test_cuda_stream_across_graph_break(self):
182+
def fn(x):
183+
s = torch.cuda.Stream()
184+
x = torch.mul(x, 5)
185+
x = torch.add(x, 2)
186+
187+
print("foo")
188+
tcs = torch.cuda.stream(s)
189+
with tcs:
190+
x = torch.relu(x)
191+
x = torch.add(x, 1)
192+
x = torch.cos(x)
193+
return x
194+
195+
x = torch.randn((2, 2), device="cuda")
196+
ref = fn(x)
197+
cnts = torch._dynamo.testing.CompileCounter()
198+
opt_fn = torch._dynamo.optimize(cnts)(fn)
199+
res = opt_fn(x)
200+
self.assertEqual(ref, res)
201+
self.assertEqual(cnts.frame_count, 2)
202+
self.assertEqual(cnts.op_count, 9)
203+
180204
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
181205
def test_cuda_stream_context_manager2(self):
182206
def fn(x, s):

torch/_dynamo/device_interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,17 @@ def get_compute_capability(device: _device_t = None):
174174
device_interfaces: Dict[str, Type[DeviceInterface]] = {}
175175

176176

177-
def register_interface_for_device(device: str, device_interface: Type[DeviceInterface]):
177+
def register_interface_for_device(
178+
device: Union[str, torch.device], device_interface: Type[DeviceInterface]
179+
):
180+
if isinstance(device, torch.device):
181+
device = str(device)
178182
device_interfaces[device] = device_interface
179183

180184

181-
def get_interface_for_device(device: str) -> Type[DeviceInterface]:
185+
def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]:
186+
if isinstance(device, torch.device):
187+
device = str(device)
182188
if device in device_interfaces:
183189
return device_interfaces[device]
184190
raise NotImplementedError(f"No interface for device {device}")
@@ -189,3 +195,5 @@ def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterfa
189195

190196

191197
register_interface_for_device("cuda", CudaInterface)
198+
for i in range(torch.cuda.device_count()):
199+
register_interface_for_device(f"cuda:{i}", CudaInterface)

torch/_dynamo/symbolic_convert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,6 @@ def popn(self, n: int) -> List[VariableTracker]:
815815

816816
def LOAD_FAST(self, inst):
817817
name = inst.argval
818-
819818
if name in self.f_locals and config.replay_record_enabled:
820819
self.exec_recorder.add_local_var(name, self.f_locals[name])
821820

torch/_dynamo/variables/builder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
AutocastModeVariable,
8585
EventVariable,
8686
NullContextVariable,
87+
StreamContextVariable,
8788
StreamVariable,
8889
)
8990
from .dicts import (
@@ -570,12 +571,17 @@ def build_key_value(k, v):
570571
elif isinstance(value, HigherOrderOperator):
571572
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
572573
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
574+
elif isinstance(value, torch.cuda.StreamContext):
575+
self.install_guards(GuardBuilder.ID_MATCH)
576+
stream_source = AttrSource(self.source, "stream")
577+
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
578+
return StreamContextVariable.create(self.tx, stream_var)
573579
elif isinstance(value, _StreamBase):
574580
self.install_guards(GuardBuilder.ID_MATCH)
575581
return StreamVariable(
576582
None,
577583
value,
578-
value.device.type,
584+
value.device,
579585
source=self.source,
580586
)
581587
elif isinstance(value, _EventBase):
@@ -1500,9 +1506,7 @@ def _clone_input(value):
15001506
for _, device_interface in get_registered_device_interfaces()
15011507
]:
15021508
proxy.node.meta["example_value"] = example_value
1503-
return StreamVariable(
1504-
proxy, example_value, example_value.device.type, **options
1505-
)
1509+
return StreamVariable(proxy, example_value, example_value.device, **options)
15061510
elif (
15071511
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
15081512
) or proxy.node.target in [

torch/_dynamo/variables/ctx_manager.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -521,19 +521,13 @@ def exit(self, tx, *args):
521521
)
522522
self.state.cleanup_assert()
523523

524-
def module_name(self):
525-
return "torch." + str(self.device)
526-
527-
def fn_name(self):
528-
return "stream"
529-
530524

531525
class StreamVariable(VariableTracker):
532526
def __init__(self, proxy, value, device, **kwargs):
533527
if proxy is not None and "example_value" in proxy.node.meta:
534528
assert proxy.node.meta["example_value"] == value
535529
assert (
536-
value.device.type == device
530+
value.device.type == device.type
537531
), "stream value is not equal to the passed device"
538532
super().__init__(**kwargs)
539533
self.proxy = proxy
@@ -586,6 +580,22 @@ def call_method(
586580
def as_proxy(self):
587581
return self.proxy
588582

583+
def reconstruct(self, codegen):
584+
# If we got here, this stream is fully subsumed by the graph - this means it is
585+
# not an input or global
586+
assert not self.source
587+
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
588+
# is fine and sound according to dynamo principles of treating collectives. However,
589+
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
590+
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
591+
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
592+
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
593+
name = f"_stream_{self.device}_{id(self.value)}"
594+
if name not in codegen.tx.output.global_scope:
595+
codegen.tx.output.install_global(name, self.value)
596+
597+
return [codegen.create_load_global(name, push_null=False, add=True)]
598+
589599

590600
class EventVariable(VariableTracker):
591601
def __init__(self, proxy, value, **kwargs):

0 commit comments

Comments
 (0)