Skip to content

Commit 137eb47

Browse files
authored
Support llama3 (#64)
* Support llama3 * Sync with main branch * Fix CI * fix linting * Fix pyink issues * fix run_offline script * Fix pyink * Fix after merging main * Update jetstream version in install_everything.sh * Fix unit tests * Fix test
1 parent 9606a1f commit 137eb47

18 files changed

+101
-67
lines changed

benchmarks/run_offline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import jax
2222
import jax.numpy as jnp
2323

24-
from jetstream.engine import token_utils
2524
from jetstream_pt import engine as je
2625
# pylint: disable-next=all
2726
from benchmarks import analyze_sharegpt
@@ -97,11 +96,11 @@ def create_engine():
9796
def run_prefill_time(engine, params, decode_state, seqlen):
9897
"""Run prefill and measure time."""
9998
metadata = engine.get_tokenizer()
100-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
99+
tokenizer = engine.build_tokenizer(metadata)
101100

102101
text = "This is a beautiful day"
103-
tokens, true_length = token_utils.tokenize_and_pad(
104-
text, vocab, is_bos=True, prefill_lengths=[seqlen]
102+
tokens, true_length = tokenizer.encode(
103+
text, is_bos=True, prefill_lengths=[seqlen]
105104
)
106105

107106
for _ in range(3):

install_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
TORCHXLA_TAG=jetstream-pytorch
16-
JETSTREAM_TAG=v0.2.0
16+
JETSTREAM_TAG=v0.2.1
1717

1818
# Uninstall existing jax
1919
pip3 show jax && pip3 uninstall -y jax

jetstream_pt/engine.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
import torch
2727
import numpy as np
2828

29-
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
29+
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
3030
import torch_xla2
3131
from torch.utils import _pytree as pytree
3232

3333
from jetstream_pt import cache_manager
3434
from jetstream_pt import quantize
3535
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
36-
from jetstream_pt.third_party.llama2 import model_exportable, model_args
36+
from jetstream_pt.third_party.llama import model_exportable, model_args
3737

3838

3939
Mesh = jax.sharding.Mesh
@@ -526,6 +526,14 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
526526
# pylint: disable-next=all
527527
return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path)
528528

529+
def build_tokenizer(
530+
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
531+
) -> tokenizer_api.Tokenizer:
532+
if "llama-3" in self.env.model_type:
533+
return token_utils.TikToken(metadata)
534+
535+
return token_utils.SentencePieceTokenizer(metadata)
536+
529537
def join_prefixes(
530538
self,
531539
prefix1: engine_api.Prefix,
@@ -652,13 +660,18 @@ def create_pytorch_engine(
652660
context_length: int = 1024,
653661
batch_size: int = 1,
654662
max_decode_length: int = 4096,
655-
model_name="llama",
663+
model_name="llama-2",
656664
quantize_weights=False,
657665
quantize_kv=False,
658666
max_cache_length=1024,
659667
) -> PyTorchEngine:
660668
"""Returns: The pytorch engine."""
661669

670+
supported_models = ["llama-2", "llama-3"]
671+
if model_name not in supported_models:
672+
raise NotImplementedError(
673+
f"Model name should be one of{','.join(supported_models)}"
674+
)
662675
# See issue b/309529778 if it's turned on.
663676
jax.config.update("jax_dynamic_shapes", False)
664677
# Pytorch exports has int64 constants.
@@ -696,11 +709,7 @@ def create_pytorch_engine(
696709
if model_name.startswith("llama"):
697710

698711
args = model_args.get_model_args(
699-
param_size,
700-
context_length,
701-
batch_size,
702-
tokenizer.vocab_size,
703-
bf16_enable,
712+
model_name + "-" + param_size, context_length, batch_size, bf16_enable
704713
)
705714
args.device = "meta"
706715
args.quantize = quantize_weights

jetstream_pt/ray_worker.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from torch.utils import _pytree as pytree
3333
import torch_xla2
3434

35-
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
35+
from jetstream.engine import engine_api, tokenizer_pb2
3636

37-
from jetstream_pt.third_party.llama2 import model_exportable, model_args
37+
from jetstream_pt.third_party.llama import model_exportable, model_args
3838

3939
from jetstream_pt import cache_manager
4040
from jetstream_pt import quantize
@@ -99,7 +99,7 @@ def __init__(
9999
context_length: int = 1024,
100100
batch_size: int = 1,
101101
max_decode_length: int = 4096,
102-
model_name="llama",
102+
model_name="llama-2",
103103
quantize_weights=False,
104104
quantize_kv=False,
105105
max_cache_length=1024,
@@ -159,14 +159,12 @@ def __init__(
159159
)
160160
env = JetEngineEnvironment(env_data)
161161

162-
tokenizer = token_utils.load_vocab(tokenizer_path)
163162
pt_model = None
164-
if model_name == "llama":
163+
if "llama" in model_name:
165164
args = model_args.get_model_args(
166-
param_size,
165+
model_name + "-" + param_size,
167166
context_length,
168167
batch_size,
169-
tokenizer.vocab_size,
170168
bf16_enable,
171169
)
172170
args.device = "meta"

jetstream_pt/third_party/llama2/generation_original.py renamed to jetstream_pt/third_party/llama/generation_original.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import List, Literal, Optional, Tuple, TypedDict
66

77
import torch
8-
from jetstream_pt.third_party.llama2 import model_original
8+
from jetstream_pt.third_party.llama import model_original
99
from flax import struct
10-
from jetstream_pt.third_party.llama2.tokenizer import Tokenizer
10+
from jetstream_pt.third_party.llama.tokenizer import Tokenizer
1111

1212
Role = Literal["system", "user", "assistant"]
1313

jetstream_pt/third_party/llama2/model_args.py renamed to jetstream_pt/third_party/llama/model_args.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,68 +34,81 @@ class ModelArgs:
3434
device = "cpu"
3535
quantize = False
3636

37+
rope_theta: float = 10000.0
38+
3739

3840
def get_arg(
39-
param_size: str,
41+
model_name: str,
4042
seqlen,
4143
batch_size,
42-
vocab_size: int,
4344
bf16_enable: bool = False,
4445
) -> ModelArgs:
4546
"""Gets model args."""
4647

4748
data = {}
48-
if param_size == "tiny":
49+
if model_name == "llama-2-tiny":
4950
data = {
5051
"dim": 128,
52+
"vocab_size": 32000,
5153
"multiple_of": 32,
5254
"n_heads": 8,
5355
"n_layers": 3,
5456
"norm_eps": 1e-05,
5557
}
56-
elif param_size == "7b":
58+
elif model_name == "llama-2-7b":
5759
data = {
5860
"dim": 4096,
61+
"vocab_size": 32000,
5962
"multiple_of": 256,
6063
"n_heads": 32,
6164
"n_layers": 32,
6265
"norm_eps": 1e-05,
6366
}
64-
elif param_size == "13b":
67+
elif model_name == "llama-2-13b":
6568
data = {
6669
"dim": 5120,
70+
"vocab_size": 32000,
6771
"multiple_of": 256,
6872
"n_heads": 40,
6973
"n_layers": 40,
7074
"norm_eps": 1e-05,
7175
}
72-
elif param_size == "70b":
76+
elif model_name == "llama-2-70b":
7377
data = {
7478
"dim": 8192,
79+
"vocab_size": 32000,
7580
"multiple_of": 4096,
7681
"ffn_dim_multiplier": 1.3,
7782
"n_heads": 64,
7883
"n_kv_heads": 8,
7984
"n_layers": 80,
8085
"norm_eps": 1e-05,
8186
}
87+
elif model_name == "llama-3-8b":
88+
data = {
89+
"dim": 4096,
90+
"vocab_size": 128256,
91+
"multiple_of": 1024,
92+
"ffn_dim_multiplier": 1.3,
93+
"n_layers": 32,
94+
"n_heads": 32,
95+
"n_kv_heads": 8,
96+
"norm_eps": 1e-05,
97+
"rope_theta": 500000.0,
98+
}
8299
return ModelArgs(
83100
max_seq_len=seqlen,
84101
max_batch_size=batch_size,
85-
vocab_size=vocab_size,
86102
bf16_enable=bf16_enable,
87103
**data,
88104
)
89105

90106

91-
def get_model_args(
92-
param_size, context_length, batch_size, vocab_size, bf16_enable
93-
):
107+
def get_model_args(model_name, context_length, batch_size, bf16_enable):
94108
model_args = get_arg(
95-
param_size=param_size,
109+
model_name=model_name,
96110
seqlen=context_length,
97111
batch_size=batch_size,
98-
vocab_size=vocab_size,
99112
bf16_enable=bf16_enable,
100113
)
101114
model_args.n_kv_heads = (

jetstream_pt/third_party/llama2/model_exportable.py renamed to jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def __init__(
157157
)
158158
# TODO what to do with this
159159
freqs_cis = precompute_freqs_cis(
160-
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
160+
self.params.dim // self.params.n_heads,
161+
self.params.max_seq_len * 2,
162+
theta=self.params.rope_theta,
161163
)
162164

163165
self.register_buffer("freqs_cis", freqs_cis)

run_interactive.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
import jax
2525

2626
from jetstream.engine import token_utils
27+
from colorama import Fore, Style
28+
import numpy as np
29+
30+
import os
31+
2732
from jetstream_pt import engine as je
2833

2934
FLAGS = flags.FLAGS
@@ -64,6 +69,11 @@
6469
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
6570
"max_cache_length", 1024, "kv_cache_quantize"
6671
)
72+
_MODEL_NAME = flags.DEFINE_string(
73+
"model",
74+
"llama-2",
75+
"name of the model. Supported options are llama-2 and llama-3",
76+
)
6777

6878

6979
def create_engine():
@@ -81,6 +91,7 @@ def create_engine():
8191
param_size=_SIZE.value,
8292
context_length=_CONTEXT_LENGTH.value,
8393
batch_size=_BATCH_SIZE.value,
94+
model_name=_MODEL_NAME.value,
8495
quantize_weights=_QUANTIZE_WEIGHTS.value,
8596
quantize_kv=_QUANTIZE_KV_CACHE.value,
8697
max_cache_length=_MAX_CACHE_LENGTH.value,
@@ -100,8 +111,7 @@ def main(argv):
100111
print("Load params ", time.perf_counter() - start)
101112

102113
metadata = engine.get_tokenizer()
103-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
104-
stop_tokens = [vocab.eos_id, vocab.pad_id]
114+
tokenizer = engine.build_tokenizer(metadata)
105115
max_output_length = 1024
106116

107117
if _PROFILING_OUTPUT.value:
@@ -121,9 +131,8 @@ def main(argv):
121131
]
122132
for prompt in prompts:
123133
slot = random.randint(0, _BATCH_SIZE.value - 1)
124-
tokens, true_length = token_utils.tokenize_and_pad(
125-
prompt, vocab, is_bos=True
126-
)
134+
tokens, true_length = tokenizer.encode(prompt, is_bos=True)
135+
127136
print(f"---- Input prompts are: {prompt}")
128137
print(f"---- Encoded tokens are: {tokens}")
129138

@@ -135,29 +144,27 @@ def main(argv):
135144
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
136145
sampled_tokens_list = []
137146
print(f"---- Streaming decode started on #slot{slot}.")
147+
complete = np.zeros((1,), dtype=np.bool_)
138148
while True:
139-
# pylint: disable-next=all
140149
decode_state, result_tokens = engine.generate(params, decode_state)
141-
142-
slot_data = result_tokens.get_result_at_slot(slot)
143-
slot_tokens = slot_data.tokens
144-
slot_lengths = slot_data.lengths
145-
146-
token_id = slot_tokens[slot, 0].item()
147-
if slot_lengths > max_output_length or token_id in stop_tokens:
150+
result_tokens = result_tokens.convert_to_numpy()
151+
output, complete = tokenizer.decode(
152+
slot, max_output_length, result_tokens, complete
153+
)
154+
if complete[0]:
148155
break
149-
156+
token_id = output[0][0]
150157
sampled_tokens_list.append(token_id)
151-
# output = token_utils.mix_decode(vocab, token_id)
152-
# print(Fore.GREEN + output, end="", flush=True)
158+
# output_str = tokenizer.decode_str([token_id])
159+
# print(Fore.GREEN + output_str, end="", flush=True)
153160

154161
# print(Style.RESET_ALL + "\n")
155162
# print("---- Streaming decode finished.")
156163

157164
print("---- All output tokens.")
158165
print(sampled_tokens_list)
159166
print("---- All output text.")
160-
print(vocab.tokenizer.decode(sampled_tokens_list))
167+
print(tokenizer.decode_str(sampled_tokens_list))
161168

162169
if _PROFILING_OUTPUT.value:
163170
jax.profiler.stop_trace()

run_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@
7171
"The model size the server runs on.",
7272
required=False,
7373
)
74+
_MODEL_NAME = flags.DEFINE_string(
75+
"model",
76+
"llama-2",
77+
"name of the model. Supported options are llama-2 and llama-3",
78+
)
7479

7580
_QUANTIZE_WEIGHTS = flags.DEFINE_bool(
7681
"quantize_weights", False, "weight quantization"
@@ -98,6 +103,7 @@ def main(argv: Sequence[str]):
98103
param_size=_PARAM_SIZE.value,
99104
context_length=_CONTEXT_LENGTH.value,
100105
batch_size=_BATCH_SIZE.value,
106+
model_name=_MODEL_NAME.value,
101107
quantize_weights=_QUANTIZE_WEIGHTS.value,
102108
quantize_kv=_QUANTIZE_KV_CACHE.value,
103109
max_cache_length=_MAX_CACHE_LENGTH.value,

tests/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import jax
3-
from jetstream_pt.third_party.llama2 import model_args
3+
from jetstream_pt.third_party.llama import model_args
44
from jetstream_pt import environment
55

66

@@ -9,7 +9,7 @@ def make_env_tiny(bf16_enable=True):
99
torch.set_default_dtype(torch_dtype)
1010
jax.config.update("jax_dynamic_shapes", False)
1111
jax.config.update("jax_traceback_filtering", "off")
12-
config = model_args.get_model_args("tiny", 128, 1, 32000, True)
12+
config = model_args.get_model_args("llama-2-tiny", 128, 1, True)
1313
environment_data = environment.JetEngineEnvironmentData()
1414
environment_data.max_input_sequence_length = 128
1515
environment_data.max_input_sequence_length = 128

tests/test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
2323
from jetstream_pt.engine import PyTorchEngine, Prefix, DecodeState
24-
from jetstream_pt.third_party.llama2 import model_exportable, model_original
24+
from jetstream_pt.third_party.llama import model_exportable, model_original
2525

2626
# This model will output tokens with value of 2
2727
# and will update caches with value of 1.0

0 commit comments

Comments
 (0)