|
17 | 17 | "RL_FUNCTIONS", |
18 | 18 | "RL_PRE_ITEMS", |
19 | 19 | "RL_CONFIG_CHANGES", |
| 20 | + "RL_METRICS_CHANGES", |
20 | 21 | ] |
21 | 22 |
|
22 | 23 | import re |
23 | 24 | import torch |
24 | 25 | import inspect |
25 | 26 | from collections import defaultdict |
26 | 27 | from unsloth_zoo.rl_replacements import RL_REPLACEMENTS |
27 | | -RL_EXTRA_ARGS = defaultdict(list) |
28 | | -RL_FUNCTIONS = defaultdict(list) |
29 | | -RL_PRE_ITEMS = defaultdict(list) |
30 | | -RL_CONFIG_CHANGES = defaultdict(list) |
| 28 | +RL_EXTRA_ARGS = defaultdict(list) |
| 29 | +RL_FUNCTIONS = defaultdict(list) |
| 30 | +RL_PRE_ITEMS = defaultdict(list) |
| 31 | +RL_CONFIG_CHANGES = defaultdict(list) |
| 32 | +RL_METRICS_CHANGES = defaultdict(list) |
31 | 33 |
|
32 | 34 | torch_compile_options = { |
33 | 35 | "epilogue_fusion" : True, |
@@ -260,3 +262,20 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): |
260 | 262 | return check_batch_size |
261 | 263 | pass |
262 | 264 | RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size) |
| 265 | + |
| 266 | + |
| 267 | +# Add other reward function names |
| 268 | +def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): |
| 269 | + if "reward_funcs" not in RLTrainer_source: return "" |
| 270 | + |
| 271 | + log_metrics = \ |
| 272 | + "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ |
| 273 | + "else: _reward_funcs = reward_funcs\n"\ |
| 274 | + "for reward_func in _reward_funcs:\n"\ |
| 275 | + " try:\n"\ |
| 276 | + " reward_func_name = reward_func.__name__\n"\ |
| 277 | + " other_metrics.append(f'rewards/{reward_func_name}')\n"\ |
| 278 | + " except: pass\n" |
| 279 | + return log_metrics |
| 280 | +pass |
| 281 | +RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics) |
0 commit comments