Skip to content

Commit df9ed47

Browse files
committed
black & pass mypy
1 parent 79cd54f commit df9ed47

19 files changed

+69
-64
lines changed

examples/checkpoint_memsave.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def totallayer(s, param):
9797

9898

9999
def vqe_forward(param):
100-
s = tc.backend.ones([2**nwires])
100+
s = tc.backend.ones([2 ** nwires])
101101
s /= tc.backend.norm(s)
102102
s = totallayer(s, param)
103103
e = tc.expectation((tc.gates.x(), [1]), ket=s)

examples/variational_dynamics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def update(theta, lhs, rhs, tau):
9696
# TFIM Hamiltonian defined on lattice graph g (1D OBC chain)
9797
h = tc.array_to_tensor(h)
9898

99-
psi0 = np.zeros(2**N)
99+
psi0 = np.zeros(2 ** N)
100100
psi0[0] = 1.0
101101
psi0 = tc.array_to_tensor(psi0)
102102

examples/vqeh2o_benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
mb = tc.quantum.PauliStringSum2COO_numpy(lsb, wb)
4242
mbd = mb.todense()
4343
mb = K.coo_sparse_matrix(
44-
np.transpose(np.stack([mb.row, mb.col])), mb.data, shape=(2**n, 2**n)
44+
np.transpose(np.stack([mb.row, mb.col])), mb.data, shape=(2 ** n, 2 ** n)
4545
)
4646
mbd = tc.array_to_tensor(mbd)
4747

tensorcircuit/applications/utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ def amplitude_encoding(
5454
norm = tf.linalg.norm(fig, axis=1)
5555
norm = norm[..., tf.newaxis]
5656
fig = fig / norm
57-
if fig.shape[1] < 2**qubits:
57+
if fig.shape[1] < 2 ** qubits:
5858
fig = tf.concat(
5959
[
6060
fig,
61-
tf.zeros([fig.shape[0], 2**qubits - fig.shape[1]], dtype=tf.float64),
61+
tf.zeros([fig.shape[0], 2 ** qubits - fig.shape[1]], dtype=tf.float64),
6262
],
6363
axis=1,
6464
)
@@ -205,7 +205,7 @@ def train_qml_vag(
205205
with tf.GradientTape() as tape:
206206
tape.watch(nnp)
207207
cnnp = tf.cast(nnp, dtype=tf.complex64)
208-
c = Circuit(nqubits, inputs=np.ones([1024], dtype=np.complex64) / 2**5)
208+
c = Circuit(nqubits, inputs=np.ones([1024], dtype=np.complex64) / 2 ** 5)
209209
for epoch in range(epochs):
210210
for i in range(nqubits):
211211
c.rz(i, theta=cnnp[3 * epoch, i]) # type: ignore
@@ -275,7 +275,7 @@ def validate_qml_vag(
275275
) -> Any:
276276
xs, ys = gdata
277277
cnnp = tf.cast(nnp, dtype=tf.complex64)
278-
c = Circuit(nqubits, inputs=np.ones([1024], dtype=np.complex64) / 2**5)
278+
c = Circuit(nqubits, inputs=np.ones([1024], dtype=np.complex64) / 2 ** 5)
279279
for epoch in range(epochs):
280280
for i in range(nqubits):
281281
c.rz(i, theta=cnnp[3 * epoch, i]) # type: ignore
@@ -401,7 +401,7 @@ def TFIM1Denergy(
401401
Jzz *= 4
402402
for i in range(L):
403403
q = np.pi * (2 * i - (1 + (-1) ** L) / 2) / L
404-
e -= np.abs(Jx) / 2 * np.sqrt(1 + Jzz**2 / 4 / Jx**2 - Jzz / Jx * np.cos(q))
404+
e -= np.abs(Jx) / 2 * np.sqrt(1 + Jzz ** 2 / 4 / Jx ** 2 - Jzz / Jx * np.cos(q))
405405
return e
406406

407407

tensorcircuit/applications/vags.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def GHZ_vag(
5454
gdata: Any, nnp: Tensor, preset: Sequence[int], verbose: bool = False, n: int = 3
5555
) -> Tuple[Tensor, Tensor]:
5656
# gdata = None
57-
reference_state = np.zeros([2**n])
57+
reference_state = np.zeros([2 ** n])
5858
# W states benchmarks
5959
# for i in range(n):
6060
# reference_state[2**(i)] = 1/np.sqrt(n)
@@ -80,7 +80,7 @@ def GHZ_vag(
8080
s = circuit.wavefunction()
8181
s = tf.reshape(
8282
s,
83-
[2**n],
83+
[2 ** n],
8484
)
8585
loss = tf.math.reduce_sum(
8686
tf.math.abs(s - reference_state)
@@ -1053,6 +1053,7 @@ def quantum_mp_qaoa_vag(
10531053
gmatrix = tf.constant(gmatrix)
10541054
return loss[0], gmatrix
10551055

1056+
10561057
except NameError as e:
10571058
logger.warning(e)
10581059
logger.warning("tfq related vags disabled due to missing packages")

tensorcircuit/applications/van.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
"""
44

55
from typing import Any, Optional, Tuple, Union, List
6-
from sympy import N
76
import tensorflow as tf
7+
from tensorflow.keras.layers import Layer
8+
from tensorflow.keras.models import Model
89
import numpy as np
910

1011
# TODO(@refraction-ray): Add type annotation in this module some time.
1112

1213

13-
class MaskedLinear(tf.keras.layers.Layer):
14+
class MaskedLinear(Layer): # type: ignore
1415
def __init__(
1516
self,
1617
input_space: int,
@@ -55,7 +56,7 @@ def regularization(self, lbd_w: float = 1.0, lbd_b: float = 1.0) -> tf.Tensor:
5556
return lbd_w * tf.reduce_sum(self.w ** 2) + lbd_b * tf.reduce_sum(self.b ** 2)
5657

5758

58-
class MADE(tf.keras.Model):
59+
class MADE(Model): # type: ignore
5960
def __init__(
6061
self,
6162
input_space: int,
@@ -75,7 +76,7 @@ def __init__(
7576
else:
7677
self._dtype = dtype
7778
self._m: List[np.array] = []
78-
self._masks = []
79+
self._masks: List[tf.Tensor] = []
7980
self.ml_layer: List[Any] = []
8081
self.input_space = input_space
8182
self.spin_channel = spin_channel
@@ -241,7 +242,7 @@ def log_prob(self, sample: tf.Tensor) -> tf.Tensor:
241242
return log_prob
242243

243244

244-
class MaskedConv2D(tf.keras.layers.Layer):
245+
class MaskedConv2D(Layer): # type: ignore
245246
def __init__(self, mask_type: str, **kwargs: Any):
246247
super().__init__()
247248
assert mask_type in {"A", "B"}, "mask_type must be in A or B"
@@ -268,7 +269,7 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
268269
return self.conv(inputs)
269270

270271

271-
class ResidualBlock(tf.keras.layers.Layer):
272+
class ResidualBlock(Layer): # type: ignore
272273
def __init__(self, layers: List[Any]):
273274
super().__init__()
274275
self.layers = layers
@@ -280,7 +281,7 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
280281
return y + inputs
281282

282283

283-
class PixelCNN(tf.keras.Model):
284+
class PixelCNN(Model): # type: ignore
284285
def __init__(self, spin_channel: int, depth: int, filters: int):
285286
super().__init__()
286287
self.rb = []
@@ -340,7 +341,6 @@ def _log_prob(self, sample: tf.Tensor, x_hat: tf.Tensor) -> tf.Tensor:
340341
probm = tf.multiply(x_hat, sample)
341342
probm = tf.reduce_sum(probm, axis=-1)
342343
lnprobm = tf.math.log(probm + eps)
343-
344344
return tf.reduce_sum(lnprobm, axis=[-1, -2])
345345

346346
def log_prob(self, sample: tf.Tensor) -> tf.Tensor:
@@ -349,7 +349,7 @@ def log_prob(self, sample: tf.Tensor) -> tf.Tensor:
349349
return log_prob
350350

351351

352-
class NMF(tf.keras.Model):
352+
class NMF(Model): # type: ignore
353353
def __init__(
354354
self,
355355
spin_channel: int,
@@ -378,7 +378,7 @@ def call(self, inputs: Optional[tf.Tensor] = None) -> tf.Tensor:
378378
else:
379379
return self.w + self.probamp
380380

381-
def sample(self, batch_size: int) -> tf.Tensor:
381+
def sample(self, batch_size: int) -> Tuple[tf.Tensor, tf.Tensor]:
382382
x_hat = self.call()
383383
x_hat = x_hat[tf.newaxis, :]
384384
tile_shape = tuple([batch_size] + [1 for _ in range(self.D + 1)])

tensorcircuit/applications/vqes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def construct_matrix_tf(ham: List[List[float]], dtype: Any = tf.complex128) -> T
8686
def construct_matrix_v2(ham: List[List[float]], dtype: Any = tf.complex128) -> Tensor:
8787
# deprecated
8888
s = len(ham[0]) - 1
89-
h = tf.zeros([2**s, 2**s], dtype=dtype)
89+
h = tf.zeros([2 ** s, 2 ** s], dtype=dtype)
9090
for term in tqdm(ham, desc="Hamiltonian building"):
9191
term = list(term)
9292
for i, t in enumerate(term):

tensorcircuit/densitymatrix.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def densitymatrix(self, check: bool = False, reuse: bool = True) -> Tensor:
287287
nodes, _ = self._copy_dm_tensor(conj=False, reuse=reuse)
288288
# t = contractor(nodes, output_edge_order=d_edges)
289289
dm = backend.reshape(
290-
nodes[0].tensor, shape=[2**self._nqubits, 2**self._nqubits]
290+
nodes[0].tensor, shape=[2 ** self._nqubits, 2 ** self._nqubits]
291291
)
292292
if check:
293293
self.check_density_matrix(dm)

tensorcircuit/mpscircuit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def apply_adjacent_double_gate(
277277
max_truncation_err=self.max_truncation_err,
278278
relative=self.relative,
279279
)
280-
self._fidelity *= 1 - backend.real(backend.sum(err**2))
280+
self._fidelity *= 1 - backend.real(backend.sum(err ** 2))
281281

282282
def apply_double_gate(
283283
self,

tensorcircuit/quantum.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def ps2coo_core(
12771277
)
12781278
return tf.SparseTensor(indices=indices, values=values, dense_shape=(s, s)) # type: ignore
12791279

1280+
12801281
except NameError:
12811282
logger.warning(
12821283
"tensorflow is not installed, and sparse Hamiltonian generation utilities are disabled"
@@ -1703,7 +1704,7 @@ def spin_by_basis(n: int, m: int, elements: Tuple[int, int] = (1, -1)) -> Tensor
17031704
backend.cast(
17041705
backend.convert_to_tensor(np.array([[elements[0]], [elements[1]]])), "int32"
17051706
),
1706-
[2**m, int(2 ** (n - m - 1))],
1707+
[2 ** m, int(2 ** (n - m - 1))],
17071708
)
17081709
return backend.reshape(s, [-1])
17091710

tensorcircuit/templates/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ def amplitude_encoding(
2020
fig = backend.reshape(fig, shape=[-1])
2121
norm = backend.norm(fig)
2222
fig = fig / norm
23-
if backend.shape_tuple(fig)[0] < 2**nqubits:
23+
if backend.shape_tuple(fig)[0] < 2 ** nqubits:
2424
fig = backend.concat(
2525
[
2626
fig,
2727
backend.zeros(
28-
[2**nqubits - backend.shape_tuple(fig)[0]], dtype=fig.dtype
28+
[2 ** nqubits - backend.shape_tuple(fig)[0]], dtype=fig.dtype
2929
),
3030
],
3131
)

tensorcircuit/translation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def perm_matrix(n: int) -> Tensor:
3333
:return: The permutation matrix P
3434
:rtype: Tensor
3535
"""
36-
p_mat = np.zeros([2**n, 2**n])
37-
for i in range(2**n):
36+
p_mat = np.zeros([2 ** n, 2 ** n])
37+
for i in range(2 ** n):
3838
bit = i
3939
revs_i = 0
4040
for j in range(n):

tests/test_backends.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def vqe_energy(inputs, param, n, nlayers):
344344
def test_vvag(backend):
345345
n = 4
346346
nlayers = 3
347-
inp = tc.backend.ones([2**n]) / 2 ** (n / 2)
347+
inp = tc.backend.ones([2 ** n]) / 2 ** (n / 2)
348348
param = tc.backend.ones([2 * nlayers, n])
349349
inp = tc.backend.cast(inp, "complex64")
350350
param = tc.backend.cast(param, "complex64")
@@ -355,7 +355,7 @@ def test_vvag(backend):
355355
v0, (g00, g01) = vg(inp, param)
356356

357357
batch = 8
358-
inps = tc.backend.ones([batch, 2**n]) / 2 ** (n / 2)
358+
inps = tc.backend.ones([batch, 2 ** n]) / 2 ** (n / 2)
359359
inps = tc.backend.cast(inps, "complex64")
360360

361361
pvag = tc.backend.vvag(vqe_energy_p, argnums=(0, 1))
@@ -382,7 +382,7 @@ def dict_plus(x, y):
382382
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
383383
def test_vjp(backend):
384384
def f(x):
385-
return x**2
385+
return x ** 2
386386

387387
inputs = tc.backend.ones([2, 2])
388388
v, g = tc.backend.vjp(f, inputs, inputs)
@@ -410,7 +410,7 @@ def f(x):
410410
np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
411411

412412
def f2(x):
413-
return x**2
413+
return x ** 2
414414

415415
inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
416416
v = tc.backend.ones([1], dtype="complex64") # + 1.0j * tc.backend.ones([1])
@@ -440,7 +440,7 @@ def f3(d):
440440
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
441441
def test_jvp(backend):
442442
def f(x):
443-
return x**2
443+
return x ** 2
444444

445445
inputs = tc.backend.ones([2, 2])
446446
v, g = tc.backend.jvp(f, inputs, inputs)
@@ -469,7 +469,7 @@ def f(x):
469469
np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
470470

471471
def f2(x):
472-
return x**2
472+
return x ** 2
473473

474474
inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
475475
v = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
@@ -496,27 +496,27 @@ def test_jac(backend, mode):
496496
backend_jac = getattr(tc.backend, mode)
497497

498498
def f(x):
499-
return x**2
499+
return x ** 2
500500

501501
x = tc.backend.ones([3])
502502
jacf = backend_jac(f)
503503
np.testing.assert_allclose(jacf(x), 2 * np.eye(3), atol=1e-5)
504504

505505
def f2(x):
506-
return x**2, x
506+
return x ** 2, x
507507

508508
jacf2 = backend_jac(f2)
509509
np.testing.assert_allclose(jacf2(x)[1], np.eye(3), atol=1e-5)
510510
np.testing.assert_allclose(jacf2(x)[0], 2 * np.eye(3), atol=1e-5)
511511

512512
def f3(x, y):
513-
return x + y**2
513+
return x + y ** 2
514514

515515
jacf3 = backend_jac(f3, argnums=(0, 1))
516516
np.testing.assert_allclose(jacf3(x, x)[1], 2 * np.eye(3), atol=1e-5)
517517

518518
def f4(x, y):
519-
return x**2, y
519+
return x ** 2, y
520520

521521
# note the subtle difference of two tuples order in jacrev and jacfwd for current API
522522
# the value happen to be the same here, though
@@ -531,7 +531,7 @@ def test_jac_md_input(backend, mode):
531531
backend_jac = getattr(tc.backend, mode)
532532

533533
def f(x):
534-
return x**2
534+
return x ** 2
535535

536536
x = tc.backend.ones([2, 3])
537537
jacf = backend_jac(f)
@@ -565,7 +565,7 @@ def f(x):
565565
def test_vvag_has_aux(backend):
566566
def f(x):
567567
y = tc.backend.sum(x)
568-
return tc.backend.real(y**2), y
568+
return tc.backend.real(y ** 2), y
569569

570570
fvvag = tc.backend.vvag(f, has_aux=True)
571571
(_, v1), _ = fvvag(tc.backend.ones([10, 2]))
@@ -741,7 +741,7 @@ def test_with_level_set_return(backend):
741741
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
742742
def test_grad_has_aux(backend):
743743
def f(x):
744-
return tc.backend.real(x**2), x**3
744+
return tc.backend.real(x ** 2), x ** 3
745745

746746
vg = tc.backend.value_and_grad(f, has_aux=True)
747747

@@ -750,7 +750,7 @@ def f(x):
750750
)
751751

752752
def f2(x):
753-
return tc.backend.real(x**2), (x**3, tc.backend.ones([3]))
753+
return tc.backend.real(x ** 2), (x ** 3, tc.backend.ones([3]))
754754

755755
gs = tc.backend.grad(f2, has_aux=True)
756756
np.testing.assert_allclose(gs(tc.backend.ones([]))[0], 2.0, atol=1e-5)
@@ -833,7 +833,7 @@ def f2(params, n):
833833
def test_hessian(backend):
834834
# hessian support is now very fragile and especially has potential issues on tf backend
835835
def f(param):
836-
return tc.backend.sum(param**2)
836+
return tc.backend.sum(param ** 2)
837837

838838
hf = tc.backend.hessian(f)
839839
param = tc.backend.ones([2])

0 commit comments

Comments
 (0)