Skip to content

Commit de48165

Browse files
version0.2.2
1 parent 9386445 commit de48165

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
## 0.2.2
6+
57
### Added
68

79
- PyTorch backend support multi pytrees version of `tree_map`

examples/hybrid_gpu_pipeline.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222

2323
print(device)
2424

25+
enable_dlpack = True
26+
# enable_dlpack = False # for old version of ML libs
27+
tf_device = "/GPU:0"
28+
# tf_device "/device:CPU:0"
29+
2530
# dataset preparation
2631

2732
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
@@ -52,17 +57,18 @@ def filter_pair(x, y, a, b):
5257

5358

5459
def qpreds(x, weights):
55-
c = tc.Circuit(n)
56-
for i in range(n):
57-
c.rx(i, theta=x[i])
58-
for j in range(nlayers):
59-
for i in range(n - 1):
60-
c.cnot(i, i + 1)
60+
with tf.device(tf_device):
61+
c = tc.Circuit(n)
6162
for i in range(n):
62-
c.rx(i, theta=weights[2 * j, i])
63-
c.ry(i, theta=weights[2 * j + 1, i])
63+
c.rx(i, theta=x[i])
64+
for j in range(nlayers):
65+
for i in range(n - 1):
66+
c.cnot(i, i + 1)
67+
for i in range(n):
68+
c.rx(i, theta=weights[2 * j, i])
69+
c.ry(i, theta=weights[2 * j + 1, i])
6470

65-
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
71+
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
6672

6773

6874
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
@@ -74,9 +80,8 @@ def qpreds(x, weights):
7480
use_vmap=True,
7581
use_interface=True,
7682
use_jit=True,
77-
enable_dlpack=True,
83+
enable_dlpack=enable_dlpack,
7884
)
79-
# enable_dlpack = False for old version of ML libs
8085

8186

8287
model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid())

tensorcircuit/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.1"
1+
__version__ = "0.2.2"
22
__author__ = "TensorCircuit Authors"
33
__creator__ = "refraction-ray"
44

0 commit comments

Comments
 (0)