Skip to content

Commit 1de1f17

Browse files
add general scan method acc examples
1 parent 05efada commit 1de1f17

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

examples/hea_scan_jit_acc.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
reducing jit compiling time by general scan magic
3+
"""
4+
5+
import numpy as np
6+
import tensorcircuit as tc
7+
8+
n = 10
9+
nlayers = 16
10+
param_np = np.random.normal(size=[nlayers, n, 2])
11+
12+
for backend in ["tensorflow", "jax"]:
13+
with tc.runtime_backend(backend) as K:
14+
print("running %s" % K.name)
15+
16+
def energy_reference(param, n, nlayers):
17+
c = tc.Circuit(n)
18+
for i in range(n):
19+
c.h(i)
20+
for i in range(nlayers):
21+
for j in range(n - 1):
22+
c.rzz(j, j + 1, theta=param[i, j, 0])
23+
for j in range(n):
24+
c.rx(j, theta=param[i, j, 1])
25+
return K.real(c.expectation_ps(z=[0, 1]) + c.expectation_ps(x=[2]))
26+
27+
vg_reference = K.jit(
28+
K.value_and_grad(energy_reference, argnums=0), static_argnums=(1, 2)
29+
)
30+
31+
# a jit efficient way to utilize scan
32+
33+
def energy(param, n, nlayers, each):
34+
def loop_f(s_, param_):
35+
c_ = tc.Circuit(n, inputs=s_)
36+
for i in range(each):
37+
for j in range(n - 1):
38+
c_.rzz(j, j + 1, theta=param_[i, j, 0])
39+
for j in range(n):
40+
c_.rx(j, theta=param_[i, j, 1])
41+
s_ = c_.state()
42+
return s_
43+
44+
c = tc.Circuit(n)
45+
for i in range(n):
46+
c.h(i)
47+
s = c.state()
48+
s1 = K.scan(loop_f, K.reshape(param, [nlayers // each, each, n, 2]), s)
49+
c1 = tc.Circuit(n, inputs=s1)
50+
return K.real(c1.expectation_ps(z=[0, 1]) + c1.expectation_ps(x=[2]))
51+
52+
vg = K.jit(
53+
K.value_and_grad(energy, argnums=0),
54+
static_argnums=(1, 2, 3),
55+
jit_compile=True,
56+
)
57+
# set to False can improve compile time for tf
58+
59+
param = K.convert_to_tensor(param_np)
60+
61+
for each in [1, 2, 4]:
62+
print(" scan impl with each=%s" % str(each))
63+
r1 = tc.utils.benchmark(vg, param, n, nlayers, each)
64+
print(r1[0][0])
65+
66+
print(" plain impl")
67+
r0 = tc.utils.benchmark(vg_reference, param, n, nlayers) # too slow
68+
np.testing.assert_allclose(r0[0][0], r1[0][0], atol=1e-5)
69+
np.testing.assert_allclose(r0[0][1], r1[0][1], atol=1e-5)
70+
# correctness check
71+
72+
73+
# jit_compile=True icrease runtime while degrades jit time for tensorflow
74+
# and in general jax improves better with scan methodology,
75+
# both compile time and running time can outperform tf

examples/jax_scan_jit_acc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
reducing jax jit compiling time by some magic
2+
reducing jax jit compiling time by some magic:
3+
for backend agnostic but similar approach,
4+
see `hea_scan_jit_acc.py`
35
"""
46

57
import numpy as np

0 commit comments

Comments
 (0)