@@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or
130
130
:outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
131
131
132
132
import jax.numpy as jnp
133
- import jax.lax as lax
134
133
from jax import make_jaxpr
135
134
136
135
# lax.fori_loop
@@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__:
471
470
:id: yPHE7KTWgAWs
472
471
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
473
472
474
- from jax import random
475
473
key = random.key(0)
476
474
key
477
475
```
@@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne
504
502
print("old key", key)
505
503
key, subkey = random.split(key)
506
504
normal_pseudorandom = random.normal(subkey, shape=(1,))
507
- print(" \---SPLIT --> new key ", key)
508
- print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
505
+ print(r " \---SPLIT --> new key ", key)
506
+ print(r " \--> new subkey", subkey, "--> normal", normal_pseudorandom)
509
507
```
510
508
511
509
+++ {"id": "tqtFVE4MthO3"}
@@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random
519
517
print("old key", key)
520
518
key, subkey = random.split(key)
521
519
normal_pseudorandom = random.normal(subkey, shape=(1,))
522
- print(" \---SPLIT --> new key ", key)
523
- print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
520
+ print(r " \---SPLIT --> new key ", key)
521
+ print(r " \--> new subkey", subkey, "--> normal", normal_pseudorandom)
524
522
```
525
523
526
524
+++ {"id": "0KLYUluz3lN3"}
@@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val):
805
803
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
806
804
807
805
init_val = 0
808
- cond_fun = lambda x: x< 10
806
+ cond_fun = lambda x: x < 10
809
807
body_fun = lambda x: x+1
810
808
lax.while_loop(cond_fun, body_fun, init_val)
811
809
# --> array(10, dtype=int32)
0 commit comments