-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathvqe_parallel_pmap.py
68 lines (55 loc) · 1.84 KB
/
vqe_parallel_pmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
jax pmap paradigm for vqe on multiple gpus
"""
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from functools import partial
import jax
import optax
import tensorcircuit as tc
K = tc.set_backend("jax")
tc.set_contractor("cotengra")
def vqef(param, measure, n, nlayers):
c = tc.Circuit(n)
c.h(range(n))
for i in range(nlayers):
c.rzz(range(n - 1), range(1, n), theta=param[i, 0])
c.rx(range(n), theta=param[i, 1])
return K.real(
tc.templates.measurements.parameterized_measurements(c, measure, onehot=True)
)
def get_tfim_ps(n):
tfim_ps = []
for i in range(n):
tfim_ps.append(tc.quantum.xyz2ps({"x": [i]}, n=n))
for i in range(n):
tfim_ps.append(tc.quantum.xyz2ps({"z": [i, (i + 1) % n]}, n=n))
return K.convert_to_tensor(tfim_ps)
vqg_vgf = jax.vmap(K.value_and_grad(vqef), in_axes=(None, 0, None, None))
@partial(
jax.pmap,
axis_name="pmap",
in_axes=(0, 0, None, None),
static_broadcasted_argnums=(2, 3),
)
def update(param, measure, n, nlayers):
# Compute the gradients on the given minibatch (individually on each device).
loss, grads = vqg_vgf(param, measure, n, nlayers)
grads = K.sum(grads, axis=0)
grads = jax.lax.psum(grads, axis_name="pmap")
loss = K.sum(loss, axis=0)
loss = jax.lax.psum(loss, axis_name="pmap")
param = opt.update(grads, param)
return param, loss
if __name__ == "__main__":
n = 8
nlayers = 4
ndevices = 8
m = get_tfim_ps(n)
m = K.reshape(m, [ndevices, m.shape[0] // ndevices] + list(m.shape[1:]))
param = K.stateful_randn(jax.random.PRNGKey(43), shape=[nlayers, 2, n], stddev=0.1)
param = K.stack([param] * ndevices)
opt = K.optimizer(optax.adam(1e-2))
for _ in range(100):
param, loss = update(param, m, n, nlayers)
print(loss[0])