1919from transformers .models .llama .modeling_llama import logger
2020
2121
22+ @triton .heuristics ({"DO_SOFTCAPPING" : lambda args : args ["DO_SOFTCAPPING" ],})
2223@triton .jit
2324def _cross_entropy_forward (
2425 logits_ptr , logits_row_stride ,
2526 loss_ptr ,
2627 logsumexp_ptr ,
2728 labels_ptr ,
28- VOCAB_SIZE : tl .constexpr ,
29- BLOCK_SIZE : tl .constexpr ,
29+ VOCAB_SIZE : tl .constexpr ,
30+ BLOCK_SIZE : tl .constexpr ,
31+ DO_SOFTCAPPING : tl .constexpr ,
32+ SOFTCAP : tl .constexpr ,
3033):
3134 """
3235 Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
@@ -58,29 +61,38 @@ def _cross_entropy_forward(
5861 mask = col_offsets < VOCAB_SIZE
5962
6063 label_idx = tl .load (labels_ptr ).to (tl .int32 )
61- logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )).to (tl .float32 )
64+ logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
65+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
66+ if DO_SOFTCAPPING : logits = SOFTCAP * tl .math .tanh (logits / SOFTCAP )
67+
68+ logits = logits .to (tl .float32 )
6269 c = tl .max (logits , 0 )
6370 logsumexp = c + tl .log (tl .sum (tl .exp (logits - c ), 0 ))
6471
6572 if label_idx != - 100 :
66- x = tl .load (logits_ptr + label_idx ).to (tl .float32 )
67- loss = logsumexp - x
73+ x = tl .load (logits_ptr + label_idx )
74+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
75+ if DO_SOFTCAPPING : x = SOFTCAP * tl .math .tanh (x / SOFTCAP )
76+ loss = logsumexp - x .to (tl .float32 )
6877 else :
6978 loss = 0.0
7079 tl .store (logsumexp_ptr , logsumexp )
7180 tl .store (loss_ptr , loss )
7281pass
7382
7483
84+ @triton .heuristics ({"DO_SOFTCAPPING" : lambda args : args ["DO_SOFTCAPPING" ],})
7585@triton .jit
7686def _chunked_cross_entropy_forward (
7787 logits_ptr , logits_row_stride ,
7888 loss_ptr ,
7989 logsumexp_ptr ,
8090 labels_ptr ,
81- VOCAB_SIZE : tl .constexpr ,
82- N_CHUNKS : tl .constexpr ,
83- BLOCK_SIZE : tl .constexpr ,
91+ VOCAB_SIZE : tl .constexpr ,
92+ N_CHUNKS : tl .constexpr ,
93+ BLOCK_SIZE : tl .constexpr ,
94+ DO_SOFTCAPPING : tl .constexpr ,
95+ SOFTCAP : tl .constexpr ,
8496):
8597 """
8698 256K vocab divided in 4 chunks
@@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
117129 mask = col_offsets < VOCAB_SIZE
118130
119131 label_idx = tl .load (labels_ptr ).to (tl .int32 )
120- logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )).to (tl .float32 )
132+ logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
133+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
134+ if DO_SOFTCAPPING : logits = SOFTCAP * tl .math .tanh (logits / SOFTCAP )
135+
136+ logits = logits .to (tl .float32 )
121137 c = tl .max (logits , 0 )
122138 logsumexp = c + tl .log (tl .sum (tl .exp (logits - c ), 0 ))
123139
@@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
126142 # Do the -x separately
127143 if label_idx != - 100 :
128144 x = tl .load (logits_ptr + label_idx ).to (tl .float32 )
129- loss = - 1.0 * x
145+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
146+ if DO_SOFTCAPPING : x = SOFTCAP * tl .math .tanh (x / SOFTCAP )
147+ loss = - 1.0 * x .to (tl .float32 )
130148 else :
131149 loss = 0.0
132150 tl .store (loss_ptr , loss )
@@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
135153pass
136154
137155
156+ @triton .heuristics ({"DO_SOFTCAPPING" : lambda args : args ["DO_SOFTCAPPING" ],})
138157@triton .jit
139158def _cross_entropy_backward (
140159 logits_ptr , logits_row_stride ,
141160 dloss_ptr , dloss_row_stride ,
142161 logsumexp_ptr ,
143162 labels_ptr ,
144- VOCAB_SIZE : tl .constexpr ,
145- BLOCK_SIZE : tl .constexpr ,
163+ VOCAB_SIZE : tl .constexpr ,
164+ BLOCK_SIZE : tl .constexpr ,
165+ DO_SOFTCAPPING : tl .constexpr ,
166+ SOFTCAP : tl .constexpr ,
146167):
147168 """
148169 CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
@@ -173,15 +194,27 @@ def _cross_entropy_backward(
173194 else :
174195 dloss = 0.0
175196
176- x = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )).to (tl .float32 )
197+ x = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
198+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
199+ if DO_SOFTCAPPING :
200+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
201+ partial = tl .math .tanh (x / SOFTCAP )
202+ x = SOFTCAP * partial
203+ pass
204+
177205 logsumexp = tl .load (logsumexp_ptr + row_idx )
178- y = tl .exp (x - logsumexp )
206+ y = tl .exp (x . to ( tl . float32 ) - logsumexp )
179207 y = tl .where (
180208 col_offsets == label_idx ,
181209 y - 1.0 , # exp(x - logsumexp) - 1
182210 y , # exp(x - logsumexp)
183211 )
184212
213+ if DO_SOFTCAPPING :
214+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
215+ y = y * (1.0 - partial * partial )
216+ pass
217+
185218 # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
186219 tl .store (logits_ptr + col_offsets , dloss * y , mask = mask )
187220pass
@@ -191,40 +224,46 @@ def _cross_entropy_backward(
191224
192225class Fast_CrossEntropyLoss (torch .autograd .Function ):
193226 @staticmethod
194- def forward (ctx , logits , labels ):
227+ def forward (ctx , logits , labels , logit_softcapping = 0 ):
195228 n_rows , vocab_size = logits .shape
196229
197230 div , mod = divmod (vocab_size , MAX_FUSED_SIZE )
198231 n_chunks = div + (mod != 0 )
199- losses = torch .empty (n_rows , dtype = torch .float32 , device = "cuda" )
232+ losses = torch .empty (n_rows , dtype = torch .float32 , device = "cuda:0" )
233+
234+ DO_SOFTCAPPING = (logit_softcapping != 0 )
200235
201236 if n_chunks == 1 :
202237 # For small vocabs <= 65336 like Llama, Mistral
203238 BLOCK_SIZE , num_warps = calculate_settings (vocab_size )
204- logsumexp = torch .empty (n_rows , dtype = torch .float32 , device = "cuda" )
239+ logsumexp = torch .empty (n_rows , dtype = torch .float32 , device = "cuda:0 " )
205240
206241 _cross_entropy_forward [(n_rows ,)](
207242 logits , logits .stride (0 ),
208243 losses ,
209244 logsumexp ,
210245 labels ,
211- VOCAB_SIZE = vocab_size ,
212- BLOCK_SIZE = BLOCK_SIZE ,
213- num_warps = num_warps ,
246+ VOCAB_SIZE = vocab_size ,
247+ BLOCK_SIZE = BLOCK_SIZE ,
248+ DO_SOFTCAPPING = DO_SOFTCAPPING ,
249+ SOFTCAP = logit_softcapping ,
250+ num_warps = num_warps ,
214251 )
215252 else :
216253 # For large vocabs > 65336 like Gemma 256K
217- logsumexp = torch .empty ((n_rows , n_chunks ,), dtype = torch .float32 , device = "cuda" )
254+ logsumexp = torch .empty ((n_rows , n_chunks ,), dtype = torch .float32 , device = "cuda:0 " )
218255
219256 _chunked_cross_entropy_forward [(n_rows , n_chunks ,)](
220257 logits , logits .stride (0 ),
221258 losses ,
222259 logsumexp ,
223260 labels ,
224- VOCAB_SIZE = vocab_size ,
225- N_CHUNKS = n_chunks ,
226- BLOCK_SIZE = MAX_FUSED_SIZE ,
227- num_warps = 32 ,
261+ VOCAB_SIZE = vocab_size ,
262+ N_CHUNKS = n_chunks ,
263+ BLOCK_SIZE = MAX_FUSED_SIZE ,
264+ DO_SOFTCAPPING = DO_SOFTCAPPING ,
265+ SOFTCAP = logit_softcapping ,
266+ num_warps = 32 ,
228267 )
229268 # logsumexp(chunked_logsumexp) - x
230269 # Do the -x separately
@@ -234,6 +273,8 @@ def forward(ctx, logits, labels):
234273 pass
235274
236275 ctx .save_for_backward (logits , logsumexp , labels )
276+ ctx .DO_SOFTCAPPING = DO_SOFTCAPPING
277+ ctx .logit_softcapping = logit_softcapping
237278 return losses
238279 pass
239280
@@ -251,16 +292,18 @@ def backward(ctx, dlosses):
251292 dlosses , dlosses .stride (0 ),
252293 logsumexp ,
253294 labels ,
254- VOCAB_SIZE = vocab_size ,
255- BLOCK_SIZE = BLOCK_SIZE ,
256- num_warps = 8 ,
295+ VOCAB_SIZE = vocab_size ,
296+ BLOCK_SIZE = BLOCK_SIZE ,
297+ DO_SOFTCAPPING = ctx .DO_SOFTCAPPING ,
298+ SOFTCAP = ctx .logit_softcapping ,
299+ num_warps = 8 ,
257300 )
258301 return logits , None , None ,
259302 pass
260303pass
261304
262305
263- def fast_cross_entropy_loss (logits , labels ):
306+ def fast_cross_entropy_loss (logits , labels , logit_softcapping = 0 ):
264307 """
265308 Arguments:
266309 logits: (batch, seq_len, vocab_size)
@@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
274317 loss = Fast_CrossEntropyLoss .apply (
275318 logits .view (batch * seq_len , d ),
276319 labels .view (- 1 ),
320+ logit_softcapping ,
277321 )
278322 n_items = torch .count_nonzero (labels != - 100 )
279323 return loss .sum () / n_items
0 commit comments