Skip to content

Commit 631ab63

Browse files
committed
Dynamically adjust get_per_token_logps function and patch as well (unslothai#2911)
1 parent 6c32cae commit 631ab63

File tree

1 file changed

+56
-2
lines changed

1 file changed

+56
-2
lines changed

unsloth/models/rl_replacements.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,58 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
291291
pass
292292
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps)
293293

294+
def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
295+
if function_name != "_get_per_token_logps_and_entropies": return function
296+
297+
# Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
298+
def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None, compute_entropy = False):
299+
if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
300+
return {"logps": None, "entropies": None} # Unsloth efficient GRPO
301+
# Otherwise, calculate normally:
302+
if not hasattr(self, '_autocast_dtype'):
303+
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
304+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
305+
306+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
307+
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
308+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
309+
logits = model(
310+
input_ids = input_ids,
311+
attention_mask = attention_mask,
312+
logits_to_keep = logits_to_keep + 1,
313+
).logits
314+
315+
entropies = None
316+
if compute_entropy:
317+
from trl.trainer.utils import entropy_from_logits
318+
entropies = entropy_from_logits(logits)
319+
320+
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
321+
return {"logps": logits, "entropies": entropies}
322+
# input_ids = input_ids[:, -logits_to_keep:]
323+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
324+
# See https://github.com/huggingface/trl/issues/2770
325+
# logits = logits[:, -logits_to_keep:]
326+
# return logits
327+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
328+
# logits = logits / self.temperature
329+
# logps = selective_log_softmax(logits, input_ids)
330+
331+
# row_indices, col_indices = torch.where(logps < -20)
332+
333+
# # Method 1: Check if tensors have elements
334+
# if len(row_indices) > 0 and len(col_indices) > 0:
335+
# breakpoint() # Breakpoint triggered here
336+
# print("Found high values!")
337+
# return logps # compute logprobs for the input tokens
338+
pass
339+
pass
340+
341+
function = inspect.getsource(_get_per_token_logps_and_entropies)
342+
return function
343+
pass
344+
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies)
345+
294346
grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
295347
grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"]
296348
UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
@@ -319,14 +371,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
319371
_input_ids = input_ids
320372
_logits_to_keep = logits_to_keep
321373

322-
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
374+
get_logps_func = lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False: self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size) if hasattr(self, "_get_per_token_logps") else self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy)['logps']
375+
376+
per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
323377

324378
# Compute the KL divergence between the model and the reference model
325379
# _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
326380
# https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
327381
if self.beta != 0.0:
328382
with torch.inference_mode(), model.disable_adapter():
329-
ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
383+
ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
330384
else:
331385
ref_per_token_logps = None
332386
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

0 commit comments

Comments
 (0)