Skip to content

Commit fd17fae

Browse files
infra for random of backends; randn
1 parent fa17d65 commit fd17fae

File tree

7 files changed

+269
-9
lines changed

7 files changed

+269
-9
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ pylint tensorcircuit tests
9898
mypy tensorcircuit
9999
```
100100

101+
### Integrated script
102+
103+
For now, we introduce one for all checker for development:
104+
105+
```bash
106+
./check_all.sh
107+
```
108+
101109
### CI
102110

103111
We currently use GitHub Action for test CI, but it has limited quota for free private repo.

check_all.sh

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#! /bin/sh
2+
set -e
3+
echo "black check"
4+
black . --check
5+
echo "mypy check"
6+
mypy tensorcircuit
7+
echo "pylint check"
8+
pylint tensorcircuit tests
9+
echo "pytest check"
10+
pytest --cov=tensorcircuit -vv
11+
echo "sphinx check"
12+
cd docs && make html
13+
echo "all checks passed, congratulates!"

mypy.ini

+12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,20 @@ python_version = 3.6
33
ignore_missing_imports = True
44
strict = True
55
warn_unused_ignores = False
6+
disallow_untyped_calls = False
7+
68

79
[mypy-tensorcircuit.applications.van]
810
;;mypy simply cannot ignore files with wildcard patterns...
911
;;only module level * works...
1012
ignore_errors = True
13+
14+
15+
;;[mypy-numpy.*]
16+
;;ignore_errors = True
17+
18+
;; doesn't work due to https://github.com/python/mypy/issues/10757
19+
;; mypy + numpy is currently a disaster, never use mypy in your next project
20+
;; unless you enjoy writting sth worse than C
21+
;; both the establish status of mypy and support from other packages are just wasting your time
22+
;; GET AWAY MYPY AND TYPE ANNOTATION !!! WRITTING PYTHON AS IT IS !!!

tensorcircuit/backends.py

+206-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030

3131
Tensor = Any
32-
32+
PRNGKeyArray = Any # libjax.random.PRNGKeyArray
33+
RGenerator = Any # tf.random.Generator
3334

3435
libjax: Any
3536
jnp: Any
@@ -356,6 +357,77 @@ def tree_map( # pylint: disable=unused-variable
356357

357358
return r
358359

360+
def set_random_state( # pylint: disable=unused-variable
361+
self: Any, seed: Optional[int] = None
362+
) -> None:
363+
"""
364+
set random state attached in the backend
365+
366+
:param seed: int, defaults to None
367+
:type seed: Optional[int], optional
368+
"""
369+
raise NotImplementedError(
370+
"Backend '{}' has not implemented `set_random_state`.".format(self.name)
371+
)
372+
373+
def implicit_randn( # pylint: disable=unused-variable
374+
self: Any,
375+
shape: Union[int, Sequence[int]] = 1,
376+
mean: float = 0,
377+
stddev: float = 1,
378+
dtype: str = "32",
379+
) -> Tensor:
380+
"""
381+
call random normal function with the random state management behind the scene
382+
383+
:param shape: [description], defaults to 1
384+
:type shape: Union[int, Sequence[int]], optional
385+
:param mean: [description], defaults to 0
386+
:type mean: float, optional
387+
:param stddev: [description], defaults to 1
388+
:type stddev: float, optional
389+
:param dtype: [description], defaults to "32"
390+
:type dtype: str, optional
391+
:return: [description]
392+
:rtype: Tensor
393+
"""
394+
g = getattr(self, "g", None)
395+
if g is None:
396+
self.set_random_state()
397+
g = getattr(self, "g", None)
398+
r = self.stateful_randn(g, shape, mean, stddev, dtype)
399+
return r
400+
401+
def stateful_randn( # pylint: disable=unused-variable
402+
self: Any,
403+
g: Any,
404+
shape: Union[int, Sequence[int]] = 1,
405+
mean: float = 0,
406+
stddev: float = 1,
407+
dtype: str = "32",
408+
) -> Tensor:
409+
"""
410+
[summary]
411+
412+
:param self: [description]
413+
:type self: Any
414+
:param g: stateful register for each package
415+
:type g: Any
416+
:param shape: shape of output sampling tensor
417+
:type shape: Union[int, Sequence[int]]
418+
:param mean: [description], defaults to 0
419+
:type mean: float, optional
420+
:param stddev: [description], defaults to 1
421+
:type stddev: float, optional
422+
:param dtype: only real data type is supported, "32" or "64", defaults to "32"
423+
:type dtype: str, optional
424+
:return: [description]
425+
:rtype: Tensor
426+
"""
427+
raise NotImplementedError(
428+
"Backend '{}' has not implemented `stateful_randn`.".format(self.name)
429+
)
430+
359431
def grad( # pylint: disable=unused-variable
360432
self: Any, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
361433
) -> Callable[..., Any]:
@@ -580,6 +652,33 @@ def real(self, a: Tensor) -> Tensor:
580652
def cast(self, a: Tensor, dtype: str) -> Tensor:
581653
return a.astype(getattr(np, dtype))
582654

655+
def set_random_state(self, seed: Optional[int] = None) -> None:
656+
g = np.random.default_rng(seed) # None auto supported
657+
self.g = g
658+
659+
def stateful_randn(
660+
self,
661+
g: np.random.Generator,
662+
shape: Union[int, Sequence[int]] = 1,
663+
mean: float = 0,
664+
stddev: float = 1,
665+
dtype: str = "32",
666+
) -> Tensor:
667+
if isinstance(dtype, str):
668+
dtype = dtype[-2:]
669+
if isinstance(shape, int):
670+
shape = (shape,)
671+
r = g.normal(loc=mean, scale=stddev, size=shape) # type: ignore
672+
if dtype == "32":
673+
r = r.astype(np.float32)
674+
elif dtype == "64":
675+
r = r.astype(np.float64)
676+
elif not isinstance(dtype, str):
677+
r = r.astype(dtype)
678+
else:
679+
raise ValueError("unspported `dtype` %s" % dtype)
680+
return r
681+
583682
def grad(
584683
self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
585684
) -> Callable[..., Any]:
@@ -735,6 +834,49 @@ def is_tensor(self, a: Any) -> bool:
735834
return True
736835
return False
737836

837+
def set_random_state(self, seed: Optional[int] = None) -> None:
838+
if seed is None:
839+
seed = np.random.randint(42)
840+
g = libjax.random.PRNGKey(seed)
841+
self.g = g
842+
843+
def implicit_randn(
844+
self,
845+
shape: Union[int, Sequence[int]] = 1,
846+
mean: float = 0,
847+
stddev: float = 1,
848+
dtype: str = "32",
849+
) -> Tensor:
850+
g = getattr(self, "g", None)
851+
if g is None:
852+
self.set_random_state()
853+
g = getattr(self, "g", None)
854+
key, subkey = libjax.random.split(g)
855+
r = self.stateful_randn(subkey, shape, mean, stddev, dtype)
856+
self.g = key
857+
return r
858+
859+
def stateful_randn(
860+
self,
861+
g: PRNGKeyArray,
862+
shape: Union[int, Sequence[int]] = 1,
863+
mean: float = 0,
864+
stddev: float = 1,
865+
dtype: str = "32",
866+
) -> Tensor:
867+
if isinstance(dtype, str):
868+
dtype = dtype[-2:]
869+
if isinstance(shape, int):
870+
shape = (shape,)
871+
if dtype == "32":
872+
dtyper = jnp.float32
873+
elif dtype == "64":
874+
dtyper = jnp.float64
875+
elif not isinstance(dtype, str):
876+
dtyper = dtype
877+
r = libjax.random.normal(g, shape=shape, dtype=dtyper) * stddev + mean
878+
return r
879+
738880
def grad(
739881
self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
740882
) -> Any:
@@ -824,6 +966,38 @@ def _outer_product_tf(self: Any, tensor1: Tensor, tensor2: Tensor) -> Tensor:
824966
return tf.tensordot(tensor1, tensor2, 0)
825967

826968

969+
def _random_choice_tf(
970+
g: RGenerator,
971+
a: Union[int, Sequence[int], Tensor],
972+
shape: Union[int, Sequence[int]],
973+
p: Optional[Union[Sequence[float], Tensor]] = None,
974+
) -> Tensor:
975+
# only replace=True support, replace=False is not implemented
976+
# for stateless random module, tf has corresponding categorical function similar to choice
977+
# however, such utility is not implemented with ``tf.random.Generator``
978+
if isinstance(a, int):
979+
assert a > 0
980+
a = tf.range(a)
981+
assert len(a.shape) == 1
982+
if isinstance(shape, int):
983+
shape = (shape,)
984+
if p is None:
985+
dtype = tf.float32
986+
p = tf.ones_like(a)
987+
p = tf.cast(p, dtype=dtype)
988+
p /= tf.reduce_sum(p)
989+
else:
990+
if not isinstance(p, tf.Tensor):
991+
p = tf.constant(p)
992+
dtype = p.dtype
993+
shape1 = reduce(mul, shape)
994+
p_cuml = tf.cumsum(p)
995+
r = p_cuml[-1] * (1 - g.uniform([shape1], dtype=dtype))
996+
ind = tf.searchsorted(p_cuml, r)
997+
res = tf.gather(a, ind)
998+
return tf.reshape(res, shape)
999+
1000+
8271001
# temporary hot replace until new version of tensorflow is released,
8281002
# see issue: https://github.com/google/TensorNetwork/issues/940
8291003
# avoid buggy tensordot2 in tensornetwork
@@ -910,6 +1084,35 @@ def real(self, a: Tensor) -> Tensor:
9101084
def cast(self, a: Tensor, dtype: str) -> Tensor:
9111085
return tf.cast(a, dtype=getattr(tf, dtype))
9121086

1087+
def set_random_state(self, seed: Optional[Union[int, RGenerator]] = None) -> None:
1088+
if seed is None:
1089+
g = tf.random.Generator.from_non_deterministic_state()
1090+
elif isinstance(seed, int):
1091+
g = tf.random.Generator.from_seed(seed)
1092+
else:
1093+
g = seed
1094+
self.g = g
1095+
1096+
def stateful_randn(
1097+
self,
1098+
g: RGenerator,
1099+
shape: Union[int, Sequence[int]] = 1,
1100+
mean: float = 0,
1101+
stddev: float = 1,
1102+
dtype: str = "32",
1103+
) -> Tensor:
1104+
if isinstance(dtype, str):
1105+
dtype = dtype[-2:]
1106+
if isinstance(shape, int):
1107+
shape = (shape,)
1108+
if dtype == "32":
1109+
dtyper = tf.float32
1110+
elif dtype == "64":
1111+
dtyper = tf.float64
1112+
elif not isinstance(dtype, str):
1113+
dtyper = dtype
1114+
return g.normal(shape, mean, stddev, dtype=dtyper)
1115+
9131116
def grad(
9141117
self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
9151118
) -> Callable[..., Any]:
@@ -1090,6 +1293,8 @@ def wrapper(
10901293
vvag = vectorized_value_and_grad
10911294

10921295

1296+
# TODO(@refraction-ray): lack stateful random methods implementation for now
1297+
# To be added once pytorch backend is ready
10931298
class PyTorchBackend(pytorch_backend.PyTorchBackend): # type: ignore
10941299
def __init__(self) -> None:
10951300
super(PyTorchBackend, self).__init__()

tensorcircuit/gates.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
thismodule = sys.modules[__name__]
1818

1919
Tensor = Any
20+
Array = Any
2021

2122
# Common single qubit states as np.ndarray objects
2223
zero_state = np.array([1.0, 0.0], dtype=npdtype)
@@ -107,7 +108,7 @@ class Gate(tn.Node): # type: ignore
107108
pass
108109

109110

110-
def num_to_tensor(*num: float, dtype: Optional[str] = None) -> Any:
111+
def num_to_tensor(*num: Union[float, Tensor], dtype: Optional[str] = None) -> Any:
111112
l = []
112113
if not dtype:
113114
dtype = dtypestr
@@ -124,7 +125,7 @@ def num_to_tensor(*num: float, dtype: Optional[str] = None) -> Any:
124125
array_to_tensor = num_to_tensor
125126

126127

127-
def gate_wrapper(m: np.array, n: Optional[str] = None) -> Gate:
128+
def gate_wrapper(m: Tensor, n: Optional[str] = None) -> Gate:
128129
if not n:
129130
n = "unknowngate"
130131
m = m.astype(npdtype)
@@ -161,7 +162,7 @@ def matrix_for_gate(gate: Gate) -> Tensor:
161162
return t
162163

163164

164-
def bmatrix(a: np.array) -> str:
165+
def bmatrix(a: Array) -> str:
165166
"""
166167
Returns a LaTeX bmatrix
167168
@@ -261,11 +262,10 @@ def random_single_qubit_gate() -> Gate:
261262
"""
262263
Returns the random single qubit gate described in https://arxiv.org/abs/2002.07730.
263264
"""
264-
265265
# Get the random parameters
266-
theta, alpha, phi = np.random.rand(3) * 2 * np.pi
266+
theta, alpha, phi = np.random.rand(3) * 2 * np.pi # type: ignore
267267

268-
return rgate(theta, alpha, phi)
268+
return rgate(theta, alpha, phi) # type: ignore
269269

270270

271271
rs = random_single_qubit_gate

tensorcircuit/quantum.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Tuple,
1414
Set,
1515
List,
16-
Type,
1716
)
1817

1918
import numpy as np
@@ -66,7 +65,7 @@ def quantum_constructor(
6665

6766
def identity(
6867
space: Sequence[int],
69-
dtype: Type[np.number] = np.float64,
68+
dtype: Any = np.float64,
7069
) -> "QuOperator":
7170
"""Construct a `QuOperator` representing the identity on a given space.
7271
Internally, this is done by constructing `CopyNode`s for each edge, with

0 commit comments

Comments
 (0)