@@ -133,24 +133,44 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
133
133
134
134
135
135
def scipy_optimize_interface (
136
- fun : Callable [..., Any ], shape : Optional [Tuple [int , ...]] = None , jit : bool = True
136
+ fun : Callable [..., Any ],
137
+ shape : Optional [Tuple [int , ...]] = None ,
138
+ jit : bool = True ,
139
+ gradient : bool = True ,
137
140
) -> Callable [..., Any ]:
138
- vag = backend .value_and_grad (fun , argnums = 0 )
141
+ if gradient :
142
+ vag = backend .value_and_grad (fun , argnums = 0 )
143
+ if jit :
144
+ vag = backend .jit (vag )
145
+
146
+ def scipy_vag (* args : Any , ** kws : Any ) -> Tuple [Tensor , Tensor ]:
147
+ scipy_args = numpy_args_to_backend (args , dtype = dtypestr )
148
+ if shape is not None :
149
+ scipy_args = list (scipy_args )
150
+ scipy_args [0 ] = backend .reshape (scipy_args [0 ], shape )
151
+ scipy_args = tuple (scipy_args )
152
+ vs , gs = vag (* scipy_args , ** kws )
153
+ scipy_vs = general_args_to_numpy (vs )
154
+ gs = backend .reshape (gs , [- 1 ])
155
+ scipy_gs = general_args_to_numpy (gs )
156
+ scipy_vs = scipy_vs .astype (np .float64 )
157
+ scipy_gs = scipy_gs .astype (np .float64 )
158
+ return scipy_vs , scipy_gs
159
+
160
+ return scipy_vag
161
+ # no gradient
139
162
if jit :
140
- vag = backend .jit (vag )
163
+ fun = backend .jit (fun )
141
164
142
- def scipy_vag (* args : Any , ** kws : Any ) -> Tuple [ Tensor , Tensor ] :
165
+ def scipy_v (* args : Any , ** kws : Any ) -> Tensor :
143
166
scipy_args = numpy_args_to_backend (args , dtype = dtypestr )
144
167
if shape is not None :
145
168
scipy_args = list (scipy_args )
146
169
scipy_args [0 ] = backend .reshape (scipy_args [0 ], shape )
147
170
scipy_args = tuple (scipy_args )
148
- vs , gs = vag (* scipy_args , ** kws )
171
+ vs = fun (* scipy_args , ** kws )
149
172
scipy_vs = general_args_to_numpy (vs )
150
- gs = backend .reshape (gs , [- 1 ])
151
- scipy_gs = general_args_to_numpy (gs )
152
173
scipy_vs = scipy_vs .astype (np .float64 )
153
- scipy_gs = scipy_gs .astype (np .float64 )
154
- return scipy_vs , scipy_gs
174
+ return scipy_vs
155
175
156
- return scipy_vag
176
+ return scipy_v
0 commit comments