@@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
90
90
base_model_prefix = ""
91
91
main_input_name = "input_ids"
92
92
_auto_class = None
93
+ _missing_keys = set ()
93
94
94
95
def __init__ (
95
96
self ,
@@ -98,6 +99,7 @@ def __init__(
98
99
input_shape : Tuple = (1 , 1 ),
99
100
seed : int = 0 ,
100
101
dtype : jnp .dtype = jnp .float32 ,
102
+ _do_init : bool = True ,
101
103
):
102
104
if config is None :
103
105
raise ValueError ("config cannot be None" )
@@ -112,15 +114,35 @@ def __init__(
112
114
# Those are public as their type is generic to every derived classes.
113
115
self .key = PRNGKey (seed )
114
116
self .dtype = dtype
117
+ self .input_shape = input_shape
115
118
116
- # randomly initialized parameters
117
- random_params = self .init_weights (self .key , input_shape )
119
+ # To check if the model was intialized automatically.
120
+ self ._is_initialized = _do_init
121
+
122
+ if _do_init :
123
+ # randomly initialized parameters
124
+ random_params = self .init_weights (self .key , input_shape )
125
+ params_shape_tree = jax .eval_shape (lambda params : params , random_params )
126
+ else :
127
+ init_fn = partial (self .init_weights , input_shape = input_shape )
128
+ params_shape_tree = jax .eval_shape (init_fn , self .key )
129
+
130
+ logger .info (
131
+ "Model weights are not initialized as `_do_init` is set to `False`. "
132
+ f"Make sure to call `{ self .__class__ .__name__ } .init_weights` manually to initialize the weights."
133
+ )
134
+
135
+ # get the shape of the parameters
136
+ self ._params_shape_tree = params_shape_tree
118
137
119
138
# save required_params as set
120
- self ._required_params = set (flatten_dict (unfreeze (random_params )).keys ())
121
- self .params = random_params
139
+ self ._required_params = set (flatten_dict (unfreeze (params_shape_tree )).keys ())
140
+
141
+ # initialize the parameters
142
+ if _do_init :
143
+ self .params = random_params
122
144
123
- def init_weights (self , rng : jax .random .PRNGKey , input_shape : Tuple ) -> Dict :
145
+ def init_weights (self , rng : jax .random .PRNGKey , input_shape : Tuple , params : FrozenDict = None ) -> Dict :
124
146
raise NotImplementedError (f"init method has to be implemented for { self } " )
125
147
126
148
@classmethod
@@ -147,14 +169,31 @@ def module(self) -> nn.Module:
147
169
148
170
@property
149
171
def params (self ) -> Union [Dict , FrozenDict ]:
172
+ if not self ._is_initialized :
173
+ raise ValueError (
174
+ "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
175
+ "You must call `init_weights` manually and store the params outside of the model and "
176
+ "pass it explicitly where needed."
177
+ )
150
178
return self ._params
151
179
152
180
@property
153
181
def required_params (self ) -> Set :
154
182
return self ._required_params
155
183
184
+ @property
185
+ def params_shape_tree (self ) -> Dict :
186
+ return self ._params_shape_tree
187
+
156
188
@params .setter
157
189
def params (self , params : Union [Dict , FrozenDict ]):
190
+ # don't set params if the model is not initialized
191
+ if not self ._is_initialized :
192
+ raise ValueError (
193
+ "`params` cannot be set from model when the model is created with `_do_init=False`. "
194
+ "You store the params outside of the model."
195
+ )
196
+
158
197
if isinstance (params , FrozenDict ):
159
198
params = unfreeze (params )
160
199
param_keys = set (flatten_dict (params ).keys ())
@@ -417,6 +456,7 @@ def from_pretrained(
417
456
revision = kwargs .pop ("revision" , None )
418
457
from_pipeline = kwargs .pop ("_from_pipeline" , None )
419
458
from_auto_class = kwargs .pop ("_from_auto" , False )
459
+ _do_init = kwargs .pop ("_do_init" , True )
420
460
421
461
user_agent = {"file_type" : "model" , "framework" : "flax" , "from_auto_class" : from_auto_class }
422
462
if from_pipeline is not None :
@@ -553,7 +593,7 @@ def from_pretrained(
553
593
resolved_archive_file = None
554
594
555
595
# init random models
556
- model = cls (config , * model_args , ** model_kwargs )
596
+ model = cls (config , * model_args , _do_init = _do_init , ** model_kwargs )
557
597
558
598
if from_pt :
559
599
state = load_pytorch_checkpoint_in_flax_state_dict (model , resolved_archive_file )
@@ -577,25 +617,36 @@ def from_pretrained(
577
617
# make sure all arrays are stored as jnp.arrays
578
618
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
579
619
# https://github.com/google/flax/issues/1261
580
- state = jax .tree_util .tree_map (jnp .array , state )
620
+ if _do_init :
621
+ state = jax .tree_util .tree_map (jnp .array , state )
622
+ else :
623
+ # keep the params on CPU if we don't want to initialize
624
+ state = jax .tree_util .tree_map (lambda x : jax .device_put (x , jax .devices ("cpu" )[0 ]), state )
581
625
582
626
# if model is base model only use model_prefix key
583
- if cls .base_model_prefix not in dict (model .params ) and cls .base_model_prefix in state :
627
+ if cls .base_model_prefix not in dict (model .params_shape_tree ) and cls .base_model_prefix in state :
584
628
state = state [cls .base_model_prefix ]
585
629
586
630
# if model is head model and we are loading weights from base model
587
631
# we initialize new params dict with base_model_prefix
588
- if cls .base_model_prefix in dict (model .params ) and cls .base_model_prefix not in state :
632
+ if cls .base_model_prefix in dict (model .params_shape_tree ) and cls .base_model_prefix not in state :
589
633
state = {cls .base_model_prefix : state }
590
634
591
635
# flatten dicts
592
636
state = flatten_dict (state )
593
637
594
- random_state = flatten_dict (unfreeze (model .params ))
638
+ random_state = flatten_dict (unfreeze (model .params if _do_init else model . params_shape_tree ))
595
639
596
640
missing_keys = model .required_params - set (state .keys ())
597
641
unexpected_keys = set (state .keys ()) - model .required_params
598
642
643
+ if missing_keys and not _do_init :
644
+ logger .warn (
645
+ f"The checkpoint { pretrained_model_name_or_path } is missing required keys: { missing_keys } . "
646
+ f"Make sure to call model.init_weights to initialize the missing weights."
647
+ )
648
+ cls ._missing_keys = missing_keys
649
+
599
650
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
600
651
# matching the weights in the model.
601
652
mismatched_keys = []
@@ -612,9 +663,10 @@ def from_pretrained(
612
663
"model."
613
664
)
614
665
615
- # add missing keys as random parameters
616
- for missing_key in missing_keys :
617
- state [missing_key ] = random_state [missing_key ]
666
+ # add missing keys as random parameters if we are initializing
667
+ if missing_keys and _do_init :
668
+ for missing_key in missing_keys :
669
+ state [missing_key ] = random_state [missing_key ]
618
670
619
671
# remove unexpected keys to not be saved again
620
672
for unexpected_key in unexpected_keys :
@@ -680,10 +732,12 @@ def from_pretrained(
680
732
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
681
733
)
682
734
683
- # set correct parameters
684
- model .params = unflatten_dict (state )
685
-
686
- return model
735
+ if _do_init :
736
+ # set correct parameters
737
+ model .params = unflatten_dict (state )
738
+ return model
739
+ else :
740
+ return model , unflatten_dict (state )
687
741
688
742
def save_pretrained (self , save_directory : Union [str , os .PathLike ], params = None , push_to_hub = False , ** kwargs ):
689
743
"""
0 commit comments