Skip to content

Commit 68be5b5

Browse files
committed
CI: update ruff to v0.6.1
1 parent afff0e0 commit 68be5b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+241
-314
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
files: \.py$
2727

2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: v0.4.4
29+
rev: v0.6.1
3030
hooks:
3131
- id: ruff
3232

cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@
510510
"outputs": [],
511511
"source": [
512512
"image_partitions = P(1, 1, 4, 2)\n",
513-
"sharded_conv = sharded_jit(conv, \n",
514-
" in_parts=(image_partitions, None), \n",
513+
"sharded_conv = sharded_jit(conv,\n",
514+
" in_parts=(image_partitions, None),\n",
515515
" out_parts=image_partitions)\n",
516516
"\n",
517517
"sharded_conv(image, kernel)"

cloud_tpu_colabs/JAX_demo.ipynb

+1-12
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@
877877
" def g(z):\n",
878878
" return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
879879
" return grad(lambda w: jnp.sum(g(w)))(x)\n",
880-
" \n",
880+
"\n",
881881
"f(x)"
882882
]
883883
},
@@ -950,17 +950,6 @@
950950
"per_example_hess = pmap(input_hess) # pmap!\n",
951951
"per_example_hess(inputs)"
952952
]
953-
},
954-
{
955-
"cell_type": "code",
956-
"execution_count": 0,
957-
"metadata": {
958-
"colab": {},
959-
"colab_type": "code",
960-
"id": "u3ggM_WYZ8QC"
961-
},
962-
"outputs": [],
963-
"source": []
964953
}
965954
],
966955
"metadata": {

cloud_tpu_colabs/Wave_Equation.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
"source": [
6868
"from functools import partial\n",
6969
"import jax\n",
70-
"from jax import jit, pmap\n",
7170
"from jax import lax\n",
7271
"from jax import tree_util\n",
7372
"import jax.numpy as jnp\n",

docs/_tutorials/advanced-autodiff.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def our_jacrev(f):
640640
y, vjp_fun = vjp(f, x)
641641
# Use vmap to do a matrix-Jacobian product.
642642
# Here, the matrix is the Euclidean basis, so we get all
643-
# entries in the Jacobian at once.
643+
# entries in the Jacobian at once.
644644
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
645645
return J
646646
return jacfun
@@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd
654654
def our_jacfwd(f):
655655
def jacfun(x):
656656
_jvp = lambda s: jvp(f, (x,), (s,))[1]
657-
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
657+
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
658658
return jnp.transpose(Jt)
659659
return jacfun
660660

docs/jep/9407-type-promotion.ipynb

-2
Original file line numberDiff line numberDiff line change
@@ -3317,7 +3317,6 @@
33173317
],
33183318
"source": [
33193319
"# @title\n",
3320-
"from jax import dtypes\n",
33213320
"import jax\n",
33223321
"import jax.numpy as jnp\n",
33233322
"import pandas as pd\n",
@@ -3802,7 +3801,6 @@
38023801
],
38033802
"source": [
38043803
"# @title\n",
3805-
"from jax import dtypes\n",
38063804
"import jax\n",
38073805
"import jax.numpy as jnp\n",
38083806
"import pandas as pd\n",

docs/jep/9407-type-promotion.md

-2
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,6 @@ display.HTML(table.to_html())
908908
:tags: [hide-input]
909909
910910
# @title
911-
from jax import dtypes
912911
import jax
913912
import jax.numpy as jnp
914913
import pandas as pd
@@ -963,7 +962,6 @@ display.HTML(table.to_html())
963962
:tags: [hide-input]
964963
965964
# @title
966-
from jax import dtypes
967965
import jax
968966
import jax.numpy as jnp
969967
import pandas as pd

docs/notebooks/Common_Gotchas_in_JAX.ipynb

+5-7
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@
226226
],
227227
"source": [
228228
"import jax.numpy as jnp\n",
229-
"import jax.lax as lax\n",
230229
"from jax import make_jaxpr\n",
231230
"\n",
232231
"# lax.fori_loop\n",
@@ -1031,7 +1030,6 @@
10311030
}
10321031
],
10331032
"source": [
1034-
"from jax import random\n",
10351033
"key = random.key(0)\n",
10361034
"key"
10371035
]
@@ -1105,8 +1103,8 @@
11051103
"print(\"old key\", key)\n",
11061104
"key, subkey = random.split(key)\n",
11071105
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
1108-
"print(\" \\---SPLIT --> new key \", key)\n",
1109-
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
1106+
"print(r\" \\---SPLIT --> new key \", key)\n",
1107+
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
11101108
]
11111109
},
11121110
{
@@ -1140,8 +1138,8 @@
11401138
"print(\"old key\", key)\n",
11411139
"key, subkey = random.split(key)\n",
11421140
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
1143-
"print(\" \\---SPLIT --> new key \", key)\n",
1144-
"print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
1141+
"print(r\" \\---SPLIT --> new key \", key)\n",
1142+
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
11451143
]
11461144
},
11471145
{
@@ -1701,7 +1699,7 @@
17011699
],
17021700
"source": [
17031701
"init_val = 0\n",
1704-
"cond_fun = lambda x: x<10\n",
1702+
"cond_fun = lambda x: x < 10\n",
17051703
"body_fun = lambda x: x+1\n",
17061704
"lax.while_loop(cond_fun, body_fun, init_val)\n",
17071705
"# --> array(10, dtype=int32)"

docs/notebooks/Common_Gotchas_in_JAX.md

+5-7
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or
130130
:outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
131131
132132
import jax.numpy as jnp
133-
import jax.lax as lax
134133
from jax import make_jaxpr
135134
136135
# lax.fori_loop
@@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__:
471470
:id: yPHE7KTWgAWs
472471
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
473472
474-
from jax import random
475473
key = random.key(0)
476474
key
477475
```
@@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne
504502
print("old key", key)
505503
key, subkey = random.split(key)
506504
normal_pseudorandom = random.normal(subkey, shape=(1,))
507-
print(" \---SPLIT --> new key ", key)
508-
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
505+
print(r" \---SPLIT --> new key ", key)
506+
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
509507
```
510508

511509
+++ {"id": "tqtFVE4MthO3"}
@@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random
519517
print("old key", key)
520518
key, subkey = random.split(key)
521519
normal_pseudorandom = random.normal(subkey, shape=(1,))
522-
print(" \---SPLIT --> new key ", key)
523-
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
520+
print(r" \---SPLIT --> new key ", key)
521+
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
524522
```
525523

526524
+++ {"id": "0KLYUluz3lN3"}
@@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val):
805803
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
806804
807805
init_val = 0
808-
cond_fun = lambda x: x<10
806+
cond_fun = lambda x: x < 10
809807
body_fun = lambda x: x+1
810808
lax.while_loop(cond_fun, body_fun, init_val)
811809
# --> array(10, dtype=int32)

docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb

+2-4
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@
247247
}
248248
],
249249
"source": [
250-
"import jax.numpy as jnp\n",
251250
"\n",
252251
"def log1pexp(x):\n",
253252
" return jnp.log(1. + jnp.exp(x))\n",
@@ -984,7 +983,7 @@
984983
" (a, x_star, x_star_bar),\n",
985984
" x_star_bar))\n",
986985
" return a_bar, jnp.zeros_like(x_star)\n",
987-
" \n",
986+
"\n",
988987
"def rev_iter(f, packed, u):\n",
989988
" a, x_star, x_star_bar = packed\n",
990989
" _, vjp_x = vjp(lambda x: f(a, x), x_star)\n",
@@ -1884,7 +1883,6 @@
18841883
}
18851884
],
18861885
"source": [
1887-
"from jax import vjp\n",
18881886
"\n",
18891887
"y, f_vjp = vjp(f, 3.)\n",
18901888
"print(y)"
@@ -1983,7 +1981,7 @@
19831981
" return x, x\n",
19841982
"\n",
19851983
"def debug_bwd(x, g):\n",
1986-
" import pdb; pdb.set_trace()\n",
1984+
" pdb.set_trace()\n",
19871985
" return g\n",
19881986
"\n",
19891987
"debug.defvjp(debug_fwd, debug_bwd)"

docs/notebooks/Custom_derivative_rules_for_Python_code.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo
145145
:id: 6lWbTvs40ET-
146146
:outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e
147147
148-
import jax.numpy as jnp
149148
150149
def log1pexp(x):
151150
return jnp.log(1. + jnp.exp(x))
@@ -524,7 +523,7 @@ def fixed_point_rev(f, res, x_star_bar):
524523
(a, x_star, x_star_bar),
525524
x_star_bar))
526525
return a_bar, jnp.zeros_like(x_star)
527-
526+
528527
def rev_iter(f, packed, u):
529528
a, x_star, x_star_bar = packed
530529
_, vjp_x = vjp(lambda x: f(a, x), x_star)
@@ -965,7 +964,6 @@ print(grad(f)(3.))
965964
:id: s1Pn_qCIODcF
966965
:outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9
967966
968-
from jax import vjp
969967
970968
y, f_vjp = vjp(f, 3.)
971969
print(y)
@@ -1015,7 +1013,7 @@ def debug_fwd(x):
10151013
return x, x
10161014
10171015
def debug_bwd(x, g):
1018-
import pdb; pdb.set_trace()
1016+
pdb.set_trace()
10191017
return g
10201018
10211019
debug.defvjp(debug_fwd, debug_bwd)

docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb

-2
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
},
3131
"outputs": [],
3232
"source": [
33-
"import os\n",
3433
"\n",
35-
"import functools\n",
3634
"from typing import Optional\n",
3735
"\n",
3836
"import numpy as np\n",

docs/notebooks/Distributed_arrays_and_automatic_parallelization.md

-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ This tutorial discusses parallelism via `jax.Array`, the unified array object mo
2626
```{code-cell}
2727
:id: FNxScTfq3vGF
2828
29-
import os
3029
31-
import functools
3230
from typing import Optional
3331
3432
import numpy as np

0 commit comments

Comments
 (0)