diff --git a/.travis.yml b/.travis.yml index f07772455..6c79c32f2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,28 +16,37 @@ matrix: - PYFLAKES=1 - PEP8=1 - NUMPYSPEC=numpy + - MPLSPEC=matplotlib before_install: - pip install pep8==1.5.1 - pip install pyflakes script: - PYFLAKES_NODOCTEST=1 pyflakes pywt demo | grep -E -v 'unable to detect undefined names|assigned to but never used|imported but unused|redefinition of unused' > test.out; cat test.out; test \! -s test.out - pep8 pywt demo - - python: 3.5 env: - NUMPYSPEC=numpy - - python: 3.4 + - MPLSPEC=matplotlib + - USE_WHEEL=1 + - os: linux + python: 3.4 env: - NUMPYSPEC=numpy - - python: 2.6 + - MPLSPEC=matplotlib + - USE_SDIST=1 + - os: linux + python: 2.6 env: - NUMPYSPEC="numpy==1.9.3" + - MPLSPEC="matplotlib<2" - python: 2.7 env: - NUMPYSPEC=numpy + - MPLSPEC=matplotlib - python: 3.5 env: - NUMPYSPEC=numpy + - MPLSPEC=matplotlib - REFGUIDE_CHECK=1 # run doctests only cache: pip @@ -52,8 +61,9 @@ before_install: - pip install --upgrade wheel # Set numpy version first, other packages link against it - pip install $NUMPYSPEC - - pip install Cython matplotlib nose coverage codecov + - pip install Cython $MPLSPEC nose coverage codecov futures - set -o pipefail + - if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi - | if [ "${REFGUIDE_CHECK}" == "1" ]; then pip install sphinx numpydoc @@ -62,7 +72,21 @@ before_install: script: # Define a fixed build dir so next step works - | - if [ "${REFGUIDE_CHECK}" == "1" ]; then + if [ "${USE_WHEEL}" == "1" ]; then + # Need verbose output or TravisCI will terminate after 10 minutes + pip wheel . -v + pip install PyWavelets*.whl -v + pushd demo + nosetests pywt + popd + elif [ "${USE_SDIST}" == "1" ]; then + python setup.py sdist + # Move out of source directory to avoid finding local pywt + pushd dist + pip install PyWavelets* -v + nosetests pywt + popd + elif [ "${REFGUIDE_CHECK}" == "1" ]; then pip install -e . -v python util/refguide_check.py --doctests else diff --git a/appveyor.yml b/appveyor.yml index 9d201315a..ca3dc2ae1 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -23,7 +23,7 @@ install: - "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install numpy --cache-dir c:\\tmp\\pip-cache" - "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install - Cython nose coverage matplotlib --cache-dir c:\\tmp\\pip-cache" + Cython nose coverage matplotlib futures --cache-dir c:\\tmp\\pip-cache" test_script: - "util\\appveyor\\build.cmd %PYTHON%\\python.exe setup.py build --build-lib build\\lib\\" diff --git a/demo/wp_scalogram.py b/demo/wp_scalogram.py index a0172eb65..6334d462e 100644 --- a/demo/wp_scalogram.py +++ b/demo/wp_scalogram.py @@ -40,7 +40,8 @@ # Show spectrogram and wavelet packet coefficients fig2 = plt.figure() ax2 = fig2.add_subplot(211) -ax2.specgram(data, NFFT=64, noverlap=32, cmap=cmap) +ax2.specgram(data, NFFT=64, noverlap=32, Fs=2, cmap=cmap, + interpolation='bilinear') ax2.set_title("Spectrogram of signal") ax3 = fig2.add_subplot(212) ax3.imshow(values, origin='upper', extent=[-1, 1, -1, 1], diff --git a/doc/release/0.5.1-notes.rst b/doc/release/0.5.1-notes.rst new file mode 100644 index 000000000..b04110f51 --- /dev/null +++ b/doc/release/0.5.1-notes.rst @@ -0,0 +1,34 @@ +============================== +PyWavelets 0.5.1 Release Notes +============================== + +PyWavelets 0.5.1 is a bug-fix release with no new features compared to 0.5.0 + + +Bugs Fixed +========== + +In release 0.5.0 the wrong edge mode was used for the following three +deprecated modes: ``ppd``, ``sp1``, and ``per``. All deprecated edge mode +names are now correctly converted to the corresponding new names. + +One-dimensional discrete wavelet transforms did not properly respect the +``axis`` argument for complex-valued data. Prior to this release, the last +axis was always transformed for arrays with complex dtype. This fix affects +``dwt``, ``idwt``, ``wavedec``, ``waverec``. + +Authors +======= + +* Gregory R. Lee + +Issues closed for v0.5.1 +------------------------ + +- `#245 `__: Keyword "per" for dwt extension mode + +Pull requests for v0.5.1 +------------------------ + +- `#244 `__: FIX: dwt, idwt with complex data now pass axis argument properly +- `#246 `__: fix bug in deprecated mode name conversion diff --git a/doc/release/0.5.2-notes.rst b/doc/release/0.5.2-notes.rst new file mode 100644 index 000000000..10f0264fd --- /dev/null +++ b/doc/release/0.5.2-notes.rst @@ -0,0 +1,52 @@ +============================== +PyWavelets 0.5.2 Release Notes +============================== + +PyWavelets 0.5.2 is a bug-fix release with no new features compared to 0.5.1. + + +Bugs Fixed +========== + +The ``pywt.data.nino`` data reader is now compatible with numpy 1.12. (#273) + +The ``wp_scalogram.py`` demo is now compatibile with matplotlib 2.0. (#276) + +Fixed a sporadic segmentation fault affecting stationary wavelet transforms of +multi-dimensional data. (#289) + +``idwtn`` now treats coefficients set to None to be treated as zeros (#291). +This makes the behavior consistent with its docstring as well as idwt2. +Previously this raised an error. + +The tests are now included when installing from wheels or when running +``python setup.py install``. (#292) + +A bug leading to a potential ``RuntimeError`` was fixed in ``waverec``. +This bug only affected transforms where the data was >1D and the transformed +axis was not the first axis of the array. (#294). + +Authors +======= + +* Ralf Gommers +* Gregory R. Lee + +Issues closed for v0.5.2 +------------------------ + +- `#280 `__: No tests found from installed version +- `#288 `__: RuntimeErrors and segfaults from swt2() in threaded environments +- `#290 `__: idwtn should treat coefficients set to None as zeros +- `#293 `__: bug in waverec of n-dimensional data when axis != 0 + +Pull requests for v0.5.2 +------------------------ + +- `#273 `__: fix non-integer index error +- `#276 `__: update wp_scalogram demo work with matplotlib 2.0 +- `#289 `__: fix memory leak in swt_axis +- `#291 `__: idwtn should allow coefficients to be set as None +- `#292 `__: MAINT: ensure tests are included in wheels +- `#294 `__: FIX: shape adjustment in waverec should not assume a transform along … +- `#295 `__: MAINT: fix readthedocs build issue, update numpy version specifier diff --git a/doc/source/regression/dwt-idwt.rst b/doc/source/regression/dwt-idwt.rst index 9b2bd18b0..e0ec8a17f 100644 --- a/doc/source/regression/dwt-idwt.rst +++ b/doc/source/regression/dwt-idwt.rst @@ -74,10 +74,7 @@ extension mode (please refer to the PyWavelets' documentation for the :ref:`extension modes ` available: >>> pywt.Modes.modes - ['zero', 'constant', 'symmetric', 'reflect', 'periodic', 'smooth', 'periodization'] - - >>> [int(pywt.dwt_coeff_len(len(x), w.dec_len, mode)) for mode in pywt.Modes.modes] - [6, 6, 6, 6, 6, 6, 4] + ['zero', 'constant', 'symmetric', 'periodic', 'smooth', 'periodization', 'reflect'] As you see in the above example, the :ref:`periodization ` (periodization) mode is slightly different from the others. It's aim when diff --git a/doc/source/regression/modes.rst b/doc/source/regression/modes.rst index 72b3e5e62..531ca071c 100644 --- a/doc/source/regression/modes.rst +++ b/doc/source/regression/modes.rst @@ -19,7 +19,7 @@ Import :mod:`pywt` first List of available signal extension :ref:`modes `: >>> print(pywt.Modes.modes) - ['zero', 'constant', 'symmetric', 'reflect', 'periodic', 'smooth', 'periodization'] + ['zero', 'constant', 'symmetric', 'periodic', 'smooth', 'periodization', 'reflect'] Invalid mode name should rise a :exc:`ValueError`: diff --git a/doc/source/release.0.5.1.rst b/doc/source/release.0.5.1.rst new file mode 100644 index 000000000..f4f6e7f1d --- /dev/null +++ b/doc/source/release.0.5.1.rst @@ -0,0 +1 @@ +.. include:: ../release/0.5.1-notes.rst diff --git a/doc/source/release.0.5.2.rst b/doc/source/release.0.5.2.rst new file mode 100644 index 000000000..a43094d6d --- /dev/null +++ b/doc/source/release.0.5.2.rst @@ -0,0 +1 @@ +.. include:: ../release/0.5.2-notes.rst diff --git a/doc/source/releasenotes.rst b/doc/source/releasenotes.rst index 965c1a0b1..6f73b549d 100644 --- a/doc/source/releasenotes.rst +++ b/doc/source/releasenotes.rst @@ -7,3 +7,5 @@ Release Notes release.0.3.0 release.0.4.0 release.0.5.0 + release.0.5.1 + release.0.5.2 diff --git a/pywt/_dwt.py b/pywt/_dwt.py index 1c5d05278..0374c8e09 100644 --- a/pywt/_dwt.py +++ b/pywt/_dwt.py @@ -126,8 +126,8 @@ def dwt(data, wavelet, mode='symmetric', axis=-1): """ if np.iscomplexobj(data): data = np.asarray(data) - cA_r, cD_r = dwt(data.real, wavelet, mode) - cA_i, cD_i = dwt(data.imag, wavelet, mode) + cA_r, cD_r = dwt(data.real, wavelet, mode, axis) + cA_i, cD_i = dwt(data.imag, wavelet, mode, axis) return (cA_r + 1j*cA_i, cD_r + 1j*cD_i) # accept array_like input; make a copy to ensure a contiguous array @@ -196,8 +196,8 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): elif cD is None: cA = np.asarray(cA) cD = np.zeros_like(cA) - return (idwt(cA.real, cD.real, wavelet, mode) + - 1j*idwt(cA.imag, cD.imag, wavelet, mode)) + return (idwt(cA.real, cD.real, wavelet, mode, axis) + + 1j*idwt(cA.imag, cD.imag, wavelet, mode, axis)) if cA is not None: dt = _check_dtype(cA) diff --git a/pywt/_extensions/_pywt.pyx b/pywt/_extensions/_pywt.pyx index f0422fdf6..b5be60f89 100644 --- a/pywt/_extensions/_pywt.pyx +++ b/pywt/_extensions/_pywt.pyx @@ -19,6 +19,7 @@ from libc.math cimport pow, sqrt import numpy as np +# Caution: order of _old_modes entries must match _Modes.modes below _old_modes = ['zpd', 'cpd', 'sym', @@ -89,8 +90,9 @@ class _Modes(object): smooth = common.MODE_SMOOTH periodization = common.MODE_PERIODIZATION - modes = ["zero", "constant", "symmetric", "reflect", "periodic", - "smooth", "periodization"] + # Caution: order in modes list below must match _old_modes above + modes = ["zero", "constant", "symmetric", "periodic", "smooth", + "periodization", "reflect"] def from_object(self, mode): if isinstance(mode, int): diff --git a/pywt/_extensions/_swt.pyx b/pywt/_extensions/_swt.pyx index c355b3c85..6eb0b0eef 100644 --- a/pywt/_extensions/_swt.pyx +++ b/pywt/_extensions/_swt.pyx @@ -102,8 +102,7 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, # memory-views do not support n-dimensional arrays, use np.ndarray instead cdef common.ArrayInfo data_info, output_info cdef np.ndarray cD, cA - # Explicit input_shape necessary to prevent memory leak - cdef size_t[::1] input_shape, output_shape + cdef size_t[::1] output_shape cdef size_t end_level = start_level + level cdef int i, retval @@ -122,28 +121,23 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, raise ValueError(msg) data = data.astype(_check_dtype(data), copy=False) - - input_shape = data.shape - output_shape = input_shape.copy() - output_shape[axis] = common.swt_buffer_length(data.shape[axis]) - if output_shape[axis] != input_shape[axis]: - raise RuntimeError("swt_axis assumes output_shape is the same as " - "input_shape") + # For SWT, the output matches the shape of the input + output_shape = data.shape data_info.ndim = data.ndim data_info.strides = data.strides data_info.shape = data.shape - cA = np.empty(output_shape, data.dtype) - output_info.ndim = cA.ndim - output_info.strides = cA.strides - output_info.shape = cA.shape + output_info.ndim = data.ndim ret = [] for i in range(start_level+1, end_level+1): - + cA = np.empty(output_shape, dtype=data.dtype) + cD = np.empty(output_shape, dtype=data.dtype) + # strides won't match data_info.strides if data is not C-contiguous + output_info.strides = cA.strides + output_info.shape = cA.shape if data.dtype == np.float64: - cA = np.zeros(output_shape, dtype=np.float64) with nogil: retval = c_wt.double_downcoef_axis( data.data, data_info, @@ -152,8 +146,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, common.COEF_APPROX, common.MODE_PERIODIZATION, i, common.SWT_TRANSFORM) if retval: - raise RuntimeError("C wavelet transform failed") - cD = np.zeros(output_shape, dtype=np.float64) + raise RuntimeError( + "C wavelet transform failed with error code %d" % retval) with nogil: retval = c_wt.double_downcoef_axis( data.data, data_info, @@ -162,9 +156,9 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, common.COEF_DETAIL, common.MODE_PERIODIZATION, i, common.SWT_TRANSFORM) if retval: - raise RuntimeError("C wavelet transform failed") + raise RuntimeError( + "C wavelet transform failed with error code %d" % retval) elif data.dtype == np.float32: - cA = np.zeros(output_shape, dtype=np.float32) with nogil: retval = c_wt.float_downcoef_axis( data.data, data_info, @@ -173,8 +167,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, common.COEF_APPROX, common.MODE_PERIODIZATION, i, common.SWT_TRANSFORM) if retval: - raise RuntimeError("C wavelet transform failed") - cD = np.zeros(output_shape, dtype=np.float32) + raise RuntimeError( + "C wavelet transform failed with error code %d" % retval) with nogil: retval = c_wt.float_downcoef_axis( data.data, data_info, @@ -183,7 +177,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, common.COEF_DETAIL, common.MODE_PERIODIZATION, i, common.SWT_TRANSFORM) if retval: - raise RuntimeError("C wavelet transform failed") + raise RuntimeError( + "C wavelet transform failed with error code %d" % retval) else: raise TypeError("Array must be floating point, not {}" .format(data.dtype)) @@ -191,7 +186,9 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, # previous approx coeffs are the data for the next level data = cA - data_info = output_info + # update data_info to match the new data array + data_info.strides = data.strides + data_info.shape = data.shape ret.reverse() return ret diff --git a/pywt/_extensions/c/wt.template.c b/pywt/_extensions/c/wt.template.c index 9f2468f1a..db3380fed 100644 --- a/pywt/_extensions/c/wt.template.c +++ b/pywt/_extensions/c/wt.template.c @@ -36,7 +36,7 @@ int CAT(TYPE, _downcoef_axis)(const TYPE * const restrict input, const ArrayInfo if (input_info.ndim != output_info.ndim) return 1; if (axis >= input_info.ndim) - return 1; + return 2; for (i = 0; i < input_info.ndim; ++i){ if (i == axis){ @@ -44,17 +44,17 @@ int CAT(TYPE, _downcoef_axis)(const TYPE * const restrict input, const ArrayInfo case DWT_TRANSFORM: if (dwt_buffer_length(input_info.shape[i], wavelet->dec_len, dwt_mode) != output_info.shape[i]) - return 1; + return 3; break; case SWT_TRANSFORM: if (swt_buffer_length(input_info.shape[i]) != output_info.shape[i]) - return 1; + return 4; break; } } else { if (input_info.shape[i] != output_info.shape[i]) - return 1; + return 5; } } @@ -160,7 +160,7 @@ int CAT(TYPE, _downcoef_axis)(const TYPE * const restrict input, const ArrayInfo cleanup: free(temp_input); free(temp_output); - return 2; + return 6; } diff --git a/pywt/_extensions/wavelets_list.pxi b/pywt/_extensions/wavelets_list.pxi index 1ae4e96f3..3e4d34776 100644 --- a/pywt/_extensions/wavelets_list.pxi +++ b/pywt/_extensions/wavelets_list.pxi @@ -5,7 +5,7 @@ # Mapping of wavelet names to the C backend codes -cdef extern from "c/wavelets.h": +cdef extern from "c/wavelets.h": ctypedef enum WAVELET_NAME: HAAR RBIO @@ -25,7 +25,7 @@ cdef extern from "c/wavelets.h": cdef __wname_to_code __wname_to_code = { "haar": (HAAR, 0), - + "db1": (DB, 1), "db2": (DB, 2), "db3": (DB, 3), @@ -35,7 +35,7 @@ __wname_to_code = { "db7": (DB, 7), "db8": (DB, 8), "db9": (DB, 9), - + "db10": (DB, 10), "db11": (DB, 11), "db12": (DB, 12), @@ -46,7 +46,7 @@ __wname_to_code = { "db17": (DB, 17), "db18": (DB, 18), "db19": (DB, 19), - + "db20": (DB, 20), "db21": (DB, 21), "db22": (DB, 22), @@ -57,7 +57,7 @@ __wname_to_code = { "db27": (DB, 27), "db28": (DB, 28), "db29": (DB, 29), - + "db30": (DB, 30), "db31": (DB, 31), "db32": (DB, 32), @@ -67,7 +67,7 @@ __wname_to_code = { "db36": (DB, 36), "db37": (DB, 37), "db38": (DB, 38), - + "sym2": (SYM, 2), "sym3": (SYM, 3), "sym4": (SYM, 4), @@ -76,7 +76,7 @@ __wname_to_code = { "sym7": (SYM, 7), "sym8": (SYM, 8), "sym9": (SYM, 9), - + "sym10": (SYM, 10), "sym11": (SYM, 11), "sym12": (SYM, 12), @@ -98,7 +98,7 @@ __wname_to_code = { "coif7": (COIF, 7), "coif8": (COIF, 8), "coif9": (COIF, 9), - + "coif10": (COIF, 10), "coif11": (COIF, 11), "coif12": (COIF, 12), @@ -142,7 +142,7 @@ __wname_to_code = { "rbio6.8": (RBIO, 68), "dmey": (DMEY, 0), - + "gaus1": (GAUS, 1), "gaus2": (GAUS, 2), "gaus3": (GAUS, 3), @@ -151,11 +151,11 @@ __wname_to_code = { "gaus6": (GAUS, 6), "gaus7": (GAUS, 7), "gaus8": (GAUS, 8), - + "mexh": (MEXH, 0), - + "morl": (MORL, 0), - + "cgau1": (CGAU, 1), "cgau2": (CGAU, 2), "cgau3": (CGAU, 3), @@ -164,11 +164,11 @@ __wname_to_code = { "cgau6": (CGAU, 6), "cgau7": (CGAU, 7), "cgau8": (CGAU, 8), - + "shan": (SHAN, 0), - + "fbsp": (FBSP, 0), - + "cmor": (CMOR, 0), } @@ -178,7 +178,7 @@ cdef __wfamily_list_short, __wfamily_list_long __wfamily_list_short = [ "haar", "db", "sym", "coif", "bior", "rbio", "dmey", "gaus", "mexh", "morl", "cgau", "shan", "fbsp", "cmor"] -_wfamily_list_long = [ +__wfamily_list_long = [ "Haar", "Daubechies", "Symlets", "Coiflets", "Biorthogonal", "Reverse biorthogonal", "Discrete Meyer (FIR Approximation)", "Gaussian", "Mexican hat wavelet", "Morlet wavelet", "Complex Gaussian wavelets", diff --git a/pywt/_multidim.py b/pywt/_multidim.py index 12964c503..e7396730f 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -79,7 +79,8 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): ---------- coeffs : tuple (cA, (cH, cV, cD)) A tuple with approximation coefficients and three - details coefficients 2D arrays like from `dwt2()` + details coefficients 2D arrays like from `dwt2`. If any of these + components are set to ``None``, it will be treated as zeros. wavelet : Wavelet object or name string Wavelet to use mode : str, optional @@ -106,10 +107,6 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): raise ValueError("Expected 2 axes") coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH} - - # drop the keys corresponding to value = None - coeffs = dict((k, v) for k, v in coeffs.items() if v is not None) - return idwtn(coeffs, wavelet, mode, axes) @@ -215,8 +212,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): Parameters ---------- coeffs: dict - Dictionary as in output of `dwtn`. Missing or None items - will be treated as zeroes. + Dictionary as in output of `dwtn`. Missing or ``None`` items + will be treated as zeros. wavelet : Wavelet object or name string Wavelet to use mode : str, optional @@ -240,6 +237,9 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): wavelet = Wavelet(wavelet) mode = Modes.from_object(mode) + # drop the keys corresponding to value = None + coeffs = dict((k, v) for k, v in coeffs.items() if v is not None) + # Raise error for invalid key combinations coeffs = _fix_coeffs(coeffs) diff --git a/pywt/_multilevel.py b/pywt/_multilevel.py index c0dd4d468..9bfc5a13b 100644 --- a/pywt/_multilevel.py +++ b/pywt/_multilevel.py @@ -142,8 +142,14 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): a, ds = coeffs[0], coeffs[1:] for d in ds: - if (a is not None) and (d is not None) and (len(a) == len(d) + 1): - a = a[:-1] + if (a is not None) and (d is not None): + try: + if a.shape[axis] == d.shape[axis] + 1: + a = a[[slice(s) for s in d.shape]] + elif a.shape[axis] != d.shape[axis]: + raise ValueError("coefficient shape mismatch") + except IndexError: + raise ValueError("Axis greater than coefficient dimensions") a = idwt(a, d, wavelet, mode, axis) return a diff --git a/pywt/data/_readers.py b/pywt/data/_readers.py index 8d6499352..d9c0b777d 100644 --- a/pywt/data/_readers.py +++ b/pywt/data/_readers.py @@ -174,10 +174,10 @@ def nino(): sst_csv = np.load(fname)['sst_csv'] # sst_csv = pd.read_csv("http://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii", sep=' ', skipinitialspace=True) # take only full years - n = np.floor(sst_csv.shape[0]/12.)*12. + n = int(np.floor(sst_csv.shape[0]/12.)*12.) # Building the mean of three mounth # the 4. column is nino 3 - sst = np.mean(np.reshape(np.array(sst_csv)[:n,4],(n/3,-1)),axis=1) + sst = np.mean(np.reshape(np.array(sst_csv)[:n, 4], (n//3, -1)), axis=1) sst = (sst - np.mean(sst)) / np.std(sst, ddof=1) dt = 0.25 diff --git a/pywt/tests/test_concurrent.py b/pywt/tests/test_concurrent.py new file mode 100644 index 000000000..4bece6620 --- /dev/null +++ b/pywt/tests/test_concurrent.py @@ -0,0 +1,119 @@ +""" +Tests used to verify running PyWavelets transforms in parallel via +concurrent.futures.ThreadPoolExecutor does not raise errors. +""" + +from __future__ import division, print_function, absolute_import + +import sys +import warnings +import multiprocessing +import numpy as np +from functools import partial +from numpy.testing import dec, run_module_suite, assert_array_equal + +import pywt + +try: + if sys.version_info[0] == 2: + import futures + else: + from concurrent import futures + max_workers = multiprocessing.cpu_count() + futures_available = True +except ImportError: + futures_available = False + + +def _assert_all_coeffs_equal(coefs1, coefs2): + # return True only if all coefficients of SWT or DWT match over all levels + if len(coefs1) != len(coefs2): + return False + for (c1, c2) in zip(coefs1, coefs2): + if isinstance(c1, tuple): + # for swt, swt2, dwt, dwt2, wavedec, wavedec2 + for a1, a2 in zip(c1, c2): + assert_array_equal(a1, a2) + elif isinstance(c1, dict): + # for swtn, dwtn, wavedecn + for k, v in c1.items(): + assert_array_equal(v, c2[k]) + else: + return False + return True + + +@dec.skipif(not futures_available) +def test_concurrent_swt(): + # tests error-free concurrent operation (see gh-288) + # swt on 1D data calls the Cython swt + # other cases call swt_axes + with warnings.catch_warnings(): + # can remove catch_warnings once the swt2 FutureWarning is removed + warnings.simplefilter('ignore', FutureWarning) + for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn], + [np.ones(8), np.eye(16), np.eye(16)]): + transform = partial(swt_func, wavelet='haar', level=1) + for _ in range(10): + arrs = [x.copy() for _ in range(100)] + with futures.ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(transform, arrs)) + + # validate result from one of the concurrent runs + expected_result = transform(x) + _assert_all_coeffs_equal(expected_result, results[-1]) + + +@dec.skipif(not futures_available) +def test_concurrent_wavedec(): + # wavedec on 1D data calls the Cython dwt_single + # other cases call dwt_axis + for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn], + [np.ones(8), np.eye(16), np.eye(16)]): + transform = partial(wavedec_func, wavelet='haar', level=1) + for _ in range(10): + arrs = [x.copy() for _ in range(100)] + with futures.ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(transform, arrs)) + + # validate result from one of the concurrent runs + expected_result = transform(x) + _assert_all_coeffs_equal(expected_result, results[-1]) + + +@dec.skipif(not futures_available) +def test_concurrent_dwt(): + # dwt on 1D data calls the Cython dwt_single + # other cases call dwt_axis + for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn], + [np.ones(8), np.eye(16), np.eye(16)]): + transform = partial(dwt_func, wavelet='haar') + for _ in range(10): + arrs = [x.copy() for _ in range(100)] + with futures.ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(transform, arrs)) + + # validate result from one of the concurrent runs + expected_result = transform(x) + _assert_all_coeffs_equal([expected_result, ], [results[-1], ]) + + +@dec.skipif(not futures_available) +def test_concurrent_cwt(): + time, sst = pywt.data.nino() + dt = time[1]-time[0] + transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor', + sampling_period=dt) + for _ in range(10): + arrs = [sst.copy() for _ in range(50)] + with futures.ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(transform, arrs)) + + # validate result from one of the concurrent runs + expected_result = transform(sst) + for a1, a2 in zip(expected_result, results[-1]): + assert_array_equal(a1, a2) + + +if __name__ == '__main__': + run_module_suite() diff --git a/pywt/tests/test_deprecations.py b/pywt/tests/test_deprecations.py index 8a9e12f4d..b4868efa6 100644 --- a/pywt/tests/test_deprecations.py +++ b/pywt/tests/test_deprecations.py @@ -1,4 +1,7 @@ -from numpy.testing import assert_warns +import warnings + +import numpy as np +from numpy.testing import assert_warns, run_module_suite, assert_array_equal import pywt @@ -50,11 +53,6 @@ def get_mode(Modes, name): assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode) -def test_MODES_attributes_usage_deprecation(): - for mode in old_modes: - assert_warns(DeprecationWarning, pywt.dwt, range(10), 'db3', mode) - - def test_MODES_deprecation_new(): def use_MODES_new(): return pywt.MODES.symmetric @@ -74,3 +72,22 @@ def use_MODES_new(): return getattr(pywt.MODES, 'symmetric') assert_warns(DeprecationWarning, use_MODES_new) + + +def test_mode_equivalence(): + old_new = [('zpd', 'zero'), + ('cpd', 'constant'), + ('sym', 'symmetric'), + ('ppd', 'periodic'), + ('sp1', 'smooth'), + ('per', 'periodization')] + x = np.arange(8.) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + for old, new in old_new: + assert_array_equal(pywt.dwt(x, 'db2', mode=old), + pywt.dwt(x, 'db2', mode=new)) + + +if __name__ == '__main__': + run_module_suite() diff --git a/pywt/tests/test_dwt_idwt.py b/pywt/tests/test_dwt_idwt.py index a84a3a539..70a1edb91 100644 --- a/pywt/tests/test_dwt_idwt.py +++ b/pywt/tests/test_dwt_idwt.py @@ -97,10 +97,14 @@ def test_dwt_coeff_len(): w = pywt.Wavelet('sym3') ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in pywt.Modes.modes] - assert_allclose(ln_modes, [6, 6, 6, 6, 6, 6, 4]) + + expected_result = [6, ] * len(pywt.Modes.modes) + expected_result[pywt.Modes.modes.index('periodization')] = 4 + + assert_allclose(ln_modes, expected_result) ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in pywt.Modes.modes] - assert_allclose(ln_modes, [6, 6, 6, 6, 6, 6, 4]) + assert_allclose(ln_modes, expected_result) def test_idwt_none_input(): @@ -142,6 +146,8 @@ def test_idwt_single_axis(): x = [[3, 7, 1, 1], [-2, 5, 4, 6]] + x = np.asarray(x) + x = x + 1j*x # test with complex data cA, cD = pywt.dwt(x, 'db2', axis=-1) x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1) diff --git a/pywt/tests/test_modes.py b/pywt/tests/test_modes.py index 4c0df5863..b2cf90dcc 100644 --- a/pywt/tests/test_modes.py +++ b/pywt/tests/test_modes.py @@ -9,8 +9,8 @@ def test_available_modes(): - modes = ['zero', 'constant', 'symmetric', 'reflect', - 'periodic', 'smooth', 'periodization'] + modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth', + 'periodization', 'reflect'] assert_equal(pywt.Modes.modes, modes) assert_equal(pywt.Modes.from_object('constant'), 2) diff --git a/pywt/tests/test_multidim.py b/pywt/tests/test_multidim.py index c12cd8857..9e59b388b 100644 --- a/pywt/tests/test_multidim.py +++ b/pywt/tests/test_multidim.py @@ -194,10 +194,6 @@ def test_error_on_invalid_keys(): d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH} assert_raises(ValueError, pywt.idwtn, d, wavelet) - # a key whose value is None - d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': None} - assert_raises(ValueError, pywt.idwtn, d, wavelet) - # mismatched key lengths d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH} assert_raises(ValueError, pywt.idwtn, d, wavelet) @@ -266,6 +262,40 @@ def test_idwtn_axes(): assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14) +def test_idwt2_none_coeffs(): + data = np.array([[0, 1, 2, 3], + [1, 1, 1, 1], + [1, 4, 2, 8]]) + data = data + 1j*data # test with complex data + cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1)) + + # verify setting coefficients to None is the same as zeroing them + cD = np.zeros_like(cD) + result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1)) + + cD = None + result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1)) + + assert_equal(result_zeros, result_none) + + +def test_idwtn_none_coeffs(): + data = np.array([[0, 1, 2, 3], + [1, 1, 1, 1], + [1, 4, 2, 8]]) + data = data + 1j*data # test with complex data + coefs = pywt.dwtn(data, 'haar', axes=(1, 1)) + + # verify setting coefficients to None is the same as zeroing them + coefs['dd'] = np.zeros_like(coefs['dd']) + result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1)) + + coefs['dd'] = None + result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1)) + + assert_equal(result_zeros, result_none) + + def test_idwt2_axes(): data = np.array([[0, 1, 2, 3], [1, 1, 1, 1], diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py index bd0c52acc..15a7adad4 100644 --- a/pywt/tests/test_multilevel.py +++ b/pywt/tests/test_multilevel.py @@ -486,6 +486,16 @@ def test_waverec_axes_subsets(): assert_allclose(rec, data, atol=1e-14) +def test_waverec_axis_db2(): + # test for fix to issue gh-293 + rstate = np.random.RandomState(0) + data = rstate.standard_normal((16, 16)) + for axis in [0, 1]: + coefs = pywt.wavedec(data, 'db2', axis=axis) + rec = pywt.waverec(coefs, 'db2', axis=axis) + assert_allclose(rec, data, atol=1e-14) + + def test_waverec2_axes_subsets(): rstate = np.random.RandomState(0) data = rstate.standard_normal((8, 8, 8)) @@ -528,6 +538,13 @@ def test_waverec_axis_error(): assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1) +def test_waverec_shape_mismatch_error(): + c = pywt.wavedec(np.ones(16), 'haar') + # truncate a detail coefficient to an incorrect shape + c[3] = c[3][:-1] + assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1) + + def test_wavedec2_axes_errors(): data = np.ones((4, 4)) # integer axes not allowed diff --git a/setup.py b/setup.py index 9d8d0eb2a..aece1b2bc 100755 --- a/setup.py +++ b/setup.py @@ -25,8 +25,8 @@ MAJOR = 0 MINOR = 5 -MICRO = 0 -ISRELEASED = True +MICRO = 3 +ISRELEASED = False VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) @@ -249,12 +249,14 @@ def install_for_development(self): version=get_version_info()[0], packages=['pywt', 'pywt._extensions', 'pywt.data'], - package_data={'pywt.data': ['*.npy', '*.npz']}, + package_data={'pywt.data': ['*.npy', '*.npz'], + 'pywt': ['tests/*.py', 'tests/data/*.npz', + 'tests/data/*.py']}, ext_modules=ext_modules, libraries=[c_lib], cmdclass={'develop': develop_build_clib}, test_suite='nose.collector', # A function is imported in setup.py, so not really useful - install_requires=["numpy"], + install_requires=["numpy>=1.9.1"], ) diff --git a/util/readthedocs/requirements.txt b/util/readthedocs/requirements.txt index e7e7f83ea..267b91c03 100644 --- a/util/readthedocs/requirements.txt +++ b/util/readthedocs/requirements.txt @@ -8,7 +8,7 @@ # fix the versions of numpy to force the use of numpy to use the whl # of the rackspace folder instead of trying to install from more recent # source tarball published on PyPI -numpy==1.8.1 +numpy==1.9.1 Cython==0.20.2 nose wheel