Skip to content

Commit 51a7023

Browse files
Merge pull request #2873 from Erland366/fix/unslothtrainingarguments
Fix `UnslothTrainingArguments` not patching `trl.Config` properly
2 parents b347ec5 + 3bad31a commit 51a7023

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

unsloth/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@ def unsloth_train(trainer, *args, **kwargs):
6666
except:
6767
from transformers import TrainingArguments
6868
pass
69-
@dataclass
69+
7070
class UnslothTrainingArguments(TrainingArguments):
71-
embedding_learning_rate : Optional[float] = field(
72-
default = None,
73-
metadata = {"help" : "Different learning rates for embeddings and lm_head."}
74-
)
71+
def __init__(self, embedding_learning_rate: float = None, *args, **kwargs):
72+
embedding_learning_rate = embedding_learning_rate
73+
super().__init__(*args, **kwargs)
7574
pass
7675

7776

0 commit comments

Comments
 (0)