Skip to content

Commit 9c0d2ac

Browse files
authored
Add gemma and update recent changes to multiple host (#74)
add gemma and update recent changes to multiple host
1 parent dab2d7a commit 9c0d2ac

File tree

3 files changed

+55
-12
lines changed

3 files changed

+55
-12
lines changed

jetstream_pt/ray_engine.py

+7
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,14 @@ def create_pytorch_ray_engine(
152152
quantize_weights=False,
153153
quantize_kv=False,
154154
max_cache_length=1024,
155+
sharding_config=None,
155156
) -> PyTorchRayEngine:
156157

158+
supported_models = ["llama-2", "llama-3", "gemma"]
159+
if model_name not in supported_models:
160+
raise NotImplementedError(
161+
f"Model name should be one of{','.join(supported_models)}"
162+
)
157163
ray.init(ignore_reinit_error=True)
158164
pod_name = tpu.get_current_pod_name()
159165
num_hosts = tpu.get_current_pod_worker_count()
@@ -183,6 +189,7 @@ def create_pytorch_ray_engine(
183189
quantize_weights=quantize_weights,
184190
quantize_kv=quantize_kv,
185191
max_cache_length=max_cache_length,
192+
sharding_config=sharding_config,
186193
)
187194
engine_workers.append(engine_worker)
188195
engine_master = PyTorchRayEngine(

jetstream_pt/ray_worker.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, List, Optional, Tuple, Union
1818
import threading
1919
import functools
20+
import os
2021
import humanize
2122

2223

@@ -39,6 +40,7 @@
3940
from jetstream_pt import cache_manager
4041
from jetstream_pt import quantize
4142
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
43+
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
4244

4345

4446
Mesh = jax.sharding.Mesh
@@ -103,6 +105,7 @@ def __init__(
103105
quantize_weights=False,
104106
quantize_kv=False,
105107
max_cache_length=1024,
108+
sharding_config=None,
106109
):
107110

108111
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
@@ -144,38 +147,61 @@ def __init__(
144147
checkpoint_format = "safetensors"
145148
checkpoint_path = paths[0]
146149

150+
if not sharding_config:
151+
sharding_config = os.path.join("default_shardings", model_name + ".yaml")
152+
147153
env_data = JetEngineEnvironmentData(
148154
tokenizer_path=tokenizer_path,
149155
checkpoint_path=checkpoint_path,
150156
checkpoint_format=checkpoint_format,
151-
model_type="llama-2-" + param_size,
152157
batch_size=batch_size,
153158
max_decode_length=max_decode_length,
154159
max_input_sequence_length=context_length,
155160
enable_weight_quantization=quantize_weights,
156161
enable_kv_quantization=quantize_kv,
157162
cache_sequence_length=max_cache_length,
158163
bf16_enable=bf16_enable,
164+
sharding_config_path=sharding_config,
159165
)
160166
env = JetEngineEnvironment(env_data)
161167

162-
pt_model = None
163-
if "llama" in model_name:
168+
if model_name.startswith("llama"):
169+
164170
args = model_args.get_model_args(
165-
model_name + "-" + param_size,
166-
context_length,
167-
batch_size,
168-
bf16_enable,
171+
model_name + "-" + param_size, context_length, batch_size, bf16_enable
169172
)
170173
args.device = "meta"
171174
args.quantize = quantize_weights
175+
env_data.cache_shape = (
176+
batch_size,
177+
args.n_kv_heads,
178+
max_cache_length,
179+
args.dim // args.n_heads,
180+
)
181+
env_data.model_type = "llama-2-" + param_size
182+
env_data.num_layers = args.n_layers
183+
env = JetEngineEnvironment(env_data)
172184
pt_model = model_exportable.Transformer(args, env)
185+
elif model_name == "gemma":
186+
args = gemma_config.get_model_config(param_size)
187+
env_data.cache_shape = (
188+
batch_size,
189+
args.num_key_value_heads,
190+
max_cache_length,
191+
args.head_dim,
192+
)
193+
env_data.model_type = "gemma-" + param_size
194+
env_data.num_layers = args.num_hidden_layers
195+
env = JetEngineEnvironment(env_data)
196+
pt_model = gemma_model.GemmaModel(args, env)
197+
else:
198+
raise RuntimeError(f"Model with name {model_name} not found")
173199

174-
num_params_size = 0
175-
num_params = 0
176-
for _, v in pt_model.state_dict().items():
177-
num_params += 1
178-
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
200+
num_params_size = 0
201+
num_params = 0
202+
for _, v in pt_model.state_dict().items():
203+
num_params += 1
204+
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
179205
print("Number of param Gbytes:", num_params_size / (1 << 30))
180206
print("Number of param: ", num_params)
181207

run_interactive_multiple_host.py

+10
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@
6565
"max_cache_length", 1024, "kv_cache_quantize"
6666
)
6767

68+
_MODEL_NAME = flags.DEFINE_string(
69+
"model_name", None, "model type", required=False
70+
)
71+
72+
_SHARDING_CONFIG = flags.DEFINE_string(
73+
"sharding_config", "", "config file for sharding"
74+
)
75+
6876

6977
def create_engine():
7078
"""create a pytorch engine"""
@@ -73,6 +81,7 @@ def create_engine():
7381

7482
start = time.perf_counter()
7583
engine = ray_engine.create_pytorch_ray_engine(
84+
model_name=_MODEL_NAME.value,
7685
tokenizer_path=_TOKENIZER_PATH.value,
7786
ckpt_path=_CKPT_PATH.value,
7887
bf16_enable=True,
@@ -82,6 +91,7 @@ def create_engine():
8291
quantize_weights=_QUANTIZE_WEIGHTS.value,
8392
quantize_kv=_QUANTIZE_KV_CACHE.value,
8493
max_cache_length=_MAX_CACHE_LENGTH.value,
94+
sharding_config=_SHARDING_CONFIG.value,
8595
)
8696

8797
print("Initialize engine", time.perf_counter() - start)

0 commit comments

Comments
 (0)