Skip to content

Commit a65db65

Browse files
smessmerfacebook-github-bot
authored andcommitted
Enable registering stackbased kernels with lambdas (pytorch#26658)
Summary: Pull Request resolved: pytorch#26658 By SFINAE'ing the lambda registration to only kernels that aren't stackbased kernels, an attempt to register a stackbased lambda kernel will correctly fallback to the stackbased registration function and work as expected. ghstack-source-id: 90610843 Test Plan: unit tests Differential Revision: D17533871 fbshipit-source-id: 1bfe3106b0576d46798a51bdaa5b7b5508164766
1 parent 839e636 commit a65db65

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

aten/src/ATen/core/boxing/kernel_stackbased_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenC
6262
expectCallsIncrement(TensorTypeId::CPUTensorId);
6363
}
6464

65+
TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredAsLambda_thenCanBeCalled) {
66+
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId,
67+
[] (OperatorKernel*, Stack* stack) {
68+
int input = torch::jit::pop(*stack).toInt();
69+
torch::jit::pop(*stack); // pop the dummy tensor
70+
torch::jit::push(*stack, input + 1);
71+
}));
72+
expectCallsIncrement(TensorTypeId::CPUTensorId);
73+
}
74+
75+
TEST(OperatorRegistrationTest_StackBasedKernel, givenCatchAllKernel_whenRegisteredAsLambda_thenCanBeCalled) {
76+
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().catchAllKernel(
77+
[] (OperatorKernel*, Stack* stack) {
78+
int input = torch::jit::pop(*stack).toInt();
79+
torch::jit::pop(*stack); // pop the dummy tensor
80+
torch::jit::push(*stack, input + 1);
81+
}));
82+
expectCallsIncrement(TensorTypeId::CPUTensorId);
83+
}
84+
6585
TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
6686
auto registrar = RegisterOperators()
6787
.op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, &incrementKernel))

aten/src/ATen/core/op_registration/op_registration.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,10 @@ class CAFFE2_API RegisterOperators final {
325325
*/
326326
template<class Lambda>
327327
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
328-
guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
328+
guts::enable_if_t<
329+
guts::is_functor<guts::decay_t<Lambda>>::value
330+
&& !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
331+
Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
329332
static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
330333

331334
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
@@ -362,7 +365,10 @@ class CAFFE2_API RegisterOperators final {
362365
*/
363366
template<class Lambda>
364367
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
365-
guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> catchAllKernel(Lambda&& lambda) && {
368+
guts::enable_if_t<
369+
guts::is_functor<guts::decay_t<Lambda>>::value
370+
&& !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
371+
Options&&> catchAllKernel(Lambda&& lambda) && {
366372
static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
367373

368374
// We don't support stateful lambdas (i.e. lambdas with a capture), because their

0 commit comments

Comments
 (0)