Skip to content

Commit 2444479

Browse files
allow device move even with dlpack on torch interface
1 parent de48165 commit 2444479

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

examples/hybrid_gpu_pipeline.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
enable_dlpack = True
2626
# enable_dlpack = False # for old version of ML libs
2727
tf_device = "/GPU:0"
28-
# tf_device "/device:CPU:0"
28+
# tf_device = "/device:CPU:0"
29+
# another scheme to globally close GPU only for tf
30+
# https://datascience.stackexchange.com/a/76039
31+
# but if gpu support is fully shut down as above
32+
# dlpack=True wont work
2933

3034
# dataset preparation
3135

tensorcircuit/interfaces/torch.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
7070
# (x, )
7171
if len(ctx.xdtype) == 1:
7272
ctx.xdtype = ctx.xdtype[0]
73-
if not enable_dlpack:
74-
ctx.device = (backend.tree_flatten(x)[0][0]).device
73+
ctx.device = (backend.tree_flatten(x)[0][0]).device
7574
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
7675
y = fun(*x)
7776
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
@@ -81,8 +80,7 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8180
y = general_args_to_backend(
8281
y, target_backend="pytorch", enable_dlpack=enable_dlpack
8382
)
84-
if not enable_dlpack:
85-
y = backend.tree_map(lambda s: s.to(device=ctx.device), y)
83+
y = backend.tree_map(lambda s: s.to(device=ctx.device), y)
8684
return y
8785

8886
@staticmethod
@@ -104,8 +102,7 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
104102
target_backend="pytorch",
105103
enable_dlpack=enable_dlpack,
106104
)
107-
if not enable_dlpack:
108-
r = backend.tree_map(lambda s: s.to(device=ctx.device), r)
105+
r = backend.tree_map(lambda s: s.to(device=ctx.device), r)
109106
if not is_sequence(r):
110107
return (r,)
111108
return r

0 commit comments

Comments
 (0)