Skip to content

Commit dcbcdc5

Browse files
add gradient free scipy interface
1 parent c1a4a83 commit dcbcdc5

File tree

4 files changed

+44
-16
lines changed

4 files changed

+44
-16
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- add gradient free scipy interface for optimization
8+
59
## 0.0.220311
610

711
### Added

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
],
2828
classifiers=(
2929
"Programming Language :: Python :: 3",
30-
"License :: OSI Approved :: Apache",
3130
"Operating System :: OS Independent",
3231
),
3332
)

tensorcircuit/interfaces.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,44 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
133133

134134

135135
def scipy_optimize_interface(
136-
fun: Callable[..., Any], shape: Optional[Tuple[int, ...]] = None, jit: bool = True
136+
fun: Callable[..., Any],
137+
shape: Optional[Tuple[int, ...]] = None,
138+
jit: bool = True,
139+
gradient: bool = True,
137140
) -> Callable[..., Any]:
138-
vag = backend.value_and_grad(fun, argnums=0)
141+
if gradient:
142+
vag = backend.value_and_grad(fun, argnums=0)
143+
if jit:
144+
vag = backend.jit(vag)
145+
146+
def scipy_vag(*args: Any, **kws: Any) -> Tuple[Tensor, Tensor]:
147+
scipy_args = numpy_args_to_backend(args, dtype=dtypestr)
148+
if shape is not None:
149+
scipy_args = list(scipy_args)
150+
scipy_args[0] = backend.reshape(scipy_args[0], shape)
151+
scipy_args = tuple(scipy_args)
152+
vs, gs = vag(*scipy_args, **kws)
153+
scipy_vs = general_args_to_numpy(vs)
154+
gs = backend.reshape(gs, [-1])
155+
scipy_gs = general_args_to_numpy(gs)
156+
scipy_vs = scipy_vs.astype(np.float64)
157+
scipy_gs = scipy_gs.astype(np.float64)
158+
return scipy_vs, scipy_gs
159+
160+
return scipy_vag
161+
# no gradient
139162
if jit:
140-
vag = backend.jit(vag)
163+
fun = backend.jit(fun)
141164

142-
def scipy_vag(*args: Any, **kws: Any) -> Tuple[Tensor, Tensor]:
165+
def scipy_v(*args: Any, **kws: Any) -> Tensor:
143166
scipy_args = numpy_args_to_backend(args, dtype=dtypestr)
144167
if shape is not None:
145168
scipy_args = list(scipy_args)
146169
scipy_args[0] = backend.reshape(scipy_args[0], shape)
147170
scipy_args = tuple(scipy_args)
148-
vs, gs = vag(*scipy_args, **kws)
171+
vs = fun(*scipy_args, **kws)
149172
scipy_vs = general_args_to_numpy(vs)
150-
gs = backend.reshape(gs, [-1])
151-
scipy_gs = general_args_to_numpy(gs)
152173
scipy_vs = scipy_vs.astype(np.float64)
153-
scipy_gs = scipy_gs.astype(np.float64)
154-
return scipy_vs, scipy_gs
174+
return scipy_vs
155175

156-
return scipy_vag
176+
return scipy_v

tests/test_interfaces.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def f3(x):
105105
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
106106

107107

108-
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
108+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
109109
def test_scipy_interface(backend):
110110
n = 3
111111

@@ -124,8 +124,13 @@ def f(param):
124124
)
125125
return tc.backend.real(loss)
126126

127-
f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
128-
r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
129-
# L-BFGS-B may has issue with float32
130-
# see: https://github.com/scipy/scipy/issues/5832
127+
if tc.backend.name != "numpy":
128+
f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
129+
r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
130+
# L-BFGS-B may has issue with float32
131+
# see: https://github.com/scipy/scipy/issues/5832
132+
np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
133+
134+
f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n], gradient=False)
135+
r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="COBYLA")
131136
np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)

0 commit comments

Comments
 (0)