|
| 1 | +""" |
| 2 | +reducing jit compiling time by general scan magic |
| 3 | +""" |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import tensorcircuit as tc |
| 7 | + |
| 8 | +n = 10 |
| 9 | +nlayers = 16 |
| 10 | +param_np = np.random.normal(size=[nlayers, n, 2]) |
| 11 | + |
| 12 | +for backend in ["tensorflow", "jax"]: |
| 13 | + with tc.runtime_backend(backend) as K: |
| 14 | + print("running %s" % K.name) |
| 15 | + |
| 16 | + def energy_reference(param, n, nlayers): |
| 17 | + c = tc.Circuit(n) |
| 18 | + for i in range(n): |
| 19 | + c.h(i) |
| 20 | + for i in range(nlayers): |
| 21 | + for j in range(n - 1): |
| 22 | + c.rzz(j, j + 1, theta=param[i, j, 0]) |
| 23 | + for j in range(n): |
| 24 | + c.rx(j, theta=param[i, j, 1]) |
| 25 | + return K.real(c.expectation_ps(z=[0, 1]) + c.expectation_ps(x=[2])) |
| 26 | + |
| 27 | + vg_reference = K.jit( |
| 28 | + K.value_and_grad(energy_reference, argnums=0), static_argnums=(1, 2) |
| 29 | + ) |
| 30 | + |
| 31 | + # a jit efficient way to utilize scan |
| 32 | + |
| 33 | + def energy(param, n, nlayers, each): |
| 34 | + def loop_f(s_, param_): |
| 35 | + c_ = tc.Circuit(n, inputs=s_) |
| 36 | + for i in range(each): |
| 37 | + for j in range(n - 1): |
| 38 | + c_.rzz(j, j + 1, theta=param_[i, j, 0]) |
| 39 | + for j in range(n): |
| 40 | + c_.rx(j, theta=param_[i, j, 1]) |
| 41 | + s_ = c_.state() |
| 42 | + return s_ |
| 43 | + |
| 44 | + c = tc.Circuit(n) |
| 45 | + for i in range(n): |
| 46 | + c.h(i) |
| 47 | + s = c.state() |
| 48 | + s1 = K.scan(loop_f, K.reshape(param, [nlayers // each, each, n, 2]), s) |
| 49 | + c1 = tc.Circuit(n, inputs=s1) |
| 50 | + return K.real(c1.expectation_ps(z=[0, 1]) + c1.expectation_ps(x=[2])) |
| 51 | + |
| 52 | + vg = K.jit( |
| 53 | + K.value_and_grad(energy, argnums=0), |
| 54 | + static_argnums=(1, 2, 3), |
| 55 | + jit_compile=True, |
| 56 | + ) |
| 57 | + # set to False can improve compile time for tf |
| 58 | + |
| 59 | + param = K.convert_to_tensor(param_np) |
| 60 | + |
| 61 | + for each in [1, 2, 4]: |
| 62 | + print(" scan impl with each=%s" % str(each)) |
| 63 | + r1 = tc.utils.benchmark(vg, param, n, nlayers, each) |
| 64 | + print(r1[0][0]) |
| 65 | + |
| 66 | + print(" plain impl") |
| 67 | + r0 = tc.utils.benchmark(vg_reference, param, n, nlayers) # too slow |
| 68 | + np.testing.assert_allclose(r0[0][0], r1[0][0], atol=1e-5) |
| 69 | + np.testing.assert_allclose(r0[0][1], r1[0][1], atol=1e-5) |
| 70 | + # correctness check |
| 71 | + |
| 72 | + |
| 73 | +# jit_compile=True icrease runtime while degrades jit time for tensorflow |
| 74 | +# and in general jax improves better with scan methodology, |
| 75 | +# both compile time and running time can outperform tf |
0 commit comments