Skip to content

Commit d9684a6

Browse files
authored
[LLM serving] Add Mamba/Jamba model kernel (#3576)
1 parent 776a1ce commit d9684a6

File tree

11 files changed

+1475
-447
lines changed

11 files changed

+1475
-447
lines changed

csrc/cpu/aten/Conv.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace torch_ipex {
99
namespace cpu {
1010

1111
IPEX_DEFINE_DISPATCH(causal_conv1d_update_kernel_stub);
12+
IPEX_DEFINE_DISPATCH(causal_conv1d_fn_kernel_stub);
1213
std::vector<int64_t> calc_conv_output_size(
1314
at::IntArrayRef input_size,
1415
at::IntArrayRef kernel_size,
@@ -514,21 +515,55 @@ at::Tensor convolution_forward(
514515
* @param conv_weights (dim, width)
515516
* @param conv_bias (dim,)
516517
* @param silu_activation If true, apply the SiLU activation function.
518+
* @param cache_seqlens (batch,) or None
517519
* @return (hidden_states, conv_states)
518520
*/
519521
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
520522
const at::Tensor& hidden_states,
521523
const at::Tensor& conv_states,
522524
const at::Tensor& conv_weights,
523525
const c10::optional<at::Tensor>& conv_bias,
524-
bool silu_activation) {
526+
bool silu_activation,
527+
const c10::optional<at::Tensor>& cache_seqlens) {
525528
RECORD_FUNCTION("causal_conv1d_update", c10::ArrayRef<c10::IValue>({}));
526529
return causal_conv1d_update_kernel_stub(
527530
kCPU,
528531
hidden_states,
529532
conv_states,
530533
conv_weights,
531534
conv_bias,
535+
silu_activation,
536+
cache_seqlens);
537+
}
538+
539+
/**
540+
* Official Python implementation: causal_conv1d_ref:
541+
* https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py#L133
542+
* @param x (batch, dim, seqlen)
543+
* @param conv_weights (dim, width)
544+
* @param conv_bias (dim,)
545+
* @param initial_states (batch, dim, width - 1)
546+
* @param final_states_out (batch, dim, width - 1)
547+
* @param silu_activation If true, apply the SiLU activation function.
548+
* @return (out, final_states_out)
549+
* out: (batch, dim, seqlen)
550+
* final_states_out: (batch, dim, width - 1)
551+
*/
552+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_fn(
553+
const at::Tensor& x,
554+
const at::Tensor& conv_weights,
555+
const c10::optional<at::Tensor>& conv_bias,
556+
const c10::optional<at::Tensor>& initial_states,
557+
const c10::optional<at::Tensor>& final_states_out,
558+
bool silu_activation) {
559+
RECORD_FUNCTION("causal_conv1d_fn", c10::ArrayRef<c10::IValue>({}));
560+
return causal_conv1d_fn_kernel_stub(
561+
kCPU,
562+
x,
563+
conv_weights,
564+
conv_bias,
565+
initial_states,
566+
final_states_out,
532567
silu_activation);
533568
}
534569

@@ -589,11 +624,17 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
589624
c10::DispatchKey::CPU,
590625
torch_ipex::cpu::convolution_forward_impl);
591626
m.def(
592-
"causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation) -> (Tensor, Tensor)");
627+
"causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation, Tensor? cache_seqlens=None) -> (Tensor, Tensor)");
593628
m.impl(
594629
"causal_conv1d_update",
595630
c10::DispatchKey::CPU,
596631
torch_ipex::cpu::causal_conv1d_update);
632+
m.def(
633+
"causal_conv1d_fn(Tensor x, Tensor conv_weights, Tensor? conv_bias, Tensor? initial_states, Tensor? final_states_out, bool silu_activation) -> (Tensor, Tensor)");
634+
m.impl(
635+
"causal_conv1d_fn",
636+
c10::DispatchKey::CPU,
637+
torch_ipex::cpu::causal_conv1d_fn);
597638
// bw
598639
m.def(
599640
"convolution_backward(Tensor input, Tensor weight, Tensor? bias, Tensor grad_output, bool[3] out_mask, "

csrc/cpu/aten/Conv.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
5757
const at::Tensor& conv_states,
5858
const at::Tensor& conv_weights,
5959
const c10::optional<at::Tensor>& conv_bias,
60+
bool silu_activation,
61+
const c10::optional<at::Tensor>& cache_seqlens);
62+
63+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_fn(
64+
const at::Tensor& x,
65+
const at::Tensor& conv_weights,
66+
const c10::optional<at::Tensor>& conv_bias,
67+
const c10::optional<at::Tensor>& initial_states,
68+
const c10::optional<at::Tensor>& final_states_out,
6069
bool silu_activation);
6170

6271
// IPEX customized convolution OP with n-D packed weight
@@ -108,9 +117,18 @@ using causal_conv1d_update_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
108117
const at::Tensor& conv_states,
109118
const at::Tensor& conv_weights,
110119
const c10::optional<at::Tensor>& conv_bias,
120+
bool silu_activation,
121+
const c10::optional<at::Tensor>& cache_seqlens);
122+
using causal_conv1d_fn_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
123+
const at::Tensor& x,
124+
const at::Tensor& conv_weights,
125+
const c10::optional<at::Tensor>& conv_bias,
126+
const c10::optional<at::Tensor>& initial_states,
127+
const c10::optional<at::Tensor>& final_states_out,
111128
bool silu_activation);
112129
IPEX_DECLARE_DISPATCH(
113130
causal_conv1d_update_kernel_fn,
114131
causal_conv1d_update_kernel_stub);
132+
IPEX_DECLARE_DISPATCH(causal_conv1d_fn_kernel_fn, causal_conv1d_fn_kernel_stub);
115133
} // namespace cpu
116134
} // namespace torch_ipex

0 commit comments

Comments
 (0)