File tree Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -150,8 +150,8 @@ def patch_ipykernel_hf_xet():
150150 Version (importlib_version ("ipykernel" )) == Version ("7.0.0" )
151151 ):
152152 print (
153- "#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1 ` breaks progress bars. Disabling for now in XET.\n " \
154- "#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1 ` or wait for a fix to\n " \
153+ "#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0 ` breaks progress bars. Disabling for now in XET.\n " \
154+ "#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>7.0.0 ` or wait for a fix to\n " \
155155 "https://github.com/huggingface/xet-core/issues/526"
156156 )
157157 from huggingface_hub .utils import disable_progress_bars
Original file line number Diff line number Diff line change @@ -1203,12 +1203,8 @@ def _CausalLM_fast_forward(
12031203 else :
12041204 RETURN_LOGITS = os .environ .get ("UNSLOTH_RETURN_LOGITS" , "0" ) == "1"
12051205 # < 1024 Normal Unsloth uses less VRAM!
1206- if DEVICE_TYPE == "hip" :
1207- # [TODO] AMD GPUs fail on chunked_cross_entropy loss!
1208- # RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
1209- RETURN_LOGITS = False
1210- elif bsz * q_len <= 1024 :
1211- # Uses 800MB more VRAM it seems than fused CE Loss
1206+ if bsz * q_len <= 1024 and not RETURN_LOGITS :
1207+ # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
12121208 RETURN_LOGITS = False
12131209
12141210 if not RETURN_LOGITS and labels is not None :
Original file line number Diff line number Diff line change @@ -298,7 +298,9 @@ def MistralForCausalLM_fast_forward(
298298 else :
299299 RETURN_LOGITS = os .environ .get ("UNSLOTH_RETURN_LOGITS" , "0" ) == "1"
300300 # < 1024 Normal Unsloth uses less VRAM!
301- if bsz * q_len <= 1024 : RETURN_LOGITS = True
301+ if bsz * q_len <= 1024 and not RETURN_LOGITS :
302+ # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
303+ RETURN_LOGITS = False
302304
303305 if not RETURN_LOGITS and labels is not None :
304306 n_items = kwargs .get ("num_items_in_batch" , None ) or kwargs .get ("n_items" , None )
You can’t perform that action at this time.
0 commit comments