@@ -2693,42 +2693,55 @@ def to_data_frame(
2693
2693
"""
2694
2694
# check pandas once here, instead of in each private utils function
2695
2695
pd = _check_pandas_installed () # noqa
2696
+ # triage for Epoch-derived or unaggregated spectra
2697
+ from_epo = isinstance (self , EpochsTFR )
2698
+ unagg_mt = "taper" in self ._dims
2696
2699
# arg checking
2697
2700
valid_index_args = ["time" , "freq" ]
2698
- if isinstance ( self , EpochsTFR ) :
2701
+ if from_epo :
2699
2702
valid_index_args .extend (["epoch" , "condition" ])
2700
2703
valid_time_formats = ["ms" , "timedelta" ]
2701
2704
index = _check_pandas_index_arguments (index , valid_index_args )
2702
2705
time_format = _check_time_format (time_format , valid_time_formats )
2703
2706
# get data
2704
2707
picks = _picks_to_idx (self .info , picks , "all" , exclude = ())
2705
2708
data , times , freqs = self .get_data (picks , return_times = True , return_freqs = True )
2706
- axis = self ._dims .index ("channel" )
2707
- if not isinstance ( self , EpochsTFR ) :
2709
+ ch_axis = self ._dims .index ("channel" )
2710
+ if not from_epo :
2708
2711
data = data [np .newaxis ] # add singleton "epochs" axis
2709
- axis += 1
2710
- n_epochs , n_picks , n_freqs , n_times = data .shape
2711
- # reshape to (epochs*freqs*times) x signals
2712
- data = np .moveaxis (data , axis , - 1 )
2713
- data = data .reshape (n_epochs * n_freqs * n_times , n_picks )
2712
+ ch_axis += 1
2713
+ if not unagg_mt :
2714
+ data = np .expand_dims (data , - 3 ) # add singleton "tapers" axis
2715
+ n_epochs , n_picks , n_tapers , n_freqs , n_times = data .shape
2716
+ # reshape to (epochs*tapers*freqs*times) x signals
2717
+ data = np .moveaxis (data , ch_axis , - 1 )
2718
+ data = data .reshape (n_epochs * n_tapers * n_freqs * n_times , n_picks )
2714
2719
# prepare extra columns / multiindex
2715
2720
mindex = list ()
2721
+ default_index = list ()
2716
2722
times = _convert_times (times , time_format , self .info ["meas_date" ])
2717
- times = np .tile (times , n_epochs * n_freqs )
2718
- freqs = np .tile (np .repeat (freqs , n_times ), n_epochs )
2723
+ times = np .tile (times , n_epochs * n_freqs * n_tapers )
2724
+ freqs = np .tile (np .repeat (freqs , n_times * n_tapers ), n_epochs )
2719
2725
mindex .append (("time" , times ))
2720
2726
mindex .append (("freq" , freqs ))
2721
- if isinstance (self , EpochsTFR ):
2722
- mindex .append (("epoch" , np .repeat (self .selection , n_times * n_freqs )))
2727
+ if from_epo :
2728
+ mindex .append (
2729
+ ("epoch" , np .repeat (self .selection , n_times * n_freqs * n_tapers ))
2730
+ )
2723
2731
rev_event_id = {v : k for k , v in self .event_id .items ()}
2724
2732
conditions = [rev_event_id [k ] for k in self .events [:, 2 ]]
2725
- mindex .append (("condition" , np .repeat (conditions , n_times * n_freqs )))
2733
+ mindex .append (
2734
+ ("condition" , np .repeat (conditions , n_times * n_freqs * n_tapers ))
2735
+ )
2736
+ default_index .extend (["condition" , "epoch" ])
2737
+ default_index .extend (["freq" , "time" ])
2738
+ if unagg_mt :
2739
+ name = "taper"
2740
+ taper_nums = np .tile (np .arange (n_tapers ), n_epochs * n_freqs * n_times )
2741
+ mindex .append ((name , taper_nums ))
2742
+ default_index .append (name )
2726
2743
assert all (len(mdx ) == len (mindex [0 ]) for mdx in mindex [1 :])
2727
2744
# build DataFrame
2728
- if isinstance (self , EpochsTFR ):
2729
- default_index = ["condition" , "epoch" , "freq" , "time" ]
2730
- else :
2731
- default_index = ["freq" , "time" ]
2732
2745
df = _build_data_frame (
2733
2746
self , data , picks , long_format , mindex , index , default_index = default_index
2734
2747
)
0 commit comments