|
5 | 5 | from functools import partial
|
6 | 6 | from typing import Any, Callable, Optional, Sequence, Union
|
7 | 7 |
|
| 8 | +import numpy as np |
| 9 | + |
8 | 10 | from .cons import backend, dtypestr
|
9 | 11 |
|
10 | 12 | Tensor = Any
|
@@ -202,3 +204,56 @@ def energy(params: Tensor) -> Tensor:
|
202 | 204 | return backend.grad(energy)(params)
|
203 | 205 |
|
204 | 206 | return wrapper
|
| 207 | + |
| 208 | + |
| 209 | +def parameter_shift_grad( |
| 210 | + f: Callable[..., Tensor], |
| 211 | + argnums: Union[int, Sequence[int]] = 0, |
| 212 | + jit: bool = False, |
| 213 | +) -> Callable[..., Tensor]: |
| 214 | + """ |
| 215 | + similar to `grad` function but using parameter shift internally instead of AD, |
| 216 | + vmap is utilized for evaluation, so the speed is still ok |
| 217 | +
|
| 218 | + :param f: quantum function with weights in and expectation out |
| 219 | + :type f: Callable[..., Tensor] |
| 220 | + :param argnums: label which args should be differentiated, |
| 221 | + defaults to 0 |
| 222 | + :type argnums: Union[int, Sequence[int]], optional |
| 223 | + :param jit: whether jit the original function `f` at the beginning, |
| 224 | + defaults to False |
| 225 | + :type jit: bool, optional |
| 226 | + :return: the grad function |
| 227 | + :rtype: Callable[..., Tensor] |
| 228 | + """ |
| 229 | + if jit is True: |
| 230 | + f = backend.jit(f) |
| 231 | + |
| 232 | + if isinstance(argnums, int): |
| 233 | + argnums = [argnums] |
| 234 | + |
| 235 | + vfs = [backend.vmap(f, vectorized_argnums=i) for i in argnums] |
| 236 | + |
| 237 | + def grad_f(*args: Any, **kws: Any) -> Any: |
| 238 | + grad_values = [] |
| 239 | + for i in argnums: # type: ignore |
| 240 | + shape = backend.shape_tuple(args[i]) |
| 241 | + size = backend.sizen(args[i]) |
| 242 | + onehot = backend.eye(size) |
| 243 | + onehot = backend.cast(onehot, args[i].dtype) |
| 244 | + onehot = backend.reshape(onehot, [size] + list(shape)) |
| 245 | + onehot = np.pi / 2 * onehot |
| 246 | + nargs = list(args) |
| 247 | + arg = backend.reshape(args[i], [1] + list(shape)) |
| 248 | + batched_arg = backend.tile(arg, [size] + [1 for _ in shape]) |
| 249 | + nargs[i] = batched_arg + onehot |
| 250 | + nargs2 = list(args) |
| 251 | + nargs2[i] = batched_arg - onehot |
| 252 | + r = (vfs[i](*nargs, **kws) - vfs[i](*nargs2, **kws)) / 2.0 |
| 253 | + r = backend.reshape(r, shape) |
| 254 | + grad_values.append(r) |
| 255 | + if len(argnums) > 1: # type: ignore |
| 256 | + return tuple(grad_values) |
| 257 | + return grad_values[0] |
| 258 | + |
| 259 | + return grad_f |
0 commit comments