@@ -193,6 +193,10 @@ def LlamaAttention_fast_forward_inference(
193193
194194 # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
195195 # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
196+
197+ # Need to do it prior 2 steps before hitting full on short KV cache
198+ # or else error
199+ self .rotary_emb .extend_rope_embedding (Vn , seq_len + 2 )
196200 cos , sin = self .rotary_emb .get_cached (kv_seq_len )
197201 cos = cos [position_ids ].unsqueeze (1 )
198202 sin = sin [position_ids ].unsqueeze (1 )
@@ -1122,7 +1126,7 @@ def get_cached(self, seq_len = None):
11221126 def extend_rope_embedding (self , x , seq_len ):
11231127 if seq_len <= self .current_rope_size : return
11241128 # Iteratively grow by increments of 8192
1125- self .current_rope_size = math . ceil ( seq_len / 8192 ) * 8192
1129+ self .current_rope_size = (( seq_len // 8192 ) + (( seq_len % 8192 ) != 0 ) ) * 8192
11261130 self ._set_cos_sin_cache (self .current_rope_size , device = "cuda:0" , dtype = x .dtype )
11271131 pass
11281132pass
@@ -1248,7 +1252,7 @@ def get_cached(self, seq_len = None):
12481252 def extend_rope_embedding (self , x , seq_len ):
12491253 if seq_len <= self .current_rope_size : return
12501254 # Iteratively grow by increments of 8192
1251- self .current_rope_size = math . ceil ( seq_len / 8192 ) * 8192
1255+ self .current_rope_size = (( seq_len // 8192 ) + (( seq_len % 8192 ) != 0 ) ) * 8192
12521256 self ._set_cos_sin_cache (self .current_rope_size , device = "cuda:0" , dtype = x .dtype )
12531257 pass
12541258pass
@@ -1363,7 +1367,7 @@ def get_cached(self, seq_len = None):
13631367 def extend_rope_embedding (self , x , seq_len ):
13641368 if seq_len <= self .current_rope_size : return
13651369 # Iteratively grow by increments of 8192
1366- self .current_rope_size = math . ceil ( seq_len / 8192 ) * 8192
1370+ self .current_rope_size = (( seq_len // 8192 ) + (( seq_len % 8192 ) != 0 ) ) * 8192
13671371 self ._set_cos_sin_cache (self .current_rope_size , device = "cuda:0" , dtype = x .dtype )
13681372 pass
13691373pass
@@ -1952,10 +1956,10 @@ def get_peft_model(
19521956 # Offload!
19531957 # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
19541958 if "embed_tokens" in new_target_modules :
1955- print ("Unsloth: Casting embed_tokens to float32 " )
1959+ print ("Unsloth: Training embed_tokens in mixed precision to save VRAM " )
19561960
19571961 model .model .model .embed_tokens .modules_to_save .default \
1958- .to (device = "cuda:0" , dtype = torch . float32 , non_blocking = True )
1962+ .to (device = "cuda:0" , non_blocking = True )
19591963 model .model .model .embed_tokens .modules_to_save .default .requires_grad_ (True )
19601964
19611965 # [TODO] Move old embed_tokens to CPU - should be disk!
@@ -1965,10 +1969,10 @@ def get_peft_model(
19651969 pass
19661970
19671971 if "lm_head" in new_target_modules :
1968- print ("Unsloth: Casting lm_head to float32 " )
1972+ print ("Unsloth: Training lm_head in mixed precision to save VRAM " )
19691973
19701974 model .model .lm_head .modules_to_save .default \
1971- .to (device = "cuda:0" , dtype = torch . float32 , non_blocking = True )
1975+ .to (device = "cuda:0" , non_blocking = True )
19721976 model .model .lm_head .modules_to_save .default .requires_grad_ (True )
19731977
19741978 # [TODO] Move old lm_head to CPU - should be disk!
@@ -2203,18 +2207,18 @@ def get_peft_model(
22032207
22042208 # Now patch lm_head and embed_tokens
22052209 if train_embed_tokens :
2206- print ("Unsloth: Casting embed_tokens to float32 " )
2210+ print ("Unsloth: Training embed_tokens in mixed precision to save VRAM " )
22072211 assert (hasattr (model .model .model .embed_tokens , "modules_to_save" ))
22082212 model .model .model .embed_tokens .modules_to_save .default \
2209- .to (device = "cuda:0" , dtype = torch . float32 , non_blocking = True )
2213+ .to (device = "cuda:0" , non_blocking = True )
22102214 model .model .model .embed_tokens .modules_to_save .default .requires_grad_ (True )
22112215 pass
22122216
22132217 if train_lm_head :
2214- print ("Unsloth: Casting lm_head to float32 " )
2218+ print ("Unsloth: Training lm_head in mixed precision to save VRAM " )
22152219 assert (hasattr (model .model .lm_head , "modules_to_save" ))
22162220 model .model .lm_head .modules_to_save .default \
2217- .to (device = "cuda:0" , dtype = torch . float32 , non_blocking = True )
2221+ .to (device = "cuda:0" , non_blocking = True )
22182222 model .model .lm_head .modules_to_save .default .requires_grad_ (True )
22192223 pass
22202224
0 commit comments