@@ -29,7 +29,7 @@ import sys
29
29
30
30
import numpy as np
31
31
32
- if np.lib.NumpyVersion(np.__version__) >= " 2.0.0a0 " :
32
+ if np.lib.NumpyVersion(np.__version__) >= " 2.0.0 " :
33
33
from numpy._core._multiarray_tests import internal_overlap
34
34
else :
35
35
from numpy.core._multiarray_tests import internal_overlap
@@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
389
389
x_arr = _process_arguments(x, n, axis, & axis_, & n_, & in_place, & xnd, 0 )
390
390
x_type = cnp.PyArray_TYPE(x_arr)
391
391
392
- if out is not None :
393
- in_place = 0
394
- elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
392
+ if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
395
393
# we can operate in place if requested.
396
394
if in_place:
397
395
if not cnp.PyArray_ISONESEGMENT(x_arr):
@@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
416
414
x_type = cnp.PyArray_TYPE(x_arr)
417
415
in_place = 1
418
416
417
+ f_arr = None
418
+ if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
419
+ f_type = cnp.NPY_CFLOAT
420
+ else :
421
+ f_type = cnp.NPY_CDOUBLE
422
+
423
+ if out is not None :
424
+ out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
425
+ _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
426
+ if x is out:
427
+ in_place = 1
428
+ elif (
429
+ _get_element_strides(x) == _get_element_strides(out)
430
+ and not np.shares_memory(x, out)
431
+ ):
432
+ # out array that is used in OneMKL c2c FFT must have the same stride
433
+ # as input array and must have no common elements with input array.
434
+ # If these conditions are not met, we need to allocate a new array,
435
+ # which is done later.
436
+ # TODO: check to see if the same stride condition can be relaxed
437
+ f_arr = < cnp.ndarray> out
438
+ in_place = 0
439
+
419
440
if in_place:
420
441
_cache_capsule = _tls_dfti_cache_capsule()
421
442
_cache = < DftiCache * > cpython.pycapsule.PyCapsule_GetPointer(
@@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
453
474
ind[axis_] = slice (0 , n_, None )
454
475
x_arr = x_arr[tuple (ind)]
455
476
456
- return x_arr
457
- else :
458
- if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
459
- f_type = cnp.NPY_CFLOAT
477
+ if out is not None :
478
+ out[...] = x_arr
479
+ return out
460
480
else :
461
- f_type = cnp.NPY_CDOUBLE
462
-
463
- if out is None :
481
+ return x_arr
482
+ else :
483
+ if f_arr is None :
464
484
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
465
- else :
466
- out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
467
- _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
468
- # out array that is used in OneMKL c2c FFT must have the exact same
469
- # stride as input array. If not, we need to allocate a new array.
470
- # TODO: check to see if this condition can be relaxed
471
- if _get_element_strides(x) == _get_element_strides(out):
472
- f_arr = < cnp.ndarray> out
473
- else :
474
- f_arr = _allocate_result(x_arr, n_, axis_, f_type)
475
485
476
486
# call out-of-place FFT
477
487
_cache_capsule = _tls_dfti_cache_capsule()
@@ -612,9 +622,10 @@ def _r2c_fft1d_impl(
612
622
# be compared directly.
613
623
# TODO: currently instead of this condition, we check both input
614
624
# and output to be c_contig or f_contig, relax this condition
625
+ # In addition, input and output data sets must have no common elements
615
626
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
616
627
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
617
- if c_contig or f_contig:
628
+ if c_contig or f_contig and not np.shares_memory(x, out) :
618
629
f_arr = < cnp.ndarray> out
619
630
else :
620
631
f_arr = _allocate_result(x_arr, f_shape, axis_, f_type)
@@ -715,9 +726,10 @@ def _c2r_fft1d_impl(
715
726
# strides cannot be compared directly.
716
727
# TODO: currently instead of this condition, we check both input
717
728
# and output to be c_contig or f_contig, relax this condition
729
+ # Also input and output data sets must have no common elements
718
730
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
719
731
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
720
- if c_contig or f_contig:
732
+ if c_contig or f_contig and not np.shares_memory(x, out) :
721
733
f_arr = < cnp.ndarray> out
722
734
else :
723
735
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
@@ -755,13 +767,13 @@ def _c2r_fft1d_impl(
755
767
756
768
757
769
def _direct_fftnd (
758
- x , direction = + 1 , double fsc = 1.0 , out = None
770
+ x , direction = + 1 , double fsc = 1.0 , in_place = 0 , out = None
759
771
):
760
772
""" Perform n-dimensional FFT over all axes"""
761
773
cdef int err
762
774
cdef cnp.ndarray x_arr " xxnd_arrayObject"
763
775
cdef cnp.ndarray f_arr " ffnd_arrayObject"
764
- cdef int in_place, x_type, f_type
776
+ cdef int x_type, f_type
765
777
766
778
if direction not in [- 1 , + 1 ]:
767
779
raise ValueError (" Direction of FFT should +1 or -1" )
@@ -779,7 +791,7 @@ def _direct_fftnd(
779
791
raise ValueError (" An input argument x is not an array-like object" )
780
792
781
793
# a copy was made, so we can work in place.
782
- in_place = 1 if _datacopied(x_arr, x) else 0
794
+ in_place = 1 if _datacopied(x_arr, x) else in_place
783
795
784
796
x_type = cnp.PyArray_TYPE(x_arr)
785
797
if (
@@ -798,15 +810,35 @@ def _direct_fftnd(
798
810
assert x_type == cnp.NPY_CDOUBLE
799
811
in_place = 1
800
812
801
- if out is not None :
802
- in_place = 0
803
-
804
813
if in_place:
805
814
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT:
806
815
in_place = 1
807
816
else :
808
817
in_place = 0
809
818
819
+ f_arr = None
820
+ if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
821
+ f_type = cnp.NPY_CDOUBLE
822
+ else :
823
+ f_type = cnp.NPY_CFLOAT
824
+
825
+ if out is not None :
826
+ out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
827
+ _validate_out_array(out, x, out_dtype)
828
+ if x is out:
829
+ in_place = 1
830
+ elif (
831
+ _get_element_strides(x) == _get_element_strides(out)
832
+ and not np.shares_memory(x, out)
833
+ ):
834
+ # out array that is used in OneMKL c2c FFT must have the same stride
835
+ # as input array and must have no common elements with input array.
836
+ # If these conditions are not met, we need to allocate a new array,
837
+ # which is done later.
838
+ # TODO: check to see if the same stride condition can be relaxed
839
+ f_arr = < cnp.ndarray> out
840
+ in_place = 0
841
+
810
842
if in_place:
811
843
if x_type == cnp.NPY_CDOUBLE:
812
844
if direction == 1 :
@@ -821,24 +853,14 @@ def _direct_fftnd(
821
853
else :
822
854
raise ValueError (" An input argument x is not complex type array" )
823
855
824
- return x_arr
825
- else :
826
- if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
827
- f_type = cnp.NPY_CDOUBLE
856
+ if out is not None :
857
+ out[...] = x_arr
858
+ return out
828
859
else :
829
- f_type = cnp.NPY_CFLOAT
830
- if out is None :
860
+ return x_arr
861
+ else :
862
+ if f_arr is None :
831
863
f_arr = _allocate_result(x_arr, - 1 , 0 , f_type)
832
- else :
833
- out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
834
- _validate_out_array(out, x, out_dtype)
835
- # out array that is used in OneMKL c2c FFT must have the exact same
836
- # stride as input array. If not, we need to allocate a new array.
837
- # TODO: check to see if this condition can be relaxed
838
- if _get_element_strides(x) == _get_element_strides(out):
839
- f_arr = < cnp.ndarray> out
840
- else :
841
- f_arr = _allocate_result(x_arr, - 1 , 0 , f_type)
842
864
843
865
if x_type == cnp.NPY_CDOUBLE:
844
866
if direction == 1 :
0 commit comments