Skip to content

Commit e9c6ee8

Browse files
add scipy interface
1 parent deaa986 commit e9c6ee8

File tree

3 files changed

+57
-10
lines changed

3 files changed

+57
-10
lines changed

Diff for: CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
- add `state_centric` decorator in `tc.templates.blocks` to transform circuit-to-circuit funtion to state-to-state function
1414

15+
- add `interfaces.scipy_optimize_interface` to transform quantum function into `scipy.optimize.minimize` campatible form
16+
1517
### Fixed
1618

1719
- avoid error on watch non `tf.Tensor` in tensorflow backend grad method

Diff for: tensorcircuit/interfaces.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
interfaces bridging different backends
33
"""
44

5-
from typing import Any, Callable, Tuple
5+
from typing import Any, Callable, Tuple, Optional
66

77
import numpy as np
88
import torch
99

10-
from .cons import backend
10+
from .cons import backend, dtypestr
1111
from .backends import get_backend # type: ignore
1212

1313
Tensor = Any
@@ -17,16 +17,10 @@
1717

1818

1919
def tensor_to_numpy(t: Tensor) -> Array:
20-
from jax import numpy as jnp
21-
import tensorflow as tf
22-
23-
if isinstance(t, torch.Tensor):
24-
return t.numpy()
25-
if isinstance(t, tf.Tensor) or isinstance(t, tf.Variable):
20+
try:
2621
return t.numpy()
27-
if isinstance(t, jnp.ndarray):
22+
except AttributeError:
2823
return np.array(t)
29-
return t
3024

3125

3226
def general_args_to_numpy(args: Any, same_pytree: bool = True) -> Any:
@@ -133,3 +127,27 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
133127

134128
# currently, memory transparent dlpack in these ML framework has broken support on complex dtypes
135129
return Fun.apply # type: ignore
130+
131+
132+
def scipy_optimize_interface(
133+
fun: Callable[..., Any], shape: Optional[Tuple[int, ...]] = None, jit: bool = True
134+
) -> Callable[..., Any]:
135+
vag = backend.value_and_grad(fun, argnums=0)
136+
if jit:
137+
vag = backend.jit(vag)
138+
139+
def scipy_vag(*args: Any, **kws: Any) -> Tuple[Tensor, Tensor]:
140+
scipy_args = numpy_args_to_backend(args, dtype=dtypestr)
141+
if shape is not None:
142+
scipy_args = list(scipy_args)
143+
scipy_args[0] = backend.reshape(scipy_args[0], shape)
144+
scipy_args = tuple(scipy_args)
145+
vs, gs = vag(*scipy_args, **kws)
146+
scipy_vs = general_args_to_numpy(vs)
147+
gs = backend.reshape(gs, [-1])
148+
scipy_gs = general_args_to_numpy(gs)
149+
scipy_vs = scipy_vs.astype(np.float64)
150+
scipy_gs = scipy_gs.astype(np.float64)
151+
return scipy_vs, scipy_gs
152+
153+
return scipy_vag

Diff for: tests/test_interfaces.py

+27
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import pytest
44
from pytest_lazyfixture import lazy_fixture as lf
5+
from scipy import optimize
56

67
thisfile = os.path.abspath(__file__)
78
modulepath = os.path.dirname(os.path.dirname(thisfile))
@@ -98,3 +99,29 @@ def f3(x):
9899
l3.backward()
99100
pg = param3.grad
100101
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
102+
103+
104+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
105+
def test_scipy_interface(backend):
106+
n = 3
107+
108+
def f(param):
109+
c = tc.Circuit(n)
110+
for i in range(n):
111+
c.rx(i, theta=param[0, i])
112+
c.rz(i, theta=param[1, i])
113+
loss = c.expectation(
114+
[
115+
tc.gates.y(),
116+
[
117+
0,
118+
],
119+
]
120+
)
121+
return tc.backend.real(loss)
122+
123+
f_scipy = interfaces.scipy_optimize_interface(f, shape=[2, n])
124+
r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
125+
# L-BFGS-B may has issue with float32
126+
# see: https://github.com/scipy/scipy/issues/5832
127+
np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)

0 commit comments

Comments
 (0)