Skip to content

Commit 6bfb1c7

Browse files
committed
avoid memory-overlap between input and output arrays
1 parent 230d8c1 commit 6bfb1c7

File tree

3 files changed

+76
-48
lines changed

3 files changed

+76
-48
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
### Fixed
1818
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
19+
* Fixed a bug when there is overlapping memory of input and output arrays [gh-216](https://github.com/IntelPython/mkl_fft/pull/216)
1920

2021
## [2.0.0] - 2025-06-03
2122

mkl_fft/_fft_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,10 @@ def _c2c_fftnd_impl(
384384
raise ValueError("Direction of FFT should +1 or -1")
385385

386386
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
387+
inplace_FFT = 0
388+
if x.dtype not in valid_dtypes:
389+
x = x.astype(np.complex128, copy=True)
390+
inplace_FFT = 1
387391
# _direct_fftnd requires complex type, and full-dimensional transform
388392
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
389393
_direct = s is None and axes is None
@@ -393,7 +397,7 @@ def _c2c_fftnd_impl(
393397
xs, xa = _cook_nd_args(x, s, axes)
394398
if _check_shapes_for_direct(xs, x.shape, xa):
395399
_direct = True
396-
_direct = _direct and x.dtype in valid_dtypes
400+
_direct = _direct
397401
else:
398402
_direct = False
399403

@@ -402,10 +406,11 @@ def _c2c_fftnd_impl(
402406
x,
403407
direction=direction,
404408
fsc=fsc,
409+
in_place=inplace_FFT,
405410
out=out,
406411
)
407412
else:
408-
if s is None and x.dtype in valid_dtypes:
413+
if s is None:
409414
x = np.asarray(x)
410415
if out is None:
411416
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
@@ -417,7 +422,7 @@ def _c2c_fftnd_impl(
417422
x,
418423
axes,
419424
_direct_fftnd,
420-
{"direction": direction, "fsc": fsc},
425+
{"direction": direction, "fsc": fsc, "in_place": inplace_FFT},
421426
res,
422427
)
423428
else:

mkl_fft/_pydfti.pyx

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import sys
2929

3030
import numpy as np
3131

32-
if np.lib.NumpyVersion(np.__version__) >= "2.0.0a0":
32+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
3333
from numpy._core._multiarray_tests import internal_overlap
3434
else:
3535
from numpy.core._multiarray_tests import internal_overlap
@@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
389389
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0)
390390
x_type = cnp.PyArray_TYPE(x_arr)
391391

392-
if out is not None:
393-
in_place = 0
394-
elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
392+
if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
395393
# we can operate in place if requested.
396394
if in_place:
397395
if not cnp.PyArray_ISONESEGMENT(x_arr):
@@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
416414
x_type = cnp.PyArray_TYPE(x_arr)
417415
in_place = 1
418416

417+
f_arr = None
418+
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
419+
f_type = cnp.NPY_CFLOAT
420+
else:
421+
f_type = cnp.NPY_CDOUBLE
422+
423+
if out is not None:
424+
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
425+
_validate_out_array(out, x, out_dtype, axis=axis_, n=n_)
426+
if x is out:
427+
in_place = 1
428+
elif (
429+
_get_element_strides(x) == _get_element_strides(out)
430+
and not np.shares_memory(x, out)
431+
):
432+
# out array that is used in OneMKL c2c FFT must have the same stride
433+
# as input array and must have no common elements with input array.
434+
# If these conditions are not met, we need to allocate a new array,
435+
# which is done later.
436+
# TODO: check to see if the same stride condition can be relaxed
437+
f_arr = <cnp.ndarray> out
438+
in_place = 0
439+
419440
if in_place:
420441
_cache_capsule = _tls_dfti_cache_capsule()
421442
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(
@@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
453474
ind[axis_] = slice(0, n_, None)
454475
x_arr = x_arr[tuple(ind)]
455476

456-
return x_arr
457-
else:
458-
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
459-
f_type = cnp.NPY_CFLOAT
477+
if out is not None:
478+
out[...] = x_arr
479+
return out
460480
else:
461-
f_type = cnp.NPY_CDOUBLE
462-
463-
if out is None:
481+
return x_arr
482+
else:
483+
if f_arr is None:
464484
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
465-
else:
466-
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
467-
_validate_out_array(out, x, out_dtype, axis=axis_, n=n_)
468-
# out array that is used in OneMKL c2c FFT must have the exact same
469-
# stride as input array. If not, we need to allocate a new array.
470-
# TODO: check to see if this condition can be relaxed
471-
if _get_element_strides(x) == _get_element_strides(out):
472-
f_arr = <cnp.ndarray> out
473-
else:
474-
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
475485

476486
# call out-of-place FFT
477487
_cache_capsule = _tls_dfti_cache_capsule()
@@ -612,9 +622,10 @@ def _r2c_fft1d_impl(
612622
# be compared directly.
613623
# TODO: currently instead of this condition, we check both input
614624
# and output to be c_contig or f_contig, relax this condition
625+
# In addition, input and output data sets must have no common elements
615626
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
616627
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
617-
if c_contig or f_contig:
628+
if c_contig or f_contig and not np.shares_memory(x, out):
618629
f_arr = <cnp.ndarray> out
619630
else:
620631
f_arr = _allocate_result(x_arr, f_shape, axis_, f_type)
@@ -715,9 +726,10 @@ def _c2r_fft1d_impl(
715726
# strides cannot be compared directly.
716727
# TODO: currently instead of this condition, we check both input
717728
# and output to be c_contig or f_contig, relax this condition
729+
# Also input and output data sets must have no common elements
718730
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
719731
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
720-
if c_contig or f_contig:
732+
if c_contig or f_contig and not np.shares_memory(x, out):
721733
f_arr = <cnp.ndarray> out
722734
else:
723735
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
@@ -755,13 +767,13 @@ def _c2r_fft1d_impl(
755767

756768

757769
def _direct_fftnd(
758-
x, direction=+1, double fsc=1.0, out=None
770+
x, direction=+1, double fsc=1.0, in_place=0, out=None
759771
):
760772
"""Perform n-dimensional FFT over all axes"""
761773
cdef int err
762774
cdef cnp.ndarray x_arr "xxnd_arrayObject"
763775
cdef cnp.ndarray f_arr "ffnd_arrayObject"
764-
cdef int in_place, x_type, f_type
776+
cdef int x_type, f_type
765777

766778
if direction not in [-1, +1]:
767779
raise ValueError("Direction of FFT should +1 or -1")
@@ -779,7 +791,7 @@ def _direct_fftnd(
779791
raise ValueError("An input argument x is not an array-like object")
780792

781793
# a copy was made, so we can work in place.
782-
in_place = 1 if _datacopied(x_arr, x) else 0
794+
in_place = 1 if _datacopied(x_arr, x) else in_place
783795

784796
x_type = cnp.PyArray_TYPE(x_arr)
785797
if (
@@ -798,15 +810,35 @@ def _direct_fftnd(
798810
assert x_type == cnp.NPY_CDOUBLE
799811
in_place = 1
800812

801-
if out is not None:
802-
in_place = 0
803-
804813
if in_place:
805814
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT:
806815
in_place = 1
807816
else:
808817
in_place = 0
809818

819+
f_arr = None
820+
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
821+
f_type = cnp.NPY_CDOUBLE
822+
else:
823+
f_type = cnp.NPY_CFLOAT
824+
825+
if out is not None:
826+
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
827+
_validate_out_array(out, x, out_dtype)
828+
if x is out:
829+
in_place = 1
830+
elif (
831+
_get_element_strides(x) == _get_element_strides(out)
832+
and not np.shares_memory(x, out)
833+
):
834+
# out array that is used in OneMKL c2c FFT must have the same stride
835+
# as input array and must have no common elements with input array.
836+
# If these conditions are not met, we need to allocate a new array,
837+
# which is done later.
838+
# TODO: check to see if the same stride condition can be relaxed
839+
f_arr = <cnp.ndarray> out
840+
in_place = 0
841+
810842
if in_place:
811843
if x_type == cnp.NPY_CDOUBLE:
812844
if direction == 1:
@@ -821,24 +853,14 @@ def _direct_fftnd(
821853
else:
822854
raise ValueError("An input argument x is not complex type array")
823855

824-
return x_arr
825-
else:
826-
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
827-
f_type = cnp.NPY_CDOUBLE
856+
if out is not None:
857+
out[...] = x_arr
858+
return out
828859
else:
829-
f_type = cnp.NPY_CFLOAT
830-
if out is None:
860+
return x_arr
861+
else:
862+
if f_arr is None:
831863
f_arr = _allocate_result(x_arr, -1, 0, f_type)
832-
else:
833-
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
834-
_validate_out_array(out, x, out_dtype)
835-
# out array that is used in OneMKL c2c FFT must have the exact same
836-
# stride as input array. If not, we need to allocate a new array.
837-
# TODO: check to see if this condition can be relaxed
838-
if _get_element_strides(x) == _get_element_strides(out):
839-
f_arr = <cnp.ndarray> out
840-
else:
841-
f_arr = _allocate_result(x_arr, -1, 0, f_type)
842864

843865
if x_type == cnp.NPY_CDOUBLE:
844866
if direction == 1:

0 commit comments

Comments
 (0)