Skip to content

Commit 1c96619

Browse files
Add GRPO metrics (#1718)
* Update llama.py * Update llama.py * Faster inference? * Update llama.py * Update llama.py * Update utils.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update mapper.py * Fast Inference via vLLM * Update llama.py * Update llama.py * Update utils.py * Create rl.py * PatchRL * Update rl.py * Update rl.py * Update rl.py * PatchRLStatistics * Update rl.py * Update rl.py * Update rl.py * Update utils.py * Update utils.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * RL metrics * Update rl.py * RL metrics * Update __init__.py * Update rl.py * Update rl.py * Update rl.py * Update chat_templates.py * Update mapper.py * Fp8 cache * Update llama.py * Update llama.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update __init__.py * Update loader.py * Update rl.py * Update rl.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Better TRL handling * Update rl.py * Update tokenizer_utils.py * Auto patching * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * max seq length * Update rl.py * Update rl.py * Patching * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * NEFTune * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Extra replacements * Update rl_replacements.py * Update rl.py * extra RL replacements * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update _utils.py * Update loader_utils.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * autocast * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update _utils.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * GRPO optimized * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Selective Log softmax * Fix GRPO bsz * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Fix TRL * Metrics GRPO * Update rl_replacements.py * Update rl_replacements.py
1 parent f6003b0 commit 1c96619

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2025.2.10"
15+
__version__ = "2025.2.11"
1616

1717
__all__ = [
1818
"SUPPORTS_BFLOAT16",

unsloth/models/rl.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RL_FUNCTIONS,
3131
RL_PRE_ITEMS,
3232
RL_CONFIG_CHANGES,
33+
RL_METRICS_CHANGES,
3334
)
3435
selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
3536

@@ -310,10 +311,20 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
310311
RLTrainer_post += neftune_check
311312
pass
312313

314+
# Edit optional metrics
315+
other_metrics_processor = ""
316+
if trainer_file in RL_METRICS_CHANGES:
317+
process_extra_args = RL_METRICS_CHANGES[trainer_file]
318+
for process_extra_arg in process_extra_args:
319+
other_metrics_processor += process_extra_arg(call_args, extra_args)
320+
pass
321+
313322
# Add statistics as well!
314323
extra_args += \
324+
"other_metrics = []\n"\
325+
f"{other_metrics_processor}\n"\
315326
"from unsloth_zoo.logging_utils import PatchRLStatistics\n"\
316-
f"PatchRLStatistics('{trainer_file}')\n"
327+
f"PatchRLStatistics('{trainer_file}', other_metrics)\n"
317328

318329
# Patch optional args
319330
if trainer_file in RL_EXTRA_ARGS:

unsloth/models/rl_replacements.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717
"RL_FUNCTIONS",
1818
"RL_PRE_ITEMS",
1919
"RL_CONFIG_CHANGES",
20+
"RL_METRICS_CHANGES",
2021
]
2122

2223
import re
2324
import torch
2425
import inspect
2526
from collections import defaultdict
2627
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS
27-
RL_EXTRA_ARGS = defaultdict(list)
28-
RL_FUNCTIONS = defaultdict(list)
29-
RL_PRE_ITEMS = defaultdict(list)
30-
RL_CONFIG_CHANGES = defaultdict(list)
28+
RL_EXTRA_ARGS = defaultdict(list)
29+
RL_FUNCTIONS = defaultdict(list)
30+
RL_PRE_ITEMS = defaultdict(list)
31+
RL_CONFIG_CHANGES = defaultdict(list)
32+
RL_METRICS_CHANGES = defaultdict(list)
3133

3234
torch_compile_options = {
3335
"epilogue_fusion" : True,
@@ -260,3 +262,20 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source):
260262
return check_batch_size
261263
pass
262264
RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size)
265+
266+
267+
# Add other reward function names
268+
def grpo_trainer_metrics(RLTrainer_source, RLConfig_source):
269+
if "reward_funcs" not in RLTrainer_source: return ""
270+
271+
log_metrics = \
272+
"if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\
273+
"else: _reward_funcs = reward_funcs\n"\
274+
"for reward_func in _reward_funcs:\n"\
275+
" try:\n"\
276+
" reward_func_name = reward_func.__name__\n"\
277+
" other_metrics.append(f'rewards/{reward_func_name}')\n"\
278+
" except: pass\n"
279+
return log_metrics
280+
pass
281+
RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics)

0 commit comments

Comments
 (0)