@@ -344,7 +344,7 @@ def vqe_energy(inputs, param, n, nlayers):
344
344
def test_vvag (backend ):
345
345
n = 4
346
346
nlayers = 3
347
- inp = tc .backend .ones ([2 ** n ]) / 2 ** (n / 2 )
347
+ inp = tc .backend .ones ([2 ** n ]) / 2 ** (n / 2 )
348
348
param = tc .backend .ones ([2 * nlayers , n ])
349
349
inp = tc .backend .cast (inp , "complex64" )
350
350
param = tc .backend .cast (param , "complex64" )
@@ -355,7 +355,7 @@ def test_vvag(backend):
355
355
v0 , (g00 , g01 ) = vg (inp , param )
356
356
357
357
batch = 8
358
- inps = tc .backend .ones ([batch , 2 ** n ]) / 2 ** (n / 2 )
358
+ inps = tc .backend .ones ([batch , 2 ** n ]) / 2 ** (n / 2 )
359
359
inps = tc .backend .cast (inps , "complex64" )
360
360
361
361
pvag = tc .backend .vvag (vqe_energy_p , argnums = (0 , 1 ))
@@ -382,7 +382,7 @@ def dict_plus(x, y):
382
382
@pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
383
383
def test_vjp (backend ):
384
384
def f (x ):
385
- return x ** 2
385
+ return x ** 2
386
386
387
387
inputs = tc .backend .ones ([2 , 2 ])
388
388
v , g = tc .backend .vjp (f , inputs , inputs )
@@ -410,7 +410,7 @@ def f(x):
410
410
np .testing .assert_allclose (tc .backend .numpy (g ), np .ones ([1 ]), atol = 1e-5 )
411
411
412
412
def f2 (x ):
413
- return x ** 2
413
+ return x ** 2
414
414
415
415
inputs = tc .backend .ones ([1 ]) + 1.0j * tc .backend .ones ([1 ])
416
416
v = tc .backend .ones ([1 ], dtype = "complex64" ) # + 1.0j * tc.backend.ones([1])
@@ -440,7 +440,7 @@ def f3(d):
440
440
@pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
441
441
def test_jvp (backend ):
442
442
def f (x ):
443
- return x ** 2
443
+ return x ** 2
444
444
445
445
inputs = tc .backend .ones ([2 , 2 ])
446
446
v , g = tc .backend .jvp (f , inputs , inputs )
@@ -469,7 +469,7 @@ def f(x):
469
469
np .testing .assert_allclose (tc .backend .numpy (g ), np .ones ([1 ]), atol = 1e-5 )
470
470
471
471
def f2 (x ):
472
- return x ** 2
472
+ return x ** 2
473
473
474
474
inputs = tc .backend .ones ([1 ]) + 1.0j * tc .backend .ones ([1 ])
475
475
v = tc .backend .ones ([1 ]) + 1.0j * tc .backend .ones ([1 ])
@@ -496,27 +496,27 @@ def test_jac(backend, mode):
496
496
backend_jac = getattr (tc .backend , mode )
497
497
498
498
def f (x ):
499
- return x ** 2
499
+ return x ** 2
500
500
501
501
x = tc .backend .ones ([3 ])
502
502
jacf = backend_jac (f )
503
503
np .testing .assert_allclose (jacf (x ), 2 * np .eye (3 ), atol = 1e-5 )
504
504
505
505
def f2 (x ):
506
- return x ** 2 , x
506
+ return x ** 2 , x
507
507
508
508
jacf2 = backend_jac (f2 )
509
509
np .testing .assert_allclose (jacf2 (x )[1 ], np .eye (3 ), atol = 1e-5 )
510
510
np .testing .assert_allclose (jacf2 (x )[0 ], 2 * np .eye (3 ), atol = 1e-5 )
511
511
512
512
def f3 (x , y ):
513
- return x + y ** 2
513
+ return x + y ** 2
514
514
515
515
jacf3 = backend_jac (f3 , argnums = (0 , 1 ))
516
516
np .testing .assert_allclose (jacf3 (x , x )[1 ], 2 * np .eye (3 ), atol = 1e-5 )
517
517
518
518
def f4 (x , y ):
519
- return x ** 2 , y
519
+ return x ** 2 , y
520
520
521
521
# note the subtle difference of two tuples order in jacrev and jacfwd for current API
522
522
# the value happen to be the same here, though
@@ -531,7 +531,7 @@ def test_jac_md_input(backend, mode):
531
531
backend_jac = getattr (tc .backend , mode )
532
532
533
533
def f (x ):
534
- return x ** 2
534
+ return x ** 2
535
535
536
536
x = tc .backend .ones ([2 , 3 ])
537
537
jacf = backend_jac (f )
@@ -565,7 +565,7 @@ def f(x):
565
565
def test_vvag_has_aux (backend ):
566
566
def f (x ):
567
567
y = tc .backend .sum (x )
568
- return tc .backend .real (y ** 2 ), y
568
+ return tc .backend .real (y ** 2 ), y
569
569
570
570
fvvag = tc .backend .vvag (f , has_aux = True )
571
571
(_ , v1 ), _ = fvvag (tc .backend .ones ([10 , 2 ]))
@@ -741,7 +741,7 @@ def test_with_level_set_return(backend):
741
741
@pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
742
742
def test_grad_has_aux (backend ):
743
743
def f (x ):
744
- return tc .backend .real (x ** 2 ), x ** 3
744
+ return tc .backend .real (x ** 2 ), x ** 3
745
745
746
746
vg = tc .backend .value_and_grad (f , has_aux = True )
747
747
@@ -750,7 +750,7 @@ def f(x):
750
750
)
751
751
752
752
def f2 (x ):
753
- return tc .backend .real (x ** 2 ), (x ** 3 , tc .backend .ones ([3 ]))
753
+ return tc .backend .real (x ** 2 ), (x ** 3 , tc .backend .ones ([3 ]))
754
754
755
755
gs = tc .backend .grad (f2 , has_aux = True )
756
756
np .testing .assert_allclose (gs (tc .backend .ones ([]))[0 ], 2.0 , atol = 1e-5 )
@@ -833,7 +833,7 @@ def f2(params, n):
833
833
def test_hessian (backend ):
834
834
# hessian support is now very fragile and especially has potential issues on tf backend
835
835
def f (param ):
836
- return tc .backend .sum (param ** 2 )
836
+ return tc .backend .sum (param ** 2 )
837
837
838
838
hf = tc .backend .hessian (f )
839
839
param = tc .backend .ones ([2 ])
0 commit comments