Skip to content

Commit ffdbbe0

Browse files
more random methods in backends;more stable sort of nodes
1 parent 9e5fe4e commit ffdbbe0

9 files changed

+180
-17
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- add `get_random_state` and `random_split` methods to backends
8+
59
### Fixed
610

711
- avoid error on watch non `tf.Tensor` in tensorflow backend grad method

examples/mcnoise_boost.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
boosting the monte carlo noise simulation on general error with circuit layerwise slicing
3+
"""
4+
5+
from functools import partial
6+
import time
7+
import jax
8+
9+
10+
import tensorcircuit as tc
11+
12+
tc.set_backend("jax")
13+
14+
n = 10
15+
nlayer = 4
16+
17+
18+
@partial(tc.backend.jit, static_argnums=(2, 3))
19+
def f1(key, param, n, nlayer):
20+
if key is not None:
21+
tc.backend.set_random_state(key)
22+
c = tc.Circuit(n)
23+
for i in range(n):
24+
c.H(i)
25+
for j in range(nlayer):
26+
for i in range(n - 1):
27+
c.cnot(i, i + 1)
28+
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i)
29+
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i + 1)
30+
for i in range(n):
31+
c.rx(i, theta=param[j, i])
32+
return tc.backend.real(c.expectation((tc.gates.z(), [int(n / 2)])))
33+
34+
35+
@partial(tc.backend.jit, static_argnums=(2))
36+
def templatecnot(s, param, i):
37+
c = tc.Circuit(n, inputs=s)
38+
c.cnot(i, i + 1)
39+
return c.state()
40+
41+
42+
@partial(tc.backend.jit, static_argnums=(3))
43+
def templatenoise(key, s, param, i):
44+
c = tc.Circuit(n, inputs=s)
45+
status = tc.backend.stateful_randu(key)[0]
46+
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i, status=status)
47+
return c.state()
48+
49+
50+
@partial(tc.backend.jit, static_argnums=(2))
51+
def templaterz(s, param, j):
52+
c = tc.Circuit(n, inputs=s)
53+
for i in range(n):
54+
c.rx(i, theta=param[j, i])
55+
return c.state()
56+
57+
58+
@partial(tc.backend.jit, static_argnums=(2, 3))
59+
def f2(key, param, n, nlayer):
60+
c = tc.Circuit(n)
61+
for i in range(n):
62+
c.H(i)
63+
s = c.state()
64+
for j in range(nlayer):
65+
for i in range(n - 1):
66+
s = templatecnot(s, param, i)
67+
key, subkey = tc.backend.random_split(key)
68+
s = templatenoise(subkey, s, param, i)
69+
key, subkey = tc.backend.random_split(key)
70+
s = templatenoise(subkey, s, param, i + 1)
71+
s = templaterz(s, param, j)
72+
return tc.backend.real(tc.expectation((tc.gates.z(), [int(n / 2)]), ket=s))
73+
74+
75+
vagf1 = tc.backend.jit(tc.backend.value_and_grad(f1, argnums=1), static_argnums=(2, 3))
76+
77+
vagf2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=1), static_argnums=(2, 3))
78+
79+
param = tc.backend.ones([nlayer, n])
80+
81+
82+
def benchmark(f, tries=3):
83+
time0 = time.time()
84+
key = tc.backend.get_random_state(42)
85+
print(f(key, param, n, nlayer)[0])
86+
time1 = time.time()
87+
for _ in range(tries):
88+
print(f(key, param, n, nlayer)[0])
89+
time2 = time.time()
90+
print(
91+
"staging time: ",
92+
time1 - time0,
93+
"running time: ",
94+
(time2 - time1) / tries,
95+
)
96+
97+
98+
print("without layerwise slicing jit")
99+
benchmark(vagf1)
100+
print("=============================")
101+
print("with layerwise slicing jit")
102+
benchmark(vagf2)
103+
104+
# 235/0.36 vs. 26/0.04

tensorcircuit/backends/abstract_backend.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -436,18 +436,50 @@ def tree_map( # pylint: disable=unused-variable
436436
return r
437437

438438
def set_random_state( # pylint: disable=unused-variable
439-
self: Any, seed: Optional[int] = None
440-
) -> None:
439+
self: Any, seed: Optional[int] = None, get_only: bool = False
440+
) -> Any:
441441
"""
442442
set random state attached in the backend
443443
444444
:param seed: int, defaults to None
445445
:type seed: Optional[int], optional
446+
:param get_only:
447+
:type get_only: bool
446448
"""
447449
raise NotImplementedError(
448450
"Backend '{}' has not implemented `set_random_state`.".format(self.name)
449451
)
450452

453+
def get_random_state( # pylint: disable=unused-variable
454+
self: Any, seed: Optional[int] = None
455+
) -> Any:
456+
"""
457+
get backend specific random state object
458+
459+
:param seed: [description], defaults to None
460+
:type seed: Optional[int], optional
461+
:return: [description]
462+
:rtype: Any
463+
"""
464+
return self.set_random_state(seed, True)
465+
466+
def random_split( # pylint: disable=unused-variable
467+
self: Any, key: Any
468+
) -> Tuple[Any, Any]:
469+
"""
470+
a jax like split API, but does't split the key generator for other backends.
471+
just for a consistent interface of random code, be careful that you know what the function actually does.
472+
473+
:param key: [description]
474+
:type key: Any
475+
:return: [description]
476+
:rtype: Tuple[Any, Any]
477+
"""
478+
return key, key
479+
480+
# Though try hard, the current random API abstraction may not be perfect when with nested jit or vmap
481+
# so keep every random function a direct status parameters in case.
482+
451483
def implicit_randn( # pylint: disable=unused-variable
452484
self: Any,
453485
shape: Union[int, Sequence[int]] = 1,

tensorcircuit/backends/jax_backend.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,21 @@ def is_tensor(self, a: Any) -> bool:
223223
return True
224224
return False
225225

226-
def set_random_state(self, seed: Optional[Union[int, PRNGKeyArray]] = None) -> None:
226+
def set_random_state(
227+
self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False
228+
) -> Any:
227229
if seed is None:
228230
seed = np.random.randint(42)
229231
if isinstance(seed, int):
230232
g = libjax.random.PRNGKey(seed)
231233
else:
232234
g = seed
233-
self.g = g
235+
if get_only is False:
236+
self.g = g
237+
return g
238+
239+
def random_split(self, key: Any) -> Tuple[Any, Any]:
240+
return libjax.random.split(key) # type: ignore
234241

235242
def implicit_randn(
236243
self,

tensorcircuit/backends/numpy_backend.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,13 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
136136
return a.astype(getattr(np, dtype))
137137
return a.astype(dtype)
138138

139-
def set_random_state(self, seed: Optional[int] = None) -> None:
139+
def set_random_state(
140+
self, seed: Optional[int] = None, get_only: bool = False
141+
) -> Any:
140142
g = np.random.default_rng(seed) # None auto supported
141-
self.g = g
143+
if get_only is False:
144+
self.g = g
145+
return g
142146

143147
def stateful_randn(
144148
self,

tensorcircuit/backends/tensorflow_backend.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,18 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
204204
return tf.cast(a, dtype=getattr(tf, dtype))
205205
return tf.cast(a, dtype=dtype)
206206

207-
def set_random_state(self, seed: Optional[Union[int, RGenerator]] = None) -> None:
207+
def set_random_state(
208+
self, seed: Optional[Union[int, RGenerator]] = None, get_only: bool = False
209+
) -> Any:
208210
if seed is None:
209211
g = tf.random.Generator.from_non_deterministic_state()
210212
elif isinstance(seed, int):
211213
g = tf.random.Generator.from_seed(seed)
212214
else:
213215
g = seed
214-
self.g = g
216+
if get_only is False:
217+
self.g = g
218+
return g
215219

216220
def stateful_randn(
217221
self,

tensorcircuit/circuit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def _general_kraus_2(
673673
# building for jax+GPU ~100s 12 qubit * 5 layers
674674
# 370s 14 qubit * 7 layers, 0.35s running on vT4
675675
# vmap, grad, vvag are all fine for this function
676-
# layerwise jit technique can greatly boost the staging time
676+
# layerwise jit technique can greatly boost the staging time, see in /examples
677677
sites = len(index)
678678
kraus_tensor = [k.tensor for k in kraus]
679679

tensorcircuit/cons.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,8 @@ def _get_path_cache_friendly(
336336
i += 1
337337
# TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
338338
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
339-
placeholder = [1e10]
340-
for s in input_sets:
341-
if len(s) > 1:
342-
break
343-
else:
344-
placeholder = [1e10, 1e10]
345-
order = np.argsort(np.array(list(map(sorted, input_sets)) + [placeholder], dtype=object))[:-1] # type: ignore
346-
# TODO(@refraction-ray): more stable and unwarning arg sorting here
339+
placeholder = [[1e20 for _ in range(100)]]
340+
order = np.argsort(np.array(list(map(sorted, input_sets)) + placeholder, dtype=object))[:-1] # type: ignore
347341
nodes_new = [nodes[i] for i in order]
348342
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
349343
output_set = set([mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)])

tests/test_backends.py

+14
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,17 @@ def test_sparse_methods(backend):
348348
np.array([[1], [2], [0], [0]], dtype=np.complex64),
349349
atol=1e-5,
350350
)
351+
352+
353+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
354+
def test_backend_randoms_v2(backend):
355+
g = tc.backend.get_random_state(42)
356+
for t in tc.backend.stateful_randc(g, 3, [3]):
357+
assert t >= 0
358+
assert t < 3
359+
key = tc.backend.get_random_state(42)
360+
r = []
361+
for _ in range(2):
362+
key, subkey = tc.backend.random_split(key)
363+
r.append(tc.backend.stateful_randc(subkey, 3, [5]))
364+
assert tuple(r[0]) != tuple(r[1])

0 commit comments

Comments
 (0)