Skip to content

Commit ec4cdad

Browse files
add with_prob in general_krais
1 parent 3af21a4 commit ec4cdad

File tree

4 files changed

+177
-3
lines changed

4 files changed

+177
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Add multiple GPU VQE examples using jax pmap
88

9+
- Add `with_prob` option to `general_kraus` so that the probability of each option can be returned together
10+
911
- Add benchmark example showcasing new way of implementing matrix product using vmap
1012

1113
## 0.10.0

examples/mipt_pideal.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
demo example of mipt in tc style, with ideal p for each history trajectory
3+
p is also jittable now, change parameter p doesn't trigger recompiling
4+
"""
5+
6+
from functools import partial
7+
import time
8+
import numpy as np
9+
from scipy import stats
10+
import tensorcircuit as tc
11+
12+
K = tc.set_backend("jax")
13+
tc.set_dtype("complex128")
14+
# tf backend is slow (at least on cpu)
15+
16+
17+
def delete2(pick, plist):
18+
# pick = 0, 1 : return plist[pick]/(plist[0]+plist[1])
19+
# pick = 2: return 1
20+
indicator = (K.sign(1.5 - pick) + 1) / 2 # 0,1 : 1, 2: 0
21+
p = 0
22+
p += 1 - indicator
23+
p += indicator / (plist[0] + plist[1]) * (plist[0] * (1 - pick) + plist[1] * pick)
24+
return p
25+
26+
27+
@partial(K.jit, static_argnums=(2, 3))
28+
def circuit_output(random_matrix, status, n, d, p):
29+
"""
30+
mipt circuit
31+
32+
:param random_matrix: a float or complex tensor containing 4*4 random haar matrix wth size [d*n, 4, 4]
33+
:type random_matrix: _type_
34+
:param status: a int tensor with element in 0 or 1 or 2 (no meausrement) with size d*n
35+
:type status: _type_
36+
:param n: number of qubits
37+
:type n: _type_
38+
:param d: number of depth
39+
:type d: _type_
40+
:param p: measurement ratio
41+
:type p: float
42+
:return: output state
43+
"""
44+
random_matrix = K.reshape(random_matrix, [d, n, 4, 4])
45+
status = K.reshape(status, [d, n])
46+
inputs = None
47+
bs_history = []
48+
prob_history = []
49+
for j in range(d):
50+
if inputs is None:
51+
c = tc.Circuit(n)
52+
else:
53+
c = tc.Circuit(n, inputs=inputs)
54+
for i in range(0, n, 2):
55+
c.unitary(i, (i + 1) % n, unitary=random_matrix[j, i])
56+
for i in range(1, n, 2):
57+
c.unitary(i, (i + 1) % n, unitary=random_matrix[j, i])
58+
inputs = c.state()
59+
c = tc.Circuit(n, inputs=inputs)
60+
for i in range(n):
61+
pick, plist = c.general_kraus(
62+
[
63+
K.sqrt(p) * K.convert_to_tensor(np.array([[1.0, 0], [0, 0]])),
64+
K.sqrt(p) * K.convert_to_tensor(np.array([[0, 0], [0, 1.0]])),
65+
K.sqrt(1 - p) * K.eye(2),
66+
],
67+
i,
68+
status=status[j, i],
69+
with_prob=True,
70+
)
71+
bs_history.append(pick)
72+
prob_history.append(delete2(pick, plist))
73+
inputs = c.state()
74+
c = tc.Circuit(n, inputs=inputs)
75+
inputs = c.state()
76+
inputs /= K.norm(inputs)
77+
bs_history = K.stack(bs_history)
78+
prob_history = K.stack(prob_history)
79+
return inputs, bs_history, prob_history, K.sum(K.log(prob_history + 1e-11))
80+
81+
82+
@partial(K.jit, static_argnums=(2, 3))
83+
def cals(random_matrix, status, n, d, p):
84+
state, bs_history, prob_history, prob = circuit_output(
85+
random_matrix, status, n, d, p
86+
)
87+
rho = tc.quantum.reduced_density_matrix(state, cut=[i for i in range(n // 2)])
88+
return (
89+
tc.quantum.entropy(rho),
90+
tc.quantum.renyi_entropy(rho, k=2),
91+
bs_history,
92+
prob_history,
93+
prob,
94+
)
95+
96+
97+
if __name__ == "__main__":
98+
n = 12
99+
d = 12
100+
st = np.random.uniform(size=[d * n])
101+
## assume all X gate instead
102+
rm = [stats.unitary_group.rvs(4) for _ in range(d * n)]
103+
rm = [r / np.linalg.det(r) for r in rm]
104+
rm = np.stack(rm)
105+
time0 = time.time()
106+
print(cals(rm, st, n, d, 0.6))
107+
time1 = time.time()
108+
st = np.random.uniform(size=[d * n])
109+
print(cals(rm, st, n, d, 0.1))
110+
time2 = time.time()
111+
print(f"compiling time {time1-time0}, running time {time2-time1}")

tensorcircuit/circuit.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def _general_kraus_2(
507507
kraus: Sequence[Gate],
508508
*index: int,
509509
status: Optional[float] = None,
510+
with_prob: bool = False,
510511
name: Optional[str] = None,
511512
) -> Tensor:
512513
# the graph building time is frustratingly slow, several minutes
@@ -554,16 +555,20 @@ def calculate_kraus_p(i: int) -> Tensor:
554555
k / backend.cast(backend.sqrt(w) + eps, dtypestr)
555556
for w, k in zip(prob, kraus_tensor)
556557
]
557-
558-
return self.unitary_kraus(
558+
pick = self.unitary_kraus(
559559
new_kraus, *index, prob=prob, status=status, name=name
560560
)
561+
if with_prob is False:
562+
return pick
563+
else:
564+
return pick, prob
561565

562566
def general_kraus(
563567
self,
564568
kraus: Sequence[Gate],
565569
*index: int,
566570
status: Optional[float] = None,
571+
with_prob: bool = False,
567572
name: Optional[str] = None,
568573
) -> Tensor:
569574
"""
@@ -583,7 +588,9 @@ def general_kraus(
583588
when the random number will be generated automatically
584589
:type status: Optional[float], optional
585590
"""
586-
return self._general_kraus_2(kraus, *index, status=status, name=name)
591+
return self._general_kraus_2(
592+
kraus, *index, status=status, with_prob=with_prob, name=name
593+
)
587594

588595
apply_general_kraus = general_kraus
589596

tests/test_circuit.py

+54
Original file line numberDiff line numberDiff line change
@@ -1573,3 +1573,57 @@ def test_fancy_circuit_indexing(backend):
15731573
assert c.gate_count("h") == 4
15741574
assert c.gate_count("rzz") == 2
15751575
assert c.gate_count("rxx") == 3
1576+
1577+
1578+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("npb")])
1579+
def test_general_kraus(backend):
1580+
c = tc.Circuit(2)
1581+
c.h([0, 1])
1582+
p = 0.5
1583+
status = [0.3, 0.8]
1584+
rs = []
1585+
for i in range(2):
1586+
rs.append(
1587+
c.general_kraus(
1588+
[
1589+
np.sqrt(p) * np.array([[1.0, 0], [0, 0]]),
1590+
np.sqrt(p) * np.array([[0, 0], [0, 1.0]]),
1591+
np.sqrt(1 - p) * np.eye(2),
1592+
],
1593+
i,
1594+
status=status[i],
1595+
)
1596+
)
1597+
np.testing.assert_allclose(rs[0], 1)
1598+
np.testing.assert_allclose(rs[1], 2)
1599+
np.testing.assert_allclose(c.expectation_ps(z=[0]), -1, atol=1e-5)
1600+
np.testing.assert_allclose(c.expectation_ps(z=[1]), 0, atol=1e-5)
1601+
1602+
1603+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("npb")])
1604+
def test_general_kraus_with_prob(backend):
1605+
c = tc.Circuit(2)
1606+
c.h([0, 1])
1607+
p = 0.5
1608+
status = [0.3, 0.8]
1609+
rs = []
1610+
for i in range(2):
1611+
rs.append(
1612+
c.general_kraus(
1613+
[
1614+
np.sqrt(p) * np.array([[1.0, 0], [0, 0]]),
1615+
np.sqrt(p) * np.array([[0, 0], [0, 1.0]]),
1616+
np.sqrt(1 - p) * np.eye(2),
1617+
],
1618+
i,
1619+
status=status[i],
1620+
with_prob=True,
1621+
)
1622+
)
1623+
np.testing.assert_allclose(rs[0][0], 1)
1624+
np.testing.assert_allclose(rs[1][0], 2)
1625+
np.testing.assert_allclose(c.expectation_ps(z=[0]), -1, atol=1e-5)
1626+
np.testing.assert_allclose(c.expectation_ps(z=[1]), 0, atol=1e-5)
1627+
np.testing.assert_allclose(rs[0][1], [0.25, 0.25, 0.5], atol=1e-5)
1628+
np.testing.assert_allclose(rs[1][1], [0.25, 0.25, 0.5], atol=1e-5)
1629+
np.testing.assert_allclose(tc.backend.norm(c.state()), 1, atol=1e-5)

0 commit comments

Comments
 (0)