8000 Merge remote-tracking branch 'upstream/main' into ebm · emma-bailey/mne-python@72facf0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 72facf0

Browse files
committed
Merge remote-tracking branch 'upstream/main' into ebm
* upstream/main: [DOC] extend documentation for add_channels (mne-tools#13051) Add `combine_tfr` to API (mne-tools#13054) Add `combine_spectrum()` function and allow `grand_average()` to support `Spectrum` data (mne-tools#13058) BUG: Fix bug with helium anon (mne-tools#13056) [ENH] Add option to store and return TFR taper weights (mne-tools#12910)
2 parents ebda34e + 5fec4e0 commit 72facf0

20 files changed

+744
-209
lines changed

doc/api/time_frequency.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ Functions that operate on mne-python objects:
3131
.. autosummary::
3232
:toctree: ../generated/
3333

34+
combine_spectrum
35+
combine_tfr
3436
csd_tfr
3537
csd_fourier
3638
csd_multitaper
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_.

doc/changes/devel/13056.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_.

mne/_fiff/meas_info.py

Lines changed: 11 additions & 4 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
24932493
hi["meas_date"] = _ensure_meas_date_none_or_dt(
24942494
tuple(int(t) for t in tag.data),
24952495
)
2496+
if "meas_date" not in hi:
2497+
hi["meas_date"] = None
24962498
info["helium_info"] = hi
24972499
del hi
24982500

@@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
28792881
write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"])
28802882
if hi.get("orig_file_guid") is not None:
28812883
write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"])
2882-
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
2884+
if hi["meas_date"] is not None:
2885+
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
28832886
end_block(fid, FIFF.FIFFB_HELIUM)
28842887
del hi
28852888

@@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
29162919
_write_proc_history(fid, info)
29172920

29182921

2919-
@fill_doc
2920-
def write_info(fname, info, data_type=None, reset_range=True):
2922+
@verbose
2923+
def write_info(
2924+
fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None
2925+
):
29212926
"""Write measurement info in fif file.
29222927
29232928
Parameters
@@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True):
29312936
raw data.
29322937
reset_range : bool
29332938
If True, info['chs'][k]['range'] will be set to unity.
2939+
%(overwrite)s
2940+
%(verbose)s
29342941
"""
2935-
with start_and_end_file(fname) as fid:
2942+
with start_and_end_file(fname, overwrite=overwrite) as fid:
29362943
start_block(fid, FIFF.FIFFB_MEAS)
29372944
write_meas_info(fid, info, data_type, reset_range)
29382945
end_block(fid, FIFF.FIFFB_MEAS)

mne/_fiff/tests/test_meas_info.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ def test_read_write_info(tmp_path):
306306
gantry_angle = info["gantry_angle"]
307307

308308
meas_id = info["meas_id"]
309-
write_info(temp_file, info)
309+
with pytest.raises(FileExistsError, match="Destination file exists"):
310+
write_info(temp_file, info)
311+
write_info(temp_file, info, overwrite=True)
310312
info = read_info(temp_file)
311313
assert info["proc_history"][0]["creator"] == creator
312314
assert info["hpi_meas"][0]["creator"] == creator
@@ -348,7 +350,7 @@ def test_read_write_info(tmp_path):
348350
info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
349351
fname = tmp_path / "test.fif"
350352
with pytest.raises(RuntimeError, match="must be between "):
351-
write_info(fname, info)
353+
write_info(fname, info, overwrite=True)
352354

353355

354356
@testing.requires_testing_data
@@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path):
377379
for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"):
378380
info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type])
379381
info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03]
380-
write_info(fname, info)
382+
write_info(fname, info, overwrite=True)
381383
info2 = read_info(fname)
382384
assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD
383385

@@ -585,7 +587,7 @@ def test_check_consistency():
585587
info2["subject_info"] = {"height": "bad"}
586588

587589

588-
def _test_anonymize_info(base_info):
590+
def _test_anonymize_info(base_info, tmp_path):
589591
"""Test that sensitive information can be anonymized."""
590592
pytest.raises(TypeError, anonymize_info, "foo")
591593
assert isinstance(base_info, Info)
@@ -692,14 +694,25 @@ def _adjust_back(e_i, dt):
692694
# exp 4 tests is a supplied daysback
693695
delta_t_3 = timedelta(days=223 + 364 * 500)
694696

697+
def _check_equiv(got, want, err_msg):
698+
__tracebackhide__ = True
699+
fname_temp = tmp_path / "test.fif"
700+
assert_object_equal(got, want, err_msg=err_msg)
701+
write_info(fname_temp, got, reset_range=False, overwrite=True)
702+
got = read_info(fname_temp)
703+
# this gets changed on write but that's expected
704+
with got._unlock():
705+
got["file_id"] = want["file_id"]
706+
assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)")
707+
695708
new_info = anonymize_info(base_info.copy())
696-
assert_object_equal(new_info, exp_info, err_msg="anon mismatch")
709+
_check_equiv(new_info, exp_info, err_msg="anon mismatch")
697710

698711
new_info = anonymize_info(base_info.copy(), keep_his=True)
699-
assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch")
712+
_check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch")
700713

701714
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
702-
assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch")
715+
_check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch")
703716

704717
with pytest.raises(RuntimeError, match="anonymize_info generated"):
705718
anonymize_info(base_info.copy(), daysback=delta_t_3.days)
@@ -726,15 +739,15 @@ def _adjust_back(e_i, dt):
726739
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
727740
else:
728741
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
729-
assert_object_equal(
742+
_check_equiv(
730743
new_info,
731744
exp_info_3,
732745
err_msg="meas_date=None daysback mismatch",
733746
)
734747

735748
with _record_warnings(): # meas_date is None
736749
new_info = anonymize_info(base_info.copy())
737-
assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch")
750+
_check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch")
738751

739752

740753
@pytest.mark.parametrize(
@@ -777,8 +790,8 @@ def _complete_info(info):
777790
height=2.0,
778791
)
779792
info["helium_info"] = dict(
780-
he_level_raw=12.34,
781-
helium_level=45.67,
793+
he_level_raw=np.float32(12.34),
794+
helium_level=np.float32(45.67),
782795
meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc),
783796
orig_file_guid="e",
784797
)
@@ -796,14 +809,13 @@ def _complete_info(info):
796809
machid=np.ones(2, int),
797810
secs=d[0],
798811
usecs=d[1],
799-
date=d,
800812
),
801813
experimenter="j",
802814
max_info=dict(
803-
max_st=[],
804-
sss_ctc=[],
805-
sss_cal=[],
806-
sss_info=dict(head_pos=None, in_order=8),
815+
max_st=dict(),
816+
sss_ctc=dict(),
817+
sss_cal=dict(),
818+
sss_info=dict(in_order=8),
807819
),
808820
date=d,
809821
),
@@ -830,8 +842,8 @@ def test_anonymize(tmp_path):
830842
# test mne.anonymize_info()
831843
events = read_events(event_name)
832844
epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None)
833-
_test_anonymize_info(raw.info)
834-
_test_anonymize_info(epochs.info)
845+
_test_anonymize_info(raw.info, tmp_path)
846+
_test_anonymize_info(epochs.info, tmp_path)
835847

836848
# test instance methods & I/O roundtrip
837849
for inst, keep_his in zip((raw, epochs), (True, False)):

mne/_fiff/write.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from scipy.sparse import csc_array, csr_array
1515

16-
from ..utils import _file_like, _validate_type, logger
16+
from ..utils import _check_fname, _file_like, _validate_type, logger
1717
from ..utils.numerics import _date_to_julian
1818
from .constants import FIFF
1919

@@ -277,7 +277,7 @@ def end_block(fid, kind):
277277
write_int(fid, FIFF.FIFF_BLOCK_END, kind)
278278

279279

280-
def start_file(fname, id_=None):
280+
def start_file(fname, id_=None, *, overwrite=True):
281281
"""Open a fif file for writing and writes the compulsory header tags.
282282
283283
Parameters
@@ -294,6 +294,7 @@ def start_file(fname, id_=None):
294294
fid = fname
295295
fid.seek(0)
296296
else:
297+
fname = _check_fname(fname, overwrite=overwrite)
297298
fname = str(fname)
298299
if op.splitext(fname)[1].lower() == ".gz":
299300
logger.debug("Writing using gzip")
@@ -311,9 +312,9 @@ def start_file(fname, id_=None):
311312

312313

313314
@contextmanager
314-
def start_and_end_file(fname, id_=None):
315+
def start_and_end_file(fname, id_=None, *, overwrite=True):
315316
"""Start and (if successfully written) close the file."""
316-
with start_file(fname, id_=id_) as fid:
317+
with start_file(fname, id_=id_, overwrite=overwrite) as fid:
317318
yield fid
318319
end_file(fid) # we only hit this line if the yield does not err
319320

mne/channels/channels.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,17 +661,21 @@ def _pick_projs(self):
661661
return self
662662

663663
def add_channels(self, add_list, force_update_info=False):
664-
"""Append new channels to the instance.
664+
"""Append new channels from other MNE objects to the instance.
665665
666666
Parameters
667667
----------
668668
add_list : list
669-
A list of objects to append to self. Must contain all the same
670-
type as the current object.
669+
A list of MNE objects to append to the current instance.
670+
The channels contained in the other instances are appended to the
671+
channels of the current instance. Therefore, all other instances
672+
must be of the same type as the current object.
673+
See notes on how to add data coming from an array.
671674
force_update_info : bool
672675
If True, force the info for objects to be appended to match the
673-
values in ``self``. This should generally only be used when adding
674-
stim channels for which important metadata won't be overwritten.
676+
values of the current instance. This should generally only be
677+
used when adding stim channels for which important metadata won't
678+
be overwritten.
675679
676680
.. versionadded:: 0.12
677681
@@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False):
688692
-----
689693
If ``self`` is a Raw instance that has been preloaded into a
690694
:obj:`numpy.memmap` instance, the memmap will be resized.
695+
696+
This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`,
697+
:class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a
698+
channel based on values of an np.ndarray, you need to create a
699+
:class:`~mne.io.RawArray`.
700+
See <https://mne.tools/mne-project-template/auto_examples/plot_mne_objects_from_arrays.html>`_
691701
"""
692702
# avoid circular imports
693703
from ..epochs import BaseEpochs

mne/time_frequency/__init__.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ __all__ = [
1111
"RawTFRArray",
1212
"Spectrum",
1313
"SpectrumArray",
14+
"combine_spectrum",
15+
"combine_tfr",
1416
"csd_array_fourier",
1517
"csd_array_morlet",
1618
"csd_array_multitaper",
@@ -61,6 +63,7 @@ from .spectrum import (
6163
EpochsSpectrumArray,
6264
Spectrum,
6365
SpectrumArray,
66+
combine_spectrum,
6467
read_spectrum,
6568
)
6669
from .tfr import (
@@ -71,6 +74,7 @@ from .tfr import (
7174
EpochsTFRArray,
7275
RawTFR,
7376
RawTFRArray,
77+
combine_tfr,
7478
fwhm,
7579
morlet,
7680
read_tfrs,

0 commit comments

Comments
 (0)
0