Skip to content

Commit deaa986

Browse files
add jax.checkpoint example v1
1 parent bd62626 commit deaa986

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

examples/checkpoint_memsave.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
some possible attempts to save memory from state-like simulator with checkpoint tricks
3+
"""
4+
5+
from functools import partial
6+
from itertools import product
7+
import time
8+
import sys
9+
import logging
10+
11+
import numpy as np
12+
import jax
13+
from jax import numpy as jnp
14+
15+
logger = logging.getLogger("tensorcircuit")
16+
logger.setLevel(logging.INFO)
17+
ch = logging.StreamHandler()
18+
ch.setLevel(logging.DEBUG)
19+
logger.addHandler(ch)
20+
21+
sys.path.insert(0, "../")
22+
sys.setrecursionlimit(10000)
23+
24+
import tensorcircuit as tc
25+
import cotengra as ctg
26+
from tensorcircuit import keras
27+
28+
optr = ctg.ReusableHyperOptimizer(
29+
methods=["greedy", "kahypar"],
30+
parallel=True,
31+
minimize="write",
32+
max_time=15,
33+
max_repeats=512,
34+
progbar=True,
35+
)
36+
tc.set_contractor("custom", optimizer=optr, preprocessing=True)
37+
tc.set_dtype("complex64")
38+
tc.set_backend("jax")
39+
40+
41+
nwires, nlayers = 10, 36
42+
sn = int(np.sqrt(nlayers))
43+
44+
45+
def recursive_checkpoint(funs):
46+
if len(funs) == 1:
47+
return funs[0]
48+
elif len(funs) == 2:
49+
f1, f2 = funs
50+
return lambda s, param: f1(
51+
f2(s, param[: len(param) // 2]), param[len(param) // 2 :]
52+
)
53+
else:
54+
f1 = recursive_checkpoint(funs[len(funs) // 2 :])
55+
f2 = recursive_checkpoint(funs[: len(funs) // 2])
56+
return lambda s, param: f1(
57+
jax.checkpoint(f2)(s, param[: len(param) // 2]), param[len(param) // 2 :]
58+
)
59+
60+
61+
# not suggest in general for recursive checkpoint: too slow for staging (compiling)
62+
63+
"""
64+
test case:
65+
def f(s, param):
66+
return s + param
67+
fc = recursive_checkpoint([f for _ in range(100)])
68+
print(fc(jnp.zeros([2]), jnp.array([[i, i] for i in range(100)])))
69+
"""
70+
71+
72+
@jax.checkpoint
73+
@jax.jit
74+
def zzxlayer(s, param):
75+
c = tc.Circuit(nwires, inputs=s)
76+
for i in range(0, nwires):
77+
c.exp1(
78+
i,
79+
(i + 1) % nwires,
80+
theta=param[0, i],
81+
unitary=tc.gates._zz_matrix,
82+
)
83+
for i in range(nwires):
84+
c.rx(i, theta=param[0, nwires + i])
85+
return c.state()
86+
87+
88+
@jax.checkpoint
89+
@jax.jit
90+
def zzxsqrtlayer(s, param):
91+
for i in range(sn):
92+
s = zzxlayer(s, param[i : i + 1])
93+
return s
94+
95+
96+
@jax.jit
97+
def totallayer(s, param):
98+
for i in range(sn):
99+
s = zzxsqrtlayer(s, param[i * sn : (i + 1) * sn])
100+
return s
101+
102+
103+
def vqe_forward(param):
104+
s = tc.backend.ones([2 ** nwires])
105+
s /= tc.backend.norm(s)
106+
s = totallayer(s, param)
107+
e = tc.expectation((tc.gates.x(), [1]), ket=s)
108+
return tc.backend.real(e)
109+
110+
111+
def profile(tries=3):
112+
time0 = time.time()
113+
tc_vag = tc.backend.jit(tc.backend.value_and_grad(vqe_forward))
114+
param = tc.backend.cast(tc.backend.ones([nlayers, 2 * nwires]), "complex64")
115+
print(tc_vag(param))
116+
117+
time1 = time.time()
118+
for i in range(tries):
119+
print(tc_vag(param)[0])
120+
121+
time2 = time.time()
122+
print(time1 - time0, (time2 - time1) / tries)
123+
124+
125+
profile()

0 commit comments

Comments
 (0)