Skip to content

Commit 647d164

Browse files
committed
add an interface for dask.fft
1 parent 46c8b3f commit 647d164

File tree

12 files changed

+303
-8
lines changed

12 files changed

+303
-8
lines changed

.github/workflows/build-with-clang.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ jobs:
7373
- name: Run mkl_fft tests
7474
run: |
7575
source ${{ env.ONEAPI_ROOT }}/setvars.sh
76-
pip install scipy mkl-service pytest
76+
pip install pytest mkl-service scipy dask
7777
pytest -s -v --pyargs mkl_fft

.github/workflows/conda-package-cf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ jobs:
142142
- name: Install mkl_fft
143143
run: |
144144
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
145-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS
145+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy dask $CHANNELS
146146
# Test installed packages
147147
conda list -n ${{ env.TEST_ENV_NAME }}
148148
@@ -318,7 +318,7 @@ jobs:
318318
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
319319
SET PACKAGE_VERSION=%%F
320320
)
321-
SET "TEST_DEPENDENCIES=pytest scipy"
321+
SET "TEST_DEPENDENCIES=pytest scipy dask"
322322
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
323323
324324
- name: Report content of test environment

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
- name: Install mkl_fft
141141
run: |
142142
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
143-
conda create -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME=${{ env.PACKAGE_VERSION }} python=${{ matrix.python }} pytest $CHANNELS
143+
conda create -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME=${{ env.PACKAGE_VERSION }} python=${{ matrix.python }} pytest dask $CHANNELS
144144
if [[ "${{ matrix.python }}" != 3.9* ]]; then
145145
# Intel channel only has scipy=1.10 for Python 3.9, which needs mkl<2025
146146
# while scipy needs to install numpy and mkl_random and mkl_random-1.2.11 requires mkl>=2025
@@ -313,7 +313,7 @@ jobs:
313313
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
314314
SET PACKAGE_VERSION=%%F
315315
)
316-
SET "TEST_DEPENDENCIES=pytest"
316+
SET "TEST_DEPENDENCIES=pytest dask"
317317
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
318318
if ("${{ matrix.python }}" -ne "3.9") {
319319
conda install -n ${{ env.TEST_ENV_NAME }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010
* Enabled support of Python 3.13 [gh-164](https://github.com/IntelPython/mkl_fft/pull/164)
11+
* Added a new interface for FFT module of Dask accessible through `mkl_fft.interfaces.dask_fft` [gh-214](https://github.com/IntelPython/mkl_fft/pull/214)
1112

1213
### Changed
1314
* Replaced `fwd_scale` parameter with `norm` in `mkl_fft` [gh-189](https://github.com/IntelPython/mkl_fft/pull/189)

conda-recipe-cf/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ test:
3333
requires:
3434
- pytest
3535
- scipy >=1.10
36+
- dask
3637
imports:
3738
- mkl_fft
3839
- mkl_fft.interfaces
3940
- mkl_fft.interfaces.numpy_fft
4041
- mkl_fft.interfaces.scipy_fft
42+
- mkl_fft.interfaces.dask_fft
4143

4244
about:
4345
home: http://github.com/IntelPython/mkl_fft

conda-recipe/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ test:
3838
requires:
3939
- pytest
4040
- scipy >=1.10
41+
- dask
4142
imports:
4243
- mkl_fft
4344
- mkl_fft.interfaces
4445
- mkl_fft.interfaces.numpy_fft
4546
- mkl_fft.interfaces.scipy_fft
47+
- mkl_fft.interfaces.dask_fft
4648

4749
about:
4850
home: http://github.com/IntelPython/mkl_fft

mkl_fft/interfaces/README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Interfaces
2-
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy and SciPy.
2+
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask.
33

44
---
55

@@ -124,3 +124,43 @@ with mkl_fft.set_workers(4):
124124
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4
125125
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4
126126
```
127+
128+
---
129+
130+
## Dask interface - `mkl_fft.interfaces.dask_fft`
131+
132+
This interface is a drop-in replacement for the [`dask.fft`](https://dask.pydata.org/en/latest/array-api.html#fast-fourier-transforms) module and includes **all** the functions available there:
133+
134+
* complex-to-complex FFTs: `fft`, `ifft`, `fft2`, `ifft2`, `fftn`, `ifftn`.
135+
136+
* real-to-complex and complex-to-real FFTs: `rfft`, `irfft`, `rfft2`, `irfft2`, `rfftn`, `irfftn`.
137+
138+
* Hermitian FFTs: `hfft`, `ihfft`.
139+
140+
* Helper routines: `fft_wrap`, `fftfreq`, `rfftfreq`, `fftshift`, `ifftshift`. These routines serve as a fallback to the Dask implementation and are included for completeness.
141+
142+
The following example shows how to use this interface for calculating a 2D FFT.
143+
144+
```python
145+
import numpy, dask
146+
import mkl_fft.interfaces.dask_fft as dask_fft
147+
148+
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
149+
x = dask.array.from_array(a, chunks=(64, 64))
150+
lazy_res = dask_fft.fft(x)
151+
mkl_res = lazy_res.compute()
152+
np_res = numpy.fft.fft(a)
153+
numpy.allclose(mkl_res, np_res)
154+
# True
155+
156+
# There are two chunks in this example based on the size of input array (128, 64) and chunk size (64, 64)
157+
# to confirm that MKL FFT is called twice, turn on verbosity
158+
import mkl
159+
mkl.verbose(1)
160+
# True
161+
162+
mkl_res = lazy_res.compute() # MKL_VERBOSE FFT is shown twice below which means MKL FFT is called twice
163+
# MKL_VERBOSE oneMKL 2024.0 Update 2 Patch 2 Product build 20240823 for Intel(R) 64 architecture Intel(R) Advanced Vector Extensions 512 (Intel(R) AVX-512) with support for INT8, BF16, FP16 (limited) instructions, and Intel(R) Advanced Matrix Extensions (Intel(R) AMX) with INT8 and BF16, Lnx 3.80GHz intel_thread
164+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd000010e40) 432.84us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
165+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd480011300) 499.00us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
166+
```

mkl_fft/interfaces/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@
2626
from . import numpy_fft
2727

2828
try:
29-
import scipy.fft
29+
# check to see if scipy is installed
30+
import scipy
3031
except ImportError:
3132
pass
3233
else:
3334
from . import scipy_fft
35+
36+
try:
37+
# check to see if dask is installed
38+
import dask
39+
except ImportError:
40+
pass
41+
else:
42+
from . import dask_fft

mkl_fft/interfaces/dask_fft.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from dask.array.fft import fft_wrap, fftfreq, fftshift, ifftshift, rfftfreq
28+
29+
from . import numpy_fft as _numpy_fft
30+
31+
__all__ = [
32+
"fft",
33+
"ifft",
34+
"fft2",
35+
"ifft2",
36+
"fftn",
37+
"ifftn",
38+
"rfft",
39+
"irfft",
40+
"rfft2",
41+
"irfft2",
42+
"rfftn",
43+
"irfftn",
44+
"hfft",
45+
"ihfft",
46+
"fftshift",
47+
"ifftshift",
48+
"fftfreq",
49+
"rfftfreq",
50+
"fft_wrap",
51+
]
52+
53+
54+
fft = fft_wrap(_numpy_fft.fft)
55+
ifft = fft_wrap(_numpy_fft.ifft)
56+
fft2 = fft_wrap(_numpy_fft.fft2)
57+
ifft2 = fft_wrap(_numpy_fft.ifft2)
58+
fftn = fft_wrap(_numpy_fft.fftn)
59+
ifftn = fft_wrap(_numpy_fft.ifftn)
60+
rfft = fft_wrap(_numpy_fft.rfft)
61+
irfft = fft_wrap(_numpy_fft.irfft)
62+
rfft2 = fft_wrap(_numpy_fft.rfft2)
63+
irfft2 = fft_wrap(_numpy_fft.irfft2)
64+
rfftn = fft_wrap(_numpy_fft.rfftn)
65+
irfftn = fft_wrap(_numpy_fft.irfftn)
66+
hfft = fft_wrap(_numpy_fft.hfft)
67+
ihfft = fft_wrap(_numpy_fft.ihfft)

mkl_fft/tests/test_interfaces.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,19 @@
3434
except AttributeError:
3535
scipy_fft = None
3636

37+
try:
38+
dask_fft = mfi.dask_fft
39+
except AttributeError:
40+
dask_fft = None
41+
3742
interfaces = []
3843
ids = []
3944
if scipy_fft is not None:
4045
interfaces.append(scipy_fft)
4146
ids.append("scipy")
47+
if dask_fft is not None:
48+
interfaces.append(dask_fft)
49+
ids.append("dask")
4250
interfaces.append(mfi.numpy_fft)
4351
ids.append("numpy")
4452

@@ -189,3 +197,7 @@ def test_axes(interface):
189197
)
190198
def test_interface_helper_functions(interface, func):
191199
assert hasattr(interface, func)
200+
201+
202+
def test_dask_fftwrap():
203+
assert hasattr(mfi.dask_fft, "fft_wrap")

0 commit comments

Comments
 (0)