@@ -73,14 +73,13 @@ def _cross_entropy_forward(
7373 mask = col_offsets < VOCAB_SIZE
7474
7575 label_idx = tl .load (labels_ptr ).to (tl .int32 )
76- logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
76+ logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )). to ( tl . float32 )
7777
7878 # Go logit scaling for Cohere: t * x
7979 if DO_LOGIT_SCALING : logits = LOGIT_SCALE * logits
8080 # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
81- if DO_SOFTCAPPING : logits = SOFTCAP * triton_tanh (logits .to (tl .float32 ) / SOFTCAP ).to (logits .dtype )
82-
83- logits = logits .to (tl .float32 )
81+ if DO_SOFTCAPPING : logits = SOFTCAP * triton_tanh (logits / SOFTCAP )
82+
8483 c = tl .max (logits , 0 )
8584 logsumexp = c + tl .log (tl .sum (tl .exp (logits - c ), 0 ))
8685
@@ -152,14 +151,13 @@ def _chunked_cross_entropy_forward(
152151 mask = col_offsets < VOCAB_SIZE
153152
154153 label_idx = tl .load (labels_ptr ).to (tl .int32 )
155- logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
154+ logits = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )). to ( tl . float32 )
156155
157156 # Go logit scaling for Cohere: t * x
158157 if DO_LOGIT_SCALING : logits = LOGIT_SCALE * logits
159158 # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
160- if DO_SOFTCAPPING : logits = SOFTCAP * triton_tanh (logits . to ( tl . float32 ) / SOFTCAP ). to ( logits . dtype )
159+ if DO_SOFTCAPPING : logits = SOFTCAP * triton_tanh (logits / SOFTCAP )
161160
162- logits = logits .to (tl .float32 )
163161 c = tl .max (logits , 0 )
164162 logsumexp = c + tl .log (tl .sum (tl .exp (logits - c ), 0 ))
165163
@@ -229,7 +227,7 @@ def _cross_entropy_backward(
229227 else :
230228 dloss = 0.0
231229
232- x = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" ))
230+ x = tl .load (logits_ptr + col_offsets , mask = mask , other = - float ("inf" )). to ( tl . float32 )
233231
234232 # Do logit scaling for Cohere
235233 if DO_LOGIT_SCALING :
@@ -241,12 +239,12 @@ def _cross_entropy_backward(
241239 partial = x
242240 if DO_SOFTCAPPING :
243241 # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
244- partial = triton_tanh (x . to ( tl . float32 ) / SOFTCAP ). to ( x . dtype )
242+ partial = triton_tanh (x / SOFTCAP )
245243 x = SOFTCAP * partial
246244 pass
247245
248246 logsumexp = tl .load (logsumexp_ptr + row_idx )
249- y = tl .exp (x . to ( tl . float32 ) - logsumexp )
247+ y = tl .exp (x - logsumexp )
250248 y = tl .where (
251249 col_offsets == label_idx ,
252250 y - 1.0 , # exp(x - logsumexp) - 1
@@ -337,6 +335,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
337335 return losses
338336 pass
339337
338+
340339 @staticmethod
341340 def backward (ctx , dlosses ):
342341 logits , logsumexp , labels = ctx .saved_tensors
@@ -345,6 +344,8 @@ def backward(ctx, dlosses):
345344 n_rows , vocab_size = logits .shape
346345
347346 BLOCK_SIZE : int = 4096
347+ div : int
348+ mod : int
348349 div , mod = divmod (vocab_size , BLOCK_SIZE )
349350 n_blocks : int = div + (mod != 0 )
350351
0 commit comments