Skip to content

Commit 3733891

Browse files
make quantum and interfaces in init
1 parent 8fe55a1 commit 3733891

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

tensorcircuit/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@
1818
from .densitymatrix import DMCircuit
1919
from .densitymatrix2 import DMCircuit2
2020
from .gates import num_to_tensor, array_to_tensor
21+
from . import interfaces
2122
from . import templates
23+
from . import quantum

tensorcircuit/interfaces.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, Callable, Tuple, Optional
66

77
import numpy as np
8-
import torch
98

109
from .cons import backend, dtypestr
1110
from .backends import get_backend # type: ignore
@@ -78,6 +77,8 @@ def is_sequence(x: Any) -> bool:
7877

7978

8079
def torch_interface(fun: Callable[..., Any], jit: bool = False) -> Callable[..., Any]:
80+
import torch
81+
8182
def vjp_fun(x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
8283
return backend.vjp(fun, x, v) # type: ignore
8384

tests/test_interfaces.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111

1212
try:
1313
import torch
14+
15+
is_torch = True
1416
except ImportError:
15-
pytest.skip("torch not available", allow_module_level=True)
17+
is_torch = False
1618

1719
import numpy as np
1820
import tensorcircuit as tc
19-
from tensorcircuit import interfaces
2021

2122

23+
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
2224
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
2325
def test_torch_interface(backend):
2426
n = 4
@@ -38,7 +40,7 @@ def f(param):
3840

3941
f_jit = tc.backend.jit(f)
4042

41-
f_jit_torch = interfaces.torch_interface(f_jit)
43+
f_jit_torch = tc.interfaces.torch_interface(f_jit)
4244

4345
param = torch.ones([4, n], requires_grad=True)
4446
l = f_jit_torch(param)
@@ -76,7 +78,7 @@ def f2(paramzz, paramx):
7678
)
7779
return tc.backend.real(loss1), tc.backend.real(loss2)
7880

79-
f2_torch = interfaces.torch_interface(f2, jit=True)
81+
f2_torch = tc.interfaces.torch_interface(f2, jit=True)
8082

8183
paramzz = torch.ones([2, n], requires_grad=True)
8284
paramx = torch.ones([2, n], requires_grad=True)
@@ -92,7 +94,7 @@ def f2(paramzz, paramx):
9294
def f3(x):
9395
return tc.backend.real(x ** 2)
9496

95-
f3_torch = interfaces.torch_interface(f3)
97+
f3_torch = tc.interfaces.torch_interface(f3)
9698
param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
9799
l3 = f3_torch(param3)
98100
l3 = torch.sum(l3)
@@ -120,7 +122,7 @@ def f(param):
120122
)
121123
return tc.backend.real(loss)
122124

123-
f_scipy = interfaces.scipy_optimize_interface(f, shape=[2, n])
125+
f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
124126
r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
125127
# L-BFGS-B may has issue with float32
126128
# see: https://github.com/scipy/scipy/issues/5832

0 commit comments

Comments
 (0)