Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
7b81ca5
Update _utils.py
danielhanchen Aug 22, 2024
94f2d34
Update _utils.py
danielhanchen Aug 22, 2024
7c5222d
Update _utils.py
danielhanchen Aug 22, 2024
15d4417
Update _utils.py
danielhanchen Aug 22, 2024
1ea463c
Update _utils.py
danielhanchen Aug 22, 2024
cf929e2
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
5a7be98
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
2590b4c
Update tokenizer_utils.py
danielhanchen Aug 22, 2024
621e65b
update token retrieval logic (#952)
not-lain Aug 23, 2024
b62e5cd
Update llama.py
danielhanchen Aug 23, 2024
fb9dd65
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Aug 23, 2024
3b49609
get_token
danielhanchen Aug 24, 2024
9c8875e
Update README.md
danielhanchen Aug 24, 2024
c25de14
Merge branch 'main' into nightly
danielhanchen Aug 25, 2024
646a27b
Merge branch 'main' into nightly
danielhanchen Aug 27, 2024
a44357d
Update gemma2.py
danielhanchen Aug 30, 2024
7ed1c16
Update rms_layernorm.py
danielhanchen Aug 30, 2024
d7ef49e
synchronize
danielhanchen Aug 30, 2024
9a69548
Update gemma2.py
danielhanchen Aug 30, 2024
e6dadb4
Update rms_layernorm.py
danielhanchen Aug 30, 2024
f8e77cf
Update rms_layernorm.py
danielhanchen Aug 30, 2024
cfbaa97
Update rms_layernorm.py
danielhanchen Aug 30, 2024
32b2f3f
layernorm
danielhanchen Aug 30, 2024
9e7057d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
a193508
Update gemma2.py
danielhanchen Aug 30, 2024
65eaa2d
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1beeb22
Update rms_layernorm.py
danielhanchen Aug 30, 2024
1eb7705
revert
danielhanchen Aug 30, 2024
c3fe972
Gemma
danielhanchen Aug 31, 2024
75dbfba
Update rms_layernorm.py
danielhanchen Aug 31, 2024
332b091
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4ecc119
Update rms_layernorm.py
danielhanchen Aug 31, 2024
07a1246
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e3239e4
Update rms_layernorm.py
danielhanchen Aug 31, 2024
6ae1ac2
Update rms_layernorm.py
danielhanchen Aug 31, 2024
4d89f27
Update rms_layernorm.py
danielhanchen Aug 31, 2024
c76be22
Update rms_layernorm.py
danielhanchen Aug 31, 2024
ace509c
Update rms_layernorm.py
danielhanchen Aug 31, 2024
e474cfe
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1576a1e
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a2c4691
Update rms_layernorm.py
danielhanchen Aug 31, 2024
1a02e75
Update rms_layernorm.py
danielhanchen Aug 31, 2024
a26e1d1
Update rms_layernorm.py
danielhanchen Aug 31, 2024
afdb443
Update rms_layernorm.py
danielhanchen Sep 1, 2024
c3e14d8
Update rms_layernorm.py
danielhanchen Sep 1, 2024
1830bdd
Update rms_layernorm.py
danielhanchen Sep 1, 2024
6abf66a
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5cf796
Update rms_layernorm.py
danielhanchen Sep 1, 2024
b191530
Update rms_layernorm.py
danielhanchen Sep 1, 2024
512c61f
Update rms_layernorm.py
danielhanchen Sep 1, 2024
f5d50ef
Update rms_layernorm.py
danielhanchen Sep 1, 2024
d791bb9
Update rms_layernorm.py
danielhanchen Sep 1, 2024
9225608
Update gemma2.py
danielhanchen Sep 1, 2024
f61869c
Change UnslothTrainingArguments base class to SFTConfig (#979)
vTuanpham Sep 2, 2024
73d49ad
Cohere
danielhanchen Sep 2, 2024
86b6236
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Sep 2, 2024
edef5ca
Update trainer.py
danielhanchen Sep 2, 2024
6d4300c
Cohere
danielhanchen Sep 2, 2024
754e670
Cohere
danielhanchen Sep 2, 2024
d242866
New models
danielhanchen Sep 3, 2024
0b7e973
Update llama.py
danielhanchen Sep 3, 2024
19549f2
Update llama.py
danielhanchen Sep 3, 2024
8823e13
Update cohere.py
danielhanchen Sep 3, 2024
90050b7
Update llama.py
danielhanchen Sep 3, 2024
4c1ec3a
Update cohere.py
danielhanchen Sep 3, 2024
97b3956
retry
danielhanchen Sep 3, 2024
fd615ea
Update fast_lora.py
danielhanchen Sep 3, 2024
fe45990
Update llama.py
danielhanchen Sep 3, 2024
f564b8a
Update fast_lora.py
danielhanchen Sep 3, 2024
b26da84
Update llama.py
danielhanchen Sep 3, 2024
61be6a3
Update llama.py
danielhanchen Sep 3, 2024
ea48761
Update cross_entropy_loss.py
danielhanchen Sep 3, 2024
6e795c6
_apply_lora_mlp
danielhanchen Sep 3, 2024
dacba39
Update _utils.py
danielhanchen Sep 3, 2024
5074427
Gemma fixes
danielhanchen Sep 3, 2024
743ba55
Update llama.py
danielhanchen Sep 3, 2024
315136a
Merge branch 'main' into nightly
danielhanchen Sep 3, 2024
7ea6395
Update flex_attention.py
danielhanchen Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
layernorm
  • Loading branch information
danielhanchen committed Aug 30, 2024
commit 32b2f3f3b38b738a91a0c64f32f23ca934ae39f0
5 changes: 0 additions & 5 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def _gemma_rms_layernorm_forward(
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
output = output.to(X_row.dtype)

tl.store(Y + col_offsets, output, mask = mask)
pass
Expand All @@ -141,7 +140,6 @@ def forward(ctx, X, W, eps, gemma = False):
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

torch.cuda.synchronize()
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Y, Y.stride(0),
Expand All @@ -152,7 +150,6 @@ def forward(ctx, X, W, eps, gemma = False):
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
torch.cuda.synchronize()
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
Expand All @@ -170,7 +167,6 @@ def backward(ctx, dY):
n_rows, n_cols = dY.shape
dW = X

torch.cuda.synchronize()
_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
X, X .stride(0),
Expand All @@ -182,7 +178,6 @@ def backward(ctx, dY):
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
torch.cuda.synchronize()
dX = dY.view(*shape)
return dX, None, None, None
pass
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def is_big_gpu(index):
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"trace.enabled" : True, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# =============================================
Expand Down
9 changes: 4 additions & 5 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
fast_geglu_inference,
fast_rms_layernorm,
)
try:
from transformers.models.gemma2.modeling_gemma2 import (
Expand Down Expand Up @@ -205,7 +204,7 @@ def Gemma2DecoderLayer_fast_forward(
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
Expand All @@ -216,14 +215,14 @@ def Gemma2DecoderLayer_fast_forward(
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
pass

Expand Down