Skip to content

Commit ca66926

Browse files
unification cpu code for release/2.6.10+xpu_rc round 2 (#5316)(#5345)
* sync csrc/cpu. * sync cmake/cpu * sync other cmake changes. * sync test/cpu. * sync examples/cpu. * frontend sync with IPEX CPU 2.6 (#5326) * fix test_dyndisp.py --------- Co-authored-by: Han, Xu <xu.han@intel.com>
1 parent 0dc3a7b commit ca66926

File tree

33 files changed

+629
-871
lines changed

33 files changed

+629
-871
lines changed

cmake/Modules/FindoneMKL.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ endfunction()
7272
# IPEX CPU lib always download and install mkl-static lib and use static linker for mkl-static lib.
7373
# IPEX CPU lib can manual config to use the dynamic link for oneMKL lib.
7474
if(BUILD_MODULE_TYPE STREQUAL "GPU")
75+
set(USE_SYSTEM_MKL ON)
76+
endif()
77+
78+
if(USE_SYSTEM_MKL)
7579
get_mkl_from_env_var()
7680
else()
7781
if(BUILD_WITH_XPU)

cmake/cpu/Options.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ set(Options_CPU_cmake_included true)
77
# The options to build cpu
88
include(CMakeDependentOption)
99

10+
option(USE_SYSTEM_LIBXSMM "Use system LIBXSMM library" OFF)
11+
option(USE_SYSTEM_ONEDNN "Use system oneDNN library" OFF)
12+
option(USE_SYSTEM_SLEEF "Use system SLEEF library" OFF)
13+
option(USE_SYSTEM_MKL "Use system MKL library" OFF)
14+
option(USE_SYSTEM_IDEEP "Use system ideep library" OFF)
15+
option(USE_SYSTEM_GTEST "Use system GoogleTest library" OFF)
16+
1017
option(BUILD_LIBXSMM_VIA_CMAKE "Build LIBXSMM via CMake" ON)
1118
option(USE_LIBXSMM "Enable LIBXSMM" ON)
1219
option(USE_DNNL_GRAPH_COMPILER "Build with DNNL Graph Compiler" ON)

csrc/cpu/CMakeLists.txt

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,20 @@ if((DEFINED ENV{DNNL_GRAPH_BUILD_COMPILER_BACKEND}) AND USE_DNNL_GRAPH_COMPILER)
3737
endif()
3838

3939
set(THIRD_PARTY_BUILD_PATH_NAME "cpu_third_party")
40-
add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn ${THIRD_PARTY_BUILD_PATH_NAME}/ideep/mkl-dnn EXCLUDE_FROM_ALL)
41-
# add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/mkl-dnn cpu_third_party/mkl-dnn)
40+
if(USE_SYSTEM_ONEDNN)
41+
find_package(dnnl 3.4.1 CONFIG REQUIRED)
42+
get_target_property(ONEDNN_INCLUDE_DIR DNNL::dnnl INTERFACE_INCLUDE_DIRECTORIES)
43+
set(ONEDNN_LIBRARY DNNL::dnnl)
44+
set(ONEDNN_GENERATED_INCLUDE ${ONEDNN_INCLUDE_DIR})
45+
else()
46+
add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn ${THIRD_PARTY_BUILD_PATH_NAME}/ideep/mkl-dnn EXCLUDE_FROM_ALL)
47+
set(ONEDNN_INCLUDE_DIR ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/include)
48+
set(ONEDNN_LIBRARY dnnl)
49+
50+
# path of oneDNN .h.in generated file
51+
file(RELATIVE_PATH CUR_DIR_REL_PATH "${IPEX_ROOT_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}")
52+
set(ONEDNN_GENERATED_INCLUDE "${CMAKE_BINARY_DIR}/${CUR_DIR_REL_PATH}/${THIRD_PARTY_BUILD_PATH_NAME}/ideep/mkl-dnn/include")
53+
endif()
4254

4355
IF(IPEX_DISP_OP)
4456
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIPEX_DISP_OP")
@@ -120,7 +132,7 @@ add_library(${PLUGIN_NAME_CPU} SHARED ${IPEX_CPU_CPP_SRCS})
120132
# For IPEX_API macro
121133
target_compile_definitions(${PLUGIN_NAME_CPU} PUBLIC "BUILD_IPEX_MAIN_LIB")
122134

123-
set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES ONEDNN_INCLUDE_DIR "${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/include")
135+
set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES ONEDNN_INCLUDE_DIR ${ONEDNN_INCLUDE_DIR})
124136

125137
# includes
126138
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_ROOT_DIR})
@@ -133,19 +145,21 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/jit)
133145
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_JIT_CPP_ROOT})
134146
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_UTLIS_CPP_ROOT})
135147

136-
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/include)
148+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_INCLUDE_DIR})
137149

138150
if(USE_LIBXSMM)
139151
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/tpp)
140-
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
152+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${LIBXSMM_INCLUDE_DIRS})
141153
endif(USE_LIBXSMM)
142154

143-
# path of oneDNN .h.in generated file
144-
file(RELATIVE_PATH CUR_DIR_REL_PATH "${IPEX_ROOT_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}")
145-
set(ONEDNN_GENERATED_INCLUDE "${CMAKE_BINARY_DIR}/${CUR_DIR_REL_PATH}/${THIRD_PARTY_BUILD_PATH_NAME}/ideep/mkl-dnn/include")
146155
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_GENERATED_INCLUDE})
147156

148-
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/include)
157+
if(USE_SYSTEM_IDEEP)
158+
find_path(IDEEP_INCLUDE_DIR ideep.hpp REQUIRED)
159+
else()
160+
set(IDEEP_INCLUDE_DIR ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/include)
161+
endif()
162+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IDEEP_INCLUDE_DIR})
149163
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${PYTHON_INCLUDE_DIR})
150164

151165
if(BUILD_CPU_WITH_ONECCL)
@@ -165,12 +179,17 @@ if(CLANG_FORMAT)
165179
endif()
166180

167181
if(USE_LIBXSMM)
168-
if(BUILD_LIBXSMM_VIA_CMAKE)
182+
if(USE_SYSTEM_LIBXSMM)
183+
find_package(PkgConfig REQUIRED)
184+
pkg_check_modules(LIBXSMM REQUIRED libxsmm)
185+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${LIBXSMM_INCLUDE_DIRS})
186+
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${LIBXSMM_LIBRARIES})
187+
elseif(BUILD_LIBXSMM_VIA_CMAKE)
169188
add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm cpu_third_party/libxsmm EXCLUDE_FROM_ALL)
170189
add_definitions(-DLIBXSMM_DEFAULT_CONFIG)
171-
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
190+
set(LIBXSMM_INCLUDE_DIRS ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
172191
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE xsmm)
173-
else(BUILD_LIBXSMM_VIA_CMAKE)
192+
else()
174193
include(${CMAKE_ROOT}/Modules/ExternalProject.cmake)
175194
set(args
176195
CC=${CMAKE_C_COMPILER}
@@ -188,20 +207,31 @@ if(USE_LIBXSMM)
188207
${args}
189208
INSTALL_COMMAND ""
190209
)
210+
set(LIBXSMM_INCLUDE_DIRS ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
191211
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/lib/libxsmm.a)
192212
endif(BUILD_LIBXSMM_VIA_CMAKE)
193213
endif(USE_LIBXSMM)
194214

195-
# setup sleef options:
196-
set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef as static library" FORCE)
197-
set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
198-
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
199-
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
200-
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
201-
add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/sleef ${THIRD_PARTY_BUILD_PATH_NAME}/sleef EXCLUDE_FROM_ALL)
202-
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE sleef)
215+
if(USE_SYSTEM_SLEEF)
216+
find_package(PkgConfig REQUIRED)
217+
pkg_check_modules(SLEEF REQUIRED sleef)
218+
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${SLEEF_INCLUDE_DIRS})
219+
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${SLEEF_LIBRARIES})
220+
else()
221+
# setup sleef options:
222+
set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef as static library" FORCE)
223+
set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
224+
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
225+
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
226+
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
227+
add_subdirectory(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/sleef ${THIRD_PARTY_BUILD_PATH_NAME}/sleef EXCLUDE_FROM_ALL)
228+
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE sleef)
229+
endif()
230+
231+
if(NOT USE_SYSTEM_ONEDNN)
232+
add_dependencies(${PLUGIN_NAME_CPU} dnnl)
233+
endif()
203234

204-
add_dependencies(${PLUGIN_NAME_CPU} dnnl)
205235
# If Graph Compiler is built, then it should link to its LLVM dependencies,
206236
# and not the LLVM symbols exposed by PyTorch.
207237
if ((DEFINED ENV{DNNL_GRAPH_BUILD_COMPILER_BACKEND}) AND USE_DNNL_GRAPH_COMPILER)
@@ -213,7 +243,7 @@ if ((DEFINED ENV{DNNL_GRAPH_BUILD_COMPILER_BACKEND}) AND USE_DNNL_GRAPH_COMPILER
213243
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs=${DNNL_GRAPHCOMPILER_LLVM_LIB_EXCLUDE}")
214244
endif()
215245
else()
216-
target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC dnnl)
246+
target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_LIBRARY})
217247
endif()
218248
find_package(oneMKL QUIET)
219249
if (ONEMKL_FOUND)

csrc/cpu/aten/MoE.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
298298
const int64_t n_group,
299299
const int64_t topk_group,
300300
const int64_t n_routed_experts,
301-
const int64_t top_k) {
301+
const int64_t top_k,
302+
c10::optional<at::Tensor> e_score_cbias) {
302303
RECORD_FUNCTION("ipex::deepseek_moegate", c10::ArrayRef<c10::IValue>({}));
303304

304305
return deepseek_moegate_kernel_stub(
@@ -309,7 +310,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
309310
n_group,
310311
topk_group,
311312
n_routed_experts,
312-
top_k);
313+
top_k,
314+
e_score_cbias);
313315
}
314316
} // namespace cpu
315317
} // namespace torch_ipex
@@ -374,7 +376,7 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
374376
c10::DispatchKey::CPU,
375377
torch_ipex::cpu::deepseek_moe_woq);
376378
m.def(
377-
"deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k) -> (Tensor, Tensor)");
379+
"deepseek_moegate(Tensor hidden_states, Tensor scores, Tensor routed_scaling_factor, int n_group, int topk_group, int n_routed_experts, int top_k, Tensor? e_score_cbias=None) -> (Tensor, Tensor)");
378380
m.impl(
379381
"deepseek_moegate",
380382
c10::DispatchKey::CPU,

csrc/cpu/aten/MoE.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate(
9797
const int64_t n_group,
9898
const int64_t topk_group,
9999
const int64_t n_routed_experts,
100-
const int64_t top_k);
100+
const int64_t top_k,
101+
c10::optional<at::Tensor> e_score_cbias);
101102
using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
102103
const at::Tensor& hidden_states,
103104
const at::Tensor& top_x,
@@ -179,7 +180,8 @@ using deepseek_moegate_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
179180
const int64_t n_group,
180181
const int64_t topk_group,
181182
const int64_t n_routed_experts,
182-
const int64_t top_k);
183+
const int64_t top_k,
184+
c10::optional<at::Tensor> e_score_cbias);
183185
IPEX_DECLARE_DISPATCH(mixtral_moe_tpp_kernel_fn, mixtral_moe_tpp_kernel_stub);
184186
IPEX_DECLARE_DISPATCH(deepseek_moe_tpp_kernel_fn, deepseek_moe_tpp_kernel_stub);
185187
IPEX_DECLARE_DISPATCH(mixtral_moe_woq_kernel_fn, mixtral_moe_woq_kernel_stub);

csrc/cpu/aten/kernels/MoEKrnl.cpp

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ at::Tensor mixtral_moe_woq_kernl_impl(
292292

293293
template <typename T>
294294
std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
295-
const at::Tensor& hidden_states,
296295
const at::Tensor& scores,
297296
const at::Tensor& routed_scaling_factor,
298297
const int64_t n_group,
@@ -302,7 +301,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
302301
auto group_size = n_routed_experts / n_group;
303302
auto n = scores.size(0);
304303
auto h = scores.size(1);
305-
auto group_scores = at::empty({n, n_group}, hidden_states.options());
304+
auto group_scores = at::empty({n, n_group}, scores.options());
306305
auto group_scores_ptr = group_scores.data_ptr<T>();
307306
auto scores_ptr = scores.data_ptr<T>();
308307
#pragma omp parallel for collapse(2)
@@ -319,7 +318,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
319318
}
320319

321320
auto group_idx = std::get<1>(group_scores.topk(topk_group, -1, true, false));
322-
auto tmp_scores = at::zeros_like(scores, hidden_states.options());
321+
auto tmp_scores = at::zeros_like(scores, scores.options());
323322
auto group_idx_ptr = group_idx.data_ptr<int64_t>();
324323
auto tmp_scores_ptr = tmp_scores.data_ptr<T>();
325324
T scale = routed_scaling_factor.item<T>();
@@ -339,17 +338,117 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel(
339338
return std::make_tuple(topk, topk_weight);
340339
}
341340

341+
template <typename T>
342+
std::tuple<at::Tensor, at::Tensor> deepseekv3_moegate_kernel(
343+
const at::Tensor& scores,
344+
const at::Tensor& routed_scaling_factor,
345+
const int64_t n_group,
346+
const int64_t topk_group,
347+
const int64_t n_routed_experts,
348+
const int64_t top_k,
349+
const at::Tensor& e_score_cbias) {
350+
auto group_size = n_routed_experts / n_group;
351+
auto n = scores.size(0);
352+
auto h = scores.size(1);
353+
auto scores_for_choice = at::empty({n, n_group, group_size}, at::kFloat);
354+
auto scores_ptr = scores.data_ptr<T>();
355+
auto scores_for_choice_ptr = scores_for_choice.data_ptr<float>();
356+
auto scores_for_choice_stride0 = scores_for_choice.stride(0);
357+
auto e_score_cbias_ptr = e_score_cbias.data_ptr<float>();
358+
#pragma omp parallel for collapse(2)
359+
for (auto i = 0; i < n; i++) {
360+
for (auto j = 0; j < n_group; j++) {
361+
auto k_start = j * group_size;
362+
auto k_end = k_start + group_size;
363+
for (auto k = k_start; k < k_end; k++) {
364+
scores_for_choice_ptr[i * scores_for_choice_stride0 + k] =
365+
scores_ptr[i * h + k] + e_score_cbias_ptr[k];
366+
}
367+
}
368+
}
369+
auto group_scores =
370+
std::get<0>(scores_for_choice.topk(2, -1, true, false)).sum(-1);
371+
auto group_idx = std::get<1>(group_scores.topk(topk_group, -1, true, false));
372+
auto tmp_scores = at::zeros_like(scores, at::kFloat);
373+
auto group_idx_ptr = group_idx.data_ptr<int64_t>();
374+
auto tmp_scores_ptr = tmp_scores.data_ptr<float>();
375+
#pragma omp parallel for collapse(2)
376+
for (auto i = 0; i < n; i++) {
377+
for (auto j = 0; j < topk_group; j++) {
378+
auto selected_idx = group_idx_ptr[i * topk_group + j];
379+
auto k_start = selected_idx * group_size;
380+
auto k_end = k_start + group_size;
381+
for (auto k = k_start; k < k_end; k++) {
382+
tmp_scores_ptr[i * h + k] =
383+
scores_for_choice_ptr[i * scores_for_choice_stride0 + k];
384+
}
385+
}
386+
}
387+
auto topk = std::get<1>(tmp_scores.topk(top_k, -1, true, false));
388+
auto topk_weight = scores.gather(1, topk);
389+
return std::make_tuple(topk, topk_weight);
390+
}
391+
342392
std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl(
343393
const at::Tensor& hidden_states,
344394
const at::Tensor& scores,
345395
const at::Tensor& routed_scaling_factor,
346396
const int64_t n_group,
347397
const int64_t topk_group,
348398
const int64_t n_routed_experts,
349-
const int64_t top_k) {
399+
const int64_t top_k,
400+
c10::optional<at::Tensor> e_score_cbias) {
401+
if (e_score_cbias.has_value()) { // deepseekv3
402+
if (hidden_states.scalar_type() == at::ScalarType::Float) {
403+
return deepseekv3_moegate_kernel<float>(
404+
scores,
405+
routed_scaling_factor,
406+
n_group,
407+
topk_group,
408+
n_routed_experts,
409+
top_k,
410+
e_score_cbias.value());
411+
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
412+
return deepseekv3_moegate_kernel<at::BFloat16>(
413+
scores,
414+
routed_scaling_factor,
415+
n_group,
416+
topk_group,
417+
n_routed_experts,
418+
top_k,
419+
e_score_cbias.value());
420+
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
421+
return deepseekv3_moegate_kernel<at::Half>(
422+
scores,
423+
routed_scaling_factor,
424+
n_group,
425+
topk_group,
426+
n_routed_experts,
427+
top_k,
428+
e_score_cbias.value());
429+
}
430+
auto n = hidden_states.size(0);
431+
auto group_size = n_routed_experts / n_group;
432+
auto scores_for_choice =
433+
scores.view({n, -1}) + e_score_cbias.value().unsqueeze(0);
434+
auto group_scores = std::get<0>(
435+
scores_for_choice.view({n, n_group, -1}).topk(2, -1, true, false));
436+
group_scores = group_scores.sum(-1);
437+
auto group_idx =
438+
std::get<1>(group_scores.topk(topk_group, -1, true, false));
439+
auto group_mask = at::zeros_like(group_scores);
440+
group_mask.scatter_(1, group_idx, 1);
441+
auto score_mask = group_mask.unsqueeze(-1)
442+
.expand({n, n_group, group_size})
443+
.reshape({n, -1});
444+
auto tmp_scores =
445+
scores_for_choice.masked_fill(~score_mask.to(at::kBool), 0.0);
446+
auto topk = std::get<1>(tmp_scores.topk(top_k, -1, true, false));
447+
auto topk_weight = scores.gather(1, topk);
448+
return std::make_tuple(topk, topk_weight.to(hidden_states.scalar_type()));
449+
}
350450
if (hidden_states.scalar_type() == at::ScalarType::Float) {
351451
return deepseek_moegate_kernel<float>(
352-
hidden_states,
353452
scores,
354453
routed_scaling_factor,
355454
n_group,
@@ -358,7 +457,14 @@ std::tuple<at::Tensor, at::Tensor> deepseek_moegate_kernel_impl(
358457
top_k);
359458
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
360459
return deepseek_moegate_kernel<at::BFloat16>(
361-
hidden_states,
460+
scores,
461+
routed_scaling_factor,
462+
n_group,
463+
topk_group,
464+
n_routed_experts,
465+
top_k);
466+
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
467+
return deepseek_moegate_kernel<at::Half>(
362468
scores,
363469
routed_scaling_factor,
364470
n_group,

csrc/cpu/aten/utils/woq.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,11 +3454,21 @@ static at::Tensor woq_gemm_ref_impl(
34543454
at::silu_(y);
34553455
} else if (fusion_type == WOQ_FUSE_ADD || fusion_type == WOQ_FUSE_ADD_ADD) {
34563456
for (auto& tin : others_list) {
3457-
y = at::add(y, tin.view(y.sizes()));
3457+
auto tin_view = tin.view({-1, y.size(-1)});
3458+
if (tin_view.size(0) < y.size(0)) {
3459+
tin_view = at::pad(
3460+
tin_view, {0, 0, 0, y.size(0) - tin_view.size(0)}, "constant", 0);
3461+
}
3462+
y = at::add(y, tin_view);
34583463
}
34593464
} else if (fusion_type == WOQ_FUSE_MUL) {
34603465
for (auto& tin : others_list) {
3461-
y = at::mul(y, tin.view(y.sizes()));
3466+
auto tin_view = tin.view({-1, y.size(-1)});
3467+
if (tin_view.size(0) < y.size(0)) {
3468+
tin_view = at::pad(
3469+
tin_view, {0, 0, 0, y.size(0) - tin_view.size(0)}, "constant", 0);
3470+
}
3471+
y = at::mul(y, tin_view);
34623472
}
34633473
} else {
34643474
TORCH_CHECK(

0 commit comments

Comments
 (0)