2020 LlamaLinearScalingRotaryEmbedding ,
2121)
2222from .mistral import *
23-
23+ from bitsandbytes .nn import Linear4bit as Bnb_Linear4bit
24+ from peft .tuners .lora import Linear4bit as Peft_Linear4bit
2425try :
2526 from transformers .models .granite .modeling_granite import (
2627 GraniteAttention ,
@@ -423,6 +424,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
423424 def __init__ (self , config ):
424425 super ().__init__ (config = config )
425426
427+ def patched_init (original_init ):
428+ def new_init (self , * args , ** kwargs ):
429+ # we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
430+ # https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243
431+ # The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference
432+ # So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later
433+ config = kwargs .get ("config" , args [0 ] if args else None )
434+ if config is not None :
435+ self .config = config
436+ original_init (self , * args , ** kwargs )
437+ return new_init
438+
426439class FastGraniteModel (FastLlamaModel ):
427440
428441 @staticmethod
@@ -437,12 +450,13 @@ def pre_patch():
437450 exec (function , globals ())
438451 GraniteAttention .__init__ = eval (init_name )
439452 pass
440- GraniteAttention .forward = GraniteAttention_fast_forward
441- GraniteSdpaAttention .forward = GraniteAttention_fast_forward
442- GraniteFlashAttention2 .forward = GraniteAttention_fast_forward
443- GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
444- GraniteModel .forward = LlamaModel_fast_forward
445- GraniteForCausalLM .forward = CausalLM_fast_forward (GraniteModel_fast_forward_inference )
453+ GraniteAttention .forward = GraniteAttention_fast_forward
454+ GraniteSdpaAttention .forward = GraniteAttention_fast_forward
455+ GraniteFlashAttention2 .forward = GraniteAttention_fast_forward
456+ GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
457+ GraniteModel .forward = LlamaModel_fast_forward
458+ GraniteForCausalLM .forward = CausalLM_fast_forward (GraniteModel_fast_forward_inference )
459+ GraniteForCausalLM .__init__ = patched_init (GraniteForCausalLM .__init__ )
446460 PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
447461 fix_prepare_inputs_for_generation (GraniteForCausalLM )
448462
@@ -454,7 +468,7 @@ def pre_patch():
454468
455469
456470 @staticmethod
457- def post_patch (model ):
471+ def post_patch (model , tokenizer ):
458472
459473 # Torch.compile fails on embedding matrix??
460474 # Workaround randomnly fixes it for torch versions < 2.2
@@ -519,7 +533,7 @@ def post_patch(model):
519533 for _ in range (3 ):
520534 gc .collect ()
521535 torch .cuda .empty_cache ()
522- return model
536+ return model , tokenizer
523537 pass
524538pass
525539
0 commit comments