@@ -292,7 +292,6 @@ at::Tensor mixtral_moe_woq_kernl_impl(
292292
293293template <typename T>
294294std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel (
295- const at::Tensor& hidden_states,
296295 const at::Tensor& scores,
297296 const at::Tensor& routed_scaling_factor,
298297 const int64_t n_group,
@@ -302,7 +301,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
302301 auto group_size = n_routed_experts / n_group;
303302 auto n = scores.size (0 );
304303 auto h = scores.size (1 );
305- auto group_scores = at::empty ({n, n_group}, hidden_states .options ());
304+ auto group_scores = at::empty ({n, n_group}, scores .options ());
306305 auto group_scores_ptr = group_scores.data_ptr <T>();
307306 auto scores_ptr = scores.data_ptr <T>();
308307#pragma omp parallel for collapse(2)
@@ -319,7 +318,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
319318 }
320319
321320 auto group_idx = std::get<1 >(group_scores.topk (topk_group, -1 , true , false ));
322- auto tmp_scores = at::zeros_like (scores, hidden_states .options ());
321+ auto tmp_scores = at::zeros_like (scores, scores .options ());
323322 auto group_idx_ptr = group_idx.data_ptr <int64_t >();
324323 auto tmp_scores_ptr = tmp_scores.data_ptr <T>();
325324 T scale = routed_scaling_factor.item <T>();
@@ -339,17 +338,117 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
339338 return std::make_tuple (topk, topk_weight);
340339}
341340
341+ template <typename T>
342+ std::tuple<at::Tensor, at::Tensor> deepseekv3_moegate_kernel (
343+ const at::Tensor& scores,
344+ const at::Tensor& routed_scaling_factor,
345+ const int64_t n_group,
346+ const int64_t topk_group,
347+ const int64_t n_routed_experts,
348+ const int64_t top_k,
349+ const at::Tensor& e_score_cbias) {
350+ auto group_size = n_routed_experts / n_group;
351+ auto n = scores.size (0 );
352+ auto h = scores.size (1 );
353+ auto scores_for_choice = at::empty ({n, n_group, group_size}, at::kFloat );
354+ auto scores_ptr = scores.data_ptr <T>();
355+ auto scores_for_choice_ptr = scores_for_choice.data_ptr <float >();
356+ auto scores_for_choice_stride0 = scores_for_choice.stride (0 );
357+ auto e_score_cbias_ptr = e_score_cbias.data_ptr <float >();
358+ #pragma omp parallel for collapse(2)
359+ for (auto i = 0 ; i < n; i++) {
360+ for (auto j = 0 ; j < n_group; j++) {
361+ auto k_start = j * group_size;
362+ auto k_end = k_start + group_size;
363+ for (auto k = k_start; k < k_end; k++) {
364+ scores_for_choice_ptr[i * scores_for_choice_stride0 + k] =
365+ scores_ptr[i * h + k] + e_score_cbias_ptr[k];
366+ }
367+ }
368+ }
369+ auto group_scores =
370+ std::get<0 >(scores_for_choice.topk (2 , -1 , true , false )).sum (-1 );
371+ auto group_idx = std::get<1 >(group_scores.topk (topk_group, -1 , true , false ));
372+ auto tmp_scores = at::zeros_like (scores, at::kFloat );
373+ auto group_idx_ptr = group_idx.data_ptr <int64_t >();
374+ auto tmp_scores_ptr = tmp_scores.data_ptr <float >();
375+ #pragma omp parallel for collapse(2)
376+ for (auto i = 0 ; i < n; i++) {
377+ for (auto j = 0 ; j < topk_group; j++) {
378+ auto selected_idx = group_idx_ptr[i * topk_group + j];
379+ auto k_start = selected_idx * group_size;
380+ auto k_end = k_start + group_size;
381+ for (auto k = k_start; k < k_end; k++) {
382+ tmp_scores_ptr[i * h + k] =
383+ scores_for_choice_ptr[i * scores_for_choice_stride0 + k];
384+ }
385+ }
386+ }
387+ auto topk = std::get<1 >(tmp_scores.topk (top_k, -1 , true , false ));
388+ auto topk_weight = scores.gather (1 , topk);
389+ return std::make_tuple (topk, topk_weight);
390+ }
391+
342392std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl (
343393 const at::Tensor& hidden_states,
344394 const at::Tensor& scores,
345395 const at::Tensor& routed_scaling_factor,
346396 const int64_t n_group,
347397 const int64_t topk_group,
348398 const int64_t n_routed_experts,
349- const int64_t top_k) {
399+ const int64_t top_k,
400+ c10::optional<at::Tensor> e_score_cbias) {
401+ if (e_score_cbias.has_value ()) { // deepseekv3
402+ if (hidden_states.scalar_type () == at::ScalarType::Float) {
403+ return deepseekv3_moegate_kernel<float >(
404+ scores,
405+ routed_scaling_factor,
406+ n_group,
407+ topk_group,
408+ n_routed_experts,
409+ top_k,
410+ e_score_cbias.value ());
411+ } else if (hidden_states.scalar_type () == at::ScalarType::BFloat16) {
412+ return deepseekv3_moegate_kernel<at::BFloat16>(
413+ scores,
414+ routed_scaling_factor,
415+ n_group,
416+ topk_group,
417+ n_routed_experts,
418+ top_k,
419+ e_score_cbias.value ());
420+ } else if (hidden_states.scalar_type () == at::ScalarType::Half) {
421+ return deepseekv3_moegate_kernel<at::Half>(
422+ scores,
423+ routed_scaling_factor,
424+ n_group,
425+ topk_group,
426+ n_routed_experts,
427+ top_k,
428+ e_score_cbias.value ());
429+ }
430+ auto n = hidden_states.size (0 );
431+ auto group_size = n_routed_experts / n_group;
432+ auto scores_for_choice =
433+ scores.view ({n, -1 }) + e_score_cbias.value ().unsqueeze (0 );
434+ auto group_scores = std::get<0 >(
435+ scores_for_choice.view ({n, n_group, -1 }).topk (2 , -1 , true , false ));
436+ group_scores = group_scores.sum (-1 );
437+ auto group_idx =
438+ std::get<1 >(group_scores.topk (topk_group, -1 , true , false ));
439+ auto group_mask = at::zeros_like (group_scores);
440+ group_mask.scatter_ (1 , group_idx, 1 );
441+ auto score_mask = group_mask.unsqueeze (-1 )
442+ .expand ({n, n_group, group_size})
443+ .reshape ({n, -1 });
444+ auto tmp_scores =
445+ scores_for_choice.masked_fill (~score_mask.to (at::kBool ), 0.0 );
446+ auto topk = std::get<1 >(tmp_scores.topk (top_k, -1 , true , false ));
447+ auto topk_weight = scores.gather (1 , topk);
448+ return std::make_tuple (topk, topk_weight.to (hidden_states.scalar_type ()));
449+ }
350450 if (hidden_states.scalar_type () == at::ScalarType::Float) {
351451 return deepseek_moegate_kernel<float >(
352- hidden_states,
353452 scores,
354453 routed_scaling_factor,
355454 n_group,
@@ -358,7 +457,14 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl(
358457 top_k);
359458 } else if (hidden_states.scalar_type () == at::ScalarType::BFloat16) {
360459 return deepseek_moegate_kernel<at::BFloat16>(
361- hidden_states,
460+ scores,
461+ routed_scaling_factor,
462+ n_group,
463+ topk_group,
464+ n_routed_experts,
465+ top_k);
466+ } else if (hidden_states.scalar_type () == at::ScalarType::Half) {
467+ return deepseek_moegate_kernel<at::Half>(
362468 scores,
363469 routed_scaling_factor,
364470 n_group,
0 commit comments