Skip to content

Commit 3c3ce72

Browse files
add parameter shift grad (super fast)
1 parent bc89236 commit 3c3ce72

File tree

5 files changed

+142
-7
lines changed

5 files changed

+142
-7
lines changed

CHANGELOG.md

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88

99
- Add alias `expps` for `expectation_ps` and `sexpps` for `sampled_expectation_ps`
1010

11+
- Add `counts_d2s` and `counts_s2d` in quantum module to transform different representation of measurement shots results
12+
13+
- Add vmap enhanced `parameter_shift_grad` in experimental module (API subjects to change)
14+
15+
- Add `parameter_shift.py` script in examples
16+
17+
### Changed
18+
19+
- `rxx`, `ryy`, `rzz` gates now has 1/2 factor before theta consitent with `rx`, `ry`, `rz` gates. (breaking change)
20+
1121
## 0.3.1
1222

1323
### Added

examples/parameter_shift.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Demonstration on the correctness and efficiency of parameter shift gradient implementation
3+
"""
4+
5+
import numpy as np
6+
import tensorcircuit as tc
7+
from tensorcircuit import experimental as E
8+
9+
K = tc.set_backend("tensorflow")
10+
11+
n = 6
12+
m = 3
13+
14+
15+
def f1(param):
16+
c = tc.Circuit(n)
17+
for j in range(m):
18+
for i in range(n - 1):
19+
c.cnot(i, i + 1)
20+
for i in range(n):
21+
c.rx(i, theta=param[i, j])
22+
return c.expectation_ps(y=[n // 2])
23+
24+
25+
g1f1 = K.jit(K.grad(f1))
26+
27+
r1, ts, tr = tc.utils.benchmark(g1f1, K.ones([n, m], dtype="float32"))
28+
29+
g2f1 = K.jit(E.parameter_shift_grad(f1))
30+
31+
r2, ts, tr = tc.utils.benchmark(g2f1, K.ones([n, m], dtype="float32"))
32+
33+
np.testing.assert_allclose(r1, r2, atol=1e-5)
34+
print("equality test passed!")
35+
36+
# mutiple weights args version
37+
38+
39+
def f2(paramzz, paramx):
40+
c = tc.Circuit(n)
41+
for j in range(m):
42+
for i in range(n - 1):
43+
c.rzz(i, i + 1, theta=paramzz[i, j])
44+
for i in range(n):
45+
c.rx(i, theta=paramx[i, j])
46+
return c.expectation_ps(y=[n // 2])
47+
48+
49+
g1f2 = K.jit(K.grad(f2, argnums=(0, 1)))
50+
51+
r1, ts, tr = tc.utils.benchmark(
52+
g1f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
53+
)
54+
55+
g2f2 = K.jit(E.parameter_shift_grad(f2, argnums=(0, 1)))
56+
57+
r2, ts, tr = tc.utils.benchmark(
58+
g2f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
59+
)
60+
61+
np.testing.assert_allclose(r1[0], r2[0], atol=1e-5)
62+
np.testing.assert_allclose(r1[1], r2[1], atol=1e-5)
63+
print("equality test passed!")

tensorcircuit/experimental.py

+55
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import partial
66
from typing import Any, Callable, Optional, Sequence, Union
77

8+
import numpy as np
9+
810
from .cons import backend, dtypestr
911

1012
Tensor = Any
@@ -202,3 +204,56 @@ def energy(params: Tensor) -> Tensor:
202204
return backend.grad(energy)(params)
203205

204206
return wrapper
207+
208+
209+
def parameter_shift_grad(
210+
f: Callable[..., Tensor],
211+
argnums: Union[int, Sequence[int]] = 0,
212+
jit: bool = False,
213+
) -> Callable[..., Tensor]:
214+
"""
215+
similar to `grad` function but using parameter shift internally instead of AD,
216+
vmap is utilized for evaluation, so the speed is still ok
217+
218+
:param f: quantum function with weights in and expectation out
219+
:type f: Callable[..., Tensor]
220+
:param argnums: label which args should be differentiated,
221+
defaults to 0
222+
:type argnums: Union[int, Sequence[int]], optional
223+
:param jit: whether jit the original function `f` at the beginning,
224+
defaults to False
225+
:type jit: bool, optional
226+
:return: the grad function
227+
:rtype: Callable[..., Tensor]
228+
"""
229+
if jit is True:
230+
f = backend.jit(f)
231+
232+
if isinstance(argnums, int):
233+
argnums = [argnums]
234+
235+
vfs = [backend.vmap(f, vectorized_argnums=i) for i in argnums]
236+
237+
def grad_f(*args: Any, **kws: Any) -> Any:
238+
grad_values = []
239+
for i in argnums: # type: ignore
240+
shape = backend.shape_tuple(args[i])
241+
size = backend.sizen(args[i])
242+
onehot = backend.eye(size)
243+
onehot = backend.cast(onehot, args[i].dtype)
244+
onehot = backend.reshape(onehot, [size] + list(shape))
245+
onehot = np.pi / 2 * onehot
246+
nargs = list(args)
247+
arg = backend.reshape(args[i], [1] + list(shape))
248+
batched_arg = backend.tile(arg, [size] + [1 for _ in shape])
249+
nargs[i] = batched_arg + onehot
250+
nargs2 = list(args)
251+
nargs2[i] = batched_arg - onehot
252+
r = (vfs[i](*nargs, **kws) - vfs[i](*nargs2, **kws)) / 2.0
253+
r = backend.reshape(r, shape)
254+
grad_values.append(r)
255+
if len(argnums) > 1: # type: ignore
256+
return tuple(grad_values)
257+
return grad_values[0]
258+
259+
return grad_f

tensorcircuit/gates.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,9 @@ def exponential_gate(unitary: Tensor, theta: float, name: str = "none") -> Gate:
721721
# exp = exponential_gate
722722

723723

724-
def exponential_gate_unity(unitary: Tensor, theta: float, name: str = "none") -> Gate:
724+
def exponential_gate_unity(
725+
unitary: Tensor, theta: float, half: bool = False, name: str = "none"
726+
) -> Gate:
725727
r"""
726728
Faster exponential gate directly implemented based on RHS. Only works when :math:`U^2 = I` is an identity matrix.
727729
@@ -733,6 +735,9 @@ def exponential_gate_unity(unitary: Tensor, theta: float, name: str = "none") ->
733735
:type unitary: Tensor
734736
:param theta: angle in radians
735737
:type theta: float
738+
:param half: if True, the angel theta is mutiplied by 1/2,
739+
defaults to False
740+
:type half: bool
736741
:param name: suffix of Gate name
737742
:type name: str, optional
738743
:return: Exponential Gate
@@ -745,15 +750,17 @@ def exponential_gate_unity(unitary: Tensor, theta: float, name: str = "none") ->
745750
i = i.reshape([2 for _ in range(n)])
746751
unitary = backend.reshape(unitary, [2 for _ in range(n)])
747752
it = array_to_tensor(i)
753+
if half is True:
754+
theta = theta / 2.0
748755
mat = backend.cos(theta) * it - 1.0j * backend.sin(theta) * unitary
749756
return Gate(mat, name="exp1-" + name)
750757

751758

752759
exp1_gate = exponential_gate_unity
753760
# exp1 = exponential_gate_unity
754-
rzz_gate = partial(exp1_gate, unitary=_zz_matrix)
755-
rxx_gate = partial(exp1_gate, unitary=_xx_matrix)
756-
ryy_gate = partial(exp1_gate, unitary=_yy_matrix)
761+
rzz_gate = partial(exp1_gate, unitary=_zz_matrix, half=True)
762+
rxx_gate = partial(exp1_gate, unitary=_xx_matrix, half=True)
763+
ryy_gate = partial(exp1_gate, unitary=_yy_matrix, half=True)
757764

758765

759766
def multicontrol_gate(unitary: Tensor, ctrl: Union[int, Sequence[int]] = 1) -> Operator:

tests/test_gates.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_rxx_gate():
8888
c1.ryy(0, 2, theta=0.5)
8989
c1.rzz(0, 1, theta=-0.5)
9090
c2 = tc.Circuit(3)
91-
c2.exp1(0, 1, theta=1.0, unitary=tc.gates._xx_matrix)
92-
c2.exp1(0, 2, theta=0.5, unitary=tc.gates._yy_matrix)
93-
c2.exp1(0, 1, theta=-0.5, unitary=tc.gates._zz_matrix)
91+
c2.exp1(0, 1, theta=1.0 / 2, unitary=tc.gates._xx_matrix)
92+
c2.exp1(0, 2, theta=0.5 / 2, unitary=tc.gates._yy_matrix)
93+
c2.exp1(0, 1, theta=-0.5 / 2, unitary=tc.gates._zz_matrix)
9494
np.testing.assert_allclose(c1.state(), c2.state(), atol=1e-5)

0 commit comments

Comments
 (0)