File tree 2 files changed +8
-7
lines changed
2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change 25
25
enable_dlpack = True
26
26
# enable_dlpack = False # for old version of ML libs
27
27
tf_device = "/GPU:0"
28
- # tf_device "/device:CPU:0"
28
+ # tf_device = "/device:CPU:0"
29
+ # another scheme to globally close GPU only for tf
30
+ # https://datascience.stackexchange.com/a/76039
31
+ # but if gpu support is fully shut down as above
32
+ # dlpack=True wont work
29
33
30
34
# dataset preparation
31
35
Original file line number Diff line number Diff line change @@ -70,8 +70,7 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
70
70
# (x, )
71
71
if len (ctx .xdtype ) == 1 :
72
72
ctx .xdtype = ctx .xdtype [0 ]
73
- if not enable_dlpack :
74
- ctx .device = (backend .tree_flatten (x )[0 ][0 ]).device
73
+ ctx .device = (backend .tree_flatten (x )[0 ][0 ]).device
75
74
x = general_args_to_backend (x , enable_dlpack = enable_dlpack )
76
75
y = fun (* x )
77
76
ctx .ydtype = backend .tree_map (lambda s : s .dtype , y )
@@ -81,8 +80,7 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
81
80
y = general_args_to_backend (
82
81
y , target_backend = "pytorch" , enable_dlpack = enable_dlpack
83
82
)
84
- if not enable_dlpack :
85
- y = backend .tree_map (lambda s : s .to (device = ctx .device ), y )
83
+ y = backend .tree_map (lambda s : s .to (device = ctx .device ), y )
86
84
return y
87
85
88
86
@staticmethod
@@ -104,8 +102,7 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
104
102
target_backend = "pytorch" ,
105
103
enable_dlpack = enable_dlpack ,
106
104
)
107
- if not enable_dlpack :
108
- r = backend .tree_map (lambda s : s .to (device = ctx .device ), r )
105
+ r = backend .tree_map (lambda s : s .to (device = ctx .device ), r )
109
106
if not is_sequence (r ):
110
107
return (r ,)
111
108
return r
You can’t perform that action at this time.
0 commit comments