Skip to content

Commit 422c033

Browse files
Update granite to work with latest post_patch methods (#1502)
* Update granite to work with latest post_patch methods * Pass position_embeddings for granite even if transformers<4.47 * Update llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
1 parent e3a92e0 commit 422c033

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

unsloth/models/granite.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
LlamaLinearScalingRotaryEmbedding,
2121
)
2222
from .mistral import *
23-
23+
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
24+
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
2425
try:
2526
from transformers.models.granite.modeling_granite import (
2627
GraniteAttention,
@@ -423,6 +424,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
423424
def __init__(self, config):
424425
super().__init__(config = config)
425426

427+
def patched_init(original_init):
428+
def new_init(self, *args, **kwargs):
429+
# we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
430+
# https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243
431+
# The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference
432+
# So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later
433+
config = kwargs.get("config", args[0] if args else None)
434+
if config is not None:
435+
self.config = config
436+
original_init(self, *args, **kwargs)
437+
return new_init
438+
426439
class FastGraniteModel(FastLlamaModel):
427440

428441
@staticmethod
@@ -437,12 +450,13 @@ def pre_patch():
437450
exec(function, globals())
438451
GraniteAttention.__init__ = eval(init_name)
439452
pass
440-
GraniteAttention .forward = GraniteAttention_fast_forward
441-
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
442-
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
443-
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
444-
GraniteModel .forward = LlamaModel_fast_forward
445-
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
453+
GraniteAttention .forward = GraniteAttention_fast_forward
454+
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
455+
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
456+
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
457+
GraniteModel .forward = LlamaModel_fast_forward
458+
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
459+
GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__)
446460
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
447461
fix_prepare_inputs_for_generation(GraniteForCausalLM)
448462

@@ -454,7 +468,7 @@ def pre_patch():
454468

455469

456470
@staticmethod
457-
def post_patch(model):
471+
def post_patch(model, tokenizer):
458472

459473
# Torch.compile fails on embedding matrix??
460474
# Workaround randomnly fixes it for torch versions < 2.2
@@ -519,7 +533,7 @@ def post_patch(model):
519533
for _ in range(3):
520534
gc.collect()
521535
torch.cuda.empty_cache()
522-
return model
536+
return model, tokenizer
523537
pass
524538
pass
525539

0 commit comments

Comments
 (0)