Skip to content

Commit 7bdc0c1

Browse files
xuhdevfacebook-github-bot
authored andcommitted
Move the CUDA implementation of trunc to ATen. (pytorch#25423)
Summary: Pull Request resolved: pytorch#25423 Fix pytorch#24650 Test Plan: Imported from OSS Differential Revision: D17397489 Pulled By: VitalyFedyunin fbshipit-source-id: 933f915a44ff9b7803ddb2708bf0e723433ee0b6
1 parent d6ee584 commit 7bdc0c1

File tree

10 files changed

+27
-34
lines changed

10 files changed

+27
-34
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,20 +1410,6 @@
14101410
output: True
14111411
- THTensor* self
14121412
]]
1413-
[[
1414-
name: _th_trunc
1415-
cname: trunc
1416-
types:
1417-
- floating_point
1418-
backends:
1419-
- CUDA
1420-
variants: function
1421-
return: argument 0
1422-
arguments:
1423-
- arg: THTensor* result
1424-
output: True
1425-
- THTensor* self
1426-
]]
14271413
[[
14281414
name: _th_frac_
14291415
types:

aten/src/ATen/core/TensorMethods.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,13 +2559,7 @@ inline Tensor Tensor::trunc() const {
25592559
}
25602560
inline Tensor & Tensor::trunc_() const {
25612561
#ifdef USE_STATIC_DISPATCH
2562-
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
2563-
case Backend::CPU:
2564-
return CPUType::trunc_(const_cast<Tensor&>(*this));
2565-
break;
2566-
default:
2567-
AT_ERROR("trunc_ not implemented for ", at::toString(type_set()));
2568-
}
2562+
return TypeDefault::trunc_(const_cast<Tensor&>(*this));
25692563
#else
25702564
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trunc_", ""}).value();
25712565
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &>(

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
8585
Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); }
8686
Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); }
8787

88+
Tensor& trunc_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, trunc_stub); }
89+
Tensor trunc(const Tensor& self) { return unary_op_impl(self, at::trunc_out); }
90+
Tensor& trunc_(Tensor& self) { return unary_op_impl_(self, at::trunc_out); }
91+
8892
Tensor& neg_out(Tensor& result, const Tensor& self) {
8993
TORCH_CHECK(self.scalar_type() != kBool,
9094
"Negation, the `-` operator, on a bool tensor is not supported. "
@@ -291,7 +295,6 @@ IMPLEMENT_UNARY_OP_VEC(sinh)
291295
IMPLEMENT_UNARY_OP_VEC(sqrt)
292296
IMPLEMENT_UNARY_OP_VEC(tan)
293297
IMPLEMENT_UNARY_OP_VEC(tanh)
294-
IMPLEMENT_UNARY_OP_VEC(trunc)
295298
IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)
296299

297300
DEFINE_DISPATCH(abs_stub);

aten/src/ATen/native/cuda/CUDAUnaryOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,5 @@ IMPLEMENT_UNARY_OP_PREQUEL(sinh)
8686
IMPLEMENT_UNARY_OP_PREQUEL(sqrt)
8787
IMPLEMENT_UNARY_OP_PREQUEL(tan)
8888
IMPLEMENT_UNARY_OP_PREQUEL(tanh)
89-
IMPLEMENT_UNARY_OP_PREQUEL(trunc)
9089

9190
}}

aten/src/ATen/native/cuda/UnaryOpsKernel.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ void round_kernel_cuda(TensorIterator& iter) {
7575
});
7676
}
7777

78+
// We manually overload trunc because std::trunc does not work with ROCm.
79+
template <typename scalar_t>
80+
__host__ __device__ static inline scalar_t trunc_wrapper(scalar_t a) {
81+
return static_cast<scalar_t>(::truncf(static_cast<float>(a)));
82+
}
83+
84+
__host__ __device__ static inline double trunc_wrapper(double a) {
85+
return ::trunc(a);
86+
}
87+
88+
void trunc_kernel_cuda(TensorIterator& iter) {
89+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trunc_cuda", [&]() {
90+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
91+
return trunc_wrapper(a);
92+
});
93+
});
94+
}
95+
7896
void rsqrt_kernel_cuda(TensorIterator& iter) {
7997
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "rsqrt_cuda", [&]() {
8098
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
@@ -147,6 +165,7 @@ REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda);
147165
REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
148166
REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
149167
REGISTER_DISPATCH(sign_stub, &sign_kernel_cuda);
168+
REGISTER_DISPATCH(trunc_stub, &trunc_kernel_cuda);
150169
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
151170
REGISTER_DISPATCH(digamma_stub, &digamma_kernel_cuda);
152171
REGISTER_DISPATCH(polygamma_stub, &polygamma_kernel_cuda);

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,15 +2729,12 @@
27292729
use_c10_dispatcher: unboxed_only
27302730
supports_named_tensor: True
27312731
variants: function, method
2732-
dispatch:
2733-
CPU: _trunc__cpu
2734-
CUDA: _trunc__cuda
27352732

27362733
- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
27372734
supports_named_tensor: True
27382735
dispatch:
2739-
CPU: _trunc_out_cpu
2740-
CUDA: _trunc_out_cuda
2736+
CPU: trunc_out
2737+
CUDA: trunc_out
27412738

27422739
- func: type_as(Tensor self, Tensor other) -> Tensor
27432740
use_c10_dispatcher: full

aten/src/THC/THCNumerics.cuh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ struct THCNumerics<at::Half> {
212212
static inline __host__ __device__ at::Half cos(at::Half a) { return ::cos(a); }
213213
static inline __host__ __device__ at::Half sin(at::Half a) { return ::sin(a); }
214214
static inline __host__ __device__ at::Half sqrt(at::Half a) { return ::sqrt(a); }
215-
static inline __host__ __device__ at::Half trunc(at::Half a) { return ::trunc(a); }
216215
static inline __host__ __device__ at::Half acos(at::Half a) { return ::acos(a); }
217216
static inline __host__ __device__ at::Half cosh(at::Half a) { return ::cosh(a); }
218217
static inline __host__ __device__ at::Half asin(at::Half a) { return ::asin(a); }
@@ -290,7 +289,6 @@ struct THCNumerics<float> {
290289
static inline __host__ __device__ float cos (float a) { return cosf(a); }
291290
static inline __host__ __device__ float sin (float a) { return sinf(a); }
292291
static inline __host__ __device__ float sqrt (float a) { return sqrtf(a); }
293-
static inline __host__ __device__ float trunc(float a) { return truncf(a); }
294292
static inline __host__ __device__ float acos (float a) { return acosf(a); }
295293
static inline __host__ __device__ float cosh (float a) { return coshf(a); }
296294
static inline __host__ __device__ float acosh(float a) { return acoshf(a); }
@@ -343,7 +341,6 @@ struct THCNumerics<double> {
343341
static inline __host__ __device__ double cos (double a) { return ::cos(a); }
344342
static inline __host__ __device__ double sin (double a) { return ::sin(a); }
345343
static inline __host__ __device__ double sqrt (double a) { return ::sqrt(a); }
346-
static inline __host__ __device__ double trunc(double a) { return ::trunc(a); }
347344
static inline __host__ __device__ double acos (double a) { return ::acos(a); }
348345
static inline __host__ __device__ double cosh (double a) { return ::cosh(a); }
349346
static inline __host__ __device__ double acosh(double a) { return ::acosh(a); }

aten/src/THC/generic/THCTensorMathPointwise.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(expm1, THCNumerics<scalar_t>::expm1, Real)
207207
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cos, THCNumerics<scalar_t>::cos, Real)
208208
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sin, THCNumerics<scalar_t>::sin, Real)
209209
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sqrt, THCNumerics<scalar_t>::sqrt, Real)
210-
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics<scalar_t>::trunc, Real)
211210

212211
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( acos, THCNumerics<scalar_t>::acos, Real)
213212
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cosh, THCNumerics<scalar_t>::cosh, Real)

aten/src/THC/generic/THCTensorMathPointwise.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ THC_API void THCTensor_(tanh)(THCState *state, THCTensor *self, THCTensor *src);
3434
THC_API void THCTensor_(erf)(THCState *state, THCTensor *self, THCTensor *src);
3535
THC_API void THCTensor_(erfc)(THCState *state, THCTensor *self, THCTensor *src);
3636
THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src);
37-
THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src);
3837
THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src);
3938

4039
THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11836,7 +11836,7 @@ def test_unary_out_op_mem_overlap(self, device):
1183611836
("tanh", doubles, True, True, 'cpu'),
1183711837
("tanh", doubles, False, True, 'cuda'),
1183811838
("trunc", doubles, True, True, 'cpu'),
11839-
("trunc", doubles, False, True, 'cuda')
11839+
("trunc", doubles, True, True, 'cuda')
1184011840
]
1184111841

1184211842
for (fn, inputs, has_input_output_mem_overlap_check,

0 commit comments

Comments
 (0)