Skip to content

Commit efbc9de

Browse files
committed
JAX Tutorials: Fixing RNG JAX/PyTorch clash
1 parent a30e19d commit efbc9de

File tree

10 files changed

+1344
-18
lines changed

10 files changed

+1344
-18
lines changed

docs/tutorial_notebooks/JAX/tutorial11/NF_image_modeling.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
"import jax\n",
108108
"import jax.numpy as jnp\n",
109109
"from jax import random\n",
110+
"# Seeding for random operations\n",
111+
"main_rng = random.PRNGKey(42)\n",
110112
"\n",
111113
"## Flax (NN in JAX)\n",
112114
"try:\n",
@@ -136,9 +138,6 @@
136138
"# Path to the folder where the pretrained models are saved\n",
137139
"CHECKPOINT_PATH = \"../../saved_models/tutorial11_jax\"\n",
138140
"\n",
139-
"# Seeding for random operations\n",
140-
"main_rng = random.PRNGKey(42)\n",
141-
"\n",
142141
"print(\"Device:\", jax.devices()[0])"
143142
]
144143
},

docs/tutorial_notebooks/JAX/tutorial12/Autoregressive_Image_Modeling.ipynb

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
"import jax\n",
105105
"import jax.numpy as jnp\n",
106106
"from jax import random\n",
107+
"# Seeding for random operations\n",
108+
"main_rng = random.PRNGKey(42)\n",
107109
"\n",
108110
"## Flax (NN in JAX)\n",
109111
"try:\n",

docs/tutorial_notebooks/JAX/tutorial13/GAN.ipynb

+1,323
Large diffs are not rendered by default.

docs/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
"import jax\n",
108108
"import jax.numpy as jnp\n",
109109
"from jax import random\n",
110+
"# Seeding for random operations\n",
111+
"main_rng = random.PRNGKey(42)\n",
110112
"\n",
111113
"## Flax (NN in JAX)\n",
112114
"try:\n",
@@ -140,9 +142,6 @@
140142
"# Path to the folder where the pretrained models are saved\n",
141143
"CHECKPOINT_PATH = \"../../saved_models/tutorial15_jax\"\n",
142144
"\n",
143-
"# Seeding for random operations\n",
144-
"main_rng = random.PRNGKey(42)\n",
145-
"\n",
146145
"print(\"Device:\", jax.devices()[0])"
147146
]
148147
},

docs/tutorial_notebooks/JAX/tutorial17/SimCLR.ipynb

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
"import jax\n",
105105
"import jax.numpy as jnp\n",
106106
"from jax import random\n",
107+
"# Seeding for random operations\n",
108+
"main_rng = random.PRNGKey(42)\n",
107109
"\n",
108110
"## Flax (NN in JAX)\n",
109111
"try:\n",

docs/tutorial_notebooks/JAX/tutorial4/Optimization_and_Initialization.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"import jax.numpy as jnp\n",
8585
"from jax import random\n",
8686
"from jax.tree_util import tree_map\n",
87+
"# Seeding for random operations\n",
88+
"main_rng = random.PRNGKey(42)\n",
8789
"\n",
8890
"## Flax (NN in JAX)\n",
8991
"try:\n",
@@ -229,7 +231,7 @@
229231
"# We define a set of data loaders that we can use for various purposes later.\n",
230232
"# Note that for actually training a model, we will use different data loaders\n",
231233
"# with a lower batch size.\n",
232-
"train_loader = data.DataLoader(train_set,\n",
234+
"train_loader = data.DataLoader(train_set,r\n",
233235
" batch_size=1024,\n",
234236
" shuffle=False,\n",
235237
" drop_last=False,\n",

docs/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
"import jax\n",
9393
"import jax.numpy as jnp\n",
9494
"from jax import random\n",
95+
"# Seeding for random operations\n",
96+
"main_rng = random.PRNGKey(42)\n",
9597
"\n",
9698
"## Flax (NN in JAX)\n",
9799
"try:\n",
@@ -144,9 +146,6 @@
144146
"# Path to the folder where the pretrained models are saved\n",
145147
"CHECKPOINT_PATH = \"../../saved_models/tutorial5_jax\"\n",
146148
"\n",
147-
"# Seeding for random operations\n",
148-
"main_rng = random.PRNGKey(42)\n",
149-
"\n",
150149
"print(\"Device:\", jax.devices()[0])"
151150
]
152151
},

docs/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
"import jax\n",
102102
"import jax.numpy as jnp\n",
103103
"from jax import random\n",
104+
"# Seeding for random operations\n",
105+
"main_rng = random.PRNGKey(42)\n",
104106
"\n",
105107
"## Flax (NN in JAX)\n",
106108
"try:\n",
@@ -131,9 +133,6 @@
131133
"# Path to the folder where the pretrained models are saved\n",
132134
"CHECKPOINT_PATH = \"../../saved_models/tutorial6_jax\"\n",
133135
"\n",
134-
"# Seeding for random operations\n",
135-
"main_rng = random.PRNGKey(42)\n",
136-
"\n",
137136
"print(\"Device:\", jax.devices()[0])"
138137
]
139138
},

docs/tutorial_notebooks/JAX/tutorial7/GNN_overview.ipynb

+4-5
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
"import jax\n",
9191
"import jax.numpy as jnp\n",
9292
"from jax import random\n",
93+
"# Seeding for random operations\n",
94+
"main_rng = random.PRNGKey(42)\n",
9395
"\n",
9496
"## Flax (NN in JAX)\n",
9597
"try:\n",
@@ -119,9 +121,6 @@
119121
"# Path to the folder where the pretrained models are saved\n",
120122
"CHECKPOINT_PATH = \"../../saved_models/tutorial7_jax\"\n",
121123
"\n",
122-
"# Seeding for random operations\n",
123-
"main_rng = random.PRNGKey(42)\n",
124-
"\n",
125124
"print(\"Device:\", jax.devices()[0])"
126125
]
127126
},
@@ -535,7 +534,7 @@
535534
],
536535
"metadata": {
537536
"kernelspec": {
538-
"display_name": "Python 3 (ipykernel)",
537+
"display_name": "Python 3",
539538
"language": "python",
540539
"name": "python3"
541540
},
@@ -549,7 +548,7 @@
549548
"name": "python",
550549
"nbconvert_exporter": "python",
551550
"pygments_lexer": "ipython3",
552-
"version": "3.8.2"
551+
"version": "3.7.9"
553552
}
554553
},
555554
"nbformat": 4,

docs/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.ipynb

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
"import jax\n",
101101
"import jax.numpy as jnp\n",
102102
"from jax import random\n",
103+
"# Seeding for random operations\n",
104+
"main_rng = random.PRNGKey(42)\n",
103105
"\n",
104106
"## Flax (NN in JAX)\n",
105107
"try:\n",

0 commit comments

Comments
 (0)