Skip to content

Commit 87f38db

Browse files
add: embedding model (#40694)
* Gemma 3 for Embeddings * Style fixes * Rename conversion file for consistency * Default padding side emb vs gen * Corrected 270m config * style fixes * EmbeddingGemma config * TODO for built-in prompts * Resolving the sentence similarity bug and updating the architecture * code style * Add query prompt for SentenceTransformers * Code quality * Fixing or_mask_function return types * Adding placeholder prompts for document and passage * Finalizing prompt templates * Adding Retrieval ro preconfigured prompts * Add Gemma 3 270M Config * Correcting num_linear_layers flag default * Export Sentence Transformer in correct dtype --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
1 parent 5b0c01b commit 87f38db

File tree

4 files changed

+169
-28
lines changed

4 files changed

+169
-28
lines changed

src/transformers/models/gemma3/configuration_gemma3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class Gemma3TextConfig(PretrainedConfig):
136136
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
137137
rope_local_base_freq (float, *optional*, defaults to 10000.0):
138138
The base period of the RoPE embeddings for local attention.
139+
use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all
140+
text tokens instead of using a causal mask. This does not change behavior for vision tokens.
139141
140142
```python
141143
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
@@ -193,6 +195,7 @@ def __init__(
193195
attn_logit_softcapping=None,
194196
rope_scaling=None,
195197
rope_local_base_freq=10_000.0,
198+
use_bidirectional_attention=False,
196199
**kwargs,
197200
):
198201
super().__init__(
@@ -222,6 +225,7 @@ def __init__(
222225
self.final_logit_softcapping = final_logit_softcapping
223226
self.attn_logit_softcapping = attn_logit_softcapping
224227
self.layer_types = layer_types
228+
self.use_bidirectional_attention = use_bidirectional_attention
225229

226230
self.rope_local_base_freq = rope_local_base_freq
227231
self.rope_scaling = rope_scaling

src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py renamed to src/transformers/models/gemma3/convert_gemma3_weights.py

Lines changed: 121 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
1818
19-
python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \
19+
python src/transformers/models/gemma3/convert_gemma3_weights.py \
2020
--variant='gemma3_4b' \
2121
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
2222
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
2323
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/"
2424
"""
2525

2626
from collections.abc import Iterator, Sequence
27-
from typing import Any
27+
from typing import Any, Optional
2828

2929
import accelerate
3030
import numpy as np
@@ -40,6 +40,7 @@
4040
Gemma3ImageProcessor,
4141
Gemma3Processor,
4242
Gemma3TextConfig,
43+
Gemma3TextModel,
4344
GemmaTokenizerFast,
4445
GenerationConfig,
4546
SiglipVisionConfig,
@@ -100,10 +101,10 @@
100101
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
101102
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
102103

103-
_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
104+
_TRANSFORMER_DECODER_BLOCK = "/layer_"
104105
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
105-
_TRANSFORMER_EMBEDDER = "transformer/embedder"
106-
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
106+
_TRANSFORMER_EMBEDDER = "/embedder"
107+
_TRANSFORMER_FINAL_NORM = "/final_norm"
107108
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
108109
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
109110

@@ -121,11 +122,46 @@
121122
"vision_use_head": False,
122123
}
123124

125+
_VARIANT_EMBEDDINGGEMMA = "embedding"
126+
_VARIANT_GEMMA_3_270M = "gemma3_270m"
124127
_VARIANT_GEMMA_3_1B = "gemma3_1b"
125128
_VARIANT_GEMMA_3_4B = "gemma3_4b"
126129
_VARIANT_GEMMA_3_12B = "gemma3_12b"
127130
_VARIANT_GEMMA_3_27B = "gemma3_27b"
128131
_VARIANTS = {
132+
_VARIANT_EMBEDDINGGEMMA: Gemma3Config(
133+
text_config=Gemma3TextConfig(
134+
vocab_size=262_144,
135+
hidden_size=768,
136+
intermediate_size=1152,
137+
num_hidden_layers=24,
138+
num_attention_heads=3,
139+
num_key_value_heads=1,
140+
head_dim=256,
141+
max_position_embeddings=1024,
142+
query_pre_attn_scalar=256,
143+
sliding_window=512,
144+
rope_scaling=None,
145+
use_bidirectional_attention=True,
146+
),
147+
vision_config=None,
148+
),
149+
_VARIANT_GEMMA_3_270M: Gemma3Config(
150+
text_config=Gemma3TextConfig(
151+
vocab_size=262_144,
152+
hidden_size=640,
153+
intermediate_size=2048,
154+
num_hidden_layers=18,
155+
num_attention_heads=4,
156+
num_key_value_heads=1,
157+
head_dim=256,
158+
max_position_embeddings=32768,
159+
query_pre_attn_scalar=256,
160+
sliding_window=512,
161+
rope_scaling=None,
162+
),
163+
vision_config=None,
164+
),
129165
_VARIANT_GEMMA_3_1B: Gemma3Config(
130166
text_config=Gemma3TextConfig(
131167
vocab_size=262_144,
@@ -200,6 +236,8 @@
200236
),
201237
}
202238

239+
_TEXT_ONLY_VARIANTS = (_VARIANT_EMBEDDINGGEMMA, _VARIANT_GEMMA_3_270M, _VARIANT_GEMMA_3_1B)
240+
203241
# ==== Flags ====
204242

205243
_CHECKPOINT_PATH = flags.DEFINE_string(
@@ -220,6 +258,12 @@
220258
required=True,
221259
)
222260

261+
_NUM_LINEAR_LAYERS = flags.DEFINE_integer(
262+
name="num_linear_layers",
263+
default=2,
264+
help="Number of linear projection layers at the end of the Sentence Transformer.",
265+
)
266+
223267
_TRANSFORMER_DTYPE = flags.DEFINE_enum(
224268
name="text_dtype",
225269
default="bfloat16",
@@ -358,12 +402,12 @@ def convert_transformer_weights(
358402
attn_head_dim = config.num_attention_heads * config.head_dim
359403
kv_head_dim = config.num_key_value_heads * config.head_dim
360404

361-
if path == _TRANSFORMER_EMBEDDER:
405+
if path.endswith(_TRANSFORMER_EMBEDDER):
362406
if prop == "input_embedding":
363407
# Tied to language_model.lm_head.weight, assigned at the end.
364408
converted_paths = ["language_model.model.embed_tokens.weight"]
365409

366-
if _VARIANT.value != _VARIANT_GEMMA_3_1B:
410+
if _VARIANT.value not in _TEXT_ONLY_VARIANTS:
367411
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
368412
pre_expansion_embeddings = weights
369413
mu = np.mean(pre_expansion_embeddings, axis=0)
@@ -372,12 +416,12 @@ def convert_transformer_weights(
372416
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
373417

374418
converted_weights = [weights]
375-
elif _VARIANT.value == _VARIANT_GEMMA_3_1B or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
419+
elif _VARIANT.value in _TEXT_ONLY_VARIANTS or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
376420
return zip([], [])
377421
else:
378422
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
379423
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
380-
if _VARIANT.value == _VARIANT_GEMMA_3_1B:
424+
if _VARIANT.value in _TEXT_ONLY_VARIANTS:
381425
return zip([], [])
382426

383427
if path.endswith("/mm_input_projection"):
@@ -388,14 +432,16 @@ def convert_transformer_weights(
388432
converted_weights = [weights]
389433
else:
390434
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
391-
elif path == _TRANSFORMER_FINAL_NORM:
435+
elif path.endswith(_TRANSFORMER_FINAL_NORM):
392436
converted_paths = ["language_model.model.norm.weight"]
393437
converted_weights = [weights]
394-
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
395-
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
396-
next_path_separator_idx = decoder_block_path.find("/")
397-
layer_idx = decoder_block_path[:next_path_separator_idx]
398-
decoder_block_path = decoder_block_path[next_path_separator_idx:]
438+
elif _TRANSFORMER_DECODER_BLOCK in path:
439+
decoder_block_start = path.find(_TRANSFORMER_DECODER_BLOCK)
440+
decoder_block_offset = decoder_block_start + _TRANSFORMER_DECODER_BLOCK_LEN
441+
decoder_block_path = path[decoder_block_offset:]
442+
next_path_seperator_idx = decoder_block_path.find("/")
443+
layer_idx = decoder_block_path[:next_path_seperator_idx]
444+
decoder_block_path = decoder_block_path[next_path_seperator_idx:]
399445

400446
base_path = f"language_model.model.layers.{layer_idx}"
401447

@@ -445,8 +491,6 @@ def convert_transformer_weights(
445491
converted_weights = [weights]
446492
else:
447493
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
448-
else:
449-
raise ValueError(f"Unexpected path `{path}`.")
450494

451495
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
452496
raise ValueError(
@@ -457,11 +501,14 @@ def convert_transformer_weights(
457501
return zip(converted_paths, converted_weights)
458502

459503

460-
def convert(checkpoint_path: str, config: Gemma3Config) -> dict[str, torch.Tensor]:
504+
def convert(
505+
checkpoint_path: str, config: Gemma3Config, variant: str
506+
) -> tuple[dict[str, torch.Tensor], Optional[Sequence[np.ndarray]]]:
461507
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
462508
checkpointer = obc.PyTreeCheckpointer()
463509
ckpt = checkpointer.restore(checkpoint_path)
464510
hf_tree: dict[str, torch.Tensor] = {}
511+
orbax_tree_flat = tree.flatten_with_path(ckpt)
465512

466513
def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None:
467514
hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype)
@@ -473,7 +520,7 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
473520
target_dtype,
474521
)
475522

476-
for paths, value in tree.flatten_with_path(ckpt):
523+
for paths, value in orbax_tree_flat:
477524
if paths[0].startswith("SigLiPFromPatches_"):
478525
if config.vision_config is None:
479526
continue
@@ -482,17 +529,21 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
482529
update_tree(path, weights, config.vision_config.dtype)
483530
else:
484531
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
485-
if config.vision_config is None:
532+
if variant in _TEXT_ONLY_VARIANTS:
486533
path = path[len("language_model.") :]
534+
if variant == _VARIANT_EMBEDDINGGEMMA:
535+
path = path[len("model.") :]
487536

488537
update_tree(path, weights, config.text_config.dtype)
489538

490-
if config.vision_config is None:
539+
if variant == _VARIANT_EMBEDDINGGEMMA:
540+
return hf_tree, [weight[1].T for weight in orbax_tree_flat[: _NUM_LINEAR_LAYERS.value]]
541+
elif config.vision_config is None:
491542
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]
492543
else:
493544
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
494545

495-
return hf_tree
546+
return hf_tree, None
496547

497548

498549
def main(*args):
@@ -504,7 +555,7 @@ def main(*args):
504555
config = _VARIANTS[variant]
505556
config.text_config.dtype = getattr(torch, _TRANSFORMER_DTYPE.value)
506557

507-
if variant == _VARIANT_GEMMA_3_1B:
558+
if variant in _TEXT_ONLY_VARIANTS:
508559
config.vision_config = None
509560
else:
510561
config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value)
@@ -520,11 +571,13 @@ def main(*args):
520571
_TRANSFORMER_DTYPE.value,
521572
_VISION_DTYPE.value,
522573
)
523-
state_tree = convert(_CHECKPOINT_PATH.value, config)
574+
state_tree, st_linears = convert(_CHECKPOINT_PATH.value, config, variant)
524575
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
525576

526577
with accelerate.init_empty_weights():
527-
if variant == _VARIANT_GEMMA_3_1B:
578+
if variant == _VARIANT_EMBEDDINGGEMMA:
579+
model = Gemma3TextModel(config=config.text_config)
580+
elif variant in _TEXT_ONLY_VARIANTS:
528581
model = Gemma3ForCausalLM(config=config.text_config)
529582
else:
530583
model = Gemma3ForConditionalGeneration(config)
@@ -548,6 +601,8 @@ def main(*args):
548601
tokenizer = GemmaTokenizerFast(
549602
_TOKENIZER_PATH.value,
550603
add_bos_token=True,
604+
add_eos_token=variant == _VARIANT_EMBEDDINGGEMMA,
605+
padding_side="right" if variant == _VARIANT_EMBEDDINGGEMMA else "left",
551606
extra_special_tokens={
552607
"image_token": "<image_soft_token>", # Should be ID=262_144
553608
"boi_token": "<start_of_image>", # Should be ID=255_999
@@ -558,7 +613,7 @@ def main(*args):
558613
tokenizer.save_pretrained(output_path)
559614
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
560615

561-
if variant != _VARIANT_GEMMA_3_1B:
616+
if variant not in _TEXT_ONLY_VARIANTS:
562617
image_processor = Gemma3ImageProcessor(
563618
image_seq_length=256,
564619
image_mean=(0.5,) * 3,
@@ -589,6 +644,46 @@ def main(*args):
589644
)
590645
generation_config.save_pretrained(output_path)
591646

647+
if variant == _VARIANT_EMBEDDINGGEMMA:
648+
from sentence_transformers import SentenceTransformer, models
649+
650+
# TODO: Support Retrieval tasks where we use `"title: {title} | text: {passage}"` interally and construct this
651+
# from split-records cached data, but externally these come through as a single string with components
652+
# separated by a newline. This should be used for `passage` for SentenceTransformers and the relevant MTEB
653+
# Retrieval tasks.
654+
# https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts
655+
task_prompts = {
656+
"query": "task: search result | query: ",
657+
"document": "title: none | text: ",
658+
"BitextMining": "task: search result | query: ",
659+
"Clustering": "task: clustering | query: ",
660+
"Classification": "task: classification | query: ",
661+
"InstructionRetrieval": "task: code retrieval | query: ",
662+
"MultilabelClassification": "task: classification | query: ",
663+
"PairClassification": "task: sentence similarity | query: ",
664+
"Reranking": "task: search result | query: ",
665+
"Retrieval": "task: search result | query: ",
666+
"Retrieval-query": "task: search result | query: ",
667+
"Retrieval-document": "title: none | text: ",
668+
"STS": "task: sentence similarity | query: ",
669+
"Summarization": "task: summarization | query: ",
670+
}
671+
672+
transformer = models.Transformer(output_path)
673+
pooling = models.Pooling(config.text_config.hidden_size, pooling_mode="mean")
674+
normalize = models.Normalize()
675+
linears = []
676+
677+
for linear_weight in st_linears:
678+
out_size, in_size = linear_weight.shape[:2]
679+
dense = models.Dense(in_size, out_size, bias=False, activation_function=None)
680+
dense.linear.weight.data = torch.from_numpy(linear_weight.astype("float32"))
681+
linears.append(dense)
682+
683+
model = SentenceTransformer(modules=[transformer, pooling, *linears, normalize], prompts=task_prompts)
684+
model = model.to(getattr(torch, _TRANSFORMER_DTYPE.value))
685+
model.save_pretrained(output_path)
686+
592687

593688
if __name__ == "__main__":
594689
app.run(main)

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,19 @@ def _init_weights(self, module):
443443
module.mm_input_projection_weight.data.zero_()
444444

445445

446+
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
447+
"""
448+
Enables a bidirectional mask within the sliding window.
449+
"""
450+
451+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
452+
"""A token can attend to any other token if their absolute distance is within
453+
half the sliding window size (distance <= sliding_window // 2)."""
454+
return abs(q_idx - kv_idx) <= sliding_window // 2
455+
456+
return inner_mask
457+
458+
446459
@auto_docstring
447460
class Gemma3TextModel(Gemma3PreTrainedModel):
448461
config: Gemma3TextConfig
@@ -531,10 +544,16 @@ def forward(
531544
"past_key_values": past_key_values,
532545
"position_ids": position_ids,
533546
}
547+
sliding_mask_kwargs = mask_kwargs.copy()
548+
549+
if self.config.use_bidirectional_attention:
550+
mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
551+
sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
552+
534553
# Create the masks
535554
causal_mask_mapping = {
536555
"full_attention": create_causal_mask(**mask_kwargs),
537-
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
556+
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
538557
}
539558

540559
# embed positions

0 commit comments

Comments
 (0)