2727from unsloth_zoo .utils import Version , _get_dtype
2828from unsloth_zoo .hf_utils import dtype_from_config , add_dtype_kwargs , fix_lora_auto_mapping
2929from unsloth_zoo .peft_utils import SKIP_QUANTIZATION_MODULES
30- from unsloth import DEVICE_TYPE , DEVICE_COUNT
30+ from unsloth import DEVICE_TYPE , DEVICE_COUNT , DEVICE_TYPE_TORCH
3131
3232transformers_version = Version (transformers_version )
3333# Transformers moved rotary embeddings out of all attention layers
@@ -732,7 +732,7 @@ def LlamaModel_fast_forward(
732732 position_ids = torch .arange (
733733 past_key_values_length , seq_length + past_key_values_length ,
734734 dtype = torch .int32 ,
735- device = f"{ DEVICE_TYPE } :0" ,
735+ device = f"{ DEVICE_TYPE_TORCH } :0" ,
736736 )
737737 position_ids = position_ids .unsqueeze (0 ).view (- 1 , seq_length )
738738 elif position_ids is not None :
@@ -905,13 +905,13 @@ def LlamaModel_fast_forward(
905905 is_causal = True ,
906906 sliding_window = self .config .sliding_window ,
907907 )\
908- .to_causal_4d (1 , n , n , dtype = inputs_embeds .dtype , device = DEVICE_TYPE ,)\
908+ .to_causal_4d (1 , n , n , dtype = inputs_embeds .dtype , device = DEVICE_TYPE_TORCH ,)\
909909 .squeeze (0 ).squeeze (0 )
910910
911911 self .GA_mask = AttentionMaskConverter (
912912 is_causal = True ,
913913 )\
914- .to_causal_4d (1 , n , n , dtype = inputs_embeds .dtype , device = DEVICE_TYPE ,)\
914+ .to_causal_4d (1 , n , n , dtype = inputs_embeds .dtype , device = DEVICE_TYPE_TORCH ,)\
915915 .squeeze (0 ).squeeze (0 )
916916 pass
917917 pass
@@ -1028,11 +1028,11 @@ def LlamaModel_fast_forward_inference_custom(
10281028 bsz , q_len , hd = X .shape
10291029 assert (q_len == 1 )
10301030 # Get saved buffers to reduce memory movement
1031- residual = torch .empty ((bsz , q_len , hd ), dtype = torch .float32 , device = f"{ DEVICE_TYPE } :0" )
1032- _XX = torch .empty ((2 , bsz , q_len , hd ), dtype = torch .float32 , device = f"{ DEVICE_TYPE } :0" )
1031+ residual = torch .empty ((bsz , q_len , hd ), dtype = torch .float32 , device = f"{ DEVICE_TYPE_TORCH } :0" )
1032+ _XX = torch .empty ((2 , bsz , q_len , hd ), dtype = torch .float32 , device = f"{ DEVICE_TYPE_TORCH } :0" )
10331033 XX , XX2 = _XX [0 ], _XX [1 ]
1034- variance = torch .empty ((bsz , q_len , 1 ), dtype = torch .float32 , device = f"{ DEVICE_TYPE } :0" )
1035- temp_mlp = torch .empty ((2 , bsz , 1 , mlp_size ), dtype = X .dtype , device = f"{ DEVICE_TYPE } :0" )
1034+ variance = torch .empty ((bsz , q_len , 1 ), dtype = torch .float32 , device = f"{ DEVICE_TYPE_TORCH } :0" )
1035+ temp_mlp = torch .empty ((2 , bsz , 1 , mlp_size ), dtype = X .dtype , device = f"{ DEVICE_TYPE_TORCH } :0" )
10361036 temp_gates , temp_ups = tuple (temp_mlp [0 ].to (torch .device (x )) for x in range (DEVICE_COUNT )), tuple (temp_mlp [1 ].to (torch .device (x )) for x in range (DEVICE_COUNT ))
10371037
10381038 seq_len = past_key_values [0 ][0 ].shape [- 2 ]
@@ -1196,10 +1196,14 @@ def _CausalLM_fast_forward(
11961196 else :
11971197 RETURN_LOGITS = os .environ .get ("UNSLOTH_RETURN_LOGITS" , "0" ) == "1"
11981198 # < 1024 Normal Unsloth uses less VRAM!
1199- if bsz * q_len <= 1024 : RETURN_LOGITS = True
1199+ if DEVICE_TYPE == "hip" :
1200+ # [TODO] AMD GPUs fail on chunked_cross_entropy loss!
1201+ # RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
1202+ RETURN_LOGITS = False
1203+ elif bsz * q_len <= 1024 :
1204+ RETURN_LOGITS = True
12001205
12011206 if not RETURN_LOGITS and labels is not None :
1202-
12031207 n_items = kwargs .get ("num_items_in_batch" , None )
12041208 if n_items is None : n_items = kwargs .get ("n_items" , None )
12051209
@@ -1374,7 +1378,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
13741378 partial_rotary_factor = config .partial_rotary_factor if hasattr (config , "partial_rotary_factor" ) else 1.0
13751379 dim = getattr (config , "head_dim" , None )
13761380 if dim is None : dim = int ((config .hidden_size // config .num_attention_heads ))
1377- device = DEVICE_TYPE
1381+ device = DEVICE_TYPE_TORCH
13781382 max_position_embeddings = config .max_position_embeddings
13791383 pass
13801384
@@ -1486,7 +1490,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
14861490 base = config .rope_theta
14871491 partial_rotary_factor = config .partial_rotary_factor if hasattr (config , "partial_rotary_factor" ) else 1.0
14881492 dim = int ((config .hidden_size // config .num_attention_heads ))
1489- device = DEVICE_TYPE
1493+ device = DEVICE_TYPE_TORCH
14901494 max_position_embeddings = config .max_position_embeddings
14911495 pass
14921496
@@ -1606,7 +1610,7 @@ def __init__(self,
16061610 base = config .rope_theta
16071611 partial_rotary_factor = config .partial_rotary_factor if hasattr (config , "partial_rotary_factor" ) else 1.0
16081612 dim = int ((config .hidden_size // config .num_attention_heads ))
1609- device = DEVICE_TYPE
1613+ device = DEVICE_TYPE_TORCH
16101614 max_position_embeddings = config .max_position_embeddings
16111615 pass
16121616
@@ -1760,7 +1764,7 @@ def unsloth_fast_generate(
17601764 kwargs ["pad_token_id" ] = kwargs .pop ("pad_token_id" , model_eos_token_id )
17611765
17621766 # Mixed precision autocast
1763- with torch .inference_mode (), torch .autocast (device_type = DEVICE_TYPE , dtype = dtype ):
1767+ with torch .inference_mode (), torch .autocast (device_type = DEVICE_TYPE_TORCH , dtype = dtype ):
17641768 output = self ._old_generate (* args , ** kwargs )
17651769 pass
17661770
@@ -2384,7 +2388,7 @@ def get_peft_model(
23842388 pass
23852389
23862390 model .get_input_embeddings ().modules_to_save .default \
2387- .to (device = DEVICE_TYPE , dtype = new_dtype , non_blocking = True )
2391+ .to (device = DEVICE_TYPE_TORCH , dtype = new_dtype , non_blocking = True )
23882392 model .get_input_embeddings ().modules_to_save .default .requires_grad_ (True )
23892393
23902394 # [TODO] Move old embed_tokens to CPU - should be disk!
@@ -2404,7 +2408,7 @@ def get_peft_model(
24042408 pass
24052409
24062410 model .get_output_embeddings ().modules_to_save .default \
2407- .to (device = DEVICE_TYPE , dtype = new_dtype , non_blocking = True )
2411+ .to (device = DEVICE_TYPE_TORCH , dtype = new_dtype , non_blocking = True )
24082412 model .get_output_embeddings ().modules_to_save .default .requires_grad_ (True )
24092413
24102414 # [TODO] Move old lm_head to CPU - should be disk!
@@ -2673,7 +2677,7 @@ def get_peft_model(
26732677 pass
26742678
26752679 model .get_input_embeddings ().modules_to_save .default \
2676- .to (device = DEVICE_TYPE , dtype = new_dtype , non_blocking = True )
2680+ .to (device = DEVICE_TYPE_TORCH , dtype = new_dtype , non_blocking = True )
26772681 model .get_input_embeddings ().modules_to_save .default .requires_grad_ (True )
26782682 pass
26792683
@@ -2689,7 +2693,7 @@ def get_peft_model(
26892693 pass
26902694
26912695 model .get_output_embeddings ().modules_to_save .default \
2692- .to (device = DEVICE_TYPE , dtype = new_dtype , non_blocking = True )
2696+ .to (device = DEVICE_TYPE_TORCH , dtype = new_dtype , non_blocking = True )
26932697 model .get_output_embeddings ().modules_to_save .default .requires_grad_ (True )
26942698 pass
26952699
0 commit comments