Skip to content

Commit 80e7fa9

Browse files
add hybrid pipeline example all on gpu
1 parent b5c788e commit 80e7fa9

File tree

5 files changed

+119
-7
lines changed

5 files changed

+119
-7
lines changed

.github/ISSUE_TEMPLATE/tc_enhancement_proposal.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ assignees: ""
88

99
<!--Inspired from NEP: https://numpy.org/neps/nep-template.html-->
1010

11+
<!-- If you have some small feature request or issue report, just open instead a plain issue -->
12+
1113
# TEP - Title
1214

1315
Author

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
- Add `to_dlpack` and `from_dlpack` method on backends
1414

15-
- Add dlpack path for interfaces
15+
- Add `enable_dlpack` option on interfaces and torchnn
1616

1717
### Changed
1818

@@ -26,6 +26,8 @@
2626

2727
- Fixed `numpy` method bug in pytorch backend when the input tensor requires grad (#24) and when the tensor is on GPU (#25)
2828

29+
- Fixed `TorchLayer` parameter list auto registeration
30+
2931
## 0.2.1
3032

3133
### Added

examples/hybrid_gpu_pipeline.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
quantum part in tensorflow or jax, neural part in torch, both on GPU,
3+
fantastic hybrid pipeline
4+
"""
5+
6+
import os
7+
8+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
9+
import time
10+
import numpy as np
11+
import tensorflow as tf
12+
import torch
13+
import tensorcircuit as tc
14+
15+
K = tc.set_backend("tensorflow")
16+
17+
if torch.cuda.is_available():
18+
device = torch.device("cuda")
19+
else:
20+
device = torch.device("cpu")
21+
22+
23+
print(device)
24+
25+
# dataset preparation
26+
27+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
28+
x_train = x_train[..., np.newaxis] / 255.0
29+
30+
31+
def filter_pair(x, y, a, b):
32+
keep = (y == a) | (y == b)
33+
x, y = x[keep], y[keep]
34+
y = y == a
35+
return x, y
36+
37+
38+
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
39+
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
40+
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
41+
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
42+
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
43+
x_train_torch = torch.tensor(x_train_bin)
44+
x_train_torch = x_train_torch.to(device=device)
45+
y_train_torch = y_train_torch.to(device=device)
46+
47+
n = 9
48+
nlayers = 3
49+
50+
# We define the quantum function,
51+
# note how this function is running on tensorflow
52+
53+
54+
def qpreds(x, weights):
55+
c = tc.Circuit(n)
56+
for i in range(n):
57+
c.rx(i, theta=x[i])
58+
for j in range(nlayers):
59+
for i in range(n - 1):
60+
c.cnot(i, i + 1)
61+
for i in range(n):
62+
c.rx(i, theta=weights[2 * j, i])
63+
c.ry(i, theta=weights[2 * j + 1, i])
64+
65+
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
66+
67+
68+
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
69+
# qpreds_batch = tc.interfaces.torch_interface(qpreds_vmap, jit=True, enable_dlpack=True)
70+
71+
quantumnet = tc.TorchLayer(
72+
qpreds,
73+
weights_shape=[2 * nlayers, n],
74+
use_vmap=True,
75+
use_interface=True,
76+
use_jit=True,
77+
enable_dlpack=True,
78+
)
79+
80+
81+
model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid())
82+
model = model.to(device=device)
83+
84+
85+
criterion = torch.nn.BCELoss()
86+
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
87+
nepochs = 300
88+
nbatch = 32
89+
times = []
90+
for epoch in range(nepochs):
91+
index = np.random.randint(low=0, high=100, size=nbatch)
92+
# index = np.arange(nbatch)
93+
inputs, labels = x_train_torch[index], y_train_torch[index]
94+
opt.zero_grad()
95+
96+
with torch.set_grad_enabled(True):
97+
time0 = time.time()
98+
yps = model(inputs)
99+
loss = criterion(
100+
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
101+
)
102+
loss.backward()
103+
if epoch % 100 == 0:
104+
print(loss)
105+
opt.step()
106+
time1 = time.time()
107+
times.append(time1 - time0)
108+
print("training time per step: ", np.mean(times[1:]))

tensorcircuit/interfaces/torch.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,9 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
7070
# (x, )
7171
if len(ctx.xdtype) == 1:
7272
ctx.xdtype = ctx.xdtype[0]
73+
7374
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
7475
y = fun(*x)
75-
# if not is_sequence(y):
76-
# ctx.ydtype = [y.dtype]
77-
# else:
78-
# ctx.ydtype = [yi.dtype for yi in y]
7976
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
8077
if len(x) == 1:
8178
x = x[0]

tensorcircuit/torchnn.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
use_vmap: bool = True,
2323
use_interface: bool = True,
2424
use_jit: bool = True,
25+
enable_dlpack: bool = False,
2526
):
2627
"""
2728
PyTorch nn Module wrapper on quantum function ``f``.
@@ -67,14 +68,16 @@ def qpred(x, weights):
6768
:type use_interface: bool, optional
6869
:param use_jit: whether jit ``f``, defaults to True
6970
:type use_jit: bool, optional
71+
:param enable_dlpack: whether enbale dlpack in interfaces, defaults to False
72+
:type enable_dlpack: bool, optional
7073
"""
7174
super().__init__()
7275
if use_vmap:
7376
f = backend.vmap(f, vectorized_argnums=0)
7477
if use_interface:
75-
f = torch_interface(f, jit=use_jit)
78+
f = torch_interface(f, jit=use_jit, enable_dlpack=enable_dlpack)
7679
self.f = f
77-
self.q_weights = []
80+
self.q_weights = torch.nn.ParameterList() # type: ignore
7881
if isinstance(weights_shape[0], int):
7982
weights_shape = [weights_shape]
8083
if not is_sequence(initializer):

0 commit comments

Comments
 (0)