@@ -1419,7 +1419,70 @@ def perfect_sampling(self) -> Tuple[str, float]:
1419
1419
"""
1420
1420
return self .measure_jit (* [i for i in range (self ._nqubits )], with_prob = True )
1421
1421
1422
- sample = perfect_sampling
1422
+ # sample = perfect_sampling
1423
+
1424
+ def sample (
1425
+ self ,
1426
+ batch : Optional [int ] = None ,
1427
+ allow_state : bool = False ,
1428
+ status : Optional [Tensor ] = None ,
1429
+ ) -> Any :
1430
+ """
1431
+ batched sampling from state or circuit tensor network directly
1432
+
1433
+ :param batch: number of samples, defaults to None
1434
+ :type batch: Optional[int], optional
1435
+ :param allow_state: if true, we sample from the final state
1436
+ if memory allsows, True is prefered, defaults to False
1437
+ :type allow_state: bool, optional
1438
+ :param status: random generator, defaults to None
1439
+ :type status: Optional[Tensor], optional
1440
+ :return: List (if batch) of tuple (binary configuration tensor and correponding probability)
1441
+ :rtype: Any
1442
+ """
1443
+ # allow_state = False is compatibility issue
1444
+ if not allow_state :
1445
+ if batch is None :
1446
+ return self .perfect_sampling ()
1447
+
1448
+ @backend .jit # type: ignore
1449
+ def perfect_sampling (key : Any ) -> Any :
1450
+ backend .set_random_state (key )
1451
+ return self .perfect_sampling ()
1452
+
1453
+ r = []
1454
+ if status is None :
1455
+ status = backend .get_random_state ()
1456
+ subkey = status
1457
+ for _ in range (batch ):
1458
+ key , subkey = backend .random_split (subkey )
1459
+ r .append (perfect_sampling (key ))
1460
+
1461
+ return r
1462
+
1463
+ if batch is None :
1464
+ nbatch = 1
1465
+ else :
1466
+ nbatch = batch
1467
+ s = self .state ()
1468
+ p = backend .abs (s ) ** 2
1469
+ if status is None :
1470
+ ch = backend .implicit_randc (a = 2 ** self ._nqubits , shape = [nbatch ], p = p )
1471
+ else :
1472
+ ch = backend .stateful_randc (
1473
+ status , a = 2 ** self ._nqubits , shape = [nbatch ], p = p
1474
+ )
1475
+ prob = backend .gather1d (p , ch )
1476
+ confg = backend .mod (
1477
+ backend .right_shift (
1478
+ ch [..., None ], backend .reverse (backend .arange (self ._nqubits ))
1479
+ ),
1480
+ 2 ,
1481
+ )
1482
+ r = list (zip (confg , prob ))
1483
+ if batch is None :
1484
+ r = r [0 ]
1485
+ return r
1423
1486
1424
1487
# TODO(@refraction-ray): more _before function like state_before? and better API?
1425
1488
0 commit comments