8000 Reapply "Update docstrings" · tsbinns/mne-python@2f9a4b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2f9a4b4

Browse files
committed
Reapply "Update docstrings"
This reverts commit 8c16716.
1 parent 51b8cd0 commit 2f9a4b4

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

mne/time_frequency/multitaper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,8 @@ def tfr_array_multitaper(
505505
coherence across trials.
506506
507507
return_weights : bool, default False
508-
If True, return the taper weights. Only applies if ``output="complex"``.
508+
If True, return the taper weights. Only applies if ``output='complex'`` or
509+
``'phase'``.
509510
510511
.. versionadded:: 1.9.0
511512
@@ -528,7 +529,7 @@ def tfr_array_multitaper(
528529
contain the average power and the imaginary values contain the
529530
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
530531
weights : array of shape (n_tapers, n_freqs)
531-
The taper weights. Only returned if ``output="complex"`` and
532+
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
532533
``return_weights=True``.
533534
534535
See Also

mne/time_frequency/tfr.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,42 +2693,55 @@ def to_data_frame(
26932693
"""
26942694
# check pandas once here, instead of in each private utils function
26952695
pd = _check_pandas_installed() # noqa
2696+
# triage for Epoch-derived or unaggregated spectra
2697+
from_epo = isinstance(self, EpochsTFR)
2698+
unagg_mt = "taper" in self._dims
26962699
# arg checking
26972700
valid_index_args = ["time", "freq"]
2698-
if isinstance(self, EpochsTFR):
2701+
if from_epo:
26992702
valid_index_args.extend(["epoch", "condition"])
27002703
valid_time_formats = ["ms", "timedelta"]
27012704
index = _check_pandas_index_arguments(index, valid_index_args)
27022705
time_format = _check_time_format(time_format, valid_time_formats)
27032706
# get data
27042707
picks = _picks_to_idx(self.info, picks, "all", exclude=())
27052708
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
2706-
axis = self._dims.index("channel")
2707-
if not isinstance(self, EpochsTFR):
2709+
ch_axis = self._dims.index("channel")
2710+
if not from_epo:
27082711
data = data[np.newaxis] # add singleton "epochs" axis
2709-
axis += 1
2710-
n_epochs, n_picks, n_freqs, n_times = data.shape
2711-
# reshape to (epochs*freqs*times) x signals
2712-
data = np.moveaxis(data, axis, -1)
2713-
data = data.reshape(n_epochs * n_freqs * n_times, n_picks)
2712+
ch_axis += 1
2713+
if not unagg_mt:
2714+
data = np.expand_dims(data, -3) # add singleton "tapers" axis
2715+
n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape
2716+
# reshape to (epochs*tapers*freqs*times) x signals
2717+
data = np.moveaxis(data, ch_axis, -1)
2718+
data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks)
27142719
# prepare extra columns / multiindex
27152720
mindex = list()
2721+
default_index = list()
27162722
times = _convert_times(times, time_format, self.info["meas_date"])
2717-
times = np.tile(times, n_epochs * n_freqs)
2718-
freqs = np.tile(np.repeat(freqs, n_times), n_epochs)
2723+
times = np.tile(times, n_epochs * n_freqs * n_tapers)
2724+
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
27192725
mindex.append(("time", times))
27202726
mindex.append(("freq", freqs))
2721-
if isinstance(self, EpochsTFR):
2722-
mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs)))
2727+
if from_epo:
2728+
mindex.append(
2729+
("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers))
2730+
)
27232731
rev_event_id = {v: k for k, v in self.event_id.items()}
27242732
conditions = [rev_event_id[k] for k in self.events[:, 2]]
2725-
mindex.append(("condition", np.repeat(conditions, n_times * n_freqs)))
2733+
mindex.append(
2734+
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
2735+
)
2736+
default_index.extend(["condition", "epoch"])
2737+
default_index.extend(["freq", "time"])
2738+
if unagg_mt:
2739+
name = "taper"
2740+
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
2741+
mindex.append((name, taper_nums))
2742+
default_index.append(name)
27262743
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
27272744
# build DataFrame
2728-
if isinstance(self, EpochsTFR):
2729-
default_index = ["condition", "epoch", "freq", "time"]
2730-
else:
2731-
default_index = ["freq", "time"]
27322745
df = _build_data_frame(
27332746
self, data, picks, long_format, mindex, index, default_index=default_index
27342747
)

0 commit comments

Comments
 (0)
0