Skip to content

Commit d2c0c1b

Browse files
danielhanchenshimmyshimmerchrehall68neph1xyangk
authored
Nightly (#632)
* Update llama.py * offload * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * continued pretraining trainer * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * is_bfloat16_supported * Update __init__.py * Update README.md * Update llama.py * is_bfloat16_supported * Update __init__.py * Mistral v3 * Phi 3 medium * Update chat_templates.py * Update chat_templates.py * Phi-3 * Update save.py * Update README.md Mistral v3 to Mistral v0.3 * Untrained tokens * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update save.py * Update save.py * Update save.py * checkpoint * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * accelerate * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * train_dataloader * Update llama.py * Update llama.py * Update llama.py * use_fast_convert * Update save.py * Update save.py * Update save.py * Update save.py * remove_special_tokens * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Update chat_templates.py * Support bfloat16 GGUF * Update save.py * Update llama.py * fast_forward_inference * Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de>
1 parent 8a9e24e commit d2c0c1b

File tree

9 files changed

+342
-159
lines changed

9 files changed

+342
-159
lines changed

PARAMETERS.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
## LoraConfig Parameters
2+
3+
Adjusting the `LoraConfig` parameters allows you to balance model performance and computational efficiency in Low-Rank Adaptation (LoRA). Here’s a concise breakdown of key parameters:
4+
5+
**r**
6+
- **Description**: Rank of the low-rank decomposition for factorizing weight matrices.
7+
- **Impact**:
8+
- **Higher**: Retains more information, increases computational load.
9+
- **Lower**: Fewer parameters, more efficient training, potential performance drop if too small.
10+
11+
12+
**lora_alpha**
13+
- **Description**: Scaling factor for the low-rank matrices' contribution.
14+
- **Impact**:
15+
- **Higher**: Increases influence, speeds up convergence, risks instability or overfitting.
16+
- **Lower**: Subtler effect, may require more training steps.
17+
18+
**lora_dropout**
19+
- **Description**: Probability of zeroing out elements in low-rank matrices for regularization.
20+
- **Impact**:
21+
- **Higher**: More regularization, prevents overfitting, may slow training and degrade performance.
22+
- **Lower**: Less regularization, may speed up training, risks overfitting.
23+
24+
**loftq_config**
25+
- **Description**: Configuration for LoftQ, a quantization method for the backbone weights and initialization of LoRA layers.
26+
- **Impact**:
27+
- **Not None**: If specified, LoftQ will quantize the backbone weights and initialize the LoRA layers. It requires setting `init_lora_weights='loftq'`.
28+
- **None**: LoftQ quantization is not applied.
29+
- **Note**: Do not pass an already quantized model when using LoftQ as LoftQ handles the quantization process itself.
30+
31+
32+
**use_rslora**
33+
- **Description**: Enables Rank-Stabilized LoRA (RSLora).
34+
- **Impact**:
35+
- **True**: Uses Rank-Stabilized LoRA, setting the adapter scaling factor to `lora_alpha/math.sqrt(r)`, which has been proven to work better as per the [Rank-Stabilized LoRA paper](https://doi.org/10.48550/arXiv.2312.03732).
36+
- **False**: Uses the original default scaling factor `lora_alpha/r`.
37+
38+
**gradient_accumulation_steps**
39+
- **Default**: 1
40+
- **Description**: The number of steps to accumulate gradients before performing a backpropagation update.
41+
- **Impact**:
42+
- **Higher**: Accumulate gradients over multiple steps, effectively increasing the batch size without requiring additional memory. This can improve training stability and convergence, especially with large models and limited hardware.
43+
- **Lower**: Faster updates but may require more memory per step and can be less stable.
44+
45+
**weight_decay**
46+
- **Default**: 0.01
47+
- **Description**: Regularization technique that applies a small penalty to the weights during training.
48+
- **Impact**:
49+
- **Non-zero Value (e.g., 0.01)**: Adds a penalty proportional to the magnitude of the weights to the loss function, helping to prevent overfitting by discouraging large weights.
50+
- **Zero**: No weight decay is applied, which can lead to overfitting, especially in large models or with small datasets.
51+
52+
**learning_rate**
53+
- **Default**: 2e-4
54+
- **Description**: The rate at which the model updates its parameters during training.
55+
- **Impact**:
56+
- **Higher**: Faster convergence but risks overshooting optimal parameters and causing instability in training.
57+
- **Lower**: More stable and precise updates but may slow down convergence, requiring more training steps to achieve good performance.
58+
59+
## Target Modules
60+
61+
**q_proj (query projection)**
62+
- **Description**: Part of the attention mechanism in transformer models, responsible for projecting the input into the query space.
63+
- **Impact**: Transforms the input into query vectors that are used to compute attention scores.
64+
65+
**k_proj (key projection)**
66+
- **Description**: Projects the input into the key space in the attention mechanism.
67+
- **Impact**: Produces key vectors that are compared with query vectors to determine attention weights.
68+
69+
**v_proj (value projection)**
70+
- **Description**: Projects the input into the value space in the attention mechanism.
71+
- **Impact**: Produces value vectors that are weighted by the attention scores and combined to form the output.
72+
73+
**o_proj (output projection)**
74+
- **Description**: Projects the output of the attention mechanism back into the original space.
75+
- **Impact**: Transforms the combined weighted value vectors back to the input dimension, integrating attention results into the model.
76+
77+
**gate_proj (gate projection)**
78+
- **Description**: Typically used in gated mechanisms within neural networks, such as gating units in gated recurrent units (GRUs) or other gating mechanisms.
79+
- **Impact**: Controls the flow of information through the gate, allowing selective information passage based on learned weights.
80+
81+
**up_proj (up projection)**
82+
- **Description**: Used for up-projection, typically increasing the dimensionality of the input.
83+
- **Impact**: Expands the input to a higher-dimensional space, often used in feedforward layers or when transitioning between different layers with differing dimensionalities.
84+
85+
**down_proj (down projection)**
86+
- **Description**: Used for down-projection, typically reducing the dimensionality of the input.
87+
- **Impact**: Compresses the input to a lower-dimensional space, useful for reducing computational complexity and controlling the model size.

unsloth/__init__.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,20 @@
1414
import os
1515
import warnings
1616
import importlib
17+
import sys
18+
from packaging.version import Version
1719

18-
# Currently only supports 1 GPU, or else seg faults will occur.
20+
# Define a list of modules to check
21+
MODULES_TO_CHECK = ["peft", "bitsandbytes"]
22+
23+
# Check if any of the modules in the list have been imported
24+
for module in MODULES_TO_CHECK:
25+
if module in sys.modules:
26+
raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
27+
pass
28+
pass
29+
30+
# Currently only supports 1 GPU, or else seg faults will occur.
1931
if "CUDA_VISIBLE_DEVICES" in os.environ:
2032
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
2133
devices = os.environ["CUDA_VISIBLE_DEVICES"]
@@ -66,8 +78,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
6678

6779
# Try loading bitsandbytes and triton
6880
import bitsandbytes as bnb
81+
6982
import triton
70-
from triton.common.build import libcuda_dirs
83+
libcuda_dirs = lambda: None
84+
if Version(triton.__version__) >= Version("3.0.0"):
85+
try: from triton.backends.nvidia.driver import libcuda_dirs
86+
except: pass
87+
else: from triton.common.build import libcuda_dirs
88+
7189
import os
7290
import re
7391
import numpy as np
@@ -103,8 +121,11 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
103121
importlib.reload(bnb)
104122
importlib.reload(triton)
105123
try:
106-
import bitsandbytes as bnb
107-
from triton.common.build import libcuda_dirs
124+
libcuda_dirs = lambda: None
125+
if Version(triton.__version__) >= Version("3.0.0"):
126+
try: from triton.backends.nvidia.driver import libcuda_dirs
127+
except: pass
128+
else: from triton.common.build import libcuda_dirs
108129
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
109130
libcuda_dirs()
110131
except:

unsloth/chat_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,7 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
12861286
pass
12871287

12881288
for prompt in prompts:
1289-
command = f"./llama.cpp/main -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
1289+
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
12901290
f"--check-tensors -p '{prompt}'"
12911291

12921292
datas = []

unsloth/kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from .fast_lora import (
2626
get_lora_parameters,
27+
get_lora_parameters_bias,
2728
apply_lora_mlp_swiglu,
2829
apply_lora_mlp_geglu_exact,
2930
apply_lora_mlp_geglu_approx,

unsloth/kernels/fast_lora.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from .utils import fast_dequantize, QUANT_STATE, get_lora_parameters, matmul_lora
16+
from .utils import (
17+
fast_dequantize,
18+
QUANT_STATE,
19+
get_lora_parameters,
20+
get_lora_parameters_bias,
21+
matmul_lora,
22+
)
1723

1824

1925
class LoRA_MLP(torch.autograd.Function):

unsloth/models/loader.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@
3333

3434
def _get_model_name(model_name, load_in_4bit = True):
3535

36-
# First try replacing lowercase 'b' with uppercase 'B'
37-
model_name = model_name.lower()
38-
3936
if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER:
40-
model_name = INT_TO_FLOAT_MAPPER[model_name]
37+
model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
4138
logger.warning_once(
4239
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
4340
f"4bit loading.\nThe minimum required version is 4.37.\n"\
@@ -47,15 +44,15 @@ def _get_model_name(model_name, load_in_4bit = True):
4744
)
4845

4946
elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER:
50-
new_model_name = INT_TO_FLOAT_MAPPER[model_name]
47+
new_model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
5148
# logger.warning_once(
5249
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
5350
# f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
5451
# )
5552
model_name = new_model_name
5653

5754
elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER:
58-
new_model_name = FLOAT_TO_INT_MAPPER[model_name]
55+
new_model_name = FLOAT_TO_INT_MAPPER[model_name.lower()]
5956
# logger.warning_once(
6057
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
6158
# f"We shall load `{new_model_name}` for 4x faster loading."
@@ -70,17 +67,18 @@ def _get_model_name(model_name, load_in_4bit = True):
7067
class FastLanguageModel(FastLlamaModel):
7168
@staticmethod
7269
def from_pretrained(
73-
model_name = "unsloth/llama-3-8b-bnb-4bit",
74-
max_seq_length = None,
75-
dtype = None,
76-
load_in_4bit = True,
77-
token = None,
78-
device_map = "sequential",
79-
rope_scaling = None,
80-
fix_tokenizer = True,
81-
trust_remote_code = False,
82-
use_gradient_checkpointing = True,
83-
resize_model_vocab = None,
70+
model_name = "unsloth/llama-3-8b-bnb-4bit",
71+
max_seq_length = None,
72+
dtype = None,
73+
load_in_4bit = True,
74+
token = None,
75+
device_map = "sequential",
76+
rope_scaling = None,
77+
fix_tokenizer = True,
78+
trust_remote_code = False,
79+
use_gradient_checkpointing = "unsloth",
80+
resize_model_vocab = None,
81+
revision = None,
8482
*args, **kwargs,
8583
):
8684
if token is None and "HF_TOKEN" in os.environ:
@@ -95,12 +93,12 @@ def from_pretrained(
9593
# First check if it's a normal model via AutoConfig
9694
is_peft = False
9795
try:
98-
model_config = AutoConfig.from_pretrained(model_name, token = token)
96+
model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision)
9997
is_peft = False
10098
except:
10199
try:
102100
# Most likely a PEFT model
103-
peft_config = PeftConfig.from_pretrained(model_name, token = token)
101+
peft_config = PeftConfig.from_pretrained(model_name, token = token, revision = revision)
104102
except:
105103
raise RuntimeError(f"Unsloth: `{model_name}` is not a full model or a PEFT model.")
106104

@@ -143,22 +141,24 @@ def from_pretrained(
143141
pass
144142

145143
model, tokenizer = dispatch_model.from_pretrained(
146-
model_name = model_name,
147-
max_seq_length = max_seq_length,
148-
dtype = dtype,
149-
load_in_4bit = load_in_4bit,
150-
token = token,
151-
device_map = device_map,
152-
rope_scaling = rope_scaling,
153-
fix_tokenizer = fix_tokenizer,
154-
model_patcher = dispatch_model,
155-
tokenizer_name = tokenizer_name,
144+
model_name = model_name,
145+
max_seq_length = max_seq_length,
146+
dtype = dtype,
147+
load_in_4bit = load_in_4bit,
148+
token = token,
149+
device_map = device_map,
150+
rope_scaling = rope_scaling,
151+
fix_tokenizer = fix_tokenizer,
152+
model_patcher = dispatch_model,
153+
tokenizer_name = tokenizer_name,
156154
trust_remote_code = trust_remote_code,
155+
revision = revision if not is_peft else None,
157156
*args, **kwargs,
158157
)
159158

160159
if resize_model_vocab is not None:
161160
model.resize_token_embeddings(resize_model_vocab)
161+
pass
162162

163163
# In case the model supports tagging, add the unsloth tag.
164164
if hasattr(model, "add_model_tags"):
@@ -188,8 +188,16 @@ def from_pretrained(
188188
pass
189189

190190
if is_peft:
191+
# From https://github.com/huggingface/peft/issues/184
191192
# Now add PEFT adapters
192-
model = PeftModel.from_pretrained(model, old_model_name, token = token)
193+
model.enable_input_require_grads()
194+
model = PeftModel.from_pretrained(
195+
model,
196+
old_model_name,
197+
token = token,
198+
revision = revision,
199+
is_trainable = True,
200+
)
193201
# Patch it as well!
194202
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
195203
pass

unsloth/models/mapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@
186186
"unsloth/Qwen2-70B-Instruct-bnb-4bit" : (
187187
"Qwen/Qwen2-70B-Instruct",
188188
),
189+
"mistralai/Codestral-22B-v0.1" : (
190+
"mistral-community/Codestral-22B-v0.1",
191+
),
189192
}
190193

191194
INT_TO_FLOAT_MAPPER = {}

0 commit comments

Comments
 (0)