Skip to content

Commit 499635a

Browse files
danielhanchenchrehall68neph1xyangkOseltamivir
authored
Gemma2 (#709)
* Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py * Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651) * Nightly (#649) * Update llama.py * offload * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * continued pretraining trainer * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * is_bfloat16_supported * Update __init__.py * Update README.md * Update llama.py * is_bfloat16_supported * Update __init__.py * Mistral v3 * Phi 3 medium * Update chat_templates.py * Update chat_templates.py * Phi-3 * Update save.py * Update README.md Mistral v3 to Mistral v0.3 * Untrained tokens * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update save.py * Update save.py * Update save.py * checkpoint * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * accelerate * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * train_dataloader * Update llama.py * Update llama.py * Update llama.py * use_fast_convert * Update save.py * Update save.py * Update save.py * Update save.py * remove_special_tokens * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Update chat_templates.py * Support bfloat16 GGUF * Update save.py * Update llama.py * fast_forward_inference * Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> * Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving * Implemented better list management and then forgot to actually call the new list variable, fixed * Check type of given quantization method and return type error if not list or string * Update save.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> * Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652) This reverts commit 30605de. * Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653) This reverts commit e2b2083. * Update llama.py * peft * patch * Update loader.py * retrain * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * offload * Update llama.py * Create a starter script for command-line training to integrate in ML ops pipelines. (#623) * Update chat_templates.py * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Ollama * Update chat_templates.py * ollama * Update mapper.py * Update chat_templates.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Fixes * clearer messages * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update llama.py * Update llama.py * log * Update __init__.py * Update llama.py * Update __init__.py * Create Merge.png * Create ollama.png * Gemma2 * Update llama.py * Update loader.py * Update pyproject.toml * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Revert Gemma2 * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update rms_layernorm.py * Update gemma2.py * logit softcapping * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update llama.py * Update gemma2.py * Update llama.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update _utils.py * Update _utils.py * Update gemma2.py * compile flags * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update gemma2.py * Update gemma2.py * fixes * Update _utils.py * Fix generation * Update llama.py * Update llama.py * Update _utils.py * Update _utils.py * Update _utils.py * pad token * Update gemma2.py * pad token * Update _utils.py * Update llama.py * Update gemma2.py * edit warning * Update tokenizer_utils.py --------- Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> Co-authored-by: ArcadaLabs-Jason <52756218+ArcadaLabs-Jason@users.noreply.github.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
1 parent 933d9fe commit 499635a

File tree

17 files changed

+772
-60
lines changed

17 files changed

+772
-60
lines changed

images/Merge.png

30.7 KB
Loading

images/ollama.png

65.6 KB
Loading

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ exclude = ["images*"]
3434
[project.optional-dependencies]
3535
huggingface = [
3636
"tyro",
37-
"transformers>=4.38.2",
37+
"transformers>=4.42.3",
3838
"datasets>=2.16.0",
3939
"sentencepiece>=0.2.0",
4040
"tqdm",
@@ -185,9 +185,9 @@ colab-ampere-torch220 = [
185185
]
186186
colab-new = [
187187
"tyro",
188-
"transformers>=4.38.2",
188+
"transformers>=4.42.3",
189189
"datasets>=2.16.0",
190-
"sentencepiece",
190+
"sentencepiece>=0.2.0",
191191
"tqdm",
192192
"psutil",
193193
"wheel>=0.42.0",

unsloth/kernels/cross_entropy_loss.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
from transformers.models.llama.modeling_llama import logger
2020

2121

22+
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
2223
@triton.jit
2324
def _cross_entropy_forward(
2425
logits_ptr, logits_row_stride,
2526
loss_ptr,
2627
logsumexp_ptr,
2728
labels_ptr,
28-
VOCAB_SIZE : tl.constexpr,
29-
BLOCK_SIZE : tl.constexpr,
29+
VOCAB_SIZE : tl.constexpr,
30+
BLOCK_SIZE : tl.constexpr,
31+
DO_SOFTCAPPING : tl.constexpr,
32+
SOFTCAP : tl.constexpr,
3033
):
3134
"""
3235
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
@@ -58,29 +61,38 @@ def _cross_entropy_forward(
5861
mask = col_offsets < VOCAB_SIZE
5962

6063
label_idx = tl.load(labels_ptr).to(tl.int32)
61-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
64+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
65+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
66+
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
67+
68+
logits = logits.to(tl.float32)
6269
c = tl.max(logits, 0)
6370
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
6471

6572
if label_idx != -100:
66-
x = tl.load(logits_ptr + label_idx).to(tl.float32)
67-
loss = logsumexp - x
73+
x = tl.load(logits_ptr + label_idx)
74+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
75+
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
76+
loss = logsumexp - x.to(tl.float32)
6877
else:
6978
loss = 0.0
7079
tl.store(logsumexp_ptr, logsumexp)
7180
tl.store(loss_ptr, loss)
7281
pass
7382

7483

84+
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
7585
@triton.jit
7686
def _chunked_cross_entropy_forward(
7787
logits_ptr, logits_row_stride,
7888
loss_ptr,
7989
logsumexp_ptr,
8090
labels_ptr,
81-
VOCAB_SIZE : tl.constexpr,
82-
N_CHUNKS : tl.constexpr,
83-
BLOCK_SIZE : tl.constexpr,
91+
VOCAB_SIZE : tl.constexpr,
92+
N_CHUNKS : tl.constexpr,
93+
BLOCK_SIZE : tl.constexpr,
94+
DO_SOFTCAPPING : tl.constexpr,
95+
SOFTCAP : tl.constexpr,
8496
):
8597
"""
8698
256K vocab divided in 4 chunks
@@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
117129
mask = col_offsets < VOCAB_SIZE
118130

119131
label_idx = tl.load(labels_ptr).to(tl.int32)
120-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
132+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
133+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
134+
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
135+
136+
logits = logits.to(tl.float32)
121137
c = tl.max(logits, 0)
122138
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
123139

@@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
126142
# Do the -x separately
127143
if label_idx != -100:
128144
x = tl.load(logits_ptr + label_idx).to(tl.float32)
129-
loss = -1.0 * x
145+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
146+
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
147+
loss = -1.0 * x.to(tl.float32)
130148
else:
131149
loss = 0.0
132150
tl.store(loss_ptr, loss)
@@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
135153
pass
136154

137155

156+
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
138157
@triton.jit
139158
def _cross_entropy_backward(
140159
logits_ptr, logits_row_stride,
141160
dloss_ptr, dloss_row_stride,
142161
logsumexp_ptr,
143162
labels_ptr,
144-
VOCAB_SIZE : tl.constexpr,
145-
BLOCK_SIZE : tl.constexpr,
163+
VOCAB_SIZE : tl.constexpr,
164+
BLOCK_SIZE : tl.constexpr,
165+
DO_SOFTCAPPING : tl.constexpr,
166+
SOFTCAP : tl.constexpr,
146167
):
147168
"""
148169
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
@@ -173,15 +194,27 @@ def _cross_entropy_backward(
173194
else:
174195
dloss = 0.0
175196

176-
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
197+
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
198+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
199+
if DO_SOFTCAPPING:
200+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
201+
partial = tl.math.tanh(x / SOFTCAP)
202+
x = SOFTCAP * partial
203+
pass
204+
177205
logsumexp = tl.load(logsumexp_ptr + row_idx)
178-
y = tl.exp(x - logsumexp)
206+
y = tl.exp(x.to(tl.float32) - logsumexp)
179207
y = tl.where(
180208
col_offsets == label_idx,
181209
y - 1.0, # exp(x - logsumexp) - 1
182210
y, # exp(x - logsumexp)
183211
)
184212

213+
if DO_SOFTCAPPING:
214+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
215+
y = y * (1.0 - partial*partial)
216+
pass
217+
185218
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
186219
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
187220
pass
@@ -191,40 +224,46 @@ def _cross_entropy_backward(
191224

192225
class Fast_CrossEntropyLoss(torch.autograd.Function):
193226
@staticmethod
194-
def forward(ctx, logits, labels):
227+
def forward(ctx, logits, labels, logit_softcapping = 0):
195228
n_rows, vocab_size = logits.shape
196229

197230
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
198231
n_chunks = div + (mod != 0)
199-
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
232+
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
233+
234+
DO_SOFTCAPPING = (logit_softcapping != 0)
200235

201236
if n_chunks == 1:
202237
# For small vocabs <= 65336 like Llama, Mistral
203238
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
204-
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
239+
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
205240

206241
_cross_entropy_forward[(n_rows,)](
207242
logits, logits.stride(0),
208243
losses,
209244
logsumexp,
210245
labels,
211-
VOCAB_SIZE = vocab_size,
212-
BLOCK_SIZE = BLOCK_SIZE,
213-
num_warps = num_warps,
246+
VOCAB_SIZE = vocab_size,
247+
BLOCK_SIZE = BLOCK_SIZE,
248+
DO_SOFTCAPPING = DO_SOFTCAPPING,
249+
SOFTCAP = logit_softcapping,
250+
num_warps = num_warps,
214251
)
215252
else:
216253
# For large vocabs > 65336 like Gemma 256K
217-
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
254+
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
218255

219256
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
220257
logits, logits.stride(0),
221258
losses,
222259
logsumexp,
223260
labels,
224-
VOCAB_SIZE = vocab_size,
225-
N_CHUNKS = n_chunks,
226-
BLOCK_SIZE = MAX_FUSED_SIZE,
227-
num_warps = 32,
261+
VOCAB_SIZE = vocab_size,
262+
N_CHUNKS = n_chunks,
263+
BLOCK_SIZE = MAX_FUSED_SIZE,
264+
DO_SOFTCAPPING = DO_SOFTCAPPING,
265+
SOFTCAP = logit_softcapping,
266+
num_warps = 32,
228267
)
229268
# logsumexp(chunked_logsumexp) - x
230269
# Do the -x separately
@@ -234,6 +273,8 @@ def forward(ctx, logits, labels):
234273
pass
235274

236275
ctx.save_for_backward(logits, logsumexp, labels)
276+
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
277+
ctx.logit_softcapping = logit_softcapping
237278
return losses
238279
pass
239280

@@ -251,16 +292,18 @@ def backward(ctx, dlosses):
251292
dlosses, dlosses.stride(0),
252293
logsumexp,
253294
labels,
254-
VOCAB_SIZE = vocab_size,
255-
BLOCK_SIZE = BLOCK_SIZE,
256-
num_warps = 8,
295+
VOCAB_SIZE = vocab_size,
296+
BLOCK_SIZE = BLOCK_SIZE,
297+
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
298+
SOFTCAP = ctx.logit_softcapping,
299+
num_warps = 8,
257300
)
258301
return logits, None, None,
259302
pass
260303
pass
261304

262305

263-
def fast_cross_entropy_loss(logits, labels):
306+
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
264307
"""
265308
Arguments:
266309
logits: (batch, seq_len, vocab_size)
@@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
274317
loss = Fast_CrossEntropyLoss.apply(
275318
logits.view(batch*seq_len, d),
276319
labels.view(-1),
320+
logit_softcapping,
277321
)
278322
n_items = torch.count_nonzero(labels != -100)
279323
return loss.sum() / n_items

unsloth/kernels/geglu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
4141
def geglu_exact_forward_kernel(gate, up):
4242
batch, seq_len, hd = gate.shape
4343
n_elements = gate.numel()
44-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
44+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
4545
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
4646
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
4747
return out
@@ -133,7 +133,7 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
133133
def geglu_approx_forward_kernel(gate, up):
134134
batch, seq_len, hd = gate.shape
135135
n_elements = gate.numel()
136-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
136+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
137137
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
138138
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
139139
return out

unsloth/kernels/rms_layernorm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward(
119119
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
120120

121121
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
122-
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
122+
inv_var = tl.math.rsqrt(row_var + eps)
123123
tl.store(r, inv_var)
124124
normed = X_row * inv_var
125125
output = normed * (W_row + 1.0)
@@ -137,8 +137,8 @@ def forward(ctx, X, W, eps, gemma = False):
137137
n_rows, n_cols = X.shape
138138
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
139139

140-
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
141-
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
140+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
141+
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
142142

143143
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
144144
fx[(n_rows,)](

unsloth/kernels/swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
4141
def swiglu_fg_kernel(e, g):
4242
batch, seq_len, hd = e.shape
4343
n_elements = e.numel()
44-
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
44+
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
4545
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
4646
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
4747
return h

unsloth/kernels/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None):
105105

106106
# Create weight matrix
107107
if out is None:
108-
out = torch.empty(shape, dtype = dtype, device = "cuda")
108+
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
109109
else:
110110
assert(out.shape == shape)
111111
assert(out.dtype == dtype)
112112

113113
# NF4 dequantization of statistics
114114
n_elements_absmax = absmax.numel()
115-
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
115+
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
116116

117117
# Do dequantization
118118
ptr_out_absmax = get_ptr(out_absmax)
@@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None):
161161
bout = shape[0]
162162

163163
if out is None:
164-
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
164+
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
165165
# else:
166166
# assert(out.shape == (1, 1, bout,))
167167
# pass
@@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None):
179179
ldb = ctypes.c_int32(ldb)
180180
ldc = ctypes.c_int32(ldc)
181181

182-
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
182+
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
183183
cdequantize_blockwise_fp32(
184184
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
185185
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),

0 commit comments

Comments
 (0)