Skip to content

Commit 2578e95

Browse files
authored
🚛 Provide all columns of the dataset to the reward function (#2650)
* The reward function is provided with all col from the dataset * Minor clarifications * minor renaming in doc [ci skip] * fix indentation
1 parent 6f99f42 commit 2578e95

File tree

4 files changed

+104
-15
lines changed

4 files changed

+104
-15
lines changed

docs/source/grpo_trainer.md

+37-7
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ The GRPO Trainer logs the following metrics:
121121
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
122122

123123
1. **Input arguments**:
124-
- The function must accept two arguments: `prompts` and `completions`.
124+
- The function must accept the following as keyword arguments:
125+
- `prompts` (contains the prompts),
126+
- `completions` (contains the generated completions),
127+
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
128+
129+
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
125130
- Depending on the dataset format, the input will vary:
126131
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
127132
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
@@ -133,7 +138,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
133138
Below is an example of a reward function for a standard format that rewards longer completions:
134139

135140
```python
136-
def reward_func(prompts, completions):
141+
def reward_func(completions, **kwargs):
137142
"""Reward function that gives higher scores to longer completions."""
138143
return [float(len(completion)) for completion in completions]
139144
```
@@ -143,19 +148,19 @@ You can test it as follows:
143148
```python
144149
>>> prompts = ["The sky is", "The sun is"]
145150
>>> completions = [" blue.", " in the sky."]
146-
>>> print(reward_func(prompts, completions))
151+
>>> print(reward_func(prompts=prompts, completions=completions))
147152
[6.0, 12.0]
148153
```
149154

150155
#### Example 2: Reward completions with specific format
151156

152-
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
157+
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
153158
It is designed for conversational format, where prompts and completions consist of structured messages.
154159

155160
```python
156161
import re
157162

158-
def format_reward_func(prompts, completions):
163+
def format_reward_func(completions, **kwargs):
159164
"""Reward function that checks if the completion has a specific format."""
160165
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
161166
completion_contents = [completion[0]["content"] for completion in completions]
@@ -174,9 +179,34 @@ You can test this function as follows:
174179
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
175180
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
176181
... ]
177-
>>> format_reward_func(prompts, completions)
182+
>>> format_reward_func(prompts=prompts, completions=completions)
183+
[1.0, 0.0]
184+
```
185+
186+
#### Example 3: Reward completions based on a reference
187+
188+
Below is an example of a reward function that checks if the is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
189+
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
190+
191+
```python
192+
import re
193+
194+
def reward_func(completions, ground_truth, **kwargs):
195+
# Regular expression to capture content inside \boxed{}
196+
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
197+
contents = [match.group(1) if match else "" for match in matches]
198+
# Reward 1 if the content is the same as the ground truth, 0 otherwise
199+
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
200+
```
201+
202+
You can test this function as follows:
203+
204+
```python
205+
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
206+
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
207+
>>> ground_truth = ["2", "5"]
208+
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
178209
[1.0, 0.0]
179-
>>>
180210
```
181211

182212
#### Passing the reward function to the trainer

tests/test_grpo_trainer.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def test_training_reward_func_standard(self):
151151
# Test if trainer can handle reward function with standard format
152152
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
153153

154-
def reward_func(prompts, completions):
154+
def reward_func(completions, **kwargs):
155155
"""Reward function that rewards longer completions."""
156156
return [float(len(completion)) for completion in completions]
157157

@@ -186,7 +186,7 @@ def test_training_reward_func_conversational(self):
186186
# Test if trainer can handle reward function with conversational format
187187
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
188188

189-
def reward_func(prompts, completions):
189+
def reward_func(completions, **kwargs):
190190
"""Reward function that gives higher scores to longer completion content."""
191191
completion_contents = [completion[0]["content"] for completion in completions]
192192
return [float(len(content)) for content in completion_contents]
@@ -222,11 +222,11 @@ def test_training_multiple_reward_funcs(self):
222222
# Test that GRPOTrainer can be instantiated with multiple reward functions
223223
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
224224

225-
def reward_func1(prompts, completions):
225+
def reward_func1(completions, **kwargs):
226226
"""Reward function that rewards longer completions."""
227227
return [float(len(completion)) for completion in completions]
228228

229-
def reward_func2(prompts, completions):
229+
def reward_func2(completions, **kwargs):
230230
"""Reward function that rewards completions with more unique letters."""
231231
return [float(len(set(completion))) for completion in completions]
232232

@@ -261,7 +261,7 @@ def test_training_multiple_mixed_reward_funcs(self):
261261
# Test if the trainer can handle a mix of reward functions and reward models
262262
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
263263

264-
def reward_func(prompts, completions):
264+
def reward_func(completions, **kwargs):
265265
"""Reward function that rewards longer completions."""
266266
return [float(len(completion)) for completion in completions]
267267

@@ -291,3 +291,42 @@ def reward_func(prompts, completions):
291291
for n, param in previous_trainable_params.items():
292292
new_param = trainer.model.get_parameter(n)
293293
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
294+
295+
def test_training_reward_func_additional_column(self):
296+
# Test if trainer can handle reward function that rely on additional columns in the dataset
297+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
298+
299+
# Add a column to the dataset (dummy example, the column could be anything)
300+
some_values = list(range(len(dataset)))
301+
dataset = dataset.add_column("some_values", some_values)
302+
303+
def reward_func(completions, some_values, **kwargs):
304+
"""Reward function that rewards completions with lengths closer to the values in some_values."""
305+
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
306+
307+
with tempfile.TemporaryDirectory() as tmp_dir:
308+
training_args = GRPOConfig(
309+
output_dir=tmp_dir,
310+
learning_rate=0.1, # increase the learning rate to speed up the test
311+
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
312+
num_generations=3, # reduce the number of generations to reduce memory usage
313+
max_completion_length=32, # reduce the completion length to reduce memory usage
314+
report_to="none",
315+
)
316+
trainer = GRPOTrainer(
317+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
318+
reward_funcs=reward_func,
319+
args=training_args,
320+
train_dataset=dataset,
321+
)
322+
323+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
324+
325+
trainer.train()
326+
327+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
328+
329+
# Check the params have changed
330+
for n, param in previous_trainable_params.items():
331+
new_param = trainer.model.get_parameter(n)
332+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

trl/trainer/grpo_config.py

+12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class GRPOConfig(TrainingArguments):
3939
4040
> Parameters that control the data preprocessing
4141
42+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
43+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
44+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
4245
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
4346
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
4447
num_generations (`int` or `None`, *optional*, defaults to `8`):
@@ -67,6 +70,15 @@ class GRPOConfig(TrainingArguments):
6770
)
6871

6972
# Parameters that control the data preprocessing
73+
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
74+
# additional columns to compute the reward
75+
remove_unused_columns: Optional[bool] = field(
76+
default=False,
77+
metadata={
78+
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
79+
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
80+
},
81+
)
7082
max_prompt_length: Optional[int] = field(
7183
default=512,
7284
metadata={

trl/trainer/grpo_trainer.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ class GRPOTrainer(Trainer):
9494
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
9595
keyword arguments in `args.model_init_kwargs`.
9696
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
97-
- A custom reward function: This should take a list of prompts and completions and return a list of
98-
rewards. For more details, see [Using a custom reward function](#using-a-custom-reward-function).
97+
- A custom reward function: The function is provided with the prompts and the generated completions,
98+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
99+
[Using a custom reward function](#using-a-custom-reward-function).
99100
- A list of reward functions, where each item can independently be any of the above types. Mixing different
100101
types within the list (e.g., a string model ID and a custom reward function) is allowed.
101102
args ([`GRPOConfig`], *optional*, defaults to `None`):
@@ -369,7 +370,14 @@ def get_per_token_logps(model, input_ids):
369370
with torch.inference_mode():
370371
rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
371372
else:
372-
rewards[i] = torch.tensor(reward_func(prompts, completions))
373+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
374+
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
375+
for key in reward_kwargs:
376+
for example in inputs:
377+
# Repeat each value in the column for `num_generations` times
378+
reward_kwargs[key].extend([example[key]] * self.num_generations)
379+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
380+
rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
373381
# Sum the rewards from all reward functions
374382
rewards = rewards.sum(dim=0)
375383

0 commit comments

Comments
 (0)