Skip to content

Commit 0e216f7

Browse files
authored
🍭 Custom reward function for RLOO (#2612)
* rloo custom reward function and test * idont even know why i did that * removing get_reward_custom * remove get_reward_custom test * fix code quality check * adding test * end this mysery already * fix test
1 parent 59c2014 commit 0e216f7

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

tests/test_rloo_trainer.py

+39
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,42 @@ def test_rloo_training(self):
172172

173173
# Check if objective/rlhf_reward is available
174174
self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1])
175+
176+
def test_rloo_training_with_custom_reward(self):
177+
# dummy reward function
178+
def reward_function(texts):
179+
# based on length of text
180+
rewards = [len(text) for text in texts]
181+
return rewards
182+
183+
with tempfile.TemporaryDirectory() as tmp_dir:
184+
training_args = RLOOConfig(
185+
output_dir=tmp_dir,
186+
per_device_train_batch_size=2,
187+
per_device_eval_batch_size=2,
188+
total_episodes=1,
189+
num_train_epochs=1,
190+
max_steps=2,
191+
report_to="none",
192+
)
193+
194+
# Create a simple dataset
195+
dummy_text = [{"content": "Hello World!", "role": "user"}]
196+
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
197+
dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]})
198+
199+
trainer = RLOOTrainer(
200+
config=training_args,
201+
policy=self.policy_model,
202+
reward_model=reward_function,
203+
ref_policy=self.policy_ref_model,
204+
processing_class=self.tokenizer,
205+
train_dataset=dummy_dataset,
206+
eval_dataset=dummy_dataset,
207+
)
208+
209+
# Test that training completes without errors
210+
trainer.train()
211+
212+
# Check if objective/rlhf_reward is available
213+
self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1])

trl/trainer/rloo_trainer.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import textwrap
1919
import time
2020
from collections import defaultdict
21-
from typing import Optional, Union
21+
from typing import Callable, Optional, Union
2222

2323
import numpy as np
2424
import pandas as pd
@@ -79,7 +79,7 @@ def __init__(
7979
],
8080
policy: nn.Module,
8181
ref_policy: nn.Module,
82-
reward_model: nn.Module,
82+
reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
8383
train_dataset: Dataset,
8484
data_collator: Optional[DataCollatorWithPadding] = None,
8585
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
@@ -152,7 +152,8 @@ def __init__(
152152
# setup model, optimizer, and others
153153
#########
154154
for module in [policy, ref_policy, reward_model]:
155-
disable_dropout_in_model(module)
155+
if isinstance(module, nn.Module):
156+
disable_dropout_in_model(module)
156157
if args.stop_token and args.stop_token == "eos":
157158
args.stop_token_id = self.processing_class.eos_token_id
158159
self.model = policy
@@ -219,16 +220,18 @@ def __init__(
219220
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
220221

221222
if self.is_deepspeed_enabled:
222-
self.reward_model = prepare_deepspeed(
223-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
224-
)
223+
if isinstance(self.reward_model, nn.Module):
224+
self.reward_model = prepare_deepspeed(
225+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
226+
)
225227
self.ref_policy = prepare_deepspeed(
226228
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
227229
)
228230
self.deepspeed = self.model
229231
else:
230232
self.ref_policy = self.ref_policy.to(self.accelerator.device)
231-
self.reward_model = self.reward_model.to(self.accelerator.device)
233+
if isinstance(self.reward_model, nn.Module):
234+
self.reward_model = self.reward_model.to(self.accelerator.device)
232235

233236
def get_train_dataloader(self) -> DataLoader:
234237
return self.dataloader
@@ -350,9 +353,18 @@ def repeat_generator():
350353
# Response Processing 2. run reward model on the truncated responses
351354
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
352355
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
353-
_, score, _ = get_reward(
354-
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
355-
)
356+
357+
if isinstance(reward_model, nn.Module):
358+
_, score, _ = get_reward(
359+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
360+
)
361+
else:
362+
score = torch.tensor(
363+
reward_model(
364+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
365+
),
366+
dtype=torch.float,
367+
).to(device)
356368

357369
# Store batch results
358370
responses.append(response)
@@ -595,9 +607,21 @@ def generate_completions(self, sampling: bool = False):
595607
)
596608

597609
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
598-
_, score, _ = get_reward(
599-
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
600-
)
610+
611+
if isinstance(self.reward_model, nn.Module):
612+
_, score, _ = get_reward(
613+
self.reward_model,
614+
postprocessed_query_response,
615+
processing_class.pad_token_id,
616+
context_length,
617+
)
618+
else:
619+
score = torch.tensor(
620+
self.reward_model(
621+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
622+
),
623+
dtype=torch.float,
624+
).to(postprocessed_query_response.device)
601625
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
602626

603627
if sampling:

0 commit comments

Comments
 (0)