Skip to content

Commit 15fec31

Browse files
ingambeqgallouedec
andauthored
πŸƒ GRPO - Do not load reference model when beta == 0 (#2806)
* πŸ”§ Optimize GRPO training by conditionally loading reference model based on beta value * βœ… Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence * πŸ”§ Refactor GRPOTrainer code for improved readability and maintainability * πŸ”§ Simplify per_token_loss calculation in GRPOTrainer for clarity * fix test, style, and some struct for clarity --------- Co-authored-by: Quentin GallouΓ©dec <quentin.gallouedec@huggingface.co> Co-authored-by: Quentin GallouΓ©dec <45557362+qgallouedec@users.noreply.github.com>
1 parent be1e340 commit 15fec31

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

β€Žtests/test_grpo_trainer.py

+30
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,36 @@ def test_training_with_sync_ref_model(self):
500500
new_param = trainer.model.get_parameter(n)
501501
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
502502

503+
def test_beta_zero_no_ref_model_and_no_kl(self):
504+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
505+
with tempfile.TemporaryDirectory() as tmp_dir:
506+
training_args = GRPOConfig(
507+
output_dir=tmp_dir,
508+
beta=0.0, # set beta to 0 to test the case where the reference model is not used
509+
learning_rate=0.1, # increase the learning rate to speed up the test
510+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
511+
num_generations=3, # reduce the number of generations to reduce memory usage
512+
max_completion_length=32, # reduce the completion length to reduce memory usage
513+
report_to="none",
514+
)
515+
trainer = GRPOTrainer(
516+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
517+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
518+
args=training_args,
519+
train_dataset=dataset,
520+
)
521+
522+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
523+
524+
trainer.train()
525+
526+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
527+
528+
# Check that the params have changed
529+
for n, param in previous_trainable_params.items():
530+
new_param = trainer.model.get_parameter(n)
531+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
532+
503533
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
504534
@require_torch_accelerator
505535
@require_peft

β€Žtrl/trainer/grpo_config.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class GRPOConfig(TrainingArguments):
8888
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
8989
[`~transformers.TrainingArguments`].
9090
beta (`float`, *optional*, defaults to `0.04`):
91-
KL coefficient.
91+
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
92+
speed.
9293
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
9394
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
9495
weighted equally with weight `1.0`.
@@ -218,7 +219,10 @@ class GRPOConfig(TrainingArguments):
218219
)
219220
beta: float = field(
220221
default=0.04,
221-
metadata={"help": "KL coefficient."},
222+
metadata={
223+
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
224+
"training speed."
225+
},
222226
)
223227
reward_weights: Optional[list[float]] = field(
224228
default=None,

β€Žtrl/trainer/grpo_trainer.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,16 @@ def __init__(
244244
"This argument can only be used when the `model` argument is a string."
245245
)
246246

247+
self.beta = args.beta
248+
247249
if peft_config is not None:
248250
model = get_peft_model(model, peft_config)
249251

250252
# Reference model
251-
if is_deepspeed_zero3_enabled():
253+
if self.beta == 0.0:
254+
# If beta is 0.0, the reference model is not needed
255+
self.ref_model = None
256+
elif is_deepspeed_zero3_enabled():
252257
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
253258
elif not is_peft_model(model):
254259
# If PEFT configuration is not provided, create a reference model based on the initial model.
@@ -314,8 +319,6 @@ def data_collator(features): # No data collation is needed in GRPO
314319
self.num_generations = args.num_generations # = G in the GRPO paper
315320
self.use_vllm = args.use_vllm
316321

317-
self.beta = args.beta
318-
319322
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
320323
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
321324
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
@@ -603,7 +606,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
603606
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
604607

605608
with torch.inference_mode():
606-
if self.ref_model is not None:
609+
if self.beta == 0.0:
610+
ref_per_token_logps = None
611+
elif self.ref_model is not None:
607612
ref_per_token_logps = self._get_per_token_logps(
608613
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
609614
)
@@ -723,21 +728,26 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
723728
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
724729

725730
# Compute the KL divergence between the model and the reference model
726-
ref_per_token_logps = inputs["ref_per_token_logps"]
727-
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
731+
if self.beta != 0.0:
732+
ref_per_token_logps = inputs["ref_per_token_logps"]
733+
per_token_kl = (
734+
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
735+
)
728736

729737
# x - x.detach() allows for preserving gradients from x
730738
advantages = inputs["advantages"]
731-
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
732-
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
739+
per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
740+
if self.beta != 0.0:
741+
per_token_loss = per_token_loss + self.beta * per_token_kl
733742
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
734743

735744
# Log the metrics
736745
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
737746
self._metrics["completion_length"].append(completion_length)
738747

739-
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
740-
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
748+
if self.beta != 0.0:
749+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
750+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
741751

742752
return loss
743753

0 commit comments

Comments
Β (0)