2
2
interfaces bridging different backends
3
3
"""
4
4
5
- from typing import Any , Callable
5
+ from typing import Any , Callable , Tuple
6
6
7
7
import numpy as np
8
- from jax import numpy as jnp
9
8
import torch
10
- import tensorflow as tf
11
9
12
10
from .cons import backend
13
11
from .backends import get_backend
19
17
20
18
21
19
def tensor_to_numpy (t : Tensor ) -> Array :
20
+ from jax import numpy as jnp
21
+ import tensorflow as tf
22
+
22
23
if isinstance (t , torch .Tensor ):
23
24
return t .numpy ()
24
25
if isinstance (t , tf .Tensor ) or isinstance (t , tf .Variable ):
@@ -28,7 +29,7 @@ def tensor_to_numpy(t: Tensor) -> Array:
28
29
return t
29
30
30
31
31
- def general_args_to_numpy (args : Any , same_pytree : bool = False ) -> Any :
32
+ def general_args_to_numpy (args : Any , same_pytree : bool = True ) -> Any :
32
33
res = []
33
34
alone = False
34
35
if not (isinstance (args , tuple ) or isinstance (args , list )):
@@ -46,7 +47,7 @@ def general_args_to_numpy(args: Any, same_pytree: bool = False) -> Any:
46
47
47
48
48
49
def numpy_args_to_backend (
49
- args : Any , same_pytree : bool = False , dtype : Any = None , target_backend : Any = None
50
+ args : Any , same_pytree : bool = True , dtype : Any = None , target_backend : Any = None
50
51
) -> Any :
51
52
# TODO(@refraction-ray): switch same_pytree default to True
52
53
if target_backend is None :
@@ -82,13 +83,20 @@ def is_sequence(x: Any) -> bool:
82
83
return False
83
84
84
85
85
- def torch_interface (fun : Callable [..., Any ]) -> Callable [..., Any ]:
86
+ def torch_interface (fun : Callable [..., Any ], jit : bool = False ) -> Callable [..., Any ]:
87
+ def vjp_fun (x : Tensor , v : Tensor ) -> Tuple [Tensor , Tensor ]:
88
+ return backend .vjp (fun , x , v ) # type: ignore
89
+
90
+ if jit is True :
91
+ fun = backend .jit (fun )
92
+ vjp_fun = backend .jit (vjp_fun )
93
+
86
94
class F (torch .autograd .Function ): # type: ignore
87
95
@staticmethod
88
96
def forward (ctx : Any , * x : Any ) -> Any : # type: ignore
89
97
ctx .xdtype = [xi .dtype for xi in x ]
90
- x = general_args_to_numpy (x , same_pytree = True )
91
- x = numpy_args_to_backend (x , same_pytree = True )
98
+ x = general_args_to_numpy (x )
99
+ x = numpy_args_to_backend (x )
92
100
y = fun (* x )
93
101
if not is_sequence (y ):
94
102
ctx .ydtype = [y .dtype ]
@@ -99,25 +107,23 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
99
107
else :
100
108
ctx .x = x
101
109
y = numpy_args_to_backend (
102
- general_args_to_numpy (y , same_pytree = True ),
103
- same_pytree = True ,
110
+ general_args_to_numpy (y ),
104
111
target_backend = "pytorch" ,
105
112
)
106
113
return y
107
114
108
115
@staticmethod
109
116
def backward (ctx : Any , * grad_y : Any ) -> Any :
110
- grad_y = general_args_to_numpy (grad_y , same_pytree = True )
117
+ grad_y = general_args_to_numpy (grad_y )
111
118
grad_y = numpy_args_to_backend (
112
- grad_y , dtype = [d for d in ctx .ydtype ], same_pytree = True
119
+ grad_y , dtype = [d for d in ctx .ydtype ]
113
120
) # backend.dtype
114
121
if len (grad_y ) == 1 :
115
122
grad_y = grad_y [0 ]
116
- _ , g = backend . vjp ( fun , ctx .x , grad_y )
123
+ _ , g = vjp_fun ( ctx .x , grad_y )
117
124
# a redundency due to current vjp API
118
125
r = numpy_args_to_backend (
119
- general_args_to_numpy (g , same_pytree = True ),
120
- same_pytree = True ,
126
+ general_args_to_numpy (g ),
121
127
dtype = [d for d in ctx .xdtype ], # torchdtype
122
128
target_backend = "pytorch" ,
123
129
)
0 commit comments