11
11
12
12
try :
13
13
import torch
14
+
15
+ is_torch = True
14
16
except ImportError :
15
- pytest . skip ( "torch not available" , allow_module_level = True )
17
+ is_torch = False
16
18
17
19
import numpy as np
18
20
import tensorcircuit as tc
19
- from tensorcircuit import interfaces
20
21
21
22
23
+ @pytest .mark .skipif (is_torch is False , reason = "torch not installed" )
22
24
@pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" )])
23
25
def test_torch_interface (backend ):
24
26
n = 4
@@ -38,7 +40,7 @@ def f(param):
38
40
39
41
f_jit = tc .backend .jit (f )
40
42
41
- f_jit_torch = interfaces .torch_interface (f_jit )
43
+ f_jit_torch = tc . interfaces .torch_interface (f_jit )
42
44
43
45
param = torch .ones ([4 , n ], requires_grad = True )
44
46
l = f_jit_torch (param )
@@ -76,7 +78,7 @@ def f2(paramzz, paramx):
76
78
)
77
79
return tc .backend .real (loss1 ), tc .backend .real (loss2 )
78
80
79
- f2_torch = interfaces .torch_interface (f2 , jit = True )
81
+ f2_torch = tc . interfaces .torch_interface (f2 , jit = True )
80
82
81
83
paramzz = torch .ones ([2 , n ], requires_grad = True )
82
84
paramx = torch .ones ([2 , n ], requires_grad = True )
@@ -92,7 +94,7 @@ def f2(paramzz, paramx):
92
94
def f3 (x ):
93
95
return tc .backend .real (x ** 2 )
94
96
95
- f3_torch = interfaces .torch_interface (f3 )
97
+ f3_torch = tc . interfaces .torch_interface (f3 )
96
98
param3 = torch .ones ([2 ], dtype = torch .complex64 , requires_grad = True )
97
99
l3 = f3_torch (param3 )
98
100
l3 = torch .sum (l3 )
@@ -120,7 +122,7 @@ def f(param):
120
122
)
121
123
return tc .backend .real (loss )
122
124
123
- f_scipy = interfaces .scipy_optimize_interface (f , shape = [2 , n ])
125
+ f_scipy = tc . interfaces .scipy_optimize_interface (f , shape = [2 , n ])
124
126
r = optimize .minimize (f_scipy , np .zeros ([2 * n ]), method = "L-BFGS-B" , jac = True )
125
127
# L-BFGS-B may has issue with float32
126
128
# see: https://github.com/scipy/scipy/issues/5832
0 commit comments