Skip to content

Commit c2dbc03

Browse files
committed
Fixes
1 parent f61c450 commit c2dbc03

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

unsloth/import_fixes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def patch_ipykernel_hf_xet():
150150
Version(importlib_version("ipykernel")) == Version("7.0.0")
151151
):
152152
print(
153-
"#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1` breaks progress bars. Disabling for now in XET.\n"\
154-
"#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1` or wait for a fix to\n"\
153+
"#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` breaks progress bars. Disabling for now in XET.\n"\
154+
"#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>7.0.0` or wait for a fix to\n"\
155155
"https://github.com/huggingface/xet-core/issues/526"
156156
)
157157
from huggingface_hub.utils import disable_progress_bars

unsloth/models/llama.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,12 +1203,8 @@ def _CausalLM_fast_forward(
12031203
else:
12041204
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
12051205
# < 1024 Normal Unsloth uses less VRAM!
1206-
if DEVICE_TYPE == "hip":
1207-
# [TODO] AMD GPUs fail on chunked_cross_entropy loss!
1208-
# RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
1209-
RETURN_LOGITS = False
1210-
elif bsz*q_len <= 1024:
1211-
# Uses 800MB more VRAM it seems than fused CE Loss
1206+
if bsz * q_len <= 1024 and not RETURN_LOGITS:
1207+
# Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
12121208
RETURN_LOGITS = False
12131209

12141210
if not RETURN_LOGITS and labels is not None:

unsloth/models/mistral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def MistralForCausalLM_fast_forward(
298298
else:
299299
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
300300
# < 1024 Normal Unsloth uses less VRAM!
301-
if bsz * q_len <= 1024: RETURN_LOGITS = True
301+
if bsz * q_len <= 1024 and not RETURN_LOGITS:
302+
# Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
303+
RETURN_LOGITS = False
302304

303305
if not RETURN_LOGITS and labels is not None:
304306
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)

0 commit comments

Comments
 (0)