Skip to content

Commit da2eeba

Browse files
Correlation via fft implementation (#2203)
* Added keyword `method` to `correlate` function similar to [scipy correlate](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html) * If `method == 'auto'` method is choosing automatically between `direct` and `fft` * Added implementation of fft-based correlation * fft-based implementation may have accuracy issues, so it is validated in non-standard way. Depends on: #2180, #2202 --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent 2ac196c commit da2eeba

File tree

3 files changed

+212
-33
lines changed

3 files changed

+212
-33
lines changed

dpnp/dpnp_iface_statistics.py

+102-22
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import math
41+
4042
import dpctl.tensor as dpt
4143
import dpctl.tensor._tensor_elementwise_impl as ti
4244
import dpctl.utils as dpu
@@ -481,24 +483,66 @@ def _get_padding(a_size, v_size, mode):
481483
r_pad = v_size - l_pad - 1
482484
elif mode == "full":
483485
l_pad, r_pad = v_size - 1, v_size - 1
484-
else:
486+
else: # pragma: no cover
485487
raise ValueError(
486488
f"Unknown mode: {mode}. Only 'valid', 'same', 'full' are supported."
487489
)
488490

489491
return l_pad, r_pad
490492

491493

492-
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
494+
def _choose_conv_method(a, v, rdtype):
495+
assert a.size >= v.size
496+
if rdtype == dpnp.bool:
497+
# to avoid accuracy issues
498+
return "direct"
499+
500+
if v.size < 10**4 or a.size < 10**4:
501+
# direct method is faster for small arrays
502+
return "direct"
503+
504+
if dpnp.issubdtype(rdtype, dpnp.integer):
505+
max_a = int(dpnp.max(dpnp.abs(a)))
506+
sum_v = int(dpnp.sum(dpnp.abs(v)))
507+
max_value = int(max_a * sum_v)
508+
509+
default_float = dpnp.default_float_type(a.sycl_device)
510+
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
511+
# can't represent the result in the default float type
512+
return "direct" # pragma: no covers
513+
514+
if dpnp.issubdtype(rdtype, dpnp.number):
515+
return "fft"
516+
517+
raise ValueError(f"Unsupported dtype: {rdtype}") # pragma: no cover
518+
519+
520+
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
493521
queue = a.sycl_queue
522+
device = a.sycl_device
523+
524+
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
525+
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
526+
527+
if supported_dtype is None: # pragma: no cover
528+
raise ValueError(
529+
f"function does not support input types "
530+
f"({a.dtype.name}, {v.dtype.name}), "
531+
"and the inputs could not be coerced to any "
532+
f"supported types. List of supported types: "
533+
f"{[st.name for st in supported_types]}"
534+
)
535+
536+
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
537+
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
494538

495-
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
496-
out_size = l_pad + r_pad + a.size - v.size + 1
539+
usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
540+
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
497541
# out type is the same as input type
498-
out = dpnp.empty_like(a, shape=out_size, usm_type=usm_type)
542+
out = dpnp.empty_like(a_casted, shape=out_size, usm_type=usm_type)
499543

500-
a_usm = dpnp.get_usm_ndarray(a)
501-
v_usm = dpnp.get_usm_ndarray(v)
544+
a_usm = dpnp.get_usm_ndarray(a_casted)
545+
v_usm = dpnp.get_usm_ndarray(v_casted)
502546
out_usm = dpnp.get_usm_ndarray(out)
503547

504548
_manager = dpu.SequentialOrderManager[queue]
@@ -516,7 +560,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
516560
return out
517561

518562

519-
def correlate(a, v, mode="valid"):
563+
def _convolve_fft(a, v, l_pad, r_pad, rtype):
564+
assert a.size >= v.size
565+
assert l_pad < v.size
566+
567+
# +1 is needed to avoid circular convolution
568+
padded_size = a.size + r_pad + 1
569+
fft_size = 2 ** int(math.ceil(math.log2(padded_size)))
570+
571+
af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
572+
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member
573+
574+
r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
575+
if dpnp.issubdtype(rtype, dpnp.floating):
576+
r = r.real
577+
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
578+
r = r.real.round()
579+
580+
start = v.size - 1 - l_pad
581+
end = padded_size - 1
582+
583+
return r[start:end]
584+
585+
586+
def correlate(a, v, mode="valid", method="auto"):
520587
r"""
521588
Cross-correlation of two 1-dimensional sequences.
522589
@@ -541,6 +608,20 @@ def correlate(a, v, mode="valid"):
541608
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
542609
543610
Default: ``"valid"``.
611+
method : {"auto", "direct", "fft"}, optional
612+
Specifies which method to use to calculate the correlation:
613+
614+
- `"direct"` : The correlation is determined directly from sums.
615+
- `"fft"` : The Fourier Transform is used to perform the calculations.
616+
This method is faster for long sequences but can have accuracy issues.
617+
- `"auto"` : Automatically chooses direct or Fourier method based on
618+
an estimate of which is faster.
619+
620+
Note: Use of the FFT convolution on input containing NAN or INF
621+
will lead to the entire output being NAN or INF.
622+
Use method='direct' when your input contains NAN or INF values.
623+
624+
Default: ``"auto"``.
544625
545626
Returns
546627
-------
@@ -608,20 +689,14 @@ def correlate(a, v, mode="valid"):
608689
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
609690
)
610691

611-
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
692+
supported_methods = ["auto", "direct", "fft"]
693+
if method not in supported_methods:
694+
raise ValueError(
695+
f"Unknown method: {method}. Supported methods: {supported_methods}"
696+
)
612697

613698
device = a.sycl_device
614699
rdtype = result_type_for_device([a.dtype, v.dtype], device)
615-
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
616-
617-
if supported_dtype is None: # pragma: no cover
618-
raise ValueError(
619-
f"function does not support input types "
620-
f"({a.dtype.name}, {v.dtype.name}), "
621-
"and the inputs could not be coerced to any "
622-
f"supported types. List of supported types: "
623-
f"{[st.name for st in supported_types]}"
624-
)
625700

626701
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
627702
v = dpnp.conj(v)
@@ -633,10 +708,15 @@ def correlate(a, v, mode="valid"):
633708

634709
l_pad, r_pad = _get_padding(a.size, v.size, mode)
635710

636-
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
637-
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
711+
if method == "auto":
712+
method = _choose_conv_method(a, v, rdtype)
638713

639-
r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
714+
if method == "direct":
715+
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
716+
elif method == "fft":
717+
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
718+
else: # pragma: no cover
719+
raise ValueError(f"Unknown method: {method}")
640720

641721
if revert:
642722
r = r[::-1]

dpnp/tests/helper.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def assert_dtype_allclose(
1313
check_type=True,
1414
check_only_type_kind=False,
1515
factor=8,
16+
relative_factor=None,
1617
):
1718
"""
1819
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
@@ -183,6 +184,7 @@ def generate_random_numpy_array(
183184
seed_value=None,
184185
low=-10,
185186
high=10,
187+
probability=0.5,
186188
):
187189
"""
188190
Generate a random numpy array with the specified shape and dtype.
@@ -197,23 +199,32 @@ def generate_random_numpy_array(
197199
dtype : str or dtype, optional
198200
Desired data-type for the output array.
199201
If not specified, data type will be determined by numpy.
202+
200203
Default : ``None``
201204
order : {"C", "F"}, optional
202205
Specify the memory layout of the output array.
206+
203207
Default: ``"C"``.
204208
hermitian : bool, optional
205209
If True, generates a Hermitian (symmetric if `dtype` is real) matrix.
210+
206211
Default : ``False``
207212
seed_value : int, optional
208213
The seed value to initialize the random number generator.
214+
209215
Default : ``None``
210216
low : {int, float}, optional
211217
Lower boundary of the generated samples from a uniform distribution.
218+
212219
Default : ``-10``.
213220
high : {int, float}, optional
214221
Upper boundary of the generated samples from a uniform distribution.
222+
215223
Default : ``10``.
224+
probability : float, optional
225+
If dtype is bool, the probability of True. Ignored for other dtypes.
216226
227+
Default : ``0.5``.
217228
Returns
218229
-------
219230
out : numpy.ndarray
@@ -232,9 +243,15 @@ def generate_random_numpy_array(
232243

233244
# dtype=int is needed for 0d arrays
234245
size = numpy.prod(shape, dtype=int)
235-
a = numpy.random.uniform(low, high, size).astype(dtype)
236-
if numpy.issubdtype(a.dtype, numpy.complexfloating):
237-
a += 1j * numpy.random.uniform(low, high, size)
246+
if dtype == dpnp.bool:
247+
a = numpy.random.choice(
248+
[False, True], size, p=[1 - probability, probability]
249+
)
250+
else:
251+
a = numpy.random.uniform(low, high, size).astype(dtype)
252+
253+
if numpy.issubdtype(a.dtype, numpy.complexfloating):
254+
a += 1j * numpy.random.uniform(low, high, size)
238255

239256
a = a.reshape(shape)
240257
if hermitian and a.size > 0:

dpnp/tests/test_statistics.py

+90-8
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,101 @@ def test_corrcoef_scalar(self):
180180

181181

182182
class TestCorrelate:
183+
@staticmethod
184+
def _get_kwargs(mode=None, method=None):
185+
dpnp_kwargs = {}
186+
numpy_kwargs = {}
187+
if mode is not None:
188+
dpnp_kwargs["mode"] = mode
189+
numpy_kwargs["mode"] = mode
190+
if method is not None:
191+
dpnp_kwargs["method"] = method
192+
return dpnp_kwargs, numpy_kwargs
193+
194+
def setup_method(self):
195+
numpy.random.seed(0)
196+
183197
@pytest.mark.parametrize(
184198
"a, v", [([1], [1, 2, 3]), ([1, 2, 3], [1]), ([1, 2, 3], [1, 2])]
185199
)
186200
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
187201
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
188-
def test_correlate(self, a, v, mode, dtype):
202+
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
203+
def test_correlate(self, a, v, mode, dtype, method):
189204
an = numpy.array(a, dtype=dtype)
190205
vn = numpy.array(v, dtype=dtype)
191206
ad = dpnp.array(an)
192207
vd = dpnp.array(vn)
193208

194-
if mode is None:
195-
expected = numpy.correlate(an, vn)
196-
result = dpnp.correlate(ad, vd)
197-
else:
198-
expected = numpy.correlate(an, vn, mode=mode)
199-
result = dpnp.correlate(ad, vd, mode=mode)
209+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
210+
211+
expected = numpy.correlate(an, vn, **numpy_kwargs)
212+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
200213

201214
assert_dtype_allclose(result, expected)
202215

216+
@pytest.mark.parametrize("a_size", [1, 100, 10000])
217+
@pytest.mark.parametrize("v_size", [1, 100, 10000])
218+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
219+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
220+
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
221+
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
222+
an = generate_random_numpy_array(a_size, dtype, probability=0.9)
223+
vn = generate_random_numpy_array(v_size, dtype, probability=0.9)
224+
225+
ad = dpnp.array(an)
226+
vd = dpnp.array(vn)
227+
228+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
229+
230+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
231+
expected = numpy.correlate(an, vn, **numpy_kwargs)
232+
233+
rdtype = result.dtype
234+
if dpnp.issubdtype(rdtype, dpnp.integer):
235+
rdtype = dpnp.default_float_type(ad.device)
236+
237+
if method != "fft" and (
238+
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
239+
):
240+
# For 'direct' and 'auto' methods, we expect exact results for integer types
241+
assert_array_equal(result, expected)
242+
else:
243+
result = result.astype(rdtype)
244+
if method == "direct":
245+
expected = numpy.correlate(an, vn, **numpy_kwargs)
246+
# For 'direct' method we can use standard validation
247+
# acceptable error depends on the kernel size
248+
# while error grows linearly with the kernel size,
249+
# this empirically found formula provides a good balance
250+
# the resulting factor is 40 for kernel size = 1,
251+
# 400 for kernel size = 100 and 4000 for kernel size = 10000
252+
factor = int(40 * (min(a_size, v_size) ** 0.5))
253+
assert_dtype_allclose(result, expected, factor=factor)
254+
else:
255+
rtol = 1e-3
256+
atol = 1e-3
257+
258+
if rdtype == dpnp.float64 or rdtype == dpnp.complex128:
259+
rtol = 1e-6
260+
atol = 1e-6
261+
elif rdtype == dpnp.bool:
262+
result = result.astype(dpnp.int32)
263+
rdtype = result.dtype
264+
265+
expected = expected.astype(rdtype)
266+
267+
diff = numpy.abs(result.asnumpy() - expected)
268+
invalid = diff > atol + rtol * numpy.abs(expected)
269+
270+
# When using the 'fft' method, we might encounter outliers.
271+
# This usually happens when the resulting array contains values close to zero.
272+
# For these outliers, the relative error can be significant.
273+
# We can tolerate a few such outliers.
274+
max_outliers = 10 if expected.size > 1 else 0
275+
if invalid.sum() > max_outliers:
276+
assert_dtype_allclose(result, expected, factor=1000)
277+
203278
def test_correlate_mode_error(self):
204279
a = dpnp.arange(5)
205280
v = dpnp.arange(3)
@@ -240,7 +315,7 @@ def test_correlate_different_sizes(self, size):
240315
vd = dpnp.array(v)
241316

242317
expected = numpy.correlate(a, v)
243-
result = dpnp.correlate(ad, vd)
318+
result = dpnp.correlate(ad, vd, method="direct")
244319

245320
assert_dtype_allclose(result, expected, factor=20)
246321

@@ -251,6 +326,13 @@ def test_correlate_another_sycl_queue(self):
251326
with pytest.raises(ValueError):
252327
dpnp.correlate(a, v)
253328

329+
def test_correlate_unkown_method(self):
330+
a = dpnp.arange(5)
331+
v = dpnp.arange(3)
332+
333+
with pytest.raises(ValueError):
334+
dpnp.correlate(a, v, method="unknown")
335+
254336

255337
class TestCov:
256338
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)