diff --git a/README.md b/README.md
index 9e288ad2d..3e635dabe 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
-### Finetune Llama 3.2, Mistral, Phi-3.5, Qwen 2.5 & Gemma 2-5x faster with 80% less memory!
+### Finetune Llama 3.3, Mistral, Phi-4, Qwen 2.5 & Gemma 2-5x faster with 80% less memory!

@@ -22,41 +22,41 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
| Unsloth supports | Free Notebooks | Performance | Memory use |
|-----------|---------|--------|----------|
-| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 60% less |
-| **Phi-4** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) | 2x faster | 50% less |
-| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 2x faster | 40% less |
-| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 60% less |
-| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma2_(9B)-Alpaca.ipynb) | 2x faster | 63% less |
-| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb) | 2x faster | 63% less |
-| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-Conversational.ipynb) | 2.2x faster | 73% less |
-| **Ollama** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb) | 1.9x faster | 43% less |
-| **ORPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-ORPO.ipynb) | 1.9x faster | 43% less |
-| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb) | 1.9x faster | 43% less |
+| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less |
+| **Phi-4 (14B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) | 2x faster | 70% less |
+| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 2x faster | 50% less |
+| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less |
+| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma2_(9B)-Alpaca.ipynb) | 2x faster | 70% less |
+| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb) | 2x faster | 70% less |
+| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-Conversational.ipynb) | 2.2x faster | 75% less |
+| **Ollama** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb) | 1.9x faster | 60% less |
+| **ORPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-ORPO.ipynb) | 1.9x faster | 50% less |
+| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb) | 1.9x faster | 50% less |
- See [all our notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) and [all our models](https://docs.unsloth.ai/get-started/all-our-models)
- **Kaggle Notebooks** for [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook), [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
- Run notebooks for [Llama 3.2 conversational](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb), [Llama 3.1 conversational](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
-- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
+- This [text completion notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_(7B)-Text_Completion.ipynb) is for continued pretraining / raw text
- This [continued pretraining notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-CPT.ipynb) is for learning another language
- Click [here](https://docs.unsloth.ai/) for detailed documentation for Unsloth.
## 🦥 Unsloth.ai News
+- 📣 NEW! [DeepSeek-R1](https://unsloth.ai/blog/deepseek-r1) - the most powerful open reasoning models with Llama & Qwen distillations. Run or fine-tune them now! More details: [unsloth.ai/blog/deepseek-r1](https://unsloth.ai/blog/deepseek-r1). All model uploads: [here](https://huggingface.co/collections/unsloth/deepseek-r1-all-versions-678e1c48f5d2fce87892ace5).
- 📣 NEW! [Phi-4](https://unsloth.ai/blog/phi4) by Microsoft is now supported. We also [fixed bugs](https://unsloth.ai/blog/phi4) in Phi-4 and [uploaded GGUFs, 4-bit](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa). Try the [Phi-4 Colab notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb)
- 📣 NEW! [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is supported.
- 📣 NEW! We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support.
-- 📣 NEW! Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
-- 📣 NEW! [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing), [Qwen 2.5 VL (7B)](https://colab.research.google.com/drive/1whHb54GNZMrNxIsi2wm2EY_-Pvo2QyKh?usp=sharing) and [Pixtral (12B) 2409](https://colab.research.google.com/drive/1K9ZrdwvZRE96qGkCq_e88FgV3MLnymQq?usp=sharing)
-- 📣 NEW! Qwen-2.5 including [Coder](https://colab.research.google.com/drive/18sN803sU23XuJV9Q8On2xgqHSer6-UZF?usp=sharing) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
-- 📣 NEW! We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers.
+- 📣 Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
+- 📣 [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb), [Qwen 2.5 VL (7B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2_VL_(7B)-Vision.ipynb) and [Pixtral (12B) 2409](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Pixtral_(12B)-Vision.ipynb)
Click for more news
-- 📣 Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)!
-- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM!
-- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported
+- 📣 We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers.
+- 📣 Try out [Chat interface](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Unsloth_Studio.ipynb)!
+- 📣 NEW! Qwen-2.5 including [Coder](https://unsloth.ai/blog/qwen-coder) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_Coder_(14B)-Conversational.ipynb)
+- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_Small_(22B)-Alpaca.ipynb) finetuning fits in under 16GB of VRAM!
- 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
-- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
-- 📣 [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
+- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-CPT.ipynb) for other languages like Korean!
+- 📣 [2x faster inference](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Inference.ipynb) added for all our models
- 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
@@ -82,23 +82,17 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
## 🥇 Performance Benchmarking
-- For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
+- For our most detailed benchmarks, read our [Llama 3.3 Blog](https://unsloth.ai/blog/llama3-3).
+- Benchmarking of Unsloth was also conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
-| 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
-|--------------|--------------|-----------------|---------------------|-----------------|
-| Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
-| LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
-| OASST | 1x | 1.19x | 2.17x | **14.83x** |
-| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
-
-- Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
+We tested using the Alpaca Dataset, a batch size of 2, gradient accumulation steps of 4, rank = 32, and applied QLoRA on all linear layers (q, k, v, o, gate, up, down):
+
+| Model | VRAM | 🦥 Unsloth speed | 🦥 VRAM reduction | 🦥 Longer context | 😊 Hugging Face + FA2 |
+|----------------|-------|-----------------|----------------|----------------|--------------------|
+| Llama 3.3 (70B)| 80GB | 2x | >75% | 13x longer | 1x |
+| Llama 3.1 (8B) | 80GB | 2x | >70% | 12x longer | 1x |
-| Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
-| --- | --- | --- | --- | --- | --- |
-| Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
-| Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
-| Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
-| DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
+

@@ -359,119 +353,28 @@ dpo_trainer.train()
```
## 🥇 Detailed Benchmarking Tables
-- Click "Code" for fully reproducible examples
-- "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
-- For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
-
-| 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
-|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
-| Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
-| code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
-| seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
-| memory MB| 18235 | 15365 | 9631 | 8525 | | |
-| % saved| | 15.74 | 47.18 | 53.25 | | | |
-
-### Llama-Factory 3rd party benchmarking
-- [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
-
-| Method | Bits | TGS | GRAM | Speed |
-| --- | --- | --- | --- | --- |
-| HF | 16 | 2392 | 18GB | 100% |
-| HF+FA2 | 16 | 2954 | 17GB | 123% |
-| Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
-| HF | 4 | 2415 | 9GB | 101% |
-| Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
-
-### Performance comparisons between popular models
-
- Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)
-
-### Mistral 7b
-| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
-|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
-| Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
-| code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
-| seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
-| memory MB | 32853 | 19385 | 12465 | 10271 | | |
-| % saved| | 40.99 | 62.06 | 68.74 | | |
-
-### CodeLlama 34b
-| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
-|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
-| Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
-| code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
-| seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
-| memory MB | 40000 | 33217 | 27413 | 22161 | | |
-| % saved| | 16.96| 31.47 | 44.60 | | | |
-
-### 1 Tesla T4
-
-| 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
-|--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
-| Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
-| code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
-| seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
-| memory MB | 7199 | 7059 | 6459 | 5443 | | |
-| % saved | | 1.94 | 10.28 | 24.39 | | |
-
-### 2 Tesla T4s via DDP
-
- | 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
-|--------------|----------|-------------|-----------------|--------------|---------------|-------------|
-| Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
-| code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
-| seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
-| memory MB| 9176 | 9128 | 6904 | 6782 | | |
-| % saved | | 0.52 | 24.76 | 26.09 | | | |
-
-
-### Performance comparisons on 1 Tesla T4 GPU:
-
- Click for Time taken for 1 epoch
-
-One Tesla T4 on Google Colab
-`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
-
-| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
-| --- | --- | --- | --- | --- | --- |
-| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
-| Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
-| Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
-| Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
-
-**Peak Memory Usage**
-
-| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
-| --- | --- | --- | --- | --- | --- |
-| Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
-| Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
-| Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
-| Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
-
-
-
- Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:
-**Time taken for 1 epoch**
-
-Two Tesla T4s on Kaggle
-`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
-
-| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
-| --- | --- | --- | --- | --- | --- |
-| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
-| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
-| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
+### Context length benchmarks
+#### Llama 3.1 (8B) max. context length
+We tested Llama 3.1 (8B) Instruct and did 4bit QLoRA on all linear layers (Q, K, V, O, gate, up and down) with rank = 32 with a batch size of 1. We padded all sequences to a certain maximum sequence length to mimic long context finetuning workloads.
+| GPU VRAM | 🦥Unsloth context length | Hugging Face + FA2 |
+|----------|-----------------------|-----------------|
+| 8 GB | 2,972 | OOM |
+| 12 GB | 21,848 | 932 |
+| 16 GB | 40,724 | 2,551 |
+| 24 GB | 78,475 | 5,789 |
+| 40 GB | 153,977 | 12,264 |
+| 48 GB | 191,728 | 15,502 |
+| 80 GB | 342,733 | 28,454 |
+
+#### Llama 3.3 (70B) max. context length
+We tested Llama 3.3 (70B) Instruct on a 80GB A100 and did 4bit QLoRA on all linear layers (Q, K, V, O, gate, up and down) with rank = 32 with a batch size of 1. We padded all sequences to a certain maximum sequence length to mimic long context finetuning workloads.
+
+| GPU VRAM | 🦥Unsloth context length | Hugging Face + FA2 |
+|----------|------------------------|------------------|
+| 48 GB | 12,106 | OOM |
+| 80 GB | 89,389 | 6,916 |
-**Peak Memory Usage on a Multi GPU System (2 GPUs)**
-
-| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
-| --- | --- | --- | --- | --- | --- |
-| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
-| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
-| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
-
-* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
-
+

diff --git a/pyproject.toml b/pyproject.toml
index b24abd355..d89ea2c4d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,7 +39,7 @@ triton = [
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
huggingface = [
- "unsloth_zoo>=2025.1.2",
+ "unsloth_zoo>=2025.2.2",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
@@ -131,6 +131,12 @@ cu124onlytorch240 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
+cu118onlytorch250 = [
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+]
cu121onlytorch250 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
@@ -147,6 +153,12 @@ cu124onlytorch250 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
+cu118onlytorch251 = [
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+]
cu121onlytorch251 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
@@ -163,6 +175,28 @@ cu124onlytorch251 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
+cu118onlytorch260 = [
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+]
+cu124onlytorch260 = [
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
+]
+cu126onlytorch260 = [
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+]
cu118 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
@@ -223,21 +257,31 @@ cu121-torch240 = [
"bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch240]",
]
-cu121-torch250 = [
+cu124-torch240 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
- "unsloth[cu121onlytorch250]",
+ "unsloth[cu124onlytorch240]",
]
-cu124-torch240 = [
+cu118-torch250 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
- "unsloth[cu124onlytorch240]",
+ "unsloth[cu118onlytorch250]",
+]
+cu121-torch250 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.43.3",
+ "unsloth[cu121onlytorch250]",
]
cu124-torch250 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch250]",
]
+cu118-torch251 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.43.3",
+ "unsloth[cu118onlytorch251]",
+]
cu121-torch251 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
@@ -248,6 +292,21 @@ cu124-torch251 = [
"bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch251]",
]
+cu118-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu118onlytorch260]",
+]
+cu124-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu124onlytorch260]",
+]
+cu126-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu126onlytorch260]",
+]
kaggle = [
"unsloth[huggingface]",
]
@@ -285,7 +344,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
- "unsloth_zoo>=2025.1.2",
+ "unsloth_zoo>=2025.2.2",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
@@ -381,16 +440,22 @@ cu121-ampere-torch240 = [
"unsloth[cu121onlytorch240]",
"unsloth[flashattention]",
]
-cu121-ampere-torch250 = [
+cu124-ampere-torch240 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
- "unsloth[cu121onlytorch250]",
+ "unsloth[cu124onlytorch240]",
"unsloth[flashattention]",
]
-cu124-ampere-torch240 = [
+cu118-ampere-torch250 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
- "unsloth[cu124onlytorch240]",
+ "unsloth[cu118onlytorch250]",
+ "unsloth[flashattention]",
+]
+cu121-ampere-torch250 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.43.3",
+ "unsloth[cu121onlytorch250]",
"unsloth[flashattention]",
]
cu124-ampere-torch250 = [
@@ -399,6 +464,12 @@ cu124-ampere-torch250 = [
"unsloth[cu124onlytorch250]",
"unsloth[flashattention]",
]
+cu118-ampere-torch251 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.43.3",
+ "unsloth[cu118onlytorch251]",
+ "unsloth[flashattention]",
+]
cu121-ampere-torch251 = [
"unsloth[huggingface]",
"bitsandbytes>=0.43.3",
@@ -411,6 +482,24 @@ cu124-ampere-torch251 = [
"unsloth[cu124onlytorch251]",
"unsloth[flashattention]",
]
+cu118-ampere-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu118onlytorch260]",
+ "unsloth[flashattention]",
+]
+cu124-ampere-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu124onlytorch260]",
+ "unsloth[flashattention]",
+]
+cu126-ampere-torch260 = [
+ "unsloth[huggingface]",
+ "bitsandbytes>=0.45.1",
+ "unsloth[cu126onlytorch260]",
+ "unsloth[flashattention]",
+]
[project.urls]
homepage = "http://www.unsloth.ai"
diff --git a/unsloth/__init__.py b/unsloth/__init__.py
index 8002fbaef..bdde33c50 100644
--- a/unsloth/__init__.py
+++ b/unsloth/__init__.py
@@ -86,6 +86,10 @@
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
pass
+# First check if CUDA is available ie a NVIDIA GPU is seen
+if not torch.cuda.is_available():
+ raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!")
+
# Fix Xformers performance issues since 0.0.25
import importlib.util
from pathlib import Path
@@ -130,71 +134,69 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
torch.cuda.is_bf16_supported = is_bf16_supported
pass
+# For Gradio HF Spaces?
+# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
+import triton
+libcuda_dirs = lambda: None
+if Version(triton.__version__) >= Version("3.0.0"):
+ try: from triton.backends.nvidia.driver import libcuda_dirs
+ except: pass
+else: from triton.common.build import libcuda_dirs
+
# Try loading bitsandbytes and triton
import bitsandbytes as bnb
+try:
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
+ libcuda_dirs()
+except:
+ warnings.warn(
+ "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
+ )
-if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
-
- import triton
- libcuda_dirs = lambda: None
- if Version(triton.__version__) >= Version("3.0.0"):
- try: from triton.backends.nvidia.driver import libcuda_dirs
- except: pass
- else: from triton.common.build import libcuda_dirs
+ if os.path.exists("/usr/lib64-nvidia"):
+ os.system("ldconfig /usr/lib64-nvidia")
+ elif os.path.exists("/usr/local"):
+ # Sometimes bitsandbytes cannot be linked properly in Runpod for example
+ possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
+ find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
+ possible_cudas = [find_cuda.search(x) for x in possible_cudas]
+ possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
+
+ # Try linking cuda folder, or everything in local
+ if len(possible_cudas) == 0:
+ os.system("ldconfig /usr/local/")
+ else:
+ find_number = re.compile(r"([\d\.]{2,})")
+ latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
+ latest_cuda = possible_cudas[latest_cuda]
+ os.system(f"ldconfig /usr/local/{latest_cuda}")
+ pass
+ importlib.reload(bnb)
+ importlib.reload(triton)
try:
+ libcuda_dirs = lambda: None
+ if Version(triton.__version__) >= Version("3.0.0"):
+ try: from triton.backends.nvidia.driver import libcuda_dirs
+ except: pass
+ else: from triton.common.build import libcuda_dirs
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
- "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
+ "Unsloth: CUDA is not linked properly.\n"\
+ "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
+ "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
+ "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
+ "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
+ "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
)
-
- if os.path.exists("/usr/lib64-nvidia"):
- os.system("ldconfig /usr/lib64-nvidia")
- elif os.path.exists("/usr/local"):
- # Sometimes bitsandbytes cannot be linked properly in Runpod for example
- possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
- find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
- possible_cudas = [find_cuda.search(x) for x in possible_cudas]
- possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
-
- # Try linking cuda folder, or everything in local
- if len(possible_cudas) == 0:
- os.system("ldconfig /usr/local/")
- else:
- find_number = re.compile(r"([\d\.]{2,})")
- latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
- latest_cuda = possible_cudas[latest_cuda]
- os.system(f"ldconfig /usr/local/{latest_cuda}")
- pass
-
- importlib.reload(bnb)
- importlib.reload(triton)
- try:
- libcuda_dirs = lambda: None
- if Version(triton.__version__) >= Version("3.0.0"):
- try: from triton.backends.nvidia.driver import libcuda_dirs
- except: pass
- else: from triton.common.build import libcuda_dirs
- cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
- libcuda_dirs()
- except:
- warnings.warn(
- "Unsloth: CUDA is not linked properly.\n"\
- "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
- "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
- "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
- "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
- "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
- )
- pass
pass
# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
- if Version(unsloth_zoo_version) < Version("2025.1.2"):
+ if Version(unsloth_zoo_version) < Version("2025.2.2"):
try:
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
except:
diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py
index c3b94c670..8bb548519 100644
--- a/unsloth/_auto_install.py
+++ b/unsloth/_auto_install.py
@@ -18,14 +18,16 @@
v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
-if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
+if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
-elif v < V('2.6.0'): x = 'cu{}{}-torch250'
+elif v < V('2.5.1'): x = 'cu{}{}-torch250'
+elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
+elif v < V('2.7.0'): x = 'cu{}{}-torch260'
else: raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
\ No newline at end of file
diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py
index d8dc38522..c40139323 100644
--- a/unsloth/chat_templates.py
+++ b/unsloth/chat_templates.py
@@ -759,6 +759,10 @@
CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
+
+for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"):
+ CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"]
+ DEFAULT_SYSTEM_MESSAGE[version] = ""
pass
diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py
index de543962e..f052914f9 100644
--- a/unsloth/kernels/utils.py
+++ b/unsloth/kernels/utils.py
@@ -15,6 +15,7 @@
import triton
MAX_FUSED_SIZE : int = 65536
next_power_of_2 = triton.next_power_of_2
+import functools
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
@@ -66,6 +67,8 @@ def calculate_settings(n : int) -> (int, int,):
CUDA_STREAM = None
get_ptr = bnb.functional.get_ptr
import ctypes
+ctypes_c_int = ctypes.c_int
+ctypes_c_int32 = ctypes.c_int32
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
@@ -98,25 +101,31 @@ def get_lora_parameters(proj):
def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
- base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
+ base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
bias = base_layer.bias
- if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
+ # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
+ if getattr(proj, "disable_adapters", True) or proj.merged:
return W, QUANT_STATE(W), None, None, None, bias
pass
active_adapter = proj.active_adapters[0] if \
- hasattr(proj, "active_adapters") else proj.active_adapter
+ getattr(proj, "active_adapters", ) else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s, bias
pass
+global WEIGHT_BUFFER
+WEIGHT_BUFFER = None
+global ABSMAX_BUFFER
+ABSMAX_BUFFER = None
if HAS_CUDA_STREAM:
- def fast_dequantize(W, quant_state = None, out = None):
+ @torch.inference_mode
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
@@ -139,36 +148,54 @@ def fast_dequantize(W, quant_state = None, out = None):
global CUDA_STREAM
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
+ n_elements_absmax = absmax.numel()
+
# Create weight matrix
- if out is None:
- out = torch.empty(shape, dtype = dtype, device = "cuda:0")
+ if use_global_buffer:
+
+ # Use same buffers for faster inference
+ size = shape[0]*shape[1]
+ global WEIGHT_BUFFER
+ global ABSMAX_BUFFER
+ if WEIGHT_BUFFER is None:
+ WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False)
+ ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
+
+ if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
+ if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
+
+ out = WEIGHT_BUFFER[:size].view(shape)
+ out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
else:
- assert(out.shape == shape)
- assert(out.dtype == dtype)
+ if out is None:
+ out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False)
+ else:
+ assert(out.shape == shape)
+ assert(out.dtype == dtype)
+ out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
+ pass
# NF4 dequantization of statistics
- n_elements_absmax = absmax.numel()
- out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
-
- # Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
- ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM,
)
out_absmax += offset
+ # Dequantize W
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
- ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
else:
- def fast_dequantize(W, quant_state = None, out = None):
+ @torch.inference_mode
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
@@ -189,29 +216,45 @@ def fast_dequantize(W, quant_state = None, out = None):
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
+ n_elements_absmax = absmax.numel()
+
# Create weight matrix
- if out is None:
- out = torch.empty(shape, dtype = dtype, device = "cuda:0")
- else:
- assert(out.shape == shape)
- assert(out.dtype == dtype)
+ if use_global_buffer:
- # NF4 dequantization of statistics
- n_elements_absmax = absmax.numel()
- out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
+ # Use same buffers for faster inference
+ size = shape[0]*shape[1]
+ global WEIGHT_BUFFER
+ global ABSMAX_BUFFER
+ if WEIGHT_BUFFER is None:
+ WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False)
+ ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False)
+
+ if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
+ if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
+
+ out = WEIGHT_BUFFER[:size].view(shape)
+ out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
+ else:
+ if out is None:
+ out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False)
+ else:
+ assert(out.shape == shape)
+ assert(out.dtype == dtype)
+ out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False)
+ pass
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
- ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
- ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
@@ -263,17 +306,17 @@ def fast_gemv(X, W, quant_state, out = None):
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
- m = ctypes.c_int32(m)
- n = ctypes.c_int32(n)
- k = ctypes.c_int32(k)
- lda = ctypes.c_int32(lda)
- ldb = ctypes.c_int32(ldb)
- ldc = ctypes.c_int32(ldc)
+ m = ctypes_c_int32(m)
+ n = ctypes_c_int32(n)
+ k = ctypes_c_int32(k)
+ lda = ctypes_c_int32(lda)
+ ldb = ctypes_c_int32(ldb)
+ ldc = ctypes_c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
- ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
)
df += offset
absmax = df
@@ -281,7 +324,7 @@ def fast_gemv(X, W, quant_state, out = None):
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
- blocksize = ctypes.c_int32(blocksize)
+ blocksize = ctypes_c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize, CUDA_STREAM,)
@@ -327,17 +370,17 @@ def fast_gemv(X, W, quant_state, out = None):
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
- m = ctypes.c_int32(m)
- n = ctypes.c_int32(n)
- k = ctypes.c_int32(k)
- lda = ctypes.c_int32(lda)
- ldb = ctypes.c_int32(ldb)
- ldc = ctypes.c_int32(ldc)
+ m = ctypes_c_int32(m)
+ n = ctypes_c_int32(n)
+ k = ctypes_c_int32(k)
+ lda = ctypes_c_int32(lda)
+ ldb = ctypes_c_int32(ldb)
+ ldc = ctypes_c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
- ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
)
df += offset
absmax = df
@@ -345,7 +388,7 @@ def fast_gemv(X, W, quant_state, out = None):
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
- blocksize = ctypes.c_int32(blocksize)
+ blocksize = ctypes_c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize,)
@@ -354,6 +397,9 @@ def fast_gemv(X, W, quant_state, out = None):
pass
+torch_mm = torch.mm
+torch_mv = torch.mv
+torch_matmul = torch.matmul
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
@@ -361,12 +407,12 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
- out = torch.matmul(X, W.t(), out = out)
+ out = torch_matmul(X, W.t(), out = out)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out = out)
else:
- W = fast_dequantize(W.t(), W_quant)
- out = torch.matmul(X, W, out = out)
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
+ out = torch_matmul(X, W, out = out)
pass
# Add in LoRA weights
@@ -381,11 +427,11 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
if bsz == 1:
out = out.view(out_dim)
- temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
+ temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
- temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
+ temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
@@ -399,7 +445,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
- W = fast_dequantize(W.t(), W_quant)
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
if X.dim() == 3:
batch, seq_len, d = X.shape
@@ -409,7 +455,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
reshape = False
pass
- out = torch.matmul(X, W, out = out)
+ out = torch_matmul(X, W, out = out)
if W_quant is not None: del W
if A is not None:
diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py
index c52d14f40..b15e04ab7 100644
--- a/unsloth/models/__init__.py
+++ b/unsloth/models/__init__.py
@@ -20,3 +20,4 @@
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported
+from .rl import PatchFastRL
diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py
index 0036a18c4..2ec4adaa1 100644
--- a/unsloth/models/_utils.py
+++ b/unsloth/models/_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "2025.1.5"
+__version__ = "2025.2.4"
__all__ = [
"SUPPORTS_BFLOAT16",
@@ -285,7 +285,11 @@ def _is_openai_available(): return False
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
- from flash_attn.flash_attn_interface import flash_attn_cuda
+ try:
+ # See https://github.com/unslothai/unsloth/issues/1437
+ from flash_attn.flash_attn_interface import flash_attn_gpu
+ except:
+ from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
# Also check for softcapping
@@ -843,7 +847,9 @@ def patch_linear_scaling(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
- if len(rotary_emb) == 0: return None, function
+ if len(rotary_emb) == 0:
+ return None, exec_code + "\n\n" + function
+
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py
index 5dc71f920..9c12abb98 100644
--- a/unsloth/models/dpo.py
+++ b/unsloth/models/dpo.py
@@ -17,115 +17,8 @@
"PatchKTOTrainer",
]
-try:
- from transformers.utils.notebook import (
- IntervalStrategy,
- NotebookTrainingTracker,
- NotebookProgressCallback,
- )
- HAS_NOTEBOOK = True
-except:
- HAS_NOTEBOOK = False
-pass
-import torch
-from ._utils import torch_compile_options
-import inspect
-import torch.nn as nn
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+from .rl import PatchFastRL
+def PatchDPOTrainer(): PatchFastRL("DPO")
-DPOTrainer_metrics = [
- "rewards/chosen",
- "rewards/rejected",
- "rewards/accuracies",
- "rewards/margins",
- "logps/rejected",
- "logps/chosen",
- "logits/rejected",
- "logits/chosen",
-]
-set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
-
-
-def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
- self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
- self.training_loss = 0
- self.last_log = 0
- column_names = [self.first_column] + ["Training Loss"]
- if args.eval_strategy != IntervalStrategy.NO:
- column_names.append("Validation Loss")
- column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
- self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
-pass
-
-
-def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
- # Only for when there is no evaluation
- if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
- values = {"Training Loss": logs["loss"]}
- for metric in DPOTrainer_metrics:
- values[metric.replace("/", " / ")] = logs[metric]
- pass
- # First column is necessarily Step since we're not in epoch eval strategy
- values["Step"] = state.global_step
- self.training_tracker.write_line(values)
- pass
-pass
-
-
-def NotebookTrainingTracker_write_line(self, values):
- """
- Write the values in the inner table.
-
- Args:
- values (`Dict[str, float]`): The values to display.
- """
- if self.inner_table is None:
- self.inner_table = [list(values.keys()), list(values.values())]
- else:
- columns = self.inner_table[0]
- new_values = {}
- for key, value in values.items():
- lowered = key.lower()
- if lowered in set_DPOTrainer_metrics:
- new_values[lowered.replace("/", " / ")] = value
- else:
- new_values[key] = value
- pass
- values = new_values
-
- self.inner_table[0] = columns
- if len(self.inner_table) > 1:
- last_values = self.inner_table[-1]
- first_column = self.inner_table[0][0]
- if last_values[0] != values[first_column]:
- # write new line
- self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
- else:
- # update last line
- new_values = values
- for c in columns:
- if c not in new_values.keys():
- new_values[c] = last_values[columns.index(c)]
- self.inner_table[-1] = [new_values[c] for c in columns]
- else:
- # Edit for evaluation purposes
- self.inner_table.append([values[c] if c in values else 0 for c in columns])
- pass
- pass
-pass
-
-
-def PatchDPOTrainer():
- if HAS_NOTEBOOK:
- from transformers.trainer import is_in_notebook
- if is_in_notebook():
- # Patch DPO notebook printing
- NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
- from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
- DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
- DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
- pass
- pass
-pass
-PatchKTOTrainer = PatchDPOTrainer
+def PatchKTOTrainer(): PatchFastRL("KTO")
diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py
index c65434328..bc29c46ab 100644
--- a/unsloth/models/gemma.py
+++ b/unsloth/models/gemma.py
@@ -210,7 +210,15 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
- if config is not None: return # [TODO] Hack to pass in config - need to remove later
+ if config is not None:
+ # [TODO] Hack to pass in config - need to remove later
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ dim = getattr(config, "head_dim", None)
+ if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
+ device = "cuda"
+ max_position_embeddings = config.max_position_embeddings
+ pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py
index 497a357fe..fb7e96d8d 100644
--- a/unsloth/models/granite.py
+++ b/unsloth/models/granite.py
@@ -89,6 +89,7 @@ def GraniteAttention_fast_forward(
n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
+ dropout_p = self.config.attention_dropout if self.training else 0
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
@@ -135,7 +136,7 @@ def GraniteAttention_fast_forward(
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
- A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling)
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
@@ -143,7 +144,7 @@ def GraniteAttention_fast_forward(
K = K.transpose(1, 2)
V = V.transpose(1, 2)
window = (kv_seq_len, kv_seq_len)
- A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling)
+ A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p)
else:
# Grouped query attention
# if n_groups != 1:
@@ -157,7 +158,7 @@ def GraniteAttention_fast_forward(
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False)
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index edd3ddf94..a337472a3 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -20,7 +20,7 @@
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
-from unsloth_zoo.utils import Version
+from unsloth_zoo.utils import Version, _get_dtype
transformers_version = Version(transformers_version)
# Transformers moved rotary embeddings out of all attention layers
IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1")
@@ -70,7 +70,8 @@
from huggingface_hub.utils._token import get_token
pass
from triton import __version__ as triton_version
-BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None
+HAS_XFORMERS = xformers is not None
+BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
def original_apply_qkv(self, X):
@@ -89,6 +90,8 @@ def original_apply_o(self, X):
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
+# SDPA has GQA internally
+SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__
# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
@@ -243,7 +246,7 @@ def LlamaAttention_fast_forward_inference(
# Grouped query attention
_, _, cached_len, _ = Knn.shape
- if n_groups != 1:
+ if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
@@ -262,7 +265,10 @@ def LlamaAttention_fast_forward_inference(
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
- A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
+ if SDPA_HAS_GQA:
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True)
+ else:
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
@@ -272,15 +278,15 @@ def LlamaAttention_fast_forward_inference(
torch_nn_functional_silu = torch.nn.functional.silu
-def fast_swiglu_inference(self, X):
+def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
- gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
- up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
+ gate = fast_linear_forward(self.gate_proj, X, out = temp_gate)
+ up = fast_linear_forward(self. up_proj, X, out = temp_up)
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
@@ -289,14 +295,23 @@ def fast_swiglu_inference(self, X):
return down
pass
-
-def fast_rms_layernorm_inference(self, X):
+torch_square = torch.square
+torch_mean = torch.mean
+def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None):
old_dtype = X.dtype
- XX = X.to(torch.float32)
- variance = XX.square().mean(-1, keepdim = True)
+ if XX is None:
+ XX = X.to(torch.float32)
+ variance = XX.square().mean(-1, keepdim = True)
+ else:
+ XX.copy_(X)
+ torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance)
+ pass
variance += self.variance_epsilon
XX *= variance.rsqrt_()
- X = XX.to(old_dtype) # Must preserve due to residual
+
+ if XX is None: X = XX.to(old_dtype)
+ else: X.copy_(XX)
+
X *= self.weight
return X
pass
@@ -403,7 +418,7 @@ def LlamaAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if (not HAS_FLASH_ATTENTION and attention_mask is None):
+ if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
@@ -636,6 +651,7 @@ def LlamaModel_fast_forward(
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
IS_COHERE = self.config.model_type.startswith("cohere")
IS_GRANITE = self.config.model_type.startswith("granite")
+
train_embed_tokens = self.embed_tokens.weight.requires_grad
if IS_GEMMA:
@@ -664,7 +680,7 @@ def LlamaModel_fast_forward(
# Fix up attention mask by setting elements to 0
# Specifically for DPO
- if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \
+ if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \
(not train_embed_tokens):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
@@ -792,9 +808,12 @@ def LlamaModel_fast_forward(
pass
pass
- if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"):
+ if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE:
# Transformers main has made it mandatory to pass position_embeddings
# https://github.com/huggingface/transformers/pull/34858
+ # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor)
+ # unsloth's check for granite too has "version >= 4.45.0 (rightly so)".
+ # so let granite always use the attention refactor implementation.
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings)
else:
position_embeddings = None
@@ -898,15 +917,29 @@ def LlamaModel_fast_forward_inference(
attention_mask = None,
):
input_ids = input_ids[:,:self.max_seq_length]
- hidden_states = self.model.embed_tokens(input_ids)
- hidden_states = hidden_states.to(self.config.torch_dtype)
- bsz, q_len, hd = hidden_states.shape
+ bsz, q_len = input_ids.shape
+ hd = self.config.hidden_size
+ mlp_size = self.config.intermediate_size
+
+ X = self.model.embed_tokens(input_ids)
+ X = X.to(self.config.torch_dtype)
+ bsz, q_len, hd = X.shape
+ assert(q_len == 1)
+
+ # Get saved buffers to reduce memory movement
+ residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
+ _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
+ XX, XX2 = _XX[0], _XX[1]
+ variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0")
+ temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
+ temp_gate, temp_up = temp_mlp[0], temp_mlp[1]
+
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
- hidden_states,
+ X,
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
@@ -915,30 +948,54 @@ def LlamaModel_fast_forward_inference(
pass
next_decoder_cache = []
+
for idx, decoder_layer in enumerate(self.model.layers):
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
- hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
+ residual.copy_(X) # residual = X
+ X = fast_rms_layernorm_inference(
+ decoder_layer.input_layernorm,
+ X,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
+ )
+ X, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
- hidden_states = hidden_states,
+ hidden_states = X,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
- hidden_states += residual
-
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
- hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
- hidden_states += residual
+ X += residual
+
+ residual.copy_(X) # residual = X
+ X = fast_rms_layernorm_inference(
+ decoder_layer.post_attention_layernorm,
+ X,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
+ )
+ X = fast_swiglu_inference(
+ decoder_layer.mlp,
+ X,
+ temp_gate = temp_gate,
+ temp_up = temp_up,
+ )
+ X += residual
next_decoder_cache.append(present_key_value)
pass
- hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
+ X = fast_rms_layernorm_inference(
+ self.model.norm,
+ X,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
+ )
return BaseModelOutputWithPast(
- last_hidden_state = hidden_states,
+ last_hidden_state = X,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
@@ -973,7 +1030,7 @@ def _CausalLM_fast_forward(
attention_mask = attention_mask,
)
else:
- causal_mask = xformers.attn_bias.LowerTriangularMask()
+ causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -1155,7 +1212,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- dim = int((config.hidden_size // config.num_attention_heads))
+ dim = getattr(config, "head_dim", None)
+ if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
@@ -1576,9 +1634,18 @@ def from_pretrained(
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
+
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 16,
+ disable_log_stats = False,
**kwargs,
):
if trust_remote_code:
+ if fast_inference:
+ raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.")
print(
"Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
@@ -1592,9 +1659,9 @@ def from_pretrained(
statistics = \
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\
- f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
- f"O^O/ \_/ \\ Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
- f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
+ f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
+ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
+ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
@@ -1622,7 +1689,11 @@ def from_pretrained(
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# RoPE Scaling
- model_config = AutoConfig.from_pretrained(model_name, token = token)
+ model_config = AutoConfig.from_pretrained(
+ model_name,
+ token = token,
+ attn_implementation = "sdpa",
+ )
model_max_seq_length = model_config.max_position_embeddings
# Check if RoPE Scaling is even allowed
@@ -1643,6 +1714,9 @@ def from_pretrained(
rope_scaling = max_seq_length / model_max_seq_length
+ if fast_inference:
+ raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.")
+
logger.warning_once(
f"Unsloth: {model_name} can only handle sequence lengths of at most "\
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
@@ -1684,17 +1758,54 @@ def from_pretrained(
# Cannot be None, since HF now checks for the config
if load_in_4bit: kwargs["quantization_config"] = bnb_config
- model = AutoModelForCausalLM.from_pretrained(
- model_name,
- device_map = device_map,
- torch_dtype = dtype,
- # quantization_config = bnb_config,
- token = token,
- max_position_embeddings = max_position_embeddings,
- trust_remote_code = trust_remote_code,
- attn_implementation = "eager",
- **kwargs,
- )
+ if not fast_inference:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ device_map = device_map,
+ torch_dtype = dtype,
+ # quantization_config = bnb_config,
+ token = token,
+ max_position_embeddings = max_position_embeddings,
+ trust_remote_code = trust_remote_code,
+ attn_implementation = "eager",
+ **kwargs,
+ )
+ else:
+ from unsloth_zoo.vllm_utils import (
+ load_vllm,
+ get_vllm_state_dict,
+ convert_vllm_to_huggingface,
+ generate_batches,
+ )
+ allowed_args = inspect.getfullargspec(load_vllm).args
+ load_vllm_kwargs = dict(
+ model_name = model_name,
+ config = model_config,
+ gpu_memory_utilization = gpu_memory_utilization,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ float8_kv_cache = float8_kv_cache,
+ enable_lora = True,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
+ )
+ for allowed_arg in allowed_args:
+ if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
+ load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]
+ pass
+
+ # Load vLLM first
+ llm = load_vllm(**load_vllm_kwargs)
+
+ # Convert to HF format
+ _, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
+ model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype)
+ model.vllm_engine = llm
+ model.fast_generate = model.vllm_engine.generate
+
+ from functools import partial
+ model.fast_generate_batches = partial(generate_batches, model.vllm_engine)
+ pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
@@ -2190,6 +2301,20 @@ def get_peft_model(
modules_to_save = list(set(modules_to_save))
pass
+ vllm_engine = None
+ if hasattr(model, "vllm_engine"):
+ # Fast inference!
+ vllm_engine = model.vllm_engine
+ vllm_fast_generate = model.fast_generate
+ vllm_fast_generate_batches = model.fast_generate_batches
+
+ if modules_to_save is not None:
+ raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.")
+
+ if bias != "none":
+ raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.")
+ pass
+
# Get LoRA
arguments = dict(
r = r,
@@ -2296,6 +2421,19 @@ def get_peft_model(
torch.cuda.empty_cache()
pass
+ # Patch for fast inference
+ if vllm_engine is not None:
+ model.vllm_engine = vllm_engine
+ model.fast_generate = vllm_fast_generate
+ model.fast_generate_batches = vllm_fast_generate_batches
+
+ # Also saving and loading LoRA
+ from functools import partial
+ from unsloth_zoo.vllm_utils import save_lora, load_lora
+ model.save_lora = partial(save_lora, model)
+ model.load_lora = partial(load_lora, model)
+ pass
+
return model
pass
@@ -2505,18 +2643,24 @@ def for_inference(model):
# return
# pass
- internal_model = model
- internal_model.gradient_checkpointing = False
- internal_model.training = False
-
- while hasattr(internal_model, "model"):
- internal_model = internal_model.model
- internal_model.gradient_checkpointing = False
- internal_model.training = False
- pass
- if hasattr(internal_model, "training"):
- internal_model.training = False
- pass
+ m = model
+ while hasattr(m, "model"):
+ if hasattr(m, "gradient_checkpointing"):
+ m.gradient_checkpointing = False
+ if hasattr(m, "training"):
+ m.training = False
+ # Pad tokenizer to the left
+ if hasattr(m, "_saved_temp_tokenizer"):
+ m._saved_temp_tokenizer.padding_side = "left"
+ m = m.model
+ pass
+ if hasattr(m, "gradient_checkpointing"):
+ m.gradient_checkpointing = False
+ if hasattr(m, "training"):
+ m.training = False
+ # Pad tokenizer to the left
+ if hasattr(m, "_saved_temp_tokenizer"):
+ m._saved_temp_tokenizer.padding_side = "left"
# Also check if lm_head / embeddings are trained
internal_model = model
@@ -2525,30 +2669,13 @@ def for_inference(model):
pass
lm_head = internal_model.lm_head.weight
device_type = lm_head.device.type
- dtype = model.config.torch_dtype
-
- if type(dtype) is str:
- if dtype == "float16": dtype = torch.float16
- elif dtype == "bfloat16": dtype = torch.bfloat16
- pass
+ dtype = _get_dtype(model.config.torch_dtype)
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
-
- # Patch tokenizer to pad to the left
- internal_model = model
- while hasattr(internal_model, "model"):
- if hasattr(internal_model, "_saved_temp_tokenizer"):
- internal_model._saved_temp_tokenizer.padding_side = "left"
- pass
- internal_model = internal_model.model
- pass
- if hasattr(internal_model, "_saved_temp_tokenizer"):
- internal_model._saved_temp_tokenizer.padding_side = "left"
- pass
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
@@ -2566,9 +2693,6 @@ def for_inference(model):
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
- internal_model = model
- internal_model.gradient_checkpointing = use_gradient_checkpointing
- internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
@@ -2576,14 +2700,24 @@ def for_training(model, use_gradient_checkpointing = True):
del param._fast_lora
pass
- while hasattr(internal_model, "model"):
- internal_model = internal_model.model
- internal_model.gradient_checkpointing = use_gradient_checkpointing
- internal_model.training = True
- pass
- if hasattr(internal_model, "training"):
- internal_model.training = True
- pass
+ m = model
+ while hasattr(m, "model"):
+ if hasattr(m, "gradient_checkpointing"):
+ m.gradient_checkpointing = use_gradient_checkpointing
+ if hasattr(m, "training"):
+ m.training = True
+ # Pad tokenizer to the right
+ if hasattr(m, "_saved_temp_tokenizer"):
+ m._saved_temp_tokenizer.padding_side = "right"
+ m = m.model
+ pass
+ if hasattr(m, "gradient_checkpointing"):
+ m.gradient_checkpointing = use_gradient_checkpointing
+ if hasattr(m, "training"):
+ m.training = True
+ # Pad tokenizer to the right
+ if hasattr(m, "_saved_temp_tokenizer"):
+ m._saved_temp_tokenizer.padding_side = "right"
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
@@ -2591,18 +2725,6 @@ def for_training(model, use_gradient_checkpointing = True):
del model._unwrapped_old_generate
pass
- # Patch tokenizer to pad to the right
- internal_model = model
- while hasattr(internal_model, "model"):
- if hasattr(internal_model, "_saved_temp_tokenizer"):
- internal_model._saved_temp_tokenizer.padding_side = "right"
- pass
- internal_model = internal_model.model
- pass
- if hasattr(internal_model, "_saved_temp_tokenizer"):
- internal_model._saved_temp_tokenizer.padding_side = "right"
- pass
-
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py
index e9caad0e6..39b367e27 100644
--- a/unsloth/models/loader.py
+++ b/unsloth/models/loader.py
@@ -30,11 +30,11 @@
from huggingface_hub.utils._token import get_token
pass
from huggingface_hub import HfFileSystem
+import importlib.util
# [TODO] Move USE_MODELSCOPE to utils
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
- import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
pass
@@ -73,9 +73,25 @@ def from_pretrained(
resize_model_vocab = None,
revision = None,
use_exact_model_name = False,
+
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 64,
+ disable_log_stats = True,
*args, **kwargs,
):
if token is None: token = get_token()
+
+ if fast_inference:
+ if importlib.util.find_spec("vllm") is None:
+ raise ImportError(
+ "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\
+ "You can do this in a terminal via `pip install vllm`"
+ )
+ pass
+ pass
old_model_name = model_name
if not use_exact_model_name:
@@ -255,6 +271,24 @@ def from_pretrained(
tokenizer_name = None
pass
+ if fast_inference:
+ from unsloth_zoo.vllm_utils import (
+ patch_vllm,
+ vllm_dynamic_quant_supported,
+ )
+ patch_vllm()
+ if model_name.endswith("unsloth-bnb-4bit"):
+ if not vllm_dynamic_quant_supported(model_name, model_config):
+ # Instead use -bnb-4bit variant
+ print(
+ f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\
+ f"we do not yet support fast inference for {model_name}"
+ )
+ model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit"
+ pass
+ pass
+ pass
+
model, tokenizer = dispatch_model.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
@@ -268,6 +302,13 @@ def from_pretrained(
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
+
+ fast_inference = fast_inference,
+ gpu_memory_utilization = gpu_memory_utilization,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
*args, **kwargs,
)
diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py
index c1113f529..c81290b66 100644
--- a/unsloth/models/mapper.py
+++ b/unsloth/models/mapper.py
@@ -304,25 +304,30 @@
"unsloth/Mistral-Small-Instruct-2409",
"mistralai/Mistral-Small-Instruct-2409",
),
- "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : (
+ "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",
+ "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit",
),
- "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : (
+ "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
+ "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
),
- "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : (
+ "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
+ "unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
),
- "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : (
+ "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
+ "unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
),
- "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : (
+ "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
+ "unsloth/Qwen2.5-14B-Instruct-bnb-4bit",
),
"unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-32B-Instruct",
@@ -332,25 +337,30 @@
"unsloth/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
),
- "unsloth/Qwen2.5-0.5B-bnb-4bit" : (
+ "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-0.5B",
"Qwen/Qwen2.5-0.5B",
+ "unsloth/Qwen2.5-0.5B-bnb-4bit",
),
- "unsloth/Qwen2.5-1.5B-bnb-4bit" : (
+ "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-1.5B",
"Qwen/Qwen2.5-1.5B",
+ "unsloth/Qwen2.5-1.5B-bnb-4bit",
),
- "unsloth/Qwen2.5-3B-bnb-4bit" : (
+ "unsloth/Qwen2.5-3B-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-3B",
"Qwen/Qwen2.5-3B",
+ "unsloth/Qwen2.5-3B-bnb-4bit",
),
- "unsloth/Qwen2.5-7B-bnb-4bit" : (
+ "unsloth/Qwen2.5-7B-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-7B",
"Qwen/Qwen2.5-7B",
+ "unsloth/Qwen2.5-7B-bnb-4bit",
),
- "unsloth/Qwen2.5-14B-bnb-4bit" : (
+ "unsloth/Qwen2.5-14B-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-14B",
"Qwen/Qwen2.5-14B",
+ "unsloth/Qwen2.5-14B-bnb-4bit",
),
"unsloth/Qwen2.5-32B-bnb-4bit" : (
"unsloth/Qwen2.5-32B",
@@ -432,21 +442,25 @@
"unsloth/Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct",
),
- "unsloth/Llama-3.2-1B-bnb-4bit" : (
+ "unsloth/Llama-3.2-1B-unsloth-bnb-4bit" : (
"unsloth/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B",
+ "unsloth/Llama-3.2-1B-bnb-4bit",
),
- "unsloth/Llama-3.2-3B-bnb-4bit" : (
+ "unsloth/Llama-3.2-3B-unsloth-bnb-4bit" : (
"unsloth/Llama-3.2-3B",
"meta-llama/Llama-3.2-3B",
+ "unsloth/Llama-3.2-3B-bnb-4bit",
),
- "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : (
+ "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
+ "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
),
- "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : (
+ "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
+ "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
),
"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : (
"unsloth/Llama-3.1-Nemotron-70B-Instruct",
@@ -471,20 +485,18 @@
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
),
- "unsloth/Llama-3.2-90B-Vision-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : (
"unsloth/Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct",
- "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit",
),
"unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit" : (
"unsloth/Llama-3.2-11B-Vision",
"meta-llama/Llama-3.2-11B-Vision",
"unsloth/Llama-3.2-11B-Vision-bnb-4bit",
),
- "unsloth/Llama-3.2-90B-Vision-unsloth-bnb-4bit" : (
+ "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : (
"unsloth/Llama-3.2-90B-Vision",
"meta-llama/Llama-3.2-90B-Vision",
- "unsloth/Llama-3.2-90B-Vision-bnb-4bit",
),
"unsloth/Pixtral-12B-2409-unsloth-bnb-4bit" : (
"unsloth/Pixtral-12B-2409",
@@ -524,6 +536,59 @@
"microsoft/phi-4",
"unsloth/phi-4-bnb-4bit",
),
+ "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Qwen-32B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
+ ),
+ "unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Qwen-14B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ "unsloth/DeepSeek-R1-Distill-Qwen-14B-bnb-4bit",
+ ),
+ "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Qwen-7B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ "unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit",
+ ),
+ "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Qwen-1.5B",
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
+ "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-bnb-4bit",
+ ),
+ "unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Llama-8B",
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
+ "unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit",
+ ),
+ "unsloth/DeepSeek-R1-Distill-Llama-70B-bnb-4bit" : (
+ "unsloth/DeepSeek-R1-Distill-Llama-70B",
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
+ ),
+ "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : (
+ "unsloth/Mistral-Small-24B-Base-2501",
+ "mistralai/Mistral-Small-24B-Base-2501",
+ "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit",
+ ),
+ "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : (
+ "unsloth/Mistral-Small-24B-Instruct-2501",
+ "mistralai/Mistral-Small-24B-Instruct-2501",
+ "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit",
+ ),
+ "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Qwen2.5-VL-3B-Instruct",
+ "Qwen/Qwen2.5-VL-3B-Instruct",
+ "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
+ ),
+ "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Qwen2.5-VL-7B-Instruct",
+ "Qwen/Qwen2.5-VL-7B-Instruct",
+ "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
+ ),
+ "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Qwen2.5-VL-72B-Instruct",
+ "Qwen/Qwen2.5-VL-72B-Instruct",
+ "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit",
+ ),
}
INT_TO_FLOAT_MAPPER = {}
diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py
index 9a97015f9..784ca9cb4 100644
--- a/unsloth/models/mistral.py
+++ b/unsloth/models/mistral.py
@@ -304,7 +304,7 @@ def pre_patch():
attention_module = MistralAttention,
)
# Just for Mistral Nemo models!
- if function is not None:
+ if function is not None and init_name is not None:
function = patch_mistral_nemo_attention(function)
# if True:#init_name is not None:
exec(function, globals())
diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py
new file mode 100644
index 000000000..515c6587f
--- /dev/null
+++ b/unsloth/models/rl.py
@@ -0,0 +1,423 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = [
+ "PatchFastRL",
+]
+
+METRICS_MOVE_TO_END = [
+ "nll",
+ "aux",
+ "beta",
+ "alpha",
+]
+import torch
+try:
+ from transformers.utils.notebook import (
+ IntervalStrategy,
+ NotebookTrainingTracker,
+ NotebookProgressCallback,
+ )
+ HAS_NOTEBOOK = True
+except:
+ HAS_NOTEBOOK = False
+pass
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+import inspect
+import os
+import re
+import functools
+from unsloth_zoo.compiler import create_new_function
+
+
+def PatchRL(FastLanguageModel):
+
+ from trl.models.utils import unwrap_model_for_generation
+ from contextlib import contextmanager
+
+ @contextmanager
+ def unsloth_unwrap_model_for_generation(model, *args, **kwargs):
+ with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model:
+ # Put the model in inference mode.
+ FastLanguageModel.for_inference(unwrapped_model)
+
+ # We must use .clone for Unsloth since we force inference_mode
+ # Rather we should have used no_grad
+ original_generate = unwrapped_model.generate
+ def generate_with_clone(*args, **kwargs):
+ out = original_generate(*args, **kwargs)
+ if isinstance(out, torch.Tensor):
+ return out.clone()
+ return out
+ pass
+ unwrapped_model.generate = generate_with_clone
+
+ try:
+ yield unwrapped_model
+ finally:
+ # Restore generate and return
+ unwrapped_model.generate = original_generate
+ FastLanguageModel.for_training(model)
+ pass
+ pass
+ pass
+
+ import trl.trainer
+ trainers = dir(trl.trainer)
+ trainers = [x for x in trainers if x.endswith("_trainer")]
+ unwrap = "unwrap_model_for_generation"
+ for trainer in trainers:
+ if hasattr(eval(f"trl.trainer.{trainer}"), unwrap):
+ exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}")
+ pass
+pass
+
+
+def NotebookProgressCallback_on_train_begin(Trainer_metrics):
+ def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
+ self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
+ self.training_loss = 0
+ self.last_log = 0
+ column_names = [self.first_column] + ["Training Loss"]
+ if args.eval_strategy != IntervalStrategy.NO:
+ column_names.append("Validation Loss")
+ column_names += [x.replace("/", " / ") for x in Trainer_metrics]
+ self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
+ pass
+ return _NotebookProgressCallback_on_train_begin
+pass
+
+
+def NotebookProgressCallback_on_log(Trainer_metrics):
+ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
+ # Only for when there is no evaluation
+ if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
+ values = {"Training Loss": logs["loss"]}
+ for metric in Trainer_metrics:
+ # Sometimes metric is not inside logs
+ try: values[metric.replace("/", " / ")] = logs[metric]
+ except: pass
+ pass
+ # First column is necessarily Step since we're not in epoch eval strategy
+ values["Step"] = state.global_step
+ self.training_tracker.write_line(values)
+ pass
+ pass
+ return _NotebookProgressCallback_on_log
+pass
+
+
+def NotebookTrainingTracker_write_line(Trainer_metrics):
+ set_Trainer_metrics = set(Trainer_metrics)
+ def _NotebookTrainingTracker_write_line(self, values):
+ """
+ Write the values in the inner table.
+
+ Args:
+ values (`Dict[str, float]`): The values to display.
+ """
+ if self.inner_table is None:
+ self.inner_table = [list(values.keys()), list(values.values())]
+ else:
+ columns = self.inner_table[0]
+ new_values = {}
+ for key, value in values.items():
+ lowered = key.lower()
+ if lowered in set_Trainer_metrics:
+ new_values[lowered.replace("/", " / ")] = value
+ else:
+ new_values[key] = value
+ pass
+ values = new_values
+
+ self.inner_table[0] = columns
+ if len(self.inner_table) > 1:
+ last_values = self.inner_table[-1]
+ first_column = self.inner_table[0][0]
+ if last_values[0] != values[first_column]:
+ # write new line
+ self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
+ else:
+ # update last line
+ new_values = values
+ for c in columns:
+ if c not in new_values.keys():
+ new_values[c] = last_values[columns.index(c)]
+ self.inner_table[-1] = [new_values[c] for c in columns]
+ else:
+ # Edit for evaluation purposes
+ self.inner_table.append([values[c] if c in values else 0 for c in columns])
+ pass
+ pass
+ pass
+ return _NotebookTrainingTracker_write_line
+pass
+
+
+def _PatchRLStatistics(metrics, algorithm):
+ if HAS_NOTEBOOK:
+ if len(metrics) == 0:
+ raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?")
+ from transformers.trainer import is_in_notebook
+ if is_in_notebook():
+ # Patch DPO notebook printing
+ NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics)
+ from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
+ DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics)
+ DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics)
+ pass
+ pass
+pass
+
+
+@functools.cache
+def get_trl_metrics():
+ # Gets metrics so we can output them in notebooks
+
+ import trl.trainer
+ trainers = dir(trl.trainer)
+ trainers = [x for x in trainers if x.endswith("_trainer")]
+ filepath = inspect.getfile(trl.trainer)
+ filepath = os.path.split(filepath)[0]
+
+ all_metrics = dict()
+ for trainer in trainers:
+ filename = os.path.join(filepath, f"{trainer}.py")
+ if not os.path.exists(filename): continue
+ with open(filename, "r") as file: file = file.read()
+
+ # Get metrics['kl'] or stats['kl']
+ metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file)
+ stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file)
+ metrics = metrics + stats
+
+ # Get optional f-strings
+ metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
+ stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
+ metrics_f = metrics_f + stats_f
+ # Filter out prefixes if seen
+ # metrics[f"{prefix}rewards/chosen"]
+ left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file
+ if left_prefix: metrics += metrics_f
+
+ # Move all eval_ things to the end and reward to the front
+ beginning = []
+ middle = []
+ end = []
+ for x in metrics:
+ lowered = x.lower()
+ if "reward" in lowered:
+ beginning.append(x)
+ elif x.lower().startswith("eval"):
+ end.append(x)
+ else:
+ # Check if we want to move to the end
+ moved = False
+ for move_end in METRICS_MOVE_TO_END:
+ if move_end in lowered:
+ end.append(x)
+ moved = True
+ break
+ if not moved:
+ middle.append(x)
+ pass
+ pass
+ metrics = beginning + middle + end
+
+ all_metrics[trainer[:trainer.find("_")].upper()] = metrics
+ pass
+ return all_metrics
+pass
+
+
+def PatchRLStatistics(algorithm = "GRPO"):
+ # Get notebook statistics columns to show up
+ algorithm = algorithm.upper()
+ all_metrics = get_trl_metrics()
+ if algorithm not in all_metrics:
+ print(
+ f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\
+ f"We support: `{list(all_metrics.keys())}`"
+ )
+ pass
+ _PatchRLStatistics(all_metrics[algorithm], algorithm)
+pass
+
+
+def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
+ # Patch for vLLM and Unsloth PEFT
+ import trl
+ import trl.trainer
+
+ trainer = eval(f"trl.trainer.{trainer_file}")
+ name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()]
+ assert(len(name) == 1)
+ RLTrainer_name = name[0]
+ RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
+
+ try:
+ __init__ = inspect.getsource(RLTrainer.__init__)
+ except:
+ # Already patched most likely!
+ return
+ old__init__ = __init__
+ all_imports = dir(trainer)
+ assert("Union" in all_imports)
+ imports = [x for x in all_imports if not x.startswith("_")]
+ imports += ["Trainer"]
+
+ spaces = __init__.find("def")
+ __init__ = __init__.split("\n")
+ __init__ = "\n".join(x[spaces:] for x in __init__)
+
+ # Replace vLLM sections since we already have it done!
+ vllm_part = re.findall(
+ r"(\n[\s]{4}"\
+ r"if (self|args)\.use_vllm\:.+?"\
+ r"\n[\s]{4,}"\
+ "else:\n)",
+ __init__,
+ flags = re.MULTILINE | re.DOTALL,
+ )
+ if (len(vllm_part) != 1): return
+
+ vllm_part, args = vllm_part[0][0], vllm_part[0][1]
+ # Strip all comments
+ new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part)
+
+ # Get SamplingParams
+ sampling_params = re.findall(
+ r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\
+ r"SamplingParams\(.+?\))",
+ new_vllm_part,
+ flags = re.MULTILINE | re.DOTALL,
+ )
+ if len(sampling_params) != 1: return
+
+ sampling_params = sampling_params[0]
+ # Replace with our vLLM engine
+ sampling_params = \
+ " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \
+ sampling_params # Add spaces
+ new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n"
+ __init__ = __init__.replace(vllm_part, new_vllm_part)
+
+ # Remove peft_config
+ __init__ = __init__.replace("elif peft_config is None:", "elif False:")
+ __init__ = __init__.replace("elif peft_config is not None:", "elif False:")
+ __init__ = __init__.replace("if peft_config is None:", "if False:")
+ __init__ = __init__.replace("if peft_config is not None:", "if False:")
+ __init__ = __init__.replace("get_peft_model(model, peft_config)", "model")
+
+ # Add spaces back into __init__
+ __init__ = __init__.split("\n")
+ __init__ = "\n".join(' '*spaces + x for x in __init__)
+
+ # Search for vLLM calling in all child functions
+ functions = dir(RLTrainer)
+ RLTrainer_source = inspect.getsource(RLTrainer)
+ functions = [x for x in functions if f"def {x}" in RLTrainer_source]
+
+ changed = {"__init__" : (old__init__, __init__,)}
+ for function in functions:
+ if not hasattr(RLTrainer, function): continue
+ fx = getattr(RLTrainer, function)
+ try:
+ source = inspect.getsource(fx)
+ except:
+ continue
+ original_source = source
+
+ # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
+ source = re.sub(
+ r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n",
+ r"\n\1pass\n",
+ source,
+ )
+
+ # llm_model.load_weights(model.state_dict().items())
+ source = re.sub(
+ r"(\n[\s]{4,}).+?load_weights\(.+?\n",
+ r"\n\1pass\n",
+ source,
+ )
+
+ # .state_dict()
+ source = re.sub(
+ r"\.state_dict\(\)",
+ r"",
+ source,
+ )
+
+ # Replace self.llm.generate and self.llm.chat
+ lora_name = trainer_file + "_lora_model"
+ source = re.sub(
+ r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)",
+ r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))",
+ source
+ )
+
+ # Skip if no changes done
+ if source == original_source: continue
+
+ # Find all imports
+ imports += [x for x in all_imports if not x.startswith("_") and x in source]
+
+ changed[function] = (original_source, source,)
+ pass
+
+ # Import all functions
+ imports = list(set(imports))
+
+ # Patch all functions
+ for function in changed:
+ old, new = changed[function]
+ RLTrainer_source = RLTrainer_source.replace(old, new)
+ pass
+ RLTrainer_source = RLTrainer_source.replace(
+ f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1
+ )
+
+ # Create new class in compiled cache and import it
+ module = create_new_function(
+ RLTrainer_name,
+ RLTrainer_source,
+ f"trl.trainer.{trainer_file}",
+ imports,
+ )
+
+ # Patch over modules
+ exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals())
+ exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals())
+ exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals())
+ return module
+pass
+
+
+def patch_trl_rl_trainers():
+ # Patch all TRL modules if they have vLLM or PEFT
+ import trl.trainer
+ all_trainers = dir(trl.trainer)
+ all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
+ for trainer in all_trainers:
+ _patch_trl_rl_trainers(trainer)
+ return
+pass
+
+
+def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None):
+ if FastLanguageModel is not None: PatchRL(FastLanguageModel)
+ patch_trl_rl_trainers()
+ PatchRLStatistics(algorithm)
+pass