Skip to content

Commit 179840d

Browse files
Fix bugs (#1706)
* Bug fixes * fix: flash_attn_detection_error (#1556) * fix: flash_attn_detection_error * Update _utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update mapper.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * dim fix * Update _utils.py * Torch 2.6 support * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Faster inference? * Update llama.py * Update llama.py * Update utils.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update mapper.py * Fast Inference via vLLM * Update llama.py * Update llama.py * Update utils.py * Create rl.py * PatchRL * Update rl.py * Update rl.py * Update rl.py * PatchRLStatistics * Update rl.py * Update rl.py * Update rl.py * Update utils.py * Update utils.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * RL metrics * Update rl.py * RL metrics * Update __init__.py * Update rl.py * Update rl.py * Update rl.py * Update chat_templates.py * Update mapper.py * Fp8 cache * Update llama.py * Update llama.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update __init__.py * Update loader.py * Update rl.py * Update rl.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Better TRL handling * Update rl.py * Update tokenizer_utils.py * Auto patching * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * max seq length * Update rl.py * Update rl.py * Patching * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * NEFTune * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Extra replacements * Update rl_replacements.py * Update rl.py * extra RL replacements * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update _utils.py * Update loader_utils.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * autocast * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update _utils.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py --------- Co-authored-by: Zhe Zhang <2631992879@qq.com>
1 parent a41cdff commit 179840d

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2025.2.8"
15+
__version__ = "2025.2.9"
1616

1717
__all__ = [
1818
"SUPPORTS_BFLOAT16",

unsloth/models/llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def LlamaModel_fast_forward(
708708
if attention_mask is None:
709709
padding_mask = None
710710
elif self.training:
711-
# elif attention_mask is not None and self.training:
711+
# elif attention_mask is None:
712712
attention_mask = None
713713
padding_mask = None
714714
else:
@@ -724,7 +724,8 @@ def LlamaModel_fast_forward(
724724
past_key_values_length,
725725
sliding_window = getattr(self.config, "sliding_window", None),
726726
)
727-
attention_mask = attention_mask.to(torch.bool)
727+
if attention_mask is not None:
728+
attention_mask = attention_mask.to(torch.bool)
728729
pass
729730

730731
hidden_states = inputs_embeds

unsloth/models/rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,8 @@ def patch_trl_rl_trainers():
565565

566566

567567
def PatchFastRL(algorithm = None, FastLanguageModel = None):
568-
return
569-
# if FastLanguageModel is not None: PatchRL(FastLanguageModel)
570-
# patch_trl_rl_trainers()
571-
# if algorithm is not None: PatchRLStatistics(algorithm)
568+
if FastLanguageModel is not None: PatchRL(FastLanguageModel)
569+
patch_trl_rl_trainers()
570+
if type(algorithm) is str and algorithm.islower():
571+
PatchRLStatistics(algorithm)
572572
pass

unsloth/models/rl_replacements.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,20 @@ def sft_trainer_prepare_dataset(function_name, function):
101101

102102
# Ignore mean_token_accuracy since it needs logits
103103
# We override it directly with our version
104-
def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
105-
(loss, outputs) = super().compute_loss(
106-
model,
107-
inputs,
108-
return_outputs = return_outputs,
109-
num_items_in_batch = num_items_in_batch,
110-
)
111-
return (loss, outputs) if return_outputs else loss
112-
pass
113-
114104
def sft_trainer_compute_loss(function_name, function):
115105
if function_name != "compute_loss": return function
116106

117-
function = inspect.getsource(_sft_trainer_compute_loss)
118-
function = function.replace("def _sft_trainer_compute_loss", "def compute_loss")
119-
function = function.split("\n")
120-
function = "\n".join(" "*4+x for x in function)
107+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
108+
outputs = super().compute_loss(
109+
model,
110+
inputs,
111+
return_outputs = return_outputs,
112+
num_items_in_batch = num_items_in_batch,
113+
)
114+
return outputs
115+
pass
116+
117+
function = inspect.getsource(compute_loss)
121118
return function
122119
pass
123120
RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss)

0 commit comments

Comments
 (0)