@@ -1421,7 +1421,6 @@ def __setstate__(self, state):
1421
1421
1422
1422
defaults = dict (
1423
1423
method = "unknown" ,
1424
- dims = ("epoch" , "channel" , "freq" , "time" )[- state ["data" ].ndim :],
1425
1424
baseline = None ,
1426
1425
decim = 1 ,
1427
1426
data_type = "TFR" ,
@@ -1445,7 +1444,7 @@ def __setstate__(self, state):
1445
1444
unknown_class = Epochs if "epoch" in self ._dims else Evoked
1446
1445
inst_types = dict (Raw = Raw , Epochs = Epochs , Evoked = Evoked , Unknown = unknown_class )
1447
1446
self ._inst_type = inst_types [defaults ["inst_type_str" ]]
1448
- # sanity check data/freqs/times/info agreement
1447
+ # sanity check data/freqs/times/info/weights agreement
1449
1448
self ._check_state ()
1450
1449
1451
1450
def __repr__ (self ):
@@ -1498,14 +1497,26 @@ def _check_compatibility(self, other):
1498
1497
raise RuntimeError (msg .format (problem , extra ))
1499
1498
1500
1499
def _check_state (self ):
1501
- """Check data/freqs/times/info agreement during __setstate__."""
1500
+ """Check data/freqs/times/info/weights agreement during __setstate__."""
1502
1501
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
1503
1502
n_chan_info = len (self .info ["chs" ])
1504
1503
n_chan = self ._data .shape [self ._dims .index ("channel" )]
1504
+ n_taper = (
1505
+ self ._data .shape [self ._dims .index ("taper" )]
1506
+ if "taper" in self ._dims
1507
+ else None
1508
+ )
1505
1509
n_freq = self ._data .shape [self ._dims .index ("freq" )]
1506
1510
n_time = self ._data .shape [self ._dims .index ("time" )]
1507
1511
if n_chan_info != n_chan :
1508
1512
msg = msg .format ("Channel" , n_chan , "info" , n_chan_info )
1513
+ elif n_taper is not None :
1514
+ if self ._weights is None :
1515
+ raise RuntimeError ("Taper dimension in data, but no weights found." )
1516
+ if n_taper != self ._weights .shape [0 ]:
1517
+ msg = msg .format ("Taper" , n_taper , "weights" , self ._weights .shape [0 ])
1518
+ elif n_freq != self ._weights .shape [1 ]:
1519
+ msg = msg .format ("Frequency" , n_freq , "weights" , self ._weights .shape [1 ])
1509
1520
elif n_freq != len (self .freqs ):
1510
1521
msg = msg .format ("Frequency" , n_freq , "freqs" , self .freqs .size )
1511
1522
elif n_time != len (self .times ):
@@ -2775,6 +2786,7 @@ class AverageTFR(BaseTFR):
2775
2786
%(nave_tfr_attr)s
2776
2787
%(sfreq_tfr_attr)s
2777
2788
%(shape_tfr_attr)s
2789
+ %(weights_tfr_attr)s
2778
2790
2779
2791
See Also
2780
2792
--------
@@ -2891,6 +2903,10 @@ def __getstate__(self):
2891
2903
2892
2904
def __setstate__ (self , state ):
2893
2905
"""Unpack AverageTFR from serialized format."""
2906
+ if state ["data" ].ndim != 3 :
2907
+ raise ValueError (f"RawTFR data should be 3D, got { state ['data' ].ndim } ." )
2908
+ # Set dims now since optional tapers makes it difficult to disentangle later
2909
+ state ["dims" ] = ("channel" , "freq" , "time" )
2894
2910
super ().__setstate__ (state )
2895
2911
self ._comment = state .get ("comment" , "" )
2896
2912
self ._nave = state .get ("nave" , 1 )
@@ -3046,6 +3062,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin):
3046
3062
%(selection_attr)s
3047
3063
%(sfreq_tfr_attr)s
3048
3064
%(shape_tfr_attr)s
3065
+ %(weights_tfr_attr)s
3049
3066
3050
3067
See Also
3051
3068
--------
@@ -3130,8 +3147,15 @@ def __getstate__(self):
3130
3147
3131
3148
def __setstate__ (self , state ):
3132
3149
"""Unpack EpochsTFR from serialized format."""
3133
- if state ["data" ].ndim != 4 :
3134
- raise ValueError (f"EpochsTFR data should be 4D, got { state ['data' ].ndim } ." )
3150
+ if state ["data" ].ndim not in [4 , 5 ]:
3151
+ raise ValueError (
3152
+ f"EpochsTFR data should be 4D or 5D, got { state ['data' ].ndim } ."
3153
+ )
3154
+ # Set dims now since optional tapers makes it difficult to disentangle later
3155
+ state ["dims" ] = ("epoch" , "channel" )
3156
+ if state ["data" ].ndim == 5 :
3157
+ state ["dims" ] += ("taper" ,)
3158
+ state ["dims" ] += ("freq" , "time" )
3135
3159
super ().__setstate__ (state )
3136
3160
self ._metadata = state .get ("metadata" , None )
3137
3161
n_epochs = self .shape [0 ]
@@ -3235,7 +3259,16 @@ def average(self, method="mean", *, dim="epochs", copy=False):
3235
3259
See discussion here:
3236
3260
3237
3261
https://github.com/scipy/scipy/pull/12676#issuecomment-783370228
3262
+
3263
+ Averaging is not supported for data containing a taper dimension.
3238
3264
"""
3265
+ if "taper" in self ._dims :
3266
+ raise NotImplementedError (
3267
+ "Averaging multitaper tapers across epochs, frequencies, or times is "
3268
+ "not supported. If averaging across epochs, consider averaging the "
3269
+ "epochs before computing the complex/phase spectrum."
3270
+ )
3271
+
3239
3272
_check_option ("dim" , dim , ("epochs" , "freqs" , "times" ))
3240
3273
axis = self ._dims .index (dim [:- 1 ]) # self._dims entries aren't plural
3241
3274
@@ -3607,6 +3640,7 @@ class EpochsTFRArray(EpochsTFR):
3607
3640
%(selection)s
3608
3641
%(drop_log)s
3609
3642
%(metadata_epochstfr)s
3643
+ %(weights_tfr_array)s
3610
3644
3611
3645
Attributes
3612
3646
----------
@@ -3623,6 +3657,7 @@ class EpochsTFRArray(EpochsTFR):
3623
3657
%(selection_attr)s
3624
3658
%(sfreq_tfr_attr)s
3625
3659
%(shape_tfr_attr)s
3660
+ %(weights_tfr_attr)s
3626
3661
3627
3662
See Also
3628
3663
--------
@@ -3645,6 +3680,7 @@ def __init__(
3645
3680
selection = None ,
3646
3681
drop_log = None ,
3647
3682
metadata = None ,
3683
+ weights = None ,
3648
3684
):
3649
3685
state = dict (info = info , data = data , times = times , freqs = freqs )
3650
3686
optional = dict (
@@ -3655,6 +3691,7 @@ def __init__(
3655
3691
selection = selection ,
3656
3692
drop_log = drop_log ,
3657
3693
metadata = metadata ,
3694
+ weights = weights ,
3658
3695
)
3659
3696
for name , value in optional .items ():
3660
3697
if value is not None :
@@ -3697,6 +3734,7 @@ class RawTFR(BaseTFR):
3697
3734
method : str
3698
3735
The method used to compute the spectra (``'morlet'``, ``'multitaper'``
3699
3736
or ``'stockwell'``).
3737
+ %(weights_tfr_attr)s
3700
3738
3701
3739
See Also
3702
3740
--------
@@ -3746,6 +3784,19 @@ def __init__(
3746
3784
** method_kw ,
3747
3785
)
3748
3786
3787
+ def __setstate__ (self , state ):
3788
+ """Unpack RawTFR from serialized format."""
3789
+ if state ["data" ].ndim not in [3 , 4 ]:
3790
+ raise ValueError (
3791
+ f"RawTFR data should be 3D or 4D, got { state ['data' ].ndim } ."
3792
+ )
3793
+ # Set dims now since optional tapers makes it difficult to disentangle later
3794
+ state ["dims" ] = ("channel" ,)
3795
+ if state ["data" ].ndim == 4 :
3796
+ state ["dims" ] += ("taper" ,)
3797
+ state ["dims" ] += ("freq" , "time" )
3798
+ super ().__setstate__ (state )
3799
+
3749
3800
def __getitem__ (self , item ):
3750
3801
"""Get RawTFR data.
3751
3802
@@ -3811,6 +3862,7 @@ class RawTFRArray(RawTFR):
3811
3862
%(times)s
3812
3863
%(freqs_tfr_array)s
3813
3864
%(method_tfr_array)s
3865
+ %(weights_tfr_array)s
3814
3866
3815
3867
Attributes
3816
3868
----------
@@ -3821,6 +3873,7 @@ class RawTFRArray(RawTFR):
3821
3873
%(method_tfr_attr)s
3822
3874
%(sfreq_tfr_attr)s
3823
3875
%(shape_tfr_attr)s
3876
+ %(weights_tfr_attr)s
3824
3877
3825
3878
See Also
3826
3879
--------
@@ -3838,10 +3891,13 @@ def __init__(
3838
3891
freqs ,
3839
3892
* ,
3840
3893
method = None ,
3894
+ weights = None ,
3841
3895
):
3842
3896
state = dict (info = info , data = data , times = times , freqs = freqs )
3843
- if method is not None :
3844
- state ["method" ] = method
3897
+ optional = dict (method = method , weights = weights )
3898
+ for name , value in optional .items ():
3899
+ if value is not None :
3900
+ state [name ] = value
3845
3901
self .__setstate__ (state )
3846
3902
3847
3903
0 commit comments