Skip to content

Commit a2dec14

Browse files
authored
Detach logits before returning from function (#3554)
1 parent 9cfdcac commit a2dec14

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

unsloth/models/rl_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l
510510

511511
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
512512
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
513-
return logits, entropies # logps, entropies
513+
return logits.detach(), entropies # logps, entropies
514514
# input_ids = input_ids[:, -logits_to_keep:]
515515
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
516516
# See https://github.com/huggingface/trl/issues/2770

0 commit comments

Comments
 (0)