@@ -291,6 +291,58 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
291291pass
292292RL_FUNCTIONS ["grpo_trainer" ].append (grpo_trainer__get_per_token_logps )
293293
294+ def grpo_trainer__get_per_token_logps_and_entropies (function_name , function ):
295+ if function_name != "_get_per_token_logps_and_entropies" : return function
296+
297+ # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
298+ def _get_per_token_logps_and_entropies (self , model , input_ids , attention_mask , logits_to_keep , batch_size = None , compute_entropy = False ):
299+ if True : # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
300+ return {"logps" : None , "entropies" : None } # Unsloth efficient GRPO
301+ # Otherwise, calculate normally:
302+ if not hasattr (self , '_autocast_dtype' ):
303+ self ._autocast_dtype = torch .float16 if os .environ .get ('ACCELERATE_MIXED_PRECISION' , 'fp16' ) == 'fp16' else torch .bfloat16
304+ if os .environ .get ('UNSLOTH_FORCE_FLOAT32' , '0' ) == '1' : self ._autocast_dtype = torch .float16
305+
306+ os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
307+ with torch .amp .autocast (device_type = 'cuda' , dtype = self ._autocast_dtype ):
308+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
309+ logits = model (
310+ input_ids = input_ids ,
311+ attention_mask = attention_mask ,
312+ logits_to_keep = logits_to_keep + 1 ,
313+ ).logits
314+
315+ entropies = None
316+ if compute_entropy :
317+ from trl .trainer .utils import entropy_from_logits
318+ entropies = entropy_from_logits (logits )
319+
320+ # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
321+ return {"logps" : logits , "entropies" : entropies }
322+ # input_ids = input_ids[:, -logits_to_keep:]
323+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
324+ # See https://github.com/huggingface/trl/issues/2770
325+ # logits = logits[:, -logits_to_keep:]
326+ # return logits
327+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
328+ # logits = logits / self.temperature
329+ # logps = selective_log_softmax(logits, input_ids)
330+
331+ # row_indices, col_indices = torch.where(logps < -20)
332+
333+ # # Method 1: Check if tensors have elements
334+ # if len(row_indices) > 0 and len(col_indices) > 0:
335+ # breakpoint() # Breakpoint triggered here
336+ # print("Found high values!")
337+ # return logps # compute logprobs for the input tokens
338+ pass
339+ pass
340+
341+ function = inspect .getsource (_get_per_token_logps_and_entropies )
342+ return function
343+ pass
344+ RL_FUNCTIONS ["grpo_trainer" ].append (grpo_trainer__get_per_token_logps_and_entropies )
345+
294346grpo_compute_loss = RL_REPLACEMENTS ["grpo_compute_loss" ]
295347grpo_compute_loss_slow = RL_REPLACEMENTS ["grpo_compute_loss_slow" ]
296348UnslothEfficientGRPO = RL_REPLACEMENTS ["UnslothEfficientGRPO" ]
@@ -319,14 +371,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
319371 _input_ids = input_ids
320372 _logits_to_keep = logits_to_keep
321373
322- per_token_logps = self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep )
374+ get_logps_func = lambda model , input_ids , attention_mask , logits_to_keep , batch_size = None , compute_entropy = False : self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep , batch_size ) if hasattr (self , "_get_per_token_logps" ) else self ._get_per_token_logps_and_entropies (model , input_ids , attention_mask , logits_to_keep , batch_size , compute_entropy )['logps' ]
375+
376+ per_token_logps = get_logps_func (model , input_ids , attention_mask , logits_to_keep )
323377
324378 # Compute the KL divergence between the model and the reference model
325379 # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
326380 # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
327381 if self .beta != 0.0 :
328382 with torch .inference_mode (), model .disable_adapter ():
329- ref_per_token_logps = self . _get_per_token_logps (model , input_ids , attention_mask , logits_to_keep )
383+ ref_per_token_logps = per_token_logps = get_logps_func (model , input_ids , attention_mask , logits_to_keep )
330384 else :
331385 ref_per_token_logps = None
332386 # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
0 commit comments