Skip to content

Commit 4a665d1

Browse files
add numpy interface
1 parent f3ec276 commit 4a665d1

File tree

8 files changed

+90
-55
lines changed

8 files changed

+90
-55
lines changed

Diff for: tensorcircuit/backends/abstract_backend.py

-2
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,6 @@ def abs(self: Any, a: Tensor) -> Tensor:
268268
"Backend '{}' has not implemented `abs`.".format(self.name)
269269
)
270270

271-
# TODO(@refraction-ray): abs docstring doesn't get registered in the doc
272-
273271
def kron(self: Any, a: Tensor, b: Tensor) -> Tensor:
274272
"""
275273
Return the kronecker product of two matrices ``a`` and ``b``.

Diff for: tensorcircuit/circuit.py

-43
Original file line numberDiff line numberDiff line change
@@ -334,49 +334,6 @@ def _meta_apply(cls) -> None:
334334
for alias_gate in gate_alias[1:]:
335335
setattr(cls, alias_gate, getattr(cls, present_gate))
336336

337-
# @classmethod
338-
# def from_qcode(
339-
# cls, qcode: str
340-
# ) -> "Circuit": # forward reference, see https://github.com/python/mypy/issues/3661
341-
# """
342-
# [WIP], make circuit object from non universal simple assembly quantum language
343-
344-
# :param qcode:
345-
# :type qcode: str
346-
# :return: :py:class:`Circuit` object
347-
# """
348-
# # TODO(@refraction-ray): change to OpenQASM IO
349-
# lines = [s for s in qcode.split("\n") if s.strip()]
350-
# nqubits = int(lines[0])
351-
# c = cls(nqubits)
352-
# for l in lines[1:]:
353-
# ls = [s for s in l.split(" ") if s.strip()]
354-
# g = ls[0]
355-
# index = []
356-
# errloc = 0
357-
# for i, s in enumerate(ls[1:]):
358-
# try:
359-
# si = int(s)
360-
# index.append(si)
361-
# except ValueError:
362-
# errloc = i + 1
363-
# break
364-
# kwdict = {}
365-
# if errloc > 0:
366-
# for j, s in enumerate(ls[errloc::2]):
367-
# kwdict[s] = float(ls[2 * j + 1 + errloc])
368-
# getattr(c, g)(*index, **kwdict)
369-
# return c
370-
371-
# def to_qcode(self) -> str:
372-
# """
373-
# [WIP]
374-
375-
# :return: qcode str of corresponding circuit
376-
# :rtype: str
377-
# """
378-
# return self._qcode
379-
380337
def apply_single_gate(self, gate: Gate, index: int) -> None:
381338
"""
382339
Apply the gate to the bit with the given index.

Diff for: tensorcircuit/interfaces/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
general_args_to_numpy,
99
general_args_to_backend,
1010
)
11+
from .numpy import numpy_interface, np_interface
1112
from .scipy import scipy_interface, scipy_optimize_interface
1213
from .torch import torch_interface, pytorch_interface
1314
from .tensorflow import tensorflow_interface, tf_interface

Diff for: tensorcircuit/interfaces/numpy.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Interface wraps quantum function as a numpy function
3+
"""
4+
5+
from typing import Any, Callable
6+
from functools import wraps
7+
8+
from ..cons import backend
9+
from .tensortrans import general_args_to_numpy, numpy_args_to_backend
10+
11+
Tensor = Any
12+
13+
14+
def numpy_interface(
15+
fun: Callable[..., Any],
16+
jit: bool = True,
17+
) -> Callable[..., Any]:
18+
"""
19+
Convert ``fun`` on ML backend into a numpy function
20+
21+
:Example:
22+
23+
.. code-block:: python
24+
25+
K = tc.set_backend("tensorflow")
26+
27+
def f(params, n):
28+
c = tc.Circuit(n)
29+
for i in range(n):
30+
c.rx(i, theta=params[i])
31+
for i in range(n-1):
32+
c.cnot(i, i+1)
33+
r = K.real(c.expectation_ps(z=[n-1]))
34+
return r
35+
36+
n = 3
37+
f_np = tc.interfaces.numpy_interface(f, jit=True)
38+
f_np(np.ones([n]), n) # 0.1577285
39+
40+
41+
:param fun: The quantum function
42+
:type fun: Callable[..., Any]
43+
:param jit: whether to jit ``fun``, defaults to True
44+
:type jit: bool, optional
45+
:return: The numpy interface compatible version of ``fun``
46+
:rtype: Callable[..., Any]
47+
"""
48+
if jit:
49+
fun = backend.jit(fun)
50+
51+
@wraps(fun)
52+
def numpy_fun(*args: Any, **kws: Any) -> Any:
53+
backend_args = numpy_args_to_backend(args)
54+
r = fun(*backend_args, **kws)
55+
np_r = general_args_to_numpy(r)
56+
return np_r
57+
58+
return numpy_fun
59+
60+
61+
np_interface = numpy_interface

Diff for: tensorcircuit/interfaces/tensortrans.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Any
6+
from functools import partial
67

78
from ..cons import backend
89
from ..backends import get_backend # type: ignore
@@ -39,9 +40,17 @@ def which_backend(a: Tensor, return_backend: bool = True) -> Any:
3940

4041

4142
def tensor_to_numpy(t: Tensor) -> Array:
43+
if isinstance(t, int) or isinstance(t, float):
44+
return t
4245
return which_backend(t).numpy(t)
4346

4447

48+
def numpy_to_tensor(t: Array, backend: Any) -> Tensor:
49+
if isinstance(t, int) or isinstance(t, float):
50+
return t
51+
return backend.convert_to_tensor(t)
52+
53+
4554
def tensor_to_dtype(t: Tensor) -> str:
4655
return which_backend(t).dtype(t) # type: ignore
4756

@@ -85,13 +94,13 @@ def numpy_args_to_backend(
8594
target_backend = get_backend(target_backend)
8695

8796
if dtype is None:
88-
return backend.tree_map(target_backend.convert_to_tensor, args)
97+
return backend.tree_map(partial(numpy_to_tensor, backend=target_backend), args)
8998
else:
9099
if isinstance(dtype, str):
91100
leaves, treedef = backend.tree_flatten(args)
92101
dtype = [dtype for _ in range(len(leaves))]
93102
dtype = backend.tree_unflatten(treedef, dtype)
94-
t = backend.tree_map(target_backend.convert_to_tensor, args)
103+
t = backend.tree_map(partial(numpy_to_tensor, backend=target_backend), args)
95104
t = backend.tree_map(target_backend.cast, t, dtype)
96105
return t
97106

@@ -100,7 +109,6 @@ def general_args_to_backend(
100109
args: Any, dtype: Any = None, target_backend: Any = None, enable_dlpack: bool = True
101110
) -> Any:
102111
if not enable_dlpack:
103-
# TODO(@refraction-ray): add device shift for numpy mediate transformation
104112
args = general_args_to_numpy(args)
105113
args = numpy_args_to_backend(args, dtype, target_backend)
106114
return args

Diff for: tensorcircuit/simplify.py

-3
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,3 @@ def _full_light_cone_cancel(nodes: List[Any]) -> List[Any]:
280280
while is_changed:
281281
nodes, is_changed = _light_cone_cancel(nodes)
282282
return nodes
283-
284-
285-
# TODO(@refraction-ray): utilize more simplification method in contractor preprocessing

Diff for: tests/test_interfaces.py

+17
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,23 @@ def f(param):
243243
np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
244244

245245

246+
@pytest.mark.parametrize("backend", [lf("torchb"), lf("tfb"), lf("jaxb")])
247+
def test_numpy_interface(backend):
248+
def f(params, n):
249+
c = tc.Circuit(n)
250+
for i in range(n):
251+
c.rx(i, theta=params[i])
252+
for i in range(n - 1):
253+
c.cnot(i, i + 1)
254+
r = tc.backend.real(c.expectation_ps(z=[n - 1]))
255+
return r
256+
257+
n = 3
258+
f_np = tc.interfaces.numpy_interface(f, jit=False)
259+
r = f_np(np.ones([n]), n)
260+
np.testing.assert_allclose(r, 0.1577285, atol=1e-5)
261+
262+
246263
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
247264
def test_args_transformation(backend):
248265
ans = tc.interfaces.general_args_to_numpy(

Diff for: tests/test_mpscircuit.py

-4
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
sys.path.insert(0, modulepath)
1616
import tensorcircuit as tc
1717

18-
# TODO(@refraction-ray): mps circuit test: grad & jit, differentiable? jittable?
19-
# AD on jax backend may have issues for now, see
20-
# https://gist.github.com/refraction-ray/cc48c0b31984e6a04ee00050c0b36758
21-
# for a minimal demo
2218

2319
N = 16
2420
D = 100

0 commit comments

Comments
 (0)