@@ -244,11 +244,16 @@ def __init__(
244
244
"This argument can only be used when the `model` argument is a string."
245
245
)
246
246
247
+ self .beta = args .beta
248
+
247
249
if peft_config is not None :
248
250
model = get_peft_model (model , peft_config )
249
251
250
252
# Reference model
251
- if is_deepspeed_zero3_enabled ():
253
+ if self .beta == 0.0 :
254
+ # If beta is 0.0, the reference model is not needed
255
+ self .ref_model = None
256
+ elif is_deepspeed_zero3_enabled ():
252
257
self .ref_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_init_kwargs )
253
258
elif not is_peft_model (model ):
254
259
# If PEFT configuration is not provided, create a reference model based on the initial model.
@@ -314,8 +319,6 @@ def data_collator(features): # No data collation is needed in GRPO
314
319
self .num_generations = args .num_generations # = G in the GRPO paper
315
320
self .use_vllm = args .use_vllm
316
321
317
- self .beta = args .beta
318
-
319
322
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
320
323
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
321
324
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
@@ -603,7 +606,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
603
606
logits_to_keep = completion_ids .size (1 ) # we only need to compute the logits for the completion tokens
604
607
605
608
with torch .inference_mode ():
606
- if self .ref_model is not None :
609
+ if self .beta == 0.0 :
610
+ ref_per_token_logps = None
611
+ elif self .ref_model is not None :
607
612
ref_per_token_logps = self ._get_per_token_logps (
608
613
self .ref_model , prompt_completion_ids , attention_mask , logits_to_keep
609
614
)
@@ -723,21 +728,26 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
723
728
per_token_logps = self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep )
724
729
725
730
# Compute the KL divergence between the model and the reference model
726
- ref_per_token_logps = inputs ["ref_per_token_logps" ]
727
- per_token_kl = torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1
731
+ if self .beta != 0.0 :
732
+ ref_per_token_logps = inputs ["ref_per_token_logps" ]
733
+ per_token_kl = (
734
+ torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1
735
+ )
728
736
729
737
# x - x.detach() allows for preserving gradients from x
730
738
advantages = inputs ["advantages" ]
731
- per_token_loss = torch .exp (per_token_logps - per_token_logps .detach ()) * advantages .unsqueeze (1 )
732
- per_token_loss = - (per_token_loss - self .beta * per_token_kl )
739
+ per_token_loss = - torch .exp (per_token_logps - per_token_logps .detach ()) * advantages .unsqueeze (1 )
740
+ if self .beta != 0.0 :
741
+ per_token_loss = per_token_loss + self .beta * per_token_kl
733
742
loss = (per_token_loss * completion_mask ).sum () / completion_mask .sum ()
734
743
735
744
# Log the metrics
736
745
completion_length = self .accelerator .gather_for_metrics (completion_mask .sum (1 )).float ().mean ().item ()
737
746
self ._metrics ["completion_length" ].append (completion_length )
738
747
739
- mean_kl = ((per_token_kl * completion_mask ).sum (dim = 1 ) / completion_mask .sum (dim = 1 )).mean ()
740
- self ._metrics ["kl" ].append (self .accelerator .gather_for_metrics (mean_kl ).mean ().item ())
748
+ if self .beta != 0.0 :
749
+ mean_kl = ((per_token_kl * completion_mask ).sum (dim = 1 ) / completion_mask .sum (dim = 1 )).mean ()
750
+ self ._metrics ["kl" ].append (self .accelerator .gather_for_metrics (mean_kl ).mean ().item ())
741
751
742
752
return loss
743
753
0 commit comments