16
16
17
17
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
18
18
19
- python -m transformers. models. gemma3.convert_gemma3_weights_orbax_to_hf \
19
+ python src/ transformers/ models/ gemma3/convert_gemma3_weights.py \
20
20
--variant='gemma3_4b' \
21
21
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
22
22
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
23
23
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/"
24
24
"""
25
25
26
26
from collections .abc import Iterator , Sequence
27
- from typing import Any
27
+ from typing import Any , Optional
28
28
29
29
import accelerate
30
30
import numpy as np
40
40
Gemma3ImageProcessor ,
41
41
Gemma3Processor ,
42
42
Gemma3TextConfig ,
43
+ Gemma3TextModel ,
43
44
GemmaTokenizerFast ,
44
45
GenerationConfig ,
45
46
SiglipVisionConfig ,
100
101
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len (_SIGLIP_TRANSFORMER_ENCODER_BLOCK )
101
102
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
102
103
103
- _TRANSFORMER_DECODER_BLOCK = "transformer /layer_"
104
+ _TRANSFORMER_DECODER_BLOCK = "/layer_"
104
105
_TRANSFORMER_DECODER_BLOCK_LEN = len (_TRANSFORMER_DECODER_BLOCK )
105
- _TRANSFORMER_EMBEDDER = "transformer /embedder"
106
- _TRANSFORMER_FINAL_NORM = "transformer /final_norm"
106
+ _TRANSFORMER_EMBEDDER = "/embedder"
107
+ _TRANSFORMER_FINAL_NORM = "/final_norm"
107
108
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
108
109
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len (_TRANSFORMER_POST_TRAINING_PREFIX )
109
110
121
122
"vision_use_head" : False ,
122
123
}
123
124
125
+ _VARIANT_EMBEDDINGGEMMA = "embedding"
126
+ _VARIANT_GEMMA_3_270M = "gemma3_270m"
124
127
_VARIANT_GEMMA_3_1B = "gemma3_1b"
125
128
_VARIANT_GEMMA_3_4B = "gemma3_4b"
126
129
_VARIANT_GEMMA_3_12B = "gemma3_12b"
127
130
_VARIANT_GEMMA_3_27B = "gemma3_27b"
128
131
_VARIANTS = {
132
+ _VARIANT_EMBEDDINGGEMMA : Gemma3Config (
133
+ text_config = Gemma3TextConfig (
134
+ vocab_size = 262_144 ,
135
+ hidden_size = 768 ,
136
+ intermediate_size = 1152 ,
137
+ num_hidden_layers = 24 ,
138
+ num_attention_heads = 3 ,
139
+ num_key_value_heads = 1 ,
140
+ head_dim = 256 ,
141
+ max_position_embeddings = 1024 ,
142
+ query_pre_attn_scalar = 256 ,
143
+ sliding_window = 512 ,
144
+ rope_scaling = None ,
145
+ use_bidirectional_attention = True ,
146
+ ),
147
+ vision_config = None ,
148
+ ),
149
+ _VARIANT_GEMMA_3_270M : Gemma3Config (
150
+ text_config = Gemma3TextConfig (
151
+ vocab_size = 262_144 ,
152
+ hidden_size = 640 ,
153
+ intermediate_size = 2048 ,
154
+ num_hidden_layers = 18 ,
155
+ num_attention_heads = 4 ,
156
+ num_key_value_heads = 1 ,
157
+ head_dim = 256 ,
158
+ max_position_embeddings = 32768 ,
159
+ query_pre_attn_scalar = 256 ,
160
+ sliding_window = 512 ,
161
+ rope_scaling = None ,
162
+ ),
163
+ vision_config = None ,
164
+ ),
129
165
_VARIANT_GEMMA_3_1B : Gemma3Config (
130
166
text_config = Gemma3TextConfig (
131
167
vocab_size = 262_144 ,
200
236
),
201
237
}
202
238
239
+ _TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA , _VARIANT_GEMMA_3_270M , _VARIANT_GEMMA_3_1B )
240
+
203
241
# ==== Flags ====
204
242
205
243
_CHECKPOINT_PATH = flags .DEFINE_string (
220
258
required = True ,
221
259
)
222
260
261
+ _NUM_LINEAR_LAYERS = flags .DEFINE_integer (
262
+ name = "num_linear_layers" ,
263
+ default = 2 ,
264
+ help = "Number of linear projection layers at the end of the Sentence Transformer." ,
265
+ )
266
+
223
267
_TRANSFORMER_DTYPE = flags .DEFINE_enum (
224
268
name = "text_dtype" ,
225
269
default = "bfloat16" ,
@@ -358,12 +402,12 @@ def convert_transformer_weights(
358
402
attn_head_dim = config .num_attention_heads * config .head_dim
359
403
kv_head_dim = config .num_key_value_heads * config .head_dim
360
404
361
- if path == _TRANSFORMER_EMBEDDER :
405
+ if path . endswith ( _TRANSFORMER_EMBEDDER ) :
362
406
if prop == "input_embedding" :
363
407
# Tied to language_model.lm_head.weight, assigned at the end.
364
408
converted_paths = ["language_model.model.embed_tokens.weight" ]
365
409
366
- if _VARIANT .value != _VARIANT_GEMMA_3_1B :
410
+ if _VARIANT .value not in _TEXT_ONLY_VARIANTS :
367
411
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
368
412
pre_expansion_embeddings = weights
369
413
mu = np .mean (pre_expansion_embeddings , axis = 0 )
@@ -372,12 +416,12 @@ def convert_transformer_weights(
372
416
weights = np .vstack ([pre_expansion_embeddings , new_embeddings ])
373
417
374
418
converted_weights = [weights ]
375
- elif _VARIANT .value == _VARIANT_GEMMA_3_1B or prop in ("mm_output_embedding" , "mm_input_embedding_extra" ):
419
+ elif _VARIANT .value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding" , "mm_input_embedding_extra" ):
376
420
return zip ([], [])
377
421
else :
378
422
raise ValueError (f"Unexpected member, { prop } , in Embedder." )
379
423
elif path .startswith (f"{ _TRANSFORMER_EMBEDDER } /mm" ):
380
- if _VARIANT .value == _VARIANT_GEMMA_3_1B :
424
+ if _VARIANT .value in _TEXT_ONLY_VARIANTS :
381
425
return zip ([], [])
382
426
383
427
if path .endswith ("/mm_input_projection" ):
@@ -388,14 +432,16 @@ def convert_transformer_weights(
388
432
converted_weights = [weights ]
389
433
else :
390
434
raise ValueError (f"Unexpected subpath, `{ path } `, in Embedder." )
391
- elif path == _TRANSFORMER_FINAL_NORM :
435
+ elif path . endswith ( _TRANSFORMER_FINAL_NORM ) :
392
436
converted_paths = ["language_model.model.norm.weight" ]
393
437
converted_weights = [weights ]
394
- elif path .startswith (_TRANSFORMER_DECODER_BLOCK ):
395
- decoder_block_path = path [_TRANSFORMER_DECODER_BLOCK_LEN :]
396
- next_path_separator_idx = decoder_block_path .find ("/" )
397
- layer_idx = decoder_block_path [:next_path_separator_idx ]
398
- decoder_block_path = decoder_block_path [next_path_separator_idx :]
438
+ elif _TRANSFORMER_DECODER_BLOCK in path :
439
+ decoder_block_start = path .find (_TRANSFORMER_DECODER_BLOCK )
440
+ decoder_block_offset = decoder_block_start + _TRANSFORMER_DECODER_BLOCK_LEN
441
+ decoder_block_path = path [decoder_block_offset :]
442
+ next_path_seperator_idx = decoder_block_path .find ("/" )
443
+ layer_idx = decoder_block_path [:next_path_seperator_idx ]
444
+ decoder_block_path = decoder_block_path [next_path_seperator_idx :]
399
445
400
446
base_path = f"language_model.model.layers.{ layer_idx } "
401
447
@@ -445,8 +491,6 @@ def convert_transformer_weights(
445
491
converted_weights = [weights ]
446
492
else :
447
493
raise ValueError (f"Unexpected path `{ path } ` in Decoder Block." )
448
- else :
449
- raise ValueError (f"Unexpected path `{ path } `." )
450
494
451
495
if (cpl := len (converted_paths )) != (cwl := len (converted_weights )):
452
496
raise ValueError (
@@ -457,11 +501,14 @@ def convert_transformer_weights(
457
501
return zip (converted_paths , converted_weights )
458
502
459
503
460
- def convert (checkpoint_path : str , config : Gemma3Config ) -> dict [str , torch .Tensor ]:
504
+ def convert (
505
+ checkpoint_path : str , config : Gemma3Config , variant : str
506
+ ) -> tuple [dict [str , torch .Tensor ], Optional [Sequence [np .ndarray ]]]:
461
507
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
462
508
checkpointer = obc .PyTreeCheckpointer ()
463
509
ckpt = checkpointer .restore (checkpoint_path )
464
510
hf_tree : dict [str , torch .Tensor ] = {}
511
+ orbax_tree_flat = tree .flatten_with_path (ckpt )
465
512
466
513
def update_tree (path : str , weights : np .ndarray , target_dtype : torch .dtype ) -> None :
467
514
hf_tree [path ] = torch .from_numpy (weights .astype ("float32" )).type (target_dtype )
@@ -473,7 +520,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
473
520
target_dtype ,
474
521
)
475
522
476
- for paths , value in tree . flatten_with_path ( ckpt ) :
523
+ for paths , value in orbax_tree_flat :
477
524
if paths [0 ].startswith ("SigLiPFromPatches_" ):
478
525
if config .vision_config is None :
479
526
continue
@@ -482,17 +529,21 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
482
529
update_tree (path , weights , config .vision_config .dtype )
483
530
else :
484
531
for path , weights in convert_transformer_weights (config = config .text_config , paths = paths , weights = value ):
485
- if config . vision_config is None :
532
+ if variant in _TEXT_ONLY_VARIANTS :
486
533
path = path [len ("language_model." ) :]
534
+ if variant == _VARIANT_EMBEDDINGGEMMA :
535
+ path = path [len ("model." ) :]
487
536
488
537
update_tree (path , weights , config .text_config .dtype )
489
538
490
- if config .vision_config is None :
539
+ if variant == _VARIANT_EMBEDDINGGEMMA :
540
+ return hf_tree , [weight [1 ].T for weight in orbax_tree_flat [: _NUM_LINEAR_LAYERS .value ]]
541
+ elif config .vision_config is None :
491
542
hf_tree ["lm_head.weight" ] = hf_tree ["model.embed_tokens.weight" ]
492
543
else :
493
544
hf_tree ["language_model.lm_head.weight" ] = hf_tree ["language_model.model.embed_tokens.weight" ]
494
545
495
- return hf_tree
546
+ return hf_tree , None
496
547
497
548
498
549
def main (* args ):
@@ -504,7 +555,7 @@ def main(*args):
504
555
config = _VARIANTS [variant ]
505
556
config .text_config .dtype = getattr (torch , _TRANSFORMER_DTYPE .value )
506
557
507
- if variant == _VARIANT_GEMMA_3_1B :
558
+ if variant in _TEXT_ONLY_VARIANTS :
508
559
config .vision_config = None
509
560
else :
510
561
config .vision_config .dtype = getattr (torch , _VISION_DTYPE .value )
@@ -520,11 +571,13 @@ def main(*args):
520
571
_TRANSFORMER_DTYPE .value ,
521
572
_VISION_DTYPE .value ,
522
573
)
523
- state_tree = convert (_CHECKPOINT_PATH .value , config )
574
+ state_tree , st_linears = convert (_CHECKPOINT_PATH .value , config , variant )
524
575
logging .info ("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face." , variant )
525
576
526
577
with accelerate .init_empty_weights ():
527
- if variant == _VARIANT_GEMMA_3_1B :
578
+ if variant == _VARIANT_EMBEDDINGGEMMA :
579
+ model = Gemma3TextModel (config = config .text_config )
580
+ elif variant in _TEXT_ONLY_VARIANTS :
528
581
model = Gemma3ForCausalLM (config = config .text_config )
529
582
else :
530
583
model = Gemma3ForConditionalGeneration (config )
@@ -548,6 +601,8 @@ def main(*args):
548
601
tokenizer = GemmaTokenizerFast (
549
602
_TOKENIZER_PATH .value ,
550
603
add_bos_token = True ,
604
+ add_eos_token = variant == _VARIANT_EMBEDDINGGEMMA ,
605
+ padding_side = "right" if variant == _VARIANT_EMBEDDINGGEMMA else "left" ,
551
606
extra_special_tokens = {
552
607
"image_token" : "<image_soft_token>" , # Should be ID=262_144
553
608
"boi_token" : "<start_of_image>" , # Should be ID=255_999
@@ -558,7 +613,7 @@ def main(*args):
558
613
tokenizer .save_pretrained (output_path )
559
614
logging .info ("Saved GemmaTokenizer for %s to %s" , variant , output_path )
560
615
561
- if variant != _VARIANT_GEMMA_3_1B :
616
+ if variant not in _TEXT_ONLY_VARIANTS :
562
617
image_processor = Gemma3ImageProcessor (
563
618
image_seq_length = 256 ,
564
619
image_mean = (0.5 ,) * 3 ,
@@ -589,6 +644,46 @@ def main(*args):
589
644
)
590
645
generation_config .save_pretrained (output_path )
591
646
647
+ if variant == _VARIANT_EMBEDDINGGEMMA :
648
+ from sentence_transformers import SentenceTransformer , models
649
+
650
+ # TODO: Support Retrieval tasks where we use `"title: {title} | text: {passage}"` interally and construct this
651
+ # from split-records cached data, but externally these come through as a single string with components
652
+ # separated by a newline. This should be used for `passage` for SentenceTransformers and the relevant MTEB
653
+ # Retrieval tasks.
654
+ # https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts
655
+ task_prompts = {
656
+ "query" : "task: search result | query: " ,
657
+ "document" : "title: none | text: " ,
658
+ "BitextMining" : "task: search result | query: " ,
659
+ "Clustering" : "task: clustering | query: " ,
660
+ "Classification" : "task: classification | query: " ,
661
+ "InstructionRetrieval" : "task: code retrieval | query: " ,
662
+ "MultilabelClassification" : "task: classification | query: " ,
663
+ "PairClassification" : "task: sentence similarity | query: " ,
664
+ "Reranking" : "task: search result | query: " ,
665
+ "Retrieval" : "task: search result | query: " ,
666
+ "Retrieval-query" : "task: search result | query: " ,
667
+ "Retrieval-document" : "title: none | text: " ,
668
+ "STS" : "task: sentence similarity | query: " ,
669
+ "Summarization" : "task: summarization | query: " ,
670
+ }
671
+
672
+ transformer = models .Transformer (output_path )
673
+ pooling = models .Pooling (config .text_config .hidden_size , pooling_mode = "mean" )
674
+ normalize = models .Normalize ()
675
+ linears = []
676
+
677
+ for linear_weight in st_linears :
678
+ out_size , in_size = linear_weight .shape [:2 ]
679
+ dense = models .Dense (in_size , out_size , bias = False , activation_function = None )
680
+ dense .linear .weight .data = torch .from_numpy (linear_weight .astype ("float32" ))
681
+ linears .append (dense )
682
+
683
+ model = SentenceTransformer (modules = [transformer , pooling , * linears , normalize ], prompts = task_prompts )
684
+ model = model .to (getattr (torch , _TRANSFORMER_DTYPE .value ))
685
+ model .save_pretrained (output_path )
686
+
592
687
593
688
if __name__ == "__main__" :
594
689
app .run (main )
0 commit comments