22
22
23
23
print (device )
24
24
25
+ enable_dlpack = True
26
+ # enable_dlpack = False # for old version of ML libs
27
+ tf_device = "/GPU:0"
28
+ # tf_device "/device:CPU:0"
29
+
25
30
# dataset preparation
26
31
27
32
(x_train , y_train ), (x_test , y_test ) = tf .keras .datasets .mnist .load_data ()
@@ -52,17 +57,18 @@ def filter_pair(x, y, a, b):
52
57
53
58
54
59
def qpreds (x , weights ):
55
- c = tc .Circuit (n )
56
- for i in range (n ):
57
- c .rx (i , theta = x [i ])
58
- for j in range (nlayers ):
59
- for i in range (n - 1 ):
60
- c .cnot (i , i + 1 )
60
+ with tf .device (tf_device ):
61
+ c = tc .Circuit (n )
61
62
for i in range (n ):
62
- c .rx (i , theta = weights [2 * j , i ])
63
- c .ry (i , theta = weights [2 * j + 1 , i ])
63
+ c .rx (i , theta = x [i ])
64
+ for j in range (nlayers ):
65
+ for i in range (n - 1 ):
66
+ c .cnot (i , i + 1 )
67
+ for i in range (n ):
68
+ c .rx (i , theta = weights [2 * j , i ])
69
+ c .ry (i , theta = weights [2 * j + 1 , i ])
64
70
65
- return K .stack ([K .real (c .expectation_ps (z = [i ])) for i in range (n )])
71
+ return K .stack ([K .real (c .expectation_ps (z = [i ])) for i in range (n )])
66
72
67
73
68
74
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
@@ -74,9 +80,8 @@ def qpreds(x, weights):
74
80
use_vmap = True ,
75
81
use_interface = True ,
76
82
use_jit = True ,
77
- enable_dlpack = True ,
83
+ enable_dlpack = enable_dlpack ,
78
84
)
79
- # enable_dlpack = False for old version of ML libs
80
85
81
86
82
87
model = torch .nn .Sequential (quantumnet , torch .nn .Linear (9 , 1 ), torch .nn .Sigmoid ())
0 commit comments