Skip to content

Commit 291772b

Browse files
RyanMullinsamer-sinhavasqu
authored
add: differential privacy research model (#40851)
* VaultGemma * Removing Sequence and Token classification models. Removing integration tests for now * Remove pass-only modular code. style fixes * Update vaultgemma.md * Update docs/source/en/model_doc/vaultgemma.md Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * Update docs/source/en/model_doc/vaultgemma.md Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * Add links to model doc * Correct model doc usage examples * Updating model doc to describe differences from Gemma 2 * Update model_doc links * Adding integration tests * style fixes * repo consistency * attribute exception --------- Co-authored-by: Amer <amersinha@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent 8502b41 commit 291772b

File tree

12 files changed

+1294
-0
lines changed

12 files changed

+1294
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,8 @@
709709
title: UL2
710710
- local: model_doc/umt5
711711
title: UMT5
712+
- local: model_doc/vaultgemma
713+
title: VaultGemma
712714
- local: model_doc/xmod
713715
title: X-MOD
714716
- local: model_doc/xglm
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
<!--Copyright 2025 the HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
15+
16+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
17+
18+
-->
19+
20+
# VaultGemma
21+
22+
## Overview
23+
24+
[VaultGemma](https://services.google.com/fh/files/blogs/vaultgemma_tech_report.pdf) is a text-only decoder model
25+
derived from [Gemma 2](https://huggingface.co/docs/transformers/en/model_doc/gemma2), notably it drops the norms after
26+
the Attention and MLP blocks, and uses full attention for all layers instead of alternating between full attention and
27+
local sliding attention. VaultGemma is available as a pretrained model with 1B parameters that uses a 1024 token
28+
sequence length.
29+
30+
VaultGemma was trained from scratch with sequence-level differential privacy (DP). Its training data includes the same
31+
mixture as the [Gemma 2 models](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315),
32+
consisting of a number of documents of varying lengths. Additionally, it is trained using
33+
[DP stochastic gradient descent (DP-SGD)](https://arxiv.org/abs/1607.00133) and provides a
34+
(ε ≤ 2.0, δ ≤ 1.1e-10)-sequence-level DP guarantee, where a sequence consists of 1024 consecutive tokens extracted from
35+
heterogeneous data sources. Specifically, the privacy unit of the guarantee is for the sequences after sampling and
36+
packing of the mixture.
37+
38+
> [!TIP]
39+
> Click on the VaultGemma models in the right sidebar for more examples of how to apply VaultGemma to different language tasks.
40+
41+
The example below demonstrates how to chat with the model with [`Pipeline`], the [`AutoModel`] class, or from the
42+
command line.
43+
44+
<hfoptions id="usage">
45+
<hfoption id="Pipeline">
46+
47+
48+
```python
49+
from transformers import pipeline
50+
51+
pipe = pipeline(
52+
task="text-generation",
53+
model="google/vaultgemma-1b",
54+
dtype="auto",
55+
device_map="auto",
56+
)
57+
58+
text = "Tell me an unknown interesting biology fact about the brain."
59+
outputs = pipe(text, max_new_tokens=32)
60+
response = outputs[0]["generated_text"]
61+
print(response)
62+
```
63+
64+
</hfoption>
65+
<hfoption id="AutoModel">
66+
67+
```python
68+
# pip install accelerate
69+
from transformers import AutoTokenizer, AutoModelForCausalLM
70+
71+
model_id = "google/vaultgemma-1b"
72+
tokenizer = AutoTokenizer.from_pretrained(model_id)
73+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype="auto")
74+
75+
text = "Tell me an unknown interesting biology fact about the brain."
76+
input_ids = tokenizer(text, return_tensors="pt").to(model.device)
77+
78+
outputs = model.generate(**input_ids, max_new_tokens=32)
79+
print(tokenizer.decode(outputs[0]))
80+
```
81+
82+
</hfoption>
83+
<hfoption id="transformers CLI">
84+
85+
```
86+
echo -e "Write me a poem about Machine Learning. Answer:" | transformers run --task text2text-generation --model google/vaultgemma-1b-pt --device 0
87+
```
88+
89+
</hfoption>
90+
</hfoptions>
91+
92+
## VaultGemmaConfig
93+
94+
[[autodoc]] VaultGemmaConfig
95+
96+
## VaultGemmaModel
97+
98+
[[autodoc]] VaultGemmaModel
99+
- forward
100+
101+
## VaultGemmaForCausalLM
102+
103+
[[autodoc]] VaultGemmaForCausalLM

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@
338338
from .unispeech_sat import *
339339
from .univnet import *
340340
from .upernet import *
341+
from .vaultgemma import *
341342
from .video_llava import *
342343
from .videomae import *
343344
from .vilt import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@
400400
("univnet", "UnivNetConfig"),
401401
("upernet", "UperNetConfig"),
402402
("van", "VanConfig"),
403+
("vaultgemma", "VaultGemmaConfig"),
403404
("video_llava", "VideoLlavaConfig"),
404405
("videomae", "VideoMAEConfig"),
405406
("vilt", "ViltConfig"),
@@ -842,6 +843,7 @@
842843
("univnet", "UnivNet"),
843844
("upernet", "UPerNet"),
844845
("van", "VAN"),
846+
("vaultgemma", "VaultGemma"),
845847
("video_llava", "VideoLlava"),
846848
("videomae", "VideoMAE"),
847849
("vilt", "ViLT"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
384384
("unispeech-sat", "UniSpeechSatModel"),
385385
("univnet", "UnivNetModel"),
386386
("van", "VanModel"),
387+
("vaultgemma", "VaultGemmaModel"),
387388
("video_llava", "VideoLlavaModel"),
388389
("videomae", "VideoMAEModel"),
389390
("vilt", "ViltModel"),
@@ -732,6 +733,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
732733
("starcoder2", "Starcoder2ForCausalLM"),
733734
("transfo-xl", "TransfoXLLMHeadModel"),
734735
("trocr", "TrOCRForCausalLM"),
736+
("vaultgemma", "VaultGemmaForCausalLM"),
735737
("whisper", "WhisperForCausalLM"),
736738
("xglm", "XGLMForCausalLM"),
737739
("xlm", "XLMWithLMHeadModel"),
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# coding=utf-8
2+
# Copyright 2025 the HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import TYPE_CHECKING
17+
18+
from ...utils import _LazyModule
19+
from ...utils.import_utils import define_import_structure
20+
21+
22+
if TYPE_CHECKING:
23+
from .configuration_vaultgemma import *
24+
from .modeling_vaultgemma import *
25+
else:
26+
import sys
27+
28+
_file = globals()["__file__"]
29+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2+
# This file was automatically generated from src/transformers/models/vaultgemma/modular_vaultgemma.py.
3+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
4+
# the file from the modular. If any change should be done, please apply the change to the
5+
# modular_vaultgemma.py file directly. One of our CI enforces this.
6+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
# coding=utf-8
8+
# Copyright 2025 the HuggingFace Team. All rights reserved.
9+
#
10+
# Licensed under the Apache License, Version 2.0 (the "License");
11+
# you may not use this file except in compliance with the License.
12+
# You may obtain a copy of the License at
13+
#
14+
# http://www.apache.org/licenses/LICENSE-2.0
15+
#
16+
# Unless required by applicable law or agreed to in writing, software
17+
# distributed under the License is distributed on an "AS IS" BASIS,
18+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
# See the License for the specific language governing permissions and
20+
# limitations under the License.
21+
22+
from ...configuration_utils import PretrainedConfig, layer_type_validation
23+
24+
25+
class VaultGemmaConfig(PretrainedConfig):
26+
r"""
27+
This is the configuration class to store the configuration of a [`VaultGemmaModel`]. It is used to instantiate an VaultGemma
28+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29+
defaults will yield a similar configuration to that of the VaultGemma-7B.
30+
e.g. [google/vaultgemma-7b](https://huggingface.co/google/vaultgemma-7b)
31+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32+
documentation from [`PretrainedConfig`] for more information.
33+
Args:
34+
vocab_size (`int`, *optional*, defaults to 256000):
35+
Vocabulary size of the VaultGemma model. Defines the number of different tokens that can be represented by the
36+
`inputs_ids` passed when calling [`VaultGemmaModel`]
37+
hidden_size (`int`, *optional*, defaults to 2304):
38+
Dimension of the hidden representations.
39+
intermediate_size (`int`, *optional*, defaults to 9216):
40+
Dimension of the MLP representations.
41+
num_hidden_layers (`int`, *optional*, defaults to 26):
42+
Number of hidden layers in the Transformer decoder.
43+
num_attention_heads (`int`, *optional*, defaults to 8):
44+
Number of attention heads for each attention layer in the Transformer decoder.
45+
num_key_value_heads (`int`, *optional*, defaults to 4):
46+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50+
by meanpooling all the original heads within that group. For more details, check out [this
51+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
52+
`num_attention_heads`.
53+
head_dim (`int`, *optional*, defaults to 256):
54+
The attention head dimension.
55+
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
56+
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
57+
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
58+
max_position_embeddings (`int`, *optional*, defaults to 8192):
59+
The maximum sequence length that this model might ever be used with.
60+
initializer_range (`float`, *optional*, defaults to 0.02):
61+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
63+
The epsilon used by the rms normalization layers.
64+
use_cache (`bool`, *optional*, defaults to `True`):
65+
Whether or not the model should return the last key/values attentions (not used by all models). Only
66+
relevant if `config.is_decoder=True`.
67+
pad_token_id (`int`, *optional*, defaults to 0):
68+
Padding token id.
69+
eos_token_id (`int`, *optional*, defaults to 1):
70+
End of stream token id.
71+
bos_token_id (`int`, *optional*, defaults to 2):
72+
Beginning of stream token id.
73+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
74+
Whether to tie weight embeddings
75+
rope_theta (`float`, *optional*, defaults to 10000.0):
76+
The base period of the RoPE embeddings.
77+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
78+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
79+
attention_dropout (`float`, *optional*, defaults to 0.0):
80+
The dropout ratio for the attention probabilities.
81+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
82+
scaling factor used on the attention scores
83+
sliding_window (`int`, *optional*, defaults to 4096):
84+
in VaultGemma, every other layer uses sliding window attention. This is the size of the sliding window.
85+
layer_types (`list`, *optional*):
86+
Attention pattern for each layer.
87+
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
88+
scaling factor when applying tanh softcapping on the logits.
89+
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
90+
scaling factor when applying tanh softcapping on the attention scores.
91+
92+
```python
93+
>>> from transformers import VaultGemmaModel, VaultGemmaConfig
94+
>>> # Initializing a VaultGemma vaultgemma-7b style configuration
95+
>>> configuration = VaultGemmaConfig()
96+
>>> # Initializing a model from the vaultgemma-7b style configuration
97+
>>> model = VaultGemmaModel(configuration)
98+
>>> # Accessing the model configuration
99+
>>> configuration = model.config
100+
```"""
101+
102+
model_type = "vaultgemma"
103+
keys_to_ignore_at_inference = ["past_key_values"]
104+
base_model_tp_plan = {
105+
"layers.*.self_attn.q_proj": "colwise",
106+
"layers.*.self_attn.k_proj": "colwise",
107+
"layers.*.self_attn.v_proj": "colwise",
108+
"layers.*.self_attn.o_proj": "rowwise",
109+
"layers.*.mlp.gate_proj": "colwise",
110+
"layers.*.mlp.up_proj": "colwise",
111+
"layers.*.mlp.down_proj": "rowwise",
112+
}
113+
base_model_pp_plan = {
114+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
115+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
116+
"norm": (["hidden_states"], ["hidden_states"]),
117+
}
118+
119+
def __init__(
120+
self,
121+
vocab_size=256000,
122+
hidden_size=2304,
123+
intermediate_size=9216,
124+
num_hidden_layers=26,
125+
num_attention_heads=8,
126+
num_key_value_heads=4,
127+
head_dim=256,
128+
hidden_activation="gelu_pytorch_tanh",
129+
max_position_embeddings=8192,
130+
initializer_range=0.02,
131+
rms_norm_eps=1e-6,
132+
use_cache=True,
133+
pad_token_id=0,
134+
eos_token_id=1,
135+
bos_token_id=2,
136+
tie_word_embeddings=True,
137+
rope_theta=10000.0,
138+
attention_bias=False,
139+
attention_dropout=0.0,
140+
query_pre_attn_scalar=256,
141+
sliding_window=4096,
142+
layer_types=None,
143+
final_logit_softcapping=30.0,
144+
attn_logit_softcapping=50.0,
145+
**kwargs,
146+
):
147+
super().__init__(
148+
pad_token_id=pad_token_id,
149+
bos_token_id=bos_token_id,
150+
eos_token_id=eos_token_id,
151+
tie_word_embeddings=tie_word_embeddings,
152+
**kwargs,
153+
)
154+
self.vocab_size = vocab_size
155+
self.max_position_embeddings = max_position_embeddings
156+
self.hidden_size = hidden_size
157+
self.intermediate_size = intermediate_size
158+
self.num_hidden_layers = num_hidden_layers
159+
self.num_attention_heads = num_attention_heads
160+
self.head_dim = head_dim
161+
self.num_key_value_heads = num_key_value_heads
162+
self.initializer_range = initializer_range
163+
self.rms_norm_eps = rms_norm_eps
164+
self.use_cache = use_cache
165+
self.rope_theta = rope_theta
166+
self.attention_bias = attention_bias
167+
self.attention_dropout = attention_dropout
168+
self.hidden_activation = hidden_activation
169+
self.query_pre_attn_scalar = query_pre_attn_scalar
170+
self.sliding_window = sliding_window
171+
self.final_logit_softcapping = final_logit_softcapping
172+
self.attn_logit_softcapping = attn_logit_softcapping
173+
self.layer_types = layer_types
174+
175+
if self.layer_types is None:
176+
self.layer_types = [
177+
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
178+
]
179+
layer_type_validation(self.layer_types)
180+
181+
182+
__all__ = ["VaultGemmaConfig"]

0 commit comments

Comments
 (0)