@@ -9,6 +9,7 @@ namespace torch_ipex {
9
9
namespace cpu {
10
10
11
11
IPEX_DEFINE_DISPATCH (causal_conv1d_update_kernel_stub);
12
+ IPEX_DEFINE_DISPATCH (causal_conv1d_fn_kernel_stub);
12
13
std::vector<int64_t > calc_conv_output_size (
13
14
at::IntArrayRef input_size,
14
15
at::IntArrayRef kernel_size,
@@ -514,21 +515,55 @@ at::Tensor convolution_forward(
514
515
* @param conv_weights (dim, width)
515
516
* @param conv_bias (dim,)
516
517
* @param silu_activation If true, apply the SiLU activation function.
518
+ * @param cache_seqlens (batch,) or None
517
519
* @return (hidden_states, conv_states)
518
520
*/
519
521
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update (
520
522
const at::Tensor& hidden_states,
521
523
const at::Tensor& conv_states,
522
524
const at::Tensor& conv_weights,
523
525
const c10::optional<at::Tensor>& conv_bias,
524
- bool silu_activation) {
526
+ bool silu_activation,
527
+ const c10::optional<at::Tensor>& cache_seqlens) {
525
528
RECORD_FUNCTION (" causal_conv1d_update" , c10::ArrayRef<c10::IValue>({}));
526
529
return causal_conv1d_update_kernel_stub (
527
530
kCPU ,
528
531
hidden_states,
529
532
conv_states,
530
533
conv_weights,
531
534
conv_bias,
535
+ silu_activation,
536
+ cache_seqlens);
537
+ }
538
+
539
+ /* *
540
+ * Official Python implementation: causal_conv1d_ref:
541
+ * https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py#L133
542
+ * @param x (batch, dim, seqlen)
543
+ * @param conv_weights (dim, width)
544
+ * @param conv_bias (dim,)
545
+ * @param initial_states (batch, dim, width - 1)
546
+ * @param final_states_out (batch, dim, width - 1)
547
+ * @param silu_activation If true, apply the SiLU activation function.
548
+ * @return (out, final_states_out)
549
+ * out: (batch, dim, seqlen)
550
+ * final_states_out: (batch, dim, width - 1)
551
+ */
552
+ std::tuple<at::Tensor, at::Tensor> causal_conv1d_fn (
553
+ const at::Tensor& x,
554
+ const at::Tensor& conv_weights,
555
+ const c10::optional<at::Tensor>& conv_bias,
556
+ const c10::optional<at::Tensor>& initial_states,
557
+ const c10::optional<at::Tensor>& final_states_out,
558
+ bool silu_activation) {
559
+ RECORD_FUNCTION (" causal_conv1d_fn" , c10::ArrayRef<c10::IValue>({}));
560
+ return causal_conv1d_fn_kernel_stub (
561
+ kCPU ,
562
+ x,
563
+ conv_weights,
564
+ conv_bias,
565
+ initial_states,
566
+ final_states_out,
532
567
silu_activation);
533
568
}
534
569
@@ -589,11 +624,17 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
589
624
c10::DispatchKey::CPU,
590
625
torch_ipex::cpu::convolution_forward_impl);
591
626
m.def (
592
- " causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation) -> (Tensor, Tensor)" );
627
+ " causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation, Tensor? cache_seqlens=None ) -> (Tensor, Tensor)" );
593
628
m.impl (
594
629
" causal_conv1d_update" ,
595
630
c10::DispatchKey::CPU,
596
631
torch_ipex::cpu::causal_conv1d_update);
632
+ m.def (
633
+ " causal_conv1d_fn(Tensor x, Tensor conv_weights, Tensor? conv_bias, Tensor? initial_states, Tensor? final_states_out, bool silu_activation) -> (Tensor, Tensor)" );
634
+ m.impl (
635
+ " causal_conv1d_fn" ,
636
+ c10::DispatchKey::CPU,
637
+ torch_ipex::cpu::causal_conv1d_fn);
597
638
// bw
598
639
m.def (
599
640
" convolution_backward(Tensor input, Tensor weight, Tensor? bias, Tensor grad_output, bool[3] out_mask, "
0 commit comments