Skip to content

Commit d3bd9ac

Browse files
[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 6de4ee6 commit d3bd9ac

30 files changed

+702
-148
lines changed

examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
module = self.module_class(config=config, dtype=dtype, **kwargs)
141141
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142142

143-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
143+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
144144
# init input tensor
145145
input_ids = jnp.zeros(input_shape[0], dtype="i4")
146146
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])

src/transformers/modeling_flax_utils.py

+71-17
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
9090
base_model_prefix = ""
9191
main_input_name = "input_ids"
9292
_auto_class = None
93+
_missing_keys = set()
9394

9495
def __init__(
9596
self,
@@ -98,6 +99,7 @@ def __init__(
9899
input_shape: Tuple = (1, 1),
99100
seed: int = 0,
100101
dtype: jnp.dtype = jnp.float32,
102+
_do_init: bool = True,
101103
):
102104
if config is None:
103105
raise ValueError("config cannot be None")
@@ -112,15 +114,35 @@ def __init__(
112114
# Those are public as their type is generic to every derived classes.
113115
self.key = PRNGKey(seed)
114116
self.dtype = dtype
117+
self.input_shape = input_shape
115118

116-
# randomly initialized parameters
117-
random_params = self.init_weights(self.key, input_shape)
119+
# To check if the model was intialized automatically.
120+
self._is_initialized = _do_init
121+
122+
if _do_init:
123+
# randomly initialized parameters
124+
random_params = self.init_weights(self.key, input_shape)
125+
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
126+
else:
127+
init_fn = partial(self.init_weights, input_shape=input_shape)
128+
params_shape_tree = jax.eval_shape(init_fn, self.key)
129+
130+
logger.info(
131+
"Model weights are not initialized as `_do_init` is set to `False`. "
132+
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
133+
)
134+
135+
# get the shape of the parameters
136+
self._params_shape_tree = params_shape_tree
118137

119138
# save required_params as set
120-
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
121-
self.params = random_params
139+
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
140+
141+
# initialize the parameters
142+
if _do_init:
143+
self.params = random_params
122144

123-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
145+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
124146
raise NotImplementedError(f"init method has to be implemented for {self}")
125147

126148
@classmethod
@@ -147,14 +169,31 @@ def module(self) -> nn.Module:
147169

148170
@property
149171
def params(self) -> Union[Dict, FrozenDict]:
172+
if not self._is_initialized:
173+
raise ValueError(
174+
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
175+
"You must call `init_weights` manually and store the params outside of the model and "
176+
"pass it explicitly where needed."
177+
)
150178
return self._params
151179

152180
@property
153181
def required_params(self) -> Set:
154182
return self._required_params
155183

184+
@property
185+
def params_shape_tree(self) -> Dict:
186+
return self._params_shape_tree
187+
156188
@params.setter
157189
def params(self, params: Union[Dict, FrozenDict]):
190+
# don't set params if the model is not initialized
191+
if not self._is_initialized:
192+
raise ValueError(
193+
"`params` cannot be set from model when the model is created with `_do_init=False`. "
194+
"You store the params outside of the model."
195+
)
196+
158197
if isinstance(params, FrozenDict):
159198
params = unfreeze(params)
160199
param_keys = set(flatten_dict(params).keys())
@@ -417,6 +456,7 @@ def from_pretrained(
417456
revision = kwargs.pop("revision", None)
418457
from_pipeline = kwargs.pop("_from_pipeline", None)
419458
from_auto_class = kwargs.pop("_from_auto", False)
459+
_do_init = kwargs.pop("_do_init", True)
420460

421461
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
422462
if from_pipeline is not None:
@@ -553,7 +593,7 @@ def from_pretrained(
553593
resolved_archive_file = None
554594

555595
# init random models
556-
model = cls(config, *model_args, **model_kwargs)
596+
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
557597

558598
if from_pt:
559599
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
@@ -577,25 +617,36 @@ def from_pretrained(
577617
# make sure all arrays are stored as jnp.arrays
578618
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
579619
# https://github.com/google/flax/issues/1261
580-
state = jax.tree_util.tree_map(jnp.array, state)
620+
if _do_init:
621+
state = jax.tree_util.tree_map(jnp.array, state)
622+
else:
623+
# keep the params on CPU if we don't want to initialize
624+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
581625

582626
# if model is base model only use model_prefix key
583-
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
627+
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
584628
state = state[cls.base_model_prefix]
585629

586630
# if model is head model and we are loading weights from base model
587631
# we initialize new params dict with base_model_prefix
588-
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
632+
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
589633
state = {cls.base_model_prefix: state}
590634

591635
# flatten dicts
592636
state = flatten_dict(state)
593637

594-
random_state = flatten_dict(unfreeze(model.params))
638+
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
595639

596640
missing_keys = model.required_params - set(state.keys())
597641
unexpected_keys = set(state.keys()) - model.required_params
598642

643+
if missing_keys and not _do_init:
644+
logger.warn(
645+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
646+
f"Make sure to call model.init_weights to initialize the missing weights."
647+
)
648+
cls._missing_keys = missing_keys
649+
599650
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
600651
# matching the weights in the model.
601652
mismatched_keys = []
@@ -612,9 +663,10 @@ def from_pretrained(
612663
"model."
613664
)
614665

615-
# add missing keys as random parameters
616-
for missing_key in missing_keys:
617-
state[missing_key] = random_state[missing_key]
666+
# add missing keys as random parameters if we are initializing
667+
if missing_keys and _do_init:
668+
for missing_key in missing_keys:
669+
state[missing_key] = random_state[missing_key]
618670

619671
# remove unexpected keys to not be saved again
620672
for unexpected_key in unexpected_keys:
@@ -680,10 +732,12 @@ def from_pretrained(
680732
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
681733
)
682734

683-
# set correct parameters
684-
model.params = unflatten_dict(state)
685-
686-
return model
735+
if _do_init:
736+
# set correct parameters
737+
model.params = unflatten_dict(state)
738+
return model
739+
else:
740+
return model, unflatten_dict(state)
687741

688742
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
689743
"""

src/transformers/models/albert/modeling_flax_albert.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import flax.linen as nn
2222
import jax
2323
import jax.numpy as jnp
24-
from flax.core.frozen_dict import FrozenDict
24+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
2525
from flax.linen.attention import dot_product_attention_weights
26+
from flax.traverse_util import flatten_dict, unflatten_dict
2627
from jax import lax
2728

2829
from ...modeling_flax_outputs import (
@@ -522,12 +523,13 @@ def __init__(
522523
input_shape: Tuple = (1, 1),
523524
seed: int = 0,
524525
dtype: jnp.dtype = jnp.float32,
526+
_do_init: bool = True,
525527
**kwargs
526528
):
527529
module = self.module_class(config=config, dtype=dtype, **kwargs)
528-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
530+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
529531

530-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
532+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
531533
# init input tensors
532534
input_ids = jnp.zeros(input_shape, dtype="i4")
533535
token_type_ids = jnp.zeros_like(input_ids)
@@ -537,9 +539,19 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
537539
params_rng, dropout_rng = jax.random.split(rng)
538540
rngs = {"params": params_rng, "dropout": dropout_rng}
539541

540-
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
541-
"params"
542-
]
542+
random_params = self.module.init(
543+
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
544+
)["params"]
545+
546+
if params is not None:
547+
random_params = flatten_dict(unfreeze(random_params))
548+
params = flatten_dict(unfreeze(params))
549+
for missing_key in self._missing_keys:
550+
params[missing_key] = random_params[missing_key]
551+
self._missing_keys = set()
552+
return freeze(unflatten_dict(params))
553+
else:
554+
return random_params
543555

544556
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
545557
def __call__(

src/transformers/models/bart/modeling_flax_bart.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
import flax.linen as nn
2525
import jax
2626
import jax.numpy as jnp
27-
from flax.core.frozen_dict import FrozenDict, unfreeze
27+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
2828
from flax.linen import combine_masks, make_causal_mask
2929
from flax.linen.attention import dot_product_attention_weights
30+
from flax.traverse_util import flatten_dict, unflatten_dict
3031
from jax import lax
3132
from jax.random import PRNGKey
3233

@@ -912,12 +913,13 @@ def __init__(
912913
input_shape: Tuple[int] = (1, 1),
913914
seed: int = 0,
914915
dtype: jnp.dtype = jnp.float32,
916+
_do_init: bool = True,
915917
**kwargs
916918
):
917919
module = self.module_class(config=config, dtype=dtype, **kwargs)
918-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
920+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
919921

920-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
922+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
921923
# init input tensors
922924
input_ids = jnp.zeros(input_shape, dtype="i4")
923925
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
@@ -933,7 +935,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
933935
params_rng, dropout_rng = jax.random.split(rng)
934936
rngs = {"params": params_rng, "dropout": dropout_rng}
935937

936-
return self.module.init(
938+
random_params = self.module.init(
937939
rngs,
938940
input_ids,
939941
attention_mask,
@@ -943,6 +945,16 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
943945
decoder_position_ids,
944946
)["params"]
945947

948+
if params is not None:
949+
random_params = flatten_dict(unfreeze(random_params))
950+
params = flatten_dict(unfreeze(params))
951+
for missing_key in self._missing_keys:
952+
params[missing_key] = random_params[missing_key]
953+
self._missing_keys = set()
954+
return freeze(unflatten_dict(params))
955+
else:
956+
return random_params
957+
946958
def init_cache(self, batch_size, max_length, encoder_outputs):
947959
r"""
948960
Args:
@@ -1737,14 +1749,15 @@ def __init__(
17371749
input_shape: Tuple[int] = (1, 1),
17381750
seed: int = 0,
17391751
dtype: jnp.dtype = jnp.float32,
1752+
_do_init: bool = True,
17401753
**kwargs
17411754
):
17421755
config.is_decoder = True
17431756
config.is_encoder_decoder = False
17441757
module = self.module_class(config=config, dtype=dtype, **kwargs)
1745-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
1758+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
17461759

1747-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
1760+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
17481761
# init input tensors
17491762
input_ids = jnp.zeros(input_shape, dtype="i4")
17501763
attention_mask = jnp.ones_like(input_ids)

src/transformers/models/beit/modeling_flax_beit.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import flax.linen as nn
2323
import jax
2424
import jax.numpy as jnp
25-
from flax.core.frozen_dict import FrozenDict
25+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
2626
from flax.linen.attention import dot_product_attention_weights
27+
from flax.traverse_util import flatten_dict, unflatten_dict
2728

2829
from ...modeling_flax_outputs import (
2930
FlaxBaseModelOutput,
@@ -591,21 +592,39 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
591592
main_input_name = "pixel_values"
592593
module_class: nn.Module = None
593594

594-
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
595+
def __init__(
596+
self,
597+
config: BeitConfig,
598+
input_shape=None,
599+
seed: int = 0,
600+
dtype: jnp.dtype = jnp.float32,
601+
_do_init: bool = True,
602+
**kwargs
603+
):
595604
module = self.module_class(config=config, dtype=dtype, **kwargs)
596605
if input_shape is None:
597606
input_shape = (1, config.image_size, config.image_size, 3)
598-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
607+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
599608

600-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
609+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
601610
# init input tensors
602611
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
603612

604613
params_rng, dropout_rng = jax.random.split(rng)
605614
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
606615
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
607616

608-
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
617+
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
618+
619+
if params is not None:
620+
random_params = flatten_dict(unfreeze(random_params))
621+
params = flatten_dict(unfreeze(params))
622+
for missing_key in self._missing_keys:
623+
params[missing_key] = random_params[missing_key]
624+
self._missing_keys = set()
625+
return freeze(unflatten_dict(params))
626+
else:
627+
return random_params
609628

610629
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
611630
def __call__(

0 commit comments

Comments
 (0)