Skip to content

Commit d6a58cb

Browse files
author
Mathieu Scheltienne
authored
[MRG] Add copy and channel selection for a Layout object (#12338)
1 parent 9a222ba commit d6a58cb

File tree

7 files changed

+282
-33
lines changed

7 files changed

+282
-33
lines changed
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Adding :meth:`mne.channels.Layout.copy` and :meth:`mne.channels.Layout.pick` to copy and select channels from a :class:`mne.channels.Layout` object. Plotting 2D topographies of evoked responses with :func:`mne.viz.plot_evoked_topo` with both arguments ``layout`` and ``exclude`` now ignores excluded channels from the :class:`mne.channels.Layout`. By `Mathieu Scheltienne`_.

doc/development/contributing.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ Describe your changes in the changelog
592592

593593
Include in your changeset a brief description of the change in the
594594
:ref:`changelog <whats_new>` using towncrier_ format, which aggregates small,
595-
properly-named ``.rst`` files to create a change log. This can be
595+
properly-named ``.rst`` files to create a changelog. This can be
596596
skipped for very minor changes like correcting typos in the documentation.
597597

598598
There are six separate sections for changes, based on change type.

mne/channels/layout.py

+136-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import logging
1414
from collections import defaultdict
15+
from copy import deepcopy
1516
from itertools import combinations
1617
from pathlib import Path
1718

@@ -28,8 +29,10 @@
2829
_check_option,
2930
_check_sphere,
3031
_clean_names,
32+
_ensure_int,
3133
fill_doc,
3234
logger,
35+
verbose,
3336
warn,
3437
)
3538
from ..viz.topomap import plot_layout
@@ -50,9 +53,9 @@ class Layout:
5053
pos : array, shape=(n_channels, 4)
5154
The unit-normalized positions of the channels in 2d
5255
(x, y, width, height).
53-
names : list
56+
names : list of str
5457
The channel names.
55-
ids : list
58+
ids : array-like of int
5659
The channel ids.
5760
kind : str
5861
The type of Layout (e.g. 'Vectorview-all').
@@ -62,9 +65,25 @@ def __init__(self, box, pos, names, ids, kind):
6265
self.box = box
6366
self.pos = pos
6467
self.names = names
65-
self.ids = ids
68+
self.ids = np.array(ids)
69+
if self.ids.ndim != 1:
70+
raise ValueError("The channel indices should be a 1D array-like.")
6671
self.kind = kind
6772

73+
def copy(self):
74+
"""Return a copy of the layout.
75+
76+
Returns
77+
-------
78+
layout : instance of Layout
79+
A deepcopy of the layout.
80+
81+
Notes
82+
-----
83+
.. versionadded:: 1.7
84+
"""
85+
return deepcopy(self)
86+
6887
def save(self, fname, overwrite=False):
6988
"""Save Layout to disk.
7089
@@ -135,6 +154,119 @@ def plot(self, picks=None, show_axes=False, show=True):
135154
"""
136155
return plot_layout(self, picks=picks, show_axes=show_axes, show=show)
137156

157+
@verbose
158+
def pick(self, picks=None, exclude=(), *, verbose=None):
159+
"""Pick a subset of channels.
160+
161+
Parameters
162+
----------
163+
%(picks_layout)s
164+
exclude : str | int | array-like of str or int
165+
Set of channels to exclude, only used when ``picks`` is set to ``'all'`` or
166+
``None``. Exclude will not drop channels explicitly provided in ``picks``.
167+
%(verbose)s
168+
169+
Returns
170+
-------
171+
layout : instance of Layout
172+
The modified layout.
173+
174+
Notes
175+
-----
176+
.. versionadded:: 1.7
177+
"""
178+
# TODO: all the picking functions operates on an 'info' object which is missing
179+
# for a layout, thus we have to do the extra work here. The logic below can be
180+
# replaced when https://github.com/mne-tools/mne-python/issues/11913 is solved.
181+
if (isinstance(picks, str) and picks == "all") or (picks is None):
182+
picks = deepcopy(self.names)
183+
apply_exclude = True
184+
elif isinstance(picks, str):
185+
picks = [picks]
186+
apply_exclude = False
187+
elif isinstance(picks, slice):
188+
try:
189+
picks = np.arange(len(self.names))[picks]
190+
except TypeError:
191+
raise TypeError(
192+
"If a slice is provided, it must be a slice of integers."
193+
)
194+
apply_exclude = False
195+
else:
196+
try:
197+
picks = [_ensure_int(picks)]
198+
except TypeError:
199+
picks = (
200+
list(picks) if isinstance(picks, (tuple, set)) else deepcopy(picks)
201+
)
202+
apply_exclude = False
203+
if apply_exclude:
204+
if isinstance(exclude, str):
205+
exclude = [exclude]
206+
else:
207+
try:
208+
exclude = [_ensure_int(exclude)]
209+
except TypeError:
210+
exclude = (
211+
list(exclude)
212+
if isinstance(exclude, (tuple, set))
213+
else deepcopy(exclude)
214+
)
215+
for var, var_name in ((picks, "picks"), (exclude, "exclude")):
216+
if var_name == "exclude" and not apply_exclude:
217+
continue
218+
if not isinstance(var, (list, tuple, set, np.ndarray)):
219+
raise TypeError(
220+
f"'{var_name}' must be a list, tuple, set or ndarray. "
221+
f"Got {type(var)} instead."
222+
)
223+
if isinstance(var, np.ndarray) and var.ndim != 1:
224+
raise ValueError(
225+
f"'{var_name}' must be a 1D array-like. Got {var.ndim}D instead."
226+
)
227+
for k, elt in enumerate(var):
228+
if isinstance(elt, str) and elt in self.names:
229+
var[k] = self.names.index(elt)
230+
continue
231+
elif isinstance(elt, str):
232+
raise ValueError(
233+
f"The channel name {elt} provided in {var_name} does not match "
234+
"any channels from the layout."
235+
)
236+
try:
237+
var[k] = _ensure_int(elt)
238+
except TypeError:
239+
raise TypeError(
240+
f"All elements in '{var_name}' must be integers or strings."
241+
)
242+
if not (0 <= var[k] < len(self.names)):
243+
raise ValueError(
244+
f"The value {elt} provided in {var_name} does not match any "
245+
f"channels from the layout. The layout has {len(self.names)} "
246+
"channels."
247+
)
248+
if len(var) != len(set(var)):
249+
warn(
250+
f"The provided '{var_name}' has duplicates which will be ignored.",
251+
RuntimeWarning,
252+
)
253+
picks = picks.astype(int) if isinstance(picks, np.ndarray) else picks
254+
exclude = exclude.astype(int) if isinstance(exclude, np.ndarray) else exclude
255+
if apply_exclude:
256+
picks = np.array(list(set(picks) - set(exclude)), dtype=int)
257+
if len(picks) == 0:
258+
raise RuntimeError(
259+
"The channel selection yielded no remaining channels. Please edit "
260+
"the arguments 'picks' and 'exclude' to include at least one "
261+
"channel."
262+
)
263+
else:
264+
picks = np.array(list(set(picks)), dtype=int)
265+
self.pos = self.pos[picks]
266+
self.ids = self.ids[picks]
267+
self.names = [self.names[k] for k in picks]
268+
return self
269+
138270

139271
def _read_lout(fname):
140272
"""Aux function."""
@@ -533,7 +665,7 @@ def find_layout(info, ch_type=None, exclude="bads"):
533665
idx = [ii for ii, name in enumerate(layout.names) if name not in exclude]
534666
layout.names = [layout.names[ii] for ii in idx]
535667
layout.pos = layout.pos[idx]
536-
layout.ids = [layout.ids[ii] for ii in idx]
668+
layout.ids = layout.ids[idx]
537669

538670
return layout
539671

mne/channels/tests/test_layout.py

+116-12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mne._fiff.constants import FIFF
2424
from mne._fiff.meas_info import _empty_info
2525
from mne.channels import (
26+
Layout,
2627
find_layout,
2728
make_eeg_layout,
2829
make_grid_layout,
@@ -94,6 +95,18 @@ def _get_test_info():
9495
return test_info
9596

9697

98+
@pytest.fixture(scope="module")
99+
def layout():
100+
"""Get a layout."""
101+
return Layout(
102+
(0.1, 0.2, 0.1, 1.2),
103+
pos=np.array([[0, 0, 0.1, 0.1], [0.2, 0.2, 0.1, 0.1], [0.4, 0.4, 0.1, 0.1]]),
104+
names=["0", "1", "2"],
105+
ids=[0, 1, 2],
106+
kind="test",
107+
)
108+
109+
97110
def test_io_layout_lout(tmp_path):
98111
"""Test IO with .lout files."""
99112
layout = read_layout(fname="Vectorview-all", scale=False)
@@ -224,23 +237,17 @@ def test_make_grid_layout(tmp_path):
224237

225238
def test_find_layout():
226239
"""Test finding layout."""
227-
pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep")
240+
with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"):
241+
find_layout(_get_test_info(), ch_type="meep")
228242

229243
sample_info = read_info(fif_fname)
230-
grads = pick_types(sample_info, meg="grad")
231-
sample_info2 = pick_info(sample_info, grads)
232-
233-
mags = pick_types(sample_info, meg="mag")
234-
sample_info3 = pick_info(sample_info, mags)
235-
236-
# mock new convention
244+
sample_info2 = pick_info(sample_info, pick_types(sample_info, meg="grad"))
245+
sample_info3 = pick_info(sample_info, pick_types(sample_info, meg="mag"))
237246
sample_info4 = copy.deepcopy(sample_info)
238-
for ii, name in enumerate(sample_info4["ch_names"]):
247+
for ii, name in enumerate(sample_info4["ch_names"]): # mock new convention
239248
new = name.replace(" ", "")
240249
sample_info4["chs"][ii]["ch_name"] = new
241-
242-
eegs = pick_types(sample_info, meg=False, eeg=True)
243-
sample_info5 = pick_info(sample_info, eegs)
250+
sample_info5 = pick_info(sample_info, pick_types(sample_info, meg=False, eeg=True))
244251

245252
lout = find_layout(sample_info, ch_type=None)
246253
assert lout.kind == "Vectorview-all"
@@ -404,3 +411,100 @@ def test_generate_2d_layout():
404411
# Make sure background image normalizing is correct
405412
lt_bg = generate_2d_layout(xy, bg_image=bg_image)
406413
assert_allclose(lt_bg.pos[:, :2].max(), xy.max() / float(sbg))
414+
415+
416+
def test_layout_copy(layout):
417+
"""Test copying a layout."""
418+
layout2 = layout.copy()
419+
assert_allclose(layout.pos, layout2.pos)
420+
assert layout.names == layout2.names
421+
layout2.names[0] = "foo"
422+
layout2.pos[0, 0] = 0.8
423+
assert layout.names != layout2.names
424+
assert layout.pos[0, 0] != layout2.pos[0, 0]
425+
426+
427+
@pytest.mark.parametrize(
428+
"picks, exclude",
429+
[
430+
([0, 1], ()),
431+
(["0", 1], ()),
432+
(None, ["2"]),
433+
(None, "2"),
434+
(None, [2]),
435+
(None, 2),
436+
("all", 2),
437+
("all", "2"),
438+
(slice(0, 2), ()),
439+
(("0", "1"), ("0", "1")),
440+
(("0", 1), ("0", "1")),
441+
(("0", 1), (0, "1")),
442+
(set(["0", 1]), ()),
443+
(set([0, 1]), set()),
444+
(None, set([2])),
445+
(np.array([0, 1]), ()),
446+
(None, np.array([2])),
447+
(np.array(["0", "1"]), ()),
448+
],
449+
)
450+
def test_layout_pick(layout, picks, exclude):
451+
"""Test selection of channels in a layout."""
452+
layout2 = layout.copy()
453+
layout2.pick(picks, exclude)
454+
assert layout2.names == layout.names[:2]
455+
assert_allclose(layout2.pos, layout.pos[:2, :])
456+
457+
458+
def test_layout_pick_more(layout):
459+
"""Test more channel selection in a layout."""
460+
layout2 = layout.copy()
461+
layout2.pick(0)
462+
assert len(layout2.names) == 1
463+
assert layout2.names[0] == layout.names[0]
464+
assert_allclose(layout2.pos, layout.pos[:1, :])
465+
466+
layout2 = layout.copy()
467+
layout2.pick("all", exclude=("0", "1"))
468+
assert len(layout2.names) == 1
469+
assert layout2.names[0] == layout.names[2]
470+
assert_allclose(layout2.pos, layout.pos[2:, :])
471+
472+
layout2 = layout.copy()
473+
layout2.pick("all", exclude=("0", 1))
474+
assert len(layout2.names) == 1
475+
assert layout2.names[0] == layout.names[2]
476+
assert_allclose(layout2.pos, layout.pos[2:, :])
477+
478+
479+
def test_layout_pick_errors(layout):
480+
"""Test validation of layout.pick."""
481+
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
482+
layout.pick(lambda x: x)
483+
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
484+
layout.pick(None, lambda x: x)
485+
with pytest.raises(TypeError, match="must be integers or strings"):
486+
layout.pick([0, lambda x: x])
487+
with pytest.raises(TypeError, match="must be integers or strings"):
488+
layout.pick(None, [0, lambda x: x])
489+
with pytest.raises(ValueError, match="does not match any channels"):
490+
layout.pick("foo")
491+
with pytest.raises(ValueError, match="does not match any channels"):
492+
layout.pick(None, "foo")
493+
with pytest.raises(ValueError, match="does not match any channels"):
494+
layout.pick(101)
495+
with pytest.raises(ValueError, match="does not match any channels"):
496+
layout.pick(None, 101)
497+
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
498+
layout.copy().pick(["0", "0"])
499+
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
500+
layout.copy().pick(["0", 0])
501+
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
502+
layout.copy().pick(None, ["0", "0"])
503+
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
504+
layout.copy().pick(None, ["0", 0])
505+
with pytest.raises(RuntimeError, match="selection yielded no remaining channels"):
506+
layout.copy().pick(None, ["0", "1", "2"])
507+
with pytest.raises(ValueError, match="must be a 1D array-like"):
508+
layout.copy().pick(None, np.array([[0, 1]]))
509+
with pytest.raises(TypeError, match="slice of integers"):
510+
layout.copy().pick(slice("2342342342", 0, 3), ())

0 commit comments

Comments
 (0)