Skip to content

Commit 0e5a507

Browse files
danielhanchentimothelaborieeltociear
authored
Many bug fixes (#1162)
* Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
1 parent 1f52468 commit 0e5a507

File tree

7 files changed

+65
-11
lines changed

7 files changed

+65
-11
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
181181
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
182182
```
183183

184+
### Windows Installation
185+
186+
To run Unsloth directly on Windows:
187+
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
188+
- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
189+
```python
190+
trainer = SFTTrainer(
191+
dataset_num_proc=1,
192+
...
193+
)
194+
```
195+
184196
For **advanced installation instructions** or if you see weird errors during installations:
185197

186198
1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ huggingface = [
4444
"wheel>=0.42.0",
4545
"numpy",
4646
"accelerate>=0.34.1",
47-
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
47+
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
4848
"peft>=0.7.1,!=0.11.0",
4949
"protobuf<4.0.0",
5050
"huggingface_hub",
@@ -227,7 +227,7 @@ colab-new = [
227227
]
228228
colab-no-deps = [
229229
"accelerate>=0.34.1",
230-
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
230+
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
231231
"peft>=0.7.1",
232232
"xformers<0.0.27",
233233
"bitsandbytes>=0.43.3",

unsloth/chat_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@
678678
{{- end }}
679679
{{- if .Tools }}
680680
681-
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question.
681+
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
682682
{{- end }}
683683
{{- end }}<|eot_id|>
684684
{{- range $i, $_ := .Messages }}

unsloth/kernels/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
create_block_mask as _create_block_mask,
3232
)
3333
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
34-
HAS_FLEX_ATTENTION = True
34+
HAS_FLEX_ATTENTION = False
3535
except:
3636
HAS_FLEX_ATTENTION = False
3737
pass

unsloth/models/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2024.10.3"
15+
__version__ = "2024.10.4"
1616

1717
__all__ = [
1818
"prepare_model_for_kbit_training",
@@ -1194,8 +1194,8 @@ def patch_gradient_accumulation_fix(Trainer):
11941194
logger.warning_once(
11951195
"Unsloth: We fixed a gradient accumulation bug, "\
11961196
"but it seems like you don't have the latest transformers version!\n"\
1197-
"Please update transformers via:\n"\
1198-
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
1197+
"Please update transformers, TRL and unsloth via:\n"\
1198+
'`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`'
11991199
)
12001200
pass
12011201
pass

unsloth/models/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,9 @@ def MistralForCausalLM_fast_forward(
254254

255255
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
256256
loss = fast_cross_entropy_loss(
257-
logits = shift_logits,
258-
labels = shift_labels,
257+
logits = shift_logits,
258+
labels = shift_labels,
259+
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
259260
)
260261
pass
261262

unsloth/tokenizer_utils.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ def check_nvidia():
831831
PRE_CHECK = check_nvidia()
832832

833833

834+
import inspect
834835
from inspect import getsource
835836
import trl.trainer.sft_trainer
836837
from trl.trainer.sft_trainer import *
@@ -869,6 +870,35 @@ def neftune_post_forward_hook(module, input, output):
869870
pass
870871

871872

873+
def patch_trl_tokenizer_processing_class(trainer_name):
874+
# New TRL removes tokenizer!
875+
# We return it back!
876+
exec(f"from trl import {trainer_name}", globals())
877+
if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None
878+
parameters = eval(f"inspect.signature({trainer_name}).parameters")
879+
if "tokenizer" in parameters: return None
880+
881+
args = {
882+
key : \
883+
value.default \
884+
if type(value.default) is not str else \
885+
f"'{value.default}'" \
886+
for key, value in parameters.items()
887+
}
888+
args["tokenizer"] = None
889+
new_args = args.copy()
890+
del new_args["tokenizer"]
891+
del new_args["processing_class"]
892+
new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \
893+
f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class"
894+
args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items())
895+
args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):"
896+
args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})"
897+
new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n"""
898+
return new_class
899+
pass
900+
901+
872902
def patch_sft_trainer_tokenizer():
873903
"""
874904
Patches the trainer with changes
@@ -884,7 +914,8 @@ def patch_sft_trainer_tokenizer():
884914

885915
check_text = \
886916
"\n"\
887-
"test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n"\
917+
"if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
918+
"test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\
888919
"chat_template = getattr(tokenizer, 'chat_template', None)\n"\
889920
"chat_template = '' if chat_template is None else chat_template\n"\
890921
"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
@@ -941,7 +972,8 @@ def patch_sft_trainer_tokenizer():
941972
" from transformers import __version__ as transformers_version\n"\
942973
" from packaging.version import Version\n"\
943974
" if Version(transformers_version) <= Version('4.45.2'):\n"\
944-
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\
975+
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
976+
" '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\
945977
"except:\n"\
946978
" pass\n"\
947979
"\n\n"
@@ -981,4 +1013,13 @@ def patch_sft_trainer_tokenizer():
9811013
pass
9821014
pass
9831015

1016+
# Fix TRL trainers with removed tokenizer args (got replaced with processing_class)
1017+
for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"):
1018+
trainer_text = patch_trl_tokenizer_processing_class(trainer_name)
1019+
if trainer_text is None: continue
1020+
exec(trainer_text, globals())
1021+
exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals())
1022+
pass
1023+
1024+
# FInally patch TRL tokenizer things
9841025
patch_sft_trainer_tokenizer()

0 commit comments

Comments
 (0)