Skip to content

Commit 50c68bf

Browse files
committedMar 24, 2022
add extra size vqe example using mpo hamiltonian
1 parent 5f06129 commit 50c68bf

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed
 

‎examples/vqe_extra_mpo.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
Demonstration of TFIM VQE with extra size in MPO formulation
3+
"""
4+
5+
import time
6+
import logging
7+
import sys
8+
import numpy as np
9+
10+
logger = logging.getLogger("tensorcircuit")
11+
logger.setLevel(logging.INFO)
12+
ch = logging.StreamHandler()
13+
ch.setLevel(logging.DEBUG)
14+
logger.addHandler(ch)
15+
16+
sys.setrecursionlimit(10000)
17+
18+
import tensorflow as tf
19+
import tensornetwork as tn
20+
import cotengra as ctg
21+
22+
import optax
23+
import tensorcircuit as tc
24+
25+
opt = ctg.ReusableHyperOptimizer(
26+
methods=["greedy", "kahypar"],
27+
parallel="ray",
28+
minimize="combo",
29+
max_time=360,
30+
max_repeats=4096,
31+
progbar=True,
32+
)
33+
34+
35+
def opt_reconf(inputs, output, size, **kws):
36+
tree = opt.search(inputs, output, size)
37+
tree_r = tree.subtree_reconfigure_forest(
38+
parallel="ray",
39+
progbar=True,
40+
num_trees=20,
41+
num_restarts=20,
42+
subtree_weight_what=("size",),
43+
)
44+
return tree_r.path()
45+
46+
47+
tc.set_contractor("custom", optimizer=opt_reconf, preprocessing=True)
48+
tc.set_dtype("complex64")
49+
tc.set_backend("tensorflow")
50+
# jax backend is incompatible with keras.save
51+
52+
dtype = np.complex64
53+
54+
nwires, nlayers = 150, 7
55+
56+
57+
Jx = np.array([1.0 for _ in range(nwires - 1)]) # strength of xx interaction (OBC)
58+
Bz = np.array([-1.0 for _ in range(nwires)]) # strength of transverse field
59+
hamiltonian_mpo = tn.matrixproductstates.mpo.FiniteTFI(
60+
Jx, Bz, dtype=dtype
61+
) # matrix product operator
62+
hamiltonian_mpo = tc.quantum.tn2qop(hamiltonian_mpo)
63+
64+
65+
def vqe_forward(param):
66+
print("compiling")
67+
split_conf = {
68+
"max_singular_values": 2,
69+
"fixed_choice": 1,
70+
}
71+
c = tc.Circuit(nwires, split=split_conf)
72+
for i in range(nwires):
73+
c.H(i)
74+
for j in range(nlayers):
75+
for i in range(0, nwires - 1):
76+
c.exp1(
77+
i,
78+
(i + 1) % nwires,
79+
theta=param[4 * j, i],
80+
unitary=tc.gates._xx_matrix,
81+
)
82+
83+
for i in range(nwires):
84+
c.rz(i, theta=param[4 * j + 1, i])
85+
for i in range(nwires):
86+
c.ry(i, theta=param[4 * j + 2, i])
87+
for i in range(nwires):
88+
c.rz(i, theta=param[4 * j + 3, i])
89+
return tc.templates.measurements.mpo_expectation(c, hamiltonian_mpo)
90+
91+
92+
if __name__ == "__main__":
93+
refresh = False
94+
95+
time0 = time.time()
96+
if refresh:
97+
tc_vag = tf.function(
98+
tc.backend.value_and_grad(vqe_forward),
99+
input_signature=[tf.TensorSpec([4 * nlayers, nwires], tf.float32)],
100+
)
101+
tc.keras.save_func(tc_vag, "./funcs/%s_%s_tfim_mpo" % (nwires, nlayers))
102+
time1 = time.time()
103+
print("staging time: ", time1 - time0)
104+
105+
tc_vag_loaded = tc.keras.load_func("./funcs/%s_%s_tfim_mpo" % (nwires, nlayers))
106+
107+
lr1 = 0.008
108+
lr2 = 0.06
109+
steps = 2000
110+
switch = 400
111+
debug_steps = 20
112+
113+
if tc.backend.name == "jax":
114+
opt = tc.backend.optimizer(optax.adam(lr1))
115+
opt2 = tc.backend.optimizer(optax.sgd(lr2))
116+
else:
117+
opt = tc.backend.optimizer(tf.keras.optimizers.Adam(lr1))
118+
opt2 = tc.backend.optimizer(tf.keras.optimizers.SGD(lr2))
119+
120+
times = []
121+
param = tc.backend.implicit_randn(stddev=0.1, shape=[4 * nlayers, nwires])
122+
123+
for j in range(steps):
124+
loss, gr = tc_vag_loaded(param)
125+
if j < switch:
126+
param = opt.update(gr, param)
127+
else:
128+
if j == switch:
129+
print("switching the optimizer")
130+
param = opt2.update(gr, param)
131+
if j % debug_steps == 0 or j == steps - 1:
132+
times.append(time.time())
133+
print("loss", tc.backend.numpy(loss))
134+
if j > 0:
135+
print("running time:", (times[-1] - times[0]) / j)
136+
137+
138+
"""
139+
# Baseline code: obtained from DMRG using quimb
140+
141+
import quimb
142+
143+
h = quimb.tensor.tensor_gen.MPO_ham_ising(nwires, 4, 2, cyclic=False)
144+
dmrg = quimb.tensor.tensor_dmrg.DMRG2(m, bond_dims=[10, 20, 100, 100, 200], cutoffs=1e-10)
145+
dmrg.solve(tol=1e-9, verbosity=1) # may require repetition of this API
146+
"""

0 commit comments

Comments
 (0)
Please sign in to comment.