Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
7b81ca5
Update _utils.py
danielhanchen Aug 22, 2024
94f2d34
Update _utils.py
danielhanchen Aug 22, 2024
7c5222d
Update _utils.py
danielhanchen Aug 22, 2024
15d4417
Update _utils.py
danielhanchen Aug 22, 2024
1ea463c
Update _utils.py
danielhanchen Aug 22, 2024
cf929e2
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
5a7be98
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
2590b4c
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
621e65b
update token retrieval logic (#952)
not-lain Aug 23, 2024
b62e5cd
Update llama.py
danielhanchen Aug 23, 2024
fb9dd65
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Aug 23, 2024
3b49609
get_token
danielhanchen Aug 24, 2024
9c8875e
Update README.md
danielhanchen Aug 24, 2024
c25de14
Merge branch 'main' into nightly
danielhanchen Aug 25, 2024
646a27b
Merge branch 'main' into nightly
danielhanchen Aug 27, 2024
a44357d
Update gemma2.py
danielhanchen Aug 30, 2024
7ed1c16
Update rms_layernorm.py
danielhanchen Aug 30, 2024
d7ef49e
synchronize
danielhanchen Aug 30, 2024
9a69548
Update gemma2.py
danielhanchen Aug 30, 2024
e6dadb4
Update rms_layernorm.py
danielhanchen Aug 30, 2024
f8e77cf
Update rms_layernorm.py
danielhanchen Aug 30, 2024
cfbaa97
Update rms_layernorm.py
danielhanchen Aug 30, 2024
32b2f3f
layernorm
danielhanchen Aug 30, 2024
9e7057d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
a193508
Update gemma2.py
danielhanchen Aug 30, 2024
65eaa2d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1beeb22
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1eb7705
revert
danielhanchen Aug 30, 2024
c3fe972
Gemma
danielhanchen Aug 31, 2024
75dbfba
Update rms_layernorm.py
danielhanchen Aug 31, 2024
332b091
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4ecc119
Update rms_layernorm.py
danielhanchen Aug 31, 2024
07a1246
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e3239e4
Update rms_layernorm.py
danielhanchen Aug 31, 2024
6ae1ac2
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4d89f27
Update rms_layernorm.py
danielhanchen Aug 31, 2024
c76be22
Update rms_layernorm.py
danielhanchen Aug 31, 2024
ace509c
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e474cfe
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1576a1e
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a2c4691
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1a02e75
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a26e1d1
Update rms_layernorm.py
danielhanchen Aug 31, 2024
afdb443
Update rms_layernorm.py
danielhanchen Sep 1, 2024
c3e14d8
Update rms_layernorm.py
danielhanchen Sep 1, 2024
1830bdd
Update rms_layernorm.py
danielhanchen Sep 1, 2024
6abf66a
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5cf796
Update rms_layernorm.py
danielhanchen Sep 1, 2024
b191530
Update rms_layernorm.py
danielhanchen Sep 1, 2024
512c61f
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5d50ef
Update rms_layernorm.py
danielhanchen Sep 1, 2024
d791bb9
Update rms_layernorm.py
danielhanchen Sep 1, 2024
9225608
Update gemma2.py
danielhanchen Sep 1, 2024
f61869c
Change UnslothTrainingArguments base class to SFTConfig (#979)
vTuanpham Sep 2, 2024
73d49ad
Cohere
danielhanchen Sep 2, 2024
86b6236
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Sep 2, 2024
edef5ca
Update trainer.py
danielhanchen Sep 2, 2024
6d4300c
Cohere
danielhanchen Sep 2, 2024
754e670
Cohere
danielhanchen Sep 2, 2024
d242866
New models
danielhanchen Sep 3, 2024
0b7e973
Update llama.py
danielhanchen Sep 3, 2024
19549f2
Update llama.py
danielhanchen Sep 3, 2024
8823e13
Update cohere.py
danielhanchen Sep 3, 2024
90050b7
Update llama.py
danielhanchen Sep 3, 2024
4c1ec3a
Update cohere.py
danielhanchen Sep 3, 2024
97b3956
retry
danielhanchen Sep 3, 2024
fd615ea
Update fast_lora.py
danielhanchen Sep 3, 2024
fe45990
Update llama.py
danielhanchen Sep 3, 2024
f564b8a
Update fast_lora.py
danielhanchen Sep 3, 2024
b26da84
Update llama.py
danielhanchen Sep 3, 2024
61be6a3
Update llama.py
danielhanchen Sep 3, 2024
ea48761
Update cross_entropy_loss.py
danielhanchen Sep 3, 2024
6e795c6
_apply_lora_mlp
danielhanchen Sep 3, 2024
dacba39
Update _utils.py
danielhanchen Sep 3, 2024
5074427
Gemma fixes
danielhanchen Sep 3, 2024
743ba55
Update llama.py
danielhanchen Sep 3, 2024
315136a
Merge branch 'main' into nightly
danielhanchen Sep 3, 2024
7ea6395
Update flex_attention.py
danielhanchen Sep 3, 2024
91d6773
Merge branch 'main' into nightly
danielhanchen Sep 4, 2024
df06a04
Update llama.py
danielhanchen Sep 4, 2024
7f139f1
layernorm
danielhanchen Sep 4, 2024
068fc0d
Update llama.py
danielhanchen Sep 4, 2024
4eaccb0
Update llama.py
danielhanchen Sep 4, 2024
4f909fc
Flex Attention
danielhanchen Sep 5, 2024
efef0ee
Update gemma2.py
danielhanchen Sep 5, 2024
6e8951f
Update __init__.py
danielhanchen Sep 5, 2024
d60a18c
Update flex_attention.py
danielhanchen Sep 5, 2024
1b4132e
Update flex_attention.py
danielhanchen Sep 5, 2024
f5d11dc
Update flex_attention.py
danielhanchen Sep 5, 2024
2454659
Update flex_attention.py
danielhanchen Sep 5, 2024
984d217
Update flex_attention.py
danielhanchen Sep 5, 2024
e3846f5
Update flex_attention.py
danielhanchen Sep 5, 2024
2d29299
Update flex_attention.py
danielhanchen Sep 5, 2024
03310b9
Update flex_attention.py
danielhanchen Sep 5, 2024
eb37676
Update flex_attention.py
danielhanchen Sep 5, 2024
cb6a835
Update flex_attention.py
danielhanchen Sep 5, 2024
cbd6a6a
Update flex_attention.py
danielhanchen Sep 5, 2024
712deaa
Update flex_attention.py
danielhanchen Sep 5, 2024
6e74563
Update flex_attention.py
danielhanchen Sep 5, 2024
0703ce8
Update flex_attention.py
danielhanchen Sep 5, 2024
e2cafc4
Update flex_attention.py
danielhanchen Sep 5, 2024
25fb059
Update flex_attention.py
danielhanchen Sep 5, 2024
6ddcd60
Update flex_attention.py
danielhanchen Sep 6, 2024
a806b20
Update chat_templates.py (#999)
AgainstEntropy Sep 7, 2024
a690e5e
Update key from "from" to "user" (#1000)
wa008 Sep 7, 2024
6693712
Update chat_templates.py
danielhanchen Sep 7, 2024
fabda63
Also patch the KTO trainer (#1001)
corbt Sep 7, 2024
f9b8a73
flex attention
danielhanchen Sep 7, 2024
2fa9979
Update llama.py
danielhanchen Sep 7, 2024
86017d3
Update flex_attention.py
danielhanchen Sep 7, 2024
130c739
Update flex_attention.py
danielhanchen Sep 7, 2024
528c673
Update _utils.py
danielhanchen Sep 8, 2024
7380ac5
Update _utils.py
danielhanchen Sep 8, 2024
4e1a50c
Update flex_attention.py
danielhanchen Sep 8, 2024
6e9d3de
Update gemma2.py
danielhanchen Sep 8, 2024
879fc88
Update gemma2.py
danielhanchen Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cohere
  • Loading branch information
danielhanchen committed Sep 2, 2024
commit 754e670daf6b53bf8fe92c5f07bae25a96aa67f1
123 changes: 84 additions & 39 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@
from transformers.models.llama.modeling_llama import logger


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Expand Down Expand Up @@ -62,17 +67,22 @@ def _cross_entropy_forward(

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

if label_idx != -100:
x = tl.load(logits_ptr + label_idx)
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
Expand All @@ -81,18 +91,23 @@ def _cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
Expand Down Expand Up @@ -130,8 +145,11 @@ def _chunked_cross_entropy_forward(

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
Expand All @@ -142,8 +160,10 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
Expand All @@ -153,17 +173,22 @@ def _chunked_cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
Expand Down Expand Up @@ -210,6 +235,11 @@ def _cross_entropy_backward(
y, # exp(x - logsumexp)
)

if DO_LOGIT_SCALING:
# d/dx [s * x] = s
y = y * LOGIT_SCALE
pass

if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
Expand All @@ -224,14 +254,15 @@ def _cross_entropy_backward(

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_softcapping = 0):
def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0):
n_rows, vocab_size = logits.shape

div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

DO_SOFTCAPPING = (logit_softcapping != 0)
DO_SOFTCAPPING = (logit_softcapping != 0)
DO_LOGIT_SCALING = (logit_scaling != 0)

if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
Expand All @@ -243,11 +274,13 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
Expand All @@ -258,12 +291,14 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
Expand All @@ -275,6 +310,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
ctx.logit_scaling = logit_scaling
return losses
pass

Expand All @@ -292,19 +329,26 @@ def backward(ctx, dlosses):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
LOGIT_SCALE = ctx.logit_scaling,
num_warps = 8,
)
return logits, None, None,
return logits, None, None, None,
pass
pass


@torch._disable_dynamo
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
def fast_cross_entropy_loss(
logits,
labels,
logit_softcapping = 0,
logit_scaling = 0,
):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
Expand All @@ -319,6 +363,7 @@ def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
logit_scaling,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
Expand Down
49 changes: 39 additions & 10 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
pass


# Normal layernorm with mean removal
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def fast_layernorm_compiled(layernorm, X):
old_dtype = X.dtype
X = X.float()
mean = X.mean(-1, keepdim = True)
Xbar = X - mean
X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \
layernorm.variance_epsilon) * \
layernorm.weight.float()
return X.to(old_dtype)
pass


# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
Expand Down Expand Up @@ -597,6 +611,7 @@ def LlamaModel_fast_forward(
# Normalized from Gemma
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
IS_COHERE = self.config.model_type.startswith("cohere")
train_embed_tokens = self.embed_tokens.weight.requires_grad

if IS_GEMMA:
Expand Down Expand Up @@ -802,8 +817,11 @@ def custom_forward(*inputs):

# Final layernorm
if use_cache:
hidden_states = (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
hidden_states = \
(fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
elif IS_COHERE:
hidden_states = fast_layernorm_compiled(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
pass
Expand Down Expand Up @@ -943,6 +961,7 @@ def _CausalLM_fast_forward(

loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
logit_scaling = getattr(self.config, "logit_scale", 0)
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
Expand All @@ -955,16 +974,26 @@ def _CausalLM_fast_forward(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
logit_scaling = logit_scaling,
)
elif logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
else:
if logit_scaling != 0:
if logits.requires_grad:
logits = logit_scaling * logits
else:
logits *= logit_scaling
pass
pass
if logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
pass
pass

Expand Down