Skip to content

Commit cc5120b

Browse files
committedAug 7, 2022
batch and jit support for sample_expectation_ps
1 parent 3c3ce72 commit cc5120b

13 files changed

+198
-34
lines changed
 

‎CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
- Add `parameter_shift.py` script in examples
1616

17+
- Add jit support and external random management for `tc.quantum.measurement_counts`
18+
1719
### Changed
1820

1921
- `rxx`, `ryy`, `rzz` gates now has 1/2 factor before theta consitent with `rx`, `ry`, `rz` gates. (breaking change)
2022

23+
- replace `status` arguments in `sample` method as `random_generator` (new convention: status for 0, 1 uniform randomness and random_generator for random key) (breaking change)
24+
2125
## 0.3.1
2226

2327
### Added

‎docs/source/advance.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,6 @@ And a more neat approach to achieve this is as follows:
171171
172172
It is worth noting that since ``Circuit.unitary_kraus`` and ``Circuit.general_kraus`` call ``implicit_rand*`` API, the correct usage of these APIs is the same as above.
173173

174-
One may wonder why random numbers are dealt in such a complicated way, please refer to the `Jax design note <https://github.com/google/jax/blob/main/docs/design_notes/prng.md>`_ for some hints.
174+
One may wonder why random numbers are dealt in such a complicated way, please refer to the `Jax design note <https://github.com/google/jax/blob/main/docs/design_notes/prng.md>`_ for some hints.
175+
176+
If vmap is also involved apart from jit, I currently find no way to maintain the backend agnosticity as TensorFlow seems to have no support of vmap over random keys (ping me on GitHub if you think you have a way to do this). I strongly recommend the users using Jax backend in the vmap+random setup.

‎examples/parameter_shift.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def f1(param):
1919
c.cnot(i, i + 1)
2020
for i in range(n):
2121
c.rx(i, theta=param[i, j])
22-
return c.expectation_ps(y=[n // 2])
22+
return K.real(c.expectation_ps(y=[n // 2]))
2323

2424

2525
g1f1 = K.jit(K.grad(f1))
@@ -43,21 +43,59 @@ def f2(paramzz, paramx):
4343
c.rzz(i, i + 1, theta=paramzz[i, j])
4444
for i in range(n):
4545
c.rx(i, theta=paramx[i, j])
46-
return c.expectation_ps(y=[n // 2])
46+
return K.real(c.expectation_ps(y=[n // 2]))
4747

4848

4949
g1f2 = K.jit(K.grad(f2, argnums=(0, 1)))
5050

51-
r1, ts, tr = tc.utils.benchmark(
51+
r12, ts, tr = tc.utils.benchmark(
5252
g1f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
5353
)
5454

5555
g2f2 = K.jit(E.parameter_shift_grad(f2, argnums=(0, 1)))
5656

57-
r2, ts, tr = tc.utils.benchmark(
57+
r22, ts, tr = tc.utils.benchmark(
5858
g2f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
5959
)
6060

61-
np.testing.assert_allclose(r1[0], r2[0], atol=1e-5)
62-
np.testing.assert_allclose(r1[1], r2[1], atol=1e-5)
63-
print("equality test passed!")
61+
np.testing.assert_allclose(r12[0], r22[0], atol=1e-5)
62+
np.testing.assert_allclose(r12[1], r22[1], atol=1e-5)
63+
print("mutilple weight inputs: equality test passed!")
64+
65+
# sampled expectation version
66+
67+
68+
def f3(param):
69+
c = tc.Circuit(n)
70+
for j in range(m):
71+
for i in range(n - 1):
72+
c.cnot(i, i + 1)
73+
for i in range(n):
74+
c.rx(i, theta=param[i, j])
75+
return K.real(c.sample_expectation_ps(y=[n // 2]))
76+
77+
78+
g2f3 = K.jit(E.parameter_shift_grad(f3))
79+
80+
r2, ts, tr = tc.utils.benchmark(g2f3, K.ones([n, m], dtype="float32"))
81+
82+
np.testing.assert_allclose(r1, r2, atol=1e-5)
83+
print("analytical sampled expectation: equality test passed!")
84+
85+
86+
# def f3(param):
87+
# c = tc.Circuit(n)
88+
# for j in range(m):
89+
# for i in range(n - 1):
90+
# c.cnot(i, i + 1)
91+
# for i in range(n):
92+
# c.rx(i, theta=param[i, j])
93+
# return K.real(c.sample_expectation_ps(y=[n // 2], shots=81920))
94+
95+
96+
# g2f3 = K.jit(E.parameter_shift_grad(f3))
97+
98+
# r2, ts, tr = tc.utils.benchmark(g2f3, K.ones([n, m], dtype="float32"))
99+
# print(r1 - r2)
100+
# np.testing.assert_allclose(r1 - r2, np.zeros_like(r1), atol=1e-3)
101+
# print("finite sampled expectation: equality test passed!")

‎tensorcircuit/backends/abstract_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def argmin(self: Any, a: Tensor, axis: int = 0) -> Tensor:
534534
"Backend '{}' has not implemented `argmin`.".format(self.name)
535535
)
536536

537-
def unique_with_counts(self: Any, a: Tensor) -> Tuple[Tensor, Tensor]:
537+
def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
538538
"""
539539
Find the unique elements and their corresponding counts of the given tensor ``a``.
540540

‎tensorcircuit/backends/jax_backend.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,10 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
367367
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
368368
return jnp.argmin(a, axis=axis)
369369

370-
def unique_with_counts(self, a: Tensor) -> Tuple[Tensor, Tensor]:
371-
return jnp.unique(a, return_counts=True) # type: ignore
370+
def unique_with_counts(
371+
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
372+
) -> Tuple[Tensor, Tensor]:
373+
return jnp.unique(a, return_counts=True, size=size, fill_value=fill_value) # type: ignore
372374

373375
def sigmoid(self, a: Tensor) -> Tensor:
374376
return libjax.nn.sigmoid(a)

‎tensorcircuit/backends/numpy_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def mean(
164164
) -> Tensor:
165165
return np.mean(a, axis=axis, keepdims=keepdims)
166166

167-
def unique_with_counts(self, a: Tensor) -> Tuple[Tensor, Tensor]:
167+
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
168168
return np.unique(a, return_counts=True) # type: ignore
169169

170170
def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor:

‎tensorcircuit/backends/pytorch_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
349349
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
350350
return torchlib.argmin(a, dim=axis)
351351

352-
def unique_with_counts(self, a: Tensor) -> Tuple[Tensor, Tensor]:
352+
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
353353
return torchlib.unique(a, return_counts=True) # type: ignore
354354

355355
def sigmoid(self, a: Tensor) -> Tensor:

‎tensorcircuit/backends/tensorflow_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
355355
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
356356
return tf.math.argmin(a, axis=axis)
357357

358-
def unique_with_counts(self, a: Tensor) -> Tuple[Tensor, Tensor]:
358+
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
359359
r = tf.unique_with_counts(a)
360360
order = tf.argsort(r.y)
361361
return tf.gather(r.y, order), tf.gather(r.count, order)

‎tensorcircuit/basecircuit.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# pylint: disable=invalid-name
55

66
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
7-
from copy import deepcopy
87

98
import numpy as np
109
import graphviz
@@ -543,7 +542,7 @@ def sample(
543542
self,
544543
batch: Optional[int] = None,
545544
allow_state: bool = False,
546-
status: Optional[Tensor] = None,
545+
random_generator: Optional[Any] = None,
547546
) -> Any:
548547
"""
549548
batched sampling from state or circuit tensor network directly
@@ -553,18 +552,18 @@ def sample(
553552
:param allow_state: if true, we sample from the final state
554553
if memory allsows, True is prefered, defaults to False
555554
:type allow_state: bool, optional
556-
:param status: random generator, defaults to None
557-
:type status: Optional[Tensor], optional
555+
:param random_generator: random generator, defaults to None
556+
:type random_generator: Optional[Any], optional
558557
:return: List (if batch) of tuple (binary configuration tensor and correponding probability)
559558
:rtype: Any
560559
"""
561560
# allow_state = False is compatibility issue
562561
if not allow_state:
563-
if status is None:
564-
status = backend.get_random_state()
562+
if random_generator is None:
563+
random_generator = backend.get_random_state()
565564

566565
if batch is None:
567-
seed = backend.stateful_randu(status, shape=[self._nqubits])
566+
seed = backend.stateful_randu(random_generator, shape=[self._nqubits])
568567
return self.perfect_sampling(seed)
569568

570569
@backend.jit # type: ignore
@@ -574,7 +573,7 @@ def perfect_sampling(key: Any) -> Any:
574573

575574
r = []
576575

577-
subkey = status
576+
subkey = random_generator
578577
for _ in range(batch):
579578
key, subkey = backend.random_split(subkey)
580579
r.append(perfect_sampling(key))
@@ -589,12 +588,12 @@ def perfect_sampling(key: Any) -> Any:
589588
if self.is_dm is False:
590589
p = backend.abs(s) ** 2
591590
else:
592-
p = backend.real(backend.diagonal(s))
593-
if status is None:
591+
p = backend.abs(backend.diagonal(s))
592+
if random_generator is None:
594593
ch = backend.implicit_randc(a=2**self._nqubits, shape=[nbatch], p=p)
595594
else:
596595
ch = backend.stateful_randc(
597-
status, a=2**self._nqubits, shape=[nbatch], p=p
596+
random_generator, a=2**self._nqubits, shape=[nbatch], p=p
598597
)
599598
prob = backend.gather1d(p, ch)
600599
confg = backend.mod(
@@ -614,6 +613,7 @@ def sample_expectation_ps(
614613
y: Optional[Sequence[int]] = None,
615614
z: Optional[Sequence[int]] = None,
616615
shots: Optional[int] = None,
616+
random_generator: Optional[Any] = None,
617617
**kws: Any,
618618
) -> Tensor:
619619
"""
@@ -635,10 +635,15 @@ def sample_expectation_ps(
635635
:type z: Optional[Sequence[int]], optional
636636
:param shots: number of measurement shots, defaults to None, indicating analytical result
637637
:type shots: Optional[int], optional
638+
:param random_generator: random_generator, defaults to None
639+
:type random_general: Optional[Any]
638640
:return: [description]
639641
:rtype: Tensor
640642
"""
641-
c = deepcopy(self)
643+
if self.is_dm is False:
644+
c = type(self)(self._nqubits, mps_inputs=self.quvector()) # type: ignore
645+
else:
646+
c = type(self)(self._nqubits, mpo_dminputs=self.get_dm_as_quoperator()) # type: ignore
642647
if x is None:
643648
x = []
644649
if y is None:
@@ -653,9 +658,17 @@ def sample_expectation_ps(
653658
if c.is_dm is False:
654659
p = backend.abs(s) ** 2
655660
else:
656-
p = backend.real(backend.diagonal(s))
661+
p = backend.abs(backend.diagonal(s))
657662
# readout error can be processed here later
658-
mc = measurement_counts(p, counts=shots, sparse=False, is_prob=True)
663+
# TODO(@refraction-ray): explicit management on randomness
664+
mc = measurement_counts(
665+
p,
666+
counts=shots,
667+
sparse=False,
668+
is_prob=True,
669+
random_generator=random_generator,
670+
jittable=True,
671+
)
659672
x = list(x)
660673
y = list(y)
661674
z = list(z)

‎tensorcircuit/experimental.py

+1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def parameter_shift_grad(
226226
:return: the grad function
227227
:rtype: Callable[..., Tensor]
228228
"""
229+
# TODO(@refraction-ray): finite shot sample_expectation_ps not supported well for now
229230
if jit is True:
230231
f = backend.jit(f)
231232

‎tensorcircuit/quantum.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,8 @@ def measurement_counts(
18781878
counts: Optional[int] = 8192,
18791879
sparse: bool = True,
18801880
is_prob: bool = False,
1881+
random_generator: Optional[Any] = None,
1882+
jittable: bool = False,
18811883
) -> Union[Tuple[Tensor, Tensor], Tensor]:
18821884
"""
18831885
Simulate the measuring of each qubit of ``p`` in the computational basis,
@@ -1900,6 +1902,11 @@ def measurement_counts(
19001902
:param sparse: Defaults True. The bool indicating whether
19011903
the return form is in the form of two array or one of the same length as the ``state`` (if ``sparse=False``).
19021904
:type sparse: bool
1905+
:param is_prob: if True, the `state` is directly regarded as a probability list,
1906+
defaults to be False
1907+
:type is_prob: bool
1908+
:param random_generator: random_generator, defaults to None
1909+
:type random_general: Optional[Any]
19031910
:return: The counts for each bit string measured.
19041911
:rtype: Tuple[]
19051912
"""
@@ -1908,21 +1915,30 @@ def measurement_counts(
19081915
else:
19091916
if len(state.shape) == 2:
19101917
state /= backend.trace(state)
1911-
pi = backend.real(backend.diagonal(state))
1918+
pi = backend.abs(backend.diagonal(state))
19121919
else:
19131920
state /= backend.norm(state)
19141921
pi = backend.real(backend.conj(state) * state)
19151922
pi = backend.reshape(pi, [-1])
1916-
d = int(pi.shape[0])
1923+
d = int(backend.shape_tuple(pi)[0])
1924+
drange = backend.arange(d)
19171925
# raw counts in terms of integers
19181926
if (counts is None) or counts <= 0:
19191927
if not sparse:
19201928
return pi
19211929
else:
19221930
return counts_d2s(pi)
19231931
else:
1924-
raw_counts = backend.implicit_randc(d, shape=counts, p=pi)
1925-
results = backend.unique_with_counts(raw_counts)
1932+
if random_generator is None:
1933+
raw_counts = backend.implicit_randc(drange, shape=counts, p=pi)
1934+
else:
1935+
raw_counts = backend.stateful_randc(
1936+
random_generator, a=drange, shape=counts, p=pi
1937+
)
1938+
if not jittable:
1939+
results = backend.unique_with_counts(raw_counts) # non-jittable
1940+
else: # jax specified
1941+
results = backend.unique_with_counts(raw_counts, size=d, fill_value=-1)
19261942
if sparse:
19271943
return results # type: ignore
19281944
dense_results = counts_s2d(results, d)

‎tests/test_circuit.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,14 @@ def test_batch_sample(backend):
10021002
c.cnot(0, 1)
10031003
print(c.sample())
10041004
print(c.sample(batch=8))
1005-
print(c.sample(status=tc.backend.get_random_state(42)))
1005+
print(c.sample(random_generator=tc.backend.get_random_state(42)))
10061006
print(c.sample(allow_state=True))
10071007
print(c.sample(batch=8, allow_state=True))
1008-
print(c.sample(batch=8, allow_state=True, status=tc.backend.get_random_state(42)))
1008+
print(
1009+
c.sample(
1010+
batch=8, allow_state=True, random_generator=tc.backend.get_random_state(42)
1011+
)
1012+
)
10091013

10101014

10111015
def test_expectation_y_bug():

‎tests/test_dmcircuit.py

+84
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,87 @@ def test_prepend_dmcircuit(backend):
468468
assert n["name"] == n0
469469
s = c3.wavefunction()
470470
np.testing.assert_allclose(s[0], s[1], atol=1e-5)
471+
472+
473+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
474+
def test_dm_sexpps(backend):
475+
c = tc.DMCircuit(1, inputs=1 / np.sqrt(2) * np.array([1.0, 1.0j]))
476+
y = c.sample_expectation_ps(y=[0])
477+
ye = c.expectation_ps(y=[0])
478+
np.testing.assert_allclose(y, 1.0, atol=1e-5)
479+
np.testing.assert_allclose(ye, 1.0, atol=1e-5)
480+
481+
c = tc.DMCircuit(4)
482+
c.H(0)
483+
c.H(1)
484+
c.X(2)
485+
c.Y(3)
486+
c.cnot(0, 1)
487+
c.depolarizing(1, px=0.05, py=0.05, pz=0.1)
488+
c.rx(1, theta=0.3)
489+
c.ccnot(2, 3, 1)
490+
c.depolarizing(0, px=0.05, py=0.0, pz=0.1)
491+
c.rzz(0, 3, theta=0.5)
492+
c.ry(3, theta=2.2)
493+
c.amplitudedamping(2, gamma=0.1, p=0.95)
494+
c.s(1)
495+
c.td(2)
496+
c.cswap(3, 0, 1)
497+
y = c.sample_expectation_ps(x=[1], y=[0], z=[2, 3])
498+
ye = c.expectation_ps(x=[1], y=[0], z=[2, 3])
499+
np.testing.assert_allclose(ye, y, atol=1e-5)
500+
y2 = c.sample_expectation_ps(x=[1], y=[0], z=[2, 3], shots=81920)
501+
assert np.abs(y2 - y) < 0.01
502+
503+
504+
def test_dm_sexpps_jittable_vamppable(jaxb):
505+
n = 4
506+
m = 2
507+
508+
def f(param, key):
509+
c = tc.DMCircuit(n)
510+
for j in range(m):
511+
for i in range(n - 1):
512+
c.cnot(i, i + 1)
513+
for i in range(n):
514+
c.rx(i, theta=param[i, j])
515+
return tc.backend.real(
516+
c.sample_expectation_ps(y=[n // 2], shots=8192, random_generator=key)
517+
)
518+
519+
vf = tc.backend.jit(tc.backend.vmap(f, vectorized_argnums=(0, 1)))
520+
r = vf(
521+
tc.backend.ones([2, n, m], dtype="float32"),
522+
tc.backend.stack(
523+
[
524+
tc.backend.get_random_state(42),
525+
tc.backend.get_random_state(43),
526+
]
527+
),
528+
)
529+
assert np.abs(r[0] - r[1]) > 1e-4
530+
531+
print(r)
532+
533+
534+
def test_dm_sexpps_jittable_vamppable_tf(tfb):
535+
# finally giving up backend agnosticity
536+
# and not sure the effciency and the safety of vmap random in tf
537+
n = 4
538+
m = 2
539+
540+
def f(param):
541+
c = tc.DMCircuit(n)
542+
for j in range(m):
543+
for i in range(n - 1):
544+
c.cnot(i, i + 1)
545+
for i in range(n):
546+
c.rx(i, theta=param[i, j])
547+
return tc.backend.real(c.sample_expectation_ps(y=[n // 2], shots=8192))
548+
549+
vf = tc.backend.jit(tc.backend.vmap(f, vectorized_argnums=0))
550+
r = vf(tc.backend.ones([2, n, m]))
551+
r1 = vf(tc.backend.ones([2, n, m]))
552+
assert np.abs(r[0] - r[1]) > 1e-5
553+
assert np.abs(r[0] - r1[0]) > 1e-5
554+
print(r, r1)

0 commit comments

Comments
 (0)
Please sign in to comment.