Skip to content

Commit 899caf0

Browse files
danielhanchentimothelaborieeltociearErland366Datta0
authored
Qwen 2.5 (#1280)
* Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py * n_items * Update cross_entropy_loss.py * Fix DPO, ORPO * Update _utils.py * Update _utils.py * fix/transformers-unpack (#1180) * Fix DPO, ORPO (#1177) * Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py * n_items * Update cross_entropy_loss.py * Fix DPO, ORPO * Update _utils.py --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> * Add warning for missing Unpack and KwargsForCausalLM in older Transformers versions --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> * Update cross_entropy_loss.py * Update _utils.py * Update _utils.py * donot upcast lm_head and embeddings to float32 (#1186) * Cleanup upcast logs (#1188) * Fix/phi-longrope (#1193) * Enhance rotary embedding handling in LlamaAttention and LongRopeRotaryEmbedding * Typo * Improve rotary embedding handling in LlamaAttention to prevent errors with short KV cache * Update llama.py * Update llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update transformers * Unk token issues * Update _utils.py * Fix pad token * Update llama.py * Typo * ignored labels * Revert "ignored labels" This reverts commit 9d07be0. * More patching * Update _utils.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Feat/all tmp (#1219) * Update save.py Check whether path is in /tmp dir for Kaggle environment * Update save.py Move temporary_location to /tmp in Kaggle * Enhance Kaggle environment support in save and tokenizer utilities --------- Co-authored-by: dendarrion <37800703+dendarrion@users.noreply.github.com> Co-authored-by: Erland366 <erland.pg366@gmail.com> * Bug fixes * Update pyproject.toml * Update _utils.py * Update __init__.py * Update __init__.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Tied weights * Revert "Tied weights" This reverts commit 8090b7c. * Tied weights * Utils * CE Loss patching * Update __init__.py * Update __init__.py * Patching * Update cross_entropy_loss.py * CE Loss * Update _utils.py * Update _utils.py * CE Loss * Update _utils.py * Update _utils.py * Layernorm * Update _utils.py * Update _utils.py * Post patch * Update _utils.py * Update llama.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * typing * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * int64 * Update _utils.py * Update cross_entropy_loss.py * constexpr * constexpr * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update _utils.py * Update _utils.py * Update _utils.py * CE * Update cross_entropy_loss.py * Update _utils.py * Update llama.py * Update _utils.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update utils.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * typing * Update rope_embedding.py * types * Disable compiling * Update _utils.py * Update _utils.py * Forward hook * Update _utils.py * Update llama.py * Update _utils.py * Update llama.py * Update llama.py * Update _utils.py * Update pyproject.toml * Update _utils.py * Update llama.py * CE Loss * Update cross_entropy_loss.py * Update _utils.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Fix: cast logits to float32 in cross_entropy_forward to prevent errors (#1254) * Fix: cast logits to float32 in cross_entropy_forward to prevent errors * Update cross_entropy_loss.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Throw error when inferencing longer than max_popsition_embeddings (#1236) * Throw error when inferencing longer than max_popsition_embeddings without rope scaling * Update llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * CLI now handles user input strings for dtype correctly (#1235) Co-authored-by: root <root@ieeres.chu.cam.ac.uk> * Update flex_attention.py * Update _utils.py * Update _utils.py * Update flex_attention.py * Update flex_attention.py * Update loader.py * Update loader.py * Update flex_attention.py * Update flex_attention.py * Update flex_attention.py * Update flex_attention.py * Update _utils.py * Update cross_entropy_loss.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * triton_cast * Update utils.py * Qwen 2.5 Coder --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: Edd <68678137+Erland366@users.noreply.github.com> Co-authored-by: Datta Nimmaturi <datta.nimmaturi@nutanix.com> Co-authored-by: dendarrion <37800703+dendarrion@users.noreply.github.com> Co-authored-by: Erland366 <erland.pg366@gmail.com> Co-authored-by: Edwin Fennell <edwinfennell1@gmail.com> Co-authored-by: root <root@ieeres.chu.cam.ac.uk>
1 parent 0c8c5ed commit 899caf0

File tree

6 files changed

+77
-21
lines changed

6 files changed

+77
-21
lines changed

unsloth/kernels/cross_entropy_loss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import triton
1616
import triton.language as tl
1717
import torch
18-
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
18+
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast
1919
from transformers.models.llama.modeling_llama import logger
2020
from packaging.version import Version
2121

@@ -64,7 +64,7 @@ def _cross_entropy_forward(
6464
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
6565
"""
6666
row_idx = tl.program_id(0)
67-
logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
67+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
6868
loss_ptr += row_idx
6969
logsumexp_ptr += row_idx
7070
labels_ptr += row_idx
@@ -142,7 +142,7 @@ def _chunked_cross_entropy_forward(
142142
"""
143143
row_idx = tl.program_id(0)
144144
chunk_idx = tl.program_id(1)
145-
logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
145+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
146146
loss_ptr += row_idx
147147
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
148148
labels_ptr += row_idx
@@ -216,7 +216,7 @@ def _cross_entropy_backward(
216216
row_idx = tl.program_id(0)
217217
block_idx = tl.program_id(1)
218218

219-
logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
219+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
220220
dloss_ptr += row_idx * dloss_row_stride
221221
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
222222
mask = col_offsets < VOCAB_SIZE
@@ -400,6 +400,6 @@ def fast_cross_entropy_loss(
400400
pass
401401

402402
# Patch CE Losses in transformers
403-
def patch_loss_functions():
404-
_patch_loss_functions(fast_cross_entropy_loss)
403+
def patch_loss_functions(torch_compile = True):
404+
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
405405
pass

unsloth/kernels/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,18 @@
3131
# tl.math.tanh now is libdevice.tanh
3232
from packaging.version import Version
3333
import triton
34+
import triton.language as tl
3435
if Version(triton.__version__) >= Version("3.0.0"):
3536
from triton.language.extra import libdevice
3637
triton_tanh = libdevice.tanh
38+
triton_cast = tl.cast
3739
else:
38-
import triton.language as tl
3940
triton_tanh = tl.math.tanh
41+
# No casting in old Triton versions
42+
@triton.jit
43+
def triton_cast(x, dtype):
44+
return x.to(dtype)
45+
pass
4046
pass
4147

4248

unsloth/models/_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,22 @@
104104
# Ignore logging messages
105105
class HideLoggingMessage(logging.Filter):
106106
def __init__(self, text): self.text = text
107-
def filter(self, x): return not x.getMessage().startswith(self.text)
107+
def filter(self, x): return not (self.text in x.getMessage())
108108
pass
109109

110110
# The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here.
111111
from transformers.training_args import logger as transformers_training_args_logger
112112
transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
113113
del transformers_training_args_logger
114114

115+
# Using the default loss: `ForCausalLMLoss`.
116+
try:
117+
from transformers.modeling_utils import logger as transformers_modeling_utils_logger
118+
transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
119+
del transformers_modeling_utils_logger
120+
except:
121+
pass
122+
115123
# =============================================
116124

117125
# =============================================

unsloth/models/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2317,7 +2317,8 @@ def patch_peft_model(
23172317
layer.self_attn.apply_qkv = apply_lora_qkv
23182318
n_qkv += 1
23192319
else:
2320-
if model_type != "qwen2":
2320+
if model_type == "qwen2": n_qkv += 1
2321+
else:
23212322
logger.warning_once(
23222323
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
23232324
"are not enabled or a bias term (like in Qwen) is used."

unsloth/models/mapper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,22 +384,54 @@
384384
"unsloth/Qwen2.5-Math-72B-Instruct",
385385
"Qwen/Qwen2.5-Math-72B-Instruct",
386386
),
387+
"unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : (
388+
"unsloth/Qwen2.5-Coder-0.5B",
389+
"Qwen/Qwen2.5-Coder-0.5B",
390+
),
387391
"unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : (
388392
"unsloth/Qwen2.5-Coder-1.5B",
389393
"Qwen/Qwen2.5-Coder-1.5B",
390394
),
395+
"unsloth/Qwen2.5-Coder-3B-bnb-4bit" : (
396+
"unsloth/Qwen2.5-Coder-3B",
397+
"Qwen/Qwen2.5-Coder-3B",
398+
),
391399
"unsloth/Qwen2.5-Coder-7B-bnb-4bit" : (
392400
"unsloth/Qwen2.5-Coder-7B",
393401
"Qwen/Qwen2.5-Coder-7B",
394402
),
403+
"unsloth/Qwen2.5-Coder-14B-bnb-4bit" : (
404+
"unsloth/Qwen2.5-Coder-14B",
405+
"Qwen/Qwen2.5-Coder-14B",
406+
),
407+
"unsloth/Qwen2.5-Coder-32B-bnb-4bit" : (
408+
"unsloth/Qwen2.5-Coder-32B",
409+
"Qwen/Qwen2.5-Coder-32B",
410+
),
411+
"unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : (
412+
"unsloth/Qwen2.5-Coder-Instruct-0.5B",
413+
"Qwen/Qwen2.5-Coder-Instruct-0.5B",
414+
),
395415
"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : (
396416
"unsloth/Qwen2.5-Coder-Instruct-1.5B",
397417
"Qwen/Qwen2.5-Coder-Instruct-1.5B",
398418
),
419+
"unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : (
420+
"unsloth/Qwen2.5-Coder-3B-Instruct",
421+
"Qwen/Qwen2.5-Coder-3B-Instruct",
422+
),
399423
"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : (
400424
"unsloth/Qwen2.5-Coder-7B-Instruct",
401425
"Qwen/Qwen2.5-Coder-7B-Instruct",
402426
),
427+
"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : (
428+
"unsloth/Qwen2.5-Coder-14B-Instruct",
429+
"Qwen/Qwen2.5-Coder-14B-Instruct",
430+
),
431+
"unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : (
432+
"unsloth/Qwen2.5-Coder-32B-Instruct",
433+
"Qwen/Qwen2.5-Coder-32B-Instruct",
434+
),
403435
"unsloth/Llama-3.2-1B-bnb-4bit" : (
404436
"unsloth/Llama-3.2-1B",
405437
"meta-llama/Llama-3.2-1B",

unsloth/tokenizer_utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -588,15 +588,21 @@ def load_correct_tokenizer(
588588
def _fix_chat_template(chat_template):
589589
endfor = "{% endfor %}"
590590
where = chat_template.find(endfor)
591-
if where == -1: return chat_template
591+
if where == -1:
592+
endfor = "{%- endfor %}"
593+
where = chat_template.find(endfor)
594+
if where == -1:
595+
return chat_template
592596

593597
after_endfor = chat_template[where + len(endfor):]
594598

595-
if "{% if" not in after_endfor and "{% set " not in after_endfor and \
599+
dash = "-" if endfor.startswith("{%-") else ""
600+
601+
if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
596602
after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
597603
after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:
598604

599-
after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}"
605+
after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor
600606

601607
chat_template = chat_template[:where + len(endfor)] + after_endfor
602608
pass
@@ -643,10 +649,12 @@ def fix_chat_template(tokenizer):
643649

644650
if no == yes:
645651
# SAME?! That's not good! We check for add_generation_prompt
646-
if "{% if add_generation_prompt %}" not in chat_template:
652+
if "{% if add_generation_prompt %}" not in chat_template and \
653+
"{%- if add_generation_prompt %}" not in chat_template:
647654
# Try fixing it by adding it
648655
new_chat_template = _fix_chat_template(chat_template)
649-
if "{% if add_generation_prompt %}" not in new_chat_template:
656+
if "{% if add_generation_prompt %}" not in new_chat_template and \
657+
"{%- if add_generation_prompt %}" not in new_chat_template:
650658
raise RuntimeError(
651659
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
652660
"does not have a {% if add_generation_prompt %} for generation purposes.\n"\
@@ -1001,13 +1009,14 @@ def patch_sft_trainer_tokenizer():
10011009
# Also DPO weirdly tokenizes non numeric columns? Delete them!
10021010
check_text += \
10031011
"\n"\
1004-
"column_names = set(self.train_dataset.column_names)\n"\
1005-
"check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
1006-
" 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
1007-
" 'prompt_input_ids', 'prompt_attention_mask']\n"\
1008-
"if all(x in column_names for x in check):\n"\
1009-
" self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
1010-
"del check, column_names\n"\
1012+
"if hasattr(self.train_dataset, 'column_names'):\n"\
1013+
" column_names = set(self.train_dataset.column_names)\n"\
1014+
" check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
1015+
" 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
1016+
" 'prompt_input_ids', 'prompt_attention_mask']\n"\
1017+
" if all(x in column_names for x in check):\n"\
1018+
" self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
1019+
" del check, column_names\n"\
10111020
"\n"
10121021

10131022
check_text = check_text.split("\n")

0 commit comments

Comments
 (0)