-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathhybrid_gpu_pipeline.py
114 lines (88 loc) · 2.9 KB
/
hybrid_gpu_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
quantum part in tensorflow or jax, neural part in torch, both on GPU,
fantastic hybrid pipeline
"""
import os
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import time
import numpy as np
import tensorflow as tf
import torch
import tensorcircuit as tc
K = tc.set_backend("tensorflow")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
enable_dlpack = True
# enable_dlpack = False # for old version of ML libs
tf_device = "/GPU:0"
# tf_device "/device:CPU:0"
# dataset preparation
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0
def filter_pair(x, y, a, b):
keep = (y == a) | (y == b)
x, y = x[keep], y[keep]
y = y == a
return x, y
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
x_train_torch = torch.tensor(x_train_bin)
x_train_torch = x_train_torch.to(device=device)
y_train_torch = y_train_torch.to(device=device)
n = 9
nlayers = 3
# We define the quantum function,
# note how this function is running on tensorflow
def qpreds(x, weights):
with tf.device(tf_device):
c = tc.Circuit(n)
for i in range(n):
c.rx(i, theta=x[i])
for j in range(nlayers):
for i in range(n - 1):
c.cnot(i, i + 1)
for i in range(n):
c.rx(i, theta=weights[2 * j, i])
c.ry(i, theta=weights[2 * j + 1, i])
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
# qpreds_batch = tc.interfaces.torch_interface(qpreds_vmap, jit=True, enable_dlpack=True)
quantumnet = tc.TorchLayer(
qpreds,
weights_shape=[2 * nlayers, n],
use_vmap=True,
use_interface=True,
use_jit=True,
enable_dlpack=enable_dlpack,
)
model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid())
model = model.to(device=device)
criterion = torch.nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
nepochs = 300
nbatch = 32
times = []
for epoch in range(nepochs):
index = np.random.randint(low=0, high=100, size=nbatch)
# index = np.arange(nbatch)
inputs, labels = x_train_torch[index], y_train_torch[index]
opt.zero_grad()
with torch.set_grad_enabled(True):
time0 = time.time()
yps = model(inputs)
loss = criterion(
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
)
loss.backward()
if epoch % 100 == 0:
print(loss)
opt.step()
time1 = time.time()
times.append(time1 - time0)
print("training time per step: ", np.mean(times[1:]))