Skip to content

Commit e39803d

Browse files
Sync back ipexcpu 2.5 rc2 final code from ipexgpu:releases/2.5.10+xpu_rc to master (#5074)
* Sync IPEX CPU 2.5 RC2 code (#4951) * rebase frontend from ipex-cpu release/2.5 to ipex-gpu releases/2.5.10+xpu_rc (#4957) * [CPU] Fix test/cpu/test_ipex_llm_quantization.py and test_ipex_optimize_transformers.py (#5016) * fix test_ipex_llm_quantization.py * fix test_ipex_optimize_transformers.py * change int4 parameter order in test --------- Co-authored-by: Xu Han <xu.han@intel.com>
1 parent f7919a0 commit e39803d

File tree

166 files changed

+23305
-5417
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

166 files changed

+23305
-5417
lines changed

cmake/cpu/Options.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include(CMakeDependentOption)
99

1010
option(BUILD_LIBXSMM_VIA_CMAKE "Build LIBXSMM via CMake" ON)
1111
option(USE_LIBXSMM "Enable LIBXSMM" ON)
12-
option(USE_DNNL_GRAPH_COMPILER "Build with DNNL Graph Compiler" ON)
12+
option(USE_DNNL_GRAPH_COMPILER "Build with DNNL Graph Compiler" ON)
1313
if(WIN32)
1414
set(USE_LIBXSMM ON)
1515
endif()

csrc/cpu/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ if(BUILD_CPU_WITH_ONECCL)
2020
find_package(oneCCL REQUIRED)
2121
list(APPEND DEPENDS_LIB oneCCL)
2222
list(APPEND DEPENDS_LIB mpi)
23+
set(RPATH_VALUE)
24+
list(APPEND RPATH_VALUE "$ORIGIN")
25+
list(APPEND RPATH_VALUE "$ORIGIN/../opt/mpi/lib")
26+
set(CMAKE_INSTALL_RPATH "${RPATH_VALUE}")
2327
endif()
2428

2529
# TODO: Once llga is merged into oneDNN, use oneDNN directly as the third_party of IPEX

csrc/cpu/aten/Linear.cpp

Lines changed: 114 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,8 @@ at::Tensor woq_linear_pack_weight(
398398
// Note that weight is already compressed
399399
int64_t K_int4_compressed = K / 2;
400400
int64_t N_int4 = N % block_n ? N / block_n * block_n + block_n : N;
401-
at::Tensor weight_int4 = at::empty(
402-
{N_int4, K_int4_compressed}, device(c10::kCPU).dtype(c10::kByte));
403-
int64_t weight_size_bytes = weight.numel();
404-
int64_t weight_int4_size_bytes = weight_int4.numel();
405-
int64_t pad_size_bytes = weight_int4_size_bytes - weight_size_bytes;
406-
std::memcpy(weight_int4.data_ptr(), weight.data_ptr(), weight_size_bytes);
407-
std::fill_n(
408-
(uint8_t*)weight_int4.data_ptr() + weight_size_bytes,
409-
pad_size_bytes,
410-
0);
401+
at::Tensor weight_int4 =
402+
at::pad(weight, {0, 0, 0, N_int4 - N}, "constant", 0);
411403
return woq_tpp_gemm_packB_stub(
412404
kCPU, weight_int4, weight_dtype, block_n, block_k, lowp_mode);
413405
}
@@ -491,7 +483,9 @@ at::Tensor woq_linear_kernel(
491483
int64_t lowp_mode,
492484
int64_t act_quant_mode,
493485
const c10::optional<at::Tensor>& compensation) {
494-
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
486+
int64_t quant_w_mode = zps_list[0].defined()
487+
? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
488+
: (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
495489
auto K = self.size(-1);
496490
auto M = self.numel() / K;
497491
auto in = self;
@@ -533,6 +527,63 @@ at::Tensor woq_linear_forward(
533527
->run(input);
534528
}
535529

530+
at::Tensor woq_linear_forward_v2(
531+
const at::Tensor& input,
532+
const at::Tensor& qweight,
533+
const c10::string_view& weight_dtype,
534+
const std::vector<int64_t>& weight_shape,
535+
const std::vector<at::Tensor>& weight_scales,
536+
const c10::optional<std::vector<at::Tensor>>& weight_zeros,
537+
const c10::optional<std::vector<at::Tensor>>& bias,
538+
const c10::optional<at::Tensor>& g_idx,
539+
int64_t group_size,
540+
int64_t lowp_mode,
541+
int64_t act_quant_mode,
542+
const c10::optional<at::Tensor>& compensation) {
543+
static const std::map<c10::string_view, int64_t> WOQ_DTYPE_MAP = {
544+
{"int8", WOQ_DTYPE_INT8},
545+
{"int4", WOQ_DTYPE_INT4},
546+
{"nf4", WOQ_DTYPE_NF4},
547+
};
548+
TORCH_CHECK(
549+
WOQ_DTYPE_MAP.find(weight_dtype) != WOQ_DTYPE_MAP.end(),
550+
"Unsupported weight dtype: ",
551+
weight_dtype);
552+
if (WOQ_DTYPE_MAP.at(weight_dtype) == WOQ_DTYPE_INT8 && lowp_mode == 3) {
553+
TORCH_CHECK(compensation.has_value() && compensation.value().defined());
554+
}
555+
static const at::Tensor empty_tensor = at::Tensor();
556+
// zp list of all dtypes = {fp32, fp16, bf16, int8}
557+
static const std::vector<at::Tensor> empty_zp_list = {
558+
empty_tensor, empty_tensor, empty_tensor, empty_tensor};
559+
// bias list of all dtypes = {fp32, fp16, bf16}
560+
static const std::vector<at::Tensor> empty_bias_list = {
561+
empty_tensor, empty_tensor, empty_tensor};
562+
if (weight_zeros.has_value()) {
563+
TORCH_CHECK(
564+
weight_zeros.value().size() == 4,
565+
"IPEX WOQ: expect list of zeros has length 4");
566+
}
567+
auto& zeros_list =
568+
weight_zeros.has_value() ? weight_zeros.value() : empty_zp_list;
569+
if (bias.has_value()) {
570+
TORCH_CHECK(
571+
bias.value().size() == 3, "IPEX WOQ: expect list of bias has length 3");
572+
}
573+
auto& bias_list = bias.has_value() ? bias.value() : empty_bias_list;
574+
return woq_linear_kernel(
575+
input,
576+
qweight,
577+
WOQ_DTYPE_MAP.at(weight_dtype),
578+
weight_scales,
579+
zeros_list,
580+
bias_list,
581+
group_size,
582+
lowp_mode,
583+
act_quant_mode,
584+
compensation);
585+
}
586+
536587
at::Tensor woq_linear_unary_kernel(
537588
const at::Tensor& self,
538589
const at::Tensor& weight,
@@ -559,7 +610,9 @@ at::Tensor woq_linear_unary_kernel(
559610
} else if (post_op == "silu") {
560611
post_op_fusion_type = WOQ_FUSE_SILU;
561612
}
562-
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
613+
int64_t quant_w_mode = zps_list[0].defined()
614+
? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
615+
: (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
563616
auto K = self.size(-1);
564617
auto M = self.numel() / K;
565618
auto in = self;
@@ -648,7 +701,9 @@ at::Tensor woq_linear_binary_kernel(
648701
} else if (post_op == "mul") {
649702
post_op_fusion_type = WOQ_FUSE_MUL;
650703
}
651-
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
704+
int64_t quant_w_mode = zps_list[0].defined()
705+
? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
706+
: (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
652707
auto K = self.size(-1);
653708
auto M = self.numel() / K;
654709
auto in = self;
@@ -782,6 +837,39 @@ at::Tensor woq_linear_forward(
782837
return op.call(cpu_cached_cast(target_type, input), op_context);
783838
}
784839

840+
at::Tensor woq_linear_forward_v2(
841+
const at::Tensor& input,
842+
const at::Tensor& qweight,
843+
const c10::string_view& weight_dtype,
844+
const std::vector<int64_t>& weight_shape,
845+
const std::vector<at::Tensor>& weight_scales,
846+
const c10::optional<std::vector<at::Tensor>>& weight_zeros,
847+
const c10::optional<std::vector<at::Tensor>>& bias,
848+
const c10::optional<at::Tensor>& g_idx,
849+
int64_t group_size,
850+
int64_t lowp_mode,
851+
int64_t act_quant_mode,
852+
const c10::optional<at::Tensor>& compensation) {
853+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
854+
static auto op = torch::Dispatcher::singleton()
855+
.findSchemaOrThrow("torch_ipex::woq_linear", "")
856+
.typed<decltype(woq_linear_forward_v2)>();
857+
auto target_type = get_autocast_dtype();
858+
return op.call(
859+
cpu_cached_cast(target_type, input),
860+
qweight,
861+
weight_dtype,
862+
weight_shape,
863+
weight_scales,
864+
weight_zeros,
865+
bias,
866+
g_idx,
867+
group_size,
868+
lowp_mode,
869+
act_quant_mode,
870+
compensation);
871+
}
872+
785873
at::Tensor woq_linear_gelu_forward(
786874
const at::Tensor& input,
787875
const at::Tensor& op_context) {
@@ -964,6 +1052,19 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
9641052
"woq_linear_mul",
9651053
c10::DispatchKey::AutocastCPU,
9661054
torch_ipex::autocast::woq_linear_mul_forward);
1055+
// the version without op_context
1056+
m.def(
1057+
"woq_linear(Tensor input, Tensor qweight, str weight_dtype, int[] weight_shape, Tensor[] weight_scales, "
1058+
"Tensor[]? weight_zeros, Tensor[]? bias, Tensor? g_idx, int group_size, int lowp_mode, int act_quant_mode, "
1059+
"Tensor? compensation = None) -> Tensor");
1060+
m.impl(
1061+
"woq_linear",
1062+
c10::DispatchKey::CPU,
1063+
torch_ipex::cpu::woq_linear_forward_v2);
1064+
m.impl(
1065+
"woq_linear",
1066+
c10::DispatchKey::AutocastCPU,
1067+
torch_ipex::autocast::woq_linear_forward_v2);
9671068
#endif
9681069
// fuse eltwise
9691070
m.def(

csrc/cpu/aten/Linear.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ at::Tensor woq_linear_forward(
8484
const at::Tensor& input,
8585
const at::Tensor& op_context);
8686

87+
at::Tensor woq_linear_forward_v2(
88+
const at::Tensor& input,
89+
const at::Tensor& qweight,
90+
const c10::string_view& weight_dtype,
91+
const std::vector<int64_t>& weight_shape,
92+
const std::vector<at::Tensor>& weight_scales,
93+
const c10::optional<std::vector<at::Tensor>>& weight_zeros,
94+
const c10::optional<std::vector<at::Tensor>>& bias,
95+
const c10::optional<at::Tensor>& g_idx,
96+
int64_t group_size,
97+
int64_t lowp_mode,
98+
int64_t act_quant_mode,
99+
const c10::optional<at::Tensor>& compensation);
100+
87101
at::Tensor woq_linear_gelu_forward(
88102
const at::Tensor& input,
89103
const at::Tensor& op_context);
@@ -252,6 +266,12 @@ IPEX_DECLARE_DISPATCH(
252266
#define WOQ_FUSE_ADD_ADD 0x20
253267
#define WOQ_FUSE_MUL 0x30
254268

269+
// weight quant mode
270+
#define QUANT_W_PER_CHANNEL 0
271+
#define QUANT_W_PER_K_BLOCK 1
272+
#define QUANT_W_PER_CHANNEL_SYM 2
273+
#define QUANT_W_PER_K_BLOCK_SYM 3
274+
255275
#define WOQ_N_BLOCK_SIZE 32
256276

257277
#endif

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,24 @@ inline Vectorized<scalar_t> exp_u20(Vectorized<scalar_t> data) {
130130
inline Vectorized<float> exp_u20(Vectorized<float> data) {
131131
__m512 values = __m512(data);
132132
// A faster version of exp with ULP=20
133-
static __m512 vec_factorial_1 =
134-
_mm512_set1_ps(0.999999701f); // 1/factorial(1)
135-
static __m512 vec_factorial_2 =
136-
_mm512_set1_ps(0.499991506f); // 1/factorial(2)
137-
static __m512 vec_factorial_3 =
138-
_mm512_set1_ps(0.166676521f); // 1/factorial(3)
139-
static __m512 vec_factorial_4 =
133+
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); // 1/factorial(1)
134+
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); // 1/factorial(2)
135+
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); // 1/factorial(3)
136+
const __m512 vec_factorial_4 =
140137
_mm512_set1_ps(0.0418978221f); // 1/factorial(4)
141-
static __m512 vec_factorial_5 =
138+
const __m512 vec_factorial_5 =
142139
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
143-
static __m512 vec_exp_log2ef =
140+
const __m512 vec_exp_log2ef =
144141
(__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
145-
static __m512 vec_half = _mm512_set1_ps(0.5f);
146-
static __m512 vec_one = _mm512_set1_ps(1.f);
147-
static __m512 vec_zero = _mm512_set1_ps(0.f);
148-
static __m512 vec_two = _mm512_set1_ps(2.f);
149-
static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
150-
static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
151-
static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
152-
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
153-
static int n_mantissa_bits = 23;
142+
const __m512 vec_half = _mm512_set1_ps(0.5f);
143+
const __m512 vec_one = _mm512_set1_ps(1.f);
144+
const __m512 vec_zero = _mm512_set1_ps(0.f);
145+
const __m512 vec_two = _mm512_set1_ps(2.f);
146+
const __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
147+
const __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
148+
const __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
149+
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
150+
const int n_mantissa_bits = 23;
154151

155152
// exp(x) =
156153
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem

0 commit comments

Comments
 (0)