@@ -89,6 +89,7 @@ def GraniteAttention_fast_forward(
8989 n_groups = self .num_key_value_groups
9090 n_kv_heads = self .config .num_key_value_heads
9191 head_dim = self .head_dim
92+ dropout_p = self .config .attention_dropout if self .training else 0
9293 assert (n_kv_heads * n_groups == n_heads )
9394
9495 Q , K , V = self .apply_qkv (self , hidden_states )
@@ -135,15 +136,15 @@ def GraniteAttention_fast_forward(
135136 Q = Q .view (bsz , q_len , n_kv_heads , n_groups , head_dim )
136137 pass
137138
138- A = xformers_attention (Q , K , V , attn_bias = causal_mask , scale = self .scaling )
139+ A = xformers_attention (Q , K , V , attn_bias = causal_mask , scale = self .scaling , p = dropout_p )
139140 A = A .view (bsz , q_len , n_heads , head_dim )
140141
141142 elif HAS_FLASH_ATTENTION and attention_mask is None :
142143 Q = Q .transpose (1 , 2 )
143144 K = K .transpose (1 , 2 )
144145 V = V .transpose (1 , 2 )
145146 window = (kv_seq_len , kv_seq_len )
146- A = flash_attn_func (Q , K , V , causal = True , window_size = window , softmax_scale = self .scaling )
147+ A = flash_attn_func (Q , K , V , causal = True , window_size = window , softmax_scale = self .scaling , dropout_p = dropout_p )
147148 else :
148149 # Grouped query attention
149150 # if n_groups != 1:
@@ -157,7 +158,7 @@ def GraniteAttention_fast_forward(
157158 Q , K , V = Q .contiguous (), K .contiguous (), V .contiguous ()
158159 # Needs (batch_size, n_heads, seq_len, head_dim)
159160 # is_casual and attention_mask must not be both set!
160- A = scaled_dot_product_attention (Q , K , V , attn_mask = attention_mask , scale = self .scaling , is_causal = False )
161+ A = scaled_dot_product_attention (Q , K , V , attn_mask = attention_mask , scale = self .scaling , is_causal = False , dropout_p = dropout_p )
161162 # Go back to (batch_size, seq_len, n_heads, head_dim)
162163 A = A .transpose (1 , 2 ).contiguous ()
163164 pass
0 commit comments