Skip to content

Commit 60b68e3

Browse files
vasquCyrilvallez
authored andcommitted
[Gemma Embedding] Fix SWA (#40700)
* fix gemma embedding flash attention * fix sdpa * fix atttempt number 2 * alternative gemma fix * fix modular
1 parent 87f38db commit 60b68e3

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

src/transformers/models/gemma3/configuration_gemma3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def __init__(
226226
self.attn_logit_softcapping = attn_logit_softcapping
227227
self.layer_types = layer_types
228228
self.use_bidirectional_attention = use_bidirectional_attention
229+
if use_bidirectional_attention:
230+
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
229231

230232
self.rope_local_base_freq = rope_local_base_freq
231233
self.rope_scaling = rope_scaling

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
279279
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
280280
self.scaling = config.query_pre_attn_scalar**-0.5
281281
self.attention_dropout = self.config.attention_dropout
282-
self.is_causal = True
282+
self.is_causal = not self.config.use_bidirectional_attention
283283

284284
self.q_proj = nn.Linear(
285285
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
@@ -450,8 +450,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in
450450

451451
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
452452
"""A token can attend to any other token if their absolute distance is within
453-
half the sliding window size (distance <= sliding_window // 2)."""
454-
return abs(q_idx - kv_idx) <= sliding_window // 2
453+
the (exclusive) sliding window size (distance < sliding_window)."""
454+
return abs(q_idx - kv_idx) < sliding_window
455455

456456
return inner_mask
457457

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def __init__(
237237
self.attn_logit_softcapping = attn_logit_softcapping
238238
self.layer_types = layer_types
239239
self.use_bidirectional_attention = use_bidirectional_attention
240+
if use_bidirectional_attention:
241+
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
240242

241243
self.rope_local_base_freq = rope_local_base_freq
242244
self.rope_scaling = rope_scaling
@@ -402,6 +404,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
402404

403405
super().__init__(config, layer_idx)
404406
self.sliding_window = config.sliding_window if self.is_sliding else None
407+
self.is_causal = not self.config.use_bidirectional_attention
405408

406409
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
407410
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
@@ -546,8 +549,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in
546549

547550
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
548551
"""A token can attend to any other token if their absolute distance is within
549-
half the sliding window size (distance <= sliding_window // 2)."""
550-
return abs(q_idx - kv_idx) <= sliding_window // 2
552+
the (exclusive) sliding window size (distance < sliding_window)."""
553+
return abs(q_idx - kv_idx) < sliding_window
551554

552555
return inner_mask
553556

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,7 @@ def apply_rotary_pos_emb(
17441744
class Gemma3nTextAttention(Gemma3Attention):
17451745
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
17461746
super().__init__(config, layer_idx)
1747+
self.is_causal = True
17471748
del self.attn_logit_softcapping
17481749
del self.scaling
17491750
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)

0 commit comments

Comments
 (0)