forked from AI-Hypercomputer/jetstream-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_test.py
294 lines (247 loc) · 9.06 KB
/
jax_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import jax
import jax.numpy as jnp
import functools
def test1():
@functools.partial(jax.jit,
static_argnums=(2, ))
def f(x, i, issum):
if issum:
return x + i
else:
return x - i
x = jnp.ones( (10, ))
print(f(x, 0, True))
print('cache', f._cache_size())
print(f(x, 1, False))
print('cache', f._cache_size())
class A:
def __init__(self, a):
self.a = a
def incr(self):
self.a += 1
@jax.jit
def f(x):
a = A(x)
a.incr()
return a.a
print(f(x))
print(f(x))
print(f(x))
from jax.sharding import PositionalSharding
from jax.experimental import mesh_utils
def test2():
batch, seq, heads, dim = 96, 2048, 40, 128
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
sharding = sharding.reshape((1, 8, 1, 1))
val_sharding = sharding.reshape((1, 8, 1, 1))
caches_k = jnp.zeros((batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16)
caches_v = jnp.zeros((batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16)
def insert_cache(caches_k, caches_v, pos, key, val):
# val is of shape b,h,d
return caches_k.at[:, :, pos, :].set(key.squeeze(2)), caches_v.at[:, :, pos, :].set(val.squeeze(2))
#return caches_k.at[:, :, pos:pos+1, :].set(key), caches_v.at[:, :, pos:pos+1, :].set(val)
def insert_cache2(caches_k, caches_v, pos, key, val):
# val is of shape b,h,d
seqlen = caches_k.shape[2]
val = jnp.broadcast_to(val, caches_k.shape)
iota = jnp.arange(0, seqlen).reshape(1,1, seqlen, 1)
iota = jnp.broadcast_to(iota, caches_k.shape)
pos = jnp.broadcast_to(pos, (seqlen, ))
return (jnp.where(iota == pos.reshape(1,1, seqlen, 1), caches_k, key),
jnp.where(iota == pos.reshape(1,1, seqlen, 1), caches_v, val))
def insert_cache3(caches_k, caches_v, pos, key, val):
return (
jax.lax.dynamic_update_slice(caches_k, key, (0, 0, pos, 0)),
jax.lax.dynamic_update_slice(caches_k, key, (0, 0, pos, 0)),
)
insert_cache = jax.jit(
insert_cache,
donate_argnums=(0, 1)
)
insert_cache2 = jax.jit(
insert_cache2,
donate_argnums=(0, 1)
)
insert_cache3 = jax.jit(
insert_cache3,
donate_argnums=(0, 1)
)
subkey = jax.random.PRNGKey(234)
to_insert = jax.device_put(
jax.random.normal(
subkey, (batch, heads, 1, dim), dtype=jnp.bfloat16),
device=val_sharding).block_until_ready()
j = jnp.int32(7).block_until_ready()
print('====1====')
print(insert_cache.lower(caches_k, caches_v, j, to_insert, to_insert).as_text())
print('====2====')
print(insert_cache2.lower(caches_k, caches_v, j, to_insert, to_insert).as_text())
print('====3====')
print(insert_cache3.lower(caches_k, caches_v, j, to_insert, to_insert).as_text())
rng = jax.random.PRNGKey(0)
for func in (insert_cache, insert_cache2, insert_cache3):
for i in range(10):
all_times = 0
for j in range(40):
rng, subkey = jax.random.split(rng)
key = jax.device_put(
jax.random.normal(
subkey, (batch, heads, 1, dim), dtype=jnp.bfloat16),
device=val_sharding).block_until_ready()
val = jax.device_put(
jax.random.normal(
subkey, (batch, heads, 1, dim), dtype=jnp.bfloat16),
device=val_sharding).block_until_ready()
j = jnp.int32(j).block_until_ready()
start = time.perf_counter()
caches_k, caches_v = func(caches_k, caches_v, j, key, val)
caches_k.block_until_ready()
caches_v.block_until_ready()
end = time.perf_counter()
all_times += (end - start)
print(func.__name__, 'time is', all_times)
def test3():
import torch
import torch_xla2
import torch_xla2.extra
x = jnp.ones((10, 10, 10))
y = jnp.ones((10, 10, 10))
def f(x, y):
return torch.einsum("ijm, ijn -> imn", [x, y])
def g(x, y):
return jnp.einsum("ijm, ijn -> imn", x, y)
print('====== 1 ======')
with torch_xla2.tensor.XLAFunctionMode():
print(jax.jit(torch_xla2.extra.jax_view(f)).lower(x, y).as_text())
print('====== 2 ======')
print(jax.jit(g).lower(x, y).as_text())
from flax import struct
class A:
def __init__(self, a):
self.a = a
def plus(self):
self.a = self.a + 1
def flatten_A(x):
return (x.a, ), None
def unflatten_A(aux_data, flat_content):
import pdb; pdb.set_trace()
return A(*flat_content)
jax.tree_util.register_pytree_node(A, flatten_A, unflatten_A)
import functools
@functools.partial(jax.jit, donate_argnums=(0, ))
def f(a):
a.plus()
return a
def test4():
a = A(a=jnp.zeros((2, )))
b = f(a)
print(b.a)
def test5():
batch, seq, heads, dim = 96, 2048, 40, 128
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
sharding = sharding.reshape((1, 8, 1, 1))
val_sharding = sharding.reshape((1, 8, 1, 1))
caches_k = jnp.zeros((batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16)
caches_v = jnp.zeros((batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16)
def insert_cache(
cache,
new_entry,
slot,
head_indexes,
update_indexes
):
res = cache.at[slot, head_indexes, update_indexes.reshape(1, -1), :].set(new_entry)
res = jax.lax.with_sharding_constraint(res, sharding)
return res
def insert_cache2(
cache,
new_entry,
slot,
head_indexes,
update_indexes
):
res = cache.at[slot, :, update_indexes, :].set(
jnp.transpose(new_entry.squeeze(0), (1, 0, 2)))
res = jax.lax.with_sharding_constraint(res, sharding)
return res
def insert_cache3(
cache,
new_entry,
slot,
head_indexes,
update_indexes
):
index = jnp.expand_dims(jnp.full_like(update_indexes, slot), -1)
update_indexes = jnp.expand_dims(update_indexes, -1)
combined = jnp.concatenate([index, update_indexes], axis=-1)
dimension_numbers = jax.lax.ScatterDimensionNumbers(
update_window_dims=[0, 2],
inserted_window_dims=[0, 2],
scatter_dims_to_operand_dims=[1, 3],
)
res = jax.lax.scatter(
cache, combined, new_entry.squeeze(0), dimension_numbers,
unique_indices=True,
indices_are_sorted=True,
mode='promise_in_bounds')
res = jax.lax.with_sharding_constraint(res, sharding)
return res
insert_cache = jax.jit(
insert_cache,
donate_argnums=(0, 1)
)
insert_cache2 = jax.jit(
insert_cache2,
donate_argnums=(0, 1)
)
insert_cache3 = jax.jit(
insert_cache3,
donate_argnums=(0, 1)
)
insert_seqlen = 1024
subkey = jax.random.PRNGKey(234)
to_insert = jax.device_put(
jax.random.normal(
subkey, (1, heads, insert_seqlen, dim), dtype=jnp.bfloat16),
device=val_sharding).block_until_ready()
j = jnp.int32(7).block_until_ready()
update_indexes = (jnp.arange(-insert_seqlen, 0) + 7) % 1024
update_indexes = update_indexes
head_indexes = jnp.arange(heads).reshape(1, -1, 1)
rng = jax.random.PRNGKey(0)
for func in (insert_cache3, ):
print(f'===={func.__name__}====')
print(func.lower(
caches_k, to_insert, j, head_indexes, update_indexes).as_text())
for func in (insert_cache, insert_cache2, insert_cache3):
for i in range(10):
all_times = 0
for j in range(40):
rng, subkey = jax.random.split(rng)
key = jax.device_put(
jax.random.normal(
subkey, (1, heads, insert_seqlen, dim), dtype=jnp.bfloat16),
device=val_sharding).block_until_ready()
j = jnp.int32(j).block_until_ready()
start = time.perf_counter()
caches_k = func(
caches_k, to_insert, j, head_indexes, update_indexes)
caches_k.block_until_ready()
end = time.perf_counter()
all_times += (end - start)
print(func.__name__, 'time is', all_times)
test5()