Skip to content

Commit 33638eb

Browse files
new sample API and new backend methods
1 parent 2420910 commit 33638eb

14 files changed

+286
-11
lines changed

CHANGELOG.md

+12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@
66

77
- Add PyTorch nn Module wrapper in `torchnn`
88

9+
- Add `reverse`, `mod`, `left_shift`, `right_shift`, `arange` methods on backend
10+
11+
- Brand new `sample` API with batch support and sampling from state support
12+
13+
### Fixed
14+
15+
- Fixed bug in merge single gates when all gates are single-qubit ones
16+
17+
### Changed
18+
19+
- The default contractor enable preprocessing feature where single-qubit gates are merged firstly
20+
921
## 0.1.3
1022

1123
### Added

docs/source/quickstart.rst

+9
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
239239
'acosh',
240240
'addition',
241241
'adjoint',
242+
'arange',
242243
'argmax',
243244
'argmin',
244245
'asin',
@@ -255,6 +256,7 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
255256
'conj',
256257
'convert_to_tensor',
257258
'coo_sparse_matrix',
259+
'coo_sparse_matrix_from_numpy',
258260
'copy',
259261
'cos',
260262
'cosh',
@@ -267,6 +269,7 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
267269
'eigs',
268270
'eigsh',
269271
'eigsh_lanczos',
272+
'eigvalsh',
270273
'einsum',
271274
'eps',
272275
'exp',
@@ -293,12 +296,14 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
293296
'jit',
294297
'jvp',
295298
'kron',
299+
'left_shift',
296300
'log',
297301
'matmul',
298302
'max',
299303
'mean',
300304
'min',
301305
'minor',
306+
'mod',
302307
'multiply',
303308
'name',
304309
'norm',
@@ -319,6 +324,8 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
319324
'reshape',
320325
'reshape2',
321326
'reshapem',
327+
'reverse',
328+
'right_shift',
322329
'rq',
323330
'scatter',
324331
'serialize_tensor',
@@ -357,7 +364,9 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
357364
'to_dense',
358365
'trace',
359366
'transpose',
367+
'tree_flatten',
360368
'tree_map',
369+
'tree_unflatten',
361370
'unique_with_counts',
362371
'value_and_grad',
363372
'vectorized_value_and_grad',

docs/source/tutorial.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Jupyter Tutorials
55
.. toctree::
66

77
tutorials/circuit_basics.ipynb
8+
tutorials/qaoa.ipynb
89
tutorials/tfim_vqe.ipynb
910
tutorials/mnist_qml.ipynb
1011
tutorials/torch_qml.ipynb

examples/sample_benchmark.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ def construct_circuit(n, nlayers):
2727
print("n: ", n, " nlayers: ", nlayers)
2828
c = construct_circuit(n, nlayers)
2929
time0 = time.time()
30-
s = c.state()
30+
s = c.sample(allow_state=True)
3131
time1 = time.time()
32-
smp = bin(np.random.choice(range(2**n), p=np.abs(K.numpy(s)) ** 2))
3332
# print(smp)
3433
print("state sampling time: ", time1 - time0)
3534
time0 = time.time()
3635
smp = c.sample()
3736
# print(smp)
3837
time1 = time.time()
3938
print("nonjit tensor sampling time: ", time1 - time0)
39+
time0 = time.time()
40+
s = c.sample(allow_state=True, batch=10)
41+
time1 = time.time()
42+
print("batch state sampling time: ", (time1 - time0) / 10)
4043

4144
@K.jit
4245
def f(key):
@@ -48,11 +51,10 @@ def f(key):
4851
time0 = time.time()
4952
smp = f(key1)
5053
time1 = time.time()
51-
for _ in range(5):
54+
for _ in range(10):
5255
key1, key2 = K.random_split(key2)
5356
smp = f(key1)
5457
# print(smp)
5558
time2 = time.time()
56-
57-
print("jittable tensor sampling staging time: ", time1 - time0)
58-
print("jittable tensor sampling running time: ", (time2 - time1) / 5)
59+
print("jittable tensor sampling staginging time: ", time1 - time0)
60+
print("jittable tensor sampling running time: ", (time2 - time1) / 10)

tensorcircuit/backends/abstract_backend.py

+75
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,81 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
644644
"Backend '{}' has not implemented `cast`.".format(self.name)
645645
)
646646

647+
def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
648+
"""
649+
Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
650+
651+
:param x: input values
652+
:type x: Tensor
653+
:param y: mod ``y``
654+
:type y: Tensor
655+
:return: results
656+
:rtype: Tensor
657+
"""
658+
raise NotImplementedError(
659+
"Backend '{}' has not implemented `mod`.".format(self.name)
660+
)
661+
662+
def reverse(self: Any, a: Tensor) -> Tensor:
663+
"""
664+
return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
665+
666+
:param a: 1D tensor
667+
:type a: Tensor
668+
:return: 1D tensor in reverse order
669+
:rtype: Tensor
670+
"""
671+
return a[::-1]
672+
673+
def right_shift(self: Any, x: Tensor, y: Tensor) -> Tensor:
674+
"""
675+
Shift the bits of an integer x to the right y bits.
676+
677+
:param x: input values
678+
:type x: Tensor
679+
:param y: Number of bits shift to ``x``
680+
:type y: Tensor
681+
:return: result with the same shape as ``x``
682+
:rtype: Tensor
683+
"""
684+
raise NotImplementedError(
685+
"Backend '{}' has not implemented `right_shift`.".format(self.name)
686+
)
687+
688+
def left_shift(self: Any, x: Tensor, y: Tensor) -> Tensor:
689+
"""
690+
Shift the bits of an integer x to the left y bits.
691+
692+
:param x: input values
693+
:type x: Tensor
694+
:param y: Number of bits shift to ``x``
695+
:type y: Tensor
696+
:return: result with the same shape as ``x``
697+
:rtype: Tensor
698+
"""
699+
raise NotImplementedError(
700+
"Backend '{}' has not implemented `left_shift`.".format(self.name)
701+
)
702+
703+
def arange(
704+
self: Any, start: int, stop: Optional[int] = None, step: int = 1
705+
) -> Tensor:
706+
"""
707+
Values are generated within the half-open interval [start, stop)
708+
709+
:param start: start index
710+
:type start: int
711+
:param stop: end index, defaults to None
712+
:type stop: Optional[int], optional
713+
:param step: steps, defaults to 1
714+
:type step: Optional[int], optional
715+
:return: _description_
716+
:rtype: Tensor
717+
"""
718+
raise NotImplementedError(
719+
"Backend '{}' has not implemented `arange`.".format(self.name)
720+
)
721+
647722
def solve(self: Any, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
648723
"""
649724
Solve the linear system Ax=b and return the solution x.

tensorcircuit/backends/jax_backend.py

+14
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,20 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
315315
return a.astype(getattr(jnp, dtype))
316316
return a.astype(dtype)
317317

318+
def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor:
319+
if stop is None:
320+
return jnp.arange(start=0, stop=start, step=step)
321+
return jnp.arange(start=start, stop=stop, step=step)
322+
323+
def mod(self, x: Tensor, y: Tensor) -> Tensor:
324+
return jnp.mod(x, y)
325+
326+
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
327+
return jnp.right_shift(x, y)
328+
329+
def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
330+
return jnp.left_shift(x, y)
331+
318332
def expm(self, a: Tensor) -> Tensor:
319333
return jsp.linalg.expm(a)
320334
# currently expm in jax doesn't support AD, it will raise an AssertError,

tensorcircuit/backends/numpy_backend.py

+14
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,20 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
211211
return a.astype(getattr(np, dtype))
212212
return a.astype(dtype)
213213

214+
def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor:
215+
if stop is None:
216+
return np.arange(start=0, stop=start, step=step)
217+
return np.arange(start=start, stop=stop, step=step)
218+
219+
def mod(self, x: Tensor, y: Tensor) -> Tensor:
220+
return np.mod(x, y)
221+
222+
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
223+
return np.right_shift(x, y)
224+
225+
def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
226+
return np.left_shift(x, y)
227+
214228
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor:
215229
# gen, sym, her, pos
216230
# https://stackoverflow.com/questions/44672029/difference-between-numpy-linalg-solve-and-numpy-linalg-lu-solve/44710451

tensorcircuit/backends/pytorch_backend.py

+17
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,26 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
375375
return a.type(getattr(torchlib, dtype))
376376
return a.type(dtype)
377377

378+
def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor:
379+
if stop is None:
380+
return torchlib.arange(start=0, end=start, step=step)
381+
return torchlib.arange(start=start, end=stop, step=step)
382+
383+
def mod(self, x: Tensor, y: Tensor) -> Tensor:
384+
return torchlib.fmod(x, y)
385+
386+
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
387+
return torchlib.bitwise_right_shift(x, y)
388+
389+
def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
390+
return torchlib.bitwise_left_shift(x, y)
391+
378392
def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
379393
return torchlib.linalg.solve(A, b)
380394

395+
def reverse(self, a: Tensor) -> Tensor:
396+
return torchlib.flip(a, dims=(-1,))
397+
381398
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
382399
# TODO(@refraction-ray): torch not support multiple pytree args
383400
return torchlib.utils._pytree.tree_map(f, *pytrees)

tensorcircuit/backends/tensorflow_backend.py

+14
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,20 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
414414
return tf.cast(a, dtype=getattr(tf, dtype))
415415
return tf.cast(a, dtype=dtype)
416416

417+
def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor:
418+
if stop is None:
419+
return tf.range(start=0, limit=start, delta=step)
420+
return tf.range(start=start, limit=stop, delta=step)
421+
422+
def mod(self, x: Tensor, y: Tensor) -> Tensor:
423+
return tf.math.mod(x, y)
424+
425+
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
426+
return tf.bitwise.right_shift(x, y)
427+
428+
def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
429+
return tf.bitwise.left_shift(x, y)
430+
417431
def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
418432
if b.shape[-1] == A.shape[-1]:
419433
b = b[..., tf.newaxis]

tensorcircuit/circuit.py

+64-1
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,70 @@ def perfect_sampling(self) -> Tuple[str, float]:
14191419
"""
14201420
return self.measure_jit(*[i for i in range(self._nqubits)], with_prob=True)
14211421

1422-
sample = perfect_sampling
1422+
# sample = perfect_sampling
1423+
1424+
def sample(
1425+
self,
1426+
batch: Optional[int] = None,
1427+
allow_state: bool = False,
1428+
status: Optional[Tensor] = None,
1429+
) -> Any:
1430+
"""
1431+
batched sampling from state or circuit tensor network directly
1432+
1433+
:param batch: number of samples, defaults to None
1434+
:type batch: Optional[int], optional
1435+
:param allow_state: if true, we sample from the final state
1436+
if memory allsows, True is prefered, defaults to False
1437+
:type allow_state: bool, optional
1438+
:param status: random generator, defaults to None
1439+
:type status: Optional[Tensor], optional
1440+
:return: List (if batch) of tuple (binary configuration tensor and correponding probability)
1441+
:rtype: Any
1442+
"""
1443+
# allow_state = False is compatibility issue
1444+
if not allow_state:
1445+
if batch is None:
1446+
return self.perfect_sampling()
1447+
1448+
@backend.jit # type: ignore
1449+
def perfect_sampling(key: Any) -> Any:
1450+
backend.set_random_state(key)
1451+
return self.perfect_sampling()
1452+
1453+
r = []
1454+
if status is None:
1455+
status = backend.get_random_state()
1456+
subkey = status
1457+
for _ in range(batch):
1458+
key, subkey = backend.random_split(subkey)
1459+
r.append(perfect_sampling(key))
1460+
1461+
return r
1462+
1463+
if batch is None:
1464+
nbatch = 1
1465+
else:
1466+
nbatch = batch
1467+
s = self.state()
1468+
p = backend.abs(s) ** 2
1469+
if status is None:
1470+
ch = backend.implicit_randc(a=2**self._nqubits, shape=[nbatch], p=p)
1471+
else:
1472+
ch = backend.stateful_randc(
1473+
status, a=2**self._nqubits, shape=[nbatch], p=p
1474+
)
1475+
prob = backend.gather1d(p, ch)
1476+
confg = backend.mod(
1477+
backend.right_shift(
1478+
ch[..., None], backend.reverse(backend.arange(self._nqubits))
1479+
),
1480+
2,
1481+
)
1482+
r = list(zip(confg, prob))
1483+
if batch is None:
1484+
r = r[0]
1485+
return r
14231486

14241487
# TODO(@refraction-ray): more _before function like state_before? and better API?
14251488

0 commit comments

Comments
 (0)