forked from AI-Hypercomputer/jetstream-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
1200 lines (1086 loc) · 38.7 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement Jet Engine API."""
from typing import Any, List, Optional, Tuple, Union, Callable
import threading
import functools
import os
import glob
from etils import epath
from flax import struct
import jax
from jax import numpy as jnp
from safetensors import safe_open
import torch
import numpy as np
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
from jetstream.engine import sampling_utils
import torch_xla2
from torch.utils import _pytree as pytree
from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt import torchjax
from jetstream_pt.hf_tokenizer import HFTokenizerAdapter
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
from jetstream_pt.page_attention_manager import PageAttentionManager
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model
from absl import flags
FLAGS = flags.FLAGS
Mesh = jax.sharding.Mesh
P = jax.sharding.PartitionSpec
Params = jax.Array
PrefillInputs = jax.Array
@struct.dataclass
# pylint: disable-next=all
class Prefix:
token: jax.Array # [1, seqlen]
caches: List[Tuple[jax.Array, jax.Array]]
seq_len: int # true seqlen front pad
@struct.dataclass
# pylint: disable-next=all
class DecodeState:
tokens: jax.Array # [batch_size, seqlen]
caches: List[Tuple[jax.Array, jax.Array]]
cache_scales: List[
Tuple[jax.Array, jax.Array]
] # only present in quantized kv
current_position: int
lens: jax.Array # [batch_size, 1], the output token length
start: jax.Array # [batch_size, 1], the starting pos for each slot
input_pos: jax.Array # [batch_size, 1] input pos for each slot
mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid
# NOTE model specific
# pylint: disable-next=all
class PyTorchEngine(engine_api.Engine):
"""Wraps functions to the Jet Engine API format."""
def __init__(
self,
pt_model: torch.nn.Module,
env: JetEngineEnvironment,
weights=None,
):
self.pt_model = pt_model
self.env = env
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
self.rng = jax.random.PRNGKey(0)
self.weights = weights
self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
self.replicated = env.sharding_by_axis(-1) # replicated
self.cache_sharding = self.env.cache_sharding
jax.config.update("jax_enable_x64", False)
self.prefill_cache_sharding = self.env.prefill_cache_sharding
self.prefill = jax.jit(
self.prefill,
out_shardings=(self.get_prefix_destination_sharding(), None),
)
self.insert = jax.jit(
self.insert,
donate_argnums=(0, 1),
out_shardings=self.get_decode_state_sharding(),
)
self.generate = jax.jit(
self.generate_impl,
donate_argnums=(1,),
out_shardings=(self.get_decode_state_sharding(), None),
)
if self.env.page_attention:
max_pages_per_sequence = (
self.env._data.cache_sequence_length
// self.env._data.paged_attention_page_size
)
assert (
self.env._data.cache_sequence_length
% self.env._data.paged_attention_page_size
== 0
), f"cache_sequence_length {self.env._data.cache_sequence_length} should divide paged_attention_page_size {self.env._data.paged_attention_page_size}"
self.page_attention_manager = PageAttentionManager(
batch_size=self.env.batch_size,
paged_attention_total_num_pages=self.env._data.paged_attention_total_num_pages,
paged_attention_page_size=self.env._data.paged_attention_page_size,
max_pages_per_sequence=max_pages_per_sequence,
)
self._insert_page_attention_jit = jax.jit(
self._insert_page_attention,
donate_argnums=(0, 1),
out_shardings=self.get_decode_state_sharding(),
)
self.insert = self.insert_page_attention_with_reservation
self.generate_jit = jax.jit(
self.generate_impl,
donate_argnums=(1,),
out_shardings=(self.get_decode_state_sharding(), None),
)
self.generate = self.generate_page_attention
# self._insert_wrap = jax.jit(self._insert_wrap, donate_argnums=(0, 1),
# out_shardings=self.get_decode_state_sharding())
# self._insert_no_wrap = jax.jit(
# self._insert_no_wrap,
# donate_argnums=(0, 1),
# out_shardings=self.get_decode_state_sharding())
self._lock = threading.RLock()
def init_decode_state(
self,
) -> DecodeState:
caches_obj = self.env.make_caches_generate()
caches = [c.state() for c in caches_obj]
scalers = []
if self.env.quant_config.enable_kv_quantization:
scalers = [c.scalers() for c in caches_obj]
return DecodeState(
jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32),
caches,
scalers,
self.env.starting_position,
jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens
jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos
jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos
jnp.full(
(self.env.batch_size, self.env.cache_sequence_length),
float("-inf"),
dtype=self.default_dtype,
), # mask
)
# pylint: disable-next=all
def _call_model_generate(
self,
weights,
tokens,
input_indexes,
caches,
cache_scales,
mask,
start,
input_pos,
ragged_batch_index,
ragged_block_index,
page_token_indices,
):
if self.env.quant_config.enable_kv_quantization:
caches_obj = [
cache_manager.Int8KVCacheGenerate(
k, v, ks, vs, input_indexes, env=self.env
)
for (k, v), (ks, vs) in torchjax.to_torch(
list(zip(caches, cache_scales))
)
]
elif self.env.page_attention:
caches_obj = [
cache_manager.PageKVCacheGenerate(
k,
v,
self.page_attention_manager,
page_token_indices,
self.cache_sharding,
env=self.env,
)
for k, v in torchjax.to_torch(caches)
]
else:
caches_obj = [
cache_manager.KVCacheGenerate(
k, v, input_indexes, self.cache_sharding, env=self.env
)
for k, v in torchjax.to_torch(caches)
]
mask = jnp.expand_dims(mask, (1, 2))
args = (
tokens,
input_pos,
caches_obj,
mask,
start,
ragged_batch_index,
ragged_block_index,
)
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
# The mode is needed so that tensors created inside of
# the model (such as via torch.ones etc) also have the right type
res = torch.func.functional_call(self.pt_model, paramst, argst)
updated_caches = []
for c in caches_obj:
c.finalize()
updated_caches.append(c.state())
scales = []
if self.env.quant_config.enable_kv_quantization:
scales = [c.scalers() for c in caches_obj]
return torchjax.from_torch((res, updated_caches, scales))
@functools.partial(
jax.jit,
static_argnums=(0,),
)
def _call_model_prefill(self, weights, tokens, input_indexes):
caches = [
cache_manager.KVCachePrefill(
self.env.quant_config.enable_kv_quantization
)
for _ in self.pt_model.layers
]
mask = jnp.full(
(1, self.env.n_reps, tokens.shape[1], tokens.shape[1]),
float("-inf"),
dtype=self.default_dtype,
)
mask = jnp.triu(mask, k=1)
mask = mask.reshape(1, 1, -1, tokens.shape[1])
start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32)
args = (tokens, input_indexes, caches, mask, start)
paramst, argst = torchjax.to_torch((weights, args))
with self._lock:
with torch_xla2.default_env():
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
caches_res = [c.state() for c in caches]
return torchjax.from_torch((res, caches_res))
def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
logits = jnp.expand_dims(logits, 0)
return (
sampling_utils.sampling(
logits[:, -1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)
.reshape(batch_size, -1)
.astype(jnp.int32)
)
def prefill(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
if isinstance(padded_tokens, jax.Array):
batched_token = padded_tokens.reshape(1, -1)
else:
raise TypeError(
"Input tokens should be of type Jax Array, but receiving:"
" {prefill_inputs}"
)
seq_len = padded_tokens.shape[0]
input_indexes = jnp.arange(0, seq_len)
logits, updated_caches = self._call_model_prefill(
params,
batched_token,
input_indexes,
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0] # seqlen, num words
if sampler:
token = sampler(logits[true_length - 1])
else:
token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)
token_out = jnp.reshape(token, (1, 1))
data = jnp.concatenate(
[
token_out, # First token
jnp.ones_like(token_out), # validity of first token
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
],
axis=-1,
)
length = token_out.shape[1]
result = engine_api.ResultTokens(
data=data,
tokens_idx=(0, length),
valid_idx=(length, 2 * length),
length_idx=(2 * length, 2 * length + 1),
samples_per_slot=1,
)
# truncate to true_length didnt work need to be out side of jit
# caches = [
# (jax.lax.dynamic_slice_in_dim(
# k, seq_len - true_length, true_length, axis=2),
# jax.lax.dynamic_slice_in_dim(
# v, seq_len - true_length, true_length, axis=2))
# for k, v in updated_caches
# ]
return Prefix(token, updated_caches, true_length), result
def shrink_prefix(
self,
prefix: Prefix,
new_length: int, # pylint: disable=unused-argument
) -> Prefix:
"""shrink prefix"""
return prefix
# pylint: disable-next=all
def _insert_no_wrap(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
):
scales = []
caches = []
if self.env.ring_buffer:
current_pos = decode_state.current_position
else:
current_pos = prefix.seq_len
pos = current_pos - prefix.seq_len
tokens = decode_state.tokens.at[slot].set(prefix.token)
x = jnp.arange(0, self.env.cache_sequence_length)
cond = jnp.logical_and(x < current_pos, x >= pos)
mask_insert = jnp.where(cond, 0, float("-inf"))
mask = decode_state.mask.at[slot].set(mask_insert)
start = decode_state.start.at[slot].set(
pos % self.env.cache_sequence_length
)
input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len)
if not self.env.quant_config.enable_kv_quantization:
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, new_entry, update_index):
res = jax.lax.dynamic_update_slice(
cache,
new_entry,
update_index,
)
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
return res
if self.env.generate_cache_stacked:
caches = decode_state.caches
for idx, (newk, newv) in enumerate(prefix.caches):
update_index = [idx, slot, 0, pos, 0]
newk = jnp.expand_dims(newk, 0)
newv = jnp.expand_dims(newv, 0)
caches = [
(
insert(caches[0][0], newk, update_index),
insert(caches[0][1], newv, update_index),
)
]
else:
update_index = [slot, 0, pos, 0]
caches = [
(insert(k, newk, update_index), insert(v, newv, update_index))
for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches)
]
else:
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, scaler, new_entry, update_index):
reduce_axis = (-3, -1)
vals, scales, _ = torchjax.call_torch(
quantize.quantize_tensor, new_entry, reduce_axis
)
if self.env.generate_cache_stacked:
vals = jnp.expand_dims(vals, 0)
scales = jnp.expand_dims(scales, 0)
new_scaler = jax.lax.dynamic_update_slice(
scaler,
scales,
update_index,
)
new_scaler = jax.lax.with_sharding_constraint(
new_scaler, self.replicated
)
res = jax.lax.dynamic_update_slice(
cache,
vals,
update_index,
)
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
return res, new_scaler
if self.env.generate_cache_stacked:
cache_k, k_scale = (
decode_state.caches[0][0],
decode_state.cache_scales[0][0],
)
cache_v, v_scale = (
decode_state.caches[0][1],
decode_state.cache_scales[0][1],
)
for idx, (newk, newv) in enumerate(prefix.caches):
update_index = [idx, slot, 0, pos, 0]
# newk = jnp.expand_dims(newk, 0)
# newv = jnp.expand_dims(newv, 0)
cache_k, k_scale = insert(cache_k, k_scale, newk, update_index)
cache_v, v_scale = insert(cache_v, v_scale, newv, update_index)
caches = [(cache_k, cache_v)]
scales = [(k_scale, v_scale)]
else:
update_index = [slot, 0, pos, 0]
for (k, v), (kscaler, vscaler), (newk, newv) in zip(
decode_state.caches, decode_state.cache_scales, prefix.caches
):
kcache, kscale = insert(k, kscaler, newk, update_index)
vcache, vscale = insert(v, vscaler, newv, update_index)
caches.append((kcache, vcache))
scales.append((kscale, vscale))
lens = decode_state.lens.at[slot].set(1)
return DecodeState(
tokens,
caches,
scales,
decode_state.current_position,
lens,
start,
input_pos,
mask,
)
# pylint: disable-next=all
def _insert_wrap(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
): # returns Decode State
start_insert = decode_state.current_position - prefix.seq_len
tokens = decode_state.tokens.at[slot].set(prefix.token)
start_insert = start_insert % self.env.cache_sequence_length
# pos < 0
update_indexes = (
jnp.arange(0, prefix.caches[0][0].shape[2]) + start_insert
) % self.env.cache_sequence_length
update_indexes = update_indexes.reshape(1, -1)
x = jnp.arange(0, self.env.cache_sequence_length)
cond = jax.lax.cond(
decode_state.current_position > start_insert,
lambda x, start_insert, current_position: jnp.logical_and(
x >= start_insert, x < current_position
),
lambda x, start_insert, current_position: jnp.logical_or(
x >= start_insert, x < current_position
),
x,
start_insert,
decode_state.current_position,
)
mask_insert = jnp.where(cond, 0, float("-inf"))
mask = decode_state.mask.at[slot].set(mask_insert)
start = decode_state.start.at[slot].set(start_insert)
input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len)
old_caches = decode_state.caches
old_scales = decode_state.cache_scales
cache_inserts = prefix.caches
scales = []
caches = []
if not self.env.quant_config.enable_kv_quantization:
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, new_entry):
new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2))
res = cache.at[slot, :, update_indexes, :].set(new_entry)
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
return res
caches = [
(insert(k, newk), insert(v, newv))
for (k, v), (newk, newv) in zip(old_caches, cache_inserts)
]
else:
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, scaler, new_entry):
new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2))
reduce_axis = (1, 2)
vals, scales, _ = torchjax.call_torch(
quantize.quantize_tensor, new_entry, reduce_axis
)
new_scaler = scaler.at[slot, :, update_indexes, :].set(scales)
new_scaler = jax.lax.with_sharding_constraint(
new_scaler, self.replicated
)
res = cache.at[slot, :, update_indexes, :].set(vals)
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
return res, new_scaler
caches = []
for (k, v), (kscaler, vscaler), (newk, newv) in zip(
old_caches, old_scales, cache_inserts
):
kcache, kscale = insert(k, kscaler, newk)
vcache, vscale = insert(v, vscaler, newv)
caches.append((kcache, vcache))
scales.append((kscale, vscale))
lens = decode_state.lens.at[slot].set(1)
return DecodeState(
tokens,
caches,
scales,
decode_state.current_position,
lens,
start,
input_pos,
mask,
)
def _insert_page_attention(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
num_pages: int,
update_indexes: jax.Array,
tep_kv: jax.Array,
):
caches = self.page_attention_manager.insert_prefill_cache(
prefill_caches=prefix.caches,
decode_caches=decode_state.caches,
update_indexes=update_indexes,
tep_kv=tep_kv,
sharding=self.cache_sharding,
)
current_pos = prefix.seq_len
pos = current_pos - prefix.seq_len
tokens = decode_state.tokens.at[slot].set(prefix.token)
x = jnp.arange(0, self.env.cache_sequence_length)
cond = jnp.logical_and(x < current_pos, x >= pos)
mask_insert = jnp.where(cond, 0, float("-inf"))
mask = decode_state.mask.at[slot].set(mask_insert)
start = decode_state.start.at[slot].set(
pos % self.env.cache_sequence_length
)
input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len)
scales = None
lens = decode_state.lens.at[slot].set(1)
return DecodeState(
tokens,
caches,
scales,
decode_state.current_position,
lens,
start,
input_pos,
mask,
)
def insert(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
) -> DecodeState:
if self.env.ring_buffer:
start_insert = decode_state.current_position - prefix.seq_len
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen
return jax.lax.cond(
jnp.logical_and(
start_insert >= 0, end_insert < self.env.cache_sequence_length
),
self._insert_no_wrap,
self._insert_wrap,
prefix,
decode_state,
slot,
)
# Left aligned, starts from 0, guaranteed no wrap
else:
return self._insert_no_wrap(prefix, decode_state, slot)
def insert_page_attention_with_reservation(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
) -> DecodeState:
num_pages, np_update_indexes = (
self.page_attention_manager.reserve_pages_insert(slot, prefix.seq_len)
)
update_indexes = jnp.array(np_update_indexes)
_, kv_heads, _, dim = prefix.caches[0][0].shape
tep_kv = jnp.zeros(
(
kv_heads,
num_pages * self.page_attention_manager.paged_attention_page_size,
dim,
),
dtype=self.default_dtype,
device=self.prefill_cache_sharding,
)
return self._insert_page_attention_jit(
prefix, decode_state, slot, num_pages, update_indexes, tep_kv
)
def precompute_ragged_block_indices(self, decode_state: DecodeState):
"""Precompute the ragged attention block indices. Ragged attention iterates the grid
and relies on the computed grid index to skip the unnecessary blocks. The basic idea
is to use input_pos, which is the length of each slot to determine if we should
work on the next block of the slot or move to the next slot."""
start = decode_state.start
end = (start + decode_state.input_pos) % self.env.cache_len
batch_size = start.shape[0]
bk = self.env.block_size
# The batch index
b = jnp.arange(batch_size).reshape((batch_size, 1))
num_bk = self.env.cache_len // self.env.block_size
# The block index
i = jnp.arange(num_bk).reshape((1, num_bk))
i = jnp.broadcast_to(i, (batch_size, num_bk))
start = start.reshape((batch_size, 1))
end = end.reshape((batch_size, 1))
am_last_batch = b == batch_size - 1
last_good_block = jnp.where(
start < end,
jax.lax.div(end - 1, bk),
jax.lax.div(self.env.cache_len - 1, bk),
)
next_b = jnp.where(am_last_batch, b, b + 1)
next_i = jnp.where(am_last_batch, last_good_block, 0)
# start < end, continue work on the block is there is overlap with the [start, end)
def true_comp(b, i, bk, start, end, next_b, next_i):
b_next = jnp.where(i * bk >= end, next_b, b)
i_next = jnp.where(i * bk >= end, next_i, i)
i_next = jnp.where((i + 1) * bk <= start, jax.lax.div(start, bk), i_next)
return b_next, i_next
# start > end, continue work on the block is there is no overlap with [end, start)
def false_comp(b, i, bk, start, end):
b_next = b
i_next = jnp.where(
jnp.logical_and(i * bk >= end, (i + 1) * bk <= start),
jax.lax.div(start, bk),
i,
)
return b_next, i_next
true_comp_b, true_comp_i = true_comp(b, i, bk, start, end, next_b, next_i)
false_comp_b, false_comp_i = false_comp(b, i, bk, start, end)
b_next = jnp.where(
start < end, true_comp_b, jnp.where(start == end, next_b, false_comp_b)
)
i_next = jnp.where(
start < end, true_comp_i, jnp.where(start == end, next_i, false_comp_i)
)
return b_next, i_next
def generate(
self, params: Any, decode_state: DecodeState, sampler=None
) -> tuple[DecodeState, engine_api.ResultTokens]:
return (None, None)
def generate_page_attention(
self, params: Any, decode_state: DecodeState
) -> tuple[DecodeState, engine_api.ResultTokens]:
np_pos = np.asarray(decode_state.input_pos.block_until_ready())
self.page_attention_manager.fill_new_pages(np_pos)
np_page_token_indices = self.page_attention_manager.get_page_token_indices(
np_pos
)
page_token_indices = jnp.asarray(np_page_token_indices)
new_decode_state, result_tokens = self.generate_jit(
params, decode_state, page_token_indices=page_token_indices
)
# new_decode_state, result_tokens = self.generate_impl(params, decode_state, page_token_indices)
return new_decode_state, result_tokens
def generate_impl(
self,
params: Any,
decode_state: DecodeState,
sampler=None,
page_token_indices=None,
) -> tuple[DecodeState, engine_api.ResultTokens]:
# seq_len = padded_tokens.shape[0]
if self.env.page_attention:
page_token_indices = torchjax.to_torch(page_token_indices)
pos = decode_state.current_position
if self.env.ring_buffer:
input_indexes = jnp.full((1,), pos)
else:
input_indexes = decode_state.input_pos
ragged_batch_index, ragged_block_index = (
self.precompute_ragged_block_indices(decode_state)
)
ragged_batch_index, ragged_block_index = ragged_batch_index.reshape(
(-1)
), ragged_block_index.reshape((-1))
def update_mask():
if self.env.ring_buffer:
return decode_state.mask.at[:, decode_state.current_position].set(0)
batch = jnp.arange(self.env.batch_size)
return decode_state.mask.at[batch, decode_state.input_pos].set(0)
mask = decode_state.mask
if not self.env.lazy_cache_update:
mask = update_mask()
logits, new_caches, new_scales = self._call_model_generate(
params,
decode_state.tokens,
input_indexes,
decode_state.caches,
decode_state.cache_scales,
mask,
decode_state.start,
decode_state.input_pos,
ragged_batch_index,
ragged_block_index,
page_token_indices,
)
if self.env.lazy_cache_update:
# fill mask later, now use flash attention
mask = update_mask()
if sampler:
next_token = sampler(logits[:, -1])
else:
next_token = self._sampling(logits, self.env.batch_size)
if self.env.ring_buffer:
input_pos = decode_state.input_pos + 1
lens = decode_state.lens + 1
else:
input_pos = jnp.where(
decode_state.input_pos == 0,
0,
decode_state.input_pos + 1 % self.env.cache_len,
)
lens = jnp.where(
decode_state.lens == 0, 0, decode_state.lens + 1 % self.env.cache_len
)
data = jnp.concatenate(
[
decode_state.tokens,
jnp.ones_like(next_token),
lens,
],
axis=-1,
)
# [0] is the batch dimension, [1] normally should be 1
length = next_token.shape[1]
result_tokens = engine_api.ResultTokens(
data=data,
tokens_idx=(0, length),
valid_idx=(length, 2 * length),
length_idx=(2 * length, 2 * length + 1),
samples_per_slot=1,
)
new_decode_state = DecodeState(
next_token,
new_caches,
new_scales,
(decode_state.current_position + 1) % self.env.cache_sequence_length,
lens,
decode_state.start,
input_pos,
mask,
)
return new_decode_state, result_tokens
# pylint: disable-next=all
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
# pylint: disable-next=all
return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path)
def build_tokenizer(
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
) -> tokenizer_api.Tokenizer:
if self.env.hf_tokenizer is not None:
return HFTokenizerAdapter(self.env.hf_tokenizer)
if "llama-3" in self.env.model_type:
return token_utils.TikToken(metadata)
return token_utils.SentencePieceTokenizer(metadata)
def join_prefixes(
self,
prefix1: engine_api.Prefix,
length1: int,
prefix2: engine_api.Prefix,
length2: int,
) -> tuple[engine_api.Prefix, int]:
"""join prefixes"""
raise NotImplementedError("join_prefixes not supported")
def _make_state_dict_jax(self, model_args_meta):
def make_array(t):
res = jax.random.normal(
jax.random.key(0), shape=t.shape, dtype=self.default_dtype
)
res = res.astype(torch_xla2.tensor.t2j_dtype(t.dtype))
return res
return pytree.tree_map_only(torch.Tensor, make_array, model_args_meta)
def _load_from_safetensors(self, path):
weights = {}
with safe_open(path, framework="flax", device="cpu") as f:
for key, model_weights in self.pt_model.state_dict().items():
if key == "freqs_cis":
continue
weights[key] = f.get_tensor(key)
assert tuple(model_weights.shape) == tuple(
weights[key].shape
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
return weights
def _load_from_state_dict(self, path):
state_dict = torch.load(path, map_location=torch.device("cpu"))
weights = {}
print(f"Loaded keys are : {state_dict.keys()}")
for key, model_weights in self.pt_model.state_dict().items():
if key == "freqs_cis":
continue
assert key in state_dict, f"key: {key} not found"
weights[key] = torch_xla2.tensor.t2j(state_dict[key])
assert tuple(model_weights.shape) == tuple(
weights[key].shape
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
return weights
# pylint: disable-next=all
def load_params(self) -> Params:
if self.weights is not None:
return self.weights
# We want to fix this: load from files
with jax.default_device(self.colocated_cpus):
if self.env.checkpoint_path:
if self.env.checkpoint_format == "safetensors":
jax_weights = self._load_from_safetensors(self.env.checkpoint_path)
elif self.env.checkpoint_format == "state_dict":
jax_weights = self._load_from_state_dict(self.env.checkpoint_path)
else:
jax_weights = self._make_state_dict_jax(self.pt_model.state_dict())
if self.env.quant_config.num_bits_weight == 4:
assert (
"gemma" not in self.env.model_type
), "int-4 is not supported in Gemma model yet."
quantize_linear_weights_scaler_map = (
self.pt_model.get_quantized_linear_weight_to_scaler_map()
)
with jax.default_device(jax.devices("cpu")[0]):
for key, val in jax_weights.items():
for qname in quantize_linear_weights_scaler_map.keys():
if key.endswith(qname):
val = val.astype(jnp.int4)
jax_weights[key] = val
jax_weights = {
key: jax.device_put(value, self.env.sharding_by_name(key))
for key, value in jax_weights.items()
}
for k, v in jax_weights.items():
if k.startswith("layers") and not k.startswith("layers.0"):
continue
print(f"Name: {k}, shape: {v.shape} x {v.dtype}")
return jax_weights
@property
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
return jax.devices("cpu")[0]
def get_prefix_destination_sharding(self) -> Prefix:
"""Returns the shardings necessary to transfer data between engines."""
return Prefix(
self.replicated,
self.replicated
if self.env.shard_on_batch
else self.prefill_cache_sharding
if self.env.page_attention
else self.cache_sharding,
self.replicated,
)
def get_decode_state_sharding(self) -> DecodeState:
"""Gets the shardings corresponding to the decode state."""
return DecodeState(
self.x_sharding if self.env.shard_on_batch else self.replicated,
self.cache_sharding,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)
def get_prefix_sequence_ddim(self) -> Any:
"""Returns the index of the sequence dim in the prefix type."""
return self.get_prefix_destination_sharding()
@property
def max_concurrent_decodes(self) -> int:
return self.env.batch_size
@property
def samples_per_slot(self) -> int:
return 1
# return self.samples_per_slot_input
@property
def max_prefill_length(self) -> int:
return self.env.max_input_sequence_length
@property
def max_decode_length(self) -> int:
"""Maximum decode length."""