""" quantum part in tensorflow or jax, neural part in torch, both on GPU, fantastic hybrid pipeline """ import os os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" import time import numpy as np import tensorflow as tf import torch import tensorcircuit as tc K = tc.set_backend("tensorflow") if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") print(device) enable_dlpack = True # enable_dlpack = False # for old version of ML libs tf_device = "/GPU:0" # tf_device = "/device:CPU:0" # another scheme to globally close GPU only for tf # https://datascience.stackexchange.com/a/76039 # but if gpu support is fully shut down as above # dlpack=True wont work # dataset preparation (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train[..., np.newaxis] / 255.0 def filter_pair(x, y, a, b): keep = (y == a) | (y == b) x, y = x[keep], y[keep] y = y == a return x, y x_train, y_train = filter_pair(x_train, y_train, 1, 5) x_train_small = tf.image.resize(x_train, (3, 3)).numpy() x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32) x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9]) y_train_torch = torch.tensor(y_train, dtype=torch.float32) x_train_torch = torch.tensor(x_train_bin) x_train_torch = x_train_torch.to(device=device) y_train_torch = y_train_torch.to(device=device) n = 9 nlayers = 3 # We define the quantum function, # note how this function is running on tensorflow def qpreds(x, weights): with tf.device(tf_device): c = tc.Circuit(n) for i in range(n): c.rx(i, theta=x[i]) for j in range(nlayers): for i in range(n - 1): c.cnot(i, i + 1) for i in range(n): c.rx(i, theta=weights[2 * j, i]) c.ry(i, theta=weights[2 * j + 1, i]) return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)]) # qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0) # qpreds_batch = tc.interfaces.torch_interface(qpreds_vmap, jit=True, enable_dlpack=True) quantumnet = tc.TorchLayer( qpreds, weights_shape=[2 * nlayers, n], use_vmap=True, use_interface=True, use_jit=True, enable_dlpack=enable_dlpack, ) model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid()) model = model.to(device=device) criterion = torch.nn.BCELoss() opt = torch.optim.Adam(model.parameters(), lr=1e-2) nepochs = 300 nbatch = 32 times = [] for epoch in range(nepochs): index = np.random.randint(low=0, high=100, size=nbatch) # index = np.arange(nbatch) inputs, labels = x_train_torch[index], y_train_torch[index] opt.zero_grad() with torch.set_grad_enabled(True): time0 = time.time() yps = model(inputs) loss = criterion( torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1]) ) loss.backward() if epoch % 100 == 0: print(loss) opt.step() time1 = time.time() times.append(time1 - time0) print("training time per step: ", np.mean(times[1:]))