Skip to content

Commit cf1dbc7

Browse files
pearufacebook-github-bot
authored andcommitted
Vectorize unary operator erfinv (pytorch#26629)
Summary: Resolves pytorch#19088 for erfinv. erfinv speedup (MKL, AMD Ryzen Threadripper 2970WX 24-Core Processor): 22x Pull Request resolved: pytorch#26629 Differential Revision: D17527230 Pulled By: ezyang fbshipit-source-id: 0a5a53a88f7eb219617120383a454a01ad78279a
1 parent c643290 commit cf1dbc7

File tree

9 files changed

+42
-17
lines changed

9 files changed

+42
-17
lines changed

aten/src/ATen/core/TensorMethods.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5120,7 +5120,13 @@ inline Tensor Tensor::polygamma(int64_t n) const {
51205120
}
51215121
inline Tensor Tensor::erfinv() const {
51225122
#ifdef USE_STATIC_DISPATCH
5123-
return TypeDefault::erfinv(const_cast<Tensor&>(*this));
5123+
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
5124+
case Backend::CPU:
5125+
return CPUType::erfinv(const_cast<Tensor&>(*this));
5126+
break;
5127+
default:
5128+
AT_ERROR("erfinv not implemented for ", at::toString(type_set()));
5129+
}
51245130
#else
51255131
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv", ""}).value();
51265132
return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &>(
@@ -5129,7 +5135,13 @@ inline Tensor Tensor::erfinv() const {
51295135
}
51305136
inline Tensor & Tensor::erfinv_() const {
51315137
#ifdef USE_STATIC_DISPATCH
5132-
return TypeDefault::erfinv_(const_cast<Tensor&>(*this));
5138+
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
5139+
case Backend::CPU:
5140+
return CPUType::erfinv_(const_cast<Tensor&>(*this));
5141+
break;
5142+
default:
5143+
AT_ERROR("erfinv_ not implemented for ", at::toString(type_set()));
5144+
}
51335145
#else
51345146
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv_", ""}).value();
51355147
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &>(

aten/src/ATen/cpu/vec256/vec256_base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/Utils.h>
1010
#include <ATen/native/Copy.h>
11+
#include <ATen/native/Math.h>
1112
#include <ATen/NumericUtils.h>
1213
#include <c10/util/C++17.h>
1314
#include <c10/util/BFloat16.h>
@@ -197,6 +198,9 @@ struct Vec256 {
197198
Vec256<T> erfc() const {
198199
return map(std::erfc);
199200
}
201+
Vec256<T> erfinv() const {
202+
return map(calc_erfinv);
203+
}
200204
Vec256<T> exp() const {
201205
return map(std::exp);
202206
}

aten/src/ATen/cpu/vec256/vec256_double.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ template <> class Vec256<double> {
109109
Vec256<double> erfc() const {
110110
return Vec256<double>(Sleef_erfcd4_u15(values));
111111
}
112+
Vec256<double> erfinv() const {
113+
return map(calc_erfinv);
114+
}
112115
Vec256<double> exp() const {
113116
return Vec256<double>(Sleef_expd4_u10(values));
114117
}

aten/src/ATen/cpu/vec256/vec256_float.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ template <> class Vec256<float> {
117117
Vec256<float> erfc() const {
118118
return Vec256<float>(Sleef_erfcf8_u15(values));
119119
}
120+
Vec256<float> erfinv() const {
121+
return map(calc_erfinv);
122+
}
120123
Vec256<float> exp() const {
121124
return Vec256<float>(Sleef_expf8_u10(values));
122125
}

aten/src/ATen/cpu/vml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ IMPLEMENT_VML_BUG(cos)
106106
// IMPLEMENT_VML_BUG(cosh)
107107
IMPLEMENT_VML_BUG(erf)
108108
IMPLEMENT_VML_BUG(erfc)
109+
IMPLEMENT_VML(erfinv)
109110
IMPLEMENT_VML_BUG(exp)
110111
IMPLEMENT_VML_BUG(expm1)
111112
IMPLEMENT_VML_BUG(floor)
@@ -174,6 +175,7 @@ IMPLEMENT_VML_MKL(cos, Cos)
174175
// IMPLEMENT_VML_MKL(cosh, Cosh)
175176
IMPLEMENT_VML_MKL(erf, Erf)
176177
IMPLEMENT_VML_MKL(erfc, Erfc)
178+
IMPLEMENT_VML_MKL(erfinv, ErfInv)
177179
IMPLEMENT_VML_MKL(exp, Exp)
178180
IMPLEMENT_VML_MKL(expm1, Expm1)
179181
IMPLEMENT_VML_MKL(log, Ln)

aten/src/ATen/native/Math.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
#pragma once
2+
13
#include <cstdlib>
24
#include <cmath>
35
#include <limits>
46
#include <type_traits>
57

8+
#ifndef M_PIf
9+
#define M_PIf 3.1415926535f
10+
#endif // M_PIf
11+
612
/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
713
Below is the copyright.
814
Output was modified to be inf or -inf when input is 1 or -1. */

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
6565
Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); }
6666
Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); }
6767

68-
Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erfinv_stub); }
69-
Tensor erfinv(const Tensor& self) { return unary_op_impl(self, at::erfinv_out); }
70-
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }
71-
7268
Tensor& floor_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, floor_stub); }
7369
Tensor floor(const Tensor& self) { return unary_op_impl(self, at::floor_out); }
7470
Tensor& floor_(Tensor& self) { return unary_op_impl_(self, at::floor_out); }
@@ -281,6 +277,7 @@ IMPLEMENT_UNARY_OP_VEC(cos)
281277
IMPLEMENT_UNARY_OP_VEC(cosh)
282278
IMPLEMENT_UNARY_OP_VEC(erf)
283279
IMPLEMENT_UNARY_OP_VEC(erfc)
280+
IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv)
284281
IMPLEMENT_UNARY_OP_VEC(exp)
285282
IMPLEMENT_UNARY_OP_VEC(expm1)
286283
IMPLEMENT_UNARY_OP_VEC(frac)

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,6 @@ static void cosh_kernel(TensorIterator& iter) {
156156
});
157157
}
158158

159-
static void erfinv_kernel(TensorIterator& iter) {
160-
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "erfinv_cpu", [&]() {
161-
cpu_kernel(
162-
iter,
163-
[=](scalar_t a) -> scalar_t { return calc_erfinv(a); });
164-
});
165-
}
166-
167159
static void digamma_kernel(TensorIterator& iter) {
168160
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() {
169161
cpu_kernel(
@@ -337,7 +329,6 @@ REGISTER_DISPATCH(neg_stub, &neg_kernel);
337329
REGISTER_DISPATCH(sign_stub, &sign_kernel);
338330
REGISTER_DISPATCH(sinh_stub, &sinh_kernel);
339331
REGISTER_DISPATCH(cosh_stub, &cosh_kernel);
340-
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel);
341332
REGISTER_DISPATCH(digamma_stub, &digamma_kernel);
342333
REGISTER_DISPATCH(trigamma_stub, &trigamma_kernel);
343334
REGISTER_DISPATCH(polygamma_stub, &polygamma_kernel);
@@ -355,6 +346,7 @@ IMPLEMENT_FLOAT_KERNEL(FLOATING, cos)
355346
// IMPLEMENT_FLOAT_KERNEL(FLOATING, cosh)
356347
IMPLEMENT_FLOAT_KERNEL(FLOATING, erf)
357348
IMPLEMENT_FLOAT_KERNEL(FLOATING, erfc)
349+
IMPLEMENT_FLOAT_KERNEL(FLOATING, erfinv)
358350
IMPLEMENT_FLOAT_KERNEL(FLOATING, exp)
359351
IMPLEMENT_FLOAT_KERNEL(FLOATING, expm1)
360352
IMPLEMENT_FLOAT_KERNEL(FLOATING, floor)

aten/src/ATen/native/native_functions.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4842,17 +4842,23 @@
48424842
use_c10_dispatcher: full
48434843
supports_named_tensor: True
48444844
variants: method, function
4845+
dispatch:
4846+
CPU: erfinv
4847+
CUDA: erfinv
48454848

48464849
- func: erfinv_(Tensor(a!) self) -> Tensor(a!)
48474850
use_c10_dispatcher: unboxed_only
48484851
supports_named_tensor: True
48494852
variants: method
4853+
dispatch:
4854+
CPU: _erfinv__cpu
4855+
CUDA: _erfinv__cuda
48504856

48514857
- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
48524858
supports_named_tensor: True
48534859
dispatch:
4854-
CPU: erfinv_out
4855-
CUDA: erfinv_out
4860+
CPU: _erfinv_out_cpu
4861+
CUDA: _erfinv_out_cuda
48564862

48574863
- func: sign(Tensor self) -> Tensor
48584864
use_c10_dispatcher: unboxed_only

0 commit comments

Comments
 (0)