Skip to content

Commit 3ea7044

Browse files
danielhanchentimothelaborieeltociearErland366Datta0
authored
Bug fix (#1249)
* Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py * n_items * Update cross_entropy_loss.py * Fix DPO, ORPO * Update _utils.py * Update _utils.py * fix/transformers-unpack (#1180) * Fix DPO, ORPO (#1177) * Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py * n_items * Update cross_entropy_loss.py * Fix DPO, ORPO * Update _utils.py --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> * Add warning for missing Unpack and KwargsForCausalLM in older Transformers versions --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> * Update cross_entropy_loss.py * Update _utils.py * Update _utils.py * donot upcast lm_head and embeddings to float32 (#1186) * Cleanup upcast logs (#1188) * Fix/phi-longrope (#1193) * Enhance rotary embedding handling in LlamaAttention and LongRopeRotaryEmbedding * Typo * Improve rotary embedding handling in LlamaAttention to prevent errors with short KV cache * Update llama.py * Update llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update transformers * Unk token issues * Update _utils.py * Fix pad token * Update llama.py * Typo * ignored labels * Revert "ignored labels" This reverts commit 9d07be0. * More patching * Update _utils.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Feat/all tmp (#1219) * Update save.py Check whether path is in /tmp dir for Kaggle environment * Update save.py Move temporary_location to /tmp in Kaggle * Enhance Kaggle environment support in save and tokenizer utilities --------- Co-authored-by: dendarrion <37800703+dendarrion@users.noreply.github.com> Co-authored-by: Erland366 <erland.pg366@gmail.com> * Bug fixes * Update pyproject.toml * Update _utils.py * Update __init__.py * Update __init__.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Tied weights * Revert "Tied weights" This reverts commit 8090b7c. * Tied weights * Utils * CE Loss patching * Update __init__.py * Update __init__.py * Patching * Update cross_entropy_loss.py * CE Loss * Update _utils.py * Update _utils.py * CE Loss * Update _utils.py * Update _utils.py * Layernorm * Update _utils.py * Update _utils.py * Post patch * Update _utils.py * Update llama.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * typing * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * int64 * Update _utils.py * Update cross_entropy_loss.py * constexpr * constexpr * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update _utils.py * Update _utils.py * Update _utils.py * CE * Update cross_entropy_loss.py * Update _utils.py * Update llama.py * Update _utils.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update utils.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * typing * Update rope_embedding.py * types * Disable compiling * Update _utils.py * Update _utils.py * Forward hook * Update _utils.py * Update llama.py * Update _utils.py * Update llama.py * Update llama.py * Update _utils.py * Update pyproject.toml * Update _utils.py * Update llama.py * CE Loss * Update cross_entropy_loss.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: Edd <68678137+Erland366@users.noreply.github.com> Co-authored-by: Datta Nimmaturi <datta.nimmaturi@nutanix.com> Co-authored-by: dendarrion <37800703+dendarrion@users.noreply.github.com> Co-authored-by: Erland366 <erland.pg366@gmail.com>
1 parent e2e406e commit 3ea7044

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

unsloth/kernels/cross_entropy_loss.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,13 @@ def _cross_entropy_forward(
7373
mask = col_offsets < VOCAB_SIZE
7474

7575
label_idx = tl.load(labels_ptr).to(tl.int32)
76-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
76+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
7777

7878
# Go logit scaling for Cohere: t * x
7979
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
8080
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
81-
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype)
82-
83-
logits = logits.to(tl.float32)
81+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
82+
8483
c = tl.max(logits, 0)
8584
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
8685

@@ -152,14 +151,13 @@ def _chunked_cross_entropy_forward(
152151
mask = col_offsets < VOCAB_SIZE
153152

154153
label_idx = tl.load(labels_ptr).to(tl.int32)
155-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
154+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
156155

157156
# Go logit scaling for Cohere: t * x
158157
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
159158
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
160-
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype)
159+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
161160

162-
logits = logits.to(tl.float32)
163161
c = tl.max(logits, 0)
164162
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
165163

@@ -229,7 +227,7 @@ def _cross_entropy_backward(
229227
else:
230228
dloss = 0.0
231229

232-
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
230+
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
233231

234232
# Do logit scaling for Cohere
235233
if DO_LOGIT_SCALING:
@@ -241,12 +239,12 @@ def _cross_entropy_backward(
241239
partial = x
242240
if DO_SOFTCAPPING:
243241
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
244-
partial = triton_tanh(x.to(tl.float32) / SOFTCAP).to(x.dtype)
242+
partial = triton_tanh(x / SOFTCAP)
245243
x = SOFTCAP * partial
246244
pass
247245

248246
logsumexp = tl.load(logsumexp_ptr + row_idx)
249-
y = tl.exp(x.to(tl.float32) - logsumexp)
247+
y = tl.exp(x - logsumexp)
250248
y = tl.where(
251249
col_offsets == label_idx,
252250
y - 1.0, # exp(x - logsumexp) - 1
@@ -337,6 +335,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
337335
return losses
338336
pass
339337

338+
340339
@staticmethod
341340
def backward(ctx, dlosses):
342341
logits, logsumexp, labels = ctx.saved_tensors
@@ -345,6 +344,8 @@ def backward(ctx, dlosses):
345344
n_rows, vocab_size = logits.shape
346345

347346
BLOCK_SIZE : int = 4096
347+
div : int
348+
mod : int
348349
div, mod = divmod(vocab_size, BLOCK_SIZE)
349350
n_blocks : int = div + (mod != 0)
350351

0 commit comments

Comments
 (0)