8000 Fix to_data_frame bug with tapers · tsbinns/mne-python@80126a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 80126a7

Browse files
committed
Fix to_data_frame bug with tapers
1 parent de39d25 commit 80126a7

File tree

2 files changed

+73
-30
lines changed

2 files changed

+73
-30
lines changed

mne/time_frequency/tests/test_tfr.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,12 +1292,15 @@ def test_to_data_frame():
12921292
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
12931293
n_picks = len(ch_names)
12941294
ch_types = ["eeg"] * n_picks
1295+
n_tapers = 2
12951296
n_freqs = 5
12961297
n_times = 6
1297-
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
1298-
times = np.arange(6)
1298+
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
1299+
times = np.arange(n_times)
12991300
srate = 1000.0
1300-
freqs = np.arange(5)
1301+
freqs = np.arange(n_freqs)
1302+
tapers = np.arange(n_tapers)
1303+
weights = np.ones((n_tapers, n_freqs))
13011304
events = np.zeros((n_epos, 3), dtype=int)
13021305
events[:, 0] = np.arange(n_epos)
13031306
events[:, 2] = np.arange(5, 5 + n_epos)
@@ -1310,6 +1313,7 @@ def test_to_data_frame():
13101313
freqs=freqs,
13111314
events=events,
13121315
event_id=event_id,
1316+
weights=weights,
13131317
)
13141318
# test index checking
13151319
with pytest.raises(ValueError, match="options. Valid index options are"):
@@ -1321,32 +1325,51 @@ def test_to_data_frame():
13211325
# test wide format
13221326
df_wide = tfr.to_data_frame()
13231327
assert all(np.isin(tfr.ch_names, df_wide.columns))
1324-
assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns))
1328+
assert all(
1329+
np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns)
1330+
)
13251331
# test long format
13261332
df_long = tfr.to_data_frame(long_format=True)
1327-
expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value")
1333+
expected = (
1334+
"condition",
1335+
"epoch",
1336+
"freq",
1337+
"time",
1338+
"channel",
1339+
"ch_type",
1340+
"value",
1341+
"taper",
1342+
)
13281343
assert set(expected) == set(df_long.columns)
13291344
assert set(tfr.ch_names) == set(df_long["channel"])
13301345
assert len(df_long) == tfr.data.size
13311346
# test long format w/ index
13321347
df_long = tfr.to_data_frame(long_format=True, index=["freq"])
13331348
del df_wide, df_long
13341349
# test whether data is in correct shape
1335-
df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"])
1350+
df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"])
13361351
data = tfr.data
13371352
assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze())
13381353
# compare arbitrary observation:
13391354
assert (
1340-
df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0]
1341-
== data[1, 3, 1, 2]
1355+
df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0]
1356+
== data[1, 3, 1, 1, 2]
13421357
)
13431358

13441359
# Check also for AverageTFR:
1360+
# (remove taper dimension before averaging)
1361+
state = tfr.__getstate__()
1362+
state["data"] = state["data"][:, :, 0]
1363+
state["dims"] = ("epoch", "channel", "freq", "time")
1364+
state["weights"] = None
1365+
tfr = EpochsTFR(inst=state)
13451366
tfr = tfr.average()
13461367
with pytest.raises(ValueError, match="options. Valid index options are"):
13471368
tfr.to_data_frame(index=["epoch", "condition"])
13481369
with pytest.raises(ValueError, match='"epoch" is not a valid option'):
13491370
tfr.to_data_frame(index="epoch")
1371+
with pytest.raises(ValueError, match='"taper" is not a valid option'):
1372+
tfr.to_data_frame(index="taper")
13501373
with pytest.raises(TypeError, match="index must be `None` or a string "):
13511374
tfr.to_data_frame(index=np.arange(400))
13521375
# test wide format
@@ -1382,11 +1405,13 @@ def test_to_data_frame_index(index):
13821405
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
13831406
n_picks = len(ch_names)
13841407
ch_types = ["eeg"] * n_picks
1408+
n_tapers = 2
13851409
n_freqs = 5
13861410
n_times = 6
1387-
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
1388-
times = np.arange(6)
1389-
freqs = np.arange(5)
1411+
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
1412+
times = np.arange(n_times)
1413+
freqs = np.arange(n_freqs)
1414+
weights = np.ones((n_tapers, n_freqs))
13901415
events = np.zeros((n_epos, 3), dtype=int)
13911416
events[:, 0] = np.arange(n_epos)
13921417
events[:, 2] = np.arange(5, 8)
@@ -1399,14 +1424,15 @@ def test_to_data_frame_index(index):
13991424
freqs=freqs,
14001425
events=events,
14011426
event_id=event_id,
1427+
weights=weights,
14021428
)
14031429
df = tfr.to_data_frame(picks=[0, 2, 3], index=index)
14041430
# test index order/hierarchy preservation
14051431
if not isinstance(index, list):
14061432
index = [index]
14071433
assert list(df.index.names) == index
14081434
# test that non-indexed data were present as columns
1409-
non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index))
1435+
non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index))
14101436
if len(non_index):
14111437
assert all(np.isin(non_index, df.columns))
14121438

mne/time_frequency/tfr.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,6 +1837,7 @@ def get_data(
18371837
tmax=None,
18381838
return_times=False,
18391839
return_freqs=False,
1840+
return_tapers=False,
18401841
):
18411842
"""Get time-frequency data in NumPy array format.
18421843
@@ -1852,6 +1853,10 @@ def get_data(
18521853
return_freqs : bool
18531854
Whether to return the frequency bin values for the requested
18541855
frequency range. Default is ``False``.
1856+
return_tapers : bool
1857+
Whe 10000 ther to return the taper numbers. Default is ``False``.
1858+
1859+
.. versionadded:: 1.X.0
18551860
18561861
Returns
18571862
-------
@@ -1863,6 +1868,9 @@ def get_data(
18631868
freqs : array
18641869
The frequency values for the requested data range. Only returned if
18651870
``return_freqs`` is ``True``.
1871+
tapers : array | None
1872+
The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be
1873+
``None`` if a taper dimension is not present in the data.
18661874
18671875
Notes
18681876
-----
@@ -1900,7 +1908,13 @@ def get_data(
19001908
if return_freqs:
19011909
freqs = self._freqs[fmin_idx:fmax_idx]
19021910
out.append(freqs)
1903-
if not return_times and not return_freqs:
1911+
if return_tapers:
1912+
if "taper" in self._dims:
1913+
tapers = np.arange(self.shape[self._dims.index("taper")])
1914+
else:
1915+
tapers = None
1916+
out.append(tapers)
1917+
if not return_times and not return_freqs and not return_tapers:
19041918
return out[0]
19051919
return tuple(out)
19061920

@@ -2676,21 +2690,21 @@ def to_data_frame(
26762690
):
26772691
"""Export data in tabular structure as a pandas DataFrame.
26782692
2679-
Channels are converted to columns in the DataFrame. By default,
2680-
additional columns ``'time'``, ``'freq'``, ``'epoch'``, and
2681-
``'condition'`` (epoch event description) are added, unless ``index``
2682-
is not ``None`` (in which case the columns specified in ``index`` will
2683-
be used to form the DataFrame's index instead). ``'epoch'``, and
2684-
``'condition'`` are not supported for ``AverageTFR``.
2693+
Channels are converted to columns in the DataFrame. By default, additional
2694+
columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'``
2695+
(epoch event description) are added, unless ``index`` is not ``None`` (in which
2696+
case the columns specified in ``index`` will be used to for 10000 m the DataFrame's
2697+
index instead). ``'epoch'``, and ``'condition'`` are not supported for
2698+
``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is
2699+
present, such as for complex or phase multitaper data.
26852700
26862701
Parameters
26872702
----------
26882703
%(picks_all)s
26892704
%(index_df_epo)s
2690-
Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and
2691-
``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'``
2692-
for ``AverageTFR``.
2693-
Defaults to ``None``.
2705+
Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``,
2706+
and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and
2707+
``'taper'`` for ``AverageTFR``. Defaults to ``None``.
26942708
%(long_format_df_epo)s
26952709
%(time_format_df)s
26962710
@@ -2710,12 +2724,16 @@ def to_data_frame(
27102724
valid_index_args = ["time", "freq"]
27112725
if from_epo:
27122726
valid_index_args.extend(["epoch", "condition"])
2727+
if unagg_mt:
2728+
valid_index_args.append("taper")
27132729
valid_time_formats = ["ms", "timedelta"]
27142730
index = _check_pandas_index_arguments(index, valid_index_args)
27152731
time_format = _check_time_format(time_format, valid_time_formats)
27162732
# get data
27172733
picks = _picks_to_idx(self.info, picks, "all", exclude=())
2718-
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
2734+
data, times, freqs, tapers = self.get_data(
2735+
picks, return_times=True, return_freqs=True, return_tapers=True
2736+
)
27192737
ch_axis = self._dims.index("channel")
27202738
if not from_epo:
27212739
data = data[np.newaxis] # add singleton "epochs" axis
@@ -2731,7 +2749,7 @@ def to_data_frame(
27312749
default_index = list()
27322750
times = _convert_times(times, time_format, self.info["meas_date"])
27332751
times = np.tile(times, n_epochs * n_freqs * n_tapers)
2734-
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
2752+
freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers)
27352753
mindex.append(("time", times))
27362754
mindex.append(("freq", freqs))
27372755
if from_epo:
@@ -2744,12 +2762,11 @@ def to_data_frame(
27442762
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
27452763
)
27462764
default_index.extend(["condition", "epoch"])
2747-
default_index.extend(["freq", "time"])
27482765
if unagg_mt:
2749-
name = "taper"
2750-
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
2751-
mindex.append((name, taper_nums))
2752-
default_index.append(name)
2766+
tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times)
2767+
mindex.append(("taper", tapers))
2768+
default_index.append("taper")
2769+
default_index.extend(["freq", "time"])
27532770
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
27542771
# build DataFrame
27552772
df = _build_data_frame(

0 commit comments

Comments
 (0)
0