Skip to content

Commit d13a81e

Browse files
add mipt examples
1 parent da7f32d commit d13a81e

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

examples/mipt.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
demo example of mipt in tc style
3+
"""
4+
from functools import partial
5+
import time
6+
import numpy as np
7+
from scipy import stats
8+
import tensorcircuit as tc
9+
10+
K = tc.set_backend("jax")
11+
# tf backend is slow (at least on cpu)
12+
13+
14+
@partial(K.jit, static_argnums=(2, 3, 4))
15+
def circuit_output(random_matrix, status, n, d, p):
16+
"""
17+
mipt circuit
18+
19+
:param random_matrix: a float or complex tensor containing 4*4 random haar matrix wth size [d*n, 4, 4]
20+
:type random_matrix: _type_
21+
:param status: a int tensor with element in 0 or 1 or 2 (no meausrement) with size d*n
22+
:type status: _type_
23+
:param n: number of qubits
24+
:type n: _type_
25+
:param d: number of depth
26+
:type d: _type_
27+
:param p: measurement ratio
28+
:type p: float
29+
:return: output state
30+
"""
31+
random_matrix = K.reshape(random_matrix, [d, n, 4, 4])
32+
status = K.reshape(status, [d, n])
33+
inputs = None
34+
for j in range(d):
35+
if inputs is None:
36+
c = tc.Circuit(n)
37+
else:
38+
c = tc.Circuit(n, inputs=inputs)
39+
for i in range(0, n, 2):
40+
c.unitary(i, (i + 1) % n, unitary=random_matrix[j, i])
41+
for i in range(1, n, 2):
42+
c.unitary(i, (i + 1) % n, unitary=random_matrix[j, i])
43+
inputs = c.state()
44+
c = tc.Circuit(n, inputs=inputs)
45+
for i in range(n):
46+
c.general_kraus(
47+
[
48+
np.sqrt(p) * np.array([[1.0, 0], [0, 0]]),
49+
np.sqrt(p) * np.array([[0, 0], [0, 1.0]]),
50+
np.sqrt(1 - p) * np.eye(2),
51+
],
52+
i,
53+
status=status[j, i],
54+
)
55+
inputs = c.state()
56+
c = tc.Circuit(n, inputs=inputs)
57+
inputs = c.state()
58+
inputs /= K.norm(inputs)
59+
return inputs
60+
61+
62+
@partial(K.jit, static_argnums=(2, 3, 4))
63+
def cals(random_matrix, status, n, d, p):
64+
state = circuit_output(random_matrix, status, n, d, p)
65+
rho = tc.quantum.reduced_density_matrix(state, cut=[i for i in range(n // 2)])
66+
return tc.quantum.entropy(rho), tc.quantum.renyi_entropy(rho, k=2)
67+
68+
69+
if __name__ == "__main__":
70+
n = 12
71+
d = 12
72+
st = np.random.uniform(size=[d * n])
73+
## assume all X gate instead
74+
rm = [stats.unitary_group.rvs(4) for _ in range(d * n)]
75+
rm = [r / np.linalg.det(r) for r in rm]
76+
rm = np.stack(rm)
77+
time0 = time.time()
78+
print(cals(rm, st, n, d, 0.1))
79+
time1 = time.time()
80+
st = np.random.uniform(size=[d * n])
81+
print(cals(rm, st, n, d, 0.1))
82+
time2 = time.time()
83+
print(f"compiling time {time1-time0}, running time {time2-time1}")

0 commit comments

Comments
 (0)