|
17 | 17 | from typing import Any, List, Optional, Tuple, Union
|
18 | 18 | import threading
|
19 | 19 | import functools
|
| 20 | +import os |
20 | 21 | import humanize
|
21 | 22 |
|
22 | 23 |
|
|
39 | 40 | from jetstream_pt import cache_manager
|
40 | 41 | from jetstream_pt import quantize
|
41 | 42 | from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
|
| 43 | +from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model |
42 | 44 |
|
43 | 45 |
|
44 | 46 | Mesh = jax.sharding.Mesh
|
@@ -103,6 +105,7 @@ def __init__(
|
103 | 105 | quantize_weights=False,
|
104 | 106 | quantize_kv=False,
|
105 | 107 | max_cache_length=1024,
|
| 108 | + sharding_config=None, |
106 | 109 | ):
|
107 | 110 |
|
108 | 111 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
|
@@ -144,38 +147,61 @@ def __init__(
|
144 | 147 | checkpoint_format = "safetensors"
|
145 | 148 | checkpoint_path = paths[0]
|
146 | 149 |
|
| 150 | + if not sharding_config: |
| 151 | + sharding_config = os.path.join("default_shardings", model_name + ".yaml") |
| 152 | + |
147 | 153 | env_data = JetEngineEnvironmentData(
|
148 | 154 | tokenizer_path=tokenizer_path,
|
149 | 155 | checkpoint_path=checkpoint_path,
|
150 | 156 | checkpoint_format=checkpoint_format,
|
151 |
| - model_type="llama-2-" + param_size, |
152 | 157 | batch_size=batch_size,
|
153 | 158 | max_decode_length=max_decode_length,
|
154 | 159 | max_input_sequence_length=context_length,
|
155 | 160 | enable_weight_quantization=quantize_weights,
|
156 | 161 | enable_kv_quantization=quantize_kv,
|
157 | 162 | cache_sequence_length=max_cache_length,
|
158 | 163 | bf16_enable=bf16_enable,
|
| 164 | + sharding_config_path=sharding_config, |
159 | 165 | )
|
160 | 166 | env = JetEngineEnvironment(env_data)
|
161 | 167 |
|
162 |
| - pt_model = None |
163 |
| - if "llama" in model_name: |
| 168 | + if model_name.startswith("llama"): |
| 169 | + |
164 | 170 | args = model_args.get_model_args(
|
165 |
| - model_name + "-" + param_size, |
166 |
| - context_length, |
167 |
| - batch_size, |
168 |
| - bf16_enable, |
| 171 | + model_name + "-" + param_size, context_length, batch_size, bf16_enable |
169 | 172 | )
|
170 | 173 | args.device = "meta"
|
171 | 174 | args.quantize = quantize_weights
|
| 175 | + env_data.cache_shape = ( |
| 176 | + batch_size, |
| 177 | + args.n_kv_heads, |
| 178 | + max_cache_length, |
| 179 | + args.dim // args.n_heads, |
| 180 | + ) |
| 181 | + env_data.model_type = "llama-2-" + param_size |
| 182 | + env_data.num_layers = args.n_layers |
| 183 | + env = JetEngineEnvironment(env_data) |
172 | 184 | pt_model = model_exportable.Transformer(args, env)
|
| 185 | + elif model_name == "gemma": |
| 186 | + args = gemma_config.get_model_config(param_size) |
| 187 | + env_data.cache_shape = ( |
| 188 | + batch_size, |
| 189 | + args.num_key_value_heads, |
| 190 | + max_cache_length, |
| 191 | + args.head_dim, |
| 192 | + ) |
| 193 | + env_data.model_type = "gemma-" + param_size |
| 194 | + env_data.num_layers = args.num_hidden_layers |
| 195 | + env = JetEngineEnvironment(env_data) |
| 196 | + pt_model = gemma_model.GemmaModel(args, env) |
| 197 | + else: |
| 198 | + raise RuntimeError(f"Model with name {model_name} not found") |
173 | 199 |
|
174 |
| - num_params_size = 0 |
175 |
| - num_params = 0 |
176 |
| - for _, v in pt_model.state_dict().items(): |
177 |
| - num_params += 1 |
178 |
| - num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) |
| 200 | + num_params_size = 0 |
| 201 | + num_params = 0 |
| 202 | + for _, v in pt_model.state_dict().items(): |
| 203 | + num_params += 1 |
| 204 | + num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) |
179 | 205 | print("Number of param Gbytes:", num_params_size / (1 << 30))
|
180 | 206 | print("Number of param: ", num_params)
|
181 | 207 |
|
|
0 commit comments