-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathhybrid_gpu_pipeline.py
108 lines (83 loc) · 2.71 KB
/
hybrid_gpu_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
quantum part in tensorflow or jax, neural part in torch, both on GPU,
fantastic hybrid pipeline
"""
import os
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import time
import numpy as np
import tensorflow as tf
import torch
import tensorcircuit as tc
K = tc.set_backend("tensorflow")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
# dataset preparation
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0
def filter_pair(x, y, a, b):
keep = (y == a) | (y == b)
x, y = x[keep], y[keep]
y = y == a
return x, y
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
x_train_torch = torch.tensor(x_train_bin)
x_train_torch = x_train_torch.to(device=device)
y_train_torch = y_train_torch.to(device=device)
n = 9
nlayers = 3
# We define the quantum function,
# note how this function is running on tensorflow
def qpreds(x, weights):
c = tc.Circuit(n)
for i in range(n):
c.rx(i, theta=x[i])
for j in range(nlayers):
for i in range(n - 1):
c.cnot(i, i + 1)
for i in range(n):
c.rx(i, theta=weights[2 * j, i])
c.ry(i, theta=weights[2 * j + 1, i])
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
# qpreds_batch = tc.interfaces.torch_interface(qpreds_vmap, jit=True, enable_dlpack=True)
quantumnet = tc.TorchLayer(
qpreds,
weights_shape=[2 * nlayers, n],
use_vmap=True,
use_interface=True,
use_jit=True,
enable_dlpack=True,
)
model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid())
model = model.to(device=device)
criterion = torch.nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
nepochs = 300
nbatch = 32
times = []
for epoch in range(nepochs):
index = np.random.randint(low=0, high=100, size=nbatch)
# index = np.arange(nbatch)
inputs, labels = x_train_torch[index], y_train_torch[index]
opt.zero_grad()
with torch.set_grad_enabled(True):
time0 = time.time()
yps = model(inputs)
loss = criterion(
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
)
loss.backward()
if epoch % 100 == 0:
print(loss)
opt.step()
time1 = time.time()
times.append(time1 - time0)
print("training time per step: ", np.mean(times[1:]))