Skip to content

Commit 958ae60

Browse files
add jit in torch interface
1 parent 986f607 commit 958ae60

File tree

6 files changed

+34
-26
lines changed

6 files changed

+34
-26
lines changed

check_all.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mypy tensorcircuit
77
echo "pylint check"
88
pylint tensorcircuit tests
99
echo "pytest check"
10-
pytest --cov=tensorcircuit -vv
10+
pytest --cov=tensorcircuit -vv -W ignore::DeprecationWarning
1111
echo "sphinx check"
1212
cd docs && make html
13-
echo "all checks passed, congratulates!"
13+
echo "all checks passed, congratulates! 💐"

tensorcircuit/backends.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def softmax( # pylint: disable=unused-variable
329329
330330
.. math ::
331331
332-
\\mathrm{softmax}(x) = \\frac{\exp(x_i)}{\\sum_j \\exp(x_j)}
332+
\\mathrm{softmax}(x) = \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}
333333
334334
335335
:param a: Tensor
@@ -850,7 +850,7 @@ def vectorized_value_and_grad( # pylint: disable=unused-variable
850850
And if argnums=1, the gradient is like
851851
852852
.. math::
853-
g^1_i = \\frac{\\partial \sum_j f(vargs[0][j], args[1])}{\\partial args[1][i]}
853+
g^1_i = \\frac{\\partial \\sum_j f(vargs[0][j], args[1])}{\\partial args[1][i]}
854854
855855
, which is suitable for quantum machine learning scenarios, where ``f`` is the loss function,
856856
args[0] corresponds the input data and args[1] corresponds to the weights in the QML model.

tensorcircuit/cons.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _get_path_cache_friendly(
319319
mapping_dict[id(e)] = i
320320
i += 1
321321
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
322-
order = np.argsort(list(map(sorted, input_sets)) + [[1e10]])[:-1] # type: ignore
322+
order = np.argsort(np.array(list(map(sorted, input_sets)) + [[1e10]], dtype=object))[:-1] # type: ignore
323323
# TODO(@refraction-ray): more stable and unwarning arg sorting here
324324
nodes_new = [nodes[i] for i in order]
325325
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]

tensorcircuit/interfaces.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
interfaces bridging different backends
33
"""
44

5-
from typing import Any, Callable
5+
from typing import Any, Callable, Tuple
66

77
import numpy as np
8-
from jax import numpy as jnp
98
import torch
10-
import tensorflow as tf
119

1210
from .cons import backend
1311
from .backends import get_backend
@@ -19,6 +17,9 @@
1917

2018

2119
def tensor_to_numpy(t: Tensor) -> Array:
20+
from jax import numpy as jnp
21+
import tensorflow as tf
22+
2223
if isinstance(t, torch.Tensor):
2324
return t.numpy()
2425
if isinstance(t, tf.Tensor) or isinstance(t, tf.Variable):
@@ -28,7 +29,7 @@ def tensor_to_numpy(t: Tensor) -> Array:
2829
return t
2930

3031

31-
def general_args_to_numpy(args: Any, same_pytree: bool = False) -> Any:
32+
def general_args_to_numpy(args: Any, same_pytree: bool = True) -> Any:
3233
res = []
3334
alone = False
3435
if not (isinstance(args, tuple) or isinstance(args, list)):
@@ -46,7 +47,7 @@ def general_args_to_numpy(args: Any, same_pytree: bool = False) -> Any:
4647

4748

4849
def numpy_args_to_backend(
49-
args: Any, same_pytree: bool = False, dtype: Any = None, target_backend: Any = None
50+
args: Any, same_pytree: bool = True, dtype: Any = None, target_backend: Any = None
5051
) -> Any:
5152
# TODO(@refraction-ray): switch same_pytree default to True
5253
if target_backend is None:
@@ -82,13 +83,20 @@ def is_sequence(x: Any) -> bool:
8283
return False
8384

8485

85-
def torch_interface(fun: Callable[..., Any]) -> Callable[..., Any]:
86+
def torch_interface(fun: Callable[..., Any], jit: bool = False) -> Callable[..., Any]:
87+
def vjp_fun(x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
88+
return backend.vjp(fun, x, v) # type: ignore
89+
90+
if jit is True:
91+
fun = backend.jit(fun)
92+
vjp_fun = backend.jit(vjp_fun)
93+
8694
class F(torch.autograd.Function): # type: ignore
8795
@staticmethod
8896
def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8997
ctx.xdtype = [xi.dtype for xi in x]
90-
x = general_args_to_numpy(x, same_pytree=True)
91-
x = numpy_args_to_backend(x, same_pytree=True)
98+
x = general_args_to_numpy(x)
99+
x = numpy_args_to_backend(x)
92100
y = fun(*x)
93101
if not is_sequence(y):
94102
ctx.ydtype = [y.dtype]
@@ -99,25 +107,23 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
99107
else:
100108
ctx.x = x
101109
y = numpy_args_to_backend(
102-
general_args_to_numpy(y, same_pytree=True),
103-
same_pytree=True,
110+
general_args_to_numpy(y),
104111
target_backend="pytorch",
105112
)
106113
return y
107114

108115
@staticmethod
109116
def backward(ctx: Any, *grad_y: Any) -> Any:
110-
grad_y = general_args_to_numpy(grad_y, same_pytree=True)
117+
grad_y = general_args_to_numpy(grad_y)
111118
grad_y = numpy_args_to_backend(
112-
grad_y, dtype=[d for d in ctx.ydtype], same_pytree=True
119+
grad_y, dtype=[d for d in ctx.ydtype]
113120
) # backend.dtype
114121
if len(grad_y) == 1:
115122
grad_y = grad_y[0]
116-
_, g = backend.vjp(fun, ctx.x, grad_y)
123+
_, g = vjp_fun(ctx.x, grad_y)
117124
# a redundency due to current vjp API
118125
r = numpy_args_to_backend(
119-
general_args_to_numpy(g, same_pytree=True),
120-
same_pytree=True,
126+
general_args_to_numpy(g),
121127
dtype=[d for d in ctx.xdtype], # torchdtype
122128
target_backend="pytorch",
123129
)

tensorcircuit/mpscircuit.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def split_tensor(
3535
:type max_truncation_err: float, optional
3636
:param relative: Multiply `max_truncation_err` with the largest singular value.
3737
:type relative: bool, optional
38+
:return: two tensors after splitting
39+
:rtype: Tuple[Tensor, Tensor]
3840
"""
3941
# The behavior is a little bit different from tn.split_node because it explicitly requires a center
4042
svd = (max_truncation_err is not None) or (max_singular_values is not None)
@@ -398,7 +400,7 @@ def from_wavefunction(
398400
"""
399401
wavefunction = backend.reshape(wavefunction, (-1, 1))
400402
tensors: List[Tensor] = []
401-
while True:
403+
while True: # not jittable
402404
nright = wavefunction.shape[1]
403405
wavefunction = backend.reshape(wavefunction, (-1, nright * 2))
404406
wavefunction, Q = split_tensor(
@@ -599,10 +601,10 @@ def expectation_two_gates_correlations(
599601
:type gate1: Gate
600602
:param gate2: second gate to be applied
601603
:type gate2: Gate
602-
:param site: qubit index of the first gate
603-
:type site: int
604-
:param site: qubit index of the second gate
605-
:type site: int
604+
:param site1: qubit index of the first gate
605+
:type site1: int
606+
:param site2: qubit index of the second gate
607+
:type site2: int
606608
"""
607609
value = self._mps.measure_two_body_correlator(
608610
gate1.tensor, gate2.tensor, site1, [site2]

tests/test_interfaces.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def f2(paramzz, paramx):
8181
)
8282
return tc.backend.real(loss1), tc.backend.real(loss2)
8383

84-
f2_torch = interfaces.torch_interface(f2)
84+
f2_torch = interfaces.torch_interface(f2, jit=True)
8585

8686
paramzz = torch.ones([2, n], requires_grad=True)
8787
paramx = torch.ones([2, n], requires_grad=True)

0 commit comments

Comments
 (0)