diff --git a/README.md b/README.md index 9e288ad2d..3e635dabe 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ -### Finetune Llama 3.2, Mistral, Phi-3.5, Qwen 2.5 & Gemma 2-5x faster with 80% less memory! +### Finetune Llama 3.3, Mistral, Phi-4, Qwen 2.5 & Gemma 2-5x faster with 80% less memory! ![](https://i.ibb.co/sJ7RhGG/image-41.png) @@ -22,41 +22,41 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and | Unsloth supports | Free Notebooks | Performance | Memory use | |-----------|---------|--------|----------| -| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 60% less | -| **Phi-4** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) | 2x faster | 50% less | -| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 2x faster | 40% less | -| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 60% less | -| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma2_(9B)-Alpaca.ipynb) | 2x faster | 63% less | -| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb) | 2x faster | 63% less | -| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-Conversational.ipynb) | 2.2x faster | 73% less | -| **Ollama** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb) | 1.9x faster | 43% less | -| **ORPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-ORPO.ipynb) | 1.9x faster | 43% less | -| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb) | 1.9x faster | 43% less | +| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less | +| **Phi-4 (14B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) | 2x faster | 70% less | +| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 2x faster | 50% less | +| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less | +| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma2_(9B)-Alpaca.ipynb) | 2x faster | 70% less | +| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb) | 2x faster | 70% less | +| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-Conversational.ipynb) | 2.2x faster | 75% less | +| **Ollama** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb) | 1.9x faster | 60% less | +| **ORPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-ORPO.ipynb) | 1.9x faster | 50% less | +| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb) | 1.9x faster | 50% less | - See [all our notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) and [all our models](https://docs.unsloth.ai/get-started/all-our-models) - **Kaggle Notebooks** for [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook), [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook) - Run notebooks for [Llama 3.2 conversational](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb), [Llama 3.1 conversational](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing) -- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text +- This [text completion notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_(7B)-Text_Completion.ipynb) is for continued pretraining / raw text - This [continued pretraining notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-CPT.ipynb) is for learning another language - Click [here](https://docs.unsloth.ai/) for detailed documentation for Unsloth. ## 🦥 Unsloth.ai News +- 📣 NEW! [DeepSeek-R1](https://unsloth.ai/blog/deepseek-r1) - the most powerful open reasoning models with Llama & Qwen distillations. Run or fine-tune them now! More details: [unsloth.ai/blog/deepseek-r1](https://unsloth.ai/blog/deepseek-r1). All model uploads: [here](https://huggingface.co/collections/unsloth/deepseek-r1-all-versions-678e1c48f5d2fce87892ace5). - 📣 NEW! [Phi-4](https://unsloth.ai/blog/phi4) by Microsoft is now supported. We also [fixed bugs](https://unsloth.ai/blog/phi4) in Phi-4 and [uploaded GGUFs, 4-bit](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa). Try the [Phi-4 Colab notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) - 📣 NEW! [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is supported. - 📣 NEW! We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support. -- 📣 NEW! Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7) -- 📣 NEW! [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing), [Qwen 2.5 VL (7B)](https://colab.research.google.com/drive/1whHb54GNZMrNxIsi2wm2EY_-Pvo2QyKh?usp=sharing) and [Pixtral (12B) 2409](https://colab.research.google.com/drive/1K9ZrdwvZRE96qGkCq_e88FgV3MLnymQq?usp=sharing) -- 📣 NEW! Qwen-2.5 including [Coder](https://colab.research.google.com/drive/18sN803sU23XuJV9Q8On2xgqHSer6-UZF?usp=sharing) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing) -- 📣 NEW! We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers. +- 📣 Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7) +- 📣 [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb), [Qwen 2.5 VL (7B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2_VL_(7B)-Vision.ipynb) and [Pixtral (12B) 2409](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Pixtral_(12B)-Vision.ipynb)
Click for more news -- 📣 Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)! -- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM! -- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported +- 📣 We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers. +- 📣 Try out [Chat interface](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Unsloth_Studio.ipynb)! +- 📣 NEW! Qwen-2.5 including [Coder](https://unsloth.ai/blog/qwen-coder) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_Coder_(14B)-Conversational.ipynb) +- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_Small_(22B)-Alpaca.ipynb) finetuning fits in under 16GB of VRAM! - 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs. -- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean! -- 📣 [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models +- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-CPT.ipynb) for other languages like Korean! +- 📣 [2x faster inference](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Inference.ipynb) added for all our models - 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
@@ -82,23 +82,17 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and ## 🥇 Performance Benchmarking -- For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables) +- For our most detailed benchmarks, read our [Llama 3.3 Blog](https://unsloth.ai/blog/llama3-3). +- Benchmarking of Unsloth was also conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl). -| 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) | -|--------------|--------------|-----------------|---------------------|-----------------| -| Alpaca | 1x | 1.04x | 1.98x | **15.64x** | -| LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** | -| OASST | 1x | 1.19x | 2.17x | **14.83x** | -| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** | - -- Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl). +We tested using the Alpaca Dataset, a batch size of 2, gradient accumulation steps of 4, rank = 32, and applied QLoRA on all linear layers (q, k, v, o, gate, up, down): + +| Model | VRAM | 🦥 Unsloth speed | 🦥 VRAM reduction | 🦥 Longer context | 😊 Hugging Face + FA2 | +|----------------|-------|-----------------|----------------|----------------|--------------------| +| Llama 3.3 (70B)| 80GB | 2x | >75% | 13x longer | 1x | +| Llama 3.1 (8B) | 80GB | 2x | >70% | 12x longer | 1x | -| Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction | -| --- | --- | --- | --- | --- | --- | -| Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% | -| Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% | -| Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% | -| DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% | +
![](https://i.ibb.co/sJ7RhGG/image-41.png) @@ -359,119 +353,28 @@ dpo_trainer.train() ``` ## 🥇 Detailed Benchmarking Tables -- Click "Code" for fully reproducible examples -- "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical. -- For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables) - -| 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max | -|--------------|-------------|-------------|-----------------|--------------|---------------|-------------| -| Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** | -| code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | | -| seconds| 1040 | 1001 | 525 | 419 | 196 | 67 | -| memory MB| 18235 | 15365 | 9631 | 8525 | | | -| % saved| | 15.74 | 47.18 | 53.25 | | | | - -### Llama-Factory 3rd party benchmarking -- [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024. - -| Method | Bits | TGS | GRAM | Speed | -| --- | --- | --- | --- | --- | -| HF | 16 | 2392 | 18GB | 100% | -| HF+FA2 | 16 | 2954 | 17GB | 123% | -| Unsloth+FA2 | 16 | 4007 | 16GB | **168%** | -| HF | 4 | 2415 | 9GB | 101% | -| Unsloth+FA2 | 4 | 3726 | 7GB | **160%** | - -### Performance comparisons between popular models -
- Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.) - -### Mistral 7b -| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max | -|--------------|-------------|-------------|-----------------|--------------|---------------|-------------| -| Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** | -| code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | | -| seconds | 1813 | 1571 | 842 | 718 | 393 | 132 | -| memory MB | 32853 | 19385 | 12465 | 10271 | | | -| % saved| | 40.99 | 62.06 | 68.74 | | | - -### CodeLlama 34b -| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max | -|--------------|-------------|-------------|-----------------|--------------|---------------|-------------| -| Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x | -| code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | | -| seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 | -| memory MB | 40000 | 33217 | 27413 | 22161 | | | -| % saved| | 16.96| 31.47 | 44.60 | | | | - -### 1 Tesla T4 - -| 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max | -|--------------|-------------|-----------------|-----------------|---------------|---------------|-------------| -| Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** | -| code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | | -| seconds | 1599 | 1468 | 942 | 894 | 545 | 193 | -| memory MB | 7199 | 7059 | 6459 | 5443 | | | -| % saved | | 1.94 | 10.28 | 24.39 | | | - -### 2 Tesla T4s via DDP - - | 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max | -|--------------|----------|-------------|-----------------|--------------|---------------|-------------| -| Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** | -| code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | | -| seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 | -| memory MB| 9176 | 9128 | 6904 | 6782 | | | -| % saved | | 0.52 | 24.76 | 26.09 | | | | -
- -### Performance comparisons on 1 Tesla T4 GPU: -
- Click for Time taken for 1 epoch - -One Tesla T4 on Google Colab -`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10` - -| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) | -| --- | --- | --- | --- | --- | --- | -| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m | -| Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) | -| Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) | -| Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) | - -**Peak Memory Usage** - -| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) | -| --- | --- | --- | --- | --- | --- | -| Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB | -| Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB | -| Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB | -| Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB | -
- -
- Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP: -**Time taken for 1 epoch** - -Two Tesla T4s on Kaggle -`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10` - -| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * | -| --- | --- | --- | --- | --- | --- | -| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * | -| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * | -| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * | +### Context length benchmarks +#### Llama 3.1 (8B) max. context length +We tested Llama 3.1 (8B) Instruct and did 4bit QLoRA on all linear layers (Q, K, V, O, gate, up and down) with rank = 32 with a batch size of 1. We padded all sequences to a certain maximum sequence length to mimic long context finetuning workloads. +| GPU VRAM | 🦥Unsloth context length | Hugging Face + FA2 | +|----------|-----------------------|-----------------| +| 8 GB | 2,972 | OOM | +| 12 GB | 21,848 | 932 | +| 16 GB | 40,724 | 2,551 | +| 24 GB | 78,475 | 5,789 | +| 40 GB | 153,977 | 12,264 | +| 48 GB | 191,728 | 15,502 | +| 80 GB | 342,733 | 28,454 | + +#### Llama 3.3 (70B) max. context length +We tested Llama 3.3 (70B) Instruct on a 80GB A100 and did 4bit QLoRA on all linear layers (Q, K, V, O, gate, up and down) with rank = 32 with a batch size of 1. We padded all sequences to a certain maximum sequence length to mimic long context finetuning workloads. + +| GPU VRAM | 🦥Unsloth context length | Hugging Face + FA2 | +|----------|------------------------|------------------| +| 48 GB | 12,106 | OOM | +| 80 GB | 89,389 | 6,916 | -**Peak Memory Usage on a Multi GPU System (2 GPUs)** - -| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * | -| --- | --- | --- | --- | --- | --- | -| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * | -| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * | -| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * | - -* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency. -
+
![](https://i.ibb.co/sJ7RhGG/image-41.png)
diff --git a/pyproject.toml b/pyproject.toml index b24abd355..d89ea2c4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.2.2", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -131,6 +131,12 @@ cu124onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch250 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -147,6 +153,12 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch251 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -163,6 +175,28 @@ cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] +cu124onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", +] +cu126onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu118 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -223,21 +257,31 @@ cu121-torch240 = [ "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch240]", ] -cu121-torch250 = [ +cu124-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", ] -cu124-torch240 = [ +cu118-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", +] +cu121-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", ] cu124-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch250]", ] +cu118-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", +] cu121-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -248,6 +292,21 @@ cu124-torch251 = [ "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch251]", ] +cu118-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", +] +cu124-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", +] +cu126-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", +] kaggle = [ "unsloth[huggingface]", ] @@ -285,7 +344,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.2.2", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -381,16 +440,22 @@ cu121-ampere-torch240 = [ "unsloth[cu121onlytorch240]", "unsloth[flashattention]", ] -cu121-ampere-torch250 = [ +cu124-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", "unsloth[flashattention]", ] -cu124-ampere-torch240 = [ +cu118-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", + "unsloth[flashattention]", +] +cu121-ampere-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", "unsloth[flashattention]", ] cu124-ampere-torch250 = [ @@ -399,6 +464,12 @@ cu124-ampere-torch250 = [ "unsloth[cu124onlytorch250]", "unsloth[flashattention]", ] +cu118-ampere-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", + "unsloth[flashattention]", +] cu121-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -411,6 +482,24 @@ cu124-ampere-torch251 = [ "unsloth[cu124onlytorch251]", "unsloth[flashattention]", ] +cu118-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", + "unsloth[flashattention]", +] +cu124-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", + "unsloth[flashattention]", +] +cu126-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", + "unsloth[flashattention]", +] [project.urls] homepage = "http://www.unsloth.ai" diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8002fbaef..bdde33c50 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -86,6 +86,10 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass +# First check if CUDA is available ie a NVIDIA GPU is seen +if not torch.cuda.is_available(): + raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!") + # Fix Xformers performance issues since 0.0.25 import importlib.util from pathlib import Path @@ -130,71 +134,69 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 torch.cuda.is_bf16_supported = is_bf16_supported pass +# For Gradio HF Spaces? +# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: +import triton +libcuda_dirs = lambda: None +if Version(triton.__version__) >= Version("3.0.0"): + try: from triton.backends.nvidia.driver import libcuda_dirs + except: pass +else: from triton.common.build import libcuda_dirs + # Try loading bitsandbytes and triton import bitsandbytes as bnb +try: + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + libcuda_dirs() +except: + warnings.warn( + "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\ + ) -if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: - - import triton - libcuda_dirs = lambda: None - if Version(triton.__version__) >= Version("3.0.0"): - try: from triton.backends.nvidia.driver import libcuda_dirs - except: pass - else: from triton.common.build import libcuda_dirs + if os.path.exists("/usr/lib64-nvidia"): + os.system("ldconfig /usr/lib64-nvidia") + elif os.path.exists("/usr/local"): + # Sometimes bitsandbytes cannot be linked properly in Runpod for example + possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n") + find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$") + possible_cudas = [find_cuda.search(x) for x in possible_cudas] + possible_cudas = [x.group(1) for x in possible_cudas if x is not None] + + # Try linking cuda folder, or everything in local + if len(possible_cudas) == 0: + os.system("ldconfig /usr/local/") + else: + find_number = re.compile(r"([\d\.]{2,})") + latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0] + latest_cuda = possible_cudas[latest_cuda] + os.system(f"ldconfig /usr/local/{latest_cuda}") + pass + importlib.reload(bnb) + importlib.reload(triton) try: + libcuda_dirs = lambda: None + if Version(triton.__version__) >= Version("3.0.0"): + try: from triton.backends.nvidia.driver import libcuda_dirs + except: pass + else: from triton.common.build import libcuda_dirs cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: warnings.warn( - "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\ + "Unsloth: CUDA is not linked properly.\n"\ + "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\ + "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\ + "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\ + "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\ + "Unsloth will still run for now, but maybe it might crash - let's hope it works!" ) - - if os.path.exists("/usr/lib64-nvidia"): - os.system("ldconfig /usr/lib64-nvidia") - elif os.path.exists("/usr/local"): - # Sometimes bitsandbytes cannot be linked properly in Runpod for example - possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n") - find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$") - possible_cudas = [find_cuda.search(x) for x in possible_cudas] - possible_cudas = [x.group(1) for x in possible_cudas if x is not None] - - # Try linking cuda folder, or everything in local - if len(possible_cudas) == 0: - os.system("ldconfig /usr/local/") - else: - find_number = re.compile(r"([\d\.]{2,})") - latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0] - latest_cuda = possible_cudas[latest_cuda] - os.system(f"ldconfig /usr/local/{latest_cuda}") - pass - - importlib.reload(bnb) - importlib.reload(triton) - try: - libcuda_dirs = lambda: None - if Version(triton.__version__) >= Version("3.0.0"): - try: from triton.backends.nvidia.driver import libcuda_dirs - except: pass - else: from triton.common.build import libcuda_dirs - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 - libcuda_dirs() - except: - warnings.warn( - "Unsloth: CUDA is not linked properly.\n"\ - "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\ - "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\ - "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\ - "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\ - "Unsloth will still run for now, but maybe it might crash - let's hope it works!" - ) - pass pass # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.1.2"): + if Version(unsloth_zoo_version) < Version("2025.2.2"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py index c3b94c670..8bb548519 100644 --- a/unsloth/_auto_install.py +++ b/unsloth/_auto_install.py @@ -18,14 +18,16 @@ v = V(torch.__version__) cuda = str(torch.version.cuda) is_ampere = torch.cuda.get_device_capability()[0] >= 8 -if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!") +if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!") if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!") elif v <= V('2.1.1'): x = 'cu{}{}-torch211' elif v <= V('2.1.2'): x = 'cu{}{}-torch212' elif v < V('2.3.0'): x = 'cu{}{}-torch220' elif v < V('2.4.0'): x = 'cu{}{}-torch230' elif v < V('2.5.0'): x = 'cu{}{}-torch240' -elif v < V('2.6.0'): x = 'cu{}{}-torch250' +elif v < V('2.5.1'): x = 'cu{}{}-torch250' +elif v <= V('2.5.1'): x = 'cu{}{}-torch251' +elif v < V('2.7.0'): x = 'cu{}{}-torch260' else: raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') \ No newline at end of file diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index d8dc38522..c40139323 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -759,6 +759,10 @@ CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates + +for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"): + CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"] + DEFAULT_SYSTEM_MESSAGE[version] = "" pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index de543962e..f052914f9 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -15,6 +15,7 @@ import triton MAX_FUSED_SIZE : int = 65536 next_power_of_2 = triton.next_power_of_2 +import functools # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -66,6 +67,8 @@ def calculate_settings(n : int) -> (int, int,): CUDA_STREAM = None get_ptr = bnb.functional.get_ptr import ctypes +ctypes_c_int = ctypes.c_int +ctypes_c_int32 = ctypes.c_int32 cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 @@ -98,25 +101,31 @@ def get_lora_parameters(proj): def get_lora_parameters_bias(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight bias = base_layer.bias - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if getattr(proj, "disable_adapters", True) or proj.merged: return W, QUANT_STATE(W), None, None, None, bias pass active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter + getattr(proj, "active_adapters", ) else proj.active_adapter A = proj.lora_A [active_adapter].weight B = proj.lora_B [active_adapter].weight s = proj.scaling[active_adapter] return W, QUANT_STATE(W), A, B, s, bias pass +global WEIGHT_BUFFER +WEIGHT_BUFFER = None +global ABSMAX_BUFFER +ABSMAX_BUFFER = None if HAS_CUDA_STREAM: - def fast_dequantize(W, quant_state = None, out = None): + @torch.inference_mode + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -139,36 +148,54 @@ def fast_dequantize(W, quant_state = None, out = None): global CUDA_STREAM if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0") + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + if use_global_buffer: + + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) + pass # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - - # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM, + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM, ) out_absmax += offset + # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) return out.t() if is_transposed else out pass else: - def fast_dequantize(W, quant_state = None, out = None): + @torch.inference_mode + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -189,29 +216,45 @@ def fast_dequantize(W, quant_state = None, out = None): absmax2, code2, blocksize2, _, _, _, _ = state2 pass + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") - else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if use_global_buffer: - # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False) + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] + else: + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) + pass # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()),) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -263,17 +306,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM, + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, ) df += offset absmax = df @@ -281,7 +324,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize, CUDA_STREAM,) @@ -327,17 +370,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), ) df += offset absmax = df @@ -345,7 +388,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize,) @@ -354,6 +397,9 @@ def fast_gemv(X, W, quant_state, out = None): pass +torch_mm = torch.mm +torch_mv = torch.mv +torch_matmul = torch.matmul def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) @@ -361,12 +407,12 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: - out = torch.matmul(X, W.t(), out = out) + out = torch_matmul(X, W.t(), out = out) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant) - out = torch.matmul(X, W, out = out) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + out = torch_matmul(X, W, out = out) pass # Add in LoRA weights @@ -381,11 +427,11 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if bsz == 1: out = out.view(out_dim) - temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) else: out = out.view(bsz, out_dim) - temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) + temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) pass out = out.view(bsz, 1, out_dim) @@ -399,7 +445,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape @@ -409,7 +455,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass - out = torch.matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index c52d14f40..b15e04ab7 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,3 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported +from .rl import PatchFastRL diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0036a18c4..2ec4adaa1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.5" +__version__ = "2025.2.4" __all__ = [ "SUPPORTS_BFLOAT16", @@ -285,7 +285,11 @@ def _is_openai_available(): return False if _is_package_available("flash_attn"): # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: - from flash_attn.flash_attn_interface import flash_attn_cuda + try: + # See https://github.com/unslothai/unsloth/issues/1437 + from flash_attn.flash_attn_interface import flash_attn_gpu + except: + from flash_attn.flash_attn_interface import flash_attn_cuda HAS_FLASH_ATTENTION = True # Also check for softcapping @@ -843,7 +847,9 @@ def patch_linear_scaling( "self.rotary_emb = .+?\)", function, flags = re.DOTALL | re.MULTILINE, ) - if len(rotary_emb) == 0: return None, function + if len(rotary_emb) == 0: + return None, exec_code + "\n\n" + function + rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index 5dc71f920..9c12abb98 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -17,115 +17,8 @@ "PatchKTOTrainer", ] -try: - from transformers.utils.notebook import ( - IntervalStrategy, - NotebookTrainingTracker, - NotebookProgressCallback, - ) - HAS_NOTEBOOK = True -except: - HAS_NOTEBOOK = False -pass -import torch -from ._utils import torch_compile_options -import inspect -import torch.nn as nn -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from .rl import PatchFastRL +def PatchDPOTrainer(): PatchFastRL("DPO") -DPOTrainer_metrics = [ - "rewards/chosen", - "rewards/rejected", - "rewards/accuracies", - "rewards/margins", - "logps/rejected", - "logps/chosen", - "logits/rejected", - "logits/chosen", -] -set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics) - - -def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" - self.training_loss = 0 - self.last_log = 0 - column_names = [self.first_column] + ["Training Loss"] - if args.eval_strategy != IntervalStrategy.NO: - column_names.append("Validation Loss") - column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics] - self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) -pass - - -def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): - # Only for when there is no evaluation - if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: - values = {"Training Loss": logs["loss"]} - for metric in DPOTrainer_metrics: - values[metric.replace("/", " / ")] = logs[metric] - pass - # First column is necessarily Step since we're not in epoch eval strategy - values["Step"] = state.global_step - self.training_tracker.write_line(values) - pass -pass - - -def NotebookTrainingTracker_write_line(self, values): - """ - Write the values in the inner table. - - Args: - values (`Dict[str, float]`): The values to display. - """ - if self.inner_table is None: - self.inner_table = [list(values.keys()), list(values.values())] - else: - columns = self.inner_table[0] - new_values = {} - for key, value in values.items(): - lowered = key.lower() - if lowered in set_DPOTrainer_metrics: - new_values[lowered.replace("/", " / ")] = value - else: - new_values[key] = value - pass - values = new_values - - self.inner_table[0] = columns - if len(self.inner_table) > 1: - last_values = self.inner_table[-1] - first_column = self.inner_table[0][0] - if last_values[0] != values[first_column]: - # write new line - self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) - else: - # update last line - new_values = values - for c in columns: - if c not in new_values.keys(): - new_values[c] = last_values[columns.index(c)] - self.inner_table[-1] = [new_values[c] for c in columns] - else: - # Edit for evaluation purposes - self.inner_table.append([values[c] if c in values else 0 for c in columns]) - pass - pass -pass - - -def PatchDPOTrainer(): - if HAS_NOTEBOOK: - from transformers.trainer import is_in_notebook - if is_in_notebook(): - # Patch DPO notebook printing - NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line - from transformers.trainer import DEFAULT_PROGRESS_CALLBACK - DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin - DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log - pass - pass -pass -PatchKTOTrainer = PatchDPOTrainer +def PatchKTOTrainer(): PatchFastRL("KTO") diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index c65434328..bc29c46ab 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -210,7 +210,15 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() - if config is not None: return # [TODO] Hack to pass in config - need to remove later + if config is not None: + # [TODO] Hack to pass in config - need to remove later + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) + device = "cuda" + max_position_embeddings = config.max_position_embeddings + pass self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 497a357fe..fb7e96d8d 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -89,6 +89,7 @@ def GraniteAttention_fast_forward( n_groups = self.num_key_value_groups n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim + dropout_p = self.config.attention_dropout if self.training else 0 assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) @@ -135,7 +136,7 @@ def GraniteAttention_fast_forward( Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) pass - A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) + A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p) A = A.view(bsz, q_len, n_heads, head_dim) elif HAS_FLASH_ATTENTION and attention_mask is None: @@ -143,7 +144,7 @@ def GraniteAttention_fast_forward( K = K.transpose(1, 2) V = V.transpose(1, 2) window = (kv_seq_len, kv_seq_len) - A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) + A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p) else: # Grouped query attention # if n_groups != 1: @@ -157,7 +158,7 @@ def GraniteAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index edd3ddf94..a337472a3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,7 +20,7 @@ from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version -from unsloth_zoo.utils import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") @@ -70,7 +70,8 @@ from huggingface_hub.utils._token import get_token pass from triton import __version__ as triton_version -BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None +HAS_XFORMERS = xformers is not None +BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None def original_apply_qkv(self, X): @@ -89,6 +90,8 @@ def original_apply_o(self, X): from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax +# SDPA has GQA internally +SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__ # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): @@ -243,7 +246,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if n_groups != 1: + if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -262,7 +265,10 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + if SDPA_HAS_GQA: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) @@ -272,15 +278,15 @@ def LlamaAttention_fast_forward_inference( torch_nn_functional_silu = torch.nn.functional.silu -def fast_swiglu_inference(self, X): +def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None): # gate = self.gate_proj(X) # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) - up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) + gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) gate *= up @@ -289,14 +295,23 @@ def fast_swiglu_inference(self, X): return down pass - -def fast_rms_layernorm_inference(self, X): +torch_square = torch.square +torch_mean = torch.mean +def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None): old_dtype = X.dtype - XX = X.to(torch.float32) - variance = XX.square().mean(-1, keepdim = True) + if XX is None: + XX = X.to(torch.float32) + variance = XX.square().mean(-1, keepdim = True) + else: + XX.copy_(X) + torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance) + pass variance += self.variance_epsilon XX *= variance.rsqrt_() - X = XX.to(old_dtype) # Must preserve due to residual + + if XX is None: X = XX.to(old_dtype) + else: X.copy_(XX) + X *= self.weight return X pass @@ -403,7 +418,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -636,6 +651,7 @@ def LlamaModel_fast_forward( IS_GEMMA2 = self.config.model_type.startswith("gemma2") IS_COHERE = self.config.model_type.startswith("cohere") IS_GRANITE = self.config.model_type.startswith("granite") + train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -664,7 +680,7 @@ def LlamaModel_fast_forward( # Fix up attention mask by setting elements to 0 # Specifically for DPO - if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \ + if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \ (not train_embed_tokens): # Careful for inference the attention_mask is size (1, kv_seq_len) # Whilst the input_embeds is size (1, 1, 4096) @@ -792,9 +808,12 @@ def LlamaModel_fast_forward( pass pass - if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"): + if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE: # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 + # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) + # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". + # so let granite always use the attention refactor implementation. position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None @@ -898,15 +917,29 @@ def LlamaModel_fast_forward_inference( attention_mask = None, ): input_ids = input_ids[:,:self.max_seq_length] - hidden_states = self.model.embed_tokens(input_ids) - hidden_states = hidden_states.to(self.config.torch_dtype) - bsz, q_len, hd = hidden_states.shape + bsz, q_len = input_ids.shape + hd = self.config.hidden_size + mlp_size = self.config.intermediate_size + + X = self.model.embed_tokens(input_ids) + X = X.to(self.config.torch_dtype) + bsz, q_len, hd = X.shape + assert(q_len == 1) + + # Get saved buffers to reduce memory movement + residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + XX, XX2 = _XX[0], _XX[1] + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (bsz, q_len), - hidden_states, + X, seq_len, sliding_window = getattr(self.config, "sliding_window", None), ) @@ -915,30 +948,54 @@ def LlamaModel_fast_forward_inference( pass next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.model.layers): - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) - hidden_states, present_key_value = LlamaAttention_fast_forward_inference( + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.input_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, - hidden_states = hidden_states, + hidden_states = X, past_key_value = past_key_values[idx], position_ids = position_ids, attention_mask = attention_mask, do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), ) - hidden_states += residual - - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) - hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states += residual + X += residual + + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.post_attention_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X = fast_swiglu_inference( + decoder_layer.mlp, + X, + temp_gate = temp_gate, + temp_up = temp_up, + ) + X += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) + X = fast_rms_layernorm_inference( + self.model.norm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) return BaseModelOutputWithPast( - last_hidden_state = hidden_states, + last_hidden_state = X, past_key_values = next_decoder_cache, hidden_states = [], attentions = [], @@ -973,7 +1030,7 @@ def _CausalLM_fast_forward( attention_mask = attention_mask, ) else: - causal_mask = xformers.attn_bias.LowerTriangularMask() + causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1155,7 +1212,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads)) + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) device = "cuda" max_position_embeddings = config.max_position_embeddings pass @@ -1576,9 +1634,18 @@ def from_pretrained( model_patcher = None, tokenizer_name = None, trust_remote_code = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = False, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, **kwargs, ): if trust_remote_code: + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.") print( "Unsloth: WARNING `trust_remote_code` is True.\n"\ "Are you certain you want to do remote code execution?" @@ -1592,9 +1659,9 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ - f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ - f"O^O/ \_/ \\ Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ - f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ + f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ + f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' print(statistics) @@ -1622,7 +1689,11 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) # RoPE Scaling - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_config = AutoConfig.from_pretrained( + model_name, + token = token, + attn_implementation = "sdpa", + ) model_max_seq_length = model_config.max_position_embeddings # Check if RoPE Scaling is even allowed @@ -1643,6 +1714,9 @@ def from_pretrained( rope_scaling = max_seq_length / model_max_seq_length + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.") + logger.warning_once( f"Unsloth: {model_name} can only handle sequence lengths of at most "\ f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\ @@ -1684,17 +1758,54 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map = device_map, - torch_dtype = dtype, - # quantization_config = bnb_config, - token = token, - max_position_embeddings = max_position_embeddings, - trust_remote_code = trust_remote_code, - attn_implementation = "eager", - **kwargs, - ) + if not fast_inference: + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map = device_map, + torch_dtype = dtype, + # quantization_config = bnb_config, + token = token, + max_position_embeddings = max_position_embeddings, + trust_remote_code = trust_remote_code, + attn_implementation = "eager", + **kwargs, + ) + else: + from unsloth_zoo.vllm_utils import ( + load_vllm, + get_vllm_state_dict, + convert_vllm_to_huggingface, + generate_batches, + ) + allowed_args = inspect.getfullargspec(load_vllm).args + load_vllm_kwargs = dict( + model_name = model_name, + config = model_config, + gpu_memory_utilization = gpu_memory_utilization, + max_seq_length = max_seq_length, + dtype = dtype, + float8_kv_cache = float8_kv_cache, + enable_lora = True, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + ) + for allowed_arg in allowed_args: + if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: + load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg] + pass + + # Load vLLM first + llm = load_vllm(**load_vllm_kwargs) + + # Convert to HF format + _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) + model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) + model.vllm_engine = llm + model.fast_generate = model.vllm_engine.generate + + from functools import partial + model.fast_generate_batches = partial(generate_batches, model.vllm_engine) + pass # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer # We currently only support NVIDIA GPUs - AMD / Intel is a work in progress! @@ -2190,6 +2301,20 @@ def get_peft_model( modules_to_save = list(set(modules_to_save)) pass + vllm_engine = None + if hasattr(model, "vllm_engine"): + # Fast inference! + vllm_engine = model.vllm_engine + vllm_fast_generate = model.fast_generate + vllm_fast_generate_batches = model.fast_generate_batches + + if modules_to_save is not None: + raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.") + + if bias != "none": + raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.") + pass + # Get LoRA arguments = dict( r = r, @@ -2296,6 +2421,19 @@ def get_peft_model( torch.cuda.empty_cache() pass + # Patch for fast inference + if vllm_engine is not None: + model.vllm_engine = vllm_engine + model.fast_generate = vllm_fast_generate + model.fast_generate_batches = vllm_fast_generate_batches + + # Also saving and loading LoRA + from functools import partial + from unsloth_zoo.vllm_utils import save_lora, load_lora + model.save_lora = partial(save_lora, model) + model.load_lora = partial(load_lora, model) + pass + return model pass @@ -2505,18 +2643,24 @@ def for_inference(model): # return # pass - internal_model = model - internal_model.gradient_checkpointing = False - internal_model.training = False - - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = False - internal_model.training = False - pass - if hasattr(internal_model, "training"): - internal_model.training = False - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" # Also check if lm_head / embeddings are trained internal_model = model @@ -2525,30 +2669,13 @@ def for_inference(model): pass lm_head = internal_model.lm_head.weight device_type = lm_head.device.type - dtype = model.config.torch_dtype - - if type(dtype) is str: - if dtype == "float16": dtype = torch.float16 - elif dtype == "bfloat16": dtype = torch.bfloat16 - pass + dtype = _get_dtype(model.config.torch_dtype) # Wrap model.generate if model.generate.__name__ != "_fast_generate": model._unwrapped_old_generate = model.generate model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) pass - - # Patch tokenizer to pad to the left - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -2566,9 +2693,6 @@ def for_inference(model): @staticmethod def for_training(model, use_gradient_checkpointing = True): - internal_model = model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True # Delete all fast inference loras for param in model.parameters(): @@ -2576,14 +2700,24 @@ def for_training(model, use_gradient_checkpointing = True): del param._fast_lora pass - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True - pass - if hasattr(internal_model, "training"): - internal_model.training = True - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" # Also revert model.generate if hasattr(model, "_unwrapped_old_generate"): @@ -2591,18 +2725,6 @@ def for_training(model, use_gradient_checkpointing = True): del model._unwrapped_old_generate pass - # Patch tokenizer to pad to the right - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e9caad0e6..39b367e27 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -30,11 +30,11 @@ from huggingface_hub.utils._token import get_token pass from huggingface_hub import HfFileSystem +import importlib.util # [TODO] Move USE_MODELSCOPE to utils USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" if USE_MODELSCOPE: - import importlib if importlib.util.find_spec("modelscope") is None: raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') pass @@ -73,9 +73,25 @@ def from_pretrained( resize_model_vocab = None, revision = None, use_exact_model_name = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = False, + random_state = 3407, + max_lora_rank = 64, + disable_log_stats = True, *args, **kwargs, ): if token is None: token = get_token() + + if fast_inference: + if importlib.util.find_spec("vllm") is None: + raise ImportError( + "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\ + "You can do this in a terminal via `pip install vllm`" + ) + pass + pass old_model_name = model_name if not use_exact_model_name: @@ -255,6 +271,24 @@ def from_pretrained( tokenizer_name = None pass + if fast_inference: + from unsloth_zoo.vllm_utils import ( + patch_vllm, + vllm_dynamic_quant_supported, + ) + patch_vllm() + if model_name.endswith("unsloth-bnb-4bit"): + if not vllm_dynamic_quant_supported(model_name, model_config): + # Instead use -bnb-4bit variant + print( + f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\ + f"we do not yet support fast inference for {model_name}" + ) + model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit" + pass + pass + pass + model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -268,6 +302,13 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, + + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, *args, **kwargs, ) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index c1113f529..c81290b66 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -304,25 +304,30 @@ "unsloth/Mistral-Small-Instruct-2409", "mistralai/Mistral-Small-Instruct-2409", ), - "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", + "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", + "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-3B-Instruct", + "unsloth/Qwen2.5-3B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", + "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-14B-Instruct", + "unsloth/Qwen2.5-14B-Instruct-bnb-4bit", ), "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-32B-Instruct", @@ -332,25 +337,30 @@ "unsloth/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-72B-Instruct", ), - "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B", "Qwen/Qwen2.5-0.5B", + "unsloth/Qwen2.5-0.5B-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B", "Qwen/Qwen2.5-1.5B", + "unsloth/Qwen2.5-1.5B-bnb-4bit", ), - "unsloth/Qwen2.5-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B", "Qwen/Qwen2.5-3B", + "unsloth/Qwen2.5-3B-bnb-4bit", ), - "unsloth/Qwen2.5-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B", "Qwen/Qwen2.5-7B", + "unsloth/Qwen2.5-7B-bnb-4bit", ), - "unsloth/Qwen2.5-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B", "Qwen/Qwen2.5-14B", + "unsloth/Qwen2.5-14B-bnb-4bit", ), "unsloth/Qwen2.5-32B-bnb-4bit" : ( "unsloth/Qwen2.5-32B", @@ -432,21 +442,25 @@ "unsloth/Qwen2.5-Coder-32B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", ), - "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", + "unsloth/Llama-3.2-1B-bnb-4bit", ), - "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B", "meta-llama/Llama-3.2-3B", + "unsloth/Llama-3.2-3B-bnb-4bit", ), - "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", ), "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( "unsloth/Llama-3.1-Nemotron-70B-Instruct", @@ -471,20 +485,18 @@ "meta-llama/Llama-3.2-11B-Vision-Instruct", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision-Instruct", "meta-llama/Llama-3.2-90B-Vision-Instruct", - "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", ), "unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-11B-Vision", "meta-llama/Llama-3.2-11B-Vision", "unsloth/Llama-3.2-11B-Vision-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision", "meta-llama/Llama-3.2-90B-Vision", - "unsloth/Llama-3.2-90B-Vision-bnb-4bit", ), "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit" : ( "unsloth/Pixtral-12B-2409", @@ -524,6 +536,59 @@ "microsoft/phi-4", "unsloth/phi-4-bnb-4bit", ), + "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Qwen-32B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + ), + "unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Qwen-14B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "unsloth/DeepSeek-R1-Distill-Qwen-14B-bnb-4bit", + ), + "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Qwen-7B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit", + ), + "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-bnb-4bit", + ), + "unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Llama-8B", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit", + ), + "unsloth/DeepSeek-R1-Distill-Llama-70B-bnb-4bit" : ( + "unsloth/DeepSeek-R1-Distill-Llama-70B", + "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + ), + "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Base-2501", + "mistralai/Mistral-Small-24B-Base-2501", + "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit", + ), + "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Instruct-2501", + "mistralai/Mistral-Small-24B-Instruct-2501", + "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-7B-Instruct", + "Qwen/Qwen2.5-VL-7B-Instruct", + "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-72B-Instruct", + "Qwen/Qwen2.5-VL-72B-Instruct", + "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a97015f9..784ca9cb4 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -304,7 +304,7 @@ def pre_patch(): attention_module = MistralAttention, ) # Just for Mistral Nemo models! - if function is not None: + if function is not None and init_name is not None: function = patch_mistral_nemo_attention(function) # if True:#init_name is not None: exec(function, globals()) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py new file mode 100644 index 000000000..515c6587f --- /dev/null +++ b/unsloth/models/rl.py @@ -0,0 +1,423 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "PatchFastRL", +] + +METRICS_MOVE_TO_END = [ + "nll", + "aux", + "beta", + "alpha", +] +import torch +try: + from transformers.utils.notebook import ( + IntervalStrategy, + NotebookTrainingTracker, + NotebookProgressCallback, + ) + HAS_NOTEBOOK = True +except: + HAS_NOTEBOOK = False +pass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import inspect +import os +import re +import functools +from unsloth_zoo.compiler import create_new_function + + +def PatchRL(FastLanguageModel): + + from trl.models.utils import unwrap_model_for_generation + from contextlib import contextmanager + + @contextmanager + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + # Put the model in inference mode. + FastLanguageModel.for_inference(unwrapped_model) + + # We must use .clone for Unsloth since we force inference_mode + # Rather we should have used no_grad + original_generate = unwrapped_model.generate + def generate_with_clone(*args, **kwargs): + out = original_generate(*args, **kwargs) + if isinstance(out, torch.Tensor): + return out.clone() + return out + pass + unwrapped_model.generate = generate_with_clone + + try: + yield unwrapped_model + finally: + # Restore generate and return + unwrapped_model.generate = original_generate + FastLanguageModel.for_training(model) + pass + pass + pass + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + unwrap = "unwrap_model_for_generation" + for trainer in trainers: + if hasattr(eval(f"trl.trainer.{trainer}"), unwrap): + exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + pass +pass + + +def NotebookProgressCallback_on_train_begin(Trainer_metrics): + def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): + self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" + self.training_loss = 0 + self.last_log = 0 + column_names = [self.first_column] + ["Training Loss"] + if args.eval_strategy != IntervalStrategy.NO: + column_names.append("Validation Loss") + column_names += [x.replace("/", " / ") for x in Trainer_metrics] + self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) + pass + return _NotebookProgressCallback_on_train_begin +pass + + +def NotebookProgressCallback_on_log(Trainer_metrics): + def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): + # Only for when there is no evaluation + if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + values = {"Training Loss": logs["loss"]} + for metric in Trainer_metrics: + # Sometimes metric is not inside logs + try: values[metric.replace("/", " / ")] = logs[metric] + except: pass + pass + # First column is necessarily Step since we're not in epoch eval strategy + values["Step"] = state.global_step + self.training_tracker.write_line(values) + pass + pass + return _NotebookProgressCallback_on_log +pass + + +def NotebookTrainingTracker_write_line(Trainer_metrics): + set_Trainer_metrics = set(Trainer_metrics) + def _NotebookTrainingTracker_write_line(self, values): + """ + Write the values in the inner table. + + Args: + values (`Dict[str, float]`): The values to display. + """ + if self.inner_table is None: + self.inner_table = [list(values.keys()), list(values.values())] + else: + columns = self.inner_table[0] + new_values = {} + for key, value in values.items(): + lowered = key.lower() + if lowered in set_Trainer_metrics: + new_values[lowered.replace("/", " / ")] = value + else: + new_values[key] = value + pass + values = new_values + + self.inner_table[0] = columns + if len(self.inner_table) > 1: + last_values = self.inner_table[-1] + first_column = self.inner_table[0][0] + if last_values[0] != values[first_column]: + # write new line + self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) + else: + # update last line + new_values = values + for c in columns: + if c not in new_values.keys(): + new_values[c] = last_values[columns.index(c)] + self.inner_table[-1] = [new_values[c] for c in columns] + else: + # Edit for evaluation purposes + self.inner_table.append([values[c] if c in values else 0 for c in columns]) + pass + pass + pass + return _NotebookTrainingTracker_write_line +pass + + +def _PatchRLStatistics(metrics, algorithm): + if HAS_NOTEBOOK: + if len(metrics) == 0: + raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") + from transformers.trainer import is_in_notebook + if is_in_notebook(): + # Patch DPO notebook printing + NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) + from transformers.trainer import DEFAULT_PROGRESS_CALLBACK + DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) + DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + pass + pass +pass + + +@functools.cache +def get_trl_metrics(): + # Gets metrics so we can output them in notebooks + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + filepath = inspect.getfile(trl.trainer) + filepath = os.path.split(filepath)[0] + + all_metrics = dict() + for trainer in trainers: + filename = os.path.join(filepath, f"{trainer}.py") + if not os.path.exists(filename): continue + with open(filename, "r") as file: file = file.read() + + # Get metrics['kl'] or stats['kl'] + metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) + stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) + metrics = metrics + stats + + # Get optional f-strings + metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + metrics_f = metrics_f + stats_f + # Filter out prefixes if seen + # metrics[f"{prefix}rewards/chosen"] + left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file + if left_prefix: metrics += metrics_f + + # Move all eval_ things to the end and reward to the front + beginning = [] + middle = [] + end = [] + for x in metrics: + lowered = x.lower() + if "reward" in lowered: + beginning.append(x) + elif x.lower().startswith("eval"): + end.append(x) + else: + # Check if we want to move to the end + moved = False + for move_end in METRICS_MOVE_TO_END: + if move_end in lowered: + end.append(x) + moved = True + break + if not moved: + middle.append(x) + pass + pass + metrics = beginning + middle + end + + all_metrics[trainer[:trainer.find("_")].upper()] = metrics + pass + return all_metrics +pass + + +def PatchRLStatistics(algorithm = "GRPO"): + # Get notebook statistics columns to show up + algorithm = algorithm.upper() + all_metrics = get_trl_metrics() + if algorithm not in all_metrics: + print( + f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ + f"We support: `{list(all_metrics.keys())}`" + ) + pass + _PatchRLStatistics(all_metrics[algorithm], algorithm) +pass + + +def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): + # Patch for vLLM and Unsloth PEFT + import trl + import trl.trainer + + trainer = eval(f"trl.trainer.{trainer_file}") + name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] + assert(len(name) == 1) + RLTrainer_name = name[0] + RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + + try: + __init__ = inspect.getsource(RLTrainer.__init__) + except: + # Already patched most likely! + return + old__init__ = __init__ + all_imports = dir(trainer) + assert("Union" in all_imports) + imports = [x for x in all_imports if not x.startswith("_")] + imports += ["Trainer"] + + spaces = __init__.find("def") + __init__ = __init__.split("\n") + __init__ = "\n".join(x[spaces:] for x in __init__) + + # Replace vLLM sections since we already have it done! + vllm_part = re.findall( + r"(\n[\s]{4}"\ + r"if (self|args)\.use_vllm\:.+?"\ + r"\n[\s]{4,}"\ + "else:\n)", + __init__, + flags = re.MULTILINE | re.DOTALL, + ) + if (len(vllm_part) != 1): return + + vllm_part, args = vllm_part[0][0], vllm_part[0][1] + # Strip all comments + new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + + # Get SamplingParams + sampling_params = re.findall( + r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ + r"SamplingParams\(.+?\))", + new_vllm_part, + flags = re.MULTILINE | re.DOTALL, + ) + if len(sampling_params) != 1: return + + sampling_params = sampling_params[0] + # Replace with our vLLM engine + sampling_params = \ + " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + sampling_params # Add spaces + new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" + __init__ = __init__.replace(vllm_part, new_vllm_part) + + # Remove peft_config + __init__ = __init__.replace("elif peft_config is None:", "elif False:") + __init__ = __init__.replace("elif peft_config is not None:", "elif False:") + __init__ = __init__.replace("if peft_config is None:", "if False:") + __init__ = __init__.replace("if peft_config is not None:", "if False:") + __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") + + # Add spaces back into __init__ + __init__ = __init__.split("\n") + __init__ = "\n".join(' '*spaces + x for x in __init__) + + # Search for vLLM calling in all child functions + functions = dir(RLTrainer) + RLTrainer_source = inspect.getsource(RLTrainer) + functions = [x for x in functions if f"def {x}" in RLTrainer_source] + + changed = {"__init__" : (old__init__, __init__,)} + for function in functions: + if not hasattr(RLTrainer, function): continue + fx = getattr(RLTrainer, function) + try: + source = inspect.getsource(fx) + except: + continue + original_source = source + + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + source = re.sub( + r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n", + r"\n\1pass\n", + source, + ) + + # llm_model.load_weights(model.state_dict().items()) + source = re.sub( + r"(\n[\s]{4,}).+?load_weights\(.+?\n", + r"\n\1pass\n", + source, + ) + + # .state_dict() + source = re.sub( + r"\.state_dict\(\)", + r"", + source, + ) + + # Replace self.llm.generate and self.llm.chat + lora_name = trainer_file + "_lora_model" + source = re.sub( + r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)", + r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))", + source + ) + + # Skip if no changes done + if source == original_source: continue + + # Find all imports + imports += [x for x in all_imports if not x.startswith("_") and x in source] + + changed[function] = (original_source, source,) + pass + + # Import all functions + imports = list(set(imports)) + + # Patch all functions + for function in changed: + old, new = changed[function] + RLTrainer_source = RLTrainer_source.replace(old, new) + pass + RLTrainer_source = RLTrainer_source.replace( + f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 + ) + + # Create new class in compiled cache and import it + module = create_new_function( + RLTrainer_name, + RLTrainer_source, + f"trl.trainer.{trainer_file}", + imports, + ) + + # Patch over modules + exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + return module +pass + + +def patch_trl_rl_trainers(): + # Patch all TRL modules if they have vLLM or PEFT + import trl.trainer + all_trainers = dir(trl.trainer) + all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")] + for trainer in all_trainers: + _patch_trl_rl_trainers(trainer) + return +pass + + +def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): + if FastLanguageModel is not None: PatchRL(FastLanguageModel) + patch_trl_rl_trainers() + PatchRLStatistics(algorithm) +pass