Skip to content

Llama 3.1 RoPE scaling #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
meta-llama/Llama-3.1-8B
meta-llama/Llama-3.1-8B-Instruct
meta-llama/Llama-3.2-1B
meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.3-70B
meta-llama/Llama-3.3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
Expand Down
3 changes: 3 additions & 0 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
sharded = {}
# Some output and embeddings weights might be tied: in this case untie them
if weights["output.weight"].device.type == "meta":
weights["output.weight"] = weights["tok_embeddings.weight"].clone()
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
with jax.default_device(jax.devices("cpu")[0]):
Expand Down
11 changes: 10 additions & 1 deletion jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class ModelInfo:
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
_llama3_70 = _llama2_70
_llama3_1_8b = _llama3_8
_llama3_2_1b = ModelInfo(llama_model.Transformer, 16, 8, 64, 4)
_llama3_3_70b = _llama2_70

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)

Expand All @@ -78,6 +81,12 @@ class ModelInfo:
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
"meta-llama/Meta-Llama-3-70B": _llama3_70,
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
"meta-llama/Llama-3.1-8B": _llama3_1_8b,
"meta-llama/Llama-3.1-8B-Instruct": _llama3_1_8b,
"meta-llama/Llama-3.2-1B": _llama3_2_1b,
"meta-llama/Llama-3.2-1B-Instruct": _llama3_2_1b,
"meta-llama/Llama-3.3-70B": _llama3_3_70b,
"meta-llama/Llama-3.3-70B-Instruct": _llama3_3_70b,
"google/gemma-2b": _gemma_2b,
"google/gemma-2b-it": _gemma_2b,
"google/gemma-7b": _gemma_7b,
Expand Down Expand Up @@ -215,7 +224,7 @@ def _hf_download(
local_dir_use_symlinks=False,
token=hf_token,
allow_patterns=[
"model-?????-of-?????.safetensors",
"model*.safetensors",
"*.json",
"*.model",
],
Expand Down
65 changes: 65 additions & 0 deletions jetstream_pt/third_party/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from typing import Optional


@dataclasses.dataclass
class RopeScalingArgs:
"""Rope scaling configuration parameters."""

factor: float = 8.0
low_freq_factor: float = 1.0
high_freq_factor: float = 4.0
original_max_position_embeddings: int = 8192


@dataclasses.dataclass
class ModelArgs:
"""Model configuration parameters."""
Expand All @@ -29,6 +39,7 @@ class ModelArgs:
device = "cpu"

rope_theta: float = 10000.0
rope_scaling_args: RopeScalingArgs = None


def get_arg(
Expand Down Expand Up @@ -103,6 +114,60 @@ def get_arg(
"vocab_size": 128256,
"rope_theta": 500000.0,
}
elif model_name == "llama-3.1-8b":
data = {
"dim": 4096,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.3,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}
elif model_name == "llama-3.2-1b":
data = {
"dim": 2048,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.5,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}
elif model_name == "llama-3.3-70b":
data = {
"dim": 8192,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.3,
"n_layers": 80,
"n_heads": 64,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}

return ModelArgs(
max_seq_len=seqlen,
Expand Down
46 changes: 42 additions & 4 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional
import copy
import jax
import math
import torch
import torch.nn.functional as F
import functools
Expand Down Expand Up @@ -170,12 +171,42 @@ def forward(
return out


def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs):
# Values obtained from grid search
scale_factor = config.factor
low_freq_factor = config.low_freq_factor
high_freq_factor = config.high_freq_factor
old_context_len = config.original_max_position_embeddings

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0
) -> torch.Tensor:
dim: int,
end: int,
theta: float = 10000.0,
rope_scaling_config: model_args.RopeScalingArgs = None,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if rope_scaling_config is not None:
freqs = apply_scaling(freqs, rope_scaling_config)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

Expand Down Expand Up @@ -223,6 +254,7 @@ def __init__(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
theta=self.params.rope_theta,
rope_scaling_config=self.params.rope_scaling_args,
)

self.register_buffer("freqs_cis", freqs_cis)
Expand Down Expand Up @@ -306,6 +338,12 @@ def from_hf_model_id(cls, model_id, env, is_tiny=False):
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
"meta-llama/Llama-3.1-8B": "llama-3.1-8b",
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-8b",
"meta-llama/Llama-3.2-1B": "llama-3.2-1b",
"meta-llama/Llama-3.2-1B-Instruct": "llama-3.2-1b",
"meta-llama/Llama-3.3-70B": "llama-3.3-70b",
"meta-llama/Llama-3.3-70B-Instruct": "llama-3.3-70b",
}.get(model_id)
assert name
args = model_args.get_model_args(
Expand Down
Loading