Skip to content

Commit ae0732c

Browse files
xuhdevfacebook-github-bot
authored andcommitted
Speed up an integer to the power of a positive integer on CPU (pytorch#26020)
Summary: Current integer scalar exps are always cast to double. This commit avoids cast if the tensor is also integral and the scalar is positive to speed up. Benchmark (Debian Buster, g++ 8, Intel(R) Xeon(R) E-2136 CPU @ 3.30GHz 0 0:0 3300.00 MHz , Debug build, Turbo turned off): ```python import timeit for n, t in [(1000, 13000), (10_000, 1300)]: for e in (2, 3, 4): for dtype in ('torch.int16', 'torch.int32', 'torch.int64'): print(f'a.pow({e}) (a.numel() == {n}) for {t} times') print(f'dtype {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a.pow({e})', setup=f'import torch; a = torch.arange({n}, device="cpu", dtype={dtype})', number=t)) ``` Before: ``` a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.6958350749996498 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 0.7989626339999631 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 0.7973162800003593 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.8660746679997828 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 0.8101709959996697 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 0.8135280149999744 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 5.010833072999958 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 4.801007671999741 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 3.963344578000033 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.6216251330001796 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.5672429639998882 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.5544572270000572 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.656308512999658 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 1.502670819999821 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.5757876879997639 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 4.775718216999849 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 4.754745475000163 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 3.737249878000057 ``` After: ``` a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.1006453190002503 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.0849009019998448 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.093259106000005 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.0859826279997833 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.1076840900000207 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.0755480369998622 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.918211066999902 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.9183043200000611 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.930021430999659 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 0.7271483560002707 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.7289002070001516 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.7267536800000016 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 0.7301799359997858 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.7289195180001116 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.7270008230002531 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.5354506029998447 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 1.528263066999898 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 1.5369428439998956 ``` --- Best viewed with whitespace changes turned off Pull Request resolved: pytorch#26020 Differential Revision: D17485400 Pulled By: VitalyFedyunin fbshipit-source-id: 3a16b074825a5aab0f7e7af3d8100f9e4b7011a3
1 parent 66d2750 commit ae0732c

File tree

2 files changed

+128
-97
lines changed

2 files changed

+128
-97
lines changed

aten/src/ATen/native/cpu/PowKernel.cpp

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
3535
}
3636

3737
void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
38-
// Casting exponent to double(not tensor.dtype) allows powering integral
39-
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
40-
const auto exp = exp_scalar.to<double>();
4138
if (isFloatingType(iter.dtype())) {
39+
const auto exp = exp_scalar.to<double>();
4240
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
4341
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "pow", [&]() {
4442
using Vec = Vec256<scalar_t>;
@@ -98,55 +96,73 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
9896
// Trying to implement pow/sqrt/rsqrt as loop in vec256_int.h does not allow
9997
// powering integral tensor to float exponent. That's why we need this code
10098
// duplication:
101-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
102-
if (exp == 0.5) {
103-
cpu_kernel(iter,
104-
[](scalar_t base) -> scalar_t {
105-
return std::sqrt(static_cast<long double>(base));
106-
}
107-
);
108-
} else if (exp == 2) {
109-
cpu_kernel(iter,
110-
[](scalar_t base) -> scalar_t {
111-
const auto ld_base = static_cast<long double>(base);
112-
return ld_base * ld_base;
113-
}
114-
);
115-
} else if (exp == 3) {
116-
cpu_kernel(iter,
117-
[](scalar_t base) -> scalar_t {
118-
const auto ld_base = static_cast<long double>(base);
119-
return ld_base * ld_base * ld_base;
120-
}
121-
);
122-
} else if (exp == -0.5) {
123-
cpu_kernel(iter,
124-
[](scalar_t base) -> scalar_t {
125-
return 1.0 / std::sqrt(static_cast<long double>(base));
126-
}
127-
);
128-
} else if (exp == -1) {
129-
cpu_kernel(iter,
130-
[](scalar_t base) -> scalar_t {
131-
return 1.0 / static_cast<long double>(base);
132-
}
133-
);
134-
} else if (exp == -2) {
135-
cpu_kernel(iter,
136-
[](scalar_t base) -> scalar_t {
137-
const auto ld_base = static_cast<long double>(base);
138-
return 1.0 / (ld_base * ld_base);
139-
}
140-
);
141-
} else {
142-
cpu_kernel(iter,
143-
[=](scalar_t base) -> scalar_t {
144-
return std::pow(static_cast<long double>(base),
145-
static_cast<long double>(exp));
146-
}
147-
);
148-
}
149-
});
99+
100+
if (exp_scalar.isIntegral(true) && exp_scalar.to<int64_t>() >= 0) {
101+
// Specifically deal with an integer to the power of a positive integer for better efficiency.
102+
const auto exp = exp_scalar.to<int64_t>();
103+
104+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
105+
switch (exp) {
106+
case 2:
107+
cpu_kernel(iter,
108+
[](scalar_t base) -> scalar_t {
109+
return base * base;
110+
}
111+
);
112+
break;
113+
case 3:
114+
cpu_kernel(iter,
115+
[](scalar_t base) -> scalar_t {
116+
return base * base * base;
117+
}
118+
);
119+
break;
120+
default:
121+
cpu_kernel(iter,
122+
[=](scalar_t base) -> scalar_t {
123+
return std::pow(base, exp);
124+
}
125+
);
126+
}
127+
});
128+
} else {
129+
// Casting exponent to double(not tensor.dtype) allows powering integral
130+
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
131+
const auto exp = exp_scalar.to<double>();
132+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
133+
if (exp == 0.5) {
134+
cpu_kernel(iter,
135+
[](scalar_t base) -> scalar_t {
136+
return std::sqrt(static_cast<long double>(base));
137+
}
138+
);
139+
} else if (exp == -0.5) {
140+
cpu_kernel(iter,
141+
[](scalar_t base) -> scalar_t {
142+
return 1.0 / std::sqrt(static_cast<long double>(base));
143+
}
144+
);
145+
} else if (exp == -1) {
146+
cpu_kernel(iter,
147+
[](scalar_t base) -> scalar_t {
148+
return 1.0 / static_cast<long double>(base);
149+
}
150+
);
151+
} else if (exp == -2) {
152+
cpu_kernel(iter,
153+
[](scalar_t base) -> scalar_t {
154+
return 1.0 / (base * base);
155+
}
156+
);
157+
} else {
158+
cpu_kernel(iter,
159+
[=](scalar_t base) -> scalar_t {
160+
return std::pow(static_cast<long double>(base), exp);
161+
}
162+
);
163+
}
164+
});
165+
}
150166
}
151167
}
152168

test/test_torch.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,51 +1347,6 @@ def test_baddbmm(self):
13471347
res6 = torch.baddbmm(.1, res2, .5, b1, b2)
13481348
self.assertEqual(res6, res2 * .1 + res * .5)
13491349

1350-
def test_pow(self):
1351-
# [res] torch.pow([res,] x)
1352-
1353-
# pow has dedicated implementation for different exponents
1354-
for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
1355-
# base - tensor, exponent - number
1356-
# contiguous
1357-
m1 = torch.rand(100, 100) + 0.5
1358-
res1 = torch.pow(m1[4], exponent)
1359-
res2 = res1.clone().zero_()
1360-
for i in range(res2.size(0)):
1361-
res2[i] = math.pow(m1[4][i], exponent)
1362-
self.assertEqual(res1, res2)
1363-
1364-
# non-contiguous
1365-
m1 = torch.rand(100, 100) + 0.5
1366-
res1 = torch.pow(m1[:, 4], exponent)
1367-
res2 = res1.clone().zero_()
1368-
for i in range(res2.size(0)):
1369-
res2[i] = math.pow(m1[i, 4], exponent)
1370-
self.assertEqual(res1, res2)
1371-
1372-
# base - number, exponent - tensor
1373-
# contiguous
1374-
m1 = torch.randn(100, 100)
1375-
res1 = torch.pow(3, m1[4])
1376-
res2 = res1.clone().zero_()
1377-
for i in range(res2.size(0)):
1378-
res2[i] = math.pow(3, m1[4, i])
1379-
self.assertEqual(res1, res2)
1380-
1381-
# non-contiguous
1382-
m1 = torch.randn(100, 100)
1383-
res1 = torch.pow(3, m1[:, 4])
1384-
res2 = res1.clone().zero_()
1385-
for i in range(res2.size(0)):
1386-
res2[i] = math.pow(3, m1[i][4])
1387-
self.assertEqual(res1, res2)
1388-
1389-
# resize behavior for exp == 1
1390-
m1 = torch.randn(2, 2)
1391-
out = torch.randn([0])
1392-
torch.pow(m1, 1, out=out)
1393-
self.assertEqual(out, m1)
1394-
13951350
def _test_cop(self, torchfn, mathfn):
13961351
def reference_implementation(res2):
13971352
for i, j in iter_indices(sm1):
@@ -7022,6 +6977,66 @@ def test_diagonal(self, device):
70226977
expected = torch.diag(x, 17)
70236978
self.assertEqual(result, expected)
70246979

6980+
def test_pow(self, device):
6981+
# [res] torch.pow([res,] x)
6982+
6983+
# pow has dedicated implementation for different exponents
6984+
for dtype in torch.testing.get_all_math_dtypes(device):
6985+
6986+
# This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it
6987+
# for now.
6988+
if dtype == torch.half:
6989+
continue
6990+
6991+
m1 = torch.empty(0, dtype=dtype, device=device)
6992+
if m1.is_floating_point():
6993+
m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5
6994+
else:
6995+
# math.pow will overflow and throw exceptions for large integers
6996+
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
6997+
m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device)
6998+
6999+
for num in [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]:
7000+
if isinstance(num, int) and num < 0 and not m1.is_floating_point():
7001+
with self.assertRaisesRegex(RuntimeError,
7002+
r'Integers to negative integer powers are not allowed\.'):
7003+
torch.pow(m1[4], num)
7004+
else:
7005+
# base - tensor, exponent - number
7006+
# contiguous
7007+
res1 = torch.pow(m1[4], num)
7008+
res2 = res1.clone().zero_()
7009+
for i in range(res2.size(0)):
7010+
res2[i] = math.pow(m1[4][i], num)
7011+
self.assertEqual(res1, res2)
7012+
7013+
# non-contiguous
7014+
res1 = torch.pow(m1[:, 4], num)
7015+
res2 = res1.clone().zero_()
7016+
for i in range(res2.size(0)):
7017+
res2[i] = math.pow(m1[i, 4], num)
7018+
self.assertEqual(res1, res2)
7019+
7020+
# base - number, exponent - tensor
7021+
# contiguous
7022+
res1 = torch.pow(3, m1[4])
7023+
res2 = res1.clone().zero_()
7024+
for i in range(res2.size(0)):
7025+
res2[i] = math.pow(3, m1[4, i])
7026+
self.assertEqual(res1, res2)
7027+
7028+
# non-contiguous
7029+
res1 = torch.pow(3, m1[:, 4])
7030+
res2 = res1.clone().zero_()
7031+
for i in range(res2.size(0)):
7032+
res2[i] = math.pow(3, m1[i][4])
7033+
self.assertEqual(res1, res2)
7034+
7035+
# resize behavior for exp == 1
7036+
out = torch.zeros(1, dtype=dtype, device=device)
7037+
torch.pow(m1, 1, out=out)
7038+
self.assertEqual(out, m1)
7039+
70257040
def test_neg(self, device):
70267041
int_types = [torch.int, torch.short, torch.int8, torch.uint8]
70277042
float_types = [torch.float, torch.double, torch.long]

0 commit comments

Comments
 (0)