@@ -398,16 +398,8 @@ at::Tensor woq_linear_pack_weight(
398
398
// Note that weight is already compressed
399
399
int64_t K_int4_compressed = K / 2 ;
400
400
int64_t N_int4 = N % block_n ? N / block_n * block_n + block_n : N;
401
- at::Tensor weight_int4 = at::empty (
402
- {N_int4, K_int4_compressed}, device (c10::kCPU ).dtype (c10::kByte ));
403
- int64_t weight_size_bytes = weight.numel ();
404
- int64_t weight_int4_size_bytes = weight_int4.numel ();
405
- int64_t pad_size_bytes = weight_int4_size_bytes - weight_size_bytes;
406
- std::memcpy (weight_int4.data_ptr (), weight.data_ptr (), weight_size_bytes);
407
- std::fill_n (
408
- (uint8_t *)weight_int4.data_ptr () + weight_size_bytes,
409
- pad_size_bytes,
410
- 0 );
401
+ at::Tensor weight_int4 =
402
+ at::pad (weight, {0 , 0 , 0 , N_int4 - N}, " constant" , 0 );
411
403
return woq_tpp_gemm_packB_stub (
412
404
kCPU , weight_int4, weight_dtype, block_n, block_k, lowp_mode);
413
405
}
@@ -491,7 +483,9 @@ at::Tensor woq_linear_kernel(
491
483
int64_t lowp_mode,
492
484
int64_t act_quant_mode,
493
485
const c10::optional<at::Tensor>& compensation) {
494
- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
486
+ int64_t quant_w_mode = zps_list[0 ].defined ()
487
+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
488
+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
495
489
auto K = self.size (-1 );
496
490
auto M = self.numel () / K;
497
491
auto in = self;
@@ -533,6 +527,63 @@ at::Tensor woq_linear_forward(
533
527
->run (input);
534
528
}
535
529
530
+ at::Tensor woq_linear_forward_v2 (
531
+ const at::Tensor& input,
532
+ const at::Tensor& qweight,
533
+ const c10::string_view& weight_dtype,
534
+ const std::vector<int64_t >& weight_shape,
535
+ const std::vector<at::Tensor>& weight_scales,
536
+ const c10::optional<std::vector<at::Tensor>>& weight_zeros,
537
+ const c10::optional<std::vector<at::Tensor>>& bias,
538
+ const c10::optional<at::Tensor>& g_idx,
539
+ int64_t group_size,
540
+ int64_t lowp_mode,
541
+ int64_t act_quant_mode,
542
+ const c10::optional<at::Tensor>& compensation) {
543
+ static const std::map<c10::string_view, int64_t > WOQ_DTYPE_MAP = {
544
+ {" int8" , WOQ_DTYPE_INT8},
545
+ {" int4" , WOQ_DTYPE_INT4},
546
+ {" nf4" , WOQ_DTYPE_NF4},
547
+ };
548
+ TORCH_CHECK (
549
+ WOQ_DTYPE_MAP.find (weight_dtype) != WOQ_DTYPE_MAP.end (),
550
+ " Unsupported weight dtype: " ,
551
+ weight_dtype);
552
+ if (WOQ_DTYPE_MAP.at (weight_dtype) == WOQ_DTYPE_INT8 && lowp_mode == 3 ) {
553
+ TORCH_CHECK (compensation.has_value () && compensation.value ().defined ());
554
+ }
555
+ static const at::Tensor empty_tensor = at::Tensor ();
556
+ // zp list of all dtypes = {fp32, fp16, bf16, int8}
557
+ static const std::vector<at::Tensor> empty_zp_list = {
558
+ empty_tensor, empty_tensor, empty_tensor, empty_tensor};
559
+ // bias list of all dtypes = {fp32, fp16, bf16}
560
+ static const std::vector<at::Tensor> empty_bias_list = {
561
+ empty_tensor, empty_tensor, empty_tensor};
562
+ if (weight_zeros.has_value ()) {
563
+ TORCH_CHECK (
564
+ weight_zeros.value ().size () == 4 ,
565
+ " IPEX WOQ: expect list of zeros has length 4" );
566
+ }
567
+ auto & zeros_list =
568
+ weight_zeros.has_value () ? weight_zeros.value () : empty_zp_list;
569
+ if (bias.has_value ()) {
570
+ TORCH_CHECK (
571
+ bias.value ().size () == 3 , " IPEX WOQ: expect list of bias has length 3" );
572
+ }
573
+ auto & bias_list = bias.has_value () ? bias.value () : empty_bias_list;
574
+ return woq_linear_kernel (
575
+ input,
576
+ qweight,
577
+ WOQ_DTYPE_MAP.at (weight_dtype),
578
+ weight_scales,
579
+ zeros_list,
580
+ bias_list,
581
+ group_size,
582
+ lowp_mode,
583
+ act_quant_mode,
584
+ compensation);
585
+ }
586
+
536
587
at::Tensor woq_linear_unary_kernel (
537
588
const at::Tensor& self,
538
589
const at::Tensor& weight,
@@ -559,7 +610,9 @@ at::Tensor woq_linear_unary_kernel(
559
610
} else if (post_op == " silu" ) {
560
611
post_op_fusion_type = WOQ_FUSE_SILU;
561
612
}
562
- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
613
+ int64_t quant_w_mode = zps_list[0 ].defined ()
614
+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
615
+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
563
616
auto K = self.size (-1 );
564
617
auto M = self.numel () / K;
565
618
auto in = self;
@@ -648,7 +701,9 @@ at::Tensor woq_linear_binary_kernel(
648
701
} else if (post_op == " mul" ) {
649
702
post_op_fusion_type = WOQ_FUSE_MUL;
650
703
}
651
- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
704
+ int64_t quant_w_mode = zps_list[0 ].defined ()
705
+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
706
+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
652
707
auto K = self.size (-1 );
653
708
auto M = self.numel () / K;
654
709
auto in = self;
@@ -782,6 +837,39 @@ at::Tensor woq_linear_forward(
782
837
return op.call (cpu_cached_cast (target_type, input), op_context);
783
838
}
784
839
840
+ at::Tensor woq_linear_forward_v2 (
841
+ const at::Tensor& input,
842
+ const at::Tensor& qweight,
843
+ const c10::string_view& weight_dtype,
844
+ const std::vector<int64_t >& weight_shape,
845
+ const std::vector<at::Tensor>& weight_scales,
846
+ const c10::optional<std::vector<at::Tensor>>& weight_zeros,
847
+ const c10::optional<std::vector<at::Tensor>>& bias,
848
+ const c10::optional<at::Tensor>& g_idx,
849
+ int64_t group_size,
850
+ int64_t lowp_mode,
851
+ int64_t act_quant_mode,
852
+ const c10::optional<at::Tensor>& compensation) {
853
+ c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
854
+ static auto op = torch::Dispatcher::singleton ()
855
+ .findSchemaOrThrow (" torch_ipex::woq_linear" , " " )
856
+ .typed <decltype (woq_linear_forward_v2)>();
857
+ auto target_type = get_autocast_dtype ();
858
+ return op.call (
859
+ cpu_cached_cast (target_type, input),
860
+ qweight,
861
+ weight_dtype,
862
+ weight_shape,
863
+ weight_scales,
864
+ weight_zeros,
865
+ bias,
866
+ g_idx,
867
+ group_size,
868
+ lowp_mode,
869
+ act_quant_mode,
870
+ compensation);
871
+ }
872
+
785
873
at::Tensor woq_linear_gelu_forward (
786
874
const at::Tensor& input,
787
875
const at::Tensor& op_context) {
@@ -964,6 +1052,19 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
964
1052
" woq_linear_mul" ,
965
1053
c10::DispatchKey::AutocastCPU,
966
1054
torch_ipex::autocast::woq_linear_mul_forward);
1055
+ // the version without op_context
1056
+ m.def (
1057
+ " woq_linear(Tensor input, Tensor qweight, str weight_dtype, int[] weight_shape, Tensor[] weight_scales, "
1058
+ " Tensor[]? weight_zeros, Tensor[]? bias, Tensor? g_idx, int group_size, int lowp_mode, int act_quant_mode, "
1059
+ " Tensor? compensation = None) -> Tensor" );
1060
+ m.impl (
1061
+ " woq_linear" ,
1062
+ c10::DispatchKey::CPU,
1063
+ torch_ipex::cpu::woq_linear_forward_v2);
1064
+ m.impl (
1065
+ " woq_linear" ,
1066
+ c10::DispatchKey::AutocastCPU,
1067
+ torch_ipex::autocast::woq_linear_forward_v2);
967
1068
#endif
968
1069
// fuse eltwise
969
1070
m.def (
0 commit comments