8000 Begin add support for tapers in array objs · tsbinns/mne-python@01c486c · GitHub
[go: up one dir, main page]

Skip to content

Commit 01c486c

Browse files
committed
Begin add support for tapers in array objs
1 parent 54f2a32 commit 01c486c

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

mne/time_frequency/tfr.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,6 @@ def __setstate__(self, state):
14211421

14221422
defaults = dict(
14231423
method="unknown",
1424-
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :],
14251424
baseline=None,
14261425
decim=1,
14271426
data_type="TFR",
@@ -1445,7 +1444,7 @@ def __setstate__(self, state):
14451444
unknown_class = Epochs if "epoch" in self._dims else Evoked
14461445
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class)
14471446
self._inst_type = inst_types[defaults["inst_type_str"]]
1448-
# sanity check data/freqs/times/info agreement
1447+
# sanity check data/freqs/times/info/weights agreement
14491448
self._check_state()
14501449

14511450
def __repr__(self):
@@ -1498,14 +1497,26 @@ def _check_compatibility(self, other):
14981497
raise RuntimeError(msg.format(problem, extra))
14991498

15001499
def _check_state(self):
1501-
"""Check data/freqs/times/info agreement during __setstate__."""
1500+
"""Check data/freqs/times/info/weights agreement during __setstate__."""
15021501
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
15031502
n_chan_info = len(self.info["chs"])
15041503
n_chan = self._data.shape[self._dims.index("channel")]
1504+
n_taper = (
1505+
self._data.shape[self._dims.index("taper")]
1506+
if "taper" in self._dims
1507+
else None
1508+
)
15051509
n_freq = self._data.shape[self._dims.index("freq")]
15061510
n_time = self._data.shape[self._dims.index("time")]
15071511
if n_chan_info != n_chan:
15081512
msg = msg.format("Channel", n_chan, "info", n_chan_info)
1513+
elif n_taper is not None:
1514+
if self._weights is None:
1515+
raise RuntimeError("Taper dimension in data, but no weights found.")
1516+
if n_taper != self._weights.shape[0]:
1517+
msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0])
1518+
elif n_freq != self._weights.shape[1]:
1519+
msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1])
15091520
elif n_freq != len(self.freqs):
15101521
msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
15111522
elif n_time != len(self.times):
@@ -2775,6 +2786,7 @@ class AverageTFR(BaseTFR):
27752786
%(nave_tfr_attr)s
27762787
%(sfreq_tfr_attr)s
27772788
%(shape_tfr_attr)s
2789+
%(weights_tfr_attr)s
27782790
27792791
See Also
27802792
--------
@@ -2891,6 +2903,10 @@ def __getstate__(self):
28912903

28922904
def __setstate__(self, state):
28932905
"""Unpack AverageTFR from serialized format."""
2906+
if state["data"].ndim != 3:
2907+
raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.")
2908+
# Set dims now since optional tapers makes it difficult to disentangle later
2909+
state["dims"] = ("channel", "freq", "time")
28942910
super().__setstate__(state)
28952911
self._comment = state.get("comment", "")
28962912
self._nave = state.get("nave", 1)
@@ -3046,6 +3062,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin):
30463062
%(selection_attr)s
30473063
%(sfreq_tfr_attr)s
30483064
%(shape_tfr_attr)s
3065+
%(weights_tfr_attr)s
30493066
30503067
See Also
30513068
--------
@@ -3130,8 +3147,15 @@ def __getstate__(self):
31303147

31313148
def __setstate__(self, state):
31323149
"""Unpack EpochsTFR from serialized format."""
3133-
if state["data"].ndim != 4:
3134-
raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.")
3150+
if state["data"].ndim not in [4, 5]:
3151+
raise ValueError(
3152+
f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}."
3153+
)
3154+
# Set dims now since optional tapers makes it difficult to disentangle later
3155+
state["dims"] = ("epoch", "channel")
3156+
if state["data"].ndim == 5:
3157+
state["dims"] += ("taper",)
3158+
state["dims"] += ("freq", "time")
31353159
super().__setstate__(state)
31363160
self._metadata = state.get("metadata", None)
31373161
n_epochs = self.shape[0]
@@ -3235,7 +3259,16 @@ def average(self, method="mean", *, dim="epochs", copy=False):
32353259
See discussion here:
32363260
32373261
https://github.com/scipy/scipy/pull/12676#issuecomment-783370228
3262+
3263+
Averaging is not supported for data containing a taper dimension.
32383264
"""
3265+
if "taper" in self._dims:
3266+
raise NotImplementedError(
3267+
"Averaging multitaper tapers across epochs, frequencies, or times is "
3268+
"not supported. If averaging across epochs, consider averaging the "
3269+
"epochs before computing the complex/phase spectrum."
3270+
)
3271+
32393272
_check_option("dim", dim, ("epochs", "freqs", "times"))
32403273
axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural
32413274

@@ -3607,6 +3640,7 @@ class EpochsTFRArray(EpochsTFR):
36073640
%(selection)s
36083641
%(drop_log)s
36093642
%(metadata_epochstfr)s
3643+
%(weights_tfr_array)s
36103644
36113645
Attributes
36123646
----------
@@ -3623,6 +3657,7 @@ class EpochsTFRArray(EpochsTFR):
36233657
%(selection_attr)s
36243658
%(sfreq_tfr_attr)s
36253659
%(shape_tfr_attr)s
3660+
%(weights_tfr_attr)s
36263661
36273662
See Also
36283663
--------
@@ -3645,6 +3680,7 @@ def __init__(
36453680
selection=None,
36463681
drop_log=None,
36473682
metadata=None,
3683+
weights=None,
36483684
):
36493685
state = dict(info=info, data=data, times=times, freqs=freqs)
36503686
optional = dict(
@@ -3655,6 +3691,7 @@ def __init__(
36553691
selection=selection,
36563692
drop_log=drop_log,
36573693
metadata=metadata,
3694+
weights=weights,
36583695
)
36593696
for name, value in optional.items():
36603697
if value is not None:
@@ -3697,6 +3734,7 @@ class RawTFR(BaseTFR):
36973734
method : str
36983735
The method used to compute the spectra (``'morlet'``, ``'multitaper'``
36993736
or ``'stockwell'``).
3737+
%(weights_tfr_attr)s
37003738
37013739
See Also
37023740
--------
@@ -3746,6 +3784,19 @@ def __init__(
37463784
**method_kw,
37473785
)
37483786

3787+
def __setstate__(self, state):
3788+
"""Unpack RawTFR from serialized format."""
3789+
if state["data"].ndim not in [3, 4]:
3790+
raise ValueError(
3791+
f"RawTFR data should be 3D or 4D, got {state['data'].ndim}."
3792+
)
3793+
# Set dims now since optional tapers makes it difficult to disentangle later
3794+
state["dims"] = ("channel",)
3795+
if state["data"].ndim == 4:
3796+
state["dims"] += ("taper",)
3797+
state["dims"] += ("freq", "time")
3798+
super().__setstate__(state)
3799+
37493800
def __getitem__(self, item):
37503801
"""Get RawTFR data.
37513802
@@ -3811,6 +3862,7 @@ class RawTFRArray(RawTFR):
38113862
%(times)s
38123863
%(freqs_tfr_array)s
38133864
%(method_tfr_array)s
3865+
%(weights_tfr_array)s
38143866
38153867
Attributes
38163868
----------
@@ -3821,6 +3873,7 @@ class RawTFRArray(RawTFR):
38213873
%(method_tfr_attr)s
38223874
%(sfreq_tfr_attr)s
38233875
%(shape_tfr_attr)s
3876+
%(weights_tfr_attr)s
38243877
38253878
See Also
38263879
--------
@@ -3838,10 +3891,13 @@ def __init__(
38383891
freqs,
38393892
*,
38403893
method=None,
3894+
weights=None,
38413895
):
38423896
state = dict(info=info, data=data, times=times, freqs=freqs)
3843-
if method is not None:
3844-
state["method"] = method
3897+
optional = dict(method=method, weights=weights)
3898+
for name, value in optional.items():
3899+
if value is not None:
3900+
state[name] = value
38453901
self.__setstate__(state)
38463902

38473903

mne/utils/docs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5008,6 +5008,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
50085008
solution.
50095009
"""
50105010

5011+
docdict["weight_tfr_array"] = """
5012+
weights : array of shape (n_tapers, n_freqs) | None
5013+
The weights for each taper. Must be provided if ``data`` has a taper dimension, such
5014+
as for complex or phase multitaper data.
5015+
"""
5016+
docdict["weight_tfr_attr"] = """
5017+
weights : array of shape (n_tapers, n_freqs) | None
5018+
The weights for each taper, if present in the data.
5019+
"""
5020+
50115021
docdict["window_psd"] = """\
50125022
window : str | float | tuple
50135023
Windowing function to use. See :func:`scipy.signal.get_window`.

0 commit comments

Comments
 (0)
0