Skip to content

Commit e9b6e14

Browse files
sample method now support status
1 parent f2d48ff commit e9b6e14

File tree

6 files changed

+45
-20
lines changed

6 files changed

+45
-20
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
- Add `probability_sample` method for backend as an alternative for `random_choice` since it supports `status` as external randomness format
2424

25+
- Add `status` support for `sample` and `sample_expection_ps` methods
26+
2527
### Changed
2628

2729
- The inner mechanism for `sample_expectation_ps` is changed to sample representation from count representation for a fast speed

tensorcircuit/backends/abstract_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ def probability_sample(
10981098
p_cuml = self.cumsum(p)
10991099
r = p_cuml[-1] * (1 - self.cast(status, p.dtype))
11001100
ind = self.searchsorted(p_cuml, r)
1101-
a = self.arange(shots)
1101+
a = self.arange(self.shape_tuple(p)[0])
11021102
res = self.gather1d(a, ind)
11031103
return res
11041104

tensorcircuit/basecircuit.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ def sample(
510510
readout_error: Optional[Sequence[Any]] = None,
511511
format: Optional[str] = None,
512512
random_generator: Optional[Any] = None,
513+
status: Optional[Tensor] = None,
513514
) -> Any:
514515
"""
515516
batched sampling from state or circuit tensor network directly
@@ -526,6 +527,9 @@ def sample(
526527
:type format: Optional[str]
527528
:param random_generator: random generator, defaults to None
528529
:type random_generator: Optional[Any], optional
530+
:param status: external randomness given by tensor uniformly from [0, 1],
531+
if set, can overwrite random_generator
532+
:type status: Optional[Tensor]
529533
:return: List (if batch) of tuple (binary configuration tensor and correponding probability)
530534
if the format is None, and consitent with format when given
531535
:rtype: Any
@@ -578,21 +582,20 @@ def perfect_sampling(key: Any) -> Any:
578582
# readout error
579583
if readout_error is not None:
580584
p = self.readouterror_bs(readout_error, p)
581-
582-
a_range = backend.arange(2**self._nqubits)
583-
if random_generator is None:
584-
ch = backend.implicit_randc(a=a_range, shape=[nbatch], p=p)
585-
else:
586-
ch = backend.stateful_randc(
587-
random_generator, a=a_range, shape=[nbatch], p=p
588-
)
585+
ch = backend.probability_sample(nbatch, p, status, random_generator)
586+
# if random_generator is None:
587+
# ch = backend.implicit_randc(a=a_range, shape=[nbatch], p=p)
588+
# else:
589+
# ch = backend.stateful_randc(
590+
# random_generator, a=a_range, shape=[nbatch], p=p
591+
# )
589592
# confg = backend.mod(
590593
# backend.right_shift(
591594
# ch[..., None], backend.reverse(backend.arange(self._nqubits))
592595
# ),
593596
# 2,
594597
# )
595-
if format is None:
598+
if format is None: # for backward compatibility
596599
confg = sample_int2bin(ch, self._nqubits)
597600
prob = backend.gather1d(p, ch)
598601
r = list(zip(confg, prob)) # type: ignore
@@ -608,6 +611,7 @@ def sample_expectation_ps(
608611
z: Optional[Sequence[int]] = None,
609612
shots: Optional[int] = None,
610613
random_generator: Optional[Any] = None,
614+
status: Optional[Tensor] = None,
611615
readout_error: Optional[Sequence[Any]] = None,
612616
**kws: Any,
613617
) -> Tensor:
@@ -635,7 +639,10 @@ def sample_expectation_ps(
635639
:param shots: number of measurement shots, defaults to None, indicating analytical result
636640
:type shots: Optional[int], optional
637641
:param random_generator: random_generator, defaults to None
638-
:type random_general: Optional[Any]
642+
:type random_generator: Optional[Any]
643+
:param status: external randomness given by tensor uniformly from [0, 1],
644+
if set, can overwrite random_generator
645+
:type status: Optional[Tensor]
639646
:param readout_error: readout_error, defaults to None
640647
:type readout_error: Optional[Sequence[Any]]. Tensor, List, Tuple
641648
:return: [description]
@@ -674,6 +681,7 @@ def sample_expectation_ps(
674681
counts=shots,
675682
format="count_vector",
676683
random_generator=random_generator,
684+
status=status,
677685
jittable=True,
678686
is_prob=True,
679687
)
@@ -684,6 +692,7 @@ def sample_expectation_ps(
684692
counts=shots,
685693
format="sample_bin",
686694
random_generator=random_generator,
695+
status=status,
687696
jittable=True,
688697
is_prob=True,
689698
)

tensorcircuit/quantum.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,7 @@ def measurement_counts(
19991999
format: str = "count_vector",
20002000
is_prob: bool = False,
20012001
random_generator: Optional[Any] = None,
2002+
status: Optional[Tensor] = None,
20022003
jittable: bool = False,
20032004
) -> Any:
20042005
"""
@@ -2048,7 +2049,10 @@ def measurement_counts(
20482049
defaults to be False
20492050
:type is_prob: bool
20502051
:param random_generator: random_generator, defaults to None
2051-
:type random_general: Optional[Any]
2052+
:type random_generator: Optional[Any]
2053+
:param status: external randomness given by tensor uniformly from [0, 1],
2054+
if set, can overwrite random_generator
2055+
:type status: Optional[Tensor]
20522056
:param jittable: if True, jax backend try using a jittable count, defaults to False
20532057
:type jittable: bool
20542058
:return: The counts for each bit string measured.
@@ -2066,7 +2070,6 @@ def measurement_counts(
20662070
pi = backend.reshape(pi, [-1])
20672071
d = int(backend.shape_tuple(pi)[0])
20682072
n = int(np.log(d) / np.log(2) + 1e-8)
2069-
drange = backend.arange(d)
20702073
if (counts is None) or counts <= 0:
20712074
if format == "count_vector":
20722075
return pi
@@ -2081,12 +2084,15 @@ def measurement_counts(
20812084
"unsupported format %s for analytical measurement" % format
20822085
)
20832086
else:
2084-
if random_generator is None:
2085-
raw_counts = backend.implicit_randc(drange, shape=counts, p=pi)
2086-
else:
2087-
raw_counts = backend.stateful_randc(
2088-
random_generator, a=drange, shape=counts, p=pi
2089-
)
2087+
raw_counts = backend.probability_sample(
2088+
counts, pi, status=status, g=random_generator
2089+
)
2090+
# if random_generator is None:
2091+
# raw_counts = backend.implicit_randc(drange, shape=counts, p=pi)
2092+
# else:
2093+
# raw_counts = backend.stateful_randc(
2094+
# random_generator, a=drange, shape=counts, p=pi
2095+
# )
20902096
return sample2all(raw_counts, n, format=format, jittable=jittable)
20912097

20922098

tests/test_backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def test_backend_methods_2(backend):
291291
r = tc.backend.probability_sample(10000, p, status=np.random.uniform(size=[10000]))
292292
_, r = np.unique(r, return_counts=True)
293293
np.testing.assert_allclose(
294-
r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=100, rtol=1
294+
r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=200, rtol=1
295295
)
296296

297297

tests/test_circuit.py

+8
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,14 @@ def test_batch_sample(backend):
10171017
batch=8, allow_state=True, random_generator=tc.backend.get_random_state(42)
10181018
)
10191019
)
1020+
print(
1021+
c.sample(
1022+
batch=8,
1023+
allow_state=True,
1024+
status=np.random.uniform(size=[8]),
1025+
format="sample_bin",
1026+
)
1027+
)
10201028

10211029

10221030
def test_expectation_y_bug():

0 commit comments

Comments
 (0)