diff --git a/.circleci/config.yml b/.circleci/config.yml index 6618188621b..9f30ce574f5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -22,58 +22,6 @@ _check_skip: &check_skip fi jobs: - pytest-macos-arm64: - parameters: - scheduled: - type: string - default: "false" - macos: - xcode: "14.2.0" - resource_class: macos.m1.medium.gen1 - environment: - HOMEBREW_NO_AUTO_UPDATE: 1 - steps: - - checkout - - run: - <<: *check_skip - - run: - name: Install Python and dependencies - command: | - set -eo pipefail - brew install python@3.11 - which python - which pip - pip install --upgrade pip - pip install --upgrade --only-binary "numpy,scipy,dipy,statsmodels" -ve .[full,test_extra] - # 3D too slow on Apple's software renderer, and numba causes us problems - pip uninstall -y vtk pyvista pyvistaqt numba - mkdir -p test-results - echo "set -eo pipefail" >> $BASH_ENV - - run: - command: mne sys_info - - run: - command: ./tools/get_testing_version.sh && cat testing_version.txt - - restore_cache: - keys: - - data-cache-testing-{{ checksum "testing_version.txt" }} - - run: - command: python -c "import mne; mne.datasets.testing.data_path(verbose=True)" - - save_cache: - key: data-cache-testing-{{ checksum "testing_version.txt" }} - paths: - - ~/mne_data/MNE-testing-data # (2.5 G) - - run: - command: pytest -m "not slowtest" --tb=short --cov=mne --cov-report xml -vv mne - - run: - name: Prepare test data upload - command: cp -av junit-results.xml test-results/junit.xml - - store_test_results: - path: ./test-results - # Codecov orb has bugs on macOS (gpg issues) - # - codecov/upload - - run: - command: bash <(curl -s https://codecov.io/bash) - build_docs: parameters: scheduled: @@ -454,6 +402,7 @@ jobs: default: "false" docker: - image: cimg/base:current-22.04 + resource_class: large steps: - restore_cache: keys: @@ -496,8 +445,8 @@ jobs: deploy: - machine: - image: ubuntu-2004:202111-01 + docker: + - image: cimg/base:current-22.04 steps: - attach_workspace: at: /tmp/build @@ -591,20 +540,6 @@ workflows: only: - main - weekly: - jobs: - - pytest-macos-arm64: - name: pytest_macos_arm64_weekly - scheduled: "true" - triggers: - - schedule: - # "At 6:00 AM GMT every Monday" - cron: "0 6 * * 1" - filters: - branches: - only: - - main - monthly: jobs: - linkcheck: diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 3e511b1a194..c9248c01bb0 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,3 +1,4 @@ e81ec528a42ac687f3d961ed5cf8e25f236925b0 # black 12395f9d9cf6ea3c72b225b62e052dd0d17d9889 # YAML indentation d6d2f8c6a2ed4a0b27357da9ddf8e0cd14931b59 # isort +e7dd1588013179013a50d3f6b8e8f9ae0a185783 # ruff format diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 00000000000..8fb235d7045 --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000000..00a7b00c94e --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 53e02d49867..b7ab58dc917 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -10,4 +10,4 @@ This project and everyone participating in it is governed by the [MNE-Python's C ## How to contribute -Before contributing make sure you are familiar with [our contributing guide](https://mne.tools/dev/install/contributing.html). +Before contributing make sure you are familiar with [our contributing guide](https://mne.tools/dev/development/contributing.html). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ea102484a7f..1ca19246c37 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,5 +1,5 @@ Thanks for contributing a pull request! Please make sure you have read the -[contribution guidelines](https://mne.tools/dev/install/contributing.html) +[contribution guidelines](https://mne.tools/dev/development/contributing.html) before submitting. Please be aware that we are a loose team of volunteers so patience is diff --git a/.github/actions/rename_towncrier/rename_towncrier.py b/.github/actions/rename_towncrier/rename_towncrier.py new file mode 100755 index 00000000000..68971d1c83f --- /dev/null +++ b/.github/actions/rename_towncrier/rename_towncrier.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Adapted from action-towncrier-changelog +import json +import os +import re +import subprocess +import sys +from pathlib import Path + +from github import Github +from tomllib import loads + +event_name = os.getenv('GITHUB_EVENT_NAME', 'pull_request') +if not event_name.startswith('pull_request'): + print(f'No-op for {event_name}') + sys.exit(0) +if 'GITHUB_EVENT_PATH' in os.environ: + with open(os.environ['GITHUB_EVENT_PATH'], encoding='utf-8') as fin: + event = json.load(fin) + pr_num = event['number'] + basereponame = event['pull_request']['base']['repo']['full_name'] + real = True +else: # local testing + pr_num = 12318 # added some towncrier files + basereponame = "mne-tools/mne-python" + real = False + +g = Github(os.environ.get('GITHUB_TOKEN')) +baserepo = g.get_repo(basereponame) + +# Grab config from upstream's default branch +toml_cfg = loads(Path("pyproject.toml").read_text("utf-8")) + +config = toml_cfg["tool"]["towncrier"] +pr = baserepo.get_pull(pr_num) +modified_files = [f.filename for f in pr.get_files()] + +# Get types from config +types = [ent["directory"] for ent in toml_cfg["tool"]["towncrier"]["type"]] +type_pipe = "|".join(types) + +# Get files that potentially match the types +directory = toml_cfg["tool"]["towncrier"]["directory"] +assert directory.endswith("/"), directory + +file_re = re.compile(rf"^{directory}({type_pipe})\.rst$") +found_stubs = [ + f for f in modified_files if file_re.match(f) +] +for stub in found_stubs: + fro = stub + to = file_re.sub(rf"{directory}{pr_num}.\1.rst", fro) + print(f"Renaming {fro} to {to}") + if real: + subprocess.check_call(["mv", fro, to]) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml new file mode 100644 index 00000000000..2c0b693750e --- /dev/null +++ b/.github/workflows/autofix.yml @@ -0,0 +1,21 @@ +name: autofix.ci + +on: # yamllint disable-line rule:truthy + pull_request: + types: [opened, synchronize, labeled, unlabeled] + +permissions: + contents: read + +jobs: + autofix: + name: Autoupdate changelog entry + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install --upgrade towncrier pygithub + - run: python ./.github/actions/rename_towncrier/rename_towncrier.py + - uses: autofix-ci/action@ea32e3a12414e6d3183163c3424a7d7a8631ad84 diff --git a/.github/workflows/check_changelog.yml b/.github/workflows/check_changelog.yml new file mode 100644 index 00000000000..cf59c165258 --- /dev/null +++ b/.github/workflows/check_changelog.yml @@ -0,0 +1,15 @@ +name: Changelog + +on: # yamllint disable-line rule:truthy + pull_request: + types: [opened, synchronize, labeled, unlabeled] + +jobs: + changelog_checker: + name: Check towncrier entry in doc/changes/devel/ + runs-on: ubuntu-latest + steps: + - uses: larsoner/action-towncrier-changelog@co # revert to scientific-python @ 0.1.1 once bug is fixed + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BOT_USERNAME: changelog-bot diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a06f3336543..7f348f80778 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -56,7 +56,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun @@ -69,4 +69,4 @@ jobs: # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index dd85f1bb8a4..c9895e11919 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies @@ -28,7 +28,7 @@ jobs: pip install build twine - run: python -m build --sdist --wheel - run: twine check --strict dist/* - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: dist path: dist @@ -43,7 +43,7 @@ jobs: name: pypi url: https://pypi.org/p/mne steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist path: dist diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 09555ac5eb9..68979e20033 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,11 +4,9 @@ concurrency: cancel-in-progress: true on: # yamllint disable-line rule:truthy push: - branches: - - '*' + branches: ["main", "maint/*"] pull_request: - branches: - - '*' + branches: ["main", "maint/*"] permissions: contents: read @@ -20,18 +18,17 @@ jobs: timeout-minutes: 3 steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.11' - - uses: psf/black@stable - - uses: pre-commit/action@v3.0.0 + python-version: '3.12' + - uses: pre-commit/action@v3.0.1 bandit: name: Bandit needs: style runs-on: ubuntu-latest steps: - - uses: davidslusser/actions_python_bandit@v1.0.0 + - uses: davidslusser/actions_python_bandit@v1.0.1 with: src: "mne" options: "-c pyproject.toml -ll -r" @@ -57,22 +54,25 @@ jobs: matrix: include: - os: ubuntu-latest - python: '3.10' - kind: conda - - os: ubuntu-latest - python: '3.11' + python: '3.12' kind: pip-pre - - os: macos-latest - python: '3.8' + - os: ubuntu-latest + python: '3.12' + kind: conda + - os: macos-14 # arm64 + python: '3.12' + kind: mamba + - os: macos-latest # intel + python: '3.12' kind: mamba - os: windows-latest python: '3.10' kind: mamba - os: ubuntu-latest - python: '3.8' + python: '3.9' kind: minimal - os: ubuntu-20.04 - python: '3.8' + python: '3.9' kind: old steps: - uses: actions/checkout@v4 @@ -85,25 +85,33 @@ jobs: qt: true pyvista: false # Python (if pip) - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} if: startswith(matrix.kind, 'pip') # Python (if conda) - - uses: conda-incubator/setup-miniconda@v2 + - name: Remove numba and dipy + run: | # TODO: Remove when numba 0.59 and dipy 1.8 land on conda-forge + sed -i '/numba/d' environment.yml + sed -i '/dipy/d' environment.yml + sed -i 's/- mne$/- mne-base/' environment.yml + if: matrix.os == 'ubuntu-latest' && startswith(matrix.kind, 'conda') && matrix.python == '3.12' + - uses: mamba-org/setup-micromamba@v1 with: - python-version: ${{ env.PYTHON_VERSION }} environment-file: ${{ env.CONDA_ENV }} - activate-environment: mne - miniforge-version: latest - miniforge-variant: Mambaforge - use-mamba: ${{ matrix.kind != 'conda' }} + environment-name: mne + create-args: >- + python=${{ env.PYTHON_VERSION }} + mamba + fmt!=10.2.0 if: ${{ !startswith(matrix.kind, 'pip') }} + # Make sure we have the right Python + - run: python -c "import platform; assert platform.machine() == 'arm64', platform.machine()" + if: matrix.os == 'macos-14' - run: ./tools/github_actions_dependencies.sh # Minimal commands on Linux (macOS stalls) - run: ./tools/get_minimal_commands.sh if: ${{ startswith(matrix.os, 'ubuntu') }} - - run: ./tools/github_actions_install.sh - run: ./tools/github_actions_infos.sh # Check Qt - run: ./tools/check_qt_import.sh $MNE_QT_BACKEND @@ -112,11 +120,13 @@ jobs: run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/ if: matrix.kind == 'minimal' - run: ./tools/get_testing_version.sh - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: key: ${{ env.TESTING_VERSION }} path: ~/mne_data - run: ./tools/github_actions_download.sh - run: ./tools/github_actions_test.sh - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} if: success() diff --git a/.gitignore b/.gitignore index be502ec189a..118eebd9c76 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,11 @@ junit-results.xml *.tmproj *.png *.dat +# make sure we ship data files +!mne/data/**/*.dat +!mne/data/**/*.fif +!mne/data/**/*.fif.gz +!mne/icons/**/*.png .DS_Store events.eve foo-lh.label @@ -27,7 +32,6 @@ foo.lout bar.lout foobar.lout epochs_data.mat -memmap*.dat tmp-*.w tmtags auto_examples @@ -41,7 +45,6 @@ MNE-brainstorm-data* physionet-sleep-data* MEGSIM* build -mne/_version.py coverage htmlcov .cache/ @@ -63,14 +66,16 @@ tutorials/misc/report.h5 tutorials/io/fnirs.csv pip-log.txt .coverage* +!.coveragerc coverage.xml tags doc/coverages doc/samples -doc/*.dat doc/fil-result doc/optipng.exe sg_execution_times.rst +sg_api_usage.rst +sg_api_unused.dot cover *.html @@ -92,6 +97,7 @@ cover .venv/ venv/ *.json +!codemeta.json .hypothesis/ .ruff_cache/ .ipynb_checkpoints/ diff --git a/.mailmap b/.mailmap index e6d5377c402..d71df509cc2 100644 --- a/.mailmap +++ b/.mailmap @@ -114,11 +114,13 @@ Giorgio Marinato neurogima <76406896+neurogima@users Guillaume Dumas deep-introspection Guillaume Dumas Guillaume Dumas Hamid Maymandi <46011104+HamidMandi@users.noreply.github.com> Hamid <46011104+HamidMandi@users.noreply.github.com> +Hasrat Ali Arzoo <56307533+hasrat17@users.noreply.github.com> hasrat17 <56307533+hasrat17@users.noreply.github.com> Hongjiang Ye YE Hongjiang Hubert Banville hubertjb Hüseyin Orkun Elmas Hüseyin Hyonyoung Shin <55095699+mcvain@users.noreply.github.com> mcvain <55095699+mcvain@users.noreply.github.com> Ingoo Lee dlsrnsi +Ivo de Jong ivopascal Jaakko Leppakangas Jaakko Leppakangas Jaakko Leppakangas jaeilepp Jaakko Leppakangas jaeilepp @@ -220,6 +222,7 @@ Mikołaj Magnuski Mikolaj Magnuski mmagnuski Mohamed Sherif mohdsherif Mohammad Daneshzand <55800429+mdaneshzand@users.noreply.github.com> mdaneshzand <55800429+mdaneshzand@users.noreply.github.com> +Motofumi Fushimi <30593537+motofumi-fushimi@users.noreply.github.com> motofumi-fushimi <30593537+motofumi-fushimi@users.noreply.github.com> Natalie Klein natalieklein Nathalie Gayraud Nathalie Nathalie Gayraud Nathalie diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 436fbbb80a7..744e28edcf7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,29 +1,24 @@ repos: - - repo: https://github.com/psf/black - rev: 23.11.0 - hooks: - - id: black - args: [--quiet] - # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.3.7 hooks: - id: ruff - name: ruff mne + name: ruff lint mne args: ["--fix"] files: ^mne/ - - # Ruff tutorials and examples - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 - hooks: - id: ruff - name: ruff tutorials and examples + name: ruff lint mne preview + args: ["--fix", "--preview", "--select=NPY201"] + files: ^mne/ + - id: ruff + name: ruff lint doc, tutorials, and examples # D103: missing docstring in public function # D400: docstring first line must end with period args: ["--ignore=D103,D400", "--fix"] - files: ^tutorials/|^examples/ + files: ^doc/|^tutorials/|^examples/ + - id: ruff-format + files: ^mne/|^doc/|^tutorials/|^examples/ # Codespell - repo: https://github.com/codespell-project/codespell @@ -37,7 +32,7 @@ repos: # yamllint - repo: https://github.com/adrienverge/yamllint.git - rev: v1.33.0 + rev: v1.35.1 hooks: - id: yamllint args: [--strict, -c, .yamllint.yml] @@ -50,3 +45,12 @@ repos: additional_dependencies: - tomli files: ^doc/.*\.(rst|inc)$ + + # mypy + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + # Avoid the conflict between mne/__init__.py and mne/__init__.pyi by ignoring the former + exclude: ^mne/(beamformer|channels|commands|datasets|decoding|export|forward|gui|html_templates|inverse_sparse|io|minimum_norm|preprocessing|report|simulation|source_space|stats|time_frequency|utils|viz)?/?__init__\.py$ + additional_dependencies: ["numpy==1.26.2"] diff --git a/CITATION.cff b/CITATION.cff index c1850a2f55b..936f3f90677 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,9 +1,9 @@ cff-version: 1.2.0 title: "MNE-Python" message: "If you use this software, please cite both the software itself, and the paper listed in the preferred-citation field." -version: 1.6.0 -date-released: "2023-11-20" -commit: 498cf789685ede0b29e712a1e7220c69443e8744 +version: 1.7.0 +date-released: "2024-04-19" +commit: a3743420a8eef774dafd2908f0de89c4d37fcd01 doi: 10.5281/zenodo.592483 keywords: - MEG @@ -35,10 +35,10 @@ authors: given-names: Teon - family-names: Sassenhagen given-names: Jona - - family-names: Luessi - given-names: Martin - family-names: McCloy given-names: Daniel + - family-names: Luessi + given-names: Martin - family-names: King given-names: Jean-Remi - family-names: Höchenberger @@ -53,10 +53,10 @@ authors: given-names: Marijn - family-names: Wronkiewicz given-names: Mark - - family-names: Holdgraf - given-names: Chris - family-names: Rockhill given-names: Alex + - family-names: Holdgraf + given-names: Chris - family-names: Massich given-names: Joan - family-names: Bekhti @@ -117,12 +117,12 @@ authors: given-names: Martin - family-names: Foti given-names: Nick + - family-names: Huberty + given-names: Scott - family-names: Nangini given-names: Cathy - family-names: García Alanis given-names: José C - - family-names: Huberty - given-names: Scott - family-names: Hauk given-names: Olaf - family-names: Maddox @@ -165,6 +165,10 @@ authors: given-names: Christopher - family-names: Raimundo given-names: Félix + - family-names: Woessner + given-names: Jacob + - family-names: Kaneda + given-names: Michiru - family-names: Alday given-names: Phillip - family-names: Pari @@ -189,6 +193,10 @@ authors: given-names: Alexandre - family-names: Gütlin given-names: Dirk + - family-names: Heinila + given-names: Erkka + - family-names: Armeni + given-names: Kristijan - name: kjs - family-names: Weinstein given-names: Alejandro @@ -202,14 +210,10 @@ authors: given-names: Dmitrii - family-names: Peterson given-names: Erica - - family-names: Heinila - given-names: Erkka - family-names: Hanna given-names: Jevri - family-names: Houck given-names: Jon - - family-names: Kaneda - given-names: Michiru - family-names: Klein given-names: Natalie - family-names: Roujansky @@ -220,16 +224,18 @@ authors: given-names: Antti - family-names: Maess given-names: Burkhard + - family-names: Forster + given-names: Carina - family-names: O'Reilly given-names: Christian + - family-names: Welke + given-names: Dominik - family-names: Kolkhorst given-names: Henrich - family-names: Banville given-names: Hubert - family-names: Zhang given-names: Jack - - family-names: Woessner - given-names: Jacob - family-names: Maksymenko given-names: Kostiantyn - family-names: Clarke @@ -242,8 +248,6 @@ authors: given-names: Pierre-Antoine - family-names: Choudhary given-names: Saket - - family-names: Forster - given-names: Carina - family-names: Kim given-names: Cora - family-names: Klotzsche @@ -268,6 +272,8 @@ authors: given-names: Nick - family-names: Ruuskanen given-names: Santeri + - family-names: Herbst + given-names: Sophie - family-names: Radanovic given-names: Ana - family-names: Quinn @@ -278,8 +284,6 @@ authors: given-names: Basile - family-names: Welke given-names: Dominik - - family-names: Welke - given-names: Dominik - family-names: Stephen given-names: Emily - family-names: Hornberger @@ -294,22 +298,30 @@ authors: given-names: Giorgio - family-names: Anevar given-names: Hafeza + - family-names: Abdelhedi + given-names: Hamza - family-names: Sosulski given-names: Jan - family-names: Stout given-names: Jeff - family-names: Calder-Travis given-names: Joshua + - family-names: Zhu + given-names: Judy D - family-names: Eisenman given-names: Larry - family-names: Esch given-names: Lorenz - family-names: Dovgialo given-names: Marian + - family-names: Alibou + given-names: Nabil - family-names: Barascud given-names: Nicolas - family-names: Legrand given-names: Nicolas + - family-names: Kapralov + given-names: Nikolai - family-names: Falach given-names: Rotem - family-names: Deslauriers-Gauthier @@ -320,6 +332,10 @@ authors: given-names: Steve - family-names: Bierer given-names: Steven + - family-names: Binns + given-names: Thomas Samuel + - family-names: Stenner + given-names: Tristan - family-names: Férat given-names: Victor - family-names: Peterson @@ -350,8 +366,6 @@ authors: given-names: Gennadiy - family-names: O'Neill given-names: George - - family-names: Abdelhedi - given-names: Hamza - family-names: Schiratti given-names: Jean-Baptiste - family-names: Evans @@ -362,16 +376,14 @@ authors: given-names: Jordan - family-names: Teves given-names: Joshua - - family-names: Zhu - given-names: Judy D - - family-names: Armeni - given-names: Kristijan - family-names: Mathewson given-names: Kyle - family-names: Gwilliams given-names: Laura - family-names: Varghese given-names: Lenny + - family-names: Hamilton + given-names: Liberty - family-names: Gemein given-names: Lukas - family-names: Hecker @@ -393,6 +405,8 @@ authors: given-names: Niklas - family-names: Kozynets given-names: Oleh + - family-names: Molfese + given-names: Peter J - family-names: Ablin given-names: Pierre - family-names: Bertrand @@ -407,24 +421,20 @@ authors: given-names: Sena - family-names: Khan given-names: Sheraz - - family-names: Herbst - given-names: Sophie - family-names: Datta given-names: Sumalyo - family-names: Papadopoulo given-names: Theodore + - family-names: Donoghue + given-names: Thomas - family-names: Jochmann given-names: Thomas - - family-names: Binns - given-names: Thomas Samuel - family-names: Merk given-names: Timon - family-names: Flak given-names: Tod - family-names: Dupré la Tour given-names: Tom - - family-names: Stenner - given-names: Tristan - family-names: NessAiver given-names: Tziona - name: akshay0724 @@ -441,6 +451,8 @@ authors: given-names: Adina - family-names: Ciok given-names: Alex + - family-names: Kiefer + given-names: Alexander - family-names: Gilbert given-names: Andy - family-names: Pradhan @@ -515,6 +527,8 @@ authors: given-names: Evgeny - family-names: Zamberlan given-names: Federico + - family-names: Hofer + given-names: Florian - family-names: Pop given-names: Florin - family-names: Weber @@ -530,6 +544,8 @@ authors: given-names: Gonzalo - family-names: Maymandi given-names: Hamid + - family-names: Arzoo + given-names: Hasrat Ali - family-names: Sonntag given-names: Hermann - family-names: Ye @@ -540,10 +556,10 @@ authors: given-names: Hüseyin Orkun - family-names: Machairas given-names: Ilias - - family-names: Skelin - given-names: Ivan - family-names: Zubarev given-names: Ivan + - family-names: de Jong + given-names: Ivo - family-names: Kaczmarzyk given-names: Jakub - family-names: Zerfowski @@ -576,8 +592,6 @@ authors: given-names: Lau Møller - family-names: Barbosa given-names: Leonardo S - - family-names: Hamilton - given-names: Liberty - family-names: Alfine given-names: Lorenzo - family-names: Hejtmánek @@ -596,6 +610,8 @@ authors: given-names: Marcin - family-names: Henney given-names: Mark Alexander + - family-names: Oberg + given-names: Martin - family-names: Schulz given-names: Martin - family-names: van Harmelen @@ -639,8 +655,6 @@ authors: given-names: Padma - family-names: Silva given-names: Pedro - - family-names: Molfese - given-names: Peter J - family-names: Das given-names: Proloy - family-names: Chu @@ -661,6 +675,8 @@ authors: given-names: Reza - family-names: Koehler given-names: Richard + - family-names: Scholz + given-names: Richard - family-names: Stargardsky given-names: Riessarius - family-names: Oostenveld @@ -691,6 +707,8 @@ authors: given-names: Senwen - family-names: Antopolskiy given-names: Sergey + - family-names: Shirazi + given-names: Seyed (Yahya) - family-names: Wong given-names: Simeon - family-names: Wong @@ -711,8 +729,6 @@ authors: given-names: Svea Marie - family-names: Wang given-names: T - - family-names: Donoghue - given-names: Thomas - family-names: Moreau given-names: Thomas - family-names: Radman @@ -727,12 +743,17 @@ authors: given-names: Tommy - family-names: Anijärv given-names: Toomas Erik + - family-names: Kumaravel + given-names: Velu Prabhakar + - family-names: Turner + given-names: Will - family-names: Xia given-names: Xiaokai - family-names: Zuo given-names: Yiping - family-names: Zhang given-names: Zhi + - name: btkcodedev - name: buildqa - name: luzpaz preferred-citation: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bec834c7fdb..e653797b3ad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,5 +5,5 @@ MNE-Python is maintained by a community of scientists and research labs. The pro Users and contributors to MNE-Python are expected to follow our [code of conduct](https://github.com/mne-tools/.github/blob/main/CODE_OF_CONDUCT.md). -The [contributing guide](https://mne.tools/dev/install/contributing.html) has details on the preferred contribution workflow +The [contributing guide](https://mne.tools/dev/development/contributing.html) has details on the preferred contribution workflow and the recommended system configuration for a smooth contribution/development experience. diff --git a/LICENSE.txt b/LICENSE.txt index 6d98ee83925..c9197c42f20 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,24 +1,11 @@ -Copyright © 2011-2022, authors of MNE-Python -All rights reserved. +Copyright 2011-2023 MNE-Python authors -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY -DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 5a06c9c814b..00000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,86 +0,0 @@ -include *.rst -include LICENSE.txt -include SECURITY.md -include mne/__init__.py - -recursive-include examples *.py -recursive-include examples *.txt -recursive-include tutorials *.py -recursive-include tutorials *.txt - -recursive-include mne *.py -recursive-include mne *.pyi -recursive-include mne/data * -recursive-include mne/icons * -recursive-include mne/data/helmets * -recursive-include mne/data/image * -recursive-include mne/data/fsaverage * -include mne/datasets/_fsaverage/root.txt -include mne/datasets/_fsaverage/bem.txt -include mne/datasets/_infant/*.txt -include mne/datasets/_phantom/*.txt -include mne/data/dataset_checksums.txt -include mne/data/eegbci_checksums.txt - -recursive-include mne/html_templates *.html.jinja - -recursive-include mne/channels/data/layouts * -recursive-include mne/channels/data/montages * -recursive-include mne/channels/data/neighbors * - -recursive-include mne/gui/help *.json - -recursive-include mne/html *.js -recursive-include mne/html *.css - -recursive-include mne/report * - -recursive-include mne/io/artemis123/resources * - -recursive-include mne mne/datasets *.csv -include mne/io/edf/gdf_encodes.txt -include mne/datasets/sleep_physionet/SHA1SUMS - -### Exclude - -recursive-exclude examples/MNE-sample-data * -recursive-exclude examples/MNE-testing-data * -recursive-exclude examples/MNE-spm-face * -recursive-exclude examples/MNE-somato-data * -recursive-exclude tools * -exclude tools -exclude Makefile -exclude .coveragerc -exclude *.yml -exclude *.yaml -exclude .git-blame-ignore-revs -exclude ignore_words.txt -exclude .mailmap -exclude codemeta.json -exclude CITATION.cff -recursive-exclude mne *.pyc - -recursive-exclude doc * -recursive-exclude logo * - -exclude CONTRIBUTING.md -exclude CODE_OF_CONDUCT.md -exclude .github -exclude .github/CONTRIBUTING.md -exclude .github/ISSUE_TEMPLATE -exclude .github/ISSUE_TEMPLATE/blank.md -exclude .github/ISSUE_TEMPLATE/bug_report.md -exclude .github/ISSUE_TEMPLATE/feature_request.md -exclude .github/PULL_REQUEST_TEMPLATE.md - -# Test files - -recursive-exclude mne/io/tests/data * -recursive-exclude mne/io/besa/tests/data * -recursive-exclude mne/io/bti/tests/data * -recursive-exclude mne/io/edf/tests/data * -recursive-exclude mne/io/kit/tests/data * -recursive-exclude mne/io/brainvision/tests/data * -recursive-exclude mne/io/egi/tests/data * -recursive-exclude mne/io/nicolet/tests/data * -recursive-exclude mne/preprocessing/tests/data * diff --git a/Makefile b/Makefile index 2843e0193b5..8a79bf966c5 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ clean-cache: clean: clean-build clean-pyc clean-so clean-ctags clean-cache wheel: - $(PYTHON) -m build + $(PYTHON) -m build -w sample_data: @python -c "import mne; mne.datasets.sample.data_path(verbose=True);" @@ -54,9 +54,6 @@ pep: pre-commit codespell: # running manually @codespell --builtin clear,rare,informal,names,usage -w -i 3 -q 3 -S $(CODESPELL_SKIPS) --ignore-words=ignore_words.txt --uri-ignore-words-list=bu $(CODESPELL_DIRS) -check-manifest: - check-manifest -q --ignore .circleci/config.yml,doc,logo,mne/io/*/tests/data*,mne/io/tests/data,mne/preprocessing/tests/data,.DS_Store,mne/_version.py - check-readme: clean wheel twine check dist/* diff --git a/README.rst b/README.rst index e8690281bcb..153dcf0a5ef 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ .. -*- mode: rst -*- -|MNE|_ +|MNE| MNE-Python ========== @@ -43,7 +43,7 @@ only, use pip_ in a terminal: $ pip install --upgrade mne -The current MNE-Python release requires Python 3.8 or higher. MNE-Python 0.17 +The current MNE-Python release requires Python 3.9 or higher. MNE-Python 0.17 was the last release to support Python 2.7. For more complete instructions, including our standalone installers and more @@ -73,7 +73,7 @@ Dependencies The minimum required dependencies to run MNE-Python are: -- `Python `__ ≥ 3.8 +- `Python `__ ≥ 3.9 - `NumPy `__ ≥ 1.21.2 - `SciPy `__ ≥ 1.7.1 - `Matplotlib `__ ≥ 3.5.0 @@ -88,12 +88,12 @@ For full functionality, some functions require: - `scikit-learn `__ ≥ 1.0 - `Joblib `__ ≥ 0.15 (for parallelization) - `mne-qt-browser `__ ≥ 0.1 (for fast raw data visualization) -- `Qt `__ ≥ 5.12 via one of the following bindings (for fast raw data visualization and interactive 3D visualization): +- `Qt `__ ≥ 5.15 via one of the following bindings (for fast raw data visualization and interactive 3D visualization): - `PyQt6 `__ ≥ 6.0 - `PySide6 `__ ≥ 6.0 - - `PyQt5 `__ ≥ 5.12 - - `PySide2 `__ ≥ 5.12 + - `PyQt5 `__ ≥ 5.15 + - `PySide2 `__ ≥ 5.15 - `Numba `__ ≥ 0.54.0 - `NiBabel `__ ≥ 3.2.1 @@ -121,53 +121,20 @@ About ^^^^^ +---------+------------+----------------+ -| CI | |Codecov|_ | |Bandit|_ | +| CI | |Codecov| | |Bandit| | +---------+------------+----------------+ -| Package | |PyPI|_ | |conda-forge|_ | +| Package | |PyPI| | |conda-forge| | +---------+------------+----------------+ -| Docs | |Docs|_ | |Discourse|_ | +| Docs | |Docs| | |Discourse| | +---------+------------+----------------+ -| Meta | |Zenodo|_ | |OpenSSF|_ | +| Meta | |Zenodo| | |OpenSSF| | +---------+------------+----------------+ License ^^^^^^^ -MNE-Python is **BSD-licensed** (BSD-3-Clause): - - This software is OSI Certified Open Source Software. - OSI Certified is a certification mark of the Open Source Initiative. - - Copyright (c) 2011-2022, authors of MNE-Python. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - * Neither the names of MNE-Python authors nor the names of any - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - **This software is provided by the copyright holders and contributors - "as is" and any express or implied warranties, including, but not - limited to, the implied warranties of merchantability and fitness for - a particular purpose are disclaimed. In no event shall the copyright - owner or contributors be liable for any direct, indirect, incidental, - special, exemplary, or consequential damages (including, but not - limited to, procurement of substitute goods or services; loss of use, - data, or profits; or business interruption) however caused and on any - theory of liability, whether in contract, strict liability, or tort - (including negligence or otherwise) arising in any way out of the use - of this software, even if advised of the possibility of such - damage.** +MNE-Python is licensed under the BSD-3-Clause license. .. _Documentation: https://mne.tools/dev/ @@ -176,28 +143,28 @@ MNE-Python is **BSD-licensed** (BSD-3-Clause): .. _pip: https://pip.pypa.io/en/stable/ .. |PyPI| image:: https://img.shields.io/pypi/dm/mne.svg?label=PyPI -.. _PyPI: https://pypi.org/project/mne/ + :target: https://pypi.org/project/mne/ .. |conda-forge| image:: https://img.shields.io/conda/dn/conda-forge/mne.svg?label=Conda -.. _conda-forge: https://anaconda.org/conda-forge/mne + :target: https://anaconda.org/conda-forge/mne .. |Docs| image:: https://img.shields.io/badge/Docs-online-green?label=Documentation -.. _Docs: https://mne.tools/dev/ + :target: https://mne.tools/dev/ .. |Zenodo| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.592483.svg -.. _Zenodo: https://doi.org/10.5281/zenodo.592483 + :target: https://doi.org/10.5281/zenodo.592483 .. |Discourse| image:: https://img.shields.io/discourse/status?label=Forum&server=https%3A%2F%2Fmne.discourse.group%2F -.. _Discourse: https://mne.discourse.group/ + :target: https://mne.discourse.group/ .. |Codecov| image:: https://img.shields.io/codecov/c/github/mne-tools/mne-python?label=Coverage -.. _Codecov: https://codecov.io/gh/mne-tools/mne-python + :target: https://codecov.io/gh/mne-tools/mne-python .. |Bandit| image:: https://img.shields.io/badge/Security-Bandit-yellow.svg -.. _Bandit: https://github.com/PyCQA/bandit + :target: https://github.com/PyCQA/bandit .. |OpenSSF| image:: https://www.bestpractices.dev/projects/7783/badge -.. _OpenSSF: https://www.bestpractices.dev/projects/7783 + :target: https://www.bestpractices.dev/projects/7783 .. |MNE| image:: https://mne.tools/dev/_static/mne_logo_gray.svg -.. _MNE: https://mne.tools/dev/ + :target: https://mne.tools/dev/ diff --git a/SECURITY.md b/SECURITY.md index e627242d244..82d4c9e45de 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -10,9 +10,9 @@ without a proper 6-month deprecation cycle. | Version | Supported | | ------- | ------------------------ | -| 1.7.x | :heavy_check_mark: (dev) | -| 1.6.x | :heavy_check_mark: | -| < 1.6 | :x: | +| 1.8.x | :heavy_check_mark: (dev) | +| 1.7.x | :heavy_check_mark: | +| < 1.7 | :x: | ## Reporting a Vulnerability diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 5cee5568623..7e2fa2bd397 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -64,10 +64,6 @@ stages: make nesting displayName: make nesting condition: always() - - bash: | - make check-manifest - displayName: make check-manifest - condition: always() - bash: | make check-readme displayName: make check-readme @@ -108,7 +104,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard "PyQt6!=6.5.1" qtpy nibabel sphinx-gallery + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard qtpy nibabel sphinx-gallery -r tools/pyqt6_requirements.txt python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' @@ -117,10 +113,9 @@ stages: mne sys_info -pd mne sys_info -pd | grep "qtpy .*(PyQt6=.*)$" displayName: Print config - # Uncomment if "xcb not found" Qt errors/segfaults come up again - # - bash: | - # set -e - # LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" + - bash: | + set -e + LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" - bash: source tools/get_testing_version.sh displayName: 'Get testing version' - task: Cache@2 @@ -188,9 +183,9 @@ stages: displayName: 'Get test data' - bash: | set -e - python -m pip install PyQt6 - # Uncomment if "xcb not found" Qt errors/segfaults come up again - # LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" + python -m pip install -r tools/pyqt6_requirements.txt + LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()" + - bash: | mne sys_info -pd mne sys_info -pd | grep "qtpy .* (PyQt6=.*)$" PYTEST_QT_API=PyQt6 pytest -m "not slowtest" ${TEST_OPTIONS} @@ -198,7 +193,7 @@ stages: displayName: 'PyQt6' - bash: | set -e - python -m pip install PySide6 + python -m pip install "PySide6!=6.7.0" mne sys_info -pd mne sys_info -pd | grep "qtpy .* (PySide6=.*)$" PYTEST_QT_API=PySide6 pytest -m "not slowtest" ${TEST_OPTIONS} @@ -243,7 +238,7 @@ stages: variables: MNE_LOGGING_LEVEL: 'warning' MNE_FORCE_SERIAL: 'true' - OPENBLAS_NUM_THREADS: 2 + OPENBLAS_NUM_THREADS: '2' OMP_DYNAMIC: 'false' PYTHONUNBUFFERED: 1 PYTHONIOENCODING: 'utf-8' @@ -256,9 +251,9 @@ stages: 3.9 pip: TEST_MODE: 'pip' PYTHON_VERSION: '3.9' - 3.11 pip pre: + 3.12 pip pre: TEST_MODE: 'pip-pre' - PYTHON_VERSION: '3.11' + PYTHON_VERSION: '3.12' steps: - task: UsePythonVersion@0 inputs: @@ -279,6 +274,8 @@ stages: displayName: 'Print config' - script: python -c "import numpy; numpy.show_config()" displayName: Print NumPy config + - script: python -c "import numpy; import scipy.linalg; import sklearn.neighbors; from threadpoolctl import threadpool_info; from pprint import pprint; pprint(threadpool_info())" + displayName: Print threadpoolctl info - bash: source tools/get_testing_version.sh displayName: 'Get testing version' - task: Cache@2 diff --git a/codemeta.json b/codemeta.json index b2922b2194d..ebfe798c648 100644 --- a/codemeta.json +++ b/codemeta.json @@ -5,11 +5,11 @@ "codeRepository": "git+https://github.com/mne-tools/mne-python.git", "dateCreated": "2010-12-26", "datePublished": "2014-08-04", - "dateModified": "2023-11-20", - "downloadUrl": "https://github.com/mne-tools/mne-python/archive/v1.6.0.zip", + "dateModified": "2024-04-19", + "downloadUrl": "https://github.com/mne-tools/mne-python/archive/v1.7.0.zip", "issueTracker": "https://github.com/mne-tools/mne-python/issues", "name": "MNE-Python", - "version": "1.6.0", + "version": "1.7.0", "description": "MNE-Python is an open-source Python package for exploring, visualizing, and analyzing human neurophysiological data. It provides methods for data input/output, preprocessing, visualization, source estimation, time-frequency analysis, connectivity analysis, machine learning, and statistics.", "applicationCategory": "Neuroscience", "developmentStatus": "active", @@ -37,7 +37,7 @@ "macOS" ], "softwareRequirements": [ - "python>=3.8", + "python>=3.9", "numpy>=1.21.2", "scipy>=1.7.1", "matplotlib>=3.5.0", @@ -46,9 +46,7 @@ "decorator", "packaging", "jinja2", - "importlib_resources>=5.10.2; python_version<'3.9'", - "lazy_loader>=0.3", - "defusedxml" + "lazy_loader>=0.3" ], "author": [ { @@ -99,18 +97,18 @@ "givenName":"Jona", "familyName": "Sassenhagen" }, - { - "@type":"Person", - "email":"mluessi@nmr.mgh.harvard.edu", - "givenName":"Martin", - "familyName": "Luessi" - }, { "@type":"Person", "email":"dan@mccloy.info", "givenName":"Daniel", "familyName": "McCloy" }, + { + "@type":"Person", + "email":"mluessi@nmr.mgh.harvard.edu", + "givenName":"Martin", + "familyName": "Luessi" + }, { "@type":"Person", "email":"jeanremi.king+github@gmail.com", @@ -153,18 +151,18 @@ "givenName":"Mark", "familyName": "Wronkiewicz" }, - { - "@type":"Person", - "email":"choldgraf@gmail.com", - "givenName":"Chris", - "familyName": "Holdgraf" - }, { "@type":"Person", "email":"aprockhill206@gmail.com", "givenName":"Alex", "familyName": "Rockhill" }, + { + "@type":"Person", + "email":"choldgraf@gmail.com", + "givenName":"Chris", + "familyName": "Holdgraf" + }, { "@type":"Person", "email":"mailsik@gmail.com", @@ -345,6 +343,12 @@ "givenName":"Nick", "familyName": "Foti" }, + { + "@type":"Person", + "email":"", + "givenName":"Scott", + "familyName": "Huberty" + }, { "@type":"Person", "email":"cnangini@gmail.com", @@ -357,12 +361,6 @@ "givenName":"José C", "familyName": "García Alanis" }, - { - "@type":"Person", - "email":"", - "givenName":"Scott", - "familyName": "Huberty" - }, { "@type":"Person", "email":"olaf.hauk@mrc-cbu.cam.ac.uk", @@ -489,6 +487,18 @@ "givenName":"Félix", "familyName": "Raimundo" }, + { + "@type":"Person", + "email":"Woessner.jacob@gmail.com", + "givenName":"Jacob", + "familyName": "Woessner" + }, + { + "@type":"Person", + "email":"rcmdnk@gmail.com", + "givenName":"Michiru", + "familyName": "Kaneda" + }, { "@type":"Person", "email":"phillip.alday@mpi.nl", @@ -561,6 +571,18 @@ "givenName":"Dirk", "familyName": "Gütlin" }, + { + "@type":"Person", + "email":"erkkahe@gmail.com", + "givenName":"Erkka", + "familyName": "Heinila" + }, + { + "@type":"Person", + "email":"kristijan.armeni@gmail.com", + "givenName":"Kristijan", + "familyName": "Armeni" + }, { "@type":"Person", "email":"kjs@llama", @@ -603,12 +625,6 @@ "givenName":"Erica", "familyName": "Peterson" }, - { - "@type":"Person", - "email":"erkkahe@gmail.com", - "givenName":"Erkka", - "familyName": "Heinila" - }, { "@type":"Person", "email":"jevri.hanna@gmail.com", @@ -621,12 +637,6 @@ "givenName":"Jon", "familyName": "Houck" }, - { - "@type":"Person", - "email":"rcmdnk@gmail.com", - "givenName":"Michiru", - "familyName": "Kaneda" - }, { "@type":"Person", "email":"neklein@andrew.cmu.edu", @@ -657,12 +667,24 @@ "givenName":"Burkhard", "familyName": "Maess" }, + { + "@type":"Person", + "email":"carinaforster0611@gmail.com", + "givenName":"Carina", + "familyName": "Forster" + }, { "@type":"Person", "email":"christian.oreilly@gmail.com", "givenName":"Christian", "familyName": "O'Reilly" }, + { + "@type":"Person", + "email":"dominik.welke@ae.mpg.de", + "givenName":"Dominik", + "familyName": "Welke" + }, { "@type":"Person", "email":"", @@ -681,12 +703,6 @@ "givenName":"Jack", "familyName": "Zhang" }, - { - "@type":"Person", - "email":"Woessner.jacob@gmail.com", - "givenName":"Jacob", - "familyName": "Woessner" - }, { "@type":"Person", "email":"makkostya@ukr.net", @@ -723,12 +739,6 @@ "givenName":"Saket", "familyName": "Choudhary" }, - { - "@type":"Person", - "email":"carinaforster0611@gmail.com", - "givenName":"Carina", - "familyName": "Forster" - }, { "@type":"Person", "email":"", @@ -801,6 +811,12 @@ "givenName":"Santeri", "familyName": "Ruuskanen" }, + { + "@type":"Person", + "email":"ksherbst@gmail.com", + "givenName":"Sophie", + "familyName": "Herbst" + }, { "@type":"Person", "email":"", @@ -825,12 +841,6 @@ "givenName":"Basile", "familyName": "Pinsard" }, - { - "@type":"Person", - "email":"dominik.welke@ae.mpg.de", - "givenName":"Dominik", - "familyName": "Welke" - }, { "@type":"Person", "email":"dominik.welke@web.de", @@ -879,6 +889,12 @@ "givenName":"Hafeza", "familyName": "Anevar" }, + { + "@type":"Person", + "email":"hamza.abdelhedii@gmail.com", + "givenName":"Hamza", + "familyName": "Abdelhedi" + }, { "@type":"Person", "email":"mail@jan-sosulski.de", @@ -897,6 +913,12 @@ "givenName":"Joshua", "familyName": "Calder-Travis" }, + { + "@type":"Person", + "email":"", + "givenName":"Judy D", + "familyName": "Zhu" + }, { "@type":"Person", "email":"leisenman@wustl.edu", @@ -915,6 +937,12 @@ "givenName":"Marian", "familyName": "Dovgialo" }, + { + "@type":"Person", + "email":"", + "givenName":"Nabil", + "familyName": "Alibou" + }, { "@type":"Person", "email":"", @@ -927,6 +955,12 @@ "givenName":"Nicolas", "familyName": "Legrand" }, + { + "@type":"Person", + "email":"4dvlup@gmail.com", + "givenName":"Nikolai", + "familyName": "Kapralov" + }, { "@type":"Person", "email":"falachrotem@gmail.com", @@ -957,6 +991,18 @@ "givenName":"Steven", "familyName": "Bierer" }, + { + "@type":"Person", + "email":"t.s.binns@outlook.com", + "givenName":"Thomas Samuel", + "familyName": "Binns" + }, + { + "@type":"Person", + "email":"ttstenner@gmail.com", + "givenName":"Tristan", + "familyName": "Stenner" + }, { "@type":"Person", "email":"victor.ferat@live.Fr", @@ -1047,12 +1093,6 @@ "givenName":"George", "familyName": "O'Neill" }, - { - "@type":"Person", - "email":"hamza.abdelhedii@gmail.com", - "givenName":"Hamza", - "familyName": "Abdelhedi" - }, { "@type":"Person", "email":"jean.baptiste.schiratti@gmail.com", @@ -1083,18 +1123,6 @@ "givenName":"Joshua", "familyName": "Teves" }, - { - "@type":"Person", - "email":"", - "givenName":"Judy D", - "familyName": "Zhu" - }, - { - "@type":"Person", - "email":"kristijan.armeni@gmail.com", - "givenName":"Kristijan", - "familyName": "Armeni" - }, { "@type":"Person", "email":"kylemath@gmail.com", @@ -1113,6 +1141,12 @@ "givenName":"Lenny", "familyName": "Varghese" }, + { + "@type":"Person", + "email":"", + "givenName":"Liberty", + "familyName": "Hamilton" + }, { "@type":"Person", "email":"", @@ -1179,6 +1213,12 @@ "givenName":"Oleh", "familyName": "Kozynets" }, + { + "@type":"Person", + "email":"pmolfese@gmail.com", + "givenName":"Peter J", + "familyName": "Molfese" + }, { "@type":"Person", "email":"pierreablin@gmail.com", @@ -1221,12 +1261,6 @@ "givenName":"Sheraz", "familyName": "Khan" }, - { - "@type":"Person", - "email":"ksherbst@gmail.com", - "givenName":"Sophie", - "familyName": "Herbst" - }, { "@type":"Person", "email":"", @@ -1241,15 +1275,15 @@ }, { "@type":"Person", - "email":"", + "email":"tdonoghue.research@gmail.com", "givenName":"Thomas", - "familyName": "Jochmann" + "familyName": "Donoghue" }, { "@type":"Person", - "email":"t.s.binns@outlook.com", - "givenName":"Thomas Samuel", - "familyName": "Binns" + "email":"", + "givenName":"Thomas", + "familyName": "Jochmann" }, { "@type":"Person", @@ -1269,12 +1303,6 @@ "givenName":"Tom", "familyName": "Dupré la Tour" }, - { - "@type":"Person", - "email":"ttstenner@gmail.com", - "givenName":"Tristan", - "familyName": "Stenner" - }, { "@type":"Person", "email":"tzionan@mail.tau.ac.il", @@ -1329,6 +1357,12 @@ "givenName":"Alex", "familyName": "Ciok" }, + { + "@type":"Person", + "email":"", + "givenName":"Alexander", + "familyName": "Kiefer" + }, { "@type":"Person", "email":"7andy121@gmail.com", @@ -1551,6 +1585,12 @@ "givenName":"Federico", "familyName": "Zamberlan" }, + { + "@type":"Person", + "email":"hofaflo@gmail.com", + "givenName":"Florian", + "familyName": "Hofer" + }, { "@type":"Person", "email":"florinpop@me.com", @@ -1599,6 +1639,12 @@ "givenName":"Hamid", "familyName": "Maymandi" }, + { + "@type":"Person", + "email":"", + "givenName":"Hasrat Ali", + "familyName": "Arzoo" + }, { "@type":"Person", "email":"hermann.sonntag@gmail.com", @@ -1631,15 +1677,15 @@ }, { "@type":"Person", - "email":"", + "email":"ivan.zubarev@aalto.fi", "givenName":"Ivan", - "familyName": "Skelin" + "familyName": "Zubarev" }, { "@type":"Person", - "email":"ivan.zubarev@aalto.fi", - "givenName":"Ivan", - "familyName": "Zubarev" + "email":"ivopascal@gmail.com", + "givenName":"Ivo", + "familyName": "de Jong" }, { "@type":"Person", @@ -1737,12 +1783,6 @@ "givenName":"Leonardo S", "familyName": "Barbosa" }, - { - "@type":"Person", - "email":"", - "givenName":"Liberty", - "familyName": "Hamilton" - }, { "@type":"Person", "email":"lorenzo.alfine@gmail.com", @@ -1797,6 +1837,12 @@ "givenName":"Mark Alexander", "familyName": "Henney" }, + { + "@type":"Person", + "email":"", + "givenName":"Martin", + "familyName": "Oberg" + }, { "@type":"Person", "email":"dev@mgschulz.de", @@ -1929,12 +1975,6 @@ "givenName":"Pedro", "familyName": "Silva" }, - { - "@type":"Person", - "email":"pmolfese@gmail.com", - "givenName":"Peter J", - "familyName": "Molfese" - }, { "@type":"Person", "email":"proloy@umd.edu", @@ -1995,6 +2035,12 @@ "givenName":"Richard", "familyName": "Koehler" }, + { + "@type":"Person", + "email":"", + "givenName":"Richard", + "familyName": "Scholz" + }, { "@type":"Person", "email":"rie.acad@gmail.com", @@ -2085,6 +2131,12 @@ "givenName":"Sergey", "familyName": "Antopolskiy" }, + { + "@type":"Person", + "email":"shirazi@ieee.org", + "givenName":"Seyed (Yahya)", + "familyName": "Shirazi" + }, { "@type":"Person", "email":"", @@ -2145,12 +2197,6 @@ "givenName":"T", "familyName": "Wang" }, - { - "@type":"Person", - "email":"tdonoghue.research@gmail.com", - "givenName":"Thomas", - "familyName": "Donoghue" - }, { "@type":"Person", "email":"thomas.moreau.2010@gmail.com", @@ -2193,6 +2239,18 @@ "givenName":"Toomas Erik", "familyName": "Anijärv" }, + { + "@type":"Person", + "email":"", + "givenName":"Velu Prabhakar", + "familyName": "Kumaravel" + }, + { + "@type":"Person", + "email":"williamfrancisturner@gmail.com", + "givenName":"Will", + "familyName": "Turner" + }, { "@type":"Person", "email":"xia@xiaokai.me", @@ -2211,6 +2269,12 @@ "givenName":"Zhi", "familyName": "Zhang" }, + { + "@type":"Person", + "email":"btk.codedev@gmail.com", + "givenName":"", + "familyName": "btkcodedev" + }, { "@type":"Person", "email":"", diff --git a/doc/Makefile b/doc/Makefile index 70d7429f4ad..3c251069045 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -76,6 +76,6 @@ doctest: "results in _build/doctest/output.txt." view: - @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')" + @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/sg_execution_times.html')" show: view diff --git a/doc/_includes/channel_interpolation.rst b/doc/_includes/channel_interpolation.rst index 4639604af58..e90a763d214 100644 --- a/doc/_includes/channel_interpolation.rst +++ b/doc/_includes/channel_interpolation.rst @@ -59,7 +59,7 @@ where :math:`G_{ds} \in R^{M \times N}` computes :math:`g_{m}(\boldsymbol{r_i}, To interpolate bad channels, one can simply do: - >>> evoked.interpolate_bads(reset_bads=False) # doctest: +SKIP + >>> evoked.interpolate_bads(reset_bads=False) # doctest: +SKIP and the bad channel will be fixed. @@ -67,4 +67,4 @@ and the bad channel will be fixed. .. topic:: Examples: - * :ref:`ex-interpolate-bad-channels` + * :ref:`ex-interpolate-bad-channels` diff --git a/doc/_includes/channel_types.rst b/doc/_includes/channel_types.rst index 647dab25ba4..0a2ea0ab007 100644 --- a/doc/_includes/channel_types.rst +++ b/doc/_includes/channel_types.rst @@ -10,6 +10,11 @@ Supported channel types from the include: channel-types-begin-content +.. NOTE: In the future, this table should be automatically synchronized with + the sensor types listed in the glossary. Perhaps a table showing data type + channels as well as non-data type channels should be added to the glossary + and displayed here too. + Channel types are represented in MNE-Python with shortened or abbreviated names. This page lists all supported channel types, their abbreviated names, and the measurement unit used to represent data of that type. Where channel @@ -23,50 +28,77 @@ parentheses. More information about measurement units is given in the .. cssclass:: table-bordered .. rst-class:: midvalign -============= ========================================= ================= -Channel type Description Measurement unit -============= ========================================= ================= -eeg scalp electroencephalography (EEG) Volts +================= ========================================= ================= +Channel type Description Measurement unit +================= ========================================= ================= +eeg scalp electroencephalography (EEG) Volts + +meg (mag) Magnetoencephalography (magnetometers) Teslas + +meg (grad) Magnetoencephalography (gradiometers) Teslas/meter + +ecg Electrocardiography (ECG) Volts + +seeg Stereotactic EEG channels Volts + +dbs Deep brain stimulation (DBS) Volts + +ecog Electrocorticography (ECoG) Volts + +fnirs (hbo) Functional near-infrared spectroscopy Moles/liter + (oxyhemoglobin) + +fnirs (hbr) Functional near-infrared spectroscopy Moles/liter + (deoxyhemoglobin) + +emg Electromyography (EMG) Volts + +eog Electrooculography (EOG) Volts + +bio Miscellaneous biological channels (e.g., Arbitrary units + skin conductance) -meg (mag) Magnetoencephalography (magnetometers) Teslas +stim stimulus (a.k.a. trigger) channels Arbitrary units -meg (grad) Magnetoencephalography (gradiometers) Teslas/meter +resp respiration monitoring channel Volts -ecg Electrocardiography (ECG) Volts +chpi continuous head position indicator Teslas + (HPI) coil channels -seeg Stereotactic EEG channels Volts +exci Flux excitation channel -dbs Deep brain stimulation (DBS) Volts +ias Internal Active Shielding data + (Triux systems only?) -ecog Electrocorticography (ECoG) Volts +syst System status channel information + (Triux systems only) -fnirs (hbo) Functional near-infrared spectroscopy Moles/liter - (oxyhemoglobin) +temperature Temperature Degrees Celsius -fnirs (hbr) Functional near-infrared spectroscopy Moles/liter - (deoxyhemoglobin) +gsr Galvanic skin response Siemens -emg Electromyography (EMG) Volts +ref_meg Reference Magnetometers Teslas -bio Miscellaneous biological channels (e.g., Arbitrary units - skin conductance) +dipole Dipole amplitude Amperes -stim stimulus (a.k.a. trigger) channels Arbitrary units +gof Goodness of fit (GOF) Goodness-of-fit -resp respiration monitoring channel Volts +cw-nirs (amp) Continuous-wave functional near-infrared Volts + spectroscopy (CW-fNIRS) (CW amplitude) -chpi continuous head position indicator Teslas - (HPI) coil channels +fd-nirs (ac amp) Frequency-domain near-infrared Volts + spectroscopy (FD-NIRS AC amplitude) -exci Flux excitation channel +fd-nirs (phase) Frequency-domain near-infrared Radians + spectroscopy (FD-NIRS phase) -ias Internal Active Shielding data - (Triux systems only?) +fnirs (od) Functional near-infrared spectroscopy Volts + (optical density) -syst System status channel information - (Triux systems only) +csd Current source density Volts per square + meter -temperature Temperature Degrees Celsius +eyegaze Eye-tracking (gaze position) Arbitrary units -gsr Galvanic skin response Siemens -============= ========================================= ================= +pupil Eye-tracking (pupil size) Arbitrary units +================= ========================================= ================= \ No newline at end of file diff --git a/doc/_includes/forward.rst b/doc/_includes/forward.rst index f92632f8220..d04eeba7b5b 100644 --- a/doc/_includes/forward.rst +++ b/doc/_includes/forward.rst @@ -130,26 +130,26 @@ transformation symbols (:math:`T_x`) indicate the transformations actually present in the FreeSurfer files. Generally, .. math:: \begin{bmatrix} - x_2 \\ - y_2 \\ - z_2 \\ - 1 - \end{bmatrix} = T_{12} \begin{bmatrix} - x_1 \\ - y_1 \\ - z_1 \\ - 1 - \end{bmatrix} = \begin{bmatrix} - R_{11} & R_{12} & R_{13} & x_0 \\ - R_{21} & R_{22} & R_{23} & y_0 \\ - R_{31} & R_{32} & R_{33} & z_0 \\ - 0 & 0 & 0 & 1 - \end{bmatrix} \begin{bmatrix} - x_1 \\ - y_1 \\ - z_1 \\ - 1 - \end{bmatrix}\ , + x_2 \\ + y_2 \\ + z_2 \\ + 1 + \end{bmatrix} = T_{12} \begin{bmatrix} + x_1 \\ + y_1 \\ + z_1 \\ + 1 + \end{bmatrix} = \begin{bmatrix} + R_{11} & R_{12} & R_{13} & x_0 \\ + R_{21} & R_{22} & R_{23} & y_0 \\ + R_{31} & R_{32} & R_{33} & z_0 \\ + 0 & 0 & 0 & 1 + \end{bmatrix} \begin{bmatrix} + x_1 \\ + y_1 \\ + z_1 \\ + 1 + \end{bmatrix}\ , where :math:`x_k`, :math:`y_k`,and :math:`z_k` are the location coordinates in two coordinate systems, :math:`T_{12}` is the coordinate transformation from @@ -161,20 +161,20 @@ files produced by FreeSurfer and MNE. The fixed transformations :math:`T_-` and :math:`T_+` are: .. math:: T_{-} = \begin{bmatrix} - 0.99 & 0 & 0 & 0 \\ - 0 & 0.9688 & 0.042 & 0 \\ - 0 & -0.0485 & 0.839 & 0 \\ - 0 & 0 & 0 & 1 - \end{bmatrix} + 0.99 & 0 & 0 & 0 \\ + 0 & 0.9688 & 0.042 & 0 \\ + 0 & -0.0485 & 0.839 & 0 \\ + 0 & 0 & 0 & 1 + \end{bmatrix} and .. math:: T_{+} = \begin{bmatrix} - 0.99 & 0 & 0 & 0 \\ - 0 & 0.9688 & 0.046 & 0 \\ - 0 & -0.0485 & 0.9189 & 0 \\ - 0 & 0 & 0 & 1 - \end{bmatrix} + 0.99 & 0 & 0 & 0 \\ + 0 & 0.9688 & 0.046 & 0 \\ + 0 & -0.0485 & 0.9189 & 0 \\ + 0 & 0 & 0 & 1 + \end{bmatrix} .. note:: This section does not discuss the transformation between the MRI voxel @@ -352,11 +352,11 @@ coordinates (:math:`r_D`) by where .. math:: T = \begin{bmatrix} - e_x & 0 \\ - e_y & 0 \\ - e_z & 0 \\ - r_{0D} & 1 - \end{bmatrix}\ . + e_x & 0 \\ + e_y & 0 \\ + e_z & 0 \\ + r_{0D} & 1 + \end{bmatrix}\ . Calculation of the magnetic field --------------------------------- diff --git a/doc/_includes/ssp.rst b/doc/_includes/ssp.rst index 1bc860d15db..40b25a237db 100644 --- a/doc/_includes/ssp.rst +++ b/doc/_includes/ssp.rst @@ -101,12 +101,12 @@ The EEG average reference is the mean signal over all the sensors. It is typical in EEG analysis to subtract the average reference from all the sensor signals :math:`b^{1}(t), ..., b^{n}(t)`. That is: -.. math:: {b}^{j}_{s}(t) = b^{j}(t) - \frac{1}{n}\sum_{k}{b^k(t)} +.. math:: {b}^{j}_{s}(t) = b^{j}(t) - \frac{1}{n}\sum_{k}{b^k(t)} :name: eeg_proj where the noise term :math:`b_{n}^{j}(t)` is given by -.. math:: b_{n}^{j}(t) = \frac{1}{n}\sum_{k}{b^k(t)} +.. math:: b_{n}^{j}(t) = \frac{1}{n}\sum_{k}{b^k(t)} :name: noise_term Thus, the projector vector :math:`P_{\perp}` will be given by diff --git a/doc/_static/style.css b/doc/_static/style.css index 61eea678830..11a27b72c92 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -7,8 +7,6 @@ --pst-font-family-monospace: 'Source Code Pro', var(--pst-font-family-monospace-system); /* colors that aren't responsive to light/dark mode */ --mne-color-discord: #5865F2; - --mne-color-primary: #007bff; - --mne-color-primary-highlight: #0063cc; /* font weight */ --mne-font-weight-semibold: 600; } @@ -17,7 +15,7 @@ html[data-theme="light"] { /* topbar logo links */ --mne-color-github: #000; - --mne-color-discourse: #000; + --mne-color-discourse: #d0232b; --mne-color-mastodon: #2F0C7A; /* code block copy button */ --copybtn-opacity: 0.75; @@ -25,18 +23,11 @@ html[data-theme="light"] { --mne-color-card-header: rgba(0, 0, 0, 0.05); /* section headings */ --mne-color-heading: #003e80; - /* pydata-sphinx-theme overrides */ - --pst-color-primary: var(--mne-color-primary); - --pst-color-primary-highlight: var(--mne-color-primary-highlight); - --pst-color-info: var(--pst-color-primary); - --pst-color-border: #ccc; - --pst-color-background: #fff; - --pst-color-link: var(--pst-color-primary-highlight); /* sphinx-gallery overrides */ --sg-download-a-background-color: var(--pst-color-primary); --sg-download-a-background-image: unset; --sg-download-a-border-color: var(--pst-color-border); - --sg-download-a-color: #fff; + --sg-download-a-color: var(--sd-color-primary-text); --sg-download-a-hover-background-color: var(--pst-color-primary-highlight); --sg-download-a-hover-box-shadow-1: none; --sg-download-a-hover-box-shadow-2: none; @@ -52,19 +43,11 @@ html[data-theme="dark"] { --mne-color-card-header: rgba(255, 255, 255, 0.2); /* section headings */ --mne-color-heading: #b8cbe0; - /* pydata-sphinx-theme overrides */ - --pst-color-primary: var(--mne-color-primary); - --pst-color-primary-highlight: var(--mne-color-primary-highlight); - --pst-color-info: var(--pst-color-primary); - --pst-color-border: #333; - --pst-color-background: #000; - --pst-color-link: #66b0ff; - --pst-color-on-background: #1e1e1e; /* sphinx-gallery overrides */ --sg-download-a-background-color: var(--pst-color-primary); --sg-download-a-background-image: unset; --sg-download-a-border-color: var(--pst-color-border); - --sg-download-a-color: #000; + --sg-download-a-color: var(--sd-color-primary-text); --sg-download-a-hover-background-color: var(--pst-color-primary-highlight); --sg-download-a-hover-box-shadow-1: none; --sg-download-a-hover-box-shadow-2: none; @@ -99,11 +82,6 @@ html[data-theme="dark"] img { filter: none; } -/* prev/next links */ -.prev-next-area a p.prev-next-title { - color: var(--pst-color-link); -} - /* make versionadded smaller and inline with param name */ /* don't do for deprecated / versionchanged; they have extra info (too long to fit) */ div.versionadded > p { @@ -148,8 +126,12 @@ p.sphx-glr-signature { border-radius: 0.5rem; /* ↓↓↓↓↓↓↓ these two rules copied from sphinx-design */ box-shadow: 0 .125rem .25rem var(--sd-color-shadow) !important; + color: var(--sg-download-a-color); transition: color .15s ease-in-out,background-color .15s ease-in-out,border-color .15s ease-in-out,box-shadow .15s ease-in-out; } +.sphx-glr-download a.download::before { + color: var(--sg-download-a-color); +} /* Report embedding */ iframe.sg_report { width: 95%; @@ -222,16 +204,16 @@ aside.footnote:last-child { } /* ******************************************************* navbar icon links */ -#navbar-icon-links i.fa-square-github::before { +.navbar-icon-links i.fa-square-github::before { color: var(--mne-color-github); } -#navbar-icon-links i.fa-discourse::before { +.navbar-icon-links i.fa-discourse::before { color: var(--mne-color-discourse); } -#navbar-icon-links i.fa-discord::before { +.navbar-icon-links i.fa-discord::before { color: var(--mne-color-discord); } -#navbar-icon-links i.fa-mastodon::before { +.navbar-icon-links i.fa-mastodon::before { color: var(--mne-color-mastodon); } @@ -242,7 +224,6 @@ aside.footnote:last-child { } /* topbar nav active */ .bd-header.navbar-light#navbar-main .navbar-nav > li.active > .nav-link { - color: var(--pst-color-link); font-weight: var(--mne-font-weight-semibold); } /* topbar nav hover */ @@ -250,18 +231,6 @@ aside.footnote:last-child { .bd-header.navbar-light#navbar-main .navbar-nav li a.nav-link:hover { color: var(--pst-color-secondary); } -/* sidebar nav */ -nav.bd-links .active > a, -nav.bd-links .active:hover > a, -.toc-entry a.nav-link.active, -.toc-entry a.nav-link.active:hover { - color: var(--pst-color-link); -} -/* sidebar nav hover */ -nav.bd-links li > a:hover, -.toc-entry a.nav-link:hover { - color: var(--pst-color-secondary); -} /* *********************************************************** homepage logo */ img.logo { @@ -273,10 +242,10 @@ img.logo { ul.quicklinks a { font-weight: var(--mne-font-weight-semibold); color: var(--pst-color-text-base); + text-decoration: none; } ul.quicklinks a:hover { text-decoration: none; - color: var(--pst-color-secondary); } h5.card-header { margin-top: 0px; @@ -287,7 +256,6 @@ h5.card-header::before { height: 0px; margin-top: 0px; } - /* ******************************************************* homepage carousel */ div.frontpage-gallery { overflow: hidden; @@ -342,7 +310,6 @@ div#contributor-avatars div.card img { div#contributor-avatars div.card img { width: 3em; } - .contributor-avatar { clip-path: circle(closest-side); } @@ -380,3 +347,9 @@ img.hidden { td.justify { text-align-last: justify; } + +/* Matplotlib HTML5 video embedding */ +div.sphx-glr-animation video { + max-width: 100%; + height: auto; +} diff --git a/doc/_static/versions.json b/doc/_static/versions.json index 8141440bd16..48e4006f494 100644 --- a/doc/_static/versions.json +++ b/doc/_static/versions.json @@ -1,14 +1,19 @@ [ { - "name": "1.7 (devel)", + "name": "1.8 (devel)", "version": "dev", "url": "https://mne.tools/dev/" }, { - "name": "1.6 (stable)", + "name": "1.7 (stable)", "version": "stable", "url": "https://mne.tools/stable/" }, + { + "name": "1.6", + "version": "1.6", + "url": "https://mne.tools/1.6/" + }, { "name": "1.5", "version": "1.5", diff --git a/doc/api/events.rst b/doc/api/events.rst index f9447741a09..3f7159a22d5 100644 --- a/doc/api/events.rst +++ b/doc/api/events.rst @@ -55,4 +55,4 @@ Events average_movements combine_event_ids equalize_epoch_counts - make_metadata \ No newline at end of file + make_metadata diff --git a/doc/api/file_io.rst b/doc/api/file_io.rst index 3b43de6ce64..2da9059deb3 100644 --- a/doc/api/file_io.rst +++ b/doc/api/file_io.rst @@ -63,4 +63,4 @@ Base class: :toctree: ../generated/ :template: autosummary/class_no_members.rst - BaseEpochs \ No newline at end of file + BaseEpochs diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 54d4bfa2999..1e0e9e56079 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -93,6 +93,7 @@ Projections: cortical_signal_suppression create_ecg_epochs create_eog_epochs + find_bad_channels_lof find_bad_channels_maxwell find_ecg_events find_eog_events @@ -162,6 +163,8 @@ Projections: Calibration read_eyelink_calibration set_channel_types_eyetrack + convert_units + get_screen_visual_angle interpolate_blinks EEG referencing: diff --git a/doc/api/reading_raw_data.rst b/doc/api/reading_raw_data.rst index 1b8ebae2abf..50f524ce7c8 100644 --- a/doc/api/reading_raw_data.rst +++ b/doc/api/reading_raw_data.rst @@ -40,6 +40,7 @@ Reading raw data read_raw_nihon read_raw_fil read_raw_nsx + read_raw_neuralynx Base class: diff --git a/doc/api/realtime.rst b/doc/api/realtime.rst index 91c027a9e3f..0df65ad0d56 100644 --- a/doc/api/realtime.rst +++ b/doc/api/realtime.rst @@ -1,5 +1,6 @@ +.. include:: ../links.inc Realtime ======== -Realtime functionality has moved to the standalone module :mod:`mne_realtime`. +Realtime functionality has moved to the standalone module `MNE-LSL`_. diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index f8948909491..8923920bdba 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -14,7 +14,12 @@ Time-Frequency :toctree: ../generated/ AverageTFR + AverageTFRArray + BaseTFR EpochsTFR + EpochsTFRArray + RawTFR + RawTFRArray CrossSpectralDensity Spectrum SpectrumArray diff --git a/doc/changes/devel.rst.template b/doc/changes/devel.rst.template deleted file mode 100644 index 09c49cad107..00000000000 --- a/doc/changes/devel.rst.template +++ /dev/null @@ -1,34 +0,0 @@ -.. NOTE: we use cross-references to highlight new functions and classes. - Please follow the examples below like :func:`mne.stats.f_mway_rm`, so the - whats_new page will have a link to the function/class documentation. - -.. NOTE: there are 3 separate sections for changes, based on type: - - "Enhancements" for new features - - "Bugs" for bug fixes - - "API changes" for backward-incompatible changes - -.. NOTE: changes from first-time contributors should be added to the TOP of - the relevant section (Enhancements / Bugs / API changes), and should look - like this (where xxxx is the pull request number): - - - description of enhancement/bugfix/API change (:gh:`xxxx` by - :newcontrib:`Firstname Lastname`) - - Also add a corresponding entry for yourself in doc/changes/names.inc - -.. _current: - -Version X.Y.dev0 (development) ------------------------------- - -Enhancements -~~~~~~~~~~~~ -- None yet - -Bugs -~~~~ -- None yet - -API changes -~~~~~~~~~~~ -- None yet diff --git a/doc/changes/devel/.gitignore b/doc/changes/devel/.gitignore new file mode 100644 index 00000000000..f935021a8f8 --- /dev/null +++ b/doc/changes/devel/.gitignore @@ -0,0 +1 @@ +!.gitignore diff --git a/doc/changes/names.inc b/doc/changes/names.inc index da884792c4f..112418f7e72 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -20,13 +20,15 @@ .. _Alex Gramfort: https://alexandre.gramfort.net +.. _Alex Kiefer: https://home.alexk101.dev + .. _Alex Rockhill: https://github.com/alexrockhill/ .. _Alexander Rudiuk: https://github.com/ARudiuk .. _Alexandre Barachant: https://alexandre.barachant.org -.. _Andrea Brovelli: https://andrea-brovelli.net +.. _Andrea Brovelli: https://brovelli.github.io/ .. _Andreas Hojlund: https://github.com/ahoejlund @@ -68,11 +70,13 @@ .. _Bruno Nicenboim: https://bnicenboim.github.io +.. _btkcodedev: https://github.com/btkcodedev + .. _buildqa: https://github.com/buildqa .. _Carlos de la Torre-Ortiz: https://ctorre.me -.. _Carina Forster: https://github.com/carinafo +.. _Carina Forster: https://github.com/CarinaFo .. _Cathy Nangini: https://github.com/KatiRG @@ -172,6 +176,8 @@ .. _Felix Raimundo: https://github.com/gamazeps +.. _Florian Hofer: https://github.com/hofaflo + .. _Florin Pop: https://github.com/florin-pop .. _Frederik Weber: https://github.com/Frederik-D-Weber @@ -184,7 +190,7 @@ .. _George O'Neill: https://georgeoneill.github.io -.. _Gonzalo Reina: https://github.com/Gon-reina +.. _Gonzalo Reina: https://orcid.org/0000-0003-4219-2306 .. _Guillaume Dumas: https://mila.quebec/en/person/guillaume-dumas @@ -200,9 +206,11 @@ .. _Hari Bharadwaj: https://github.com/haribharadwaj +.. _Hasrat Ali Arzoo: https://github.com/hasrat17 + .. _Henrich Kolkhorst: https://github.com/hekolk -.. _Hongjiang Ye: https://github.com/rubyyhj +.. _Hongjiang Ye: https://github.com/hongjiang-ye .. _Hubert Banville: https://github.com/hubertjb @@ -218,6 +226,8 @@ .. _Ivana Kojcic: https://github.com/ikojcic +.. _Ivo de Jong: https://github.com/ivopascal + .. _Jaakko Leppakangas: https://github.com/jaeilepp .. _Jack Zhang: https://github.com/jackz314 @@ -338,7 +348,7 @@ .. _Mark Alexander Henney: https://github.com/henneysq -.. _Mark Wronkiewicz: https://ml.jpl.nasa.gov/people/wronkiewicz/wronkiewicz.html +.. _Mark Wronkiewicz: https://github.com/wronk .. _Marmaduke Woodman: https://github.com/maedoc @@ -346,6 +356,8 @@ .. _Martin Luessi: https://github.com/mluessi +.. _Martin Oberg: https://github.com/obergmartin + .. _Martin Schulz: https://github.com/marsipu .. _Mathieu Scheltienne: https://github.com/mscheltienne @@ -382,6 +394,10 @@ .. _Moritz Gerster: https://github.com/moritz-gerster +.. _Motofumi Fushimi: https://github.com/motofumi-fushimi/motofumi-fushimi.github.io + +.. _Nabil Alibou: https://github.com/nabilalibou + .. _Natalie Klein: https://github.com/natalieklein .. _Nathalie Gayraud: https://github.com/ngayraud @@ -400,11 +416,13 @@ .. _Nikolai Chapochnikov: https://github.com/chapochn +.. _Nikolai Kapralov: https://github.com/ctrltz + .. _Nikolas Chalas: https://github.com/Nichalas .. _Okba Bekhelifi: https://github.com/okbalefthanded -.. _Olaf Hauk: https://www.neuroscience.cam.ac.uk/directory/profile.php?olafhauk +.. _Olaf Hauk: https://neuroscience.cam.ac.uk/member/olafhauk .. _Oleh Kozynets: https://github.com/OlehKSS @@ -450,7 +468,7 @@ .. _Rasmus Aagaard: https://github.com/rasgaard -.. _Rasmus Zetter: https://people.aalto.fi/rasmus.zetter +.. _Rasmus Zetter: https://github.com/rzetter .. _Reza Nasri: https://github.com/rznas @@ -460,6 +478,8 @@ .. _Richard Koehler: https://github.com/richardkoehler +.. _Richard Scholz: https://github.com/scholzri + .. _Riessarius Stargardsky: https://github.com/Riessarius .. _Roan LaPlante: https://github.com/aestrivex @@ -542,7 +562,7 @@ .. _Tal Linzen: https://tallinzen.net/ -.. _Teon Brooks: https://teonbrooks.com +.. _Teon Brooks: https://github.com/teonbrooks .. _Théodore Papadopoulo: https://github.com/papadop @@ -572,12 +592,16 @@ .. _Valerii Chirkov: https://github.com/vagechirkov +.. _Velu Prabhakar Kumaravel: https://github.com/vpKumaravel + .. _Victor Ferat: https://github.com/vferat .. _Victoria Peterson: https://github.com/vpeterson .. _Xiaokai Xia: https://github.com/dddd1007 +.. _Will Turner: https://bootstrapbill.github.io + .. _Yaroslav Halchenko: http://haxbylab.dartmouth.edu/ppl/yarik.html .. _Yiping Zuo: https://github.com/frostime @@ -589,3 +613,5 @@ .. _Zhi Zhang: https://github.com/tczhangzhi/ .. _Zvi Baratz: https://github.com/ZviBaratz + +.. _Seyed Yahya Shirazi: https://neuromechanist.github.io diff --git a/doc/changes/v0.10.rst b/doc/changes/v0.10.rst index 6a0c3322e88..ac4f2e42857 100644 --- a/doc/changes/v0.10.rst +++ b/doc/changes/v0.10.rst @@ -91,7 +91,7 @@ BUG - Fix dropping of events after downsampling stim channels by `Marijn van Vliet`_ -- Fix scaling in :func:``mne.viz.utils._setup_vmin_vmax`` by `Jaakko Leppakangas`_ +- Fix scaling in ``mne.viz.utils._setup_vmin_vmax`` by `Jaakko Leppakangas`_ - Fix order of component selection in :class:`mne.decoding.CSP` by `Clemens Brunner`_ diff --git a/doc/changes/v0.12.rst b/doc/changes/v0.12.rst index cf01f8ff62c..b3b7aba1a39 100644 --- a/doc/changes/v0.12.rst +++ b/doc/changes/v0.12.rst @@ -129,7 +129,7 @@ BUG - Fix bug in :func:`mne.io.Raw.save` where, in rare cases, automatically split files could end up writing an extra empty file that wouldn't be read properly by `Eric Larson`_ -- Fix :class:``mne.realtime.StimServer`` by removing superfluous argument ``ip`` used while initializing the object by `Mainak Jas`_. +- Fix ``mne.realtime.StimServer`` by removing superfluous argument ``ip`` used while initializing the object by `Mainak Jas`_. - Fix removal of projectors in :func:`mne.preprocessing.maxwell_filter` in ``st_only=True`` mode by `Eric Larson`_ @@ -175,37 +175,37 @@ Authors The committer list for this release is the following (preceded by number of commits): -* 348 Eric Larson -* 347 Jaakko Leppakangas -* 157 Alexandre Gramfort -* 139 Jona Sassenhagen -* 67 Jean-Remi King -* 32 Chris Holdgraf -* 31 Denis A. Engemann -* 30 Mainak Jas -* 16 Christopher J. Bailey -* 13 Marijn van Vliet -* 10 Mark Wronkiewicz -* 9 Teon Brooks -* 9 kaichogami -* 8 Clément Moutard -* 5 Camilo Lamus -* 5 mmagnuski -* 4 Christian Brodbeck -* 4 Daniel McCloy -* 4 Yousra Bekhti -* 3 Fede Raimondo -* 1 Jussi Nurminen -* 1 MartinBaBer -* 1 Mikolaj Magnuski -* 1 Natalie Klein -* 1 Niklas Wilming -* 1 Richard Höchenberger -* 1 Sagun Pai -* 1 Sourav Singh -* 1 Tom Dupré la Tour -* 1 jona-sassenhagen@ -* 1 kambysese -* 1 pbnsilva -* 1 sviter -* 1 zuxfoucault +* 348 Eric Larson +* 347 Jaakko Leppakangas +* 157 Alexandre Gramfort +* 139 Jona Sassenhagen +* 67 Jean-Remi King +* 32 Chris Holdgraf +* 31 Denis A. Engemann +* 30 Mainak Jas +* 16 Christopher J. Bailey +* 13 Marijn van Vliet +* 10 Mark Wronkiewicz +* 9 Teon Brooks +* 9 kaichogami +* 8 Clément Moutard +* 5 Camilo Lamus +* 5 mmagnuski +* 4 Christian Brodbeck +* 4 Daniel McCloy +* 4 Yousra Bekhti +* 3 Fede Raimondo +* 1 Jussi Nurminen +* 1 MartinBaBer +* 1 Mikolaj Magnuski +* 1 Natalie Klein +* 1 Niklas Wilming +* 1 Richard Höchenberger +* 1 Sagun Pai +* 1 Sourav Singh +* 1 Tom Dupré la Tour +* 1 jona-sassenhagen@ +* 1 kambysese +* 1 pbnsilva +* 1 sviter +* 1 zuxfoucault diff --git a/doc/changes/v0.13.rst b/doc/changes/v0.13.rst index 425ba4c76a1..aee297d9d2d 100644 --- a/doc/changes/v0.13.rst +++ b/doc/changes/v0.13.rst @@ -198,7 +198,7 @@ API - Deprecated ``mne.time_frequency.cwt_morlet`` and ``mne.time_frequency.single_trial_power`` in favour of :func:`mne.time_frequency.tfr_morlet` with parameter average=False, by `Jean-Remi King`_ and `Alex Gramfort`_ -- Add argument ``mask_type`` to func:`mne.read_events` and func:`mne.find_events` to support MNE-C style of trigger masking by `Teon Brooks`_ and `Eric Larson`_ +- Add argument ``mask_type`` to :func:`mne.read_events` and :func:`mne.find_events` to support MNE-C style of trigger masking by `Teon Brooks`_ and `Eric Larson`_ - Extended Infomax is now the new default in :func:`mne.preprocessing.infomax` (``extended=True``), by `Clemens Brunner`_ diff --git a/doc/changes/v0.15.rst b/doc/changes/v0.15.rst index ada8180d4ac..e2de7301973 100644 --- a/doc/changes/v0.15.rst +++ b/doc/changes/v0.15.rst @@ -226,7 +226,7 @@ API - ``mne.viz.decoding.plot_gat_times``, ``mne.viz.decoding.plot_gat_matrix`` are now deprecated. Use matplotlib instead as shown in the examples, by `Jean-Remi King`_ and `Alex Gramfort`_ -- Add ``norm_trace`` parameter to control single-epoch covariance normalization in :class:mne.decoding.CSP, by `Jean-Remi King`_ +- Add ``norm_trace`` parameter to control single-epoch covariance normalization in :class:`mne.decoding.CSP`, by `Jean-Remi King`_ - Allow passing a list of channel names as ``show_names`` in function :func:`mne.viz.plot_sensors` and methods :meth:`mne.Evoked.plot_sensors`, :meth:`mne.Epochs.plot_sensors` and :meth:`mne.io.Raw.plot_sensors` to show only a subset of channel names by `Jaakko Leppakangas`_ diff --git a/doc/changes/v0.17.rst b/doc/changes/v0.17.rst index 40896b6f383..49e722c584d 100644 --- a/doc/changes/v0.17.rst +++ b/doc/changes/v0.17.rst @@ -234,7 +234,7 @@ API In 0.19 The ``stim_channel`` keyword arguments will be removed from ``read_raw_...`` functions. -- Calling :meth:``mne.io.pick.pick_info`` removing channels that are needed by compensation matrices (``info['comps']``) no longer raises ``RuntimeException`` but instead logs an info level message. By `Luke Bloy`_ +- Calling ``mne.io.pick.pick_info`` removing channels that are needed by compensation matrices (``info['comps']``) no longer raises ``RuntimeException`` but instead logs an info level message. By `Luke Bloy`_ - :meth:`mne.Epochs.save` now has the parameter ``fmt`` to specify the desired format (precision) saving epoched data, by `Stefan Repplinger`_, `Eric Larson`_ and `Alex Gramfort`_ @@ -274,44 +274,44 @@ Authors People who contributed to this release (in alphabetical order): -* Alexandre Gramfort -* Antoine Gauthier -* Britta Westner -* Christian Brodbeck -* Clemens Brunner -* Daniel McCloy -* David Sabbagh -* Denis A. Engemann -* Eric Larson -* Ezequiel Mikulan -* Henrich Kolkhorst -* Hubert Banville -* Jasper J.F. van den Bosch -* Jen Evans -* Joan Massich -* Johan van der Meer -* Jona Sassenhagen -* Kambiz Tavabi -* Lorenz Esch -* Luke Bloy -* Mainak Jas -* Manu Sutela -* Marcin Koculak -* Marijn van Vliet -* Mikolaj Magnuski -* Peter J. Molfese -* Sam Perry -* Sara Sommariva -* Sergey Antopolskiy -* Sheraz Khan -* Stefan Appelhoff -* Stefan Repplinger -* Steven Bethard -* Teekuningas -* Teon Brooks -* Thomas Hartmann -* Thomas Jochmann -* Tom Dupré la Tour -* Tristan Stenner -* buildqa -* jeythekey +* Alexandre Gramfort +* Antoine Gauthier +* Britta Westner +* Christian Brodbeck +* Clemens Brunner +* Daniel McCloy +* David Sabbagh +* Denis A. Engemann +* Eric Larson +* Ezequiel Mikulan +* Henrich Kolkhorst +* Hubert Banville +* Jasper J.F. van den Bosch +* Jen Evans +* Joan Massich +* Johan van der Meer +* Jona Sassenhagen +* Kambiz Tavabi +* Lorenz Esch +* Luke Bloy +* Mainak Jas +* Manu Sutela +* Marcin Koculak +* Marijn van Vliet +* Mikolaj Magnuski +* Peter J. Molfese +* Sam Perry +* Sara Sommariva +* Sergey Antopolskiy +* Sheraz Khan +* Stefan Appelhoff +* Stefan Repplinger +* Steven Bethard +* Teekuningas +* Teon Brooks +* Thomas Hartmann +* Thomas Jochmann +* Tom Dupré la Tour +* Tristan Stenner +* buildqa +* jeythekey diff --git a/doc/changes/v0.23.rst b/doc/changes/v0.23.rst index bf8ed2042e5..0fa34b0dc2d 100644 --- a/doc/changes/v0.23.rst +++ b/doc/changes/v0.23.rst @@ -246,7 +246,7 @@ Bugs - Fix bug with :func:`mne.grow_labels` where ``overlap=False`` could run forever or raise an error (:gh:`9317` by `Eric Larson`_) -- Fix compatibility bugs with :mod:`mne_realtime` (:gh:`8845` by `Eric Larson`_) +- Fix compatibility bugs with ``mne_realtime`` (:gh:`8845` by `Eric Larson`_) - Fix bug with `mne.viz.Brain` where non-inflated surfaces had an X-offset imposed by default (:gh:`8794` by `Eric Larson`_) diff --git a/doc/changes/v0.24.rst b/doc/changes/v0.24.rst index 425fd5d5759..5f92e3dbdf6 100644 --- a/doc/changes/v0.24.rst +++ b/doc/changes/v0.24.rst @@ -37,7 +37,7 @@ Enhancements ~~~~~~~~~~~~ .. - Add something cool (:gh:`9192` **by new contributor** |New Contributor|_) -- Add `pooch` to system information reports (:gh:`9801` **by new contributor** |Joshua Teves|_) +- Add ``pooch`` to system information reports (:gh:`9801` **by new contributor** |Joshua Teves|_) - Get annotation descriptions from the name field of SNIRF stimulus groups when reading SNIRF files via `mne.io.read_raw_snirf` (:gh:`9575` **by new contributor** |Darin Erat Sleiter|_) @@ -89,7 +89,7 @@ Enhancements - :func:`mne.concatenate_raws`, :func:`mne.concatenate_epochs`, and :func:`mne.write_evokeds` gained a new parameter ``on_mismatch``, which controls behavior in case not all of the supplied instances share the same device-to-head transformation (:gh:`9438` by `Richard Höchenberger`_) -- Add support for multiple datablocks (acquistions with pauses) in :func:`mne.io.read_raw_nihon` (:gh:`9437` by `Federico Raimondo`_) +- Add support for multiple datablocks (acquisitions with pauses) in :func:`mne.io.read_raw_nihon` (:gh:`9437` by `Federico Raimondo`_) - Add new function :func:`mne.preprocessing.annotate_break` to automatically detect and mark "break" periods without any marked experimental events in the continuous data (:gh:`9445` by `Richard Höchenberger`_) diff --git a/doc/changes/v1.2.rst b/doc/changes/v1.2.rst index b6a8b5a8edf..e292b472b03 100644 --- a/doc/changes/v1.2.rst +++ b/doc/changes/v1.2.rst @@ -63,7 +63,7 @@ Bugs API changes ~~~~~~~~~~~ -- In meth:`mne.Evoked.plot`, the default value of the ``spatial_colors`` parameter has been changed to ``'auto'``, which will use spatial colors if channel locations are available (:gh:`11201` by :newcontrib:`Hüseyin Orkun Elmas` and `Daniel McCloy`_) +- In :meth:`mne.Evoked.plot`, the default value of the ``spatial_colors`` parameter has been changed to ``'auto'``, which will use spatial colors if channel locations are available (:gh:`11201` by :newcontrib:`Hüseyin Orkun Elmas` and `Daniel McCloy`_) - Starting with this release we now follow the Python convention of using ``FutureWarning`` instead of ``DeprecationWarning`` to signal user-facing changes to our API (:gh:`11120` by `Daniel McCloy`_) - The ``names`` parameter of :func:`mne.viz.plot_arrowmap` and :func:`mne.viz.plot_regression_weights` has been deprecated; sensor names will be automatically drawn from the ``info_from`` or ``model`` parameter (respectively), and can be hidden, shown, or altered via the ``show_names`` parameter (:gh:`11123` by `Daniel McCloy`_) - The ``bands`` parameter of :meth:`mne.Epochs.plot_psd_topomap` now accepts :class:`dict` input; legacy :class:`tuple` input is supported, but discouraged for new code (:gh:`11050` by `Daniel McCloy`_) diff --git a/doc/changes/v1.6.rst b/doc/changes/v1.6.rst index ee58c7a527f..f770b5046d2 100644 --- a/doc/changes/v1.6.rst +++ b/doc/changes/v1.6.rst @@ -1,11 +1,11 @@ .. _changes_1_6_0: -Version 1.6.0 (2022-11-20) +Version 1.6.0 (2023-11-20) -------------------------- Enhancements ~~~~~~~~~~~~ -- Add support for Neuralynx data files with ``mne.io.read_raw_neuralynx`` (:gh:`11969` by :newcontrib:`Kristijan Armeni` and :newcontrib:`Ivan Skelin`) +- Add support for Neuralynx data files with :func:`mne.io.read_raw_neuralynx` (:gh:`11969` by :newcontrib:`Kristijan Armeni` and :newcontrib:`Ivan Skelin`) - Improve tests for saving splits with :class:`mne.Epochs` (:gh:`11884` by `Dmitrii Altukhov`_) - Added functionality for linking interactive figures together, such that changing one figure will affect another, see :ref:`tut-ui-events` and :mod:`mne.viz.ui_events`. Current figures implementing UI events are :func:`mne.viz.plot_topomap` and :func:`mne.viz.plot_source_estimates` (:gh:`11685` :gh:`11891` by `Marijn van Vliet`_) - HTML anchors for :class:`mne.Report` now reflect the ``section-title`` of the report items rather than using a global incrementor ``global-N`` (:gh:`11890` by `Eric Larson`_) diff --git a/doc/changes/v1.7.rst b/doc/changes/v1.7.rst new file mode 100644 index 00000000000..e8f8e2e8e7b --- /dev/null +++ b/doc/changes/v1.7.rst @@ -0,0 +1,180 @@ +.. _changes_1_7_0: + +1.7.0 (2024-04-19) +================== + +Notable changes +--------------- + +- In this version, we started adding type hints (also known as "type annotations") to select parts of the codebase. + This meta information will be used by development environments (IDEs) like VS Code and PyCharm automatically to provide + better assistance such as tab completion or error detection even before running your code. + + So far, we've only added return type hints to :func:`mne.io.read_raw`, :func:`mne.read_epochs`, :func:`mne.read_evokeds` and + all format-specific ``read_raw_*()`` and ``read_epochs_*()`` functions. Now your editors will know: + these functions return evoked and raw data, respectively. We are planning add type hints to more functions after careful + evaluation in the future. + + You don't need to do anything to benefit from these changes – your editor will pick them up automatically and provide the + enhanced experience if it supports it! (`#12250 `__) + + +Dependencies +------------ + +- ``defusedxml`` is now an optional (rather than required) dependency and needed when reading EGI-MFF data, NEDF data, and BrainVision montages, by `Eric Larson`_. (`#12264 `__) +- For developers, ``pytest>=8.0`` is now required for running unit tests, by `Eric Larson`_. (`#12376 `__) +- ``pytest-harvest`` is no longer used as a test dependency, by `Eric Larson`_. (`#12451 `__) +- The minimum supported version of Qt bindings is 5.15, by `Eric Larson`_. (`#12491 `__) + + +Bugfixes +-------- + +- Fix bug where section parameter in :meth:`mne.Report.add_html` was not being utilized resulting in improper formatting, by :newcontrib:`Martin Oberg`. (`#12319 `__) +- Fix bug in :func:`mne.preprocessing.maxwell_filter` where calibration was incorrectly applied during virtual sensor reconstruction, by `Eric Larson`_ and :newcontrib:`Motofumi Fushimi`. (`#12348 `__) +- Reformats channel and detector lookup in :func:`mne.io.read_raw_snirf` from array based to dictionary based. Removes incorrect assertions that every detector and source must have data associated with every registered optode position, by :newcontrib:`Alex Kiefer`. (`#12430 `__) +- Remove FDT file format check for strings in EEGLAB's EEG.data in :func:`mne.io.read_raw_eeglab` and related functions by :newcontrib:`Seyed Yahya Shirazi` (`#12523 `__) +- Fixes to interactivity in time-frequency objects: the rectangle selector now works on TFR image plots of gradiometer data; and in ``TFR.plot_joint()`` plots, the colormap limits of interactively-generated topomaps match the colormap limits of the main plot. By `Daniel McCloy`_. (`#11282 `__) +- Allow :func:`mne.viz.plot_compare_evokeds` to plot eyetracking channels, and improve error handling, y `Scott Huberty`_. (`#12190 `__) +- Fix bug in :meth:`mne.Epochs.apply_function` where data was handed down incorrectly in parallel processing, by `Dominik Welke`_. (`#12206 `__) +- Remove incorrect type hints in :func:`mne.io.read_raw_neuralynx`, by `Richard Höchenberger`_. (`#12236 `__) +- Fix bug with accessing the last data sample using ``raw[:, -1]`` where an empty array was returned, by `Eric Larson`_. (`#12248 `__) +- Correctly handle temporal gaps in Neuralynx .ncs files via :func:`mne.io.read_raw_neuralynx`, by `Kristijan Armeni`_ and `Eric Larson`_. (`#12279 `__) +- Fix bug where parent directory existence was not checked properly in :meth:`mne.io.Raw.save`, by `Eric Larson`_. (`#12282 `__) +- Add ``tol`` parameter to :meth:`mne.events_from_annotations` so that the user can specify the tolerance to ignore rounding errors of event onsets when using ``chunk_duration`` is not None (default is 1e-8), by `Michiru Kaneda`_ (`#12324 `__) +- Allow :meth:`mne.io.Raw.interpolate_bads` and :meth:`mne.Epochs.interpolate_bads` to work on ``ecog`` and ``seeg`` data; for ``seeg`` data a spline is fit to neighboring electrode contacts on the same shaft, by `Alex Rockhill`_ (`#12336 `__) +- Fix clicking on an axis of :func:`mne.viz.plot_evoked_topo` when multiple vertical lines ``vlines`` are used, by `Mathieu Scheltienne`_. (`#12345 `__) +- Fix bug in :meth:`mne.viz.EvokedField.set_vmax` that prevented setting the color limits of the MEG magnetic field density, by `Marijn van Vliet`_ (`#12354 `__) +- Fix faulty indexing in :func:`mne.io.read_raw_neuralynx` when picking a single channel, by `Kristijan Armeni`_. (`#12357 `__) +- Fix bug where :func:`mne.preprocessing.compute_proj_ecg` and :func:`mne.preprocessing.compute_proj_eog` could modify the default ``reject`` and ``flat`` arguments on multiple calls based on channel types present, by `Eric Larson`_. (`#12380 `__) +- Fix bad channels not handled properly in :func:`mne.stc_near_sensors` by `Alex Rockhill`_. (`#12382 `__) +- Fix bug where :func:`mne.preprocessing.regress_artifact` projection check was not specific to the channels being processed, by `Eric Larson`_. (`#12389 `__) +- Change how samples are read when using ``data_format='auto'`` in :func:`mne.io.read_raw_cnt`, by `Jacob Woessner`_. (`#12393 `__) +- Fix bugs with :class:`mne.Report` CSS where TOC items could disappear at the bottom of the page, by `Eric Larson`_. (`#12399 `__) +- In :func:`~mne.viz.plot_compare_evokeds`, actually plot GFP (not RMS amplitude) for EEG channels when global field power is requested by `Daniel McCloy`_. (`#12410 `__) +- Fix :ref:`tut-working-with-seeg` use of :func:`mne.stc_near_sensors` to use the :class:`mne.VolSourceEstimate` positions and not the pial surface, by `Alex Rockhill`_ (`#12436 `__) +- Fix prefiltering information management for EDF/BDF, by `Michiru Kaneda`_ (`#12441 `__) +- Fix validation of ``ch_type`` in :func:`mne.preprocessing.annotate_muscle_zscore`, by `Mathieu Scheltienne`_. (`#12444 `__) +- Fix errant redundant use of ``BIDSPath.split`` when writing split raw and epochs data, by `Eric Larson`_. (`#12451 `__) +- Disable config parser interpolation when reading BrainVision files, which allows using the percent sign as a regular character in channel units, by `Clemens Brunner`_. (`#12456 `__) +- - Fix the default color of :meth:`mne.viz.Brain.add_text` to properly contrast with the figure background color, by `Marijn van Vliet`_. (`#12470 `__) +- - Changed default ECoG and sEEG electrode sizes in brain plots to better reflect real world sizes, by `Liberty Hamilton`_ (`#12474 `__) +- Fixed bugs with handling of rank in :class:`mne.decoding.CSP`, by `Eric Larson`_. (`#12476 `__) +- - Fix reading segmented recordings with :func:`mne.io.read_raw_eyelink` by `Dominik Welke`_. (`#12481 `__) +- Improve compatibility with other Qt-based GUIs by handling theme icons better, by `Eric Larson`_. (`#12483 `__) +- - Fix problem caused by onsets with NaN values using :func:`mne.io.read_raw_eeglab` by `Jacob Woessner`_ (`#12484 `__) +- Fix cleaning of channel names for non vectorview or CTF dataset including whitespaces or dash in their channel names, by `Mathieu Scheltienne`_. (`#12489 `__) +- Fix bug with :meth:`mne.preprocessing.ICA.plot_sources` for ``evoked`` data where the + legend contained too many entries, by `Eric Larson`_. (`#12498 `__) +- Fix bug where using ``phase="minimum"`` in filtering functions like + :meth:`mne.io.Raw.filter` constructed a filter half the desired length with + compromised attenuation. Now ``phase="minimum"`` has the same length and comparable + suppression as ``phase="zero"``, and the old (incorrect) behavior can be achieved + with ``phase="minimum-half"``, by `Eric Larson`_. (`#12507 `__) +- Correct reading of ``info["subject_info"]["his_id"]`` in :func:`mne.io.read_raw_snirf`, by `Eric Larson`_. (`#12526 `__) +- Calling :meth:`~mne.io.Raw.compute_psd` with ``method="multitaper"`` is now expressly disallowed when ``reject_by_annotation=True`` and ``bad_*`` annotations are present (previously this was nominally allowed but resulted in ``nan`` values in the PSD). By `Daniel McCloy`_. (`#12535 `__) +- :meth:`~mne.io.Raw.compute_psd` and :func:`~mne.time_frequency.psd_array_welch` will now use FFT windows aligned to the onsets of good data spans when ``bad_*`` annotations are present. By `Daniel McCloy`_. (`#12536 `__) +- Fix bug in loading of complex/phase TFRs. By `Daniel McCloy`_. (`#12537 `__) +- Fix bug with :func:`mne.SourceSpaces.export_volume` where the ``img.affine`` was not set properly, by `Eric Larson`_. (`#12544 `__) + + +API changes by deprecation +-------------------------- + +- The default value of the ``zero_mean`` parameter of :func:`mne.time_frequency.tfr_array_morlet` will change from ``False`` to ``True`` in version 1.8, for consistency with related functions. By `Daniel McCloy`_. (`#11282 `__) +- The parameter for providing data to :func:`mne.time_frequency.tfr_array_morlet` and :func:`mne.time_frequency.tfr_array_multitaper` has been switched from ``epoch_data`` to ``data``. Only use the ``data`` parameter to avoid a warning. Changes by `Thomas Binns`_. (`#12308 `__) +- Change :func:`mne.stc_near_sensors` ``surface`` default from the ``'pial'`` surface to the surface in ``src`` if ``src`` is not ``None`` in version 1.8, by `Alex Rockhill`_. (`#12382 `__) + + +New features +------------ + +- Detecting Bad EEG/MEG channels using the local outlier factor (LOF) algorithm in :func:`mne.preprocessing.find_bad_channels_lof`, by :newcontrib:`Velu Prabhakar Kumaravel`. (`#11234 `__) +- Inform the user about channel discrepancy between provided info, forward operator, and/or covariance matrices in :func:`mne.beamformer.make_lcmv`, by :newcontrib:`Nikolai Kapralov`. (`#12238 `__) +- Support partial pathlength factors for each wavelength in :func:`mne.preprocessing.nirs.beer_lambert_law`, by :newcontrib:`Richard Scholz`. (`#12446 `__) +- Add ``picks`` parameter to :meth:`mne.io.Raw.plot`, allowing users to select which channels to plot. This makes makes the raw data plotting API consistent with :meth:`mne.Epochs.plot` and :meth:`mne.Evoked.plot`, by :newcontrib:`Ivo de Jong`. (`#12467 `__) +- New class :class:`mne.time_frequency.RawTFR` and new methods :meth:`mne.io.Raw.compute_tfr`, :meth:`mne.Epochs.compute_tfr`, and :meth:`mne.Evoked.compute_tfr`. These new methods supersede functions :func:`mne.time_frequency.tfr_morlet`, and :func:`mne.time_frequency.tfr_multitaper`, and :func:`mne.time_frequency.tfr_stockwell`, which are now considered "legacy" functions. By `Daniel McCloy`_. (`#11282 `__) +- Add ability reject :class:`mne.Epochs` using callables, by `Jacob Woessner`_. (`#12195 `__) +- Custom functions applied via :meth:`mne.io.Raw.apply_function`, :meth:`mne.Epochs.apply_function` or :meth:`mne.Evoked.apply_function` can now use ``ch_idx`` or ``ch_name`` to get access to the currently processed channel during channel wise processing. +- :meth:`mne.Evoked.apply_function` can now also work on full data array instead of just channel wise, analogous to :meth:`mne.io.Raw.apply_function` and :meth:`mne.Epochs.apply_function`, by `Dominik Welke`_. (`#12206 `__) +- Allow :class:`mne.time_frequency.EpochsTFR` as input to :func:`mne.epochs.equalize_epoch_counts`, by `Carina Forster`_. (`#12207 `__) +- Speed up export to .edf in :func:`mne.export.export_raw` by using ``edfio`` instead of ``EDFlib-Python``. (`#12218 `__) +- Added a helper function :func:`mne.preprocessing.eyetracking.convert_units` to convert eyegaze data from pixel-on-screen values to radians of visual angle. Also added a helper function :func:`mne.preprocessing.eyetracking.get_screen_visual_angle` to get the visual angle that the participant screen subtends, by `Scott Huberty`_. (`#12237 `__) +- We added type hints for the return values of :func:`mne.read_evokeds` and :func:`mne.io.read_raw`. Development environments like VS Code or PyCharm will now provide more help when using these functions in your code. By `Richard Höchenberger`_ and `Eric Larson`_. (:gh:`12297`) (`#12250 `__) +- Add ``method="polyphase"`` to :meth:`mne.io.Raw.resample` and related functions to allow resampling using :func:`scipy.signal.upfirdn`, by `Eric Larson`_. (`#12268 `__) +- The package build backend was switched from ``setuptools`` to ``hatchling``. This will only affect users who build and install MNE-Python from source. By `Richard Höchenberger`_. (:gh:`12281`) (`#12269 `__) +- :meth:`mne.Annotations.to_data_frame` can now output different formats for the ``onset`` column: seconds, milliseconds, datetime objects, and timedelta objects. By `Daniel McCloy`_. (`#12289 `__) +- Add method :meth:`mne.SourceEstimate.save_as_surface` to allow saving GIFTI files from surface source estimates, by `Peter Molfese`_. (`#12309 `__) +- :class:`mne.Epochs` can now be constructed using :class:`mne.Annotations` stored in the ``raw`` object, by specifying ``events=None``. By `Alex Rockhill`_. (`#12311 `__) +- Add :meth:`~mne.SourceEstimate.savgol_filter`, :meth:`~mne.SourceEstimate.filter`, :meth:`~mne.SourceEstimate.apply_hilbert`, and :meth:`~mne.SourceEstimate.apply_function` methods to :class:`mne.SourceEstimate` and related classes, by `Hamza Abdelhedi`_. (`#12323 `__) +- Add ability to export STIM channels to EDF in :meth:`mne.io.Raw.export`, by `Clemens Brunner`_. (`#12332 `__) +- Speed up raw FIF reading when using small buffer sizes by `Eric Larson`_. (`#12343 `__) +- Speed up :func:`mne.io.read_raw_neuralynx` on large datasets with many gaps, by `Kristijan Armeni`_. (`#12371 `__) +- Add ability to detect minima peaks found in :class:`mne.Evoked` if data is all positive and maxima if data is all negative. (`#12383 `__) +- Add ability to remove bad marker coils in :func:`mne.io.read_raw_kit`, by `Judy D Zhu`_. (`#12394 `__) +- Add option to pass ``image_kwargs`` to :class:`mne.Report.add_epochs` to allow adjusting e.g. ``vmin`` and ``vmax`` of the epochs image in the report, by `Sophie Herbst`_. (`#12443 `__) +- Add support for multiple raw instances in :func:`mne.preprocessing.compute_average_dev_head_t` by `Eric Larson`_. (`#12445 `__) +- Completing PR 12453. Add option to pass ``image_kwargs`` per channel type to :class:`mne.Report.add_epochs`. (`#12454 `__) +- :func:`mne.epochs.make_metadata` now accepts strings as ``tmin`` and ``tmax`` parameter values, simplifying metadata creation based on time-varying events such as responses to a stimulus, by `Richard Höchenberger`_. (`#12462 `__) +- Include date of acquisition and filter parameters in ``raw.info`` for :func:`mne.io.read_raw_neuralynx` by `Kristijan Armeni`_. (`#12463 `__) +- Add ``physical_range="channelwise"`` to :meth:`mne.io.Raw.export` for exporting to EDF, which can improve amplitude resolution if individual channels vary greatly in their offsets, by `Clemens Brunner`_. (`#12510 `__) +- Added the ability to reorder report contents via :meth:`mne.Report.reorder` (with + helper to get contents with :meth:`mne.Report.get_contents`), by `Eric Larson`_. (`#12513 `__) +- Add ``exclude_after_unique`` option to :meth:`mne.io.read_raw_edf` and :meth:`mne.io.read_raw_edf` to search for exclude channels after making channels names unique, by `Michiru Kaneda`_ (`#12518 `__) + + +Other changes +------------- + +- Updated the text in the preprocessing tutorial to use :meth:`mne.io.Raw.pick` instead of the legacy :meth:`mne.io.Raw.pick_types`, by :newcontrib:`btkcodedev`. (`#12326 `__) +- Clarify in the :ref:`EEG referencing tutorial ` that an average reference projector ready is required for inverse modeling, by :newcontrib:`Nabil Alibou` (`#12420 `__) +- Fix dead links in ``README.rst`` documentation by :newcontrib:`Will Turner`. (`#12461 `__) +- Replacing percent format with f-strings format specifiers , by :newcontrib:`Hasrat Ali Arzoo`. (`#12464 `__) +- Adopted towncrier_ for changelog entries, by `Eric Larson`_. (`#12299 `__) +- Automate adding of PR number to towncrier stubs, by `Eric Larson`_. (`#12318 `__) +- Refresh code base to use Python 3.9 syntax using Ruff UP rules (pyupgrade), by `Clemens Brunner`_. (`#12358 `__) +- Move private data preparation functions for BrainVision export from ``pybv`` to ``mne``, by `Clemens Brunner`_. (`#12450 `__) +- Update the list of sensor types in docstrings, tutorials and the glossary by `Nabil Alibou`_. (`#12509 `__) + + +Authors +------- +* Alex Rockhill +* Alexander Kiefer+ +* Alexandre Gramfort +* Britta Westner +* Carina Forster +* Clemens Brunner +* Daniel McCloy +* Dominik Welke +* Eric Larson +* Erkka Heinila +* Florian Hofer +* Hamza Abdelhedi +* Hasrat Ali Arzoo+ +* Ivo de Jong+ +* Jacob Woessner +* Judy D Zhu +* Kristijan Armeni +* Liberty Hamilton +* Marijn van Vliet +* Martin Oberg+ +* Mathieu Scheltienne +* Michiru Kaneda +* Motofumi Fushimi+ +* Nabil Alibou+ +* Nikolai Kapralov+ +* Peter J. Molfese +* Richard Höchenberger +* Richard Scholz+ +* Scott Huberty +* Seyed (Yahya) Shirazi+ +* Sophie Herbst +* Stefan Appelhoff +* Thomas Donoghue +* Thomas Samuel Binns +* Tristan Stenner +* Velu Prabhakar Kumaravel+ +* Will Turner+ +* btkcodedev+ diff --git a/doc/conf.py b/doc/conf.py index 2267fcb1026..ae7ab9677fd 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -6,35 +6,36 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from datetime import datetime, timezone import faulthandler import gc -from importlib.metadata import metadata import os -from pathlib import Path import subprocess import sys import time import warnings +from datetime import datetime, timezone +from importlib.metadata import metadata +from pathlib import Path -import numpy as np import matplotlib +import numpy as np import sphinx -from sphinx.domains.changeset import versionlabels -from sphinx_gallery.sorting import FileNameSortKey, ExplicitOrder from numpydoc import docscrape +from sphinx.domains.changeset import versionlabels +from sphinx_gallery.sorting import ExplicitOrder, FileNameSortKey import mne import mne.html_templates._templates from mne.tests.test_docstring_parameters import error_ignores from mne.utils import ( - linkcode_resolve, # noqa, analysis:ignore _assert_no_instances, - sizeof_fmt, + linkcode_resolve, run_subprocess, + sizeof_fmt, ) from mne.viz import Brain # noqa +assert linkcode_resolve is not None # avoid flake warnings, used by numpydoc matplotlib.use("agg") faulthandler.enable() os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" @@ -51,9 +52,8 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, "..", "mne"))) -sys.path.append(os.path.abspath(os.path.join(curdir, "sphinxext"))) +curpath = Path(__file__).parent.resolve(strict=True) +sys.path.append(str(curpath / "sphinxext")) # -- Project information ----------------------------------------------------- @@ -63,12 +63,12 @@ # We need to triage which date type we use so that incremental builds work # (Sphinx looks at variable changes and rewrites all files if some change) -copyright = ( +copyright = ( # noqa: A001 f'2012–{td.year}, MNE Developers. Last updated \n' # noqa: E501 '' # noqa: E501 ) if os.getenv("MNE_FULL_DATE", "false").lower() != "true": - copyright = f"2012–{td.year}, MNE Developers. Last updated locally." + copyright = f"2012–{td.year}, MNE Developers. Last updated locally." # noqa: A001 # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -107,6 +107,7 @@ "sphinx_gallery.gen_gallery", "sphinxcontrib.bibtex", "sphinxcontrib.youtube", + "sphinxcontrib.towncrier.ext", # homegrown "contrib_avatars", "gen_commands", @@ -123,7 +124,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_includes"] +exclude_patterns = ["_includes", "changes/devel"] # The suffix of source filenames. source_suffix = ".rst" @@ -149,6 +150,10 @@ copybutton_prompt_text = r">>> |\.\.\. |\$ " copybutton_prompt_is_regexp = True +# -- sphinxcontrib-towncrier configuration ----------------------------------- + +towncrier_draft_working_directory = str(curpath.parent) + # -- Intersphinx configuration ----------------------------------------------- intersphinx_mapping = { @@ -172,18 +177,11 @@ "patsy": ("https://patsy.readthedocs.io/en/latest", None), "pyvista": ("https://docs.pyvista.org", None), "imageio": ("https://imageio.readthedocs.io/en/latest", None), - "mne_realtime": ("https://mne.tools/mne-realtime", None), "picard": ("https://pierreablin.github.io/picard/", None), - "qdarkstyle": ("https://qdarkstylesheet.readthedocs.io/en/latest", None), "eeglabio": ("https://eeglabio.readthedocs.io/en/latest", None), - "dipy": ( - "https://dipy.org/documentation/1.7.0/", - "https://dipy.org/documentation/1.7.0/objects.inv/", - ), - "pooch": ("https://www.fatiando.org/pooch/latest/", None), + "dipy": ("https://docs.dipy.org/stable", None), "pybv": ("https://pybv.readthedocs.io/en/latest/", None), "pyqtgraph": ("https://pyqtgraph.readthedocs.io/en/latest/", None), - "openmeeg": ("https://openmeeg.github.io", None), } @@ -233,7 +231,11 @@ "EvokedArray": "mne.EvokedArray", "BiHemiLabel": "mne.BiHemiLabel", "AverageTFR": "mne.time_frequency.AverageTFR", + "AverageTFRArray": "mne.time_frequency.AverageTFRArray", "EpochsTFR": "mne.time_frequency.EpochsTFR", + "EpochsTFRArray": "mne.time_frequency.EpochsTFRArray", + "RawTFR": "mne.time_frequency.RawTFR", + "RawTFRArray": "mne.time_frequency.RawTFRArray", "Raw": "mne.io.Raw", "ICA": "mne.preprocessing.ICA", "Covariance": "mne.Covariance", @@ -273,6 +275,30 @@ "EOGRegression": "mne.preprocessing.EOGRegression", "Spectrum": "mne.time_frequency.Spectrum", "EpochsSpectrum": "mne.time_frequency.EpochsSpectrum", + "EpochsFIF": "mne.Epochs", + "EpochsEEGLAB": "mne.Epochs", + "EpochsKIT": "mne.Epochs", + "RawBOXY": "mne.io.Raw", + "RawBrainVision": "mne.io.Raw", + "RawBTi": "mne.io.Raw", + "RawCTF": "mne.io.Raw", + "RawCurry": "mne.io.Raw", + "RawEDF": "mne.io.Raw", + "RawEEGLAB": "mne.io.Raw", + "RawEGI": "mne.io.Raw", + "RawEximia": "mne.io.Raw", + "RawEyelink": "mne.io.Raw", + "RawFIL": "mne.io.Raw", + "RawGDF": "mne.io.Raw", + "RawHitachi": "mne.io.Raw", + "RawKIT": "mne.io.Raw", + "RawNedf": "mne.io.Raw", + "RawNeuralynx": "mne.io.Raw", + "RawNihon": "mne.io.Raw", + "RawNIRX": "mne.io.Raw", + "RawPersyst": "mne.io.Raw", + "RawSNIRF": "mne.io.Raw", + "Calibration": "mne.preprocessing.eyetracking.Calibration", # dipy "dipy.align.AffineMap": "dipy.align.imaffine.AffineMap", "dipy.align.DiffeomorphicMap": "dipy.align.imwarp.DiffeomorphicMap", @@ -371,34 +397,17 @@ "n_moments", "n_patterns", "n_new_events", - # Undocumented (on purpose) - "RawKIT", - "RawEximia", - "RawEGI", - "RawEEGLAB", - "RawEDF", - "RawCTF", - "RawBTi", - "RawBrainVision", - "RawCurry", - "RawNIRX", - "RawGDF", - "RawSNIRF", - "RawBOXY", - "RawPersyst", - "RawNihon", - "RawNedf", - "RawHitachi", - "RawFIL", - "RawEyelink", # sklearn subclasses "mapping", "to", "any", # unlinkable "CoregistrationUI", - "IntracranialElectrodeLocator", "mne_qt_browser.figure.MNEQtBrowser", + # pooch, since its website is unreliable and users will rarely need the links + "pooch.Unzip", + "pooch.Untar", + "pooch.HTTPDownloader", } numpydoc_validate = True numpydoc_validation_checks = {"all"} | set(error_ignores) @@ -442,16 +451,18 @@ # -- Sphinx-gallery configuration -------------------------------------------- -class Resetter(object): +class Resetter: """Simple class to make the str(obj) static for Sphinx build env hash.""" def __init__(self): self.t0 = time.time() def __repr__(self): + """Make a stable repr.""" return f"<{self.__class__.__name__}>" def __call__(self, gallery_conf, fname, when): + """Do the reset.""" import matplotlib.pyplot as plt try: @@ -479,6 +490,8 @@ def __call__(self, gallery_conf, fname, when): plt.ioff() plt.rcParams["animation.embed_limit"] = 40.0 plt.rcParams["figure.raise_window"] = False + # https://github.com/sphinx-gallery/sphinx-gallery/pull/1243#issue-2043332860 + plt.rcParams["animation.html"] = "html5" # neo holds on to an exception, which in turn holds a stack frame, # which will keep alive the global vars during SG execution try: @@ -674,7 +687,9 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): if what in ("attribute", "method"): size = os.path.getsize( os.path.join( - os.path.dirname(__file__), "generated", "%s.examples" % (name,) + os.path.dirname(__file__), + "generated", + f"{name}.examples", ) ) if size > 0: @@ -685,11 +700,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): .. minigallery:: {1} -""".format( - name.split(".")[-1], name - ).split( - "\n" - ) +""".format(name.split(".")[-1], name).split("\n") # -- Other extension configuration ------------------------------------------- @@ -725,8 +736,8 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): "https://doi.org/10.3109/", # www.tandfonline.com "https://www.researchgate.net/profile/", "https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html", - "https://scholar.google.com/scholar?cites=12188330066413208874&as_ylo=2014", - "https://scholar.google.com/scholar?cites=1521584321377182930&as_ylo=2013", + r"https://scholar.google.com/scholar\?cites=12188330066413208874&as_ylo=2014", + r"https://scholar.google.com/scholar\?cites=1521584321377182930&as_ylo=2013", # 500 server error "https://openwetware.org/wiki/Beauchamp:FreeSurfer", # 503 Server error @@ -743,6 +754,8 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # Too slow "https://speakerdeck.com/dengemann/", "https://www.dtu.dk/english/service/phonebook/person", + # SSL problems sometimes + "http://ilabs.washington.edu", ] linkcheck_anchors = False # saves a bit of time linkcheck_timeout = 15 # some can be quite slow @@ -761,6 +774,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # -- Nitpicky ---------------------------------------------------------------- nitpicky = True +show_warning_types = True nitpick_ignore = [ ("py:class", "None. Remove all items from D."), ("py:class", "a set-like object providing a view on D's items"), @@ -778,14 +792,22 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ("py:class", "None. Remove all items from od."), ] nitpick_ignore_regex = [ - ("py:.*", r"mne\.io\.BaseRaw.*"), - ("py:.*", r"mne\.BaseEpochs.*"), + # Classes whose methods we purposefully do not document + ("py:.*", r"mne\.io\.BaseRaw.*"), # use mne.io.Raw + ("py:.*", r"mne\.BaseEpochs.*"), # use mne.Epochs + # Type hints for undocumented types + ("py:.*", r"mne\.io\..*\.Raw.*"), # RawEDF etc. + ("py:.*", r"mne\.epochs\.EpochsFIF.*"), + ("py:.*", r"mne\.io\..*\.Epochs.*"), # EpochsKIT etc. ( "py:obj", "(filename|metadata|proj|times|tmax|tmin|annotations|ch_names|compensation_grade|filenames|first_samp|first_time|last_samp|n_times|proj|times|tmax|tmin)", ), # noqa: E501 ] -suppress_warnings = ["image.nonlocal_uri"] # we intentionally link outside +suppress_warnings = [ + "image.nonlocal_uri", # we intentionally link outside + "config.cache", # our rebuild is okay +] # -- Sphinx hacks / overrides ------------------------------------------------ @@ -802,7 +824,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -switcher_version_match = "dev" if release.endswith("dev0") else version +switcher_version_match = "dev" if ".dev" in version else version html_theme_options = { "icon_links": [ dict( @@ -833,7 +855,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ), ], "icon_links_label": "External Links", # for screen reader - "use_edit_page_button": True, + "use_edit_page_button": False, "navigation_with_keys": False, "show_toc_level": 1, "article_header_start": [], # disable breadcrumbs @@ -1182,21 +1204,21 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): "carousel": [ dict( title="Source Estimation", - text="Distributed, sparse, mixed-norm, beam\u00ADformers, dipole fitting, and more.", # noqa E501 + text="Distributed, sparse, mixed-norm, beam\u00adformers, dipole fitting, and more.", # noqa E501 url="auto_tutorials/inverse/index.html", img="sphx_glr_30_mne_dspm_loreta_008.gif", alt="dSPM", ), dict( title="Machine Learning", - text="Advanced decoding models including time general\u00ADiza\u00ADtion.", # noqa E501 + text="Advanced decoding models including time general\u00adiza\u00adtion.", # noqa E501 url="auto_tutorials/machine-learning/50_decoding.html", img="sphx_glr_50_decoding_006.png", alt="Decoding", ), dict( title="Encoding Models", - text="Receptive field estima\u00ADtion with optional smooth\u00ADness priors.", # noqa E501 + text="Receptive field estima\u00adtion with optional smooth\u00adness priors.", # noqa E501 url="auto_tutorials/machine-learning/30_strf.html", img="sphx_glr_30_strf_001.png", alt="STRF", @@ -1210,7 +1232,7 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ), dict( title="Connectivity", - text="All-to-all spectral and effective connec\u00ADtivity measures.", # noqa E501 + text="All-to-all spectral and effective connec\u00adtivity measures.", # noqa E501 url="https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html", # noqa E501 img="https://mne.tools/mne-connectivity/stable/_images/sphx_glr_mne_inverse_label_connectivity_001.png", # noqa E501 alt="Connectivity", @@ -1304,6 +1326,7 @@ def reset_warnings(gallery_conf, fname): for key in ( "invalid version and will not be supported", # pyxdf "distutils Version classes are deprecated", # seaborn and neo + "is_categorical_dtype is deprecated", # seaborn "`np.object` is a deprecated alias for the builtin `object`", # pyxdf # nilearn, should be fixed in > 0.9.1 "In future, it will be an error for 'np.bool_' scalars to", @@ -1326,8 +1349,19 @@ def reset_warnings(gallery_conf, fname): # nilearn "pkg_resources is deprecated as an API", r"The .* was deprecated in Matplotlib 3\.7", - # scipy - r"scipy.signal.morlet2 is deprecated in SciPy 1\.12", + # Matplotlib->tz + r"datetime\.datetime\.utcfromtimestamp", + # joblib + r"ast\.Num is deprecated", + r"Attribute n is deprecated and will be removed in Python 3\.14", + # numpydoc + r"ast\.NameConstant is deprecated and will be removed in Python 3\.14", + # pooch + r"Python 3\.14 will, by default, filter extracted tar archives.*", + # seaborn + r"DataFrameGroupBy\.apply operated on the grouping columns.*", + # pandas + r"\nPyarrow will become a required dependency of pandas.*", ): warnings.filterwarnings( # deal with other modules having bad imports "ignore", message=".*%s.*" % key, category=DeprecationWarning @@ -1366,6 +1400,7 @@ def reset_warnings(gallery_conf, fname): r"iteritems is deprecated.*Use \.items instead\.", "is_categorical_dtype is deprecated.*", "The default of observed=False.*", + "When grouping with a length-1 list-like.*", ): warnings.filterwarnings( "ignore", @@ -1732,7 +1767,7 @@ def reset_warnings(gallery_conf, fname): def check_existing_redirect(path): """Make sure existing HTML files are redirects, before overwriting.""" if path.is_file(): - with open(path, "r") as fid: + with open(path) as fid: for _ in range(8): next(fid) line = fid.readline() diff --git a/doc/development/contributing.rst b/doc/development/contributing.rst index 4a9e7f52d0e..04fa49e924b 100644 --- a/doc/development/contributing.rst +++ b/doc/development/contributing.rst @@ -93,8 +93,8 @@ Setting up your local development environment Configuring git ~~~~~~~~~~~~~~~ -.. note:: Git GUI alternative - :class: sidebar +.. admonition:: Git GUI alternative + :class: sidebar note `GitHub desktop`_ is a GUI alternative to command line git that some users appreciate; it is available for |windows| Windows and |apple| MacOS. @@ -230,8 +230,8 @@ of how that structure is set up is given here: Creating the virtual environment ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. note:: Supported Python environments - :class: sidebar +.. admonition:: Supported Python environments + :class: sidebar note We strongly recommend the `Anaconda`_ or `Miniconda`_ environment managers for Python. Other setups are possible but are not officially supported by @@ -243,7 +243,7 @@ Creating the virtual environment These instructions will set up a Python environment that is separated from your system-level Python and any other managed Python environments on your computer. This lets you switch between different versions of Python (MNE-Python requires -version 3.8 or higher) and also switch between the stable and development +version 3.9 or higher) and also switch between the stable and development versions of MNE-Python (so you can, for example, use the same computer to analyze your data with the stable release, and also work with the latest development version to fix bugs or add new features). Even if you've already @@ -304,11 +304,11 @@ be reflected the next time you open a Python interpreter and ``import mne`` Finally, we'll add a few dependencies that are not needed for running MNE-Python, but are needed for locally running our test suite:: - $ pip install -e .[test] + $ pip install -e ".[test]" And for building our documentation:: - $ pip install -e .[doc] + $ pip install -e ".[doc]" $ conda install graphviz .. note:: @@ -375,7 +375,7 @@ feature, you should first synchronize your local ``main`` branch with the $ git merge upstream/main # synchronize local main branch with remote upstream main branch $ git checkout -b new-feature-x # create local branch "new-feature-x" and check it out -.. note:: Alternative +.. tip:: :class: sidebar You can save some typing by using ``git pull upstream/main`` to replace @@ -591,42 +591,54 @@ Describe your changes in the changelog -------------------------------------- Include in your changeset a brief description of the change in the -:ref:`changelog ` (:file:`doc/changes/devel.rst`; this can be -skipped for very minor changes like correcting typos in the documentation). - -There are different sections of the changelog for each release, and separate -**subsections for bugfixes, new features, and changes to the public API.** -Please be sure to add your entry to the appropriate subsection. - -The styling and positioning of the entry depends on whether you are a -first-time contributor or have been mentioned in the changelog before. - -First-time contributors -""""""""""""""""""""""" - -Welcome to MNE-Python! We're very happy to have you here. 🤗 And to ensure you -get proper credit for your work, please add a changelog entry with the -following pattern **at the top** of the respective subsection (bugs, -enhancements, etc.): - -.. code-block:: rst - - - Bugs - ---- - - - Short description of the changes (:gh:`0000` by :newcontrib:`Firstname Lastname`) - - - ... - -where ``0000`` must be replaced with the respective GitHub pull request (PR) -number, and ``Firstname Lastname`` must be replaced with your full name. - -It is usually best to wait to add a line to the changelog until your PR is -finalized, to avoid merge conflicts (since the changelog is updated with -almost every PR). - -Lastly, make sure that your name is included in the list of authors in +:ref:`changelog ` using towncrier_ format, which aggregates small, +properly-named ``.rst`` files to create a change log. This can be +skipped for very minor changes like correcting typos in the documentation. + +There are six separate sections for changes, based on change type. +To add a changelog entry to a given section, name it as +:file:`doc/changes/devel/..rst`. The types are: + +notable + For overarching changes, e.g., adding type hints package-wide. These are rare. +dependency + For changes to dependencies, e.g., adding a new dependency or changing + the minimum version of an existing dependency. +bugfix + For bug fixes. Can change code behavior with no deprecation period. +apichange + Code behavior changes that require a deprecation period. +newfeature + For new features. +other + For changes that don't fit into any of the above categories, e.g., + internal refactorings. + +For example, for an enhancement PR with number 12345, the changelog entry should be +added as a new file :file:`doc/changes/devel/12345.enhancement.rst`. The file should +contain: + +1. A brief description of the change, typically in a single line of one or two + sentences. +2. reST links to **public** API endpoints like functions (``:func:``), + classes (``:class:``), and methods (``:meth:``). If changes are only internal + to private functions/attributes, mention internal refactoring rather than name + the private attributes changed. +3. Author credit. If you are a new contributor (we're very happy to have you here! 🤗), + you should using the ``:newcontrib:`` reST role, whereas previous contributors should + use a standard reST link to their name. For example, a new contributor could write: + + .. code-block:: rst + + Short description of the changes, by :newcontrib:`Firstname Lastname`. + + And an previous contributor could write: + + .. code-block:: rst + + Short description of the changes, by `Firstname Lastname`_. + +Make sure that your name is included in the list of authors in :file:`doc/changes/names.inc`, otherwise the documentation build will fail. To add an author name, append a line with the following pattern (note how the syntax is different from that used in the changelog): @@ -638,27 +650,13 @@ how the syntax is different from that used in the changelog): Many contributors opt to link to their GitHub profile that way. Have a look at the existing entries in the file to get some inspiration. -Recurring contributors -"""""""""""""""""""""" - -The changelog entry should follow the following patterns: - -.. code-block:: rst - - - Short description of the changes from one contributor (:gh:`0000` by `Contributor Name`_) - - Short description of the changes from several contributors (:gh:`0000` by `Contributor Name`_, `Second Contributor`_, and `Third Contributor`_) - -where ``0000`` must be replaced with the respective GitHub pull request (PR) -number. Mind the Oxford comma in the case of multiple contributors. - Sometimes, changes that shall appear as a single changelog entry are spread out -across multiple PRs. In this case, name all relevant PRs, separated by -commas: +across multiple PRs. In this case, edit the existing towncrier file for the relevant +change, and append additional PR numbers in parentheticals with the ``:gh:`` role like: .. code-block:: rst - - Short description of the changes from one contributor in multiple PRs (:gh:`0000`, :gh:`1111` by `Contributor Name`_) - - Short description of the changes from several contributors in multiple PRs (:gh:`0000`, :gh:`1111` by `Contributor Name`_, `Second Contributor`_, and `Third Contributor`_) + Short description of the changes, by `Firstname Lastname`_. (:gh:`12346`) Test locally before opening pull requests (PRs) ----------------------------------------------- @@ -867,8 +865,8 @@ to both visualization functions and tutorials/examples. Running the test suite ~~~~~~~~~~~~~~~~~~~~~~ -.. note:: pytest flags - :class: sidebar +.. admonition:: pytest flags + :class: sidebar tip The ``-x`` flag exits the pytest run when any test fails; this can speed up debugging when running all tests in a file or module. diff --git a/doc/development/index.rst b/doc/development/index.rst index 1bdc5322f36..98fc28f8e7f 100644 --- a/doc/development/index.rst +++ b/doc/development/index.rst @@ -24,7 +24,7 @@ experience. .. _`opening an issue`: https://github.com/mne-tools/mne-python/issues/new/choose .. _`MNE Forum`: https://mne.discourse.group .. _`code of conduct`: https://github.com/mne-tools/.github/blob/main/CODE_OF_CONDUCT.md -.. _`contributing guide`: https://mne.tools/dev/install/contributing.html +.. _`contributing guide`: https://mne.tools/dev/development/contributing.html .. toctree:: :hidden: diff --git a/doc/development/roadmap.rst b/doc/development/roadmap.rst index ced61c7e4a1..defd4eac5cc 100644 --- a/doc/development/roadmap.rst +++ b/doc/development/roadmap.rst @@ -6,8 +6,6 @@ MNE-Python. These are goals that require substantial effort and/or API design considerations. Some of these may be suitable for Google Summer of Code projects, while others require more extensive work. -.. contents:: Page contents - :local: Open ---- diff --git a/doc/development/whats_new.rst b/doc/development/whats_new.rst index 0e8c96ebe4d..920194e7fb2 100644 --- a/doc/development/whats_new.rst +++ b/doc/development/whats_new.rst @@ -8,6 +8,7 @@ Changes for each version of MNE-Python are listed below. .. toctree:: :maxdepth: 1 + ../changes/v1.7.rst ../changes/v1.6.rst ../changes/v1.5.rst ../changes/v1.4.rst diff --git a/doc/documentation/cited.rst b/doc/documentation/cited.rst index 7654cf3fd40..31c19589b16 100644 --- a/doc/documentation/cited.rst +++ b/doc/documentation/cited.rst @@ -3,7 +3,7 @@ Papers citing MNE-Python ======================== -Estimates provided by Google Scholar as of 14 August 2023: +Estimates provided by Google Scholar as of 19 April 2024: -- `MNE (1540) `_ -- `MNE-Python (2040) `_ +- `MNE (1730) `_ +- `MNE-Python (2570) `_ diff --git a/doc/documentation/datasets.rst b/doc/documentation/datasets.rst index 063d06da363..70da39cccd8 100644 --- a/doc/documentation/datasets.rst +++ b/doc/documentation/datasets.rst @@ -516,7 +516,7 @@ Contains both EEG (EGI) and eye-tracking (ASCII format) data recorded from a pupillary light reflex experiment, stored in separate files. 1 participant fixated on the screen while short light flashes appeared. Event onsets were recorded by a photodiode attached to the screen and were sent to both the EEG and eye-tracking -systems. +systems. .. topic:: Examples diff --git a/doc/documentation/glossary.rst b/doc/documentation/glossary.rst index 91b8922e8c6..89a5c477a75 100644 --- a/doc/documentation/glossary.rst +++ b/doc/documentation/glossary.rst @@ -41,15 +41,15 @@ general neuroimaging concepts. If you think a term is missing, please consider Channels refer to MEG sensors, EEG electrodes or other sensors such as EOG, ECG, sEEG, ECoG, etc. Channels usually have a type (such as gradiometer), and a unit (such as T/m) used e.g. for - plotting. See also :term:`data channels`. + plotting. See also :term:`data channels` and :term:`non-data channels`. data channels Many functions in MNE-Python operate on "data channels" by default. These are channels that contain electrophysiological data from the brain, as opposed to other channel types such as EOG, ECG, stimulus/trigger, - or acquisition system status data. The set of channels considered - "data channels" in MNE contains the following types (together with scale - factors for plotting): + or acquisition system status data (see :term:`non-data channels`). + The set of channels considered "data channels" in MNE contains the + following types (together with scale factors for plotting): .. mne:: data channels list @@ -287,6 +287,13 @@ general neuroimaging concepts. If you think a term is missing, please consider data into a common space for statistical analysis. See :ref:`ch_morph` for more details. + non-data channels + All types of channels other than :term:`data channels`. + The set of channels considered "non-data channels" in MNE contains the + following types (together with scale factors for plotting): + + .. mne:: non-data channels list + OPM optically pumped magnetometer An optically pumped magnetometer (OPM) is a type of magnetometer @@ -350,6 +357,10 @@ general neuroimaging concepts. If you think a term is missing, please consider A selection is a set of picked channels (for example, all sensors falling within a :term:`region of interest`). + sensor types + All the sensors handled by MNE-Python can be divided into two categories: + :term:`data channels` and :term:`non-data channels`. + STC source estimate source time course diff --git a/doc/install/installers.rst b/doc/install/installers.rst index 2d1d75323b8..26199483d60 100644 --- a/doc/install/installers.rst +++ b/doc/install/installers.rst @@ -15,7 +15,7 @@ Got any questions? Let us know on the `MNE Forum`_! :class-content: text-center :name: linux-installers - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.5.1/MNE-Python-1.5.1_0-Linux.sh + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.6.1/MNE-Python-1.6.1_0-Linux.sh :ref-type: ref :color: primary :shadow: @@ -29,14 +29,14 @@ Got any questions? Let us know on the `MNE Forum`_! .. code-block:: console - $ sh ./MNE-Python-1.5.1_0-Linux.sh + $ sh ./MNE-Python-1.6.1_0-Linux.sh .. tab-item:: macOS (Intel) :class-content: text-center :name: macos-intel-installers - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.5.1/MNE-Python-1.5.1_0-macOS_Intel.pkg + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.6.1/MNE-Python-1.6.1_0-macOS_Intel.pkg :ref-type: ref :color: primary :shadow: @@ -52,7 +52,7 @@ Got any questions? Let us know on the `MNE Forum`_! :class-content: text-center :name: macos-apple-installers - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.5.1/MNE-Python-1.5.1_0-macOS_M1.pkg + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.6.1/MNE-Python-1.6.1_0-macOS_M1.pkg :ref-type: ref :color: primary :shadow: @@ -68,7 +68,7 @@ Got any questions? Let us know on the `MNE Forum`_! :class-content: text-center :name: windows-installers - .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.5.1/MNE-Python-1.5.1_0-Windows.exe + .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.6.1/MNE-Python-1.6.1_0-Windows.exe :ref-type: ref :color: primary :shadow: @@ -120,7 +120,7 @@ information, including a line that will read something like: .. code-block:: - Using Python: /some/directory/mne-python_1.5.1_0/bin/python + Using Python: /some/directory/mne-python_1.6.1_0/bin/python This path is what you need to enter in VS Code when selecting the Python interpreter. diff --git a/doc/install/manual_install.rst b/doc/install/manual_install.rst index 57932648bf6..ab7ad074e51 100644 --- a/doc/install/manual_install.rst +++ b/doc/install/manual_install.rst @@ -15,19 +15,20 @@ Installing MNE-Python with all dependencies ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If you use Anaconda, we suggest installing MNE-Python into its own ``conda`` environment. -The dependency stack is large and may take a long time (several tens of -minutes) to resolve on some systems via the default ``conda`` solver. We -therefore highly recommend using the new `libmamba `__ -solver instead, which is **much** faster. To permanently change to this solver, -you can set ``CONDA_SOLVER=libmamba`` in your environment or run -``conda config --set solver libmamba``. Below we just use ``--solver`` in each command. +First, please ensure you're using a recent version of ``conda``. Run in your terminal: -Run in your terminal: +.. code-block:: console + + $ conda update --name=base conda # update conda + $ conda --version + +The installed ``conda`` version should be ``23.10.0`` or newer. + +Now, you can install MNE-Python: .. code-block:: console - $ conda install --channel=conda-forge --name=base conda-libmamba-solver - $ conda create --solver=libmamba --override-channels --channel=conda-forge --name=mne mne + $ conda create --channel=conda-forge --strict-channel-priority --name=mne mne This will create a new ``conda`` environment called ``mne`` (you can adjust this by passing a different name via ``--name``) and install all @@ -50,7 +51,7 @@ or via :code:`conda`: .. code-block:: console - $ conda create --override-channels --channel=conda-forge --name=mne mne-base + $ conda create --channel=conda-forge --strict-channel-priority --name=mne mne-base This will create a new ``conda`` environment called ``mne`` (you can adjust this by passing a different name via ``--name``). @@ -67,7 +68,7 @@ others), you should run via :code:`pip`: .. code-block:: console - $ pip install mne[hdf5] + $ pip install "mne[hdf5]" or via :code:`conda`: diff --git a/doc/install/mne_tools_suite.rst b/doc/install/mne_tools_suite.rst index 03b65671826..64b3933ea0f 100644 --- a/doc/install/mne_tools_suite.rst +++ b/doc/install/mne_tools_suite.rst @@ -62,11 +62,11 @@ MNE-Python, including packages for: - automatic multi-dipole localization and uncertainty quantification with the Bayesian algorithm SESAME (`sesameeg`_) - GLM and group level analysis of near-infrared spectroscopy data (`MNE-NIRS`_) -- high-level EEG Python library for all kinds of EEG inverse solutions (`invertmeeg`_) - All-Resolutions Inference (ARI) for statistically valid circular inference and effect localization (`MNE-ARI`_) - real-time analysis (`MNE-Realtime`_) - non-parametric sequential analyses and adaptive sample size determination (`niseq`_) +- a graphical user interface for multi-subject MEG/EEG analysis with plugin support (`Meggie`_) What should I install? ^^^^^^^^^^^^^^^^^^^^^^ @@ -100,7 +100,6 @@ Help with installation is available through the `MNE Forum`_. See the .. _MNELAB: https://github.com/cbrnr/mnelab .. _autoreject: https://autoreject.github.io/ .. _alphaCSC: https://alphacsc.github.io/ -.. _picard: https://pierreablin.github.io/picard/ .. _pactools: https://pactools.github.io/ .. _rsa: https://github.com/wmvanvliet/mne-rsa .. _microstate: https://github.com/wmvanvliet/mne_microstates @@ -112,5 +111,6 @@ Help with installation is available through the `MNE Forum`_. See the .. _invertmeeg: https://github.com/LukeTheHecker/invert .. _MNE-ARI: https://github.com/john-veillette/mne_ari .. _niseq: https://github.com/john-veillette/niseq +.. _Meggie: https://github.com/cibr-jyu/meggie .. include:: ../links.inc diff --git a/doc/install/updating.rst b/doc/install/updating.rst index 0737ee7c6a0..c946d5e496e 100644 --- a/doc/install/updating.rst +++ b/doc/install/updating.rst @@ -78,8 +78,8 @@ Sometimes, new features or bugfixes become available that are important to your research and you just can't wait for the next official release of MNE-Python to start taking advantage of them. In such cases, you can use ``pip`` to install the *development version* of MNE-Python. Ensure to activate the MNE conda -environment first by running ``conda activate name_of_environment``. +environment first by running ``conda activate mne``. .. code-block:: console - $ pip install -U --no-deps git+https://github.com/mne-tools/mne-python@main + $ pip install -U --no-deps https://github.com/mne-tools/mne-python/archive/refs/heads/main.zip diff --git a/doc/links.inc b/doc/links.inc index 52dfec9b068..27e61c850bc 100644 --- a/doc/links.inc +++ b/doc/links.inc @@ -26,6 +26,7 @@ .. _`MNE-ICAlabel`: https://github.com/mne-tools/mne-icalabel .. _`MNE-Connectivity`: https://github.com/mne-tools/mne-connectivity .. _`MNE-NIRS`: https://github.com/mne-tools/mne-nirs +.. _PICARD: https://pierreablin.github.io/picard/ .. _OpenMEEG: https://openmeeg.github.io .. _openneuro-py: https://pypi.org/project/openneuro-py .. _EOSS2: https://chanzuckerberg.com/eoss/proposals/improving-usability-of-core-neuroscience-analysis-tools-with-mne-python @@ -95,6 +96,7 @@ .. _PIL: https://pypi.python.org/pypi/PIL .. _tqdm: https://tqdm.github.io/ .. _pooch: https://www.fatiando.org/pooch/latest/ +.. _towncrier: https://towncrier.readthedocs.io/ .. python editors @@ -107,7 +109,7 @@ .. _anaconda: https://www.anaconda.com/products/individual .. _miniconda: https://conda.io/en/latest/miniconda.html .. _miniforge: https://github.com/conda-forge/miniforge -.. _mambaforge: https://mamba.readthedocs.io/en/latest/mamba-installation.html#mamba-install +.. _mambaforge: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html .. _installation instructions for Anaconda: http://docs.continuum.io/anaconda/install .. _installation instructions for Miniconda: https://conda.io/projects/conda/en/latest/user-guide/install/index.html .. _Anaconda troubleshooting guide: http://conda.pydata.org/docs/troubleshooting.html diff --git a/doc/references.bib b/doc/references.bib index 9263379209a..7a992b2c1fa 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2450,6 +2450,37 @@ @article{TierneyEtAl2022 author = {Tierney, Tim M. and Mellor, Stephanie nd O'Neill, George C. and Timms, Ryan C. and Barnes, Gareth R.}, } +@article{KumaravelEtAl2022, + doi = {10.3390/s22197314}, + url = {https://doi.org/10.3390/s22197314}, + year = {2022}, + month = sep, + publisher = {{MDPI} {AG}}, + volume = {22}, + number = {19}, + pages = {7314}, + author = {Velu Prabhakar Kumaravel and Marco Buiatti and Eugenio Parise and Elisabetta Farella}, + title = {Adaptable and Robust {EEG} Bad Channel Detection Using Local Outlier Factor ({LOF})}, + journal = {Sensors} +} + +@article{BreunigEtAl2000, + author = {Breunig, Markus M. and Kriegel, Hans-Peter and Ng, Raymond T. and Sander, J\"{o}rg}, + title = {LOF: Identifying Density-Based Local Outliers}, + year = {2000}, + issue_date = {June 2000}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + volume = {29}, + number = {2}, + url = {https://doi.org/10.1145/335191.335388}, + doi = {10.1145/335191.335388}, + journal = {SIGMOD Rec.}, + month = {may}, + pages = {93–104}, + numpages = {12}, + keywords = {outlier detection, database mining} +} @article{OyamaEtAl2015, title = {Dry phantom for magnetoencephalography —{Configuration}, calibration, and contribution}, diff --git a/doc/sphinxext/contrib_avatars.py b/doc/sphinxext/contrib_avatars.py index bbfd17de7d3..04583ac4c77 100644 --- a/doc/sphinxext/contrib_avatars.py +++ b/doc/sphinxext/contrib_avatars.py @@ -1,33 +1,41 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import os from pathlib import Path -from selenium import webdriver -from selenium.webdriver.common.by import By -from selenium.webdriver.support.ui import WebDriverWait -from selenium.common.exceptions import WebDriverException - def generate_contrib_avatars(app, config): """Render a template webpage with avatars generated by JS and a GitHub API call.""" root = Path(app.srcdir) infile = root / "sphinxext" / "_avatar_template.html" outfile = root / "_templates" / "avatars.html" - try: - options = webdriver.ChromeOptions() - options.add_argument("--headless=new") - driver = webdriver.Chrome(options=options) - except WebDriverException: - options = webdriver.FirefoxOptions() - options.add_argument("--headless=new") - driver = webdriver.Firefox(options=options) - driver.get(f"file://{infile}") - wait = WebDriverWait(driver, 20) - wait.until(lambda d: d.find_element(by=By.ID, value="contributor-avatars")) - body = driver.find_element(by=By.TAG_NAME, value="body").get_attribute("innerHTML") + if os.getenv("MNE_ADD_CONTRIBUTOR_IMAGE", "false").lower() != "true": + body = """\ +

Contributor avators will appear here in full doc builds. Set \ +MNE_ADD_CONTRIBUTOR_IMAGE=true in your environment to generate it.

""" + else: + from selenium import webdriver + from selenium.common.exceptions import WebDriverException + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + + try: + options = webdriver.ChromeOptions() + options.add_argument("--headless=new") + driver = webdriver.Chrome(options=options) + except WebDriverException: + options = webdriver.FirefoxOptions() + options.add_argument("-headless") + driver = webdriver.Firefox(options=options) + driver.get(f"file://{infile}") + wait = WebDriverWait(driver, 20) + wait.until(lambda d: d.find_element(by=By.ID, value="contributor-avatars")) + body = driver.find_element(by=By.TAG_NAME, value="body").get_attribute( + "innerHTML" + ) + driver.quit() with open(outfile, "w") as fid: fid.write(body) - driver.quit() def setup(app): diff --git a/doc/sphinxext/flow_diagram.py b/doc/sphinxext/flow_diagram.py index ba374c60f88..cefe6713a7d 100644 --- a/doc/sphinxext/flow_diagram.py +++ b/doc/sphinxext/flow_diagram.py @@ -12,18 +12,14 @@ sensor_color = "#7bbeca" source_color = "#ff6347" -legend = """ -< +legend = f""" +< - - -
+
Sensor (M/EEG) space
+
Source (brain) space
>""" % ( - edge_size, - sensor_color, - source_color, -) +
>""" legend = "".join(legend.split("\n")) nodes = dict( diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py index 5fa9cd7418a..e50e243eb48 100644 --- a/doc/sphinxext/gen_commands.py +++ b/doc/sphinxext/gen_commands.py @@ -2,10 +2,9 @@ # Copyright the MNE-Python contributors. import glob from importlib import import_module -import os from pathlib import Path -from mne.utils import _replace_md5, ArgvSetter +from mne.utils import ArgvSetter, _replace_md5 def setup(app): diff --git a/doc/sphinxext/gen_names.py b/doc/sphinxext/gen_names.py index 1871ae0068c..fd667ec0951 100644 --- a/doc/sphinxext/gen_names.py +++ b/doc/sphinxext/gen_names.py @@ -25,7 +25,7 @@ def generate_name_links_rst(app=None): ) with open(out_fname, "w", encoding="utf8") as fout: fout.write(":orphan:\n\n") - with open(names_path, "r") as fin: + with open(names_path) as fin: for line in fin: if line.startswith(".. _"): fout.write(f"- {line[4:]}") diff --git a/doc/sphinxext/gh_substitutions.py b/doc/sphinxext/gh_substitutions.py index bccc16d13d0..890a71f1c47 100644 --- a/doc/sphinxext/gh_substitutions.py +++ b/doc/sphinxext/gh_substitutions.py @@ -4,7 +4,7 @@ from docutils.parsers.rst.roles import set_classes -def gh_role(name, rawtext, text, lineno, inliner, options={}, content=[]): +def gh_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # noqa: B006 """Link to a GitHub issue. adapted from diff --git a/doc/sphinxext/mne_substitutions.py b/doc/sphinxext/mne_substitutions.py index 6a5cdbb6797..0c4f9a2f3dd 100644 --- a/doc/sphinxext/mne_substitutions.py +++ b/doc/sphinxext/mne_substitutions.py @@ -4,12 +4,13 @@ from docutils.parsers.rst import Directive from docutils.statemachine import StringList -from mne.defaults import DEFAULTS from mne._fiff.pick import ( - _PICK_TYPES_DATA_DICT, - _DATA_CH_TYPES_SPLIT, _DATA_CH_TYPES_ORDER_DEFAULT, + _DATA_CH_TYPES_SPLIT, + _EYETRACK_CH_TYPES_SPLIT, + _PICK_TYPES_DATA_DICT, ) +from mne.defaults import DEFAULTS class MNESubstitution(Directive): # noqa: D101 @@ -29,18 +30,35 @@ def run(self, **kwargs): # noqa: D102 ): keys.append(key) rst = "- " + "\n- ".join( - "``%r``: **%s** (scaled by %g to plot in *%s*)" - % ( - key, - DEFAULTS["titles"][key], - DEFAULTS["scalings"][key], - DEFAULTS["units"][key], - ) + f"``{repr(key)}``: **{DEFAULTS['titles'][key]}** " + f"(scaled by {DEFAULTS['scalings'][key]:g} to " + f"plot in *{DEFAULTS['units'][key]}*)" for key in keys ) + elif self.arguments[0] == "non-data channels list": + keys = list() + rst = "" + for key in _DATA_CH_TYPES_ORDER_DEFAULT: + if ( + not _PICK_TYPES_DATA_DICT.get(key, True) + or key in _EYETRACK_CH_TYPES_SPLIT + or key in ("ref_meg", "whitened") + ): + keys.append(key) + for key in keys: + if DEFAULTS["scalings"].get(key, False) and DEFAULTS["units"].get( + key, False + ): + rst += ( + f"- ``{repr(key)}``: **{DEFAULTS['titles'][key]}** " + f"(scaled by {DEFAULTS['scalings'][key]:g} to " + f"plot in *{DEFAULTS['units'][key]}*)\n" + ) + else: + rst += f"- ``{repr(key)}``: **{DEFAULTS['titles'][key]}**\n" else: raise self.error( - "MNE directive unknown in %s: %r" + "MNE directive unknown in %s: %r" # noqa: UP031 % ( env.doc2path(env.docname, base=None), self.arguments[0], diff --git a/doc/sphinxext/newcontrib_substitutions.py b/doc/sphinxext/newcontrib_substitutions.py index 41cf348c7c4..c38aeb86219 100644 --- a/doc/sphinxext/newcontrib_substitutions.py +++ b/doc/sphinxext/newcontrib_substitutions.py @@ -3,7 +3,7 @@ from docutils.nodes import reference, strong, target -def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, content=[]): +def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # noqa: B006 """Create a role to highlight new contributors in changelog entries.""" newcontrib = f"new contributor {text}" alias_text = f" <{text}_>" diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index b882aedc6b1..4d9c9d94252 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -3,7 +3,7 @@ from docutils import nodes -def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): +def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # noqa: B006 parts = text.split() def pass_error_to_sphinx(rawtext, text, lineno, inliner): @@ -24,7 +24,7 @@ def pass_error_to_sphinx(rawtext, text, lineno, inliner): except ValueError: return pass_error_to_sphinx(rawtext, text, lineno, inliner) # input is well-formatted: proceed - node = nodes.Text("\u202F".join(parts)) + node = nodes.Text("\u202f".join(parts)) return [node], [] diff --git a/environment.yml b/environment.yml index 9f0971b2fb3..cc2f8e752d5 100644 --- a/environment.yml +++ b/environment.yml @@ -2,10 +2,11 @@ name: mne channels: - conda-forge dependencies: - - python>=3.8 + - python>=3.9 - pip - numpy - scipy + - openblas - matplotlib - tqdm - pooch>=1.5 @@ -14,6 +15,7 @@ dependencies: - packaging - numba - pandas + - pyarrow - xlrd - scikit-learn - h5py @@ -33,7 +35,7 @@ dependencies: - traitlets - pyvista>=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5,!=0.38.6,!=0.42.0 - pyvistaqt>=0.4 - - qdarkstyle + - qdarkstyle!=3.2.2 - darkdetect - dipy - nibabel @@ -56,8 +58,9 @@ dependencies: - mne-qt-browser - pymatreader - eeglabio - - edflib-python + - edfio>=0.2.1 - pybv - mamba - lazy_loader - defusedxml + - python-neo diff --git a/examples/datasets/brainstorm_data.py b/examples/datasets/brainstorm_data.py index 0f32c704284..6331c9f1b29 100644 --- a/examples/datasets/brainstorm_data.py +++ b/examples/datasets/brainstorm_data.py @@ -41,7 +41,9 @@ raw.set_eeg_reference("average", projection=True) # show power line interference and remove it -raw.compute_psd(tmax=60).plot(average=False, picks="data", exclude="bads") +raw.compute_psd(tmax=60).plot( + average=False, amplitude=False, picks="data", exclude="bads" +) raw.notch_filter(np.arange(60, 181, 60), fir_design="firwin") events = mne.find_events(raw, stim_channel="UPPT001") diff --git a/examples/datasets/hf_sef_data.py b/examples/datasets/hf_sef_data.py index ec6ef61bcb2..44aa6e8f9a4 100644 --- a/examples/datasets/hf_sef_data.py +++ b/examples/datasets/hf_sef_data.py @@ -14,7 +14,6 @@ # %% - import os import mne diff --git a/examples/datasets/limo_data.py b/examples/datasets/limo_data.py index 4a0f96ed8ff..54a2f34a530 100644 --- a/examples/datasets/limo_data.py +++ b/examples/datasets/limo_data.py @@ -190,7 +190,7 @@ # get levels of phase coherence levels = sorted(phase_coh.unique()) # create labels for levels of phase coherence (i.e., 0 - 85%) -labels = ["{0:.2f}".format(i) for i in np.arange(0.0, 0.90, 0.05)] +labels = [f"{i:.2f}" for i in np.arange(0.0, 0.90, 0.05)] # create dict of evokeds for each level of phase-coherence evokeds = { diff --git a/examples/datasets/opm_data.py b/examples/datasets/opm_data.py index 3f1903b3010..fcc60d80934 100644 --- a/examples/datasets/opm_data.py +++ b/examples/datasets/opm_data.py @@ -114,8 +114,8 @@ ) idx = np.argmax(dip_opm.gof) print( - "Best dipole at t=%0.1f ms with %0.1f%% GOF" - % (1000 * dip_opm.times[idx], dip_opm.gof[idx]) + f"Best dipole at t={1000 * dip_opm.times[idx]:0.1f} ms with " + f"{dip_opm.gof[idx]:0.1f}% GOF" ) # Plot N20m dipole as an example diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset.py similarity index 60% rename from examples/datasets/spm_faces_dataset_sgskip.py rename to examples/datasets/spm_faces_dataset.py index 1357fc513b6..32df7d1a9ed 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset.py @@ -5,15 +5,8 @@ From raw data to dSPM on SPM Faces dataset ========================================== -Runs a full pipeline using MNE-Python: - - - artifact removal - - averaging Epochs - - forward model computation - - source reconstruction using dSPM on the contrast : "faces - scrambled" - -.. note:: This example does quite a bit of processing, so even on a - fast machine it can take several minutes to complete. +Runs a full pipeline using MNE-Python. This example does quite a bit of processing, so +even on a fast machine it can take several minutes to complete. """ # Authors: Alexandre Gramfort # Denis Engemann @@ -21,12 +14,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% - -# sphinx_gallery_thumbnail_number = 10 - -import matplotlib.pyplot as plt - import mne from mne import combine_evoked, io from mne.datasets import spm_face @@ -40,109 +27,72 @@ spm_path = data_path / "MEG" / "spm" # %% -# Load and filter data, set up epochs +# Load data, filter it, and fit ICA. raw_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D.ds" - raw = io.read_raw_ctf(raw_fname, preload=True) # Take first run # Here to save memory and time we'll downsample heavily -- this is not # advised for real data as it can effectively jitter events! -raw.resample(120.0, npad="auto") - -picks = mne.pick_types(raw.info, meg=True, exclude="bads") -raw.filter(1, 30, method="fir", fir_design="firwin") +raw.resample(100) +raw.filter(1.0, None) # high-pass +reject = dict(mag=5e-12) +ica = ICA(n_components=0.95, max_iter="auto", random_state=0) +ica.fit(raw, reject=reject) +# compute correlation scores, get bad indices sorted by score +eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) +eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") +ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on +ica.plot_components(eog_inds) # view topographic sensitivity of components +ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar +ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +# %% +# Epoch data and apply ICA. events = mne.find_events(raw, stim_channel="UPPT001") - -# plot the events to get an idea of the paradigm -mne.viz.plot_events(events, raw.info["sfreq"]) - event_ids = {"faces": 1, "scrambled": 2} - tmin, tmax = -0.2, 0.6 -baseline = None # no baseline as high-pass is applied -reject = dict(mag=5e-12) - epochs = mne.Epochs( raw, events, event_ids, tmin, tmax, - picks=picks, - baseline=baseline, + picks="meg", + baseline=None, preload=True, reject=reject, ) - -# Fit ICA, find and remove major artifacts -ica = ICA(n_components=0.95, max_iter="auto", random_state=0) -ica.fit(raw, decim=1, reject=reject) - -# compute correlation scores, get bad indices sorted by score -eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) -eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") -ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on -ica.plot_components(eog_inds) # view topographic sensitivity of components -ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar -ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +del raw ica.apply(epochs) # clean data, default in place - evoked = [epochs[k].average() for k in event_ids] - contrast = combine_evoked(evoked, weights=[-1, 1]) # Faces - scrambled - evoked.append(contrast) - for e in evoked: e.plot(ylim=dict(mag=[-400, 400])) -plt.show() - -# estimate noise covarariance -noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) - # %% -# Visualize fields on MEG helmet - -# The transformation here was aligned using the dig-montage. It's included in -# the spm_faces dataset and is named SPM_dig_montage.fif. -trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" - -maps = mne.make_field_map( - evoked[0], trans_fname, subject="spm", subjects_dir=subjects_dir, n_jobs=None -) - -evoked[0].plot_field(maps, time=0.170, time_viewer=False) - -# %% -# Look at the whitened evoked daat +# Estimate noise covariance and look at the whitened evoked data +noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) evoked[0].plot_white(noise_cov) # %% # Compute forward model +trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" src = subjects_dir / "spm" / "bem" / "spm-oct-6-src.fif" bem = subjects_dir / "spm" / "bem" / "spm-5120-5120-5120-bem-sol.fif" forward = mne.make_forward_solution(contrast.info, trans_fname, src, bem) # %% -# Compute inverse solution +# Compute inverse solution and plot + +# sphinx_gallery_thumbnail_number = 8 snr = 3.0 lambda2 = 1.0 / snr**2 -method = "dSPM" - -inverse_operator = make_inverse_operator( - contrast.info, forward, noise_cov, loose=0.2, depth=0.8 -) - -# Compute inverse solution on contrast -stc = apply_inverse(contrast, inverse_operator, lambda2, method, pick_ori=None) -# stc.save('spm_%s_dSPM_inverse' % contrast.comment) - -# Plot contrast in 3D with mne.viz.Brain if available +inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov) +stc = apply_inverse(contrast, inverse_operator, lambda2, method="dSPM", pick_ori=None) brain = stc.plot( hemi="both", subjects_dir=subjects_dir, @@ -150,4 +100,3 @@ views=["ven"], clim={"kind": "value", "lims": [3.0, 6.0, 9.0]}, ) -# brain.save_image('dSPM_map.png') diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index 85a468cb590..2ffd18d34b4 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -20,14 +20,13 @@ # %% - import matplotlib.pyplot as plt import numpy as np from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.model_selection import ShuffleSplit, cross_val_score from sklearn.pipeline import Pipeline -from mne import Epochs, events_from_annotations, pick_types +from mne import Epochs, pick_types from mne.channels import make_standard_montage from mne.datasets import eegbci from mne.decoding import CSP @@ -41,7 +40,6 @@ # avoid classification of evoked responses by using epochs that start 1s after # cue onset. tmin, tmax = -1.0, 4.0 -event_id = dict(hands=2, feet=3) subject = 1 runs = [6, 10, 14] # motor imagery: hands vs feet @@ -50,22 +48,21 @@ eegbci.standardize(raw) # set channel names montage = make_standard_montage("standard_1005") raw.set_montage(montage) +raw.annotations.rename(dict(T1="hands", T2="feet")) +raw.set_eeg_reference(projection=True) # Apply band-pass filter raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") -events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) - picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") # Read epochs (train will be done only between 1 and 2s) # Testing will be done with a running classifier epochs = Epochs( raw, - events, - event_id, - tmin, - tmax, + event_id=["hands", "feet"], + tmin=tmin, + tmax=tmax, proj=True, picks=picks, baseline=None, @@ -95,9 +92,7 @@ # Printing the results class_balance = np.mean(labels == labels[0]) class_balance = max(class_balance, 1.0 - class_balance) -print( - "Classification accuracy: %f / Chance level: %f" % (np.mean(scores), class_balance) -) +print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}") # plot CSP patterns estimated on full data for visualization csp.fit_transform(epochs_data, labels) diff --git a/examples/decoding/decoding_csp_timefreq.py b/examples/decoding/decoding_csp_timefreq.py index f81e4fc0fea..c389645d668 100644 --- a/examples/decoding/decoding_csp_timefreq.py +++ b/examples/decoding/decoding_csp_timefreq.py @@ -21,7 +21,6 @@ # %% - import matplotlib.pyplot as plt import numpy as np from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -29,23 +28,22 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import LabelEncoder -from mne import Epochs, create_info, events_from_annotations +from mne import Epochs, create_info from mne.datasets import eegbci from mne.decoding import CSP from mne.io import concatenate_raws, read_raw_edf -from mne.time_frequency import AverageTFR +from mne.time_frequency import AverageTFRArray # %% # Set parameters and read data -event_id = dict(hands=2, feet=3) # motor imagery: hands vs feet subject = 1 runs = [6, 10, 14] raw_fnames = eegbci.load_data(subject, runs) raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames]) +raw.annotations.rename(dict(T1="hands", T2="feet")) # Extract information from the raw file sfreq = raw.info["sfreq"] -events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) raw.pick(picks="eeg", exclude="bads") raw.load_data() @@ -95,10 +93,9 @@ # Extract epochs from filtered data, padded by window size epochs = Epochs( raw_filter, - events, - event_id, - tmin - w_size, - tmax + w_size, + event_id=["hands", "feet"], + tmin=tmin - w_size, + tmax=tmax + w_size, proj=False, baseline=None, preload=True, @@ -148,10 +145,9 @@ # Extract epochs from filtered data, padded by window size epochs = Epochs( raw_filter, - events, - event_id, - tmin - w_size, - tmax + w_size, + event_id=["hands", "feet"], + tmin=tmin - w_size, + tmax=tmax + w_size, proj=False, baseline=None, preload=True, @@ -177,13 +173,15 @@ # Plot time-frequency results # Set up time frequency object -av_tfr = AverageTFR( - create_info(["freq"], sfreq), - tf_scores[np.newaxis, :], - centered_w_times, - freqs[1:], - 1, +av_tfr = AverageTFRArray( + info=create_info(["freq"], sfreq), + data=tf_scores[np.newaxis, :], + times=centered_w_times, + freqs=freqs[1:], + nave=1, ) chance = np.mean(y) # set chance level to white in the plot -av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds) +av_tfr.plot( + [0], vlim=(chance, None), title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds +) diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index 4d49fb1e350..0a02a61052c 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -64,7 +64,7 @@ # Define a two fold cross-validation cv = KFold(n_splits=2, shuffle=False) -# Run cross validaton +# Run cross validation y_preds = cross_val_predict(clf, X, y, cv=cv) # Plot the True EMG power and the EMG power predicted from MEG data diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 24b459f192f..6d20b9ac582 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -17,7 +17,7 @@ .. _figure 1: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F1 .. _figure 2: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F2 .. _figure 5: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F5 -""" # noqa: E501 +""" # Authors: Chris Holdgraf # Eric Larson @@ -26,9 +26,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% -# sphinx_gallery_thumbnail_number = 3 - from os.path import join import matplotlib.pyplot as plt @@ -58,8 +55,8 @@ speech = data["envelope"].T sfreq = float(data["Fs"].item()) sfreq /= decim -speech = mne.filter.resample(speech, down=decim, npad="auto") -raw = mne.filter.resample(raw, down=decim, npad="auto") +speech = mne.filter.resample(speech, down=decim, method="polyphase") +raw = mne.filter.resample(raw, down=decim, method="polyphase") # Read in channel positions and create our MNE objects from the raw data montage = mne.channels.make_standard_montage("biosemi128") @@ -105,7 +102,7 @@ coefs = np.zeros((n_splits, n_channels, n_delays)) scores = np.zeros((n_splits, n_channels)) for ii, (train, test) in enumerate(cv.split(speech)): - print("split %s / %s" % (ii + 1, n_splits)) + print(f"split {ii + 1} / {n_splits}") rf.fit(speech[train], Y[train]) scores[ii] = rf.score(speech[test], Y[test]) # coef_ is shape (n_outputs, n_features, n_delays). we only have 1 feature @@ -131,6 +128,8 @@ # across the scalp. We will recreate `figure 1`_ and `figure 2`_ from # :footcite:`CrosseEtAl2016`. +# sphinx_gallery_thumbnail_number = 3 + # Print mean coefficients across all time delays / channels (see Fig 1) time_plot = 0.180 # For highlighting a specific time. fig, ax = plt.subplots(figsize=(4, 8), layout="constrained") @@ -213,7 +212,7 @@ patterns = coefs.copy() scores = np.zeros((n_splits,)) for ii, (train, test) in enumerate(cv.split(speech)): - print("split %s / %s" % (ii + 1, n_splits)) + print(f"split {ii + 1} / {n_splits}") sr.fit(Y[train], speech[train]) scores[ii] = sr.score(Y[test], speech[test])[0] # coef_ is shape (n_outputs, n_features, n_delays). We have 128 features @@ -273,9 +272,7 @@ show=False, vlim=(-max_coef, max_coef), ) -ax[0].set( - title="Model coefficients\nbetween delays %s and %s" % (time_plot[0], time_plot[1]) -) +ax[0].set(title=f"Model coefficients\nbetween delays {time_plot[0]} and {time_plot[1]}") mne.viz.plot_topomap( np.mean(mean_patterns[:, ix_plot], axis=1), @@ -285,8 +282,10 @@ vlim=(-max_patterns, max_patterns), ) ax[1].set( - title="Inverse-transformed coefficients\nbetween delays %s and %s" - % (time_plot[0], time_plot[1]) + title=( + f"Inverse-transformed coefficients\nbetween delays {time_plot[0]} and " + f"{time_plot[1]}" + ) ) # %% diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py index 5f4ea3fbcf7..b7c8c4f2c94 100644 --- a/examples/decoding/ssd_spatial_filters.py +++ b/examples/decoding/ssd_spatial_filters.py @@ -20,7 +20,6 @@ # %% - import matplotlib.pyplot as plt import mne diff --git a/examples/forward/left_cerebellum_volume_source.py b/examples/forward/left_cerebellum_volume_source.py index 22e46073d88..ff810493e99 100644 --- a/examples/forward/left_cerebellum_volume_source.py +++ b/examples/forward/left_cerebellum_volume_source.py @@ -73,8 +73,8 @@ # And display source positions in freeview:: # # >>> from mne.utils import run_subprocess -# >>> mri_fname = subjects_dir + '/sample/mri/brain.mgz' -# >>> run_subprocess(['freeview', '-v', mri_fname, '-v', -# '%s:colormap=lut:opacity=0.5' % aseg_fname, '-v', -# '%s:colormap=jet:colorscale=0,2' % nii_fname, -# '-slice', '157 75 105']) +# >>> mri_fname = subjects_dir / "sample" / "mri" / "brain.mgz" +# >>> run_subprocess(["freeview", "-v", str(mri_fname), "-v", +# f"{aseg_fname}:colormap=lut:opacity=0.5", +# "-v", f"{nii_fname}:colormap=jet:colorscale=0,2", +# "--slice", "157", "75", "105"]) diff --git a/examples/inverse/compute_mne_inverse_raw_in_label.py b/examples/inverse/compute_mne_inverse_raw_in_label.py index d2d7b8be3d2..ac97df8ff4b 100644 --- a/examples/inverse/compute_mne_inverse_raw_in_label.py +++ b/examples/inverse/compute_mne_inverse_raw_in_label.py @@ -49,7 +49,7 @@ ) # Save result in stc files -stc.save("mne_%s_raw_inverse_%s" % (method, label_name), overwrite=True) +stc.save(f"mne_{method}_raw_inverse_{label_name}", overwrite=True) # %% # View activation time-series diff --git a/examples/inverse/compute_mne_inverse_volume.py b/examples/inverse/compute_mne_inverse_volume.py index 8283dfdeeca..39b455f464b 100644 --- a/examples/inverse/compute_mne_inverse_volume.py +++ b/examples/inverse/compute_mne_inverse_volume.py @@ -56,5 +56,5 @@ index_img(img, 61), str(t1_fname), threshold=8.0, - title="%s (t=%.1f s.)" % (method, stc.times[61]), + title=f"{method} (t={stc.times[61]:.1f} s.)", ) diff --git a/examples/inverse/dics_epochs.py b/examples/inverse/dics_epochs.py index d480b13f8a4..c359c30c0fb 100644 --- a/examples/inverse/dics_epochs.py +++ b/examples/inverse/dics_epochs.py @@ -22,7 +22,7 @@ import mne from mne.beamformer import apply_dics_tfr_epochs, make_dics from mne.datasets import somato -from mne.time_frequency import csd_tfr, tfr_morlet +from mne.time_frequency import csd_tfr print(__doc__) @@ -67,8 +67,8 @@ # decomposition for each epoch. We must pass ``output='complex'`` if we wish to # use this TFR later with a DICS beamformer. We also pass ``average=False`` to # compute the TFR for each individual epoch. -epochs_tfr = tfr_morlet( - epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False +epochs_tfr = epochs.compute_tfr( + "morlet", freqs, n_cycles=5, return_itc=False, output="complex", average=False ) # crop either side to use a buffer to remove edge artifact diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index 7ae7fa86424..f118a217c9e 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -34,12 +34,7 @@ data_path = somato.data_path() subject = "01" task = "somato" -raw_fname = ( - data_path - / "sub-{}".format(subject) - / "meg" - / "sub-{}_task-{}_meg.fif".format(subject, task) -) +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" # crop to 5 minutes to save memory raw = mne.io.read_raw_fif(raw_fname).crop(0, 300) @@ -59,10 +54,7 @@ # Read forward operator and point to freesurfer subject directory fname_fwd = ( - data_path - / "derivatives" - / "sub-{}".format(subject) - / "sub-{}_task-{}-fwd.fif".format(subject, task) + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" ) subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" diff --git a/examples/inverse/label_source_activations.py b/examples/inverse/label_source_activations.py index 4a92ea27962..7640a468ebd 100644 --- a/examples/inverse/label_source_activations.py +++ b/examples/inverse/label_source_activations.py @@ -113,7 +113,7 @@ ax.set( xlabel="Time (ms)", ylabel="Source amplitude", - title="Mean vector activations in Label %r" % (label.name,), + title=f"Mean vector activations in Label {label.name!r}", xlim=xlim, ylim=ylim, ) diff --git a/examples/inverse/mixed_norm_inverse.py b/examples/inverse/mixed_norm_inverse.py index 038bbad0d8b..bc6b91bfeae 100644 --- a/examples/inverse/mixed_norm_inverse.py +++ b/examples/inverse/mixed_norm_inverse.py @@ -137,7 +137,7 @@ forward["src"], stc, bgcolor=(1, 1, 1), - fig_name="%s (cond %s)" % (solver, condition), + fig_name=f"{solver} (cond {condition})", opacity=0.1, ) @@ -159,7 +159,7 @@ src_fsaverage, stc_fsaverage, bgcolor=(1, 1, 1), - fig_name="Morphed %s (cond %s)" % (solver, condition), + fig_name=f"Morphed {solver} (cond {condition})", opacity=0.1, ) diff --git a/examples/inverse/read_stc.py b/examples/inverse/read_stc.py index 9b2823bd7a7..b06f61d14f8 100644 --- a/examples/inverse/read_stc.py +++ b/examples/inverse/read_stc.py @@ -29,9 +29,7 @@ stc = mne.read_source_estimate(fname) n_vertices, n_samples = stc.data.shape -print( - "stc data size: %s (nb of vertices) x %s (nb of samples)" % (n_vertices, n_samples) -) +print(f"stc data size: {n_vertices} (nb of vertices) x {n_samples} (nb of samples)") # View source activations plt.plot(stc.times, stc.data[::100, :].T) diff --git a/examples/io/elekta_epochs.py b/examples/io/elekta_epochs.py index 5619a0e5174..4afa0ad888d 100644 --- a/examples/io/elekta_epochs.py +++ b/examples/io/elekta_epochs.py @@ -15,7 +15,6 @@ # %% - import os import mne diff --git a/examples/preprocessing/css.py b/examples/preprocessing/css.py index 9095094d93c..ba4e2385d0c 100644 --- a/examples/preprocessing/css.py +++ b/examples/preprocessing/css.py @@ -75,9 +75,9 @@ def subcortical_waveform(times): labels=[postcenlab, hiplab], data_fun=cortical_waveform, ) -stc.data[ - np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], : -] = subcortical_waveform(times) +stc.data[np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], :] = ( + subcortical_waveform(times) +) evoked = simulate_evoked(fwd, stc, raw.info, cov, nave=15) ############################################################################### diff --git a/examples/preprocessing/define_target_events.py b/examples/preprocessing/define_target_events.py index 5672b8d69ad..5aa1becbb6b 100644 --- a/examples/preprocessing/define_target_events.py +++ b/examples/preprocessing/define_target_events.py @@ -100,7 +100,7 @@ # average epochs and get an Evoked dataset. -early, late = [epochs[k].average() for k in event_id] +early, late = (epochs[k].average() for k in event_id) # %% # View evoked response diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index 6c7052cb028..87e1d8621f0 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -320,12 +320,9 @@ # compute variance of residuals print( "Variance of residual (interpolated data - original data)\n\n" - "With adding virtual channel: {}\n" - "Compared to interpolation only using other channels: {}" - "".format( - np.mean(np.var(data_virtual - data_orig, axis=1)), - np.mean(np.var(data_comp - data_orig, axis=1)), - ) + f"With adding virtual channel: {np.mean(np.var(data_virtual - data_orig, axis=1))}\n" + f"Compared to interpolation only using other channels: {np.mean(np.var(data_comp - data_orig, axis=1))}" + "" ) # plot results @@ -384,16 +381,7 @@ raw = raw_data[1] # typically impedances < 25 kOhm are acceptable for active systems and # impedances < 5 kOhm are desirable for a passive system -impedances = ( - rng.random( - ( - len( - raw.ch_names, - ) - ) - ) - * 30 -) +impedances = rng.random(len(raw.ch_names)) * 30 impedances[10] = 80 # set a few bad impendances impedances[25] = 99 cmap = LinearSegmentedColormap.from_list( diff --git a/examples/preprocessing/eeg_csd.py b/examples/preprocessing/eeg_csd.py index 73515e1f043..35ba959c34d 100644 --- a/examples/preprocessing/eeg_csd.py +++ b/examples/preprocessing/eeg_csd.py @@ -49,8 +49,8 @@ # %% # Also look at the power spectral densities: -raw.compute_psd().plot(picks="data", exclude="bads") -raw_csd.compute_psd().plot(picks="data", exclude="bads") +raw.compute_psd().plot(picks="data", exclude="bads", amplitude=False) +raw_csd.compute_psd().plot(picks="data", exclude="bads", amplitude=False) # %% # CSD can also be computed on Evoked (averaged) data. diff --git a/examples/preprocessing/eog_artifact_histogram.py b/examples/preprocessing/eog_artifact_histogram.py index d883fa427f8..8a89f9d8a44 100644 --- a/examples/preprocessing/eog_artifact_histogram.py +++ b/examples/preprocessing/eog_artifact_histogram.py @@ -15,7 +15,6 @@ # %% - import matplotlib.pyplot as plt import numpy as np diff --git a/examples/preprocessing/epochs_metadata.py b/examples/preprocessing/epochs_metadata.py new file mode 100644 index 00000000000..d1ea9a85996 --- /dev/null +++ b/examples/preprocessing/epochs_metadata.py @@ -0,0 +1,171 @@ +""" +.. _epochs-metadata: + +=============================================================== +Automated epochs metadata generation with variable time windows +=============================================================== + +When working with :class:`~mne.Epochs`, :ref:`metadata ` can be +invaluable. There is an extensive tutorial on +:ref:`how it can be generated automatically `. +In the brief examples below, we will demonstrate different ways to bound the time +windows used to generate the metadata. + +""" +# Authors: Richard Höchenberger +# +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# %% +# We will use data from an EEG recording during an Eriksen flanker task. For the +# purpose of demonstration, we'll only load the first 60 seconds of data. + +import mne + +data_dir = mne.datasets.erp_core.data_path() +infile = data_dir / "ERP-CORE_Subject-001_Task-Flankers_eeg.fif" + +raw = mne.io.read_raw(infile, preload=True) +raw.crop(tmax=60).filter(l_freq=0.1, h_freq=40) + +# %% +# Visualizing the events +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# All experimental events are stored in the :class:`~mne.io.Raw` instance as +# :class:`~mne.Annotations`. We first need to convert these to events and the +# corresponding mapping from event codes to event names (``event_id``). We then +# visualize the events. +all_events, all_event_id = mne.events_from_annotations(raw) +mne.viz.plot_events(events=all_events, event_id=all_event_id, sfreq=raw.info["sfreq"]) + + +# %% +# As you can see, there are four types of ``stimulus`` and two types of ``response`` +# events. +# +# Declaring "row events" +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# For the sake of this example, we will assume that during analysis our epochs will be +# time-locked to the stimulus onset events. Hence, we would like to create metadata with +# one row per ``stimulus``. We can achieve this by specifying all stimulus event names +# as ``row_events``. + +row_events = [ + "stimulus/compatible/target_left", + "stimulus/compatible/target_right", + "stimulus/incompatible/target_left", + "stimulus/incompatible/target_right", +] + +# %% +# Specifying metadata time windows +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Now, we will explore different ways of specifying the time windows around the +# ``row_events`` when generating metadata. Any events falling within the same time +# window will be added to the same row in the metadata table. +# +# Fixed time window +# ~~~~~~~~~~~~~~~~~ +# +# A simple way to specify the time window extent is by specifying the time in seconds +# relative to the row event. In the following example, the time window spans from the +# row event (time point zero) up until three seconds later. + +metadata_tmin = 0.0 +metadata_tmax = 3.0 + +metadata, events, event_id = mne.epochs.make_metadata( + events=all_events, + event_id=all_event_id, + tmin=metadata_tmin, + tmax=metadata_tmax, + sfreq=raw.info["sfreq"], + row_events=row_events, +) + +metadata + +# %% +# This looks good at the first glance. However, for example in the 2nd and 3rd row, we +# have two responses listed (left and right). This is because the 3-second time window +# is obviously a bit too wide and captures more than one trial. While we could make it +# narrower, this could lead to a loss of events – if the window might become **too** +# narrow. Ultimately, this problem arises because the response time varies from trial +# to trial, so it's difficult for us to set a fixed upper bound for the time window. +# +# Fixed time window with ``keep_first`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# One workaround is using the ``keep_first`` parameter, which will create a new column +# containing the first event of the specified type. + +metadata_tmin = 0.0 +metadata_tmax = 3.0 +keep_first = "response" # <-- new + +metadata, events, event_id = mne.epochs.make_metadata( + events=all_events, + event_id=all_event_id, + tmin=metadata_tmin, + tmax=metadata_tmax, + sfreq=raw.info["sfreq"], + row_events=row_events, + keep_first=keep_first, # <-- new +) + +metadata + +# %% +# As you can see, a new column ``response`` was created with the time of the first +# response event falling inside the time window. The ``first_response`` column specifies +# **which** response occurred first (left or right). +# +# Variable time window +# ~~~~~~~~~~~~~~~~~~~~ +# +# Another way to address the challenge of variable time windows **without** the need to +# create new columns is by specifying ``tmin`` and ``tmax`` as event names. In this +# example, we use ``tmin=row_events``, because we want the time window to start +# with the time-locked event. ``tmax``, on the other hand, are the response events: +# The first response event following ``tmin`` will be used to determine the duration of +# the time window. + +metadata_tmin = row_events +metadata_tmax = ["response/left", "response/right"] + +metadata, events, event_id = mne.epochs.make_metadata( + events=all_events, + event_id=all_event_id, + tmin=metadata_tmin, + tmax=metadata_tmax, + sfreq=raw.info["sfreq"], + row_events=row_events, +) + +metadata + +# %% +# Variable time window (simplified) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can slightly simplify the above code: Since ``tmin`` shall be set to the +# ``row_events``, we can paass ``tmin=None``, which is a more convenient way to express +# ``tmin=row_events``. The resulting metadata looks the same as in the previous example. + +metadata_tmin = None # <-- new +metadata_tmax = ["response/left", "response/right"] + +metadata, events, event_id = mne.epochs.make_metadata( + events=all_events, + event_id=all_event_id, + tmin=metadata_tmin, + tmax=metadata_tmax, + sfreq=raw.info["sfreq"], + row_events=row_events, +) + +metadata diff --git a/examples/preprocessing/find_ref_artifacts.py b/examples/preprocessing/find_ref_artifacts.py index 93b96e89e9c..90e3d1fb0da 100644 --- a/examples/preprocessing/find_ref_artifacts.py +++ b/examples/preprocessing/find_ref_artifacts.py @@ -70,7 +70,7 @@ # %% # The PSD of these data show the noise as clear peaks. -raw.compute_psd(fmax=30).plot(picks="data", exclude="bads") +raw.compute_psd(fmax=30).plot(picks="data", exclude="bads", amplitude=False) # %% # Run the "together" algorithm. @@ -99,7 +99,7 @@ # %% # Cleaned data: -raw_tog.compute_psd(fmax=30).plot(picks="data", exclude="bads") +raw_tog.compute_psd(fmax=30).plot(picks="data", exclude="bads", amplitude=False) # %% # Now try the "separate" algorithm. @@ -143,7 +143,7 @@ # %% # Cleaned raw data PSD: -raw_sep.compute_psd(fmax=30).plot(picks="data", exclude="bads") +raw_sep.compute_psd(fmax=30).plot(picks="data", exclude="bads", amplitude=False) ############################################################################## # References diff --git a/examples/preprocessing/ica_comparison.py b/examples/preprocessing/ica_comparison.py index 02930174435..d4246b80362 100644 --- a/examples/preprocessing/ica_comparison.py +++ b/examples/preprocessing/ica_comparison.py @@ -55,7 +55,7 @@ def run_ica(method, fit_params=None): t0 = time() ica.fit(raw, reject=reject) fit_time = time() - t0 - title = "ICA decomposition using %s (took %.1fs)" % (method, fit_time) + title = f"ICA decomposition using {method} (took {fit_time:.1f}s)" ica.plot_components(title=title) diff --git a/examples/preprocessing/muscle_ica.py b/examples/preprocessing/muscle_ica.py index f57e24a678b..64c14f5f5af 100644 --- a/examples/preprocessing/muscle_ica.py +++ b/examples/preprocessing/muscle_ica.py @@ -11,7 +11,6 @@ artifact is produced during postural maintenance. This is more appropriately removed by ICA otherwise there wouldn't be any epochs left! Note that muscle artifacts of this kind are much more pronounced in EEG than they are in MEG. - """ # Authors: Alex Rockhill # diff --git a/examples/preprocessing/otp.py b/examples/preprocessing/otp.py index aa235e79a78..df3a6c74ffe 100644 --- a/examples/preprocessing/otp.py +++ b/examples/preprocessing/otp.py @@ -79,15 +79,9 @@ def compute_bias(raw): bias = compute_bias(raw) -print("Raw bias: %0.1fmm (worst: %0.1fmm)" % (np.mean(bias), np.max(bias))) +print(f"Raw bias: {np.mean(bias):0.1f}mm (worst: {np.max(bias):0.1f}mm)") bias_clean = compute_bias(raw_clean) -print( - "OTP bias: %0.1fmm (worst: %0.1fmm)" - % ( - np.mean(bias_clean), - np.max(bias_clean), - ) -) +print(f"OTP bias: {np.mean(bias_clean):0.1f}mm (worst: {np.max(bias_clean):0.1f}m)") # %% # References diff --git a/examples/preprocessing/xdawn_denoising.py b/examples/preprocessing/xdawn_denoising.py index 6fc38a55b94..20a6abc72fb 100644 --- a/examples/preprocessing/xdawn_denoising.py +++ b/examples/preprocessing/xdawn_denoising.py @@ -25,7 +25,6 @@ # %% - from mne import Epochs, compute_raw_covariance, io, pick_types, read_events from mne.datasets import sample from mne.preprocessing import Xdawn diff --git a/examples/simulation/simulate_raw_data.py b/examples/simulation/simulate_raw_data.py index ef375bfec38..e413a8deb75 100644 --- a/examples/simulation/simulate_raw_data.py +++ b/examples/simulation/simulate_raw_data.py @@ -55,9 +55,9 @@ def data_fun(times): global n n_samp = len(times) window = np.zeros(n_samp) - start, stop = [ + start, stop = ( int(ii * float(n_samp) / (2 * n_dipoles)) for ii in (2 * n, 2 * n + 1) - ] + ) window[start:stop] = 1.0 n += 1 data = 25e-9 * np.sin(2.0 * np.pi * 10.0 * n * times) diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index dd142138784..8a12b78a9d3 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -58,16 +58,16 @@ raw_erms = dict() new_sfreq = 60.0 # Nyquist frequency (30 Hz) < line noise freq (50 Hz) raws["vv"] = mne.io.read_raw_fif(vv_fname, verbose="error") # ignore naming -raws["vv"].load_data().resample(new_sfreq) +raws["vv"].load_data().resample(new_sfreq, method="polyphase") raws["vv"].info["bads"] = ["MEG2233", "MEG1842"] raw_erms["vv"] = mne.io.read_raw_fif(vv_erm_fname, verbose="error") -raw_erms["vv"].load_data().resample(new_sfreq) +raw_erms["vv"].load_data().resample(new_sfreq, method="polyphase") raw_erms["vv"].info["bads"] = ["MEG2233", "MEG1842"] raws["opm"] = mne.io.read_raw_fif(opm_fname) -raws["opm"].load_data().resample(new_sfreq) +raws["opm"].load_data().resample(new_sfreq, method="polyphase") raw_erms["opm"] = mne.io.read_raw_fif(opm_erm_fname) -raw_erms["opm"].load_data().resample(new_sfreq) +raw_erms["opm"].load_data().resample(new_sfreq, method="polyphase") # Make sure our assumptions later hold assert raws["opm"].info["sfreq"] == raws["vv"].info["sfreq"] @@ -82,7 +82,7 @@ fig = ( raws[kind] .compute_psd(n_fft=n_fft, proj=True) - .plot(picks="data", exclude="bads") + .plot(picks="data", exclude="bads", amplitude=True) ) fig.suptitle(titles[kind]) diff --git a/examples/time_frequency/time_frequency_erds.py b/examples/time_frequency/time_frequency_erds.py index ee2dd62a2ba..1d805121739 100644 --- a/examples/time_frequency/time_frequency_erds.py +++ b/examples/time_frequency/time_frequency_erds.py @@ -45,29 +45,29 @@ from mne.datasets import eegbci from mne.io import concatenate_raws, read_raw_edf from mne.stats import permutation_cluster_1samp_test as pcluster_test -from mne.time_frequency import tfr_multitaper # %% # First, we load and preprocess the data. We use runs 6, 10, and 14 from # subject 1 (these runs contains hand and feet motor imagery). + fnames = eegbci.load_data(subject=1, runs=(6, 10, 14)) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames]) raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names - -events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3)) +# rename descriptions to be more easily interpretable +raw.annotations.rename(dict(T1="hands", T2="feet")) # %% # Now we can create 5-second epochs around events of interest. + tmin, tmax = -1, 4 event_ids = dict(hands=2, feet=3) # map event IDs to tasks epochs = mne.Epochs( raw, - events, - event_ids, - tmin - 0.5, - tmax + 0.5, + event_id=["hands", "feet"], + tmin=tmin - 0.5, + tmax=tmax + 0.5, picks=("C3", "Cz", "C4"), baseline=None, preload=True, @@ -95,8 +95,8 @@ # %% # Finally, we perform time/frequency decomposition over all epochs. -tfr = tfr_multitaper( - epochs, +tfr = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=freqs, use_fft=True, diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index 9dfe38eab8b..dc42f16da3a 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -25,16 +25,8 @@ from matplotlib import pyplot as plt from mne import Epochs, create_info -from mne.baseline import rescale from mne.io import RawArray -from mne.time_frequency import ( - AverageTFR, - tfr_array_morlet, - tfr_morlet, - tfr_multitaper, - tfr_stockwell, -) -from mne.viz import centers_to_edges +from mne.time_frequency import AverageTFRArray, EpochsTFRArray, tfr_array_morlet print(__doc__) @@ -112,12 +104,13 @@ "Sim: Less time smoothing,\nmore frequency smoothing", ], ): - power = tfr_multitaper( - epochs, + power = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, return_itc=False, + average=True, ) ax.set_title(title) # Plot results. Baseline correct based on first 100 ms. @@ -125,8 +118,7 @@ [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), axes=ax, show=False, colorbar=False, @@ -146,11 +138,11 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") fmin, fmax = freqs[[0, -1]] for width, ax in zip((0.2, 0.7, 3.0), axs): - power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width) + power = epochs.compute_tfr(method="stockwell", freqs=(fmin, fmax), width=width) power.plot( [0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False ) - ax.set_title("Sim: Using S transform, width = {:0.1f}".format(width)) + ax.set_title(f"Sim: Using S transform, width = {width:0.1f}") # %% # Morlet Wavelets @@ -164,13 +156,14 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") all_n_cycles = [1, 3, freqs / 2.0] for n_cycles, ax in zip(all_n_cycles, axs): - power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False) + power = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=True + ) power.plot( [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), axes=ax, show=False, colorbar=False, @@ -190,7 +183,9 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") bandwidths = [1.0, 2.0, 4.0] for bandwidth, ax in zip(bandwidths, axs): - data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex) + data = np.zeros( + (len(epochs), len(ch_names), freqs.size, epochs.times.size), dtype=complex + ) for idx, freq in enumerate(freqs): # Filter raw data and re-epoch to avoid the filter being longer than # the epoch data for low frequencies and short epochs, such as here. @@ -210,17 +205,13 @@ epochs_hilb = Epochs( raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1) ) - tfr_data = epochs_hilb.get_data() - tfr_data = tfr_data * tfr_data.conj() # compute power - tfr_data = np.mean(tfr_data, axis=0) # average over epochs - data[:, idx] = tfr_data - power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs) - power.plot( + data[:, :, idx] = epochs_hilb.get_data() + power = EpochsTFRArray(epochs.info, data, epochs.times, freqs, method="hilbert") + power.average().plot( [0], baseline=(0.0, 0.1), mode="mean", - vmin=-0.1, - vmax=0.1, + vlim=(0, 0.1), axes=ax, show=False, colorbar=False, @@ -241,8 +232,8 @@ # :class:`mne.time_frequency.EpochsTFR` is returned. n_cycles = freqs / 2.0 -power = tfr_morlet( - epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False +power = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False ) print(type(power)) avgpower = power.average() @@ -250,8 +241,7 @@ [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), title="Using Morlet wavelets and EpochsTFR", show=False, ) @@ -260,10 +250,12 @@ # Operating on arrays # ------------------- # -# MNE also has versions of the functions above which operate on numpy arrays -# instead of MNE objects. They expect inputs of the shape -# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array -# of shape ``(n_epochs, n_channels, n_freqs, n_times)``. +# MNE-Python also has functions that operate on :class:`NumPy arrays ` +# instead of MNE-Python objects. These are :func:`~mne.time_frequency.tfr_array_morlet` +# and :func:`~mne.time_frequency.tfr_array_multitaper`. They expect inputs of the shape +# ``(n_epochs, n_channels, n_times)`` and return an array of shape +# ``(n_epochs, n_channels, n_freqs, n_times)`` (or optionally, can collapse the epochs +# dimension if you want average power or inter-trial coherence; see ``output`` param). power = tfr_array_morlet( epochs.get_data(), @@ -271,12 +263,16 @@ freqs=freqs, n_cycles=n_cycles, output="avg_power", + zero_mean=False, +) +# Put it into a TFR container for easy plotting +tfr = AverageTFRArray( + info=epochs.info, data=power, times=epochs.times, freqs=freqs, nave=len(epochs) +) +tfr.plot( + baseline=(0.0, 0.1), + picks=[0], + mode="mean", + vlim=(vmin, vmax), + title="TFR calculated on a NumPy array", ) -# Baseline the output -rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False) -fig, ax = plt.subplots(layout="constrained") -x, y = centers_to_edges(epochs.times * 1000, freqs) -mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) -ax.set_title("TFR calculated on a numpy array") -ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)") -fig.colorbar(mesh) diff --git a/examples/visualization/3d_to_2d.py b/examples/visualization/3d_to_2d.py index 6d8e8674fa3..47b223e8396 100644 --- a/examples/visualization/3d_to_2d.py +++ b/examples/visualization/3d_to_2d.py @@ -23,8 +23,6 @@ # Copyright the MNE-Python contributors. # %% -from os.path import dirname -from pathlib import Path import numpy as np from matplotlib import pyplot as plt @@ -43,8 +41,7 @@ ecog_data_fname = subjects_dir / "sample_ecog_ieeg.fif" # We've already clicked and exported -layout_path = Path(dirname(mne.__file__)) / "data" / "image" -layout_name = "custom_layout.lout" +layout_name = subjects_dir / "custom_layout.lout" # %% # Load data @@ -128,10 +125,10 @@ # # Generate a layout from our clicks and normalize by the image # print('Generating and saving layout...') # lt = click.to_layout() -# lt.save(layout_path / layout_name) # save if we want +# lt.save(layout_name) # save if we want # # We've already got the layout, load it -lt = mne.channels.read_layout(layout_path / layout_name, scale=False) +lt = mne.channels.read_layout(layout_name, scale=False) x = lt.pos[:, 0] * float(im.shape[1]) y = (1 - lt.pos[:, 1]) * float(im.shape[0]) # Flip the y-position fig, ax = plt.subplots(layout="constrained") diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index f75869383a9..c01cdd80d71 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -111,7 +111,7 @@ colorbar=False, sphere=(0.0, 0.0, 0.0, 0.09), ) - ax.set_title("%s %s" % (ch_type.upper(), extr), fontsize=14) + ax.set_title(f"{ch_type.upper()} {extr}", fontsize=14) # %% # More advanced usage diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index e213408276a..9a474d9ea36 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -84,7 +84,7 @@ print("Covariance estimates sorted from best to worst") for c in noise_covs: - print("%s : %s" % (c["method"], c["loglik"])) + print(f'{c["method"]} : {c["loglik"]}') # %% # Show the evoked data: diff --git a/examples/visualization/eyetracking_plot_heatmap.py b/examples/visualization/eyetracking_plot_heatmap.py index c12aa689984..9225493ef88 100644 --- a/examples/visualization/eyetracking_plot_heatmap.py +++ b/examples/visualization/eyetracking_plot_heatmap.py @@ -24,7 +24,6 @@ # :ref:`example data `: eye-tracking data recorded from SR research's # ``'.asc'`` file format. - import matplotlib.pyplot as plt import mne @@ -35,6 +34,12 @@ stim_fpath = task_fpath / "stim" / "naturalistic.png" raw = mne.io.read_raw_eyelink(et_fpath) +calibration = mne.preprocessing.eyetracking.read_eyelink_calibration( + et_fpath, + screen_resolution=(1920, 1080), + screen_size=(0.53, 0.3), + screen_distance=0.9, +)[0] # %% # Process and epoch the data @@ -44,12 +49,8 @@ mne.preprocessing.eyetracking.interpolate_blinks(raw, interpolate_gaze=True) raw.annotations.rename({"dvns": "natural"}) # more intuitive -event_ids = {"natural": 1} -events, event_dict = mne.events_from_annotations(raw, event_id=event_ids) -epochs = mne.Epochs( - raw, events=events, event_id=event_dict, tmin=0, tmax=20, baseline=None -) +epochs = mne.Epochs(raw, event_id=["natural"], tmin=0, tmax=20, baseline=None) # %% @@ -62,9 +63,8 @@ # screen resolution of the participant screen (1920x1080) as the width and height. We # can also use the sigma parameter to smooth the plot. -px_width, px_height = 1920, 1080 cmap = plt.get_cmap("viridis") -plot_gaze(epochs["natural"], width=px_width, height=px_height, cmap=cmap, sigma=50) +plot_gaze(epochs["natural"], calibration=calibration, cmap=cmap, sigma=50) # %% # Overlaying plots with images @@ -81,10 +81,26 @@ ax.imshow(plt.imread(stim_fpath)) plot_gaze( epochs["natural"], - width=px_width, - height=px_height, + calibration=calibration, vlim=(0.0003, None), sigma=50, cmap=cmap, axes=ax, ) + +# %% +# Displaying the heatmap in units of visual angle +# ----------------------------------------------- +# +# In scientific publications it is common to report gaze data as the visual angle +# from the participants eye to the screen. We can convert the units of our gaze data to +# radians of visual angle before plotting the heatmap: + +# %% +epochs.load_data() +mne.preprocessing.eyetracking.convert_units(epochs, calibration, to="radians") +plot_gaze( + epochs["natural"], + calibration=calibration, + sigma=50, +) diff --git a/examples/visualization/topo_compare_conditions.py b/examples/visualization/topo_compare_conditions.py index 7572eab47e5..3ab4e46d5f2 100644 --- a/examples/visualization/topo_compare_conditions.py +++ b/examples/visualization/topo_compare_conditions.py @@ -19,7 +19,6 @@ # %% - import matplotlib.pyplot as plt import mne diff --git a/examples/visualization/topo_customized.py b/examples/visualization/topo_customized.py index 2d3c6662ebc..2303961f9da 100644 --- a/examples/visualization/topo_customized.py +++ b/examples/visualization/topo_customized.py @@ -19,7 +19,6 @@ # %% - import matplotlib.pyplot as plt import numpy as np diff --git a/mne/__init__.py b/mne/__init__.py index 594eddefdd2..10ff0c23738 100644 --- a/mne/__init__.py +++ b/mne/__init__.py @@ -23,11 +23,10 @@ __version__ = version("mne") except Exception: - try: - from ._version import __version__ - except ImportError: - __version__ = "0.0.0" + __version__ = "0.0.0" + (__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__) + # initialize logging from .utils import set_log_level, set_log_file diff --git a/mne/_fiff/__init__.py b/mne/_fiff/__init__.py index 6402d78b325..877068fe54d 100644 --- a/mne/_fiff/__init__.py +++ b/mne/_fiff/__init__.py @@ -7,22 +7,3 @@ # All imports should be done directly to submodules, so we don't import # anything here or use lazy_loader. - -# This warn import (made private as _warn) is just for the temporary -# _io_dep_getattr and can be removed in 1.6 along with _dep_msg and _io_dep_getattr. -from ..utils import warn as _warn - - -_dep_msg = ( - "is deprecated will be removed in 1.6, use documented public API instead. " - "If no appropriate public API exists, please open an issue on GitHub." -) - - -def _io_dep_getattr(name, mod): - import importlib - - fiff_mod = importlib.import_module(f"mne._fiff.{mod}") - obj = getattr(fiff_mod, name) - _warn(f"mne.io.{mod}.{name} {_dep_msg}", FutureWarning) - return obj diff --git a/mne/_fiff/_digitization.py b/mne/_fiff/_digitization.py index dab0427ac6a..dcbf9e8d24d 100644 --- a/mne/_fiff/_digitization.py +++ b/mne/_fiff/_digitization.py @@ -132,14 +132,15 @@ def __repr__(self): # noqa: D105 id_ = _cardinal_kind_rev.get(self["ident"], "Unknown cardinal") else: id_ = _dig_kind_proper[_dig_kind_rev.get(self["kind"], "unknown")] - id_ = "%s #%s" % (id_, self["ident"]) + id_ = f"{id_} #{self['ident']}" id_ = id_.rjust(10) cf = _coord_frame_name(self["coord_frame"]) + x, y, z = self["r"] if "voxel" in cf: - pos = ("(%0.1f, %0.1f, %0.1f)" % tuple(self["r"])).ljust(25) + pos = (f"({x:0.1f}, {y:0.1f}, {z:0.1f})").ljust(25) else: - pos = ("(%0.1f, %0.1f, %0.1f) mm" % tuple(1000 * self["r"])).ljust(25) - return "" % (id_, pos, cf) + pos = (f"({x * 1e3:0.1f}, {y * 1e3:0.1f}, {z * 1e3:0.1f}) mm").ljust(25) + return f"" # speed up info copy by only deep copying the mutable item def __deepcopy__(self, memodict): @@ -362,8 +363,8 @@ def _coord_frame_const(coord_frame): if not isinstance(coord_frame, str) or coord_frame not in _str_to_frame: raise ValueError( - "coord_frame must be one of %s, got %s" - % (sorted(_str_to_frame.keys()), coord_frame) + f"coord_frame must be one of {sorted(_str_to_frame.keys())}, got " + f"{coord_frame}" ) return _str_to_frame[coord_frame] @@ -414,9 +415,7 @@ def _make_dig_points( if lpa is not None: lpa = np.asarray(lpa) if lpa.shape != (3,): - raise ValueError( - "LPA should have the shape (3,) instead of %s" % (lpa.shape,) - ) + raise ValueError(f"LPA should have the shape (3,) instead of {lpa.shape}") dig.append( { "r": lpa, @@ -429,7 +428,7 @@ def _make_dig_points( nasion = np.asarray(nasion) if nasion.shape != (3,): raise ValueError( - "Nasion should have the shape (3,) instead of %s" % (nasion.shape,) + f"Nasion should have the shape (3,) instead of {nasion.shape}" ) dig.append( { @@ -442,9 +441,7 @@ def _make_dig_points( if rpa is not None: rpa = np.asarray(rpa) if rpa.shape != (3,): - raise ValueError( - "RPA should have the shape (3,) instead of %s" % (rpa.shape,) - ) + raise ValueError(f"RPA should have the shape (3,) instead of {rpa.shape}") dig.append( { "r": rpa, @@ -457,8 +454,7 @@ def _make_dig_points( hpi = np.asarray(hpi) if hpi.ndim != 2 or hpi.shape[1] != 3: raise ValueError( - "HPI should have the shape (n_points, 3) instead " - "of %s" % (hpi.shape,) + f"HPI should have the shape (n_points, 3) instead of {hpi.shape}" ) for idx, point in enumerate(hpi): dig.append( @@ -473,8 +469,8 @@ def _make_dig_points( extra_points = np.asarray(extra_points) if len(extra_points) and extra_points.shape[1] != 3: raise ValueError( - "Points should have the shape (n_points, 3) " - "instead of %s" % (extra_points.shape,) + "Points should have the shape (n_points, 3) instead of " + f"{extra_points.shape}" ) for idx, point in enumerate(extra_points): dig.append( diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 483ddc34b52..a2928a9f2a6 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -454,8 +454,8 @@ def _check_set(ch, projs, ch_type): for proj in projs: if ch["ch_name"] in proj["data"]["col_names"]: raise RuntimeError( - "Cannot change channel type for channel %s " - 'in projector "%s"' % (ch["ch_name"], proj["desc"]) + f'Cannot change channel type for channel {ch["ch_name"]} in ' + f'projector "{proj["desc"]}"' ) ch["kind"] = new_kind @@ -482,7 +482,7 @@ def _get_channel_positions(self, picks=None): n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0) if n_zero > 1: # XXX some systems have origin (0, 0, 0) raise ValueError( - "Could not extract channel positions for " "{} channels".format(n_zero) + f"Could not extract channel positions for {n_zero} channels" ) return pos @@ -507,8 +507,8 @@ def _set_channel_positions(self, pos, names): ) pos = np.asarray(pos, dtype=np.float64) if pos.shape[-1] != 3 or pos.ndim != 2: - msg = "Channel positions must have the shape (n_points, 3) " "not %s." % ( - pos.shape, + msg = ( + f"Channel positions must have the shape (n_points, 3) not {pos.shape}." ) raise ValueError(msg) for name, p in zip(names, pos): @@ -545,12 +545,12 @@ def set_channel_types(self, mapping, *, on_unit_change="warn", verbose=None): Notes ----- - The following sensor types are accepted: + The following :term:`sensor types` are accepted: - ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, dbs, stim, syst, - ecog, hbo, hbr, fnirs_cw_amplitude, fnirs_fd_ac_amplitude, - fnirs_fd_phase, fnirs_od, eyetrack_pos, eyetrack_pupil, - temperature, gsr + bio, chpi, csd, dbs, dipole, ecg, ecog, eeg, emg, eog, exci, + eyegaze, fnirs_cw_amplitude, fnirs_fd_ac_amplitude, fnirs_fd_phase, + fnirs_od, gof, gsr, hbo, hbr, ias, misc, pupil, ref_meg, resp, + seeg, stim, syst, temperature. .. versionadded:: 0.9.0 """ @@ -568,9 +568,9 @@ def set_channel_types(self, mapping, *, on_unit_change="warn", verbose=None): c_ind = ch_names.index(ch_name) if ch_type not in _human2fiff: raise ValueError( - "This function cannot change to this " - "channel type: %s. Accepted channel types " - "are %s." % (ch_type, ", ".join(sorted(_human2unit.keys()))) + f"This function cannot change to this channel type: {ch_type}. " + "Accepted channel types are " + f"{', '.join(sorted(_human2unit.keys()))}." ) # Set sensor type _check_set(info["chs"][c_ind], info["projs"], ch_type) @@ -578,8 +578,8 @@ def set_channel_types(self, mapping, *, on_unit_change="warn", verbose=None): unit_new = _human2unit[ch_type] if unit_old not in _unit2human: raise ValueError( - "Channel '%s' has unknown unit (%s). Please " - "fix the measurement info of your data." % (ch_name, unit_old) + f"Channel '{ch_name}' has unknown unit ({unit_old}). Please fix the" + " measurement info of your data." ) if unit_old != _human2unit[ch_type]: this_change = (_unit2human[unit_old], _unit2human[unit_new]) @@ -1104,9 +1104,9 @@ class Info(dict, SetChannelsMixin, MontageMixin, ContainsMixin): The transformation from 4D/CTF head coordinates to Neuromag head coordinates. This is only present in 4D/CTF data. custom_ref_applied : int - Whether a custom (=other than average) reference has been applied to - the EEG data. This flag is checked by some algorithms that require an - average reference to be set. + Whether a custom (=other than an average projector) reference has been + applied to the EEG data. This flag is checked by some algorithms that + require an average reference to be set. description : str | None String description of the recording. dev_ctf_t : Transform | None @@ -1659,7 +1659,7 @@ def __repr__(self): non_empty -= 1 # don't count as non-empty elif k == "bads": if v: - entr = "{} items (".format(len(v)) + entr = f"{len(v)} items (" entr += ", ".join(v) entr = shorten(entr, MAX_WIDTH, placeholder=" ...") + ")" else: @@ -1695,11 +1695,11 @@ def __repr__(self): if not np.allclose(v["trans"], np.eye(v["trans"].shape[0])): frame1 = _coord_frame_name(v["from"]) frame2 = _coord_frame_name(v["to"]) - entr = "%s -> %s transform" % (frame1, frame2) + entr = f"{frame1} -> {frame2} transform" else: entr = "" elif k in ["sfreq", "lowpass", "highpass"]: - entr = "{:.1f} Hz".format(v) + entr = f"{v:.1f} Hz" elif isinstance(v, str): entr = shorten(v, MAX_WIDTH, placeholder=" ...") elif k == "chs": @@ -1719,7 +1719,7 @@ def __repr__(self): try: this_len = len(v) except TypeError: - entr = "{}".format(v) if v is not None else "" + entr = f"{v}" if v is not None else "" else: if this_len > 0: entr = "%d item%s (%s)" % ( @@ -1731,7 +1731,7 @@ def __repr__(self): entr = "" if entr != "": non_empty += 1 - strs.append("%s: %s" % (k, entr)) + strs.append(f"{k}: {entr}") st = "\n ".join(sorted(strs)) st += "\n>" st %= non_empty @@ -1784,12 +1784,8 @@ def _check_consistency(self, prepend_error=""): or self["meas_date"].tzinfo is not datetime.timezone.utc ): raise RuntimeError( - '%sinfo["meas_date"] must be a datetime ' - "object in UTC or None, got %r" - % ( - prepend_error, - repr(self["meas_date"]), - ) + f'{prepend_error}info["meas_date"] must be a datetime object in UTC' + f' or None, got {repr(self["meas_date"])!r}' ) chs = [ch["ch_name"] for ch in self["chs"]] @@ -1799,8 +1795,8 @@ def _check_consistency(self, prepend_error=""): or self["nchan"] != len(chs) ): raise RuntimeError( - "%sinfo channel name inconsistency detected, " - "please notify mne-python developers" % (prepend_error,) + f"{prepend_error}info channel name inconsistency detected, please " + "notify MNE-Python developers" ) # make sure we have the proper datatypes @@ -2649,16 +2645,9 @@ def _check_dates(info, prepend_error=""): or value[key_2] > np.iinfo(">i4").max ): raise RuntimeError( - "%sinfo[%s][%s] must be between " - '"%r" and "%r", got "%r"' - % ( - prepend_error, - key, - key_2, - np.iinfo(">i4").min, - np.iinfo(">i4").max, - value[key_2], - ), + f"{prepend_error}info[{key}][{key_2}] must be between " + f'"{np.iinfo(">i4").min!r}" and "{np.iinfo(">i4").max!r}", got ' + f'"{value[key_2]!r}"' ) meas_date = info.get("meas_date") @@ -2671,14 +2660,9 @@ def _check_dates(info, prepend_error=""): or meas_date_stamp[0] > np.iinfo(">i4").max ): raise RuntimeError( - '%sinfo["meas_date"] seconds must be between "%r" ' - 'and "%r", got "%r"' - % ( - prepend_error, - (np.iinfo(">i4").min, 0), - (np.iinfo(">i4").max, 0), - meas_date_stamp[0], - ) + f'{prepend_error}info["meas_date"] seconds must be between ' + f'"{(np.iinfo(">i4").min, 0)!r}" and "{(np.iinfo(">i4").max, 0)!r}", got ' + f'"{meas_date_stamp[0]!r}"' ) @@ -2954,8 +2938,8 @@ def _merge_info_values(infos, key, verbose=None): """ values = [d[key] for d in infos] msg = ( - "Don't know how to merge '%s'. Make sure values are " - "compatible, got types:\n %s" % (key, [type(v) for v in values]) + f"Don't know how to merge '{key}'. Make sure values are compatible, got types:" + f"\n {[type(v) for v in values]}" ) def _flatten(lists): @@ -3171,11 +3155,14 @@ def create_info(ch_names, sfreq, ch_types="misc", verbose=None): sfreq : float Sample rate of the data. ch_types : list of str | str - Channel types, default is ``'misc'`` which is not a - :term:`data channel `. - Currently supported fields are 'ecg', 'bio', 'stim', 'eog', 'misc', - 'seeg', 'dbs', 'ecog', 'mag', 'eeg', 'ref_meg', 'grad', 'emg', 'hbr' - 'eyetrack' or 'hbo'. + Channel types, default is ``'misc'`` which is a + :term:`non-data channel `. + Currently supported fields are 'bio', 'chpi', 'csd', 'dbs', 'dipole', + 'ecg', 'ecog', 'eeg', 'emg', 'eog', 'exci', 'eyegaze', + 'fnirs_cw_amplitude', 'fnirs_fd_ac_amplitude', 'fnirs_fd_phase', + 'fnirs_od', 'gof', 'gsr', 'hbo', 'hbr', 'ias', 'misc', 'pupil', + 'ref_meg', 'resp', 'seeg', 'stim', 'syst', 'temperature' (see also + :term:`sensor types`). If str, then all channels are assumed to be of the same type. %(verbose)s @@ -3195,12 +3182,18 @@ def create_info(ch_names, sfreq, ch_types="misc", verbose=None): Proper units of measure: - * V: eeg, eog, seeg, dbs, emg, ecg, bio, ecog - * T: mag + * V: eeg, eog, seeg, dbs, emg, ecg, bio, ecog, resp, fnirs_fd_ac_amplitude, + fnirs_cw_amplitude, fnirs_od + * T: mag, chpi, ref_meg * T/m: grad * M: hbo, hbr + * rad: fnirs_fd_phase * Am: dipole - * AU: misc + * S: gsr + * C: temperature + * V/m²: csd + * GOF: gof + * AU: misc, stim, eyegaze, pupil """ try: ch_names = operator.index(ch_names) # int-like @@ -3218,8 +3211,8 @@ def create_info(ch_names, sfreq, ch_types="misc", verbose=None): ch_types = np.atleast_1d(np.array(ch_types, np.str_)) if ch_types.ndim != 1 or len(ch_types) != nchan: raise ValueError( - "ch_types and ch_names must be the same length " - "(%s != %s) for ch_types=%s" % (len(ch_types), nchan, ch_types) + f"ch_types and ch_names must be the same length ({len(ch_types)} != " + f"{nchan}) for ch_types={ch_types}" ) info = _empty_info(sfreq) ch_types_dict = get_channel_type_constants(include_defaults=True) diff --git a/mne/_fiff/open.py b/mne/_fiff/open.py index d1794317772..5bfcb83a951 100644 --- a/mne/_fiff/open.py +++ b/mne/_fiff/open.py @@ -13,7 +13,13 @@ from ..utils import _file_like, logger, verbose, warn from .constants import FIFF -from .tag import Tag, _call_dict_names, _matrix_info, read_tag, read_tag_info +from .tag import ( + Tag, + _call_dict_names, + _matrix_info, + _read_tag_header, + read_tag, +) from .tree import dir_tree_find, make_dir_tree @@ -139,7 +145,7 @@ def _fiff_open(fname, fid, preload): with fid as fid_old: fid = BytesIO(fid_old.read()) - tag = read_tag_info(fid) + tag = _read_tag_header(fid, 0) # Check that this looks like a fif file prefix = f"file {repr(fname)} does not" @@ -152,7 +158,7 @@ def _fiff_open(fname, fid, preload): if tag.size != 20: raise ValueError(f"{prefix} start with a file id tag") - tag = read_tag(fid) + tag = read_tag(fid, tag.next_pos) if tag.kind != FIFF.FIFF_DIR_POINTER: raise ValueError(f"{prefix} have a directory pointer") @@ -176,16 +182,15 @@ def _fiff_open(fname, fid, preload): directory = dir_tag.data read_slow = False if read_slow: - fid.seek(0, 0) + pos = 0 + fid.seek(pos, 0) directory = list() - while tag.next >= 0: - pos = fid.tell() - tag = read_tag_info(fid) + while pos is not None: + tag = _read_tag_header(fid, pos) if tag is None: break # HACK : to fix file ending with empty tag... - else: - tag.pos = pos - directory.append(tag) + pos = tag.next_pos + directory.append(tag) tree, _ = make_dir_tree(fid, directory) @@ -258,12 +263,12 @@ def show_fiff( tag_id=tag, show_bytes=show_bytes, ) - if output == str: + if output is str: out = "\n".join(out) return out -def _find_type(value, fmts=["FIFF_"], exclude=["FIFF_UNIT"]): +def _find_type(value, fmts=("FIFF_",), exclude=("FIFF_UNIT",)): """Find matching values.""" value = int(value) vals = [ @@ -309,7 +314,7 @@ def _show_tree( for k, kn, size, pos, type_ in zip(kinds[:-1], kinds[1:], sizes, poss, types): if not tag_found and k != tag_id: continue - tag = Tag(k, size, 0, pos) + tag = Tag(kind=k, type=type_, size=size, next=FIFF.FIFFV_NEXT_NONE, pos=pos) if read_limit is None or size <= read_limit: try: tag = read_tag(fid, pos) @@ -342,17 +347,17 @@ def _show_tree( elif isinstance(tag.data, (list, tuple)): postpend += " ... list len=" + str(len(tag.data)) elif issparse(tag.data): - postpend += " ... sparse (%s) shape=%s" % ( - tag.data.getformat(), - tag.data.shape, + postpend += ( + f" ... sparse ({tag.data.getformat()}) shape=" + f"{tag.data.shape}" ) else: postpend += " ... type=" + str(type(tag.data)) - postpend = ">" * 20 + "BAD" if not good else postpend + postpend = ">" * 20 + f"BAD @{pos}" if not good else postpend matrix_info = _matrix_info(tag) if matrix_info is not None: _, type_, _, _ = matrix_info - type_ = _call_dict_names.get(type_, "?%s?" % (type_,)) + type_ = _call_dict_names.get(type_, f"?{type_}?") this_type = "/".join(this_type) out += [ f"{next_idt}{prepend}{str(k).ljust(4)} = " diff --git a/mne/_fiff/pick.py b/mne/_fiff/pick.py index 4c5854f36fe..2af49c7b921 100644 --- a/mne/_fiff/pick.py +++ b/mne/_fiff/pick.py @@ -17,7 +17,6 @@ fill_doc, logger, verbose, - warn, ) from .constants import FIFF @@ -237,10 +236,9 @@ def channel_type(info, idx): type : str Type of channel. Will be one of:: - {'grad', 'mag', 'eeg', 'csd', 'stim', 'eog', 'emg', 'ecg', - 'ref_meg', 'resp', 'exci', 'ias', 'syst', 'misc', 'seeg', 'dbs', - 'bio', 'chpi', 'dipole', 'gof', 'ecog', 'hbo', 'hbr', - 'temperature', 'gsr', 'eyetrack'} + {'bio', 'chpi', 'dbs', 'dipole', 'ecg', 'ecog', 'eeg', 'emg', + 'eog', 'exci', 'eyetrack', 'fnirs', 'gof', 'gsr', 'ias', 'misc', + 'meg', 'ref_meg', 'resp', 'seeg', 'stim', 'syst', 'temperature'} """ # This is faster than the original _channel_type_old now in test_pick.py # because it uses (at most!) two dict lookups plus one conditional @@ -250,7 +248,7 @@ def channel_type(info, idx): first_kind = _first_rule[ch["kind"]] except KeyError: raise ValueError( - 'Unknown channel type (%s) for channel "%s"' % (ch["kind"], ch["ch_name"]) + f'Unknown channel type ({ch["kind"]}) for channel "{ch["ch_name"]}"' ) if first_kind in _second_rules: key, second_rule = _second_rules[first_kind] @@ -259,7 +257,7 @@ def channel_type(info, idx): @verbose -def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): +def pick_channels(ch_names, include, exclude=(), ordered=True, *, verbose=None): """Pick channels by names. Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``. @@ -291,7 +289,7 @@ def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): """ if len(np.unique(ch_names)) != len(ch_names): raise RuntimeError("ch_names is not a unique list, picking is unsafe") - _validate_type(ordered, (bool, None), "ordered") + _validate_type(ordered, bool, "ordered") _check_excludes_includes(include) _check_excludes_includes(exclude) if not isinstance(include, list): @@ -307,35 +305,12 @@ def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): sel.append(ch_names.index(name)) else: missing.append(name) - dep_msg = ( - "The default for pick_channels will change from ordered=False to " - "ordered=True in 1.5" - ) - if len(missing): - if ordered is None: - warn( - f"{dep_msg} and this will result in an error because the " - f"following channel names are missing:\n{missing}\n" - "Either fix your included names or explicitly pass " - "ordered=False.", - FutureWarning, - ) - elif ordered: - raise ValueError( - "Missing channels from ch_names required by " - "include:\n%s" % (missing,) - ) + if len(missing) and ordered: + raise ValueError( + f"Missing channels from ch_names required by include:\n{missing}" + ) if not ordered: - out_sel = np.unique(sel) - if ordered is None and not np.array_equal(out_sel, sel): - warn( - f"{dep_msg} and this will result in a change of behavior " - "because the resulting channel order will not match. Either " - "use a channel order that matches your instance or " - "pass ordered=False.", - FutureWarning, - ) - sel = out_sel + sel = np.unique(sel) return np.array(sel, int) @@ -436,7 +411,7 @@ def _check_meg_type(meg, allow_auto=False): allowed_types += ["auto"] if allow_auto else [] if meg not in allowed_types: raise ValueError( - "meg value must be one of %s or bool, not %s" % (allowed_types, meg) + f"meg value must be one of {allowed_types} or bool, not {meg}" ) @@ -650,7 +625,8 @@ def pick_info(info, sel=(), copy=True, verbose=None): return info elif len(sel) == 0: raise ValueError("No channels match the selection.") - n_unique = len(np.unique(np.arange(len(info["ch_names"]))[sel])) + ch_set = set(info["ch_names"][k] for k in sel) + n_unique = len(ch_set) if n_unique != len(sel): raise ValueError( "Found %d / %d unique names, sel is not unique" % (n_unique, len(sel)) @@ -688,6 +664,15 @@ def pick_info(info, sel=(), copy=True, verbose=None): if info.get("custom_ref_applied", False) and not _electrode_types(info): with info._unlock(): info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF + # remove unused projectors + if info.get("projs", False): + projs = list() + for p in info["projs"]: + if any(ch_name in ch_set for ch_name in p["data"]["col_names"]): + projs.append(p) + if len(projs) != len(info["projs"]): + with info._unlock(): + info["projs"] = projs info._check_consistency() return info @@ -707,7 +692,7 @@ def _has_kit_refs(info, picks): @verbose def pick_channels_forward( - orig, include=[], exclude=[], ordered=None, copy=True, *, verbose=None + orig, include=(), exclude=(), ordered=True, copy=True, *, verbose=None ): """Pick channels from forward operator. @@ -798,8 +783,8 @@ def pick_types_forward( seeg=False, ecog=False, dbs=False, - include=[], - exclude=[], + include=(), + exclude=(), ): """Pick by channel type and names from a forward operator. @@ -894,7 +879,7 @@ def channel_indices_by_type(info, picks=None): @verbose def pick_channels_cov( - orig, include=[], exclude="bads", ordered=None, copy=True, *, verbose=None + orig, include=(), exclude="bads", ordered=True, copy=True, *, verbose=None ): """Pick channels from covariance matrix. @@ -983,8 +968,7 @@ def _contains_ch_type(info, ch_type): _check_option("ch_type", ch_type, valid_channel_types) if info is None: raise ValueError( - 'Cannot check for channels of type "%s" because info ' - "is None" % (ch_type,) + f'Cannot check for channels of type "{ch_type}" because info is None' ) return any(ch_type == channel_type(info, ii) for ii in range(info["nchan"])) @@ -1078,8 +1062,8 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): chs = info["bads"] else: raise ValueError( - 'include/exclude must be list, tuple, ndarray, or "bads". ' - + "You provided type {}".format(type(chs)) + 'include/exclude must be list, tuple, ndarray, or "bads". You provided ' + f"type {type(chs)}." ) return chs @@ -1252,7 +1236,7 @@ def _picks_to_idx( extra_repr = ", treated as range(%d)" % (n_chan,) else: picks = none # let _picks_str_to_idx handle it - extra_repr = 'None, treated as "%s"' % (none,) + extra_repr = f'None, treated as "{none}"' # # slice @@ -1266,7 +1250,7 @@ def _picks_to_idx( picks = np.atleast_1d(picks) # this works even for picks == 'something' picks = np.array([], dtype=int) if len(picks) == 0 else picks if picks.ndim != 1: - raise ValueError("picks must be 1D, got %sD" % (picks.ndim,)) + raise ValueError(f"picks must be 1D, got {picks.ndim}D") if picks.dtype.char in ("S", "U"): picks = _picks_str_to_idx( info, @@ -1296,8 +1280,7 @@ def _picks_to_idx( # if len(picks) == 0 and not allow_empty: raise ValueError( - "No appropriate %s found for the given picks " - "(%r)" % (picks_on, orig_picks) + f"No appropriate {picks_on} found for the given picks ({orig_picks!r})" ) if (picks < -n_chan).any(): raise IndexError("All picks must be >= %d, got %r" % (-n_chan, orig_picks)) @@ -1341,8 +1324,8 @@ def _picks_str_to_idx( picks_generic = _pick_data_or_ica(info, exclude=exclude) if len(picks_generic) == 0 and orig_picks is None and not allow_empty: raise ValueError( - "picks (%s) yielded no channels, consider " - "passing picks explicitly" % (repr(orig_picks) + extra_repr,) + f"picks ({repr(orig_picks) + extra_repr}) yielded no channels, " + "consider passing picks explicitly" ) # @@ -1407,10 +1390,9 @@ def _picks_str_to_idx( if sum(any_found) == 0: if not allow_empty: raise ValueError( - "picks (%s) could not be interpreted as " - 'channel names (no channel "%s"), channel types (no ' - 'type "%s" present), or a generic type (just "all" or "data")' - % (repr(orig_picks) + extra_repr, str(bad_names), bad_type) + f"picks ({repr(orig_picks) + extra_repr}) could not be interpreted as " + f'channel names (no channel "{str(bad_names)}"), channel types (no type' + f' "{bad_type}" present), or a generic type (just "all" or "data")' ) picks = np.array([], int) elif sum(any_found) > 1: diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index 26bba36bc13..0036257d00c 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -729,7 +729,7 @@ def _write_proj(fid, projs, *, ch_names_mapping=None): def _check_projs(projs, copy=True): """Check that projs is a list of Projection.""" if not isinstance(projs, (list, tuple)): - raise TypeError("projs must be a list or tuple, got %s" % (type(projs),)) + raise TypeError(f"projs must be a list or tuple, got {type(projs)}") for pi, p in enumerate(projs): if not isinstance(p, Projection): raise TypeError( diff --git a/mne/_fiff/reference.py b/mne/_fiff/reference.py index 6bd422637bc..5822e87e17b 100644 --- a/mne/_fiff/reference.py +++ b/mne/_fiff/reference.py @@ -67,7 +67,7 @@ def _check_before_reference(inst, ref_from, ref_to, ch_type): else: extra = "channels supplied" if len(ref_to) == 0: - raise ValueError("No %s to apply the reference to" % (extra,)) + raise ValueError(f"No {extra} to apply the reference to") # After referencing, existing SSPs might not be valid anymore. projs_to_remove = [] @@ -301,8 +301,8 @@ def _check_can_reref(inst): FIFF.FIFFV_MNE_CUSTOM_REF_OFF, ): raise RuntimeError( - "Cannot set new reference on data with custom " - "reference type %r" % (_ref_dict[current_custom],) + "Cannot set new reference on data with custom reference type " + f"{_ref_dict[current_custom]!r}" ) @@ -363,8 +363,8 @@ def set_eeg_reference( if projection: # average reference projector if ref_channels != "average": raise ValueError( - "Setting projection=True is only supported for " - 'ref_channels="average", got %r.' % (ref_channels,) + 'Setting projection=True is only supported for ref_channels="average", ' + f"got {ref_channels!r}." ) # We need verbose='error' here in case we add projs sequentially if _has_eeg_average_ref_proj(inst.info, ch_type=ch_type, verbose="error"): diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index 1b87d828619..e1ae5ae571a 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -7,7 +7,9 @@ import html import re import struct +from dataclasses import dataclass from functools import partial +from typing import Any import numpy as np from scipy.sparse import csc_matrix, csr_matrix @@ -28,40 +30,16 @@ # HELPERS +@dataclass class Tag: - """Tag in FIF tree structure. + """Tag in FIF tree structure.""" - Parameters - ---------- - kind : int - Kind of Tag. - type_ : int - Type of Tag. - size : int - Size in bytes. - int : next - Position of next Tag. - pos : int - Position of Tag is the original file. - """ - - def __init__(self, kind, type_, size, next, pos=None): # noqa: D102 - self.kind = int(kind) - self.type = int(type_) - self.size = int(size) - self.next = int(next) - self.pos = pos if pos is not None else next - self.pos = int(self.pos) - self.data = None - - def __repr__(self): # noqa: D105 - attrs = list() - for attr in ("kind", "type", "size", "next", "pos", "data"): - try: - attrs.append(f"{attr} {getattr(self, attr)}") - except AttributeError: - pass - return "" + kind: int + type: int + size: int + next: int + pos: int + data: Any = None def __eq__(self, tag): # noqa: D105 return int( @@ -73,17 +51,15 @@ def __eq__(self, tag): # noqa: D105 and self.data == tag.data ) - -def read_tag_info(fid): - """Read Tag info (or header).""" - tag = _read_tag_header(fid) - if tag is None: - return None - if tag.next == 0: - fid.seek(tag.size, 1) - elif tag.next > 0: - fid.seek(tag.next, 0) - return tag + @property + def next_pos(self): + """The next tag position.""" + if self.next == FIFF.FIFFV_NEXT_SEQ: # 0 + return self.pos + 16 + self.size + elif self.next > 0: + return self.next + else: # self.next should be -1 if we get here + return None # safest to return None so that things like fid.seek die def _frombuffer_rows(fid, tag_size, dtype=None, shape=None, rlims=None): @@ -157,16 +133,18 @@ def _loc_to_eeg_loc(loc): # by the function names. -def _read_tag_header(fid): +def _read_tag_header(fid, pos): """Read only the header of a Tag.""" - s = fid.read(4 * 4) + fid.seek(pos, 0) + s = fid.read(16) if len(s) != 16: where = fid.tell() - len(s) extra = f" in file {fid.name}" if hasattr(fid, "name") else "" warn(f"Invalid tag with only {len(s)}/16 bytes at position {where}{extra}") return None # struct.unpack faster than np.frombuffer, saves ~10% of time some places - return Tag(*struct.unpack(">iIii", s)) + kind, type_, size, next_ = struct.unpack(">iIii", s) + return Tag(kind, type_, size, next_, pos) def _read_matrix(fid, tag, shape, rlims): @@ -178,10 +156,10 @@ def _read_matrix(fid, tag, shape, rlims): matrix_coding, matrix_type, bit, dtype = _matrix_info(tag) + pos = tag.pos + 16 + fid.seek(pos + tag.size - 4, 0) if matrix_coding == "dense": # Find dimensions and return to the beginning of tag data - pos = fid.tell() - fid.seek(tag.size - 4, 1) ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item()) fid.seek(-(ndim + 1) * 4, 1) dims = np.frombuffer(fid.read(4 * ndim), dtype=">i4")[::-1] @@ -205,8 +183,6 @@ def _read_matrix(fid, tag, shape, rlims): data.shape = dims else: # Find dimensions and return to the beginning of tag data - pos = fid.tell() - fid.seek(tag.size - 4, 1) ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item()) fid.seek(-(ndim + 2) * 4, 1) dims = np.frombuffer(fid.read(4 * (ndim + 1)), dtype=">i4") @@ -388,7 +364,16 @@ def _read_old_pack(fid, tag, shape, rlims): def _read_dir_entry_struct(fid, tag, shape, rlims): """Read dir entry struct tag.""" - return [_read_tag_header(fid) for _ in range(tag.size // 16 - 1)] + pos = tag.pos + 16 + entries = list() + for offset in range(1, tag.size // 16): + ent = _read_tag_header(fid, pos + offset * 16) + # The position of the real tag on disk is stored in the "next" entry within the + # directory, so we need to overwrite ent.pos. For safety let's also overwrite + # ent.next to point nowhere + ent.pos, ent.next = ent.next, FIFF.FIFFV_NEXT_NONE + entries.append(ent) + return entries def _read_julian(fid, tag, shape, rlims): @@ -439,7 +424,7 @@ def _read_julian(fid, tag, shape, rlims): _call_dict_names[key] = dtype -def read_tag(fid, pos=None, shape=None, rlims=None): +def read_tag(fid, pos, shape=None, rlims=None): """Read a Tag from a file at a given position. Parameters @@ -462,9 +447,7 @@ def read_tag(fid, pos=None, shape=None, rlims=None): tag : Tag The Tag read. """ - if pos is not None: - fid.seek(pos, 0) - tag = _read_tag_header(fid) + tag = _read_tag_header(fid, pos) if tag is None: return tag if tag.size > 0: @@ -477,10 +460,6 @@ def read_tag(fid, pos=None, shape=None, rlims=None): except KeyError: raise Exception(f"Unimplemented tag data type {tag.type}") from None tag.data = fun(fid, tag, shape, rlims) - if tag.next != FIFF.FIFFV_NEXT_SEQ: - # f.seek(tag.next,0) - fid.seek(tag.next, 1) # XXX : fix? pb when tag.next < 0 - return tag diff --git a/mne/_fiff/tests/test_compensator.py b/mne/_fiff/tests/test_compensator.py index 0a1b6f65fb3..350fb212032 100644 --- a/mne/_fiff/tests/test_compensator.py +++ b/mne/_fiff/tests/test_compensator.py @@ -14,7 +14,7 @@ from mne.io import read_raw_fif from mne.utils import requires_mne, run_subprocess -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" ctf_comp_fname = base_dir / "test_ctf_comp_raw.fif" diff --git a/mne/_fiff/tests/test_constants.py b/mne/_fiff/tests/test_constants.py index 3fc33513635..45a9899423d 100644 --- a/mne/_fiff/tests/test_constants.py +++ b/mne/_fiff/tests/test_constants.py @@ -342,7 +342,7 @@ def test_constants(tmp_path): break else: if name not in _tag_ignore_names: - raise RuntimeError("Could not find %s" % (name,)) + raise RuntimeError(f"Could not find {name}") assert check in used_enums, name if "SSS" in check: raise RuntimeError @@ -353,13 +353,13 @@ def test_constants(tmp_path): else: unknowns.append((name, val)) if check is not None and name not in _tag_ignore_names: - assert val in fif[check], "%s: %s, %s" % (check, val, name) + assert val in fif[check], f"{check}: {val}, {name}" if val in con[check]: - msg = "%s='%s' ?" % (name, con[check][val]) + msg = f"{name}='{con[check][val]}' ?" assert _aliases.get(name) == con[check][val], msg else: con[check][val] = name - unknowns = "\n\t".join("%s (%s)" % u for u in unknowns) + unknowns = "\n\t".join("{} ({})".format(*u) for u in unknowns) assert len(unknowns) == 0, "Unknown types\n\t%s" % unknowns # Assert that all the FIF defs are in our constants @@ -385,16 +385,16 @@ def test_constants(tmp_path): for key in fif["coil"]: if key not in _missing_coil_def and key not in coil_def: bad_list.append((" %s," % key).ljust(10) + " # " + fif["coil"][key][1]) - assert ( - len(bad_list) == 0 - ), "\nIn fiff-constants, missing from coil_def:\n" + "\n".join(bad_list) + assert len(bad_list) == 0, ( + "\nIn fiff-constants, missing from coil_def:\n" + "\n".join(bad_list) + ) # Assert that enum(coil) has all `coil_def.dat` entries for key, desc in zip(coil_def, coil_desc): if key not in fif["coil"]: bad_list.append((" %s," % key).ljust(10) + " # " + desc) - assert ( - len(bad_list) == 0 - ), "In coil_def, missing from fiff-constants:\n" + "\n".join(bad_list) + assert len(bad_list) == 0, ( + "In coil_def, missing from fiff-constants:\n" + "\n".join(bad_list) + ) @pytest.mark.parametrize( diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index b8aa28b9e1d..8552585eec4 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -73,7 +73,7 @@ from mne.transforms import Transform from mne.utils import _empty_hash, _record_warnings, assert_object_equal, catch_logging -root_dir = Path(__file__).parent.parent.parent +root_dir = Path(__file__).parents[2] fiducials_fname = root_dir / "data" / "fsaverage" / "fsaverage-fiducials.fif" base_dir = root_dir / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" @@ -350,9 +350,11 @@ def test_read_write_info(tmp_path): @testing.requires_testing_data def test_dir_warning(): """Test that trying to read a bad filename emits a warning before an error.""" - with pytest.raises(OSError, match="directory"): - with pytest.warns(RuntimeWarning, match="foo"): - read_info(ctf_fname) + with ( + pytest.raises(OSError, match="directory"), + pytest.warns(RuntimeWarning, match="does not conform"), + ): + read_info(ctf_fname) def test_io_dig_points(tmp_path): diff --git a/mne/_fiff/tests/test_pick.py b/mne/_fiff/tests/test_pick.py index 841ce2be9bd..ab9edeaec15 100644 --- a/mne/_fiff/tests/test_pick.py +++ b/mne/_fiff/tests/test_pick.py @@ -46,7 +46,7 @@ fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_mc = data_path / "SSS" / "test_move_anon_movecomp_raw_sss.fif" -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" ctf_fname = io_dir / "tests" / "data" / "test_ctf_raw.fif" fif_fname = io_dir / "tests" / "data" / "test_raw.fif" @@ -522,8 +522,7 @@ def test_picks_by_channels(): # duplicate check names = ["MEG 002", "MEG 002"] assert len(pick_channels(raw.info["ch_names"], names, ordered=False)) == 1 - with pytest.warns(FutureWarning, match="ordered=False"): - assert len(raw.copy().pick_channels(names)[0][0]) == 1 # legacy method OK here + assert len(raw.copy().pick_channels(names, ordered=False)[0][0]) == 1 # missing ch_name bad_names = names + ["BAD"] @@ -558,11 +557,17 @@ def test_clean_info_bads(): # simulate the bad channels raw.info["bads"] = eeg_bad_ch + meg_bad_ch + assert len(raw.info["projs"]) == 3 + raw.set_eeg_reference(projection=True) + assert len(raw.info["projs"]) == 4 + # simulate the call to pick_info excluding the bad eeg channels info_eeg = pick_info(raw.info, picks_eeg) + assert len(info_eeg["projs"]) == 1 # simulate the call to pick_info excluding the bad meg channels info_meg = pick_info(raw.info, picks_meg) + assert len(info_meg["projs"]) == 3 assert info_eeg["bads"] == eeg_bad_ch assert info_meg["bads"] == meg_bad_ch diff --git a/mne/_fiff/tests/test_proc_history.py b/mne/_fiff/tests/test_proc_history.py index d63fafc1648..eb0880271b0 100644 --- a/mne/_fiff/tests/test_proc_history.py +++ b/mne/_fiff/tests/test_proc_history.py @@ -11,7 +11,7 @@ from mne._fiff.constants import FIFF from mne.io import read_info -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_chpi_raw_sss.fif" diff --git a/mne/_fiff/tests/test_reference.py b/mne/_fiff/tests/test_reference.py index d82338e5f63..166b06e460a 100644 --- a/mne/_fiff/tests/test_reference.py +++ b/mne/_fiff/tests/test_reference.py @@ -38,7 +38,7 @@ from mne.io import RawArray, read_raw_fif from mne.utils import _record_warnings, catch_logging -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" data_dir = testing.data_path(download=False) / "MEG" / "sample" fif_fname = data_dir / "sample_audvis_trunc_raw.fif" diff --git a/mne/_fiff/tests/test_show_fiff.py b/mne/_fiff/tests/test_show_fiff.py index 41fad7c22d5..e25f248b02c 100644 --- a/mne/_fiff/tests/test_show_fiff.py +++ b/mne/_fiff/tests/test_show_fiff.py @@ -7,7 +7,7 @@ from mne.io import show_fiff -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" fname_evoked = base_dir / "test-ave.fif" fname_raw = base_dir / "test_raw.fif" fname_c_annot = base_dir / "test_raw-annot.fif" diff --git a/mne/_fiff/tree.py b/mne/_fiff/tree.py index 6aa7b5f4539..556dab1a537 100644 --- a/mne/_fiff/tree.py +++ b/mne/_fiff/tree.py @@ -4,12 +4,10 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import numpy as np from ..utils import logger, verbose from .constants import FIFF -from .tag import Tag, read_tag -from .write import _write, end_block, start_block, write_id +from .tag import read_tag def dir_tree_find(tree, kind): @@ -108,46 +106,3 @@ def make_dir_tree(fid, directory, start=0, indent=0, verbose=None): logger.debug(" " * indent + "end } %d" % block) last = this return tree, last - - -############################################################################### -# Writing - - -def copy_tree(fidin, in_id, nodes, fidout): - """Copy directory subtrees from fidin to fidout.""" - if len(nodes) <= 0: - return - - if not isinstance(nodes, list): - nodes = [nodes] - - for node in nodes: - start_block(fidout, node["block"]) - if node["id"] is not None: - if in_id is not None: - write_id(fidout, FIFF.FIFF_PARENT_FILE_ID, in_id) - - write_id(fidout, FIFF.FIFF_BLOCK_ID, in_id) - write_id(fidout, FIFF.FIFF_PARENT_BLOCK_ID, node["id"]) - - if node["directory"] is not None: - for d in node["directory"]: - # Do not copy these tags - if ( - d.kind == FIFF.FIFF_BLOCK_ID - or d.kind == FIFF.FIFF_PARENT_BLOCK_ID - or d.kind == FIFF.FIFF_PARENT_FILE_ID - ): - continue - - # Read and write tags, pass data through transparently - fidin.seek(d.pos, 0) - tag = Tag(*np.fromfile(fidin, (">i4,>I4,>i4,>i4"), 1)[0]) - tag.data = np.fromfile(fidin, ">B", tag.size) - _write(fidout, tag.data, tag.kind, 1, tag.type, ">B") - - for child in node["children"]: - copy_tree(fidin, in_id, child, fidout) - - end_block(fidout, node["block"]) diff --git a/mne/_fiff/utils.py b/mne/_fiff/utils.py index cdda8784e8a..09cc3046d6c 100644 --- a/mne/_fiff/utils.py +++ b/mne/_fiff/utils.py @@ -239,9 +239,8 @@ def _read_segments_file( block = np.fromfile(fid, dtype, count) if block.size != count: raise RuntimeError( - "Incorrect number of samples (%s != %s), " - "please report this error to MNE-Python " - "developers" % (block.size, count) + f"Incorrect number of samples ({block.size} != {count}), please " + "report this error to MNE-Python developers" ) block = block.reshape(n_channels, -1, order="F") n_samples = block.shape[1] # = count // n_channels @@ -340,7 +339,7 @@ def _construct_bids_filename(base, ext, part_idx, validate=True): ) suffix = deconstructed_base[-1] base = "_".join(deconstructed_base[:-1]) - use_fname = "{}_split-{:02}_{}{}".format(base, part_idx + 1, suffix, ext) + use_fname = f"{base}_split-{part_idx + 1:02}_{suffix}{ext}" if dirname: use_fname = op.join(dirname, use_fname) return use_fname diff --git a/mne/_fiff/what.py b/mne/_fiff/what.py index 9f0efa67453..5c248fe2c8f 100644 --- a/mne/_fiff/what.py +++ b/mne/_fiff/what.py @@ -65,7 +65,7 @@ def what(fname): try: func(fname, **kwargs) except Exception as exp: - logger.debug("Not %s: %s" % (what, exp)) + logger.debug(f"Not {what}: {exp}") else: return what return "unknown" diff --git a/mne/_freesurfer.py b/mne/_freesurfer.py index 6938e4f39fc..dd868c1ee0d 100644 --- a/mne/_freesurfer.py +++ b/mne/_freesurfer.py @@ -249,7 +249,7 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, atlas_ids=None): if atlas_ids is None: atlas_ids, colors = read_freesurfer_lut() elif return_colors: - raise ValueError("return_colors must be False if atlas_ids are " "provided") + raise ValueError("return_colors must be False if atlas_ids are provided") # restrict to the ones in the MRI, sorted by label name keep = np.isin(list(atlas_ids.values()), want) keys = sorted( @@ -554,7 +554,7 @@ def read_lta(fname, verbose=None): The affine transformation described by the lta file. """ _check_fname(fname, "read", must_exist=True) - with open(fname, "r") as fid: + with open(fname) as fid: lines = fid.readlines() # 0 is linear vox2vox, 1 is linear ras2ras trans_type = int(lines[0].split("=")[1].strip()[0]) @@ -715,7 +715,7 @@ def _get_lut(fname=None): ("A", "= 0).all(): raise ValueError( - "All control points must be positive (got %s)" - % (self.control_points[:3],) + f"All control points must be positive (got {self.control_points[:3]})" ) if isinstance(values, np.ndarray): values = [values] @@ -61,14 +60,13 @@ def __init__(self, control_points, values, interp="hann"): for v in values: if not (v is None or isinstance(v, np.ndarray)): raise TypeError( - 'All entries in "values" must be ndarray ' - "or None, got %s" % (type(v),) + 'All entries in "values" must be ndarray or None, got ' + f"{type(v)}" ) if v is not None and v.shape[0] != len(self.control_points): raise ValueError( - "Values, if provided, must be the same " - "length as the number of control points " - "(%s), got %s" % (len(self.control_points), v.shape[0]) + "Values, if provided, must be the same length as the number of " + f"control points ({len(self.control_points)}), got {v.shape[0]}" ) use_values = values @@ -84,9 +82,7 @@ def val(pt): self._left = self._right = self._use_interp = None known_types = ("cos2", "linear", "zero", "hann") if interp not in known_types: - raise ValueError( - 'interp must be one of %s, got "%s"' % (known_types, interp) - ) + raise ValueError(f'interp must be one of {known_types}, got "{interp}"') self._interp = interp def feed_generator(self, n_pts): @@ -95,10 +91,10 @@ def feed_generator(self, n_pts): n_pts = _ensure_int(n_pts, "n_pts") original_position = self._position stop = self._position + n_pts - logger.debug("Feed %s (%s-%s)" % (n_pts, self._position, stop)) + logger.debug(f"Feed {n_pts} ({self._position}-{stop})") used = np.zeros(n_pts, bool) if self._left is None: # first one - logger.debug(" Eval @ %s (%s)" % (0, self.control_points[0])) + logger.debug(f" Eval @ 0 ({self.control_points[0]})") self._left = self.values(self.control_points[0]) if len(self.control_points) == 1: self._right = self._left @@ -132,7 +128,7 @@ def feed_generator(self, n_pts): self._left_idx += 1 self._use_interp = None # need to recreate it eval_pt = self.control_points[self._left_idx + 1] - logger.debug(" Eval @ %s (%s)" % (self._left_idx + 1, eval_pt)) + logger.debug(f" Eval @ {self._left_idx + 1} ({eval_pt})") self._right = self.values(eval_pt) assert self._right is not None left_point = self.control_points[self._left_idx] @@ -153,8 +149,7 @@ def feed_generator(self, n_pts): n_use = min(stop, right_point) - self._position if n_use > 0: logger.debug( - " Interp %s %s (%s-%s)" - % (self._interp, n_use, left_point, right_point) + f" Interp {self._interp} {n_use} ({left_point}-{right_point})" ) interp_start = self._position - left_point assert interp_start >= 0 @@ -223,7 +218,7 @@ def _check_store(store): ): store = _Storer(*store) if not callable(store): - raise TypeError("store must be callable, got type %s" % (type(store),)) + raise TypeError(f"store must be callable, got type {type(store)}") return store @@ -288,11 +283,11 @@ def __init__( n_overlap = _ensure_int(n_overlap, "n_overlap") n_total = _ensure_int(n_total, "n_total") if n_samples <= 0: - raise ValueError("n_samples must be > 0, got %s" % (n_samples,)) + raise ValueError(f"n_samples must be > 0, got {n_samples}") if n_overlap < 0: - raise ValueError("n_overlap must be >= 0, got %s" % (n_overlap,)) + raise ValueError(f"n_overlap must be >= 0, got {n_overlap}") if n_total < 0: - raise ValueError("n_total must be >= 0, got %s" % (n_total,)) + raise ValueError(f"n_total must be >= 0, got {n_total}") self._n_samples = int(n_samples) self._n_overlap = int(n_overlap) del n_samples, n_overlap @@ -302,7 +297,7 @@ def __init__( "most the total number of samples (%s)" % (self._n_samples, n_total) ) if not callable(process): - raise TypeError("process must be callable, got type %s" % (type(process),)) + raise TypeError(f"process must be callable, got type {type(process)}") self._process = process self._step = self._n_samples - self._n_overlap self._store = _check_store(store) @@ -337,8 +332,7 @@ def __init__( del window, window_name if delta > 0: logger.info( - " The final %0.3f s will be lumped into the " - "final window" % (delta / sfreq,) + f" The final {delta / sfreq} s will be lumped into the final window" ) @property @@ -376,14 +370,9 @@ def feed(self, *datas, verbose=None, **kwargs): or self._in_buffers[di].dtype != data.dtype ): raise TypeError( - "data must dtype %s and shape[:-1]==%s, " - "got dtype %s shape[:-1]=%s" - % ( - self._in_buffers[di].dtype, - self._in_buffers[di].shape[:-1], - data.dtype, - data.shape[:-1], - ) + f"data must dtype {self._in_buffers[di].dtype} and " + f"shape[:-1]=={self._in_buffers[di].shape[:-1]}, got dtype " + f"{data.dtype} shape[:-1]={data.shape[:-1]}" ) logger.debug( " + Appending %d->%d" @@ -392,9 +381,8 @@ def feed(self, *datas, verbose=None, **kwargs): self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: raise ValueError( - "data (shape %s) exceeded expected total " - "buffer size (%s > %s)" - % (data.shape, self._in_offset, self.stops[-1]) + f"data (shape {data.shape}) exceeded expected total buffer size (" + f"{self._in_offset} > {self.stops[-1]})" ) # Check to see if we can process the next chunk and dump outputs while self._idx < len(self.starts) and self._in_offset >= self.stops[self._idx]: @@ -411,7 +399,7 @@ def feed(self, *datas, verbose=None, **kwargs): if self._idx == 0: for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] - logger.debug(" * Processing %d->%d" % (start, stop)) + logger.debug(f" * Processing {start}->{stop}") this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] if not all( proc.shape[-1] == this_len == this_window.size for proc in this_proc @@ -466,7 +454,7 @@ class _Storer: def __init__(self, *outs, picks=None): for oi, out in enumerate(outs): if not isinstance(out, np.ndarray) or out.ndim < 1: - raise TypeError("outs[oi] must be >= 1D ndarray, got %s" % (out,)) + raise TypeError(f"outs[oi] must be >= 1D ndarray, got {out}") self.outs = outs self.idx = 0 self.picks = picks diff --git a/mne/annotations.py b/mne/annotations.py index be62dac9dba..1c66fee1be5 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -38,6 +38,8 @@ _check_fname, _check_option, _check_pandas_installed, + _check_time_format, + _convert_times, _DefaultEventParser, _dt_to_stamp, _is_numeric, @@ -63,15 +65,15 @@ def _check_o_d_s_c(onset, duration, description, ch_names): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: raise ValueError( - "Onset must be a one dimensional array, got %s " - "(shape %s)." % (onset.ndim, onset.shape) + f"Onset must be a one dimensional array, got {onset.ndim} (shape " + f"{onset.shape})." ) duration = np.array(duration, dtype=float) if duration.ndim == 0 or duration.shape == (1,): duration = np.repeat(duration, len(onset)) if duration.ndim != 1: raise ValueError( - "Duration must be a one dimensional array, " "got %d." % (duration.ndim,) + f"Duration must be a one dimensional array, got {duration.ndim}." ) description = np.array(description, dtype=str) @@ -79,8 +81,7 @@ def _check_o_d_s_c(onset, duration, description, ch_names): description = np.repeat(description, len(onset)) if description.ndim != 1: raise ValueError( - "Description must be a one dimensional array, " - "got %d." % (description.ndim,) + f"Description must be a one dimensional array, got {description.ndim}." ) _safe_name_list(description, "write", "description") @@ -276,9 +277,7 @@ class Annotations: :meth:`Raw.save() ` notes for details. """ # noqa: E501 - def __init__( - self, onset, duration, description, orig_time=None, ch_names=None - ): # noqa: D102 + def __init__(self, onset, duration, description, orig_time=None, ch_names=None): self._orig_time = _handle_meas_date(orig_time) self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c( onset, duration, description, ch_names @@ -305,14 +304,12 @@ def __eq__(self, other): def __repr__(self): """Show the representation.""" counter = Counter(self.description) - kinds = ", ".join(["%s (%s)" % k for k in sorted(counter.items())]) + kinds = ", ".join(["{} ({})".format(*k) for k in sorted(counter.items())]) kinds = (": " if len(kinds) > 0 else "") + kinds ch_specific = ", channel-specific" if self._any_ch_names() else "" - s = "Annotations | %s segment%s%s%s" % ( - len(self.onset), - _pl(len(self.onset)), - ch_specific, - kinds, + s = ( + f"Annotations | {len(self.onset)} segment" + f"{_pl(len(self.onset))}{ch_specific}{kinds}" ) return "<" + shorten(s, width=77, placeholder=" ...") + ">" @@ -341,9 +338,8 @@ def __iadd__(self, other): self._orig_time = other.orig_time if self.orig_time != other.orig_time: raise ValueError( - "orig_time should be the same to " - "add/concatenate 2 annotations " - "(got %s != %s)" % (self.orig_time, other.orig_time) + "orig_time should be the same to add/concatenate 2 annotations (got " + f"{self.orig_time} != {other.orig_time})" ) return self.append( other.onset, other.duration, other.description, other.ch_names @@ -444,9 +440,16 @@ def delete(self, idx): self.description = np.delete(self.description, idx) self.ch_names = np.delete(self.ch_names, idx) - def to_data_frame(self): + @fill_doc + def to_data_frame(self, time_format="datetime"): """Export annotations in tabular structure as a pandas DataFrame. + Parameters + ---------- + %(time_format_df_raw)s + + .. versionadded:: 1.7 + Returns ------- result : pandas.DataFrame @@ -455,12 +458,14 @@ def to_data_frame(self): annotations are channel-specific. """ pd = _check_pandas_installed(strict=True) + valid_time_formats = ["ms", "timedelta", "datetime"] dt = _handle_meas_date(self.orig_time) if dt is None: dt = _handle_meas_date(0) + time_format = _check_time_format(time_format, valid_time_formats, dt) dt = dt.replace(tzinfo=None) - onsets_dt = [dt + timedelta(seconds=o) for o in self.onset] - df = dict(onset=onsets_dt, duration=self.duration, description=self.description) + times = _convert_times(self.onset, time_format, dt) + df = dict(onset=times, duration=self.duration, description=self.description) if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) @@ -612,10 +617,10 @@ def crop( del tmin, tmax if absolute_tmin > absolute_tmax: raise ValueError( - "tmax should be greater than or equal to tmin " - "(%s < %s)." % (absolute_tmin, absolute_tmax) + f"tmax should be greater than or equal to tmin ({absolute_tmin} < " + f"{absolute_tmax})." ) - logger.debug("Cropping annotations %s - %s" % (absolute_tmin, absolute_tmax)) + logger.debug(f"Cropping annotations {absolute_tmin} - {absolute_tmax}") onsets, durations, descriptions, ch_names = [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] @@ -813,11 +818,10 @@ def set_annotations(self, annotations, on_missing="raise", *, verbose=None): else: if getattr(self, "_unsafe_annot_add", False): warn( - "Adding annotations to Epochs created (and saved to " - "disk) before 1.0 will yield incorrect results if " - "decimation or resampling was performed on the instance, " - "we recommend regenerating the Epochs and re-saving them " - "to disk." + "Adding annotations to Epochs created (and saved to disk) before " + "1.0 will yield incorrect results if decimation or resampling was " + "performed on the instance, we recommend regenerating the Epochs " + "and re-saving them to disk." ) new_annotations = annotations.copy() new_annotations._prune_ch_names(self.info, on_missing) @@ -1141,7 +1145,9 @@ def _write_annotations_txt(fname, annot): @fill_doc -def read_annotations(fname, sfreq="auto", uint16_codec=None, encoding="utf8"): +def read_annotations( + fname, sfreq="auto", uint16_codec=None, encoding="utf8" +) -> Annotations: r"""Read annotations from a file. This function reads a ``.fif``, ``.fif.gz``, ``.vmrk``, ``.amrk``, @@ -1174,7 +1180,7 @@ def read_annotations(fname, sfreq="auto", uint16_codec=None, encoding="utf8"): Returns ------- - annot : instance of Annotations | None + annot : instance of Annotations The annotations. Notes @@ -1485,7 +1491,7 @@ def _check_event_id(event_id, raw): else: raise ValueError( "Invalid type for event_id (should be None, str, " - "dict or callable). Got {}".format(type(event_id)) + f"dict or callable). Got {type(event_id)}." ) @@ -1500,16 +1506,14 @@ def _check_event_description(event_desc, events): elif isinstance(event_desc, Iterable): event_desc = np.asarray(event_desc) if event_desc.ndim != 1: - raise ValueError( - "event_desc must be 1D, got shape {}".format(event_desc.shape) - ) + raise ValueError(f"event_desc must be 1D, got shape {event_desc.shape}") event_desc = dict(zip(event_desc, map(str, event_desc))) elif callable(event_desc): pass else: raise ValueError( "Invalid type for event_desc (should be None, list, " - "1darray, dict or callable). Got {}".format(type(event_desc)) + f"1darray, dict or callable). Got {type(event_desc)}." ) return event_desc @@ -1522,6 +1526,7 @@ def events_from_annotations( regexp=r"^(?![Bb][Aa][Dd]|[Ee][Dd][Gg][Ee]).*$", use_rounding=True, chunk_duration=None, + tol=1e-8, verbose=None, ): """Get :term:`events` and ``event_id`` from an Annotations object. @@ -1565,6 +1570,11 @@ def events_from_annotations( they fit within the annotation duration spaced according to ``chunk_duration``. As a consequence annotations with duration shorter than ``chunk_duration`` will not contribute events. + tol : float + The tolerance used to check if a chunk fits within an annotation when + ``chunk_duration`` is not ``None``. If the duration from a computed + chunk onset to the end of the annotation is smaller than + ``chunk_duration`` minus ``tol``, the onset will be discarded. %(verbose)s Returns @@ -1609,10 +1619,8 @@ def events_from_annotations( inds = values = np.array([]).astype(int) for annot in annotations[event_sel]: annot_offset = annot["onset"] + annot["duration"] - _onsets = np.arange( - start=annot["onset"], stop=annot_offset, step=chunk_duration - ) - good_events = annot_offset - _onsets >= chunk_duration + _onsets = np.arange(annot["onset"], annot_offset, chunk_duration) + good_events = annot_offset - _onsets >= chunk_duration - tol if good_events.any(): _onsets = _onsets[good_events] _inds = raw.time_as_index( @@ -1629,7 +1637,7 @@ def events_from_annotations( events = np.c_[inds, np.zeros(len(inds)), values].astype(int) - logger.info("Used Annotations descriptions: %s" % (list(event_id_.keys()),)) + logger.info(f"Used Annotations descriptions: {list(event_id_.keys())}") return events, event_id_ diff --git a/mne/baseline.py b/mne/baseline.py index 3994c5522e5..36ab0fc514f 100644 --- a/mne/baseline.py +++ b/mne/baseline.py @@ -77,7 +77,7 @@ def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=N imin = np.where(times >= bmin)[0] if len(imin) == 0: raise ValueError( - "bmin is too large (%s), it exceeds the largest " "time value" % (bmin,) + f"bmin is too large ({bmin}), it exceeds the largest time value" ) imin = int(imin[0]) if bmax is None: @@ -86,14 +86,13 @@ def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=N imax = np.where(times <= bmax)[0] if len(imax) == 0: raise ValueError( - "bmax is too small (%s), it is smaller than the " - "smallest time value" % (bmax,) + f"bmax is too small ({bmax}), it is smaller than the smallest time " + "value" ) imax = int(imax[-1]) + 1 if imin >= imax: raise ValueError( - "Bad rescaling slice (%s:%s) from time values %s, %s" - % (imin, imax, bmin, bmax) + f"Bad rescaling slice ({imin}:{imax}) from time values {bmin}, {bmax}" ) # technically this is inefficient when `picks` is given, but assuming @@ -188,8 +187,8 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): # check default value of baseline and `tmin=0` if baseline == (None, 0) and tmin == 0: raise ValueError( - "Baseline interval is only one sample. Use " - "`baseline=(0, 0)` if this is desired." + "Baseline interval is only one sample. Use `baseline=(0, 0)` if this is " + "desired." ) baseline_tmin, baseline_tmax = baseline @@ -204,15 +203,14 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): if baseline_tmin > baseline_tmax: raise ValueError( - "Baseline min (%s) must be less than baseline max (%s)" - % (baseline_tmin, baseline_tmax) + f"Baseline min ({baseline_tmin}) must be less than baseline max (" + f"{baseline_tmax})" ) if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep): msg = ( - f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s " - f"is outside of epochs data [{tmin}, {tmax}] s. Epochs were " - f"probably cropped." + f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s is outside of " + f"epochs data [{tmin}, {tmax}] s. Epochs were probably cropped." ) if on_baseline_outside_data == "raise": raise ValueError(msg) diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index 975f0852208..16cbc18e6d7 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -120,14 +120,13 @@ def _prepare_beamformer_input( nn[...] = [0, 0, 1] # align to local +Z coordinate if pick_ori is not None and not is_free_ori: raise ValueError( - "Normal or max-power orientation (got %r) can only be picked when " - "a forward operator with free orientation is used." % (pick_ori,) + f"Normal or max-power orientation (got {pick_ori!r}) can only be picked " + "when a forward operator with free orientation is used." ) if pick_ori == "normal" and not forward["surf_ori"]: raise ValueError( - "Normal orientation can only be picked when a " - "forward operator oriented in surface coordinates is " - "used." + "Normal orientation can only be picked when a forward operator oriented in " + "surface coordinates is used." ) _check_src_normal(pick_ori, forward["src"]) del forward, info @@ -505,21 +504,21 @@ def __repr__(self): # noqa: D105 if self["subject"] is None: subject = "unknown" else: - subject = '"%s"' % (self["subject"],) - out = "aso", projection, G) diff --git a/mne/beamformer/resolution_matrix.py b/mne/beamformer/resolution_matrix.py index 108fb7a4dbf..ce55a09584b 100644 --- a/mne/beamformer/resolution_matrix.py +++ b/mne/beamformer/resolution_matrix.py @@ -1,4 +1,5 @@ """Compute resolution matrix for beamformers.""" + # Authors: olaf.hauk@mrc-cbu.cam.ac.uk # # License: BSD-3-Clause diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 1daaaf17eb0..bcde4503307 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -30,7 +30,7 @@ from mne.io import read_info from mne.proj import compute_proj_evoked, make_projector from mne.surface import _compute_nearest -from mne.time_frequency import CrossSpectralDensity, EpochsTFR, csd_morlet, csd_tfr +from mne.time_frequency import CrossSpectralDensity, EpochsTFRArray, csd_morlet, csd_tfr from mne.time_frequency.csd import _sym_mat_to_vector from mne.transforms import apply_trans, invert_transform from mne.utils import catch_logging, object_diff @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator): data = rng.random((n_epochs, n_chans, len(freqs), n_times)) data *= 1e-6 data = data + data * 1j # add imag. component to simulate phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs) # Create a DICS beamformer and convert the EpochsTFR to source space. csd = csd_tfr(epochs_tfr) diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 15c1a2ba5eb..509afbcf79e 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -589,19 +589,22 @@ def test_make_lcmv_sphere(pick_ori, weight_norm): fwd_sphere = mne.make_forward_solution(evoked.info, None, src, sphere) # Test that we get an error if not reducing rank - with pytest.raises(ValueError, match="Singular matrix detected"): - with pytest.warns(RuntimeWarning, match="positive semidefinite"): - make_lcmv( - evoked.info, - fwd_sphere, - data_cov, - reg=0.1, - noise_cov=noise_cov, - weight_norm=weight_norm, - pick_ori=pick_ori, - reduce_rank=False, - rank="full", - ) + with ( + pytest.raises(ValueError, match="Singular matrix detected"), + _record_warnings(), + pytest.warns(RuntimeWarning, match="positive semidefinite"), + ): + make_lcmv( + evoked.info, + fwd_sphere, + data_cov, + reg=0.1, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + reduce_rank=False, + rank="full", + ) # Now let's reduce it filters = make_lcmv( diff --git a/mne/beamformer/tests/test_rap_music.py b/mne/beamformer/tests/test_rap_music.py index c98c83a7722..de6f047def8 100644 --- a/mne/beamformer/tests/test_rap_music.py +++ b/mne/beamformer/tests/test_rap_music.py @@ -50,9 +50,7 @@ def simu_data(evoked, forward, noise_cov, n_dipoles, times, nave=1): # Generate the two dipoles data mu, sigma = 0.1, 0.005 s1 = ( - 1 - / (sigma * np.sqrt(2 * np.pi)) - * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) + 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) ) mu, sigma = 0.075, 0.008 diff --git a/mne/bem.py b/mne/bem.py index a78309ef626..88104ea9cc2 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -100,16 +100,16 @@ def __repr__(self): # noqa: D105 if rad is None: # no radius / MEG only extra = "Sphere (no layers): r0=[%s] mm" % center else: - extra = "Sphere (%s layer%s): r0=[%s] R=%1.f mm" % ( + extra = "Sphere ({} layer{}): r0=[{}] R={:1.0f} mm".format( len(self["layers"]) - 1, _pl(self["layers"]), center, rad * 1000.0, ) else: - extra = "BEM (%s layer%s)" % (len(self["surfs"]), _pl(self["surfs"])) - extra += " solver=%s" % self["solver"] - return "" % extra + extra = f"BEM ({len(self['surfs'])} layer{_pl(self['surfs'])})" + extra += f" solver={self['solver']}" + return f"" def copy(self): """Return copy of ConductorModel instance.""" @@ -415,8 +415,8 @@ def make_bem_solution(surfs, *, solver="mne", verbose=None): surfs : list of dict The BEM surfaces to use (from :func:`mne.make_bem_model`). solver : str - Can be ``'mne'`` (default) to use MNE-Python, or ``'openmeeg'`` to use - the :doc:`OpenMEEG ` package. + Can be ``'mne'`` (default) to use MNE-Python, or ``'openmeeg'`` to use the + `OpenMEEG `__ package. .. versionadded:: 1.2 %(verbose)s @@ -542,8 +542,9 @@ def _assert_complete_surface(surf, incomplete="raise"): # Center of mass.... cm = surf["rr"].mean(axis=0) logger.info( - "%s CM is %6.2f %6.2f %6.2f mm" - % (_bem_surf_name[surf["id"]], 1000 * cm[0], 1000 * cm[1], 1000 * cm[2]) + "{} CM is {:6.2f} {:6.2f} {:6.2f} mm".format( + _bem_surf_name[surf["id"]], 1000 * cm[0], 1000 * cm[1], 1000 * cm[2] + ) ) tot_angle = _get_solids(surf["rr"][surf["tris"]], cm[np.newaxis, :])[0] prop = tot_angle / (2 * np.pi) @@ -897,18 +898,18 @@ def make_sphere_model( param = locals()[name] if isinstance(param, str): if param != "auto": - raise ValueError('%s, if str, must be "auto" not "%s"' % (name, param)) + raise ValueError(f'{name}, if str, must be "auto" not "{param}"') relative_radii = np.array(relative_radii, float).ravel() sigmas = np.array(sigmas, float).ravel() if len(relative_radii) != len(sigmas): raise ValueError( - "relative_radii length (%s) must match that of " - "sigmas (%s)" % (len(relative_radii), len(sigmas)) + f"relative_radii length ({len(relative_radii)}) must match that of sigmas (" + f"{len(sigmas)})" ) if len(sigmas) <= 1 and head_radius is not None: raise ValueError( - "at least 2 sigmas must be supplied if " - "head_radius is not None, got %s" % (len(sigmas),) + "at least 2 sigmas must be supplied if head_radius is not None, got " + f"{len(sigmas)}" ) if (isinstance(r0, str) and r0 == "auto") or ( isinstance(head_radius, str) and head_radius == "auto" @@ -964,8 +965,7 @@ def make_sphere_model( ) ) logger.info( - "Set up EEG sphere model with scalp radius %7.1f mm\n" - % (1000 * head_radius,) + f"Set up EEG sphere model with scalp radius {1000 * head_radius:7.1f} mm\n" ) return sphere @@ -1082,10 +1082,9 @@ def get_fitting_dig(info, dig_kinds="auto", exclude_frontal=True, verbose=None): if len(hsp) <= 10: kinds_str = ", ".join(['"%s"' % _dig_kind_rev[d] for d in sorted(dig_kinds)]) - msg = "Only %s head digitization points of the specified kind%s (%s,)" % ( - len(hsp), - _pl(dig_kinds), - kinds_str, + msg = ( + f"Only {len(hsp)} head digitization points of the specified " + f"kind{_pl(dig_kinds)} ({kinds_str},)" ) if len(hsp) < 4: raise ValueError(msg + ", at least 4 required") @@ -1105,22 +1104,22 @@ def _fit_sphere_to_headshape(info, dig_kinds, verbose=None): dev_head_t = Transform("meg", "head") head_to_dev = _ensure_trans(dev_head_t, "head", "meg") origin_device = apply_trans(head_to_dev, origin_head) - logger.info("Fitted sphere radius:".ljust(30) + "%0.1f mm" % (radius * 1e3,)) + logger.info("Fitted sphere radius:".ljust(30) + f"{radius * 1e3:0.1f} mm") _check_head_radius(radius) # > 2 cm away from head center in X or Y is strange if np.linalg.norm(origin_head[:2]) > 0.02: warn( - "(X, Y) fit (%0.1f, %0.1f) more than 20 mm from " - "head frame origin" % tuple(1e3 * origin_head[:2]) + "(X, Y) fit ({:0.1f}, {:0.1f}) more than 20 mm from head frame " + "origin".format(*tuple(1e3 * origin_head[:2])) ) logger.info( "Origin head coordinates:".ljust(30) - + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_head) + + "{:0.1f} {:0.1f} {:0.1f} mm".format(*tuple(1e3 * origin_head)) ) logger.info( "Origin device coordinates:".ljust(30) - + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_device) + + "{:0.1f} {:0.1f} {:0.1f} mm".format(*tuple(1e3 * origin_device)) ) return radius, origin_head, origin_device @@ -1163,15 +1162,13 @@ def _check_origin(origin, info, coord_frame="head", disp=False): if isinstance(origin, str): if origin != "auto": raise ValueError( - 'origin must be a numerical array, or "auto", ' "not %s" % (origin,) + f'origin must be a numerical array, or "auto", not {origin}' ) if coord_frame == "head": R, origin = fit_sphere_to_headshape( info, verbose=_verbose_safe_false(), units="m" )[:2] - logger.info( - " Automatic origin fit: head of radius %0.1f mm" % (R * 1000.0,) - ) + logger.info(f" Automatic origin fit: head of radius {R * 1000:0.1f} mm") del R else: origin = (0.0, 0.0, 0.0) @@ -1179,12 +1176,12 @@ def _check_origin(origin, info, coord_frame="head", disp=False): if origin.shape != (3,): raise ValueError("origin must be a 3-element array") if disp: - origin_str = ", ".join(["%0.1f" % (o * 1000) for o in origin]) - msg = " Using origin %s mm in the %s frame" % (origin_str, coord_frame) + origin_str = ", ".join([f"{o * 1000:0.1f}" for o in origin]) + msg = f" Using origin {origin_str} mm in the {coord_frame} frame" if coord_frame == "meg" and info["dev_head_t"] is not None: o_dev = apply_trans(info["dev_head_t"], origin) - origin_str = ", ".join("%0.1f" % (o * 1000,) for o in o_dev) - msg += " (%s mm in the head frame)" % (origin_str,) + origin_str = ", ".join(f"{o * 1000:0.1f}" for o in o_dev) + msg += f" ({origin_str} mm in the head frame)" logger.info(msg) return origin @@ -1299,7 +1296,7 @@ def make_watershed_bem( if gcaatlas: fname = op.join(env["FREESURFER_HOME"], "average", "RB_all_withskull_*.gca") fname = sorted(glob.glob(fname))[::-1][0] - logger.info("Using GCA atlas: %s" % (fname,)) + logger.info(f"Using GCA atlas: {fname}") cmd += [ "-atlas", "-brain_atlas", @@ -1326,9 +1323,8 @@ def make_watershed_bem( ] # report and run logger.info( - "\nRunning mri_watershed for BEM segmentation with the " - "following parameters:\n\nResults dir = %s\nCommand = %s\n" - % (ws_dir, " ".join(cmd)) + "\nRunning mri_watershed for BEM segmentation with the following parameters:\n" + f"\nResults dir = {ws_dir}\nCommand = {' '.join(cmd)}\n" ) os.makedirs(op.join(ws_dir)) run_subprocess_env(cmd) @@ -1337,12 +1333,12 @@ def make_watershed_bem( new_info = _extract_volume_info(T1_mgz) if not new_info: warn( - "nibabel is not available or the volume info is invalid." - "Volume info not updated in the written surface." + "nibabel is not available or the volume info is invalid. Volume info " + "not updated in the written surface." ) surfs = ["brain", "inner_skull", "outer_skull", "outer_skin"] for s in surfs: - surf_ws_out = op.join(ws_dir, "%s_%s_surface" % (subject, s)) + surf_ws_out = op.join(ws_dir, f"{subject}_{s}_surface") rr, tris, volume_info = read_surface(surf_ws_out, read_metadata=True) # replace volume info, 'head' stays @@ -1352,7 +1348,7 @@ def make_watershed_bem( ) # Create symbolic links - surf_out = op.join(bem_dir, "%s.surf" % s) + surf_out = op.join(bem_dir, f"{s}.surf") if not overwrite and op.exists(surf_out): skip_symlink = True else: @@ -1363,9 +1359,8 @@ def make_watershed_bem( if skip_symlink: logger.info( - "Unable to create all symbolic links to .surf files " - "in bem folder. Use --overwrite option to recreate " - "them." + "Unable to create all symbolic links to .surf files in bem folder. Use " + "--overwrite option to recreate them." ) dest = op.join(bem_dir, "watershed") else: @@ -1373,8 +1368,8 @@ def make_watershed_bem( dest = bem_dir logger.info( - "\nThank you for waiting.\nThe BEM triangulations for this " - "subject are now available at:\n%s." % dest + "\nThank you for waiting.\nThe BEM triangulations for this subject are now " + f"available at:\n{dest}." ) # Write a head file for coregistration @@ -1399,7 +1394,7 @@ def make_watershed_bem( show=True, ) - logger.info("Created %s\n\nComplete." % (fname_head,)) + logger.info(f"Created {fname_head}\n\nComplete.") def _extract_volume_info(mgz): @@ -1929,9 +1924,7 @@ def _prepare_env(subject, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) subject_dir = subjects_dir / subject if not subject_dir.is_dir(): - raise RuntimeError( - 'Could not find the subject data directory "%s"' % (subject_dir,) - ) + raise RuntimeError(f'Could not find the subject data directory "{subject_dir}"') env.update(SUBJECT=subject, SUBJECTS_DIR=str(subjects_dir), FREESURFER_HOME=fs_home) mri_dir = subject_dir / "mri" bem_dir = subject_dir / "bem" @@ -2152,11 +2145,9 @@ def make_flash_bem( flash_path.mkdir(exist_ok=True, parents=True) logger.info( - "\nProcessing the flash MRI data to produce BEM meshes with " - "the following parameters:\n" - "SUBJECTS_DIR = %s\n" - "SUBJECT = %s\n" - "Result dir = %s\n" % (subjects_dir, subject, bem_dir / "flash") + "\nProcessing the flash MRI data to produce BEM meshes with the following " + f"parameters:\nSUBJECTS_DIR = {subjects_dir}\nSUBJECT = {subject}\nResult dir =" + f"{bem_dir / 'flash'}\n" ) # Step 4 : Register with MPRAGE flash5 = flash_path / "flash5.mgz" @@ -2463,7 +2454,7 @@ def check_seghead(surf_path=subj_path / "surf"): surf = check_seghead() if surf is None: - raise RuntimeError("mkheadsurf did not produce the standard output " "file.") + raise RuntimeError("mkheadsurf did not produce the standard output file.") bem_dir = subjects_dir / subject / "bem" if not bem_dir.is_dir(): diff --git a/mne/channels/_dig_montage_utils.py b/mne/channels/_dig_montage_utils.py index 4d2e9e6af3f..2136934972d 100644 --- a/mne/channels/_dig_montage_utils.py +++ b/mne/channels/_dig_montage_utils.py @@ -13,9 +13,8 @@ # Copyright the MNE-Python contributors. import numpy as np -from defusedxml import ElementTree -from ..utils import Bunch, _check_fname, warn +from ..utils import Bunch, _check_fname, _soft_import, warn def _read_dig_montage_egi( @@ -28,10 +27,10 @@ def _read_dig_montage_egi( "hsp, hpi, elp, point_names, fif must all be " "None if egi is not None" ) _check_fname(fname, overwrite="read", must_exist=True) - - root = ElementTree.parse(fname).getroot() + defusedxml = _soft_import("defusedxml", "reading EGI montages") + root = defusedxml.ElementTree.parse(fname).getroot() ns = root.tag[root.tag.index("{") : root.tag.index("}") + 1] - sensors = root.find("%ssensorLayout/%ssensors" % (ns, ns)) + sensors = root.find(f"{ns}sensorLayout/{ns}sensors") fids = dict() dig_ch_pos = dict() @@ -76,8 +75,8 @@ def _read_dig_montage_egi( def _parse_brainvision_dig_montage(fname, scale): FID_NAME_MAP = {"Nasion": "nasion", "RPA": "rpa", "LPA": "lpa"} - - root = ElementTree.parse(fname).getroot() + defusedxml = _soft_import("defusedxml", "reading BrainVision montages") + root = defusedxml.ElementTree.parse(fname).getroot() sensors = root.find("CapTrakElectrodeList") fids, dig_ch_pos = dict(), dict() diff --git a/mne/channels/_standard_montage_utils.py b/mne/channels/_standard_montage_utils.py index 43c8fa6aecd..4df6c685912 100644 --- a/mne/channels/_standard_montage_utils.py +++ b/mne/channels/_standard_montage_utils.py @@ -9,11 +9,10 @@ from functools import partial import numpy as np -from defusedxml import ElementTree from .._freesurfer import get_mni_fiducials from ..transforms import _sph_to_cart -from ..utils import _pl, warn +from ..utils import _pl, _soft_import, warn from . import __file__ as _CHANNELS_INIT_FILE from .montage import make_dig_montage @@ -100,7 +99,7 @@ def _mgh_or_standard(basename, head_size, coord_frame="unknown"): pos = np.array(pos) / 1000.0 ch_pos = _check_dupes_odict(ch_names_, pos) - nasion, lpa, rpa = [ch_pos.pop(n) for n in fid_names] + nasion, lpa, rpa = (ch_pos.pop(n) for n in fid_names) if head_size is None: scale = 1.0 else: @@ -110,7 +109,7 @@ def _mgh_or_standard(basename, head_size, coord_frame="unknown"): # if we are in MRI/MNI coordinates, we need to replace nasion, LPA, and RPA # with those of fsaverage for ``trans='fsaverage'`` to work if coord_frame == "mri": - lpa, nasion, rpa = [x["r"].copy() for x in get_mni_fiducials("fsaverage")] + lpa, nasion, rpa = (x["r"].copy() for x in get_mni_fiducials("fsaverage")) nasion *= scale lpa *= scale rpa *= scale @@ -185,7 +184,7 @@ def _read_sfp(fname, head_size): ch_pos = _check_dupes_odict(ch_names, pos) del xs, ys, zs, ch_names # no one grants that fid names are there. - nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] + nasion, lpa, rpa = (ch_pos.pop(n, None) for n in fid_names) if head_size is not None: scale = head_size / np.median(np.linalg.norm(pos, axis=-1)) @@ -275,7 +274,7 @@ def _read_elc(fname, head_size): pos *= head_size / np.median(np.linalg.norm(pos, axis=1)) ch_pos = _check_dupes_odict(ch_names_, pos) - nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] + nasion, lpa, rpa = (ch_pos.pop(n, None) for n in fid_names) return make_dig_montage( ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa @@ -305,7 +304,7 @@ def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, add_fiducials=F nasion, lpa, rpa = None, None, None if fid_names is not None: - nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] + nasion, lpa, rpa = (ch_pos.pop(n, None) for n in fid_names) return make_dig_montage( ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa @@ -333,7 +332,7 @@ def _read_elp_besa(fname, head_size): fid_names = ("Nz", "LPA", "RPA") # No one grants that the fid names actually exist. - nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] + nasion, lpa, rpa = (ch_pos.pop(n, None) for n in fid_names) return make_dig_montage(ch_pos=ch_pos, nasion=nasion, lpa=lpa, rpa=rpa) @@ -344,7 +343,8 @@ def _read_brainvision(fname, head_size): # standard electrode positions: X-axis from T7 to T8, Y-axis from Oz to # Fpz, Z-axis orthogonal from XY-plane through Cz, fit to a sphere if # idealized (when radius=1), specified in millimeters - root = ElementTree.parse(fname).getroot() + defusedxml = _soft_import("defusedxml", "reading BrainVision montages") + root = defusedxml.ElementTree.parse(fname).getroot() ch_names = [s.text for s in root.findall("./Electrode/Name")] theta = [float(s.text) for s in root.findall("./Electrode/Theta")] pol = np.deg2rad(np.array(theta)) @@ -383,7 +383,7 @@ def _read_xyz(fname): ch_names = [] pos = [] file_format = op.splitext(fname)[1].lower() - with open(fname, "r") as f: + with open(fname) as f: if file_format != ".xyz": f.readline() # skip header delimiter = "," if file_format == ".csv" else "\t" diff --git a/mne/channels/channels.py b/mne/channels/channels.py index bc0f52cb56c..6ad43f32ee5 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -153,7 +153,7 @@ def equalize_channels(instances, copy=True, verbose=None): from ..evoked import Evoked from ..forward import Forward from ..io import BaseRaw - from ..time_frequency import CrossSpectralDensity, _BaseTFR + from ..time_frequency import BaseTFR, CrossSpectralDensity # Instances need to have a `ch_names` attribute and a `pick_channels` # method that supports `ordered=True`. @@ -161,7 +161,7 @@ def equalize_channels(instances, copy=True, verbose=None): BaseRaw, BaseEpochs, Evoked, - _BaseTFR, + BaseTFR, Forward, Covariance, CrossSpectralDensity, @@ -447,7 +447,7 @@ def pick_types( @verbose @legacy(alt="inst.pick(...)") - def pick_channels(self, ch_names, ordered=None, *, verbose=None): + def pick_channels(self, ch_names, ordered=True, *, verbose=None): """Pick some channels. Parameters @@ -549,7 +549,7 @@ def reorder_channels(self, ch_names): for ch_name in ch_names: ii = self.ch_names.index(ch_name) if ii in idx: - raise ValueError("Channel name repeated: %s" % (ch_name,)) + raise ValueError(f"Channel name repeated: {ch_name}") idx.append(ii) return self._pick_drop_channels(idx) @@ -585,14 +585,13 @@ def drop_channels(self, ch_names, on_missing="raise"): all_str = all([isinstance(ch, str) for ch in ch_names]) except TypeError: raise ValueError( - "'ch_names' must be iterable, got " - "type {} ({}).".format(type(ch_names), ch_names) + f"'ch_names' must be iterable, got type {type(ch_names)} ({ch_names})." ) if not all_str: raise ValueError( "Each element in 'ch_names' must be str, got " - "{}.".format([type(ch) for ch in ch_names]) + f"{[type(ch) for ch in ch_names]}." ) missing = [ch for ch in ch_names if ch not in self.ch_names] @@ -608,8 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"): def _pick_drop_channels(self, idx, *, verbose=None): # avoid circular imports from ..io import BaseRaw - from ..time_frequency import AverageTFR, EpochsTFR - from ..time_frequency.spectrum import BaseSpectrum msg = "adding, dropping, or reordering channels" if isinstance(self, BaseRaw): @@ -634,10 +631,8 @@ def _pick_drop_channels(self, idx, *, verbose=None): if mat is not None: setattr(self, key, mat[idx][:, idx]) - if isinstance(self, BaseSpectrum): + if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs axis = self._dims.index("channel") - elif isinstance(self, (AverageTFR, EpochsTFR)): - axis = -3 else: # All others (Evoked, Epochs, Raw) have chs axis=-2 axis = -2 if hasattr(self, "_data"): # skip non-preloaded Raw @@ -840,6 +835,8 @@ def interpolate_bads( - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` + - ``"ecog"`` channels support ``"spline"`` (default) and ``"nan"`` + - ``"seeg"`` channels support ``"spline"`` (default) and ``"nan"`` None is an alias for:: @@ -871,9 +868,12 @@ def interpolate_bads( .. versionadded:: 0.9.0 """ from .interpolation import ( + _interpolate_bads_ecog, _interpolate_bads_eeg, _interpolate_bads_meeg, + _interpolate_bads_nan, _interpolate_bads_nirs, + _interpolate_bads_seeg, ) _check_preload(self, "interpolation") @@ -895,35 +895,48 @@ def interpolate_bads( "eeg": ("spline", "MNE", "nan"), "meg": ("MNE", "nan"), "fnirs": ("nearest", "nan"), + "ecog": ("spline", "nan"), + "seeg": ("spline", "nan"), } for key in method: - _check_option("method[key]", key, ("meg", "eeg", "fnirs")) + _check_option("method[key]", key, tuple(valids)) _check_option(f"method['{key}']", method[key], valids[key]) logger.info("Setting channel interpolation method to %s.", method) idx = _picks_to_idx(self.info, list(method), exclude=(), allow_empty=True) if idx.size == 0 or len(pick_info(self.info, idx)["bads"]) == 0: warn("No bad channels to interpolate. Doing nothing...") return self + for ch_type in method.copy(): + idx = _picks_to_idx(self.info, ch_type, exclude=(), allow_empty=True) + if len(pick_info(self.info, idx)["bads"]) == 0: + method.pop(ch_type) logger.info("Interpolating bad channels.") - origin = _check_origin(origin, self.info) + needs_origin = [key != "seeg" and val != "nan" for key, val in method.items()] + if any(needs_origin): + origin = _check_origin(origin, self.info) + for ch_type, interp in method.items(): + if interp == "nan": + _interpolate_bads_nan(self, ch_type, exclude=exclude) if method.get("eeg", "") == "spline": _interpolate_bads_eeg(self, origin=origin, exclude=exclude) - eeg_mne = False - elif "eeg" not in method: - eeg_mne = False - else: - eeg_mne = True - if "meg" in method or eeg_mne: + meg_mne = method.get("meg", "") == "MNE" + eeg_mne = method.get("eeg", "") == "MNE" + if meg_mne or eeg_mne: _interpolate_bads_meeg( self, mode=mode, - origin=origin, + meg=meg_mne, eeg=eeg_mne, + origin=origin, exclude=exclude, method=method, ) - if "fnirs" in method: - _interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"]) + if method.get("fnirs", "") == "nearest": + _interpolate_bads_nirs(self, exclude=exclude) + if method.get("ecog", "") == "spline": + _interpolate_bads_ecog(self, origin=origin, exclude=exclude) + if method.get("seeg", "") == "spline": + _interpolate_bads_seeg(self, exclude=exclude) if reset_bads is True: if "nan" in method.values(): @@ -967,7 +980,7 @@ def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None): elif callable(mapping): new_names = [(ci, mapping(ch_name)) for ci, ch_name in enumerate(ch_names)] else: - raise ValueError("mapping must be callable or dict, not %s" % (type(mapping),)) + raise ValueError(f"mapping must be callable or dict, not {type(mapping)}") # check we got all strings out of the mapping for new_name in new_names: @@ -1057,9 +1070,7 @@ class _BuiltinChannelAdjacency: name="bti248grad", description="BTI 248 gradiometer system", fname="bti248grad_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="bti248grad_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="bti248grad_neighb.mat"), ), _BuiltinChannelAdjacency( name="ctf64", @@ -1083,25 +1094,19 @@ class _BuiltinChannelAdjacency: name="easycap32ch-avg", description="", fname="easycap32ch-avg_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycap32ch-avg_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycap32ch-avg_neighb.mat"), ), _BuiltinChannelAdjacency( name="easycap64ch-avg", description="", fname="easycap64ch-avg_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycap64ch-avg_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycap64ch-avg_neighb.mat"), ), _BuiltinChannelAdjacency( name="easycap128ch-avg", description="", fname="easycap128ch-avg_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycap128ch-avg_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycap128ch-avg_neighb.mat"), ), _BuiltinChannelAdjacency( name="easycapM1", @@ -1113,25 +1118,19 @@ class _BuiltinChannelAdjacency: name="easycapM11", description="Easycap M11", fname="easycapM11_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycapM11_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycapM11_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="easycapM14", description="Easycap M14", fname="easycapM14_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycapM14_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycapM14_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="easycapM15", description="Easycap M15", fname="easycapM15_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="easycapM15_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="easycapM15_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="KIT-157", @@ -1179,49 +1178,37 @@ class _BuiltinChannelAdjacency: name="neuromag306mag", description="Neuromag306, only magnetometers", fname="neuromag306mag_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="neuromag306mag_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="neuromag306mag_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="neuromag306planar", description="Neuromag306, only planar gradiometers", fname="neuromag306planar_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="neuromag306planar_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="neuromag306planar_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="neuromag122cmb", description="Neuromag122, only combined planar gradiometers", fname="neuromag122cmb_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="neuromag122cmb_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="neuromag122cmb_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="neuromag306cmb", description="Neuromag306, only combined planar gradiometers", fname="neuromag306cmb_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="neuromag306cmb_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="neuromag306cmb_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="ecog256", description="ECOG 256channels, average referenced", fname="ecog256_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="ecog256_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="ecog256_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="ecog256bipolar", description="ECOG 256channels, bipolar referenced", fname="ecog256bipolar_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="ecog256bipolar_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="ecog256bipolar_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="eeg1010_neighb", @@ -1263,33 +1250,25 @@ class _BuiltinChannelAdjacency: name="language29ch-avg", description="MPI for Psycholinguistic: Averaged 29-channel cap", fname="language29ch-avg_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="language29ch-avg_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="language29ch-avg_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="mpi_59_channels", description="MPI for Psycholinguistic: 59-channel cap", fname="mpi_59_channels_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="mpi_59_channels_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="mpi_59_channels_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="yokogawa160", description="", fname="yokogawa160_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="yokogawa160_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="yokogawa160_neighb.mat"), # noqa: E501 ), _BuiltinChannelAdjacency( name="yokogawa440", description="", fname="yokogawa440_neighb.mat", - source_url=_ft_neighbor_url_t.substitute( - fname="yokogawa440_neighb.mat" - ), # noqa: E501 + source_url=_ft_neighbor_url_t.substitute(fname="yokogawa440_neighb.mat"), # noqa: E501 ), ] @@ -2129,9 +2108,7 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): else: spacing = "old" elif info is not None: - raise TypeError( - "info must be an instance of Info or None, not %s" % (type(info),) - ) + raise TypeError(f"info must be an instance of Info or None, not {type(info)}") else: # info is None spacing = "old" @@ -2143,7 +2120,7 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): # use this to make sure we find at least one match for each name name_found = {n: False for n in name} - with open(fname, "r") as fid: + with open(fname) as fid: sel = [] for line in fid: line = line.strip() diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 77e660901a9..6c5042d1d04 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -6,13 +6,14 @@ import numpy as np from numpy.polynomial.legendre import legval +from scipy.interpolate import RectBivariateSpline from scipy.linalg import pinv from scipy.spatial.distance import pdist, squareform from .._fiff.meas_info import _simplify_info from .._fiff.pick import pick_channels, pick_info, pick_types from ..surface import _normalize_vectors -from ..utils import _check_option, _validate_type, logger, verbose, warn +from ..utils import _validate_type, logger, verbose, warn def _calc_h(cosang, stiffness=4, n_legendre_terms=50): @@ -132,13 +133,13 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): @verbose -def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): +def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None): if exclude is None: exclude = list() bads_idx = np.zeros(len(inst.ch_names), dtype=bool) goods_idx = np.zeros(len(inst.ch_names), dtype=bool) - picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude) + picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude) inst.info._check_consistency() bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] @@ -165,16 +166,18 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): pos_good = pos[goods_idx_pos] - origin pos_bad = pos[bads_idx_pos] - origin - logger.info( - "Computing interpolation matrix from {} sensor " - "positions".format(len(pos_good)) - ) + logger.info(f"Computing interpolation matrix from {len(pos_good)} sensor positions") interpolation = _make_interpolation_matrix(pos_good, pos_bad) - logger.info("Interpolating {} sensors".format(len(pos_bad))) + logger.info(f"Interpolating {len(pos_bad)} sensors") _do_interp_dots(inst, interpolation, goods_idx, bads_idx) +@verbose +def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None): + _interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose) + + def _interpolate_bads_meg( inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False ): @@ -183,6 +186,26 @@ def _interpolate_bads_meg( ) +@verbose +def _interpolate_bads_nan( + inst, + ch_type, + ref_meg=False, + exclude=(), + *, + verbose=None, +): + info = _simplify_info(inst.info) + picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True}) + use_ch_names = [inst.info["ch_names"][p] for p in picks_type] + bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] + if len(bads_type) == 0 or len(picks_type) == 0: + return + # select the bad channels to be interpolated + picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) + inst._data[..., picks_bad, :] = np.nan + + @verbose def _interpolate_bads_meeg( inst, @@ -216,10 +239,6 @@ def _interpolate_bads_meeg( # select the bad channels to be interpolated picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) - if method[ch_type] == "nan": - inst._data[picks_bad] = np.nan - continue - # do MNE based interpolation if ch_type == "eeg": picks_to = picks_type @@ -235,7 +254,7 @@ def _interpolate_bads_meeg( @verbose -def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): +def _interpolate_bads_nirs(inst, exclude=(), verbose=None): from mne.preprocessing.nirs import _validate_nirs_info if len(pick_types(inst.info, fnirs=True, exclude=())) == 0: @@ -254,25 +273,93 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): chs = [inst.info["chs"][i] for i in picks_nirs] locs3d = np.array([ch["loc"][:3] for ch in chs]) - _check_option("fnirs_method", method, ["nearest", "nan"]) - - if method == "nearest": - dist = pdist(locs3d) - dist = squareform(dist) - - for bad in picks_bad: - dists_to_bad = dist[bad] - # Ignore distances to self - dists_to_bad[dists_to_bad == 0] = np.inf - # Ignore distances to other bad channels - dists_to_bad[bads_mask] = np.inf - # Find closest remaining channels for same frequency - closest_idx = np.argmin(dists_to_bad) + (bad % 2) - inst._data[bad] = inst._data[closest_idx] - else: - assert method == "nan" - inst._data[picks_bad] = np.nan + dist = pdist(locs3d) + dist = squareform(dist) + + for bad in picks_bad: + dists_to_bad = dist[bad] + # Ignore distances to self + dists_to_bad[dists_to_bad == 0] = np.inf + # Ignore distances to other bad channels + dists_to_bad[bads_mask] = np.inf + # Find closest remaining channels for same frequency + closest_idx = np.argmin(dists_to_bad) + (bad % 2) + inst._data[bad] = inst._data[closest_idx] + # TODO: this seems like a bug because it does not respect reset_bads inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] return inst + + +def _find_seeg_electrode_shaft(pos, tol=2e-3): + # 1) find nearest neighbor to define the electrode shaft line + # 2) find all contacts on the same line + + dist = squareform(pdist(pos)) + np.fill_diagonal(dist, np.inf) + + shafts = list() + for i, n1 in enumerate(pos): + if any([i in shaft for shaft in shafts]): + continue + n2 = pos[np.argmin(dist[i])] # 1 + # https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html + shaft_dists = np.linalg.norm( + np.cross((pos - n1), (pos - n2)), axis=1 + ) / np.linalg.norm(n2 - n1) + shafts.append(np.where(shaft_dists < tol)[0]) # 2 + return shafts + + +@verbose +def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): + if exclude is None: + exclude = list() + picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude) + inst.info._check_consistency() + bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"]) + + if len(picks) == 0 or bads_idx.sum() == 0: + return + + pos = inst._get_channel_positions(picks) + + # Make sure only sEEG are used + bads_idx_pos = bads_idx[picks] + + shafts = _find_seeg_electrode_shaft(pos, tol=tol) + + # interpolate the bad contacts + picks_bad = list(np.where(bads_idx_pos)[0]) + for shaft in shafts: + bads_shaft = np.array([idx for idx in picks_bad if idx in shaft]) + if bads_shaft.size == 0: + continue + goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)] + if goods_shaft.size < 2: + raise RuntimeError( + f"{goods_shaft.size} good contact(s) found in a line " + f" with {np.array(inst.ch_names)[bads_shaft]}, " + "at least 2 are required for interpolation. " + "Dropping this channel/these channels is recommended." + ) + logger.debug( + f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using " + f"data from {np.array(inst.ch_names)[goods_shaft]}" + ) + bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0] + goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0] + n1, n2 = pos[shaft][:2] + ts = np.array( + [ + -np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2 + for n0 in pos[shaft] + ] + ) + if np.any(np.diff(ts) < 0): + ts *= -1 + y = np.arange(inst._data.shape[-1]) + inst._data[bads_shaft] = RectBivariateSpline( + x=ts[goods_shaft_idx], y=y, z=inst._data[goods_shaft] + )(x=ts[bads_shaft_idx], y=y) # 3 diff --git a/mne/channels/layout.py b/mne/channels/layout.py index a0f12cc594f..d19794115d7 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -58,7 +58,7 @@ class Layout: The type of Layout (e.g. 'Vectorview-all'). """ - def __init__(self, box, pos, names, ids, kind): # noqa: D102 + def __init__(self, box, pos, names, ids, kind): self.box = box self.pos = pos self.names = names @@ -85,7 +85,7 @@ def save(self, fname, overwrite=False): height = self.pos[:, 3] fname = _check_fname(fname, overwrite=overwrite, name=fname) if fname.suffix == ".lout": - out_str = "%8.2f %8.2f %8.2f %8.2f\n" % self.box + out_str = "{:8.2f} {:8.2f} {:8.2f} {:8.2f}\n".format(*self.box) elif fname.suffix == ".lay": out_str = "" else: @@ -107,7 +107,7 @@ def save(self, fname, overwrite=False): def __repr__(self): """Return the string representation.""" - return "" % ( + return "".format( self.kind, ", ".join(self.names[:3]), ) @@ -1181,7 +1181,7 @@ def generate_2d_layout( if ch_indices is None: ch_indices = np.arange(xy.shape[0]) if ch_names is None: - ch_names = ["{}".format(i) for i in ch_indices] + ch_names = list(map(str, ch_indices)) if len(ch_names) != len(ch_indices): raise ValueError("# channel names and indices must be equal") @@ -1205,7 +1205,7 @@ def generate_2d_layout( # Create box and pos variable box = _box_size(np.vstack([x, y]).T, padding=pad) box = (0, 0, box[0], box[1]) - w, h = [np.array([i] * x.shape[0]) for i in [w, h]] + w, h = (np.array([i] * x.shape[0]) for i in [w, h]) loc_params = np.vstack([x, y, w, h]).T layout = Layout(box, loc_params, ch_names, ch_indices, name) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 3d7ad340df8..abc9f2f62b7 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -785,7 +785,7 @@ def read_dig_dat(fname): fname = _check_fname(fname, overwrite="read", must_exist=True) - with open(fname, "r") as fid: + with open(fname) as fid: lines = fid.readlines() ch_names, poss = list(), list() @@ -796,8 +796,8 @@ def read_dig_dat(fname): continue elif len(items) != 5: raise ValueError( - "Error reading %s, line %s has unexpected number of entries:\n" - "%s" % (fname, i, line.rstrip()) + f"Error reading {fname}, line {i} has unexpected number of entries:\n" + f"{line.rstrip()}" ) num = items[1] if num == "67": @@ -1352,7 +1352,7 @@ def _read_isotrak_elp_points(fname): and 'points'. """ value_pattern = r"\-?\d+\.?\d*e?\-?\d*" - coord_pattern = r"({0})\s+({0})\s+({0})\s*$".format(value_pattern) + coord_pattern = rf"({value_pattern})\s+({value_pattern})\s+({value_pattern})\s*$" with open(fname) as fid: file_str = fid.read() @@ -1474,11 +1474,9 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): data["ch_pos"] = OrderedDict(zip(ch_names, points)) else: raise ValueError( - ( - "Length of ``ch_names`` does not match the number of points" - " in {fname}. Expected ``ch_names`` length {n_points:d}," - " given {n_chnames:d}" - ).format(fname=fname, n_points=points.shape[0], n_chnames=len(ch_names)) + "Length of ``ch_names`` does not match the number of points in " + f"{fname}. Expected ``ch_names`` length {points.shape[0]}, given " + f"{len(ch_names)}" ) return make_dig_montage(**data) @@ -1486,7 +1484,7 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): def _is_polhemus_fastscan(fname): header = "" - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: if not line.startswith("%"): break @@ -1621,7 +1619,7 @@ def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): if ext in SUPPORTED_FILE_EXT["eeglab"]: if head_size is None: - raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) + raise ValueError(f"``head_size`` cannot be None for '{ext}'") ch_names, pos = _read_eeglab_locations(fname) scale = head_size / np.median(np.linalg.norm(pos, axis=-1)) pos *= scale @@ -1642,7 +1640,7 @@ def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): elif ext in SUPPORTED_FILE_EXT["generic (Theta-phi in degrees)"]: if head_size is None: - raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) + raise ValueError(f"``head_size`` cannot be None for '{ext}'") montage = _read_theta_phi_in_degrees( fname, head_size=head_size, fid_names=("Nz", "LPA", "RPA") ) @@ -1711,11 +1709,9 @@ def compute_dev_head_t(montage): if not (len(hpi_head) == len(hpi_dev) and len(hpi_dev) > 0): raise ValueError( - ( - "To compute Device-to-Head transformation, the same number of HPI" - " points in device and head coordinates is required. (Got {dev}" - " points in device and {head} points in head coordinate systems)" - ).format(dev=len(hpi_dev), head=len(hpi_head)) + "To compute Device-to-Head transformation, the same number of HPI" + f" points in device and head coordinates is required. (Got {len(hpi_dev)}" + f" points in device and {len(hpi_head)} points in head coordinate systems)" ) trans = _quat_to_affine(_fit_matched_points(hpi_dev, hpi_head)[0]) diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index c3bbcdb33dc..adfe63f93d9 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -54,7 +54,7 @@ from mne.parallel import parallel_func from mne.utils import requires_good_network -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" base_dir = io_dir / "tests" / "data" raw_fname = base_dir / "test_raw.fif" eve_fname = base_dir / "test-eve.fif" @@ -438,8 +438,8 @@ def test_1020_selection(): raw = raw.rename_channels(dict(zip(raw.ch_names, montage.ch_names))) raw.set_montage(montage) - for input in ("a_string", 100, raw, [1, 2]): - pytest.raises(TypeError, make_1020_channel_selections, input) + for input_ in ("a_string", 100, raw, [1, 2]): + pytest.raises(TypeError, make_1020_channel_selections, input_) sels = make_1020_channel_selections(raw.info) # are all frontal channels placed before all occipital channels? diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 9630607caae..7e282562955 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -10,6 +10,7 @@ from mne import Epochs, pick_channels, pick_types, read_events from mne._fiff.constants import FIFF from mne._fiff.proj import _has_eeg_average_ref_proj +from mne.channels import make_dig_montage from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx @@ -20,7 +21,7 @@ ) from mne.utils import _record_warnings -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" raw_fname_ctf = base_dir / "test_ctf_raw.fif" @@ -329,6 +330,55 @@ def test_interpolation_nirs(): assert raw_haemo.info["bads"] == [] +@testing.requires_testing_data +def test_interpolation_ecog(): + """Test interpolation for ECoG.""" + raw, epochs_eeg = _load_data("eeg") + bads = ["EEG 012"] + bads_mask = np.isin(epochs_eeg.ch_names, bads) + + epochs_ecog = epochs_eeg.set_channel_types( + {ch: "ecog" for ch in epochs_eeg.ch_names} + ) + epochs_ecog.info["bads"] = bads + + # check that interpolation changes the data in raw + raw_ecog = RawArray(data=epochs_ecog._data[0], info=epochs_ecog.info) + raw_before = raw_ecog.copy() + raw_after = raw_ecog.interpolate_bads(method=dict(ecog="spline")) + assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask]) + assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask]) + + +@testing.requires_testing_data +def test_interpolation_seeg(): + """Test interpolation for sEEG.""" + raw, epochs_eeg = _load_data("eeg") + bads = ["EEG 012"] + bads_mask = np.isin(epochs_eeg.ch_names, bads) + epochs_seeg = epochs_eeg.set_channel_types( + {ch: "seeg" for ch in epochs_eeg.ch_names} + ) + epochs_seeg.info["bads"] = bads + + # check that interpolation changes the data in raw + raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info) + raw_before = raw_seeg.copy() + with pytest.raises(RuntimeError, match="1 good contact"): + raw_seeg.interpolate_bads(method=dict(seeg="spline")) + montage = raw_seeg.get_montage() + pos = montage.get_positions() + ch_pos = pos.pop("ch_pos") + n0 = ch_pos[epochs_seeg.ch_names[0]] + n1 = ch_pos[epochs_seeg.ch_names[1]] + for i, ch in enumerate(epochs_seeg.ch_names[2:]): + ch_pos[ch] = n0 + (n1 - n0) * (i + 2) + raw_seeg.set_montage(make_dig_montage(ch_pos, **pos)) + raw_after = raw_seeg.interpolate_bads(method=dict(seeg="spline")) + assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask]) + assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask]) + + def test_nan_interpolation(raw): """Test 'nan' method for interpolating bads.""" ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG) diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index 05caa37735b..15eb50b7975 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -32,7 +32,7 @@ from mne.defaults import HEAD_SIZE_DEFAULT from mne.io import read_info, read_raw_kit -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" fif_fname = io_dir / "tests" / "data" / "test_raw.fif" lout_path = io_dir / "tests" / "data" bti_dir = io_dir / "bti" / "tests" / "data" diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index f4da1e6932e..08971ab803b 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -95,7 +95,7 @@ mgh70_fname = data_path / "SSS" / "mgh70_raw.fif" subjects_dir = data_path / "subjects" -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" kit_dir = io_dir / "kit" / "tests" / "data" elp = kit_dir / "test_elp.txt" hsp = kit_dir / "test_hsp.txt" @@ -513,6 +513,8 @@ def test_documented(): ) def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_path): """Test that we have an equivalent of read_montage for all file formats.""" + if file_content.startswith("= 3 @@ -470,11 +470,9 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): if "moments" in hpi_result: logger.debug("Hpi coil moments (%d %d):" % hpi_result["moments"].shape[::-1]) for moment in hpi_result["moments"]: - logger.debug("%g %g %g" % tuple(moment)) + logger.debug("{:g} {:g} {:g}".format(*tuple(moment))) errors = np.linalg.norm(hpi_rrs - hpi_rrs_fit, axis=1) - logger.debug( - "HPIFIT errors: %s mm." % ", ".join("%0.1f" % (1000.0 * e) for e in errors) - ) + logger.debug(f"HPIFIT errors: {', '.join(f'{1000 * e:0.1f}' for e in errors)} mm.") if errors.sum() < len(errors) * dist_limit: logger.info("HPI consistency of isotrak and hpifit is OK.") elif not adjust and (len(used) == len(hpi_dig)): @@ -487,24 +485,22 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): if not adjust: if err >= dist_limit: warn( - "Discrepancy of HPI coil %d isotrak and hpifit is " - "%.1f mm!" % (hi + 1, d) + f"Discrepancy of HPI coil {hi + 1} isotrak and hpifit is " + f"{d:.1f} mm!" ) elif hi + 1 not in used: if goodness[hi] >= good_limit: logger.info( - "Note: HPI coil %d isotrak is adjusted by " - "%.1f mm!" % (hi + 1, d) + f"Note: HPI coil {hi + 1} isotrak is adjusted by {d:.1f} mm!" ) hpi_rrs[hi] = r_fit else: warn( - "Discrepancy of HPI coil %d isotrak and hpifit of " - "%.1f mm was not adjusted!" % (hi + 1, d) + f"Discrepancy of HPI coil {hi + 1} isotrak and hpifit of " + f"{d:.1f} mm was not adjusted!" ) logger.debug( - "HP fitting limits: err = %.1f mm, gval = %.3f." - % (1000 * dist_limit, good_limit) + f"HP fitting limits: err = {1000 * dist_limit:.1f} mm, gval = {good_limit:.3f}." ) return hpi_rrs.astype(float) @@ -643,8 +639,9 @@ def _setup_hpi_amplitude_fitting( else: line_freqs = np.zeros([0]) logger.info( - "Line interference frequencies: %s Hz" - % " ".join(["%d" % lf for lf in line_freqs]) + "Line interference frequencies: {} Hz".format( + " ".join([f"{lf}" for lf in line_freqs]) + ) ) # worry about resampled/filtered data. # What to do e.g. if Raw has been resampled and some of our @@ -657,8 +654,8 @@ def _setup_hpi_amplitude_fitting( hpi_ons = hpi_ons[keepers] elif not keepers.all(): raise RuntimeError( - "Found HPI frequencies %s above the lowpass " - "(or Nyquist) frequency %0.1f" % (hpi_freqs[~keepers].tolist(), highest) + f"Found HPI frequencies {hpi_freqs[~keepers].tolist()} above the lowpass (" + f"or Nyquist) frequency {highest:0.1f}" ) # calculate optimal window length. if isinstance(t_window, str): @@ -671,8 +668,8 @@ def _setup_hpi_amplitude_fitting( t_window = 0.2 t_window = float(t_window) if t_window <= 0: - raise ValueError("t_window (%s) must be > 0" % (t_window,)) - logger.info("Using time window: %0.1f ms" % (1000 * t_window,)) + raise ValueError(f"t_window ({t_window}) must be > 0") + logger.info(f"Using time window: {1000 * t_window:0.1f} ms") window_nsamp = np.rint(t_window * info["sfreq"]).astype(int) model = _setup_hpi_glm(hpi_freqs, line_freqs, info["sfreq"], window_nsamp) inv_model = np.linalg.pinv(model) @@ -869,25 +866,22 @@ def _check_chpi_param(chpi_, name): want_keys = list(want_ndims.keys()) + extra_keys if set(want_keys).symmetric_difference(chpi_): raise ValueError( - "%s must be a dict with entries %s, got %s" - % (name, want_keys, sorted(chpi_.keys())) + f"{name} must be a dict with entries {want_keys}, got " + f"{sorted(chpi_.keys())}" ) n_times = None for key, want_ndim in want_ndims.items(): - key_str = "%s[%s]" % (name, key) + key_str = f"{name}[{key}]" val = chpi_[key] _validate_type(val, np.ndarray, key_str) shape = val.shape if val.ndim != want_ndim: - raise ValueError( - "%s must have ndim=%d, got %d" % (key_str, want_ndim, val.ndim) - ) + raise ValueError(f"{key_str} must have ndim={want_ndim}, got {val.ndim}") if n_times is None and key != "proj": n_times = shape[0] if n_times != shape[0] and key != "proj": raise ValueError( - "%s have inconsistent number of time " - "points in %s" % (name, want_keys) + f"{name} have inconsistent number of time points in {want_keys}" ) if name == "chpi_locs": n_coils = chpi_["rrs"].shape[1] @@ -895,15 +889,14 @@ def _check_chpi_param(chpi_, name): val = chpi_[key] if val.shape[1] != n_coils: raise ValueError( - 'chpi_locs["rrs"] had values for %d coils but' - ' chpi_locs["%s"] had values for %d coils' - % (n_coils, key, val.shape[1]) + f'chpi_locs["rrs"] had values for {n_coils} coils but ' + f'chpi_locs["{key}"] had values for {val.shape[1]} coils' ) for key in ("rrs", "moments"): val = chpi_[key] if val.shape[2] != 3: raise ValueError( - 'chpi_locs["%s"].shape[2] must be 3, got ' "shape %s" % (key, shape) + f'chpi_locs["{key}"].shape[2] must be 3, got shape {shape}' ) else: assert name == "chpi_amplitudes" @@ -912,8 +905,8 @@ def _check_chpi_param(chpi_, name): n_ch = len(proj["data"]["col_names"]) if slopes.shape[0] != n_times or slopes.shape[2] != n_ch: raise ValueError( - "slopes must have shape[0]==%d and shape[2]==%d," - " got shape %s" % (n_times, n_ch, slopes.shape) + f"slopes must have shape[0]=={n_times} and shape[2]=={n_ch}, got shape " + f"{slopes.shape}" ) @@ -1003,9 +996,9 @@ def compute_head_pos( n_good = ((g_coils >= gof_limit) & (errs < dist_limit)).sum() if n_good < 3: warn( - _time_prefix(fit_time) + "%s/%s good HPI fits, cannot " - "determine the transformation (%s mm/GOF)!" - % ( + _time_prefix(fit_time) + + "{}/{} good HPI fits, cannot " + "determine the transformation ({} mm/GOF)!".format( n_good, n_coils, ", ".join( @@ -1068,13 +1061,13 @@ def compute_head_pos( v = d / dt # m/s d = 100 * np.linalg.norm(this_quat[3:] - pos_0) # dis from 1st logger.debug( - " #t = %0.3f, #e = %0.2f cm, #g = %0.3f, " - "#v = %0.2f cm/s, #r = %0.2f rad/s, #d = %0.2f cm" - % (fit_time, 100 * errs.mean(), g, 100 * v, r, d) + f" #t = {fit_time:0.3f}, #e = {100 * errs.mean():0.2f} cm, #g = {g:0.3f}" + f", #v = {100 * v:0.2f} cm/s, #r = {r:0.2f} rad/s, #d = {d:0.2f} cm" ) logger.debug( - " #t = %0.3f, #q = %s " - % (fit_time, " ".join(map("{:8.5f}".format, this_quat))) + " #t = {:0.3f}, #q = {} ".format( + fit_time, " ".join(map("{:8.5f}".format, this_quat)) + ) ) quats.append( @@ -1504,7 +1497,7 @@ def filter_chpi( raise RuntimeError("raw data must be preloaded") t_step = float(t_step) if t_step <= 0: - raise ValueError("t_step (%s) must be > 0" % (t_step,)) + raise ValueError(f"t_step ({t_step}) must be > 0") n_step = int(np.ceil(t_step * raw.info["sfreq"])) if include_line and raw.info["line_freq"] is None: raise RuntimeError( @@ -1617,11 +1610,8 @@ def get_active_chpi(raw, *, on_missing="raise", verbose=None): # check whether we have a neuromag system if system not in ["122m", "306m"]: raise NotImplementedError( - ( - "Identifying active HPI channels" - " is not implemented for other systems" - " than neuromag." - ) + "Identifying active HPI channels is not implemented for other systems than " + "neuromag." ) # extract hpi info chpi_info = get_chpi_info(raw.info, on_missing=on_missing) diff --git a/mne/commands/mne_anonymize.py b/mne/commands/mne_anonymize.py index 3c0a7ebfd27..a282f016ede 100644 --- a/mne/commands/mne_anonymize.py +++ b/mne/commands/mne_anonymize.py @@ -52,7 +52,7 @@ def mne_anonymize(fif_fname, out_fname, keep_his, daysback, overwrite): dir_name = op.split(fif_fname)[0] if out_fname is None: fif_bname = op.basename(fif_fname) - out_fname = op.join(dir_name, "{}-{}".format(ANONYMIZE_FILE_PREFIX, fif_bname)) + out_fname = op.join(dir_name, f"{ANONYMIZE_FILE_PREFIX}-{fif_bname}") elif not op.isabs(out_fname): out_fname = op.join(dir_name, out_fname) diff --git a/mne/commands/mne_browse_raw.py b/mne/commands/mne_browse_raw.py index 0c3d81a16e9..2e662e1768b 100644 --- a/mne/commands/mne_browse_raw.py +++ b/mne/commands/mne_browse_raw.py @@ -84,7 +84,7 @@ def run(): "-p", "--preload", dest="preload", - help="Preload raw data (for faster navigaton)", + help="Preload raw data (for faster navigation)", default=False, action="store_true", ) diff --git a/mne/commands/mne_bti2fiff.py b/mne/commands/mne_bti2fiff.py index c8664ca5a35..2c4e4083df1 100644 --- a/mne/commands/mne_bti2fiff.py +++ b/mne/commands/mne_bti2fiff.py @@ -30,7 +30,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - import sys import mne diff --git a/mne/commands/mne_clean_eog_ecg.py b/mne/commands/mne_clean_eog_ecg.py index 8f18f16f6cb..10b84540756 100644 --- a/mne/commands/mne_clean_eog_ecg.py +++ b/mne/commands/mne_clean_eog_ecg.py @@ -14,7 +14,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - import sys import mne diff --git a/mne/commands/mne_compute_proj_ecg.py b/mne/commands/mne_compute_proj_ecg.py index f5f798a4968..caab628bbb2 100644 --- a/mne/commands/mne_compute_proj_ecg.py +++ b/mne/commands/mne_compute_proj_ecg.py @@ -256,7 +256,7 @@ def run(): raise ValueError('qrsthr must be "auto" or a float') if bad_fname is not None: - with open(bad_fname, "r") as fid: + with open(bad_fname) as fid: bads = [w.rstrip() for w in fid.readlines()] print("Bad channels read : %s" % bads) else: diff --git a/mne/commands/mne_compute_proj_eog.py b/mne/commands/mne_compute_proj_eog.py index 456bf3b6080..165818facc4 100644 --- a/mne/commands/mne_compute_proj_eog.py +++ b/mne/commands/mne_compute_proj_eog.py @@ -253,7 +253,7 @@ def run(): ch_name = options.ch_name if bad_fname is not None: - with open(bad_fname, "r") as fid: + with open(bad_fname) as fid: bads = [w.rstrip() for w in fid.readlines()] print("Bad channels read : %s" % bads) else: diff --git a/mne/commands/mne_coreg.py b/mne/commands/mne_coreg.py index b32e8b9e3d7..b0551346e43 100644 --- a/mne/commands/mne_coreg.py +++ b/mne/commands/mne_coreg.py @@ -41,25 +41,6 @@ def run(): default=None, help="FIFF file with digitizer data for coregistration", ) - parser.add_option( - "-t", - "--tabbed", - dest="tabbed", - action="store_true", - default=None, - help="Option for small screens: Combine " - "the data source panel and the coregistration panel " - "into a single panel with tabs.", - ) - parser.add_option( - "--no-guess-mri", - dest="guess_mri_subject", - action="store_false", - default=None, - help="Prevent the GUI from automatically guessing and " - "changing the MRI subject when a new head shape source " - "file is selected.", - ) parser.add_option( "--head-opacity", type=float, @@ -94,20 +75,6 @@ def run(): dest="interaction", help='Interaction style to use, can be "trackball" or ' '"terrain".', ) - parser.add_option( - "--scale", - type=float, - default=None, - dest="scale", - help="Scale factor for the scene.", - ) - parser.add_option( - "--simple-rendering", - action="store_false", - dest="advanced_rendering", - default=None, - help="Use simplified OpenGL rendering", - ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -134,18 +101,13 @@ def run(): faulthandler.enable() mne.gui.coregistration( - tabbed=options.tabbed, inst=options.inst, subject=options.subject, subjects_dir=subjects_dir, - guess_mri_subject=options.guess_mri_subject, head_opacity=options.head_opacity, head_high_res=head_high_res, trans=trans, - scrollable=None, interaction=options.interaction, - scale=options.scale, - advanced_rendering=options.advanced_rendering, show=True, block=True, verbose=options.verbose, diff --git a/mne/commands/mne_freeview_bem_surfaces.py b/mne/commands/mne_freeview_bem_surfaces.py index 7f1c6491ba1..504ca3378bf 100644 --- a/mne/commands/mne_freeview_bem_surfaces.py +++ b/mne/commands/mne_freeview_bem_surfaces.py @@ -41,8 +41,7 @@ def freeview_bem_surfaces(subject, subjects_dir, method): if not op.isdir(subject_dir): raise ValueError( - "Wrong path: '{}'. Check subjects-dir or" - "subject argument.".format(subject_dir) + f"Wrong path: '{subject_dir}'. Check subjects-dir or subject argument." ) env = os.environ.copy() diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 91ed2fdae60..85b7acd2883 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -17,6 +17,7 @@ $ mne make_scalp_surfaces --overwrite --subject sample """ + import os import sys diff --git a/mne/commands/mne_maxfilter.py b/mne/commands/mne_maxfilter.py index 5c631dcf457..4cbb1dc9522 100644 --- a/mne/commands/mne_maxfilter.py +++ b/mne/commands/mne_maxfilter.py @@ -222,7 +222,7 @@ def run(): out_fname = prefix + "_sss.fif" if origin is not None and os.path.exists(origin): - with open(origin, "r") as fid: + with open(origin) as fid: origin = fid.readlines()[0].strip() origin = mne.preprocessing.apply_maxfilter( diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index b5654ecab7f..f5f5dc8b343 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -120,7 +120,7 @@ def run(): subjects_dir = options.subjects_dir spacing = options.spacing ico = options.ico - oct = options.oct + oct_ = options.oct surface = options.surface n_jobs = options.n_jobs add_dist = options.add_dist @@ -130,20 +130,22 @@ def run(): overwrite = True if options.overwrite is not None else False # Parse source spacing option - spacing_options = [ico, oct, spacing] + spacing_options = [ico, oct_, spacing] n_options = len([x for x in spacing_options if x is not None]) + use_spacing = "oct6" if n_options > 1: raise ValueError("Only one spacing option can be set at the same time") elif n_options == 0: # Default to oct6 - use_spacing = "oct6" + pass elif n_options == 1: if ico is not None: use_spacing = "ico" + str(ico) - elif oct is not None: - use_spacing = "oct" + str(oct) + elif oct_ is not None: + use_spacing = "oct" + str(oct_) elif spacing is not None: use_spacing = spacing + del ico, oct_, spacing # Generate filename if fname is None: if subject_to is None: diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index ae5e84cbd58..fced5272efc 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -31,7 +31,6 @@ mne_flash_bem, mne_kit2fiff, mne_make_scalp_surfaces, - mne_maxfilter, mne_prepare_bem_model, mne_report, mne_setup_forward_model, @@ -44,7 +43,7 @@ mne_what, ) from mne.datasets import testing -from mne.io import read_info, read_raw_fif +from mne.io import read_info, read_raw_fif, show_fiff from mne.utils import ( ArgvSetter, _record_warnings, @@ -101,13 +100,22 @@ def test_compare_fiff(): check_usage(mne_compare_fiff) -def test_show_fiff(): +def test_show_fiff(tmp_path): """Test mne compare_fiff.""" check_usage(mne_show_fiff) with ArgvSetter((raw_fname,)): mne_show_fiff.run() with ArgvSetter((raw_fname, "--tag=102")): mne_show_fiff.run() + bad_fname = tmp_path / "test_bad_raw.fif" + with open(bad_fname, "wb") as fout: + with open(raw_fname, "rb") as fin: + fout.write(fin.read(100000)) + with pytest.warns(RuntimeWarning, match="Invalid tag"): + lines = show_fiff(bad_fname, output=list) + last_line = lines[-1] + assert last_line.endswith(">>>>BAD @9015") + assert "302 = FIFF_EPOCH (734412b >f4)" in last_line @requires_mne @@ -122,7 +130,7 @@ def test_clean_eog_ecg(tmp_path): with ArgvSetter(("-i", use_fname, "--quiet")): mne_clean_eog_ecg.run() for key, count in (("proj", 2), ("-eve", 3)): - fnames = glob.glob(op.join(tempdir, "*%s.fif" % key)) + fnames = glob.glob(op.join(tempdir, f"*{key}.fif")) assert len(fnames) == count @@ -206,32 +214,6 @@ def test_make_scalp_surfaces(tmp_path, monkeypatch): assert "SUBJECTS_DIR" not in os.environ -def test_maxfilter(): - """Test mne maxfilter.""" - check_usage(mne_maxfilter) - with ArgvSetter( - ( - "-i", - raw_fname, - "--st", - "--movecomp", - "--linefreq", - "60", - "--trans", - raw_fname, - ) - ) as out: - with pytest.warns(RuntimeWarning, match="Don't use"): - os.environ["_MNE_MAXFILTER_TEST"] = "true" - try: - mne_maxfilter.run() - finally: - del os.environ["_MNE_MAXFILTER_TEST"] - out = out.stdout.getvalue() - for check in ("maxfilter", "-trans", "-movecomp"): - assert check in out, check - - @pytest.mark.slowtest @testing.requires_testing_data def test_report(tmp_path): @@ -295,14 +277,14 @@ def test_watershed_bem(tmp_path): mne_watershed_bem.run() os.chmod(new_fname, old_mode) for s in ("outer_skin", "outer_skull", "inner_skull"): - assert not op.isfile(op.join(subject_path_new, "bem", "%s.surf" % s)) + assert not op.isfile(op.join(subject_path_new, "bem", f"{s}.surf")) with ArgvSetter(args): mne_watershed_bem.run() kwargs = dict(rtol=1e-5, atol=1e-5) for s in ("outer_skin", "outer_skull", "inner_skull"): rr, tris, vol_info = read_surface( - op.join(subject_path_new, "bem", "%s.surf" % s), read_metadata=True + op.join(subject_path_new, "bem", f"{s}.surf"), read_metadata=True ) assert_equal(len(tris), 20480) assert_equal(tris.min(), 0) @@ -390,14 +372,12 @@ def test_flash_bem(tmp_path): kwargs = dict(rtol=1e-5, atol=1e-5) for s in ("outer_skin", "outer_skull", "inner_skull"): - rr, tris = read_surface(op.join(subject_path_new, "bem", "%s.surf" % s)) + rr, tris = read_surface(op.join(subject_path_new, "bem", f"{s}.surf")) assert_equal(len(tris), 5120) assert_equal(tris.min(), 0) assert_equal(rr.shape[0], tris.max() + 1) # compare to the testing flash surfaces - rr_c, tris_c = read_surface( - op.join(subjects_dir, "sample", "bem", "%s.surf" % s) - ) + rr_c, tris_c = read_surface(op.join(subjects_dir, "sample", "bem", f"{s}.surf")) assert_allclose(rr, rr_c, **kwargs) assert_allclose(tris, tris_c, **kwargs) diff --git a/mne/commands/utils.py b/mne/commands/utils.py index 10334ce0acb..112ff27deca 100644 --- a/mne/commands/utils.py +++ b/mne/commands/utils.py @@ -68,7 +68,7 @@ def get_optparser(cmdpath, usage=None, prog_prefix="mne", version=None): command = command[len(prog_prefix) + 1 :] # +1 is for `_` character # Set prog - prog = prog_prefix + " {}".format(command) + prog = prog_prefix + f" {command}" # Set version if version is None: @@ -106,6 +106,6 @@ def print_help(): # noqa print_help() else: cmd = sys.argv[1] - cmd = importlib.import_module(".mne_%s" % (cmd,), "mne.commands") + cmd = importlib.import_module(f".mne_{cmd}", "mne.commands") sys.argv = sys.argv[1:] cmd.run() diff --git a/mne/conftest.py b/mne/conftest.py index 40a317b7da9..93657339b26 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -10,6 +10,7 @@ import shutil import sys import warnings +from collections import defaultdict from contextlib import contextmanager from pathlib import Path from textwrap import dedent @@ -32,6 +33,7 @@ _assert_no_instances, _check_qt_version, _pl, + _record_warnings, _TempDir, numerics, ) @@ -176,9 +178,6 @@ def pytest_configure(config): ignore:numpy\.core\.numeric is deprecated.*:DeprecationWarning ignore:numpy\.core\.multiarray is deprecated.*:DeprecationWarning ignore:The numpy\.fft\.helper has been made private.*:DeprecationWarning - # TODO: Should actually fix these two - ignore:scipy.signal.morlet2 is deprecated in SciPy.*:DeprecationWarning - ignore:The `needs_threshold` and `needs_proba`.*:FutureWarning # tqdm (Fedora) ignore:.*'tqdm_asyncio' object has no attribute 'last_print_t':pytest.PytestUnraisableExceptionWarning # Until mne-qt-browser > 0.5.2 is released @@ -187,6 +186,20 @@ def pytest_configure(config): ignore:Mesa version 10\.2\.4 is too old for translucent.*:RuntimeWarning # Matplotlib <-> NumPy 2.0 ignore:`row_stack` alias is deprecated.*:DeprecationWarning + # Matplotlib->tz + ignore:datetime.datetime.utcfromtimestamp.*:DeprecationWarning + # joblib + ignore:ast\.Num is deprecated.*:DeprecationWarning + ignore:Attribute n is deprecated and will be removed in Python 3\.14.*:DeprecationWarning + # numpydoc + ignore:ast\.NameConstant is deprecated and will be removed in Python 3\.14.*:DeprecationWarning + # pooch + ignore:Python 3\.14 will, by default, filter extracted tar archives.*:DeprecationWarning + # pandas + ignore:\n*Pyarrow will become a required dependency of pandas.*:DeprecationWarning + ignore:np\.find_common_type is deprecated.*:DeprecationWarning + # pyvista <-> NumPy 2.0 + ignore:__array_wrap__ must accept context and return_scalar arguments.*:DeprecationWarning """ # noqa: E501 for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() @@ -274,7 +287,7 @@ def matplotlib_config(): class CallbackRegistryReraise(orig): def __init__(self, exception_handler=None, signals=None): - super(CallbackRegistryReraise, self).__init__(exception_handler) + super().__init__(exception_handler) cbook.CallbackRegistry = CallbackRegistryReraise @@ -381,6 +394,34 @@ def epochs_spectrum(): return _get_epochs().load_data().compute_psd() +@pytest.fixture() +def epochs_tfr(): + """Get an EpochsTFR computed from mne.io.tests.data.""" + epochs = _get_epochs().load_data() + return epochs.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def average_tfr(epochs_tfr): + """Get an AverageTFR computed by averaging an EpochsTFR (this is small & fast).""" + return epochs_tfr.average() + + +@pytest.fixture() +def full_average_tfr(full_evoked): + """Get an AverageTFR computed from Evoked. + + This is slower than the `average_tfr` fixture, but a few TFR.plot_* tests need it. + """ + return full_evoked.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def raw_tfr(raw): + """Get a RawTFR computed from mne.io.tests.data.""" + return raw.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + @pytest.fixture() def epochs_empty(): """Get empty epochs from mne.io.tests.data.""" @@ -392,22 +433,31 @@ def epochs_empty(): @pytest.fixture(scope="session", params=[testing._pytest_param()]) -def _evoked(): - # This one is session scoped, so be sure not to modify it (use evoked - # instead) - evoked = mne.read_evokeds( - fname_evoked, condition="Left Auditory", baseline=(None, 0) - ) - evoked.crop(0, 0.2) - return evoked +def _full_evoked(): + # This is session scoped, so be sure not to modify its return value (use + # `full_evoked` fixture instead) + return mne.read_evokeds(fname_evoked, condition="Left Auditory", baseline=(None, 0)) + + +@pytest.fixture(scope="session", params=[testing._pytest_param()]) +def _evoked(_full_evoked): + # This is session scoped, so be sure not to modify its return value (use `evoked` + # fixture instead) + return _full_evoked.copy().crop(0, 0.2) @pytest.fixture() def evoked(_evoked): - """Get evoked data.""" + """Get truncated evoked data.""" return _evoked.copy() +@pytest.fixture() +def full_evoked(_full_evoked): + """Get full-duration evoked data (needed for, e.g., testing TFR).""" + return _full_evoked.copy() + + @pytest.fixture(scope="function", params=[testing._pytest_param()]) def noise_cov(): """Get a noise cov from the testing dataset.""" @@ -782,13 +832,14 @@ def mixed_fwd_cov_evoked(_evoked_cov_sphere, _all_src_types_fwd): @pytest.fixture(scope="session") -@pytest.mark.slowtest -@pytest.mark.parametrize(params=[testing._pytest_param()]) def src_volume_labels(): """Create a 7mm source space with labels.""" pytest.importorskip("nibabel") volume_labels = mne.get_volume_labels_from_aseg(fname_aseg) - with pytest.warns(RuntimeWarning, match="Found no usable.*Left-vessel.*"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="Found no usable.*t-vessel.*"), + ): src = mne.setup_volume_source_space( "sample", 7.0, @@ -885,11 +936,10 @@ def protect_config(): def _test_passed(request): - try: - outcome = request.node.harvest_rep_call - except Exception: - outcome = "passed" - return outcome == "passed" + if _phase_report_key not in request.node.stash: + return True + report = request.node.stash[_phase_report_key] + return "call" in report and report["call"].outcome == "passed" @pytest.fixture() @@ -916,7 +966,6 @@ def brain_gc(request): ignore = set(id(o) for o in gc.get_objects()) yield close_func() - # no need to warn if the test itself failed, pytest-harvest helps us here if not _test_passed(request): return _assert_no_instances(Brain, "after") @@ -945,16 +994,14 @@ def pytest_sessionfinish(session, exitstatus): if n is None: return print("\n") - try: - import pytest_harvest - except ImportError: - print("Module-level timings require pytest-harvest") - return # get the number to print - res = pytest_harvest.get_session_synthesis_dct(session) - files = dict() - for key, val in res.items(): - parts = Path(key.split(":")[0]).parts + files = defaultdict(lambda: 0.0) + for item in session.items: + if _phase_report_key not in item.stash: + continue + report = item.stash[_phase_report_key] + dur = sum(x.duration for x in report.values()) + parts = Path(item.nodeid.split(":")[0]).parts # split mne/tests/test_whatever.py into separate categories since these # are essentially submodule-level tests. Keeping just [:3] works, # except for mne/viz where we want level-4 granulatity @@ -963,7 +1010,7 @@ def pytest_sessionfinish(session, exitstatus): if not parts[-1].endswith(".py"): parts = parts + ("",) file_key = "/".join(parts) - files[file_key] = files.get(file_key, 0) + val["pytest_duration_s"] + files[file_key] += dur files = sorted(list(files.items()), key=lambda x: x[1])[::-1] # print _files[:] = files[:n] @@ -984,6 +1031,11 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): writer.line(f"{timing.ljust(15)}{name}") +def pytest_report_header(config, startdir=None): + """Add information to the pytest run header.""" + return f"MNE {mne.__version__} -- {str(Path(mne.__file__).parent)}" + + @pytest.fixture(scope="function", params=("Numba", "NumPy")) def numba_conditional(monkeypatch, request): """Test both code paths on machines that have Numba.""" @@ -1102,7 +1154,6 @@ def run(nbexec=nbexec, code=code): return -@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") @pytest.fixture( params=( [nirsport2, nirsport2_snirf, testing._pytest_param()], @@ -1140,8 +1191,7 @@ def qt_windows_closed(request): if "allow_unclosed_pyside2" in marks and API_NAME.lower() == "pyside2": return # Don't check when the test fails - report = request.node.stash[_phase_report_key] - if ("call" not in report) or report["call"].failed: + if not _test_passed(request): return widgets = app.topLevelWidgets() n_after = len(widgets) @@ -1158,3 +1208,53 @@ def pytest_runtest_makereport(item, call): outcome = yield rep = outcome.get_result() item.stash.setdefault(_phase_report_key, {})[rep.when] = rep + + +@pytest.fixture(scope="function") +def eyetrack_cal(): + """Create a toy calibration instance.""" + screen_size = (0.4, 0.225) # width, height in meters + screen_resolution = (1920, 1080) + screen_distance = 0.7 # meters + onset = 0 + model = "HV9" + eye = "R" + avg_error = 0.5 + max_error = 1.0 + positions = np.zeros((9, 2)) + offsets = np.zeros((9,)) + gaze = np.zeros((9, 2)) + cal = mne.preprocessing.eyetracking.Calibration( + screen_size=screen_size, + screen_distance=screen_distance, + screen_resolution=screen_resolution, + eye=eye, + model=model, + positions=positions, + offsets=offsets, + gaze=gaze, + onset=onset, + avg_error=avg_error, + max_error=max_error, + ) + return cal + + +@pytest.fixture(scope="function") +def eyetrack_raw(): + """Create a toy raw instance with eyetracking channels.""" + # simulate a steady fixation at the center pixel of a 1920x1080 resolution screen + shape = (1, 100) # x or y, time + data = np.vstack([np.full(shape, 960), np.full(shape, 540), np.full(shape, 0)]) + + info = info = mne.create_info( + ch_names=["xpos", "ypos", "pupil"], sfreq=100, ch_types="eyegaze" + ) + more_info = dict( + xpos=("eyegaze", "px", "right", "x"), + ypos=("eyegaze", "px", "right", "y"), + pupil=("pupil", "au", "right"), + ) + raw = mne.io.RawArray(data, info) + raw = mne.preprocessing.eyetracking.set_channel_types_eyetrack(raw, more_info) + return raw diff --git a/mne/coreg.py b/mne/coreg.py index fe6895270b7..7dae561c2a2 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -167,7 +167,7 @@ def coregister_fiducials(info, fiducials, tol=0.01): coord_frame_to = FIFF.FIFFV_COORD_MRI frames_from = {d["coord_frame"] for d in info["dig"]} if len(frames_from) > 1: - raise ValueError("info contains fiducials from different coordinate " "frames") + raise ValueError("info contains fiducials from different coordinate frames") else: coord_frame_from = frames_from.pop() coords_from = _fiducial_coords(info["dig"]) @@ -220,14 +220,14 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, verbos fs_src = os.path.join(fs_home, "subjects", "fsaverage") if not os.path.exists(fs_src): raise OSError( - "fsaverage not found at %r. Is fs_home specified " "correctly?" % fs_src + "fsaverage not found at %r. Is fs_home specified correctly?" % fs_src ) for name in ("label", "mri", "surf"): dirname = os.path.join(fs_src, name) if not os.path.isdir(dirname): raise OSError( - "Freesurfer fsaverage seems to be incomplete: No " - "directory named %s found in %s" % (name, fs_src) + "Freesurfer fsaverage seems to be incomplete: No directory named " + f"{name} found in {fs_src}" ) # make sure destination does not already exist @@ -241,9 +241,9 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, verbos ) elif (not update) and os.path.exists(dest): raise OSError( - "Can not create fsaverage because %r already exists in " - "subjects_dir %r. Delete or rename the existing fsaverage " - "subject folder." % ("fsaverage", subjects_dir) + "Can not create fsaverage because {!r} already exists in " + "subjects_dir {!r}. Delete or rename the existing fsaverage " + "subject folder.".format("fsaverage", subjects_dir) ) # copy fsaverage from freesurfer @@ -422,19 +422,15 @@ def fit_matched_points( tgt_pts = np.atleast_2d(tgt_pts) if src_pts.shape != tgt_pts.shape: raise ValueError( - "src_pts and tgt_pts must have same shape (got " - "{}, {})".format(src_pts.shape, tgt_pts.shape) + "src_pts and tgt_pts must have same shape " + f"(got {src_pts.shape}, {tgt_pts.shape})" ) if weights is not None: weights = np.asarray(weights, src_pts.dtype) if weights.ndim != 1 or weights.size not in (src_pts.shape[0], 1): raise ValueError( - "weights (shape=%s) must be None or have shape " - "(%s,)" - % ( - weights.shape, - src_pts.shape[0], - ) + f"weights (shape={weights.shape}) must be None or have shape " + f"({src_pts.shape[0]},)" ) weights = weights[:, np.newaxis] @@ -472,7 +468,7 @@ def fit_matched_points( return trans else: raise ValueError( - "Invalid out parameter: %r. Needs to be 'params' or " "'trans'." % out + "Invalid out parameter: %r. Needs to be 'params' or 'trans'." % out ) @@ -541,7 +537,7 @@ def error(x): else: raise NotImplementedError( "The specified parameter combination is not implemented: " - "rotate=%r, translate=%r, scale=%r" % param_info + "rotate={!r}, translate={!r}, scale={!r}".format(*param_info) ) x, _, _, _, _ = leastsq(error, x0, full_output=True) @@ -827,8 +823,8 @@ def read_mri_cfg(subject, subjects_dir=None): if not fname.exists(): raise OSError( - "%r does not seem to be a scaled mri subject: %r does " - "not exist." % (subject, fname) + f"{subject!r} does not seem to be a scaled mri subject: {fname!r} does not" + "exist." ) logger.info("Reading MRI cfg file %s" % fname) @@ -916,8 +912,8 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): scale = np.atleast_1d(scale) if scale.ndim != 1 or scale.shape[0] not in (1, 3): raise ValueError( - "Invalid shape for scale parameter. Need scalar " - "or array of length 3. Got shape %s." % (scale.shape,) + "Invalid shape for scale parameter. Need scalar or array of length 3. Got " + f"shape {scale.shape}." ) n_params = len(scale) return str(subjects_dir), subject_from, scale, n_params == 1 @@ -1105,14 +1101,14 @@ def scale_mri( if np.isclose(scale[1], scale[0]) and np.isclose(scale[2], scale[0]): scale = scale[0] # speed up scaling conditionals using a singleton elif scale.shape != (1,): - raise ValueError("scale must have shape (3,) or (1,), got %s" % (scale.shape,)) + raise ValueError(f"scale must have shape (3,) or (1,), got {scale.shape}") # make sure we have an empty target directory dest = subject_dirname.format(subject=subject_to, subjects_dir=subjects_dir) if os.path.exists(dest): if not overwrite: raise OSError( - "Subject directory for %s already exists: %r" % (subject_to, dest) + f"Subject directory for {subject_to} already exists: {dest!r}" ) shutil.rmtree(dest) @@ -1949,7 +1945,7 @@ def fit_fiducials( n_scale_params = self._n_scale_params if n_scale_params == 3: # enforce 1 even for 3-axis here (3 points is not enough) - logger.info("Enforcing 1 scaling parameter for fit " "with fiducials.") + logger.info("Enforcing 1 scaling parameter for fit with fiducials.") n_scale_params = 1 self._lpa_weight = lpa_weight self._nasion_weight = nasion_weight @@ -2014,12 +2010,12 @@ def _setup_icp(self, n_scale_params): self._processed_high_res_mri_points[ getattr( self, - "_nearest_transformed_high_res_mri_idx_%s" % (key,), + f"_nearest_transformed_high_res_mri_idx_{key}", ) ] ) weights.append( - np.full(len(mri_pts[-1]), getattr(self, "_%s_weight" % key)) + np.full(len(mri_pts[-1]), getattr(self, f"_{key}_weight")) ) if self._has_eeg_data and self._eeg_weight > 0: head_pts.append(self._dig_dict["dig_ch_pos_location"]) diff --git a/mne/cov.py b/mne/cov.py index 1b2d4cd8ebe..7772a0a8324 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -59,7 +59,7 @@ empirical_covariance, log_likelihood, ) -from .rank import compute_rank +from .rank import _compute_rank from .utils import ( _array_repr, _check_fname, @@ -85,12 +85,12 @@ def _check_covs_algebra(cov1, cov2): if cov1.ch_names != cov2.ch_names: - raise ValueError("Both Covariance do not have the same list of " "channels.") + raise ValueError("Both Covariance do not have the same list of channels.") projs1 = [str(c) for c in cov1["projs"]] projs2 = [str(c) for c in cov1["projs"]] if projs1 != projs2: raise ValueError( - "Both Covariance do not have the same list of " "SSP projections." + "Both Covariance do not have the same list of SSP projections." ) @@ -310,7 +310,7 @@ def __iadd__(self, cov): def plot( self, info, - exclude=[], + exclude=(), colorbar=True, proj=False, show_svd=True, @@ -453,7 +453,7 @@ def plot_topomap( ) @verbose - def pick_channels(self, ch_names, ordered=None, *, verbose=None): + def pick_channels(self, ch_names, ordered=True, *, verbose=None): """Pick channels from this covariance matrix. Parameters @@ -704,7 +704,7 @@ def compute_raw_covariance( tstep = tmax - tmin if tstep is None else float(tstep) tstep_m1 = tstep - dt # inclusive! events = make_fixed_length_events(raw, 1, tmin, tmax, tstep) - logger.info("Using up to %s segment%s" % (len(events), _pl(events))) + logger.info(f"Using up to {len(events)} segment{_pl(events)}") # don't exclude any bad channels, inverses expect all channels present if picks is None: @@ -819,13 +819,13 @@ def _check_method_params( for key, values in method_params.items(): if key not in _method_params: raise ValueError( - 'key (%s) must be "%s"' % (key, '" or "'.join(_method_params)) + 'key ({}) must be "{}"'.format(key, '" or "'.join(_method_params)) ) _method_params[key].update(method_params[key]) shrinkage = method_params.get("shrinkage", {}).get("shrinkage", 0.1) if not 0 <= shrinkage <= 1: - raise ValueError("shrinkage must be between 0 and 1, got %s" % (shrinkage,)) + raise ValueError(f"shrinkage must be between 0 and 1, got {shrinkage}") was_auto = False if method is None: @@ -839,10 +839,8 @@ def _check_method_params( if not all(k in accepted_methods for k in method): raise ValueError( - "Invalid {name} ({method}). Accepted values (individually or " - 'in a list) are any of "{accepted_methods}" or None.'.format( - name=name, method=method, accepted_methods=accepted_methods - ) + f"Invalid {name} ({method}). Accepted values (individually or " + f"in a list) are any of '{accepted_methods}' or None." ) if not (isinstance(rank, str) and rank == "full"): if was_auto: @@ -850,19 +848,18 @@ def _check_method_params( for method_ in method: if method_ in ("pca", "factor_analysis"): raise ValueError( - '%s can so far only be used with rank="full",' - " got rank=%r" % (method_, rank) + f'{method_} can so far only be used with rank="full", got rank=' + f"{rank!r}" ) if not keep_sample_mean: if len(method) != 1 or "empirical" not in method: raise ValueError( - "`keep_sample_mean=False` is only supported" - 'with %s="empirical"' % (name,) + f'`keep_sample_mean=False` is only supported with {name}="empirical"' ) for p, v in _method_params.items(): if v.get("assume_centered", None) is False: raise ValueError( - "`assume_centered` must be True" " if `keep_sample_mean` is False" + "`assume_centered` must be True if `keep_sample_mean` is False" ) return method, _method_params @@ -1077,9 +1074,7 @@ def _unpack_epochs(epochs): and keep_sample_mean for epochs_t in epochs ): - warn( - "Epochs are not baseline corrected, covariance " "matrix may be inaccurate" - ) + warn("Epochs are not baseline corrected, covariance matrix may be inaccurate") orig = epochs[0].info["dev_head_t"] _check_on_missing(on_mismatch, "on_mismatch") @@ -1090,8 +1085,8 @@ def _unpack_epochs(epochs): and not np.allclose(orig["trans"], epoch.info["dev_head_t"]["trans"]) ): msg = ( - "MEG<->Head transform mismatch between epochs[0]:\n%s\n\n" - "and epochs[%s]:\n%s" % (orig, ei, epoch.info["dev_head_t"]) + "MEG<->Head transform mismatch between epochs[0]:\n{}\n\n" + "and epochs[{}]:\n{}".format(orig, ei, epoch.info["dev_head_t"]) ) _on_missing(on_mismatch, msg, "on_mismatch") @@ -1196,7 +1191,7 @@ def _unpack_epochs(epochs): if len(covs) > 1: msg = ["log-likelihood on unseen data (descending order):"] for c in covs: - msg.append("%s: %0.3f" % (c["method"], c["loglik"])) + msg.append(f"{c['method']}: {c['loglik']:0.3f}") logger.info("\n ".join(msg)) if return_estimators: out = covs @@ -1216,7 +1211,7 @@ def _check_scalings_user(scalings): _check_option("the keys in `scalings`", k, ["mag", "grad", "eeg"]) elif scalings is not None and not isinstance(scalings, np.ndarray): raise TypeError( - "scalings must be a dict, ndarray, or None, got %s" % type(scalings) + f"scalings must be a dict, ndarray, or None, got {type(scalings)}" ) scalings = _handle_default("scalings", scalings) return scalings @@ -1231,6 +1226,21 @@ def _eigvec_subspace(eig, eigvec, mask): return eig, eigvec +@verbose +def _compute_rank_raw_array( + data, info, rank, scalings, *, log_ch_type=None, verbose=None +): + from .io import RawArray + + return _compute_rank( + RawArray(data, info, copy=None, verbose=_verbose_safe_false()), + rank, + scalings, + info, + log_ch_type=log_ch_type, + ) + + def _compute_covariance_auto( data, method, @@ -1242,22 +1252,31 @@ def _compute_covariance_auto( stop_early, picks_list, rank, + *, + cov_kind="", + log_ch_type=None, + log_rank=True, ): """Compute covariance auto mode.""" - from .io import RawArray - # rescale to improve numerical stability orig_rank = rank - rank = compute_rank( - RawArray(data.T, info, copy=None, verbose=_verbose_safe_false()), - rank, - scalings, + rank = _compute_rank_raw_array( + data.T, info, + rank=rank, + scalings=scalings, + verbose=_verbose_safe_false(), ) with _scaled_array(data.T, picks_list, scalings): C = np.dot(data.T, data) _, eigvec, mask = _smart_eigh( - C, info, rank, proj_subspace=True, do_compute_rank=False + C, + info, + rank, + proj_subspace=True, + do_compute_rank=False, + log_ch_type=log_ch_type, + verbose=None if log_rank else _verbose_safe_false(), ) eigvec = eigvec[mask] data = np.dot(data, eigvec.T) @@ -1266,21 +1285,24 @@ def _compute_covariance_auto( (key, np.searchsorted(used, picks)) for key, picks in picks_list ] sub_info = pick_info(info, used) if len(used) != len(mask) else info - logger.info("Reducing data rank from %s -> %s" % (len(mask), eigvec.shape[0])) + if log_rank: + logger.info(f"Reducing data rank from {len(mask)} -> {eigvec.shape[0]}") estimator_cov_info = list() - msg = "Estimating covariance using %s" ok_sklearn = check_version("sklearn") if not ok_sklearn and (len(method) != 1 or method[0] != "empirical"): raise ValueError( - "scikit-learn is not installed, `method` must be " - "`empirical`, got %s" % (method,) + 'scikit-learn is not installed, `method` must be "empirical", got ' + f"{repr(method)}" ) for method_ in method: data_ = data.copy() name = method_.__name__ if callable(method_) else method_ - logger.info(msg % name.upper()) + logger.info( + f'Estimating {cov_kind + (" " if cov_kind else "")}' + f"covariance using {name.upper()}" + ) mp = method_params[method_] _info = {} @@ -1375,7 +1397,7 @@ def _compute_covariance_auto( estimator_cov_info.append((fa, fa.get_covariance(), _info)) del fa else: - raise ValueError("Oh no! Your estimator does not have" " a .fit method") + raise ValueError("Oh no! Your estimator does not have a .fit method") logger.info("Done.") if len(method) > 1: @@ -1696,8 +1718,8 @@ def _get_ch_whitener(A, pca, ch_type, rank): mask[:-rank] = False logger.info( - " Setting small %s eigenvalues to zero (%s)" - % (ch_type, "using PCA" if pca else "without PCA") + f" Setting small {ch_type} eigenvalues to zero " + f'({"using" if pca else "without"} PCA)' ) if pca: # No PCA case. # This line will reduce the actual number of variables in data @@ -1795,6 +1817,8 @@ def _smart_eigh( proj_subspace=False, do_compute_rank=True, on_rank_mismatch="ignore", + *, + log_ch_type=None, verbose=None, ): """Compute eigh of C taking into account rank and ch_type scalings.""" @@ -1817,8 +1841,13 @@ def _smart_eigh( noise_cov = Covariance(C, ch_names, [], projs, 0) if do_compute_rank: # if necessary - rank = compute_rank( - noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch + rank = _compute_rank( + noise_cov, + rank, + scalings, + info, + on_rank_mismatch=on_rank_mismatch, + log_ch_type=log_ch_type, ) assert C.ndim == 2 and C.shape[0] == C.shape[1] @@ -1842,7 +1871,11 @@ def _smart_eigh( else: this_rank = rank[ch_type] - e, ev, m = _get_ch_whitener(this_C, False, ch_type.upper(), this_rank) + if log_ch_type is not None: + ch_type_ = log_ch_type + else: + ch_type_ = ch_type.upper() + e, ev, m = _get_ch_whitener(this_C, False, ch_type_, this_rank) if proj_subspace: # Choose the subspace the same way we do for projections e, ev = _eigvec_subspace(e, ev, m) @@ -1991,16 +2024,15 @@ def regularize( if len(picks_dict.get("meg", [])) > 0 and rank != "full": # combined if mag != grad: raise ValueError( - "On data where magnetometers and gradiometers " - "are dependent (e.g., SSSed data), mag (%s) must " - "equal grad (%s)" % (mag, grad) + "On data where magnetometers and gradiometers are dependent (e.g., " + f"SSSed data), mag ({mag}) must equal grad ({grad})" ) logger.info("Regularizing MEG channels jointly") regs["meg"] = mag else: regs.update(mag=mag, grad=grad) if rank != "full": - rank = compute_rank(cov, rank, scalings, info) + rank = _compute_rank(cov, rank, scalings, info) info_ch_names = info["ch_names"] ch_names_by_type = dict() @@ -2039,9 +2071,9 @@ def regularize( continue reg = regs[ch_type] if reg == 0.0: - logger.info(" %s regularization : None" % desc) + logger.info(f" {desc} regularization : None") continue - logger.info(" %s regularization : %s" % (desc, reg)) + logger.info(f" {desc} regularization : {reg}") this_C = C[np.ix_(idx, idx)] U = np.eye(this_C.shape[0]) @@ -2053,8 +2085,7 @@ def regularize( # This adjustment ends up being redundant if rank is None: U = _safe_svd(P)[0][:, :-ncomp] logger.info( - " Created an SSP operator for %s " - "(dimension = %d)" % (desc, ncomp) + f" Created an SSP operator for {desc} (dimension = {ncomp})" ) else: this_picks = pick_channels(info["ch_names"], this_ch_names) @@ -2077,7 +2108,17 @@ def regularize( return cov -def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=None): +def _regularized_covariance( + data, + reg=None, + method_params=None, + info=None, + rank=None, + *, + log_ch_type=None, + log_rank=None, + cov_kind="", +): """Compute a regularized covariance from data using sklearn. This is a convenience wrapper for mne.decoding functions, which @@ -2095,8 +2136,8 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, rank= reg = float(reg) if method_params is not None: raise ValueError( - "If reg is a float, method_params must be None " - "(got %s)" % (type(method_params),) + "If reg is a float, method_params must be None (got " + f"{type(method_params)})" ) method_params = dict( shrinkage=dict(shrinkage=reg, assume_centered=True, store_precision=False) @@ -2120,6 +2161,9 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, rank= picks_list=picks_list, scalings=scalings, rank=rank, + cov_kind=cov_kind, + log_ch_type=log_ch_type, + log_rank=log_rank, )[reg]["data"] return cov @@ -2190,12 +2234,12 @@ def compute_whitener( _validate_type(pca, (str, bool), "space") _valid_pcas = (True, "white", False) if pca not in _valid_pcas: - raise ValueError("space must be one of %s, got %s" % (_valid_pcas, pca)) + raise ValueError(f"space must be one of {_valid_pcas}, got {pca}") if info is None: if "eig" not in noise_cov: raise ValueError( - "info can only be None if the noise cov has " - "already been prepared with prepare_noise_cov" + "info can only be None if the noise cov has already been prepared with " + "prepare_noise_cov" ) ch_names = deepcopy(noise_cov["names"]) else: @@ -2342,7 +2386,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): names = _safe_name_list(tag.data, "read", "names") if len(names) != dim: raise ValueError( - "Number of names does not match " "covariance matrix dimension" + "Number of names does not match covariance matrix dimension" ) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV) @@ -2488,7 +2532,7 @@ def _write_cov(fid, cov): @verbose def _ensure_cov(cov, name="cov", *, verbose=None): _validate_type(cov, ("path-like", Covariance), name) - logger.info("Noise covariance : %s" % (cov,)) + logger.info(f"Noise covariance : {cov}") if not isinstance(cov, Covariance): cov = read_cov(cov, verbose=_verbose_safe_false()) return cov diff --git a/mne/cuda.py b/mne/cuda.py index b4aa7c37bf3..be645506de3 100644 --- a/mne/cuda.py +++ b/mne/cuda.py @@ -120,7 +120,7 @@ def _set_cuda_device(device_id, verbose=None): import cupy cupy.cuda.Device(device_id).use() - logger.info("Now using CUDA device {}".format(device_id)) + logger.info(f"Now using CUDA device {device_id}") ############################################################################### @@ -330,7 +330,7 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, pad="reflect_li Number of samples to remove after resampling. cuda_dict : dict Dictionary constructed using setup_cuda_multiply_repeated(). - %(pad)s + %(pad_resample)s The default is ``'reflect_limited'``. .. versionadded:: 0.15 diff --git a/mne/data/image/custom_layout.lout b/mne/data/image/custom_layout.lout deleted file mode 100644 index ab5b81408cb..00000000000 --- a/mne/data/image/custom_layout.lout +++ /dev/null @@ -1,257 +0,0 @@ - 0.00 0.00 0.01 0.02 -000 0.79 0.46 0.07 0.05 0 -001 0.78 0.48 0.07 0.05 1 -002 0.76 0.51 0.07 0.05 2 -003 0.74 0.53 0.07 0.05 3 -004 0.72 0.55 0.07 0.05 4 -005 0.71 0.57 0.07 0.05 5 -006 0.69 0.59 0.07 0.05 6 -007 0.67 0.62 0.07 0.05 7 -008 0.66 0.64 0.07 0.05 8 -009 0.64 0.66 0.07 0.05 9 -010 0.62 0.68 0.07 0.05 10 -011 0.61 0.69 0.07 0.05 11 -012 0.59 0.71 0.07 0.05 12 -013 0.58 0.73 0.07 0.05 13 -014 0.56 0.75 0.07 0.05 14 -015 0.54 0.77 0.07 0.05 15 -016 0.77 0.44 0.07 0.05 16 -017 0.75 0.46 0.07 0.05 17 -018 0.73 0.49 0.07 0.05 18 -019 0.72 0.51 0.07 0.05 19 -020 0.70 0.54 0.07 0.05 20 -021 0.68 0.56 0.07 0.05 21 -022 0.66 0.58 0.07 0.05 22 -023 0.65 0.60 0.07 0.05 23 -024 0.63 0.62 0.07 0.05 24 -025 0.62 0.64 0.07 0.05 25 -026 0.60 0.66 0.07 0.05 26 -027 0.58 0.68 0.07 0.05 27 -028 0.57 0.70 0.07 0.05 28 -029 0.55 0.71 0.07 0.05 29 -030 0.53 0.73 0.07 0.05 30 -031 0.52 0.75 0.07 0.05 31 -032 0.75 0.42 0.07 0.05 32 -033 0.73 0.45 0.07 0.05 33 -034 0.71 0.47 0.07 0.05 34 -035 0.69 0.50 0.07 0.05 35 -036 0.68 0.52 0.07 0.05 36 -037 0.66 0.54 0.07 0.05 37 -038 0.64 0.57 0.07 0.05 38 -039 0.62 0.58 0.07 0.05 39 -040 0.61 0.61 0.07 0.05 40 -041 0.59 0.62 0.07 0.05 41 -042 0.58 0.64 0.07 0.05 42 -043 0.56 0.66 0.07 0.05 43 -044 0.54 0.68 0.07 0.05 44 -045 0.53 0.70 0.07 0.05 45 -046 0.51 0.72 0.07 0.05 46 -047 0.50 0.74 0.07 0.05 47 -048 0.72 0.41 0.07 0.05 48 -049 0.71 0.43 0.07 0.05 49 -050 0.69 0.46 0.07 0.05 50 -051 0.67 0.48 0.07 0.05 51 -052 0.65 0.50 0.07 0.05 52 -053 0.63 0.52 0.07 0.05 53 -054 0.62 0.55 0.07 0.05 54 -055 0.60 0.57 0.07 0.05 55 -056 0.58 0.59 0.07 0.05 56 -057 0.57 0.61 0.07 0.05 57 -058 0.55 0.63 0.07 0.05 58 -059 0.54 0.65 0.07 0.05 59 -060 0.52 0.67 0.07 0.05 60 -061 0.51 0.69 0.07 0.05 61 -062 0.49 0.71 0.07 0.05 62 -063 0.47 0.73 0.07 0.05 63 -064 0.70 0.39 0.07 0.05 64 -065 0.68 0.41 0.07 0.05 65 -066 0.66 0.44 0.07 0.05 66 -067 0.65 0.46 0.07 0.05 67 -068 0.63 0.49 0.07 0.05 68 -069 0.61 0.51 0.07 0.05 69 -070 0.59 0.53 0.07 0.05 70 -071 0.58 0.55 0.07 0.05 71 -072 0.56 0.57 0.07 0.05 72 -073 0.55 0.59 0.07 0.05 73 -074 0.53 0.61 0.07 0.05 74 -075 0.51 0.64 0.07 0.05 75 -076 0.50 0.66 0.07 0.05 76 -077 0.48 0.68 0.07 0.05 77 -078 0.47 0.69 0.07 0.05 78 -079 0.45 0.72 0.07 0.05 79 -080 0.68 0.38 0.07 0.05 80 -081 0.66 0.40 0.07 0.05 81 -082 0.64 0.42 0.07 0.05 82 -083 0.62 0.44 0.07 0.05 83 -084 0.60 0.47 0.07 0.05 84 -085 0.59 0.49 0.07 0.05 85 -086 0.57 0.51 0.07 0.05 86 -087 0.55 0.54 0.07 0.05 87 -088 0.54 0.56 0.07 0.05 88 -089 0.52 0.58 0.07 0.05 89 -090 0.50 0.60 0.07 0.05 90 -091 0.49 0.62 0.07 0.05 91 -092 0.47 0.64 0.07 0.05 92 -093 0.46 0.66 0.07 0.05 93 -094 0.44 0.68 0.07 0.05 94 -095 0.42 0.70 0.07 0.05 95 -096 0.65 0.36 0.07 0.05 96 -097 0.63 0.38 0.07 0.05 97 -098 0.61 0.41 0.07 0.05 98 -099 0.60 0.43 0.07 0.05 99 -100 0.58 0.45 0.07 0.05 100 -101 0.56 0.47 0.07 0.05 101 -102 0.55 0.50 0.07 0.05 102 -103 0.53 0.52 0.07 0.05 103 -104 0.51 0.54 0.07 0.05 104 -105 0.50 0.56 0.07 0.05 105 -106 0.48 0.58 0.07 0.05 106 -107 0.47 0.61 0.07 0.05 107 -108 0.45 0.63 0.07 0.05 108 -109 0.44 0.65 0.07 0.05 109 -110 0.42 0.67 0.07 0.05 110 -111 0.41 0.69 0.07 0.05 111 -112 0.63 0.34 0.07 0.05 112 -113 0.61 0.36 0.07 0.05 113 -114 0.59 0.39 0.07 0.05 114 -115 0.58 0.41 0.07 0.05 115 -116 0.56 0.43 0.07 0.05 116 -117 0.54 0.46 0.07 0.05 117 -118 0.52 0.48 0.07 0.05 118 -119 0.51 0.51 0.07 0.05 119 -120 0.49 0.52 0.07 0.05 120 -121 0.47 0.55 0.07 0.05 121 -122 0.46 0.57 0.07 0.05 122 -123 0.44 0.59 0.07 0.05 123 -124 0.43 0.61 0.07 0.05 124 -125 0.41 0.63 0.07 0.05 125 -126 0.40 0.65 0.07 0.05 126 -127 0.38 0.67 0.07 0.05 127 -128 0.60 0.32 0.07 0.05 128 -129 0.59 0.35 0.07 0.05 129 -130 0.56 0.37 0.07 0.05 130 -131 0.55 0.39 0.07 0.05 131 -132 0.53 0.42 0.07 0.05 132 -133 0.52 0.44 0.07 0.05 133 -134 0.50 0.46 0.07 0.05 134 -135 0.48 0.49 0.07 0.05 135 -136 0.47 0.51 0.07 0.05 136 -137 0.45 0.53 0.07 0.05 137 -138 0.43 0.56 0.07 0.05 138 -139 0.42 0.57 0.07 0.05 139 -140 0.40 0.60 0.07 0.05 140 -141 0.39 0.61 0.07 0.05 141 -142 0.37 0.63 0.07 0.05 142 -143 0.36 0.66 0.07 0.05 143 -144 0.58 0.31 0.07 0.05 144 -145 0.56 0.33 0.07 0.05 145 -146 0.54 0.35 0.07 0.05 146 -147 0.53 0.38 0.07 0.05 147 -148 0.51 0.40 0.07 0.05 148 -149 0.49 0.42 0.07 0.05 149 -150 0.48 0.45 0.07 0.05 150 -151 0.46 0.47 0.07 0.05 151 -152 0.44 0.49 0.07 0.05 152 -153 0.42 0.51 0.07 0.05 153 -154 0.41 0.53 0.07 0.05 154 -155 0.39 0.56 0.07 0.05 155 -156 0.38 0.58 0.07 0.05 156 -157 0.36 0.60 0.07 0.05 157 -158 0.35 0.62 0.07 0.05 158 -159 0.33 0.64 0.07 0.05 159 -160 0.55 0.29 0.07 0.05 160 -161 0.54 0.32 0.07 0.05 161 -162 0.52 0.34 0.07 0.05 162 -163 0.50 0.36 0.07 0.05 163 -164 0.49 0.38 0.07 0.05 164 -165 0.47 0.41 0.07 0.05 165 -166 0.45 0.43 0.07 0.05 166 -167 0.43 0.45 0.07 0.05 167 -168 0.42 0.48 0.07 0.05 168 -169 0.40 0.50 0.07 0.05 169 -170 0.39 0.52 0.07 0.05 170 -171 0.37 0.54 0.07 0.05 171 -172 0.36 0.56 0.07 0.05 172 -173 0.34 0.58 0.07 0.05 173 -174 0.33 0.60 0.07 0.05 174 -175 0.31 0.62 0.07 0.05 175 -176 0.53 0.27 0.07 0.05 176 -177 0.52 0.30 0.07 0.05 177 -178 0.50 0.32 0.07 0.05 178 -179 0.48 0.34 0.07 0.05 179 -180 0.46 0.37 0.07 0.05 180 -181 0.45 0.39 0.07 0.05 181 -182 0.43 0.41 0.07 0.05 182 -183 0.41 0.43 0.07 0.05 183 -184 0.40 0.46 0.07 0.05 184 -185 0.38 0.48 0.07 0.05 185 -186 0.36 0.50 0.07 0.05 186 -187 0.35 0.53 0.07 0.05 187 -188 0.33 0.55 0.07 0.05 188 -189 0.32 0.57 0.07 0.05 189 -190 0.30 0.59 0.07 0.05 190 -191 0.29 0.61 0.07 0.05 191 -192 0.51 0.26 0.07 0.05 192 -193 0.49 0.28 0.07 0.05 193 -194 0.47 0.31 0.07 0.05 194 -195 0.46 0.33 0.07 0.05 195 -196 0.44 0.35 0.07 0.05 196 -197 0.42 0.37 0.07 0.05 197 -198 0.41 0.40 0.07 0.05 198 -199 0.39 0.42 0.07 0.05 199 -200 0.37 0.44 0.07 0.05 200 -201 0.36 0.46 0.07 0.05 201 -202 0.34 0.49 0.07 0.05 202 -203 0.32 0.51 0.07 0.05 203 -204 0.31 0.53 0.07 0.05 204 -205 0.29 0.55 0.07 0.05 205 -206 0.28 0.57 0.07 0.05 206 -207 0.27 0.59 0.07 0.05 207 -208 0.48 0.24 0.07 0.05 208 -209 0.47 0.26 0.07 0.05 209 -210 0.45 0.28 0.07 0.05 210 -211 0.43 0.31 0.07 0.05 211 -212 0.41 0.33 0.07 0.05 212 -213 0.40 0.35 0.07 0.05 213 -214 0.38 0.38 0.07 0.05 214 -215 0.37 0.40 0.07 0.05 215 -216 0.35 0.42 0.07 0.05 216 -217 0.33 0.45 0.07 0.05 217 -218 0.32 0.47 0.07 0.05 218 -219 0.30 0.49 0.07 0.05 219 -220 0.28 0.51 0.07 0.05 220 -221 0.27 0.53 0.07 0.05 221 -222 0.25 0.55 0.07 0.05 222 -223 0.24 0.58 0.07 0.05 223 -224 0.46 0.23 0.07 0.05 224 -225 0.45 0.25 0.07 0.05 225 -226 0.43 0.27 0.07 0.05 226 -227 0.41 0.29 0.07 0.05 227 -228 0.39 0.31 0.07 0.05 228 -229 0.38 0.34 0.07 0.05 229 -230 0.36 0.36 0.07 0.05 230 -231 0.34 0.38 0.07 0.05 231 -232 0.33 0.41 0.07 0.05 232 -233 0.31 0.43 0.07 0.05 233 -234 0.29 0.45 0.07 0.05 234 -235 0.28 0.47 0.07 0.05 235 -236 0.26 0.50 0.07 0.05 236 -237 0.25 0.52 0.07 0.05 237 -238 0.24 0.54 0.07 0.05 238 -239 0.22 0.56 0.07 0.05 239 -240 0.44 0.21 0.07 0.05 240 -241 0.42 0.23 0.07 0.05 241 -242 0.41 0.25 0.07 0.05 242 -243 0.39 0.27 0.07 0.05 243 -244 0.37 0.30 0.07 0.05 244 -245 0.35 0.32 0.07 0.05 245 -246 0.33 0.34 0.07 0.05 246 -247 0.32 0.37 0.07 0.05 247 -248 0.30 0.39 0.07 0.05 248 -249 0.28 0.41 0.07 0.05 249 -250 0.27 0.43 0.07 0.05 250 -251 0.25 0.46 0.07 0.05 251 -252 0.24 0.48 0.07 0.05 252 -253 0.23 0.50 0.07 0.05 253 -254 0.21 0.52 0.07 0.05 254 -255 0.20 0.54 0.07 0.05 255 diff --git a/mne/data/image/mni_brain.gif b/mne/data/image/mni_brain.gif deleted file mode 100644 index 3d6cc08edbd..00000000000 Binary files a/mne/data/image/mni_brain.gif and /dev/null differ diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 22cb6acce7b..44cee84fe7f 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -66,7 +66,7 @@ from . import ( ) from ._fetch import fetch_dataset from ._fsaverage.base import fetch_fsaverage -from ._infant.base import fetch_infant_template +from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 82d68d6e9f6..2b07ea29be0 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -56,7 +56,7 @@ def fetch_dataset( What to do after downloading the file. ``"unzip"`` and ``"untar"`` will decompress the downloaded file in place; for custom extraction (e.g., only extracting certain files from the archive) pass an instance of - :class:`pooch.Unzip` or :class:`pooch.Untar`. If ``None`` (the + ``pooch.Unzip`` or ``pooch.Untar``. If ``None`` (the default), the files are left as-is. path : None | str Directory in which to put the dataset. If ``None``, the dataset @@ -87,10 +87,10 @@ def fetch_dataset( Default is ``False``. auth : tuple | None Optional authentication tuple containing the username and - password/token, passed to :class:`pooch.HTTPDownloader` (e.g., + password/token, passed to ``pooch.HTTPDownloader`` (e.g., ``auth=('foo', 012345)``). token : str | None - Optional authentication token passed to :class:`pooch.HTTPDownloader`. + Optional authentication token passed to ``pooch.HTTPDownloader``. Returns ------- diff --git a/mne/datasets/_infant/__init__.py b/mne/datasets/_infant/__init__.py new file mode 100644 index 00000000000..7347d36fcd0 --- /dev/null +++ b/mne/datasets/_infant/__init__.py @@ -0,0 +1 @@ +from .base import fetch_infant_template diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 6778a1e7cc9..22fd45475bc 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,8 +87,13 @@ # To update the `testing` or `misc` datasets, push or merge commits to their # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version -# here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing="0.150", misc="0.26") +# here: ↓↓↓↓↓↓↓↓ +RELEASES = dict( + testing="0.152", + misc="0.27", + phantom_kit="0.2", + ucl_opm_auditory="0.2", +) TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -112,7 +117,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:0b7452daef4d19132505b5639d695628", + hash="md5:df48cdabcf13ebeaafc617cb8e55b6fc", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f'tar.gz/{RELEASES["testing"]}' @@ -126,7 +131,7 @@ ) MNE_DATASETS["misc"] = dict( archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', - hash="md5:868b484fadd73b1d1a3535b7194a0d03", + hash="md5:e343d3a00cb49f8a2f719d14f4758afe", url=( "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" f'{RELEASES["misc"]}' @@ -145,8 +150,8 @@ MNE_DATASETS["ucl_opm_auditory"] = dict( archive_name="auditory_OPM_stationary.zip", - hash="md5:9ed0d8d554894542b56f8e7c4c0041fe", - url="https://osf.io/download/mwrt3/?version=1", + hash="md5:b2d69aa2d656b960bd0c18968dc1a14d", + url="https://osf.io/download/tp324/?version=1", # original is mwrt3 folder_name="auditory_OPM_stationary", config_key="MNE_DATASETS_UCL_OPM_AUDITORY_PATH", ) @@ -176,9 +181,9 @@ ) MNE_DATASETS["phantom_kit"] = dict( - archive_name="MNE-phantom-KIT-24bit.zip", - hash="md5:CAF82EE978DD473C7DE6C1034D9CCD45", - url="https://osf.io/download/svnt3/", + archive_name="MNE-phantom-KIT-data.tar.gz", + hash="md5:7bfdf40bbeaf17a66c99c695640e0740", + url="https://osf.io/fb6ya/download?version=1", folder_name="MNE-phantom-KIT-data", config_key="MNE_DATASETS_PHANTOM_KIT_PATH", ) diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index 3af5661e5f7..93c6c731932 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -7,19 +7,13 @@ import os import re import time +from importlib.resources import files from os import path as op from pathlib import Path from ...utils import _url_to_local_path, logger, verbose from ..utils import _do_path_update, _downloader_params, _get_path, _log_time_size -# TODO: remove try/except when our min version is py 3.9 -try: - from importlib.resources import files -except ImportError: - from importlib_resources import files - - EEGMI_URL = "https://physionet.org/files/eegmmidb/1.0.0/" diff --git a/mne/datasets/phantom_kit/phantom_kit.py b/mne/datasets/phantom_kit/phantom_kit.py index a4eac6c4a50..d57ca875f2c 100644 --- a/mne/datasets/phantom_kit/phantom_kit.py +++ b/mne/datasets/phantom_kit/phantom_kit.py @@ -10,7 +10,7 @@ def data_path( ): # noqa: D103 return _download_mne_dataset( name="phantom_kit", - processor="unzip", + processor="untar", path=path, force_update=force_update, update_path=update_path, diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index db9982e8c71..acff836366c 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -20,9 +20,7 @@ ) TEMAZEPAM_RECORDS_URL_SHA1 = "f52fffe5c18826a2bd4c5d5cb375bb4a9008c885" -AGE_RECORDS_URL = ( - "https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls" # noqa: E501 -) +AGE_RECORDS_URL = "https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls" AGE_RECORDS_URL_SHA1 = "0ba6650892c5d33a8e2b3f62ce1cc9f30438c54f" sha1sums_fname = op.join(op.dirname(__file__), "SHA1SUMS") @@ -134,9 +132,7 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): "level_3": "drug", } ) - data["id"] = [ - "ST7{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data["night nr"]) - ] + data["id"] = [f"ST7{s:02d}{n:1d}" for s, n in zip(data.subject, data["night nr"])] data = pd.merge(sha1_df, data, how="outer", on="id") data["record type"] = ( @@ -200,9 +196,7 @@ def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): {1: "female", 2: "male"} ) - data["id"] = [ - "SC4{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data.night) - ] + data["id"] = [f"SC4{s:02d}{n:1d}" for s, n in zip(data.subject, data.night)] data = data.set_index("id").join(sha1_df.set_index("id")).dropna() diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index 29afe9d9562..f947874aa0d 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -21,9 +21,7 @@ data_path = _data_path # expose _data_path(..) as data_path(..) -BASE_URL = ( - "https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/" # noqa: E501 -) +BASE_URL = "https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/" @verbose diff --git a/mne/datasets/sleep_physionet/tests/test_physionet.py b/mne/datasets/sleep_physionet/tests/test_physionet.py index 08b13c832c7..5147be94ab9 100644 --- a/mne/datasets/sleep_physionet/tests/test_physionet.py +++ b/mne/datasets/sleep_physionet/tests/test_physionet.py @@ -30,8 +30,8 @@ def _keep_basename_only(paths): def _get_expected_url(name): base = "https://physionet.org/physiobank/database/sleep-edfx/" - midle = "sleep-cassette/" if name.startswith("SC") else "sleep-telemetry/" - return base + midle + "/" + name + middle = "sleep-cassette/" if name.startswith("SC") else "sleep-telemetry/" + return base + middle + "/" + name def _get_expected_path(base, name): @@ -46,12 +46,12 @@ def _check_mocked_function_calls(mocked_func, call_fname_hash_pairs, base_path): # order. for idx, current in enumerate(call_fname_hash_pairs): _, call_kwargs = mocked_func.call_args_list[idx] - hash_type, hash = call_kwargs["known_hash"].split(":") + hash_type, hash_ = call_kwargs["known_hash"].split(":") assert call_kwargs["url"] == _get_expected_url(current["name"]), idx assert Path(call_kwargs["path"], call_kwargs["fname"]) == _get_expected_path( base_path, current["name"] ) - assert hash == current["hash"] + assert hash_ == current["hash"] assert hash_type == "sha1" diff --git a/mne/datasets/tests/test_datasets.py b/mne/datasets/tests/test_datasets.py index b84b3a2f367..d3a361786d7 100644 --- a/mne/datasets/tests/test_datasets.py +++ b/mne/datasets/tests/test_datasets.py @@ -60,7 +60,7 @@ def test_datasets_basic(tmp_path, monkeypatch): else: assert dataset.get_version() is None assert not datasets.has_dataset(dname) - print("%s: %s" % (dname, datasets.has_dataset(dname))) + print(f"{dname}: {datasets.has_dataset(dname)}") tempdir = str(tmp_path) # Explicitly test one that isn't preset (given the config) monkeypatch.setenv("MNE_DATASETS_SAMPLE_PATH", tempdir) diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index eddee6f5684..d4a8f4af459 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -87,7 +87,7 @@ def _dataset_version(path, name): """Get the version of the dataset.""" ver_fname = op.join(path, "version.txt") if op.exists(ver_fname): - with open(ver_fname, "r") as fid: + with open(ver_fname) as fid: version = fid.readline().strip() # version is on first line else: logger.debug(f"Version file missing: {ver_fname}") @@ -147,8 +147,8 @@ def _do_path_update(path, update_path, key, name): answer = "y" else: msg = ( - "Do you want to set the path:\n %s\nas the default " - "%s dataset path in the mne-python config [y]/n? " % (path, name) + f"Do you want to set the path:\n {path}\nas the default {name} " + "dataset path in the mne-python config [y]/n? " ) answer = _safe_input(msg, alt="pass update_path=True") if answer.lower() == "n": @@ -747,7 +747,7 @@ def fetch_hcp_mmp_parcellation( assert used.all() assert len(labels_out) == 46 for hemi, side in (("lh", "left"), ("rh", "right")): - table_name = "./%s.fsaverage164.label.gii" % (side,) + table_name = f"./{side}.fsaverage164.label.gii" write_labels_to_annot( labels_out, "fsaverage", @@ -762,7 +762,7 @@ def fetch_hcp_mmp_parcellation( def _manifest_check_download(manifest_path, destination, url, hash_): import pooch - with open(manifest_path, "r") as fid: + with open(manifest_path) as fid: names = [name.strip() for name in fid.readlines()] manifest_path = op.basename(manifest_path) need = list() @@ -787,18 +787,17 @@ def _manifest_check_download(manifest_path, destination, url, hash_): fname=op.basename(fname_path), ) - logger.info("Extracting missing file%s" % (_pl(need),)) + logger.info(f"Extracting missing file{_pl(need)}") with zipfile.ZipFile(fname_path, "r") as ff: members = set(f for f in ff.namelist() if not f.endswith("/")) missing = sorted(members.symmetric_difference(set(names))) if len(missing): raise RuntimeError( - "Zip file did not have correct names:" - "\n%s" % ("\n".join(missing)) + "Zip file did not have correct names:\n{'\n'.join(missing)}" ) for name in need: ff.extract(name, path=destination) - logger.info("Successfully extracted %d file%s" % (len(need), _pl(need))) + logger.info(f"Successfully extracted {len(need)} file{_pl(need)}") def _log_time_size(t0, sz): diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 08a0d65e951..8e36ee412a8 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -11,6 +11,7 @@ import numbers import numpy as np +from scipy.sparse import issparse from ..fixes import BaseEstimator, _check_fit_params, _get_check_scoring from ..parallel import parallel_func @@ -64,7 +65,7 @@ class LinearModel(BaseEstimator): "classes_", ) - def __init__(self, model=None): # noqa: D102 + def __init__(self, model=None): if model is None: from sklearn.linear_model import LogisticRegression @@ -106,16 +107,21 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ + # Once we require sklearn 1.1+ we should do: + # from sklearn.utils import check_array + # X = check_array(X, input_name="X") + # y = check_array(y, dtype=None, ensure_2d=False, input_name="y") + if issparse(X): + raise TypeError("X should be a dense array, got sparse instead.") X, y = np.asarray(X), np.asarray(y) if X.ndim != 2: raise ValueError( - "LinearModel only accepts 2-dimensional X, got " - "%s instead." % (X.shape,) + f"LinearModel only accepts 2-dimensional X, got {X.shape} instead." ) if y.ndim > 2: raise ValueError( - "LinearModel only accepts up to 2-dimensional y, " - "got %s instead." % (y.shape,) + f"LinearModel only accepts up to 2-dimensional y, got {y.shape} " + "instead." ) # fit the Model @@ -267,9 +273,7 @@ def get_coef(estimator, attr="filters_", inverse_transform=False): coef = coef[np.newaxis] # fake a sample dimension squeeze_first_dim = True elif not hasattr(est, attr): - raise ValueError( - "This estimator does not have a %s attribute:\n%s" % (attr, est) - ) + raise ValueError(f"This estimator does not have a {attr} attribute:\n{est}") else: coef = getattr(est, attr) @@ -281,7 +285,7 @@ def get_coef(estimator, attr="filters_", inverse_transform=False): if inverse_transform: if not hasattr(estimator, "steps") and not hasattr(est, "estimators_"): raise ValueError( - "inverse_transform can only be applied onto " "pipeline estimators." + "inverse_transform can only be applied onto pipeline estimators." ) # The inverse_transform parameter will call this method on any # estimator contained in the pipeline, in reverse order. @@ -458,15 +462,13 @@ def _fit_and_score( if return_train_score: train_score = error_score warn( - "Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r" % (error_score, e) + "Classifier fit failed. The score on this train-test partition for " + f"these parameters will be set to {error_score}. Details: \n{e!r}" ) else: raise ValueError( - "error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)" + "error_score must be the string 'raise' or a numeric value. (Hint: if " + "using 'raise', please make sure that it has been spelled correctly.)" ) else: diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 1656db50b36..ba76acd2d7c 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -12,11 +12,18 @@ import numpy as np from scipy.linalg import eigh -from ..cov import _regularized_covariance +from .._fiff.meas_info import create_info +from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..evoked import EvokedArray from ..fixes import pinv -from ..utils import _check_option, _validate_type, copy_doc, fill_doc +from ..utils import ( + _check_option, + _validate_type, + _verbose_safe_false, + copy_doc, + fill_doc, +) from .base import BaseEstimator from .mixin import TransformerMixin @@ -181,12 +188,13 @@ def fit(self, X, y): raise ValueError("n_classes must be >= 2.") if n_classes > 2 and self.component_order == "alternate": raise ValueError( - "component_order='alternate' requires two " - "classes, but data contains {} classes; use " - "component_order='mutual_info' " - "instead.".format(n_classes) + "component_order='alternate' requires two classes, but data contains " + f"{n_classes} classes; use component_order='mutual_info' instead." ) + # Convert rank to one that will run + _validate_type(self.rank, (dict, None), "rank") + covs, sample_weights = self._compute_covariance_matrices(X, y) eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights) ix = self._order_components( @@ -521,10 +529,28 @@ def _compute_covariance_matrices(self, X, y): elif self.cov_est == "epoch": cov_estimator = self._epoch_cov + # Someday we could allow the user to pass this, then we wouldn't need to convert + # but in the meantime they can use a pipeline with a scaler + self._info = create_info(n_channels, 1000.0, "mag") + if self.rank is None: + self._rank = _compute_rank_raw_array( + X.transpose(1, 0, 2).reshape(X.shape[1], -1), + self._info, + rank=None, + scalings=None, + log_ch_type="data", + ) + else: + self._rank = {"mag": sum(self.rank.values())} + covs = [] sample_weights = [] - for this_class in self._classes: - cov, weight = cov_estimator(X[y == this_class]) + for ci, this_class in enumerate(self._classes): + cov, weight = cov_estimator( + X[y == this_class], + cov_kind=f"class={this_class}", + log_rank=ci == 0, + ) if self.norm_trace: cov /= np.trace(cov) @@ -534,29 +560,39 @@ def _compute_covariance_matrices(self, X, y): return np.stack(covs), np.array(sample_weights) - def _concat_cov(self, x_class): + def _concat_cov(self, x_class, *, cov_kind, log_rank): """Concatenate epochs before computing the covariance.""" _, n_channels, _ = x_class.shape - x_class = np.transpose(x_class, [1, 0, 2]) - x_class = x_class.reshape(n_channels, -1) + x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1) cov = _regularized_covariance( - x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank + x_class, + reg=self.reg, + method_params=self.cov_method_params, + rank=self._rank, + info=self._info, + cov_kind=cov_kind, + log_rank=log_rank, + log_ch_type="data", ) weight = x_class.shape[0] return cov, weight - def _epoch_cov(self, x_class): + def _epoch_cov(self, x_class, *, cov_kind, log_rank): """Mean of per-epoch covariances.""" cov = sum( _regularized_covariance( this_X, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank, + rank=self._rank, + info=self._info, + cov_kind=cov_kind, + log_rank=log_rank and ii == 0, + log_ch_type="data", ) - for this_X in x_class + for ii, this_X in enumerate(x_class) ) cov /= len(x_class) weight = len(x_class) @@ -565,6 +601,20 @@ def _epoch_cov(self, x_class): def _decompose_covs(self, covs, sample_weights): n_classes = len(covs) + n_channels = covs[0].shape[0] + assert self._rank is not None # should happen in _compute_covariance_matrices + _, sub_vec, mask = _smart_eigh( + covs.mean(0), + self._info, + self._rank, + proj_subspace=True, + do_compute_rank=False, + log_ch_type="data", + verbose=_verbose_safe_false(), + ) + sub_vec = sub_vec[mask] + covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float) + assert covs[0].shape == (mask.sum(),) * 2 if n_classes == 2: eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0)) else: @@ -575,6 +625,9 @@ def _decompose_covs(self, covs, sample_weights): eigen_vectors.T, covs, sample_weights ) eigen_values = None + # project back + eigen_vectors = sub_vec.T @ eigen_vectors + assert eigen_vectors.shape == (n_channels, mask.sum()) return eigen_vectors, eigen_values def _compute_mutual_info(self, covs, sample_weights, eigen_vectors): @@ -773,7 +826,7 @@ def __init__( rank=None, ): """Init of SPoC.""" - super(SPoC, self).__init__( + super().__init__( n_components=n_components, reg=reg, log=log, @@ -826,6 +879,8 @@ def fit(self, X, y): reg=self.reg, method_params=self.cov_method_params, rank=self.rank, + log_ch_type="data", + log_rank=ii == 0, ) C = covs.mean(0) @@ -873,4 +928,4 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_sources, n_times). """ - return super(SPoC, self).transform(X) + return super().transform(X) diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py index 2a0adee19eb..3916c156873 100644 --- a/mne/decoding/mixin.py +++ b/mne/decoding/mixin.py @@ -69,9 +69,8 @@ def set_params(self, **params): name, sub_name = split if name not in valid_params: raise ValueError( - "Invalid parameter %s for estimator %s. " - "Check the list of available parameters " - "with `estimator.get_params().keys()`." % (name, self) + f"Invalid parameter {name} for estimator {self}. Check the list" + " of available parameters with `estimator.get_params().keys()`." ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) @@ -79,10 +78,9 @@ def set_params(self, **params): # simple objects case if key not in valid_params: raise ValueError( - "Invalid parameter %s for estimator %s. " - "Check the list of available parameters " - "with `estimator.get_params().keys()`." - % (key, self.__class__.__name__) + f"Invalid parameter {key} for estimator " + f"{self.__class__.__name__}. Check the list of available " + "parameters with `estimator.get_params().keys()`." ) setattr(self, key, value) return self diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index c3c07cfa42f..fdf7dea9211 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -134,24 +134,24 @@ def _more_tags(self): return {"no_validation": True} def __repr__(self): # noqa: D105 - s = "tmin, tmax : (%.3f, %.3f), " % (self.tmin, self.tmax) + s = f"tmin, tmax : ({self.tmin:.3f}, {self.tmax:.3f}), " estimator = self.estimator if not isinstance(estimator, str): estimator = type(self.estimator) - s += "estimator : %s, " % (estimator,) + s += f"estimator : {estimator}, " if hasattr(self, "coef_"): if self.feature_names is not None: feats = self.feature_names if len(feats) == 1: - s += "feature: %s, " % feats[0] + s += f"feature: {feats[0]}, " else: - s += "features : [%s, ..., %s], " % (feats[0], feats[-1]) + s += f"features : [{feats[0]}, ..., {feats[-1]}], " s += "fit: True" else: s += "fit: False" if hasattr(self, "scores_"): - s += "scored (%s)" % self.scoring - return "" % s + s += f"scored ({self.scoring})" + return f"" def _delay_and_reshape(self, X, y=None): """Delay and reshape the variables.""" @@ -187,17 +187,14 @@ def fit(self, X, y): """ if self.scoring not in _SCORERS.keys(): raise ValueError( - "scoring must be one of %s, got" - "%s " % (sorted(_SCORERS.keys()), self.scoring) + f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} " ) from sklearn.base import clone, is_regressor X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: - raise ValueError( - "tmin (%s) must be at most tmax (%s)" % (self.tmin, self.tmax) - ) + raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})") # Initialize delays self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq) @@ -225,17 +222,16 @@ def fit(self, X, y): and estimator.fit_intercept != self.fit_intercept ): raise ValueError( - "Estimator fit_intercept (%s) != initialization " - "fit_intercept (%s), initialize ReceptiveField with the " - "same fit_intercept value or use fit_intercept=None" - % (estimator.fit_intercept, self.fit_intercept) + f"Estimator fit_intercept ({estimator.fit_intercept}) != " + f"initialization fit_intercept ({self.fit_intercept}), initialize " + "ReceptiveField with the same fit_intercept value or use " + "fit_intercept=None" ) self.fit_intercept_ = estimator.fit_intercept else: raise ValueError( - "`estimator` must be a float or an instance" - " of `BaseEstimator`," - " got type %s." % type(self.estimator) + "`estimator` must be a float or an instance of `BaseEstimator`, got " + f"type {self.estimator}." ) self.estimator_ = estimator del estimator @@ -249,8 +245,8 @@ def fit(self, X, y): # Update feature names if we have none if (self.feature_names is not None) and (len(self.feature_names) != n_feats): raise ValueError( - "n_features in X does not match feature names " - "(%s != %s)" % (n_feats, len(self.feature_names)) + f"n_features in X does not match feature names ({n_feats} != " + f"{len(self.feature_names)})" ) # Create input features @@ -377,8 +373,8 @@ def _check_dimensions(self, X, y, predict=False): y = y[:, np.newaxis, :] # epochs else: raise ValueError( - "y must be shape (n_times[, n_epochs]" - "[,n_outputs], got %s" % (y.shape,) + "y must be shape (n_times[, n_epochs][,n_outputs], got " + f"{y.shape}" ) elif X.ndim == 3: if y is not None: @@ -390,24 +386,22 @@ def _check_dimensions(self, X, y, predict=False): ) else: raise ValueError( - "X must be shape (n_times[, n_epochs]," - " n_features), got %s" % (X.shape,) + f"X must be shape (n_times[, n_epochs], n_features), got {X.shape}" ) if y is not None: if X.shape[0] != y.shape[0]: raise ValueError( - "X and y do not have the same n_times\n" - "%s != %s" % (X.shape[0], y.shape[0]) + f"X and y do not have the same n_times\n{X.shape[0]} != " + f"{y.shape[0]}" ) if X.shape[1] != y.shape[1]: raise ValueError( - "X and y do not have the same n_epochs\n" - "%s != %s" % (X.shape[1], y.shape[1]) + f"X and y do not have the same n_epochs\n{X.shape[1]} != " + f"{y.shape[1]}" ) if predict and y.shape[-1] != len(self.estimator_.coef_): raise ValueError( - "Number of outputs does not match" - " estimator coefficients dimensions" + "Number of outputs does not match estimator coefficients dimensions" ) return X, y, X_dim, y_dim @@ -517,7 +511,7 @@ def _corr_score(y_true, y, multioutput=None): for this_y in (y_true, y): if this_y.ndim != 2: raise ValueError( - "inputs must be shape (samples, outputs), got %s" % (this_y.shape,) + f"inputs must be shape (samples, outputs), got {this_y.shape}" ) return np.array([pearsonr(y_true[:, ii], y[:, ii])[0] for ii in range(y.shape[-1])]) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 06b7b010651..c8d56b88d6e 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,6 +5,7 @@ import logging import numpy as np +from scipy.sparse import issparse from ..fixes import _get_check_scoring from ..parallel import parallel_func @@ -46,7 +47,7 @@ def __init__( position=0, allow_2d=False, verbose=None, - ): # noqa: D102 + ): _check_estimator(base_estimator) self.base_estimator = base_estimator self.n_jobs = n_jobs @@ -63,7 +64,7 @@ def _estimator_type(self): return getattr(self.base_estimator, "_estimator_type", None) def __repr__(self): # noqa: D105 - repr_str = "<" + super(SlidingEstimator, self).__repr__() + repr_str = "<" + super().__repr__() if hasattr(self, "estimators_"): repr_str = repr_str[:-1] repr_str += ", fitted with %i estimators" % len(self.estimators_) @@ -254,6 +255,12 @@ def decision_function(self, X): def _check_Xy(self, X, y=None): """Aux. function to check input data.""" + # Once we require sklearn 1.1+ we should do something like: + # from sklearn.utils import check_array + # X = check_array(X, ensure_2d=False, input_name="X") + # y = check_array(y, dtype=None, ensure_2d=False, input_name="y") + if issparse(X): + raise TypeError("X should be a dense array, got sparse instead.") X = np.asarray(X) if y is not None: y = np.asarray(y) @@ -320,9 +327,8 @@ def score(self, X, y): def classes_(self): if not hasattr(self.estimators_[0], "classes_"): raise AttributeError( - "classes_ attribute available only if " - "base_estimator has it, and estimator %s does" - " not" % (self.estimators_[0],) + "classes_ attribute available only if base_estimator has it, and " + f"estimator {self.estimators_[0]} does not" ) return self.estimators_[0].classes_ @@ -466,7 +472,7 @@ class GeneralizingEstimator(SlidingEstimator): """ def __repr__(self): # noqa: D105 - repr_str = super(GeneralizingEstimator, self).__repr__() + repr_str = super().__repr__() if hasattr(self, "estimators_"): repr_str = repr_str[:-1] repr_str += ", fitted with %i estimators>" % len(self.estimators_) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 961444b122c..64e84cdbde9 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -112,8 +112,7 @@ def __init__( key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: raise ValueError( - "%s must be defined in filter parameters for %s" - % (param + "_freq", key) + f"{param + '_freq'} must be defined in filter parameters for {key}" ) val = dicts[key][param + "_freq"] if not isinstance(val, (int, float)): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 1f72eacbc48..1e8d138f83b 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -13,12 +13,14 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal from mne import Epochs, io, pick_types, read_events -from mne.decoding.csp import CSP, SPoC, _ajd_pham +from mne.decoding import CSP, Scaler, SPoC +from mne.decoding.csp import _ajd_pham +from mne.utils import catch_logging -data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" -tmin, tmax = -0.2, 0.5 +tmin, tmax = -0.1, 0.2 event_id = dict(aud_l=1, vis_l=3) # if stop is too small pca may fail in some cases, but we're okay on this file start, stop = 0, 8 @@ -245,40 +247,95 @@ def test_csp(): assert np.abs(corr) > 0.95 -def test_regularized_csp(): +# Even the "reg is None and rank is None" case should pass now thanks to the +# do_compute_rank +@pytest.mark.parametrize("ch_type", ("mag", "eeg", ("mag", "eeg"))) +@pytest.mark.parametrize("rank", (None, "correct")) +@pytest.mark.parametrize("reg", [None, 0.001, "oas"]) +def test_regularized_csp(ch_type, rank, reg): """Test Common Spatial Patterns algorithm using regularized covariance.""" pytest.importorskip("sklearn") - raw = io.read_raw_fif(raw_fname) + from sklearn.linear_model import LogisticRegression + from sklearn.model_selection import StratifiedKFold, cross_val_score + from sklearn.pipeline import make_pipeline + + raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads").load_data() + n_orig = len(raw.ch_names) + ch_decim = 2 + raw.pick_channels(raw.ch_names[::ch_decim]) + if "eeg" in ch_type: + raw.set_eeg_reference(projection=True) + n_eig = len(raw.ch_names) - len(raw.info["projs"]) + n_ch = n_orig // ch_decim + if ch_type == "eeg": + assert n_eig == n_ch - 1 + elif ch_type == "mag": + assert n_eig == n_ch - 3 + else: + assert n_eig == n_ch - 4 + if rank == "correct": + if isinstance(ch_type, str): + rank = {ch_type: n_eig} + else: + assert ch_type == ("mag", "eeg") + rank = dict( + mag=102 // ch_decim - 3, + eeg=60 // ch_decim - 1, + ) + else: + assert rank is None, rank + raw.info.normalize_proj() + raw.filter(2, 40) events = read_events(event_name) - picks = pick_types( - raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" - ) - picks = picks[1:13:3] - epochs = Epochs( - raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True - ) + # map make left and right events the same + events[events[:, 2] == 2, 2] = 1 + events[events[:, 2] == 4, 2] = 3 + epochs = Epochs(raw, events, event_id, tmin, tmax, decim=5, preload=True) + epochs.equalize_event_counts() + assert 25 < len(epochs) < 30 epochs_data = epochs.get_data(copy=False) n_channels = epochs_data.shape[1] - + assert n_channels == n_ch n_components = 3 - reg_cov = [None, 0.05, "ledoit_wolf", "oas"] - for reg in reg_cov: - csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=None) - csp.fit(epochs_data, epochs.events[:, -1]) - y = epochs.events[:, -1] - X = csp.fit_transform(epochs_data, y) - assert csp.filters_.shape == (n_channels, n_channels) - assert csp.patterns_.shape == (n_channels, n_channels) - assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) - - # test init exception - pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) - - csp.n_components = n_components - sources = csp.transform(epochs_data) - assert sources.shape[1] == n_components + + sc = Scaler(epochs.info) + epochs_data = sc.fit_transform(epochs_data) + csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank) + with catch_logging(verbose=True) as log: + X = csp.fit_transform(epochs_data, epochs.events[:, -1]) + log = log.getvalue() + assert "Setting small MAG" not in log + assert "Setting small data eigen" in log + if rank is None: + assert "Computing rank from data" in log + assert " mag: rank" not in log.lower() + assert " data: rank" in log + assert "rank (mag)" not in log.lower() + assert "rank (data)" in log + else: # if rank is passed no computation is done + assert "Computing rank" not in log + assert ": rank" not in log + assert "rank (" not in log + assert "reducing mag" not in log.lower() + assert f"Reducing data rank from {n_channels} " in log + y = epochs.events[:, -1] + assert csp.filters_.shape == (n_eig, n_channels) + assert csp.patterns_.shape == (n_eig, n_channels) + assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) + + # test init exception + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs, y) + pytest.raises(ValueError, csp.transform, epochs) + + csp.n_components = n_components + sources = csp.transform(epochs_data) + assert sources.shape[1] == n_components + + cv = StratifiedKFold(5) + clf = make_pipeline(csp, LogisticRegression(solver="liblinear")) + score = cross_val_score(clf, epochs_data, y, cv=cv, scoring="roc_auc").mean() + assert 0.75 <= score <= 1.0 def test_csp_pipeline(): diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 6b52ee7f6e1..e32664608ce 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -12,7 +12,7 @@ from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems -data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" tmin, tmax = -0.2, 0.5 diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index dfc570e374e..8585aa0170e 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -20,7 +20,7 @@ ) from mne.decoding.time_delaying_ridge import _compute_corrs, _compute_reg_neighbors -data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" @@ -73,7 +73,7 @@ def test_compute_reg_neighbors(): reg_direct, reg_csgraph, atol=1e-7, - err_msg="%s: %s" % (reg_type, (n_ch_x, n_delays)), + err_msg=f"{reg_type}: {(n_ch_x, n_delays)}", ) @@ -155,7 +155,7 @@ def test_time_delay(): del_zero = int(round(-tmin * isfreq)) for ii in range(-2, 3): idx = del_zero + ii - err_msg = "[%s,%s] (%s): %s %s" % (tmin, tmax, isfreq, ii, idx) + err_msg = f"[{tmin},{tmax}] ({isfreq}): {ii} {idx}" if 0 <= idx < X_delayed.shape[-1]: if ii == 0: assert_array_equal(X_delayed[:, :, idx], X, err_msg=err_msg) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 992efbfec30..21d4eda6d0f 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -63,19 +63,19 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_trans = sl.transform(X) - assert X.dtype == y_trans.dtype == float + assert X.dtype == y_trans.dtype == np.dtype(float) y_pred = sl.predict(X) - assert y_pred.dtype == int + assert y_pred.dtype == np.dtype(int) assert_array_equal(y_pred.shape, [n_epochs, n_time]) y_proba = sl.predict_proba(X) - assert y_proba.dtype == float + assert y_proba.dtype == np.dtype(float) assert_array_equal(y_proba.shape, [n_epochs, n_time, 2]) # score score = sl.score(X, y) assert_array_equal(score.shape, [n_time]) assert np.sum(np.abs(score)) != 0 - assert score.dtype == float + assert score.dtype == np.dtype(float) sl = SlidingEstimator(logreg) assert_equal(sl.scoring, None) @@ -94,8 +94,13 @@ def test_search_light(): with pytest.raises(ValueError, match="for two-class"): sl.score(X, y) # But check that valid ones should work with new enough sklearn + kwargs = dict() + if check_version("sklearn", "1.4"): + kwargs["response_method"] = "predict_proba" + else: + kwargs["needs_proba"] = True if "multi_class" in signature(roc_auc_score).parameters: - scoring = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovo") + scoring = make_scorer(roc_auc_score, multi_class="ovo", **kwargs) sl = SlidingEstimator(logreg, scoring=scoring) sl.fit(X, y) sl.score(X, y) # smoke test @@ -122,10 +127,15 @@ def test_search_light(): X = rng.randn(*X.shape) # randomize X to avoid AUCs in [0, 1] score_sl = sl1.score(X, y) assert_array_equal(score_sl.shape, [n_time]) - assert score_sl.dtype == float + assert score_sl.dtype == np.dtype(float) # Check that scoring was applied adequately - scoring = make_scorer(roc_auc_score, needs_threshold=True) + kwargs = dict() + if check_version("sklearn", "1.4"): + kwargs["response_method"] = ("decision_function", "predict_proba") + else: + kwargs["needs_threshold"] = True + scoring = make_scorer(roc_auc_score, **kwargs) score_manual = [ scoring(est, x, y) for est, x in zip(sl1.estimators_, X.transpose(2, 0, 1)) ] @@ -146,7 +156,7 @@ def test_search_light(): # pipeline class _LogRegTransformer(LogisticRegression): def transform(self, X): - return super(_LogRegTransformer, self).predict_proba(X)[..., 1] + return super().predict_proba(X)[..., 1] logreg_transformer = _LogRegTransformer( random_state=0, multi_class="ovr", solver="liblinear" @@ -195,9 +205,9 @@ def test_generalization_light(): # transforms y_pred = gl.predict(X) assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time]) - assert y_pred.dtype == int + assert y_pred.dtype == np.dtype(int) y_proba = gl.predict_proba(X) - assert y_proba.dtype == float + assert y_proba.dtype == np.dtype(float) assert_array_equal(y_proba.shape, [n_epochs, n_time, n_time, 2]) # transform to different datasize @@ -208,7 +218,7 @@ def test_generalization_light(): score = gl.score(X[:, :, :3], y) assert_array_equal(score.shape, [n_time, 3]) assert np.sum(np.abs(score)) != 0 - assert score.dtype == float + assert score.dtype == np.dtype(float) gl = GeneralizingEstimator(logreg, scoring="roc_auc") gl.fit(X, y) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index bdb3f74c545..e72e0ff81ad 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -19,7 +19,7 @@ def simulate_data( - freqs_sig=[9, 12], + freqs_sig=(9, 12), n_trials=100, n_channels=20, n_samples=500, diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 88a8345d4b8..1c2a29bdf8e 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -30,7 +30,7 @@ tmin, tmax = -0.2, 0.5 event_id = dict(aud_l=1, vis_l=3) start, stop = 0, 8 -data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" @@ -62,7 +62,7 @@ def test_scaler(info, method): epochs_data_t = epochs_data.transpose([1, 0, 2]) if method in ("mean", "median"): if not check_version("sklearn"): - with pytest.raises(ImportError, match="No module"): + with pytest.raises((ImportError, RuntimeError), match=" module "): Scaler(info, method) return diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index b89b4e98ac2..3ef2403bf34 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -157,12 +157,10 @@ def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method="direct", normed=F if isinstance(reg_type, str): reg_type = (reg_type,) * 2 if len(reg_type) != 2: - raise ValueError("reg_type must have two elements, got %s" % (len(reg_type),)) + raise ValueError(f"reg_type must have two elements, got {len(reg_type)}") for r in reg_type: if r not in known_types: - raise ValueError( - "reg_type entries must be one of %s, got %s" % (known_types, r) - ) + raise ValueError(f"reg_type entries must be one of {known_types}, got {r}") reg_time = reg_type[0] == "laplacian" and n_delays > 1 reg_chs = reg_type[1] == "laplacian" and n_ch_x > 1 if not reg_time and not reg_chs: @@ -290,7 +288,7 @@ def __init__( edge_correction=True, ): if tmin > tmax: - raise ValueError("tmin must be <= tmax, got %s and %s" % (tmin, tmax)) + raise ValueError(f"tmin must be <= tmax, got {tmin} and {tmax}") self.tmin = float(tmin) self.tmax = float(tmax) self.sfreq = float(sfreq) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index e085e9e2706..0555d190ddd 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -74,7 +74,7 @@ def __init__( output="complex", n_jobs=1, verbose=None, - ): # noqa: D102 + ): """Init TimeFrequency transformer.""" # Check non-average output output = _check_option("output", output, ["complex", "power", "phase"]) @@ -150,17 +150,17 @@ def transform(self, X): # Compute time-frequency Xt = _compute_tfr( X, - self.freqs, - self.sfreq, - self.method, - self.n_cycles, - True, - self.time_bandwidth, - self.use_fft, - self.decim, - self.output, - self.n_jobs, - self.verbose, + freqs=self.freqs, + sfreq=self.sfreq, + method=self.method, + n_cycles=self.n_cycles, + zero_mean=True, + time_bandwidth=self.time_bandwidth, + use_fft=self.use_fft, + decim=self.decim, + output=self.output, + n_jobs=self.n_jobs, + verbose=self.verbose, ) # Back to original shape diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index 9cb22a43355..3ba47b99700 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -107,9 +107,7 @@ class Scaler(TransformerMixin, BaseEstimator): if ``scalings`` is a dict or None). """ - def __init__( - self, info=None, scalings=None, with_mean=True, with_std=True - ): # noqa: D102 + def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.info = info self.with_mean = with_mean self.with_std = with_std @@ -333,7 +331,7 @@ def inverse_transform(self, X): X = np.asarray(X) if X.ndim not in (2, 3): raise ValueError( - "X should be of 2 or 3 dimensions but has shape " "%s" % (X.shape,) + "X should be of 2 or 3 dimensions but has shape " f"{X.shape}" ) return X.reshape(X.shape[:-1] + self.features_shape_) @@ -384,7 +382,7 @@ def __init__( normalization="length", *, verbose=None, - ): # noqa: D102 + ): self.sfreq = sfreq self.fmin = fmin self.fmax = fmax @@ -492,10 +490,9 @@ class FilterEstimator(TransformerMixin): Notes ----- - This is primarily meant for use in conjunction with - :class:`mne_realtime.RtEpochs`. In general it is not recommended in a - normal processing pipeline as it may result in edge artifacts. Use with - caution. + This is primarily meant for use in realtime applications. + In general it is not recommended in a normal processing pipeline as it may result + in edge artifacts. Use with caution. """ def __init__( @@ -513,7 +510,7 @@ def __init__( fir_design="firwin", *, verbose=None, - ): # noqa: D102 + ): self.info = info self.l_freq = l_freq self.h_freq = h_freq @@ -626,7 +623,7 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): (e.g. epochs). """ - def __init__(self, estimator, average=False): # noqa: D102 + def __init__(self, estimator, average=False): # XXX: Use _check_estimator #3381 for attr in ("fit", "transform", "fit_transform"): if not hasattr(estimator, attr): @@ -839,7 +836,7 @@ def __init__( fir_design="firwin", *, verbose=None, - ): # noqa: D102 + ): self.l_freq = l_freq self.h_freq = h_freq self.sfreq = sfreq diff --git a/mne/defaults.py b/mne/defaults.py index 8732280998f..0418feb6788 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -216,6 +216,12 @@ temperature="Temperature", eyegaze="Eye-tracking (Gaze position)", pupil="Eye-tracking (Pupil size)", + resp="Respiration monitoring channel", + chpi="Continuous head position indicator (HPI) coil channels", + exci="Flux excitation channel", + ias="Internal Active Shielding data (Triux systems)", + syst="System status channel information (Triux systems)", + whitened="Whitened data", ), mask_params=dict( marker="o", @@ -235,8 +241,8 @@ eeg_scale=4e-3, eegp_scale=20e-3, eegp_height=0.1, - ecog_scale=5e-3, - seeg_scale=5e-3, + ecog_scale=2e-3, + seeg_scale=2e-3, meg_scale=1.0, # sensors are already in SI units ref_meg_scale=1.0, dbs_scale=5e-3, @@ -278,7 +284,9 @@ combine_xyz="fro", allow_fixed_depth=True, ), - interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"), + interpolation_method=dict( + eeg="spline", meg="MNE", fnirs="nearest", ecog="spline", seeg="spline" + ), volume_options=dict( alpha=None, resolution=1.0, diff --git a/mne/dipole.py b/mne/dipole.py index 59531463da8..9dcc88c2b01 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -130,7 +130,7 @@ def __init__( nfree=None, *, verbose=None, - ): # noqa: D102 + ): self._set_times(np.array(times)) self.pos = np.array(pos) self.amplitude = np.array(amplitude) @@ -481,7 +481,7 @@ class DipoleFixed(ExtendedTimeMixin): @verbose def __init__( self, info, data, times, nave, aspect_kind, comment="", *, verbose=None - ): # noqa: D102 + ): self.info = info self.nave = nave self._aspect_kind = aspect_kind @@ -626,7 +626,7 @@ def _read_dipole_text(fname): # There is a bug in older np.loadtxt regarding skipping fields, # so just read the data ourselves (need to get name and header anyway) data = list() - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: if not (line.startswith("%") or line.startswith("#")): need_header = False @@ -642,8 +642,8 @@ def _read_dipole_text(fname): data = np.atleast_2d(np.array(data, float)) if def_line is None: raise OSError( - "Dipole text file is missing field definition " - "comment, cannot parse %s" % (fname,) + "Dipole text file is missing field definition comment, cannot parse " + f"{fname}" ) # actually parse the fields def_line = def_line.lstrip("%").lstrip("#").strip() @@ -654,7 +654,9 @@ def _read_dipole_text(fname): def_line, ) fields = re.sub( - r"\((.*?)\)", lambda match: "/" + match.group(1), fields # "Q(nAm)", etc. + r"\((.*?)\)", + lambda match: "/" + match.group(1), + fields, # "Q(nAm)", etc. ) fields = re.sub( "(begin|end) ", # "begin" and "end" with no units @@ -688,20 +690,20 @@ def _read_dipole_text(fname): missing_fields = sorted(set(required_fields) - set(fields)) if len(missing_fields) > 0: raise RuntimeError( - "Could not find necessary fields in header: %s" % (missing_fields,) + f"Could not find necessary fields in header: {missing_fields}" ) handled_fields = set(required_fields) | set(optional_fields) assert len(handled_fields) == len(required_fields) + len(optional_fields) ignored_fields = sorted(set(fields) - set(handled_fields) - {"end/ms"}) if len(ignored_fields) > 0: - warn("Ignoring extra fields in dipole file: %s" % (ignored_fields,)) + warn(f"Ignoring extra fields in dipole file: {ignored_fields}") if len(fields) != data.shape[1]: raise OSError( - "More data fields (%s) found than data columns (%s): %s" - % (len(fields), data.shape[1], fields) + f"More data fields ({len(fields)}) found than data columns ({data.shape[1]}" + f"): {fields}" ) - logger.info("%d dipole(s) found" % len(data)) + logger.info(f"{len(data)} dipole(s) found") if "end/ms" in fields: if np.diff( @@ -774,7 +776,7 @@ def _write_dipole_text(fname, dip): # NB CoordinateSystem is hard-coded as Head here with open(fname, "wb") as fid: - fid.write('# CoordinateSystem "Head"\n'.encode("utf-8")) + fid.write(b'# CoordinateSystem "Head"\n') fid.write((header + "\n").encode("utf-8")) np.savetxt(fid, out, fmt=fmt) if dip.name is not None: @@ -886,13 +888,15 @@ def _make_guesses(surf, grid, exclude, mindist, n_jobs=None, verbose=None): """Make a guess space inside a sphere or BEM surface.""" if "rr" in surf: logger.info( - "Guess surface (%s) is in %s coordinates" - % (_bem_surf_name[surf["id"]], _coord_frame_name(surf["coord_frame"])) + "Guess surface ({}) is in {} coordinates".format( + _bem_surf_name[surf["id"]], _coord_frame_name(surf["coord_frame"]) + ) ) else: logger.info( - "Making a spherical guess space with radius %7.1f mm..." - % (1000 * surf["R"]) + "Making a spherical guess space with radius {:7.1f} mm...".format( + 1000 * surf["R"] + ) ) logger.info("Filtering (grid = %6.f mm)..." % (1000 * grid)) src = _make_volume_source_space( @@ -1508,9 +1512,8 @@ def fit_dipole( r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0] inner_skull["r0"] = r0 logger.info( - "Head origin : " - "%6.1f %6.1f %6.1f mm rad = %6.1f mm." - % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], 1000 * R) + f"Head origin : {1000 * r0[0]:6.1f} {1000 * r0[1]:6.1f} " + f"{1000 * r0[2]:6.1f} mm rad = {1000 * R:6.1f} mm." ) del R, r0 else: @@ -1522,22 +1525,20 @@ def fit_dipole( # Use the minimum distance to the MEG sensors as the radius then R = np.dot( np.linalg.inv(info["dev_head_t"]["trans"]), np.hstack([r0, [1.0]]) - )[ - :3 - ] # r0 -> device + )[:3] # r0 -> device R = R - [ info["chs"][pick]["loc"][:3] for pick in pick_types(info, meg=True, exclude=[]) ] if len(R) == 0: raise RuntimeError( - "No MEG channels found, but MEG-only " "sphere model used" + "No MEG channels found, but MEG-only sphere model used" ) R = np.min(np.sqrt(np.sum(R * R, axis=1))) # use dist to sensors kind = "max_rad" logger.info( - "Sphere model : origin at (% 7.2f % 7.2f % 7.2f) mm, " - "%s = %6.1f mm" % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], kind, R) + f"Sphere model : origin at ({1000 * r0[0]: 7.2f} {1000 * r0[1]: 7.2f} " + f"{1000 * r0[2]: 7.2f}) mm, {kind} = {R:6.1f} mm" ) inner_skull = dict(R=R, r0=r0) # NB sphere model defined in head frame del R, r0 @@ -1547,20 +1548,22 @@ def fit_dipole( fixed_position = True pos = np.array(pos, float) if pos.shape != (3,): - raise ValueError( - "pos must be None or a 3-element array-like," " got %s" % (pos,) - ) - logger.info("Fixed position : %6.1f %6.1f %6.1f mm" % tuple(1000 * pos)) + raise ValueError(f"pos must be None or a 3-element array-like, got {pos}") + logger.info( + "Fixed position : {:6.1f} {:6.1f} {:6.1f} mm".format(*tuple(1000 * pos)) + ) if ori is not None: ori = np.array(ori, float) if ori.shape != (3,): raise ValueError( - "oris must be None or a 3-element array-like," " got %s" % (ori,) + f"oris must be None or a 3-element array-like, got {ori}" ) norm = np.sqrt(np.sum(ori * ori)) if not np.isclose(norm, 1): - raise ValueError("ori must be a unit vector, got length %s" % (norm,)) - logger.info("Fixed orientation : %6.4f %6.4f %6.4f mm" % tuple(ori)) + raise ValueError(f"ori must be a unit vector, got length {norm}") + logger.info( + "Fixed orientation : {:6.4f} {:6.4f} {:6.4f} mm".format(*tuple(ori)) + ) else: logger.info("Free orientation : ") fit_n_jobs = 1 # only use 1 job to do the guess fitting @@ -1572,11 +1575,11 @@ def fit_dipole( guess_mindist = max(0.005, min_dist_to_inner_skull) guess_exclude = 0.02 - logger.info("Guess grid : %6.1f mm" % (1000 * guess_grid,)) + logger.info(f"Guess grid : {1000 * guess_grid:6.1f} mm") if guess_mindist > 0.0: - logger.info("Guess mindist : %6.1f mm" % (1000 * guess_mindist,)) + logger.info(f"Guess mindist : {1000 * guess_mindist:6.1f} mm") if guess_exclude > 0: - logger.info("Guess exclude : %6.1f mm" % (1000 * guess_exclude,)) + logger.info(f"Guess exclude : {1000 * guess_exclude:6.1f} mm") logger.info(f"Using {accuracy} MEG coil definitions.") fit_n_jobs = n_jobs cov = _ensure_cov(cov) @@ -1584,7 +1587,7 @@ def fit_dipole( _print_coord_trans(mri_head_t) _print_coord_trans(info["dev_head_t"]) - logger.info("%d bad channels total" % len(info["bads"])) + logger.info(f"{len(info['bads'])} bad channels total") # Forward model setup (setup_forward_model from setup.c) ch_types = evoked.get_channel_types() @@ -1645,8 +1648,8 @@ def fit_dipole( ) if check <= 0: raise ValueError( - "fixed position is %0.1fmm outside the inner " - "skull boundary" % (-1000 * check,) + f"fixed position is {-1000 * check:0.1f}mm outside the inner skull " + "boundary" ) # C code computes guesses w/sphere model for speed, don't bother here diff --git a/mne/epochs.py b/mne/epochs.py index 864f4021b42..9e48936f8bf 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -16,6 +16,7 @@ from collections import Counter from copy import deepcopy from functools import partial +from inspect import getfullargspec import numpy as np from scipy.interpolate import interp1d @@ -39,7 +40,7 @@ pick_info, ) from ._fiff.proj import ProjMixin, setup_proj -from ._fiff.tag import read_tag, read_tag_info +from ._fiff.tag import _read_tag_header, read_tag from ._fiff.tree import dir_tree_find from ._fiff.utils import _make_split_fnames from ._fiff.write import ( @@ -62,6 +63,7 @@ EpochAnnotationsMixin, _read_annotations_fif, _write_annotations, + events_from_annotations, ) from .baseline import _check_baseline, _log_rescale, rescale from .bem import _check_origin @@ -73,6 +75,7 @@ from .html_templates import _get_html_template from .parallel import parallel_func from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method +from .time_frequency.tfr import AverageTFR, EpochsTFR from .utils import ( ExtendedTimeMixin, GetEpochsMixin, @@ -416,6 +419,8 @@ class BaseEpochs( filename : str | None The filename (if the epochs are read from disk). %(metadata_epochs)s + + .. versionadded:: 0.16 %(event_repeated_epochs)s %(raw_sfreq)s annotations : instance of mne.Annotations | None @@ -465,7 +470,7 @@ def __init__( raw_sfreq=None, annotations=None, verbose=None, - ): # noqa: D102 + ): if events is not None: # RtEpochs can have events=None events = _ensure_events(events) # Allow reading empty epochs (ToDo: Maybe not anymore in the future) @@ -487,10 +492,7 @@ def __init__( if events is not None: # RtEpochs can have events=None for key, val in self.event_id.items(): if val not in events[:, 2]: - msg = "No matching events found for %s " "(event id %i)" % ( - key, - val, - ) + msg = f"No matching events found for {key} (event id {val})" _on_missing(on_missing, msg) # ensure metadata matches original events size @@ -510,8 +512,8 @@ def __init__( selection = np.array(selection, int) if selection.shape != (len(selected),): raise ValueError( - "selection must be shape %s got shape %s" - % (selected.shape, selection.shape) + f"selection must be shape {selected.shape} got shape " + f"{selection.shape}" ) self.selection = selection if drop_log is None: @@ -667,7 +669,7 @@ def __init__( # do the rest valid_proj = [True, "delayed", False] if proj not in valid_proj: - raise ValueError('"proj" must be one of %s, not %s' % (valid_proj, proj)) + raise ValueError(f'"proj" must be one of {valid_proj}, not {proj}') if proj == "delayed": self._do_delayed_proj = True logger.info("Entering delayed SSP mode.") @@ -698,7 +700,7 @@ def _check_consistency(self): if hasattr(self, "events"): assert len(self.selection) == len(self.events) assert len(self.drop_log) >= len(self.events) - assert len(self.selection) == sum((len(dl) == 0 for dl in self.drop_log)) + assert len(self.selection) == sum(len(dl) == 0 for dl in self.drop_log) assert hasattr(self, "_times_readonly") assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) @@ -789,7 +791,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): self.baseline = baseline return self - def _reject_setup(self, reject, flat): + def _reject_setup(self, reject, flat, *, allow_callable=False): """Set self._reject_time and self._channel_type_idx.""" idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() @@ -801,7 +803,7 @@ def _reject_setup(self, reject, flat): ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) + raise KeyError(f"Unknown channel types found in {kind}: {bads}") for key in idx.keys(): # don't throw an error if rejection/flat would do nothing @@ -812,17 +814,25 @@ def _reject_setup(self, reject, flat): # self.allow_missing_reject_keys check to allow users to # provide keys that don't exist in data raise ValueError( - "No %s channel found. Cannot reject based on " - "%s." % (key.upper(), key.upper()) + f"No {key.upper()} channel found. Cannot reject based on " + f"{key.upper()}." ) - # check for invalid values - for rej, kind in zip((reject, flat), ("Rejection", "Flat")): - for key, val in rej.items(): - if val is None or val < 0: - raise ValueError( - '%s value must be a number >= 0, not "%s"' % (kind, val) - ) + # check for invalid values + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): + for key, val in rej.items(): + name = f"{kind} dict value for {key}" + if callable(val) and allow_callable: + continue + extra_str = "" + if allow_callable: + extra_str = "or callable" + _validate_type(val, "numeric", name, extra=extra_str) + if val is None or val < 0: + raise ValueError( + f"If using numerical {name} criteria, the value " + f"must be >= 0, not {repr(val)}" + ) # now check to see if our rejection and flat are getting more # restrictive @@ -840,6 +850,9 @@ def _reject_setup(self, reject, flat): reject[key] = old_reject[key] # make sure new thresholds are at least as stringent as the old ones for key in reject: + # Skip this check if old_reject and reject are callables + if callable(reject[key]) and allow_callable: + continue if key in old_reject and reject[key] > old_reject[key]: raise ValueError( bad_msg.format( @@ -855,6 +868,8 @@ def _reject_setup(self, reject, flat): for key in set(old_flat) - set(flat): flat[key] = old_flat[key] for key in flat: + if callable(flat[key]) and allow_callable: + continue if key in old_flat and flat[key] < old_flat[key]: raise ValueError( bad_msg.format( @@ -1150,8 +1165,8 @@ def _compute_aggregate(self, picks, mode="mean"): assert len(self.events) == len(self._data) if data.shape != self._data.shape[1:]: raise RuntimeError( - "You passed a function that resulted n data of shape {}, " - "but it should be {}.".format(data.shape, self._data.shape[1:]) + f"You passed a function that resulted n data of shape " + f"{data.shape}, but it should be {self._data.shape[1:]}." ) else: if mode not in {"mean", "std"}: @@ -1397,7 +1412,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): Dropping bad epochs can be done multiple times with different ``reject`` and ``flat`` parameters. However, once an epoch is dropped, it is dropped forever, so if more lenient thresholds may - subsequently be applied, `epochs.copy ` should be + subsequently be applied, :meth:`epochs.copy ` should be used. """ if reject == "existing": @@ -1408,7 +1423,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): flat = self.flat if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') - self._reject_setup(reject, flat) + self._reject_setup(reject, flat, allow_callable=True) self._get_data(out=False, verbose=verbose) return self @@ -1524,8 +1539,9 @@ def drop(self, indices, reason="USER", verbose=None): Set epochs to remove by specifying indices to remove or a boolean mask to apply (where True values get removed). Events are correspondingly modified. - reason : str - Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). + reason : list | tuple | str + Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Reason(s) are applied to all indices specified. Default: 'USER'. %(verbose)s @@ -1537,9 +1553,11 @@ def drop(self, indices, reason="USER", verbose=None): indices = np.atleast_1d(indices) if indices.ndim > 1: - raise ValueError("indices must be a scalar or a 1-d array") + raise TypeError("indices must be a scalar or a 1-d array") + # Check if indices and reasons are of the same length + # if using collection to drop epochs - if indices.dtype == bool: + if indices.dtype == np.dtype(bool): indices = np.where(indices)[0] try_idx = np.where(indices < 0, indices + len(self.events), indices) @@ -1958,22 +1976,52 @@ def apply_function( if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) + args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs + if channel_wise is False: + if ("ch_idx" in args) or ("ch_name" in args): + raise ValueError( + "apply_function cannot access ch_idx or ch_name " + "when channel_wise=False" + ) + if "ch_idx" in args: + logger.info("apply_function requested to access ch_idx") + if "ch_name" in args: + logger.info("apply_function requested to access ch_name") + if channel_wise: parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) if n_jobs == 1: - _fun = partial(_check_fun, fun, **kwargs) + _fun = partial(_check_fun, fun) # modify data inplace to save memory - for idx in picks: - self._data[:, idx, :] = np.apply_along_axis( - _fun, -1, data_in[:, idx, :] + for ch_idx in picks: + if "ch_idx" in args: + kwargs.update(ch_idx=ch_idx) + if "ch_name" in args: + kwargs.update(ch_name=self.info["ch_names"][ch_idx]) + self._data[:, ch_idx, :] = np.apply_along_axis( + _fun, -1, data_in[:, ch_idx, :], **kwargs ) else: # use parallel function + _fun = partial(np.apply_along_axis, fun, -1) data_picks_new = parallel( - p_fun(fun, data_in[:, p, :], **kwargs) for p in picks + p_fun( + _fun, + data_in[:, ch_idx, :], + **kwargs, + **{ + k: v + for k, v in [ + ("ch_name", self.info["ch_names"][ch_idx]), + ("ch_idx", ch_idx), + ] + if k in args + }, + ) + for ch_idx in picks ) - for pp, p in enumerate(picks): - self._data[:, p, :] = data_picks_new[pp] + for run_idx, ch_idx in enumerate(picks): + self._data[:, ch_idx, :] = data_picks_new[run_idx] else: self._data = _check_fun(fun, data_in, **kwargs) @@ -1986,9 +2034,9 @@ def filename(self): def __repr__(self): """Build string representation.""" - s = " %s events " % len(self.events) + s = f" {len(self.events)} events " s += "(all good)" if self._bad_dropped else "(good & bad)" - s += ", %g – %g s" % (self.tmin, self.tmax) + s += f", {self.tmin:g} – {self.tmax:g} s" s += ", baseline " if self.baseline is None: s += "off" @@ -2002,12 +2050,12 @@ def __repr__(self): ): s += " (baseline period was cropped after baseline correction)" - s += ", ~%s" % (sizeof_fmt(self._size),) - s += ", data%s loaded" % ("" if self.preload else " not") + s += f", ~{sizeof_fmt(self._size)}" + s += f", data{'' if self.preload else ' not'} loaded" s += ", with metadata" if self.metadata is not None else "" max_events = 10 counts = [ - "%r: %i" % (k, sum(self.events[:, 2] == v)) + f"{k!r}: {sum(self.events[:, 2] == v)}" for k, v in list(self.event_id.items())[:max_events] ] if len(self.event_id) > 0: @@ -2017,7 +2065,7 @@ def __repr__(self): s += f"\n and {not_shown_events} more events ..." class_name = self.__class__.__name__ class_name = "Epochs" if class_name == "BaseEpochs" else class_name - return "<%s | %s>" % (class_name, s) + return f"<{class_name} | {s}>" @repr_html def _repr_html_(self): @@ -2166,7 +2214,14 @@ def save( ) # check for file existence and expand `~` if present - fname = str(_check_fname(fname=fname, overwrite=overwrite)) + fname = str( + _check_fname( + fname=fname, + overwrite=overwrite, + check_bids_split=True, + name="fname", + ) + ) split_size_bytes = _get_split_size(split_size) @@ -2402,7 +2457,7 @@ def equalize_event_counts(self, event_ids=None, method="mintime"): # 3. do this for every input event_ids = [ [ - k for k in ids if all((tag in k.split("/") for tag in id_)) + k for k in ids if all(tag in k.split("/") for tag in id_) ] # ids matching all tags if all(id__ not in ids for id__ in id_) else id_ # straight pass for non-tag inputs @@ -2411,7 +2466,7 @@ def equalize_event_counts(self, event_ids=None, method="mintime"): for ii, id_ in enumerate(event_ids): if len(id_) == 0: raise KeyError( - f"{orig_ids[ii]} not found in the epoch " "object's event_id." + f"{orig_ids[ii]} not found in the epoch object's event_id." ) elif len({sub_id in ids for sub_id in id_}) != 1: err = ( @@ -2434,8 +2489,8 @@ def equalize_event_counts(self, event_ids=None, method="mintime"): for eq in event_ids: eq_inds.append(self._keys_to_idx(eq)) - event_times = [self.events[e, 0] for e in eq_inds] - indices = _get_drop_indices(event_times, method) + sample_nums = [self.events[e, 0] for e in eq_inds] + indices = _get_drop_indices(sample_nums, method) # need to re-index indices indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)]) self.drop(indices, reason="EQUALIZED_COUNT") @@ -2507,6 +2562,139 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + average=False, + return_itc=False, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of epoched data. + + Parameters + ---------- + %(method_tfr_epochs)s + %(freqs_tfr_epochs)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + average : bool + Whether to return average power across epochs (instead of single-trial + power). ``average=True`` is not compatible with ``output="complex"`` or + ``output="phase"``. Ignored if ``method="stockwell"`` (Stockwell method + *requires* averaging). Default is ``False``. + return_itc : bool + Whether to return inter-trial coherence (ITC) as well as power estimates. + If ``True`` then must specify ``average=True`` (or ``method="stockwell", + average="auto"``). Default is ``False``. + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_epochs_tfr)s + + Returns + ------- + tfr : instance of EpochsTFR or AverageTFR + The time-frequency-resolved power estimates. + itc : instance of AverageTFR + The inter-trial coherence (ITC). Only returned if ``return_itc=True``. + + Notes + ----- + If ``average=True`` (or ``method="stockwell", average="auto"``) the result will + be an :class:`~mne.time_frequency.AverageTFR` instead of an + :class:`~mne.time_frequency.EpochsTFR`. + + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + if method == "stockwell" and not average: # stockwell method *must* average + logger.info( + 'Requested `method="stockwell"` so ignoring parameter `average=False`.' + ) + average = True + if average: + # augment `output` value for use by tfr_array_* functions + _check_option("output", output, ("power",), extra=" when average=True") + method_kw["output"] = "avg_power_itc" if return_itc else "avg_power" + else: + msg = ( + "compute_tfr() got incompatible parameters `average=False` and `{}` " + "({} requires averaging over epochs)." + ) + if return_itc: + raise ValueError(msg.format("return_itc=True", "computing ITC")) + if method == "stockwell": + raise ValueError(msg.format('method="stockwell"', "Stockwell method")) + # `average` and `return_itc` both False, so "phase" and "complex" are OK + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + + if method == "stockwell": + method_kw["return_itc"] = return_itc + method_kw.pop("output") + if isinstance(freqs, str): + _check_option("freqs", freqs, "auto") + else: + _validate_type(freqs, "array-like") + _check_option( + "freqs", np.array(freqs).shape, ((2,),), extra=" (wrong shape)." + ) + if average: + out = AverageTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + # tfr_array_stockwell always returns ITC (but sometimes it's None) + if hasattr(out, "_itc"): + if out._itc is not None: + state = out.__getstate__() + state["data"] = out._itc + state["data_type"] = "Inter-trial coherence" + itc = AverageTFR(inst=state) + del out._itc + return out, itc + del out._itc + return out + # now handle average=False + return EpochsTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def plot_psd( self, @@ -2661,7 +2849,7 @@ def to_data_frame( # prepare extra columns / multiindex mindex = list() times = np.tile(times, n_epochs) - times = _convert_times(self, times, time_format) + times = _convert_times(times, time_format, self.info["meas_date"]) mindex.append(("time", times)) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] @@ -2776,14 +2964,15 @@ def make_metadata( A mapping from event names (keys) to event IDs (values). The event names will be incorporated as columns of the returned metadata :class:`~pandas.DataFrame`. - tmin, tmax : float | None - Start and end of the time interval for metadata generation in seconds, relative - to the time-locked event of the respective time window (the "row events"). + tmin, tmax : float | str | list of str | None + If float, start and end of the time interval for metadata generation in seconds, + relative to the time-locked event of the respective time window (the "row + events"). .. note:: If you are planning to attach the generated metadata to `~mne.Epochs` and intend to include only events that fall inside - your epochs time interval, pass the same ``tmin`` and ``tmax`` + your epoch's time interval, pass the same ``tmin`` and ``tmax`` values here as you use for your epochs. If ``None``, the time window used for metadata generation is bounded by the @@ -2796,8 +2985,17 @@ def make_metadata( the first row event. If ``tmax=None``, the last time window for metadata generation ends with the last event in ``events``. + If a string or a list of strings, the events bounding the metadata around each + "row event". For ``tmin``, the events are assumed to occur **before** the row + event, and for ``tmax``, the events are assumed to occur **after** – unless + ``tmin`` or ``tmax`` are equal to a row event, in which case the row event + serves as the bound. + .. versionchanged:: 1.6.0 Added support for ``None``. + + .. versionadded:: 1.7.0 + Added support for strings. sfreq : float The sampling frequency of the data from which the events array was extracted. @@ -2883,8 +3081,8 @@ def make_metadata( be attached; it may well be much shorter or longer, or not overlap at all, if desired. This can be useful, for example, to include events that occurred before or after an epoch, e.g. during the inter-trial interval. - If either ``tmin``, ``tmax``, or both are ``None``, the time window will - typically vary, too. + If either ``tmin``, ``tmax``, or both are ``None``, or a string referring e.g. to a + response event, the time window will typically vary, too. .. versionadded:: 0.23 @@ -2897,11 +3095,11 @@ def make_metadata( _validate_type(events, types=("array-like",), item_name="events") _validate_type(event_id, types=(dict,), item_name="event_id") _validate_type(sfreq, types=("numeric",), item_name="sfreq") - _validate_type(tmin, types=("numeric", None), item_name="tmin") - _validate_type(tmax, types=("numeric", None), item_name="tmax") - _validate_type(row_events, types=(None, str, list, tuple), item_name="row_events") - _validate_type(keep_first, types=(None, str, list, tuple), item_name="keep_first") - _validate_type(keep_last, types=(None, str, list, tuple), item_name="keep_last") + _validate_type(tmin, types=("numeric", str, "array-like", None), item_name="tmin") + _validate_type(tmax, types=("numeric", str, "array-like", None), item_name="tmax") + _validate_type(row_events, types=(None, str, "array-like"), item_name="row_events") + _validate_type(keep_first, types=(None, str, "array-like"), item_name="keep_first") + _validate_type(keep_last, types=(None, str, "array-like"), item_name="keep_last") if not event_id: raise ValueError("event_id dictionary must contain at least one entry") @@ -2918,6 +3116,19 @@ def _ensure_list(x): keep_first = _ensure_list(keep_first) keep_last = _ensure_list(keep_last) + # Turn tmin, tmax into a list if they're strings or arrays of strings + try: + _validate_type(tmin, types=(str, "array-like"), item_name="tmin") + tmin = _ensure_list(tmin) + except TypeError: + pass + + try: + _validate_type(tmax, types=(str, "array-like"), item_name="tmax") + tmax = _ensure_list(tmax) + except TypeError: + pass + keep_first_and_last = set(keep_first) & set(keep_last) if keep_first_and_last: raise ValueError( @@ -2937,18 +3148,40 @@ def _ensure_list(x): f"{param_name}, cannot be found in event_id dictionary" ) - event_name_diff = sorted(set(row_events) - set(event_id.keys())) - if event_name_diff: - raise ValueError( - f"Present in row_events, but missing from event_id: " - f'{", ".join(event_name_diff)}' + # If tmin, tmax are strings, ensure these event names are present in event_id + def _diff_input_strings_vs_event_id(input_strings, input_name, event_id): + event_name_diff = sorted(set(input_strings) - set(event_id.keys())) + if event_name_diff: + raise ValueError( + f"Present in {input_name}, but missing from event_id: " + f'{", ".join(event_name_diff)}' + ) + + _diff_input_strings_vs_event_id( + input_strings=row_events, input_name="row_events", event_id=event_id + ) + if isinstance(tmin, list): + _diff_input_strings_vs_event_id( + input_strings=tmin, input_name="tmin", event_id=event_id + ) + if isinstance(tmax, list): + _diff_input_strings_vs_event_id( + input_strings=tmax, input_name="tmax", event_id=event_id ) - del event_name_diff # First and last sample of each epoch, relative to the time-locked event # This follows the approach taken in mne.Epochs - start_sample = None if tmin is None else int(round(tmin * sfreq)) - stop_sample = None if tmax is None else int(round(tmax * sfreq)) + 1 + # For strings and None, we don't know the start and stop samples in advance as the + # time window can vary. + if isinstance(tmin, (type(None), list)): + start_sample = None + else: + start_sample = int(round(tmin * sfreq)) + + if isinstance(tmax, (type(None), list)): + stop_sample = None + else: + stop_sample = int(round(tmax * sfreq)) + 1 # Make indexing easier # We create the DataFrame before subsetting the events so we end up with @@ -2979,11 +3212,11 @@ def _ensure_list(x): *last_cols, ] - data = np.empty((len(events_df), len(columns))) + data = np.empty((len(events_df), len(columns)), float) metadata = pd.DataFrame(data=data, columns=columns, index=events_df.index) # Event names - metadata.iloc[:, 0] = "" + metadata["event_name"] = "" # Event times start_idx = 1 @@ -2992,7 +3225,7 @@ def _ensure_list(x): # keep_first and keep_last names start_idx = stop_idx - metadata.iloc[:, start_idx:] = None + metadata[columns[start_idx:]] = "" # We're all set, let's iterate over all events and fill in in the # respective cells in the metadata. We will subset this to include only @@ -3002,14 +3235,47 @@ def _ensure_list(x): metadata.loc[row_idx, "event_name"] = id_to_name_map[row_event.id] # Determine which events fall into the current time window - if start_sample is None: + if start_sample is None and isinstance(tmin, list): + # Lower bound is the the current or the closest previpus event with a name + # in "tmin"; if there is no such event (e.g., beginning of the recording is + # being approached), the upper lower becomes the last event in the + # recording. + prev_matching_events = events_df.loc[ + (events_df["sample"] <= row_event.sample) + & (events_df["id"].isin([event_id[name] for name in tmin])), + :, + ] + if prev_matching_events.size == 0: + # No earlier matching event. Use the current one as the beginning of the + # time window. This may occur at the beginning of a recording. + window_start_sample = row_event.sample + else: + # At least one earlier matching event. Use the closest one. + window_start_sample = prev_matching_events.iloc[-1]["sample"] + elif start_sample is None: # Lower bound is the current event. window_start_sample = row_event.sample else: # Lower bound is determined by tmin. window_start_sample = row_event.sample + start_sample - if stop_sample is None: + if stop_sample is None and isinstance(tmax, list): + # Upper bound is the the current or the closest following event with a name + # in "tmax"; if there is no such event (e.g., end of the recording is being + # approached), the upper bound becomes the last event in the recording. + next_matching_events = events_df.loc[ + (events_df["sample"] >= row_event.sample) + & (events_df["id"].isin([event_id[name] for name in tmax])), + :, + ] + if next_matching_events.size == 0: + # No matching event after the current one; use the end of the recording + # as upper bound. This may occur at the end of a recording. + window_stop_sample = events_df["sample"].iloc[-1] + else: + # At least one matching later event. Use the closest one.. + window_stop_sample = next_matching_events.iloc[0]["sample"] + elif stop_sample is None: # Upper bound: next event of the same type, or the last event (of # any type) if no later event of the same type can be found. next_events = events_df.loc[ @@ -3104,6 +3370,40 @@ def _ensure_list(x): return metadata, events, event_id +def _events_from_annotations(raw, events, event_id, annotations, on_missing): + """Generate events and event_ids from annotations.""" + events, event_id_tmp = events_from_annotations(raw) + if events.size == 0: + raise RuntimeError( + "No usable annotations found in the raw object. " + "Either `events` must be provided or the raw " + "object must have annotations to construct epochs" + ) + if any(raw.annotations.duration > 0): + logger.info( + "Ignoring annotation durations and creating fixed-duration epochs " + "around annotation onsets." + ) + if event_id is None: + event_id = event_id_tmp + # if event_id is the names of events, map to events integers + if isinstance(event_id, str): + event_id = [event_id] + if isinstance(event_id, (list, tuple, set)): + if not set(event_id).issubset(set(event_id_tmp)): + msg = ( + "No matching annotations found for event_id(s) " + f"{set(event_id) - set(event_id_tmp)}" + ) + _on_missing(on_missing, msg) + # remove extras if on_missing not error + event_id = set(event_id) & set(event_id_tmp) + event_id = {my_id: event_id_tmp[my_id] for my_id in event_id} + # remove any non-selected annotations + annotations.delete(~np.isin(raw.annotations.description, list(event_id))) + return events, event_id, annotations + + @fill_doc class Epochs(BaseEpochs): """Epochs extracted from a Raw instance. @@ -3111,7 +3411,16 @@ class Epochs(BaseEpochs): Parameters ---------- %(raw_epochs)s + + .. note:: + If ``raw`` contains annotations, ``Epochs`` can be constructed around + ``raw.annotations.onset``, but note that the durations of the annotations + are ignored in this case. %(events_epochs)s + + .. versionchanged:: 1.7 + Allow ``events=None`` to use ``raw.annotations.onset`` as the source of + epoch times. %(event_id)s %(epochs_tmin_tmax)s %(baseline_epochs)s @@ -3129,20 +3438,18 @@ class Epochs(BaseEpochs): %(on_missing_epochs)s %(reject_by_annotation_epochs)s %(metadata_epochs)s + + .. versionadded:: 0.16 %(event_repeated_epochs)s %(verbose)s Attributes ---------- %(info_not_none)s - event_id : dict - Names of conditions corresponding to event_ids. + %(event_id_attr)s ch_names : list of string List of channel names. - selection : array - List of indices of selected events (not dropped or ignored etc.). For - example, if the original event array had 4 events and the second event - has been dropped, this attribute would be np.array([0, 2, 3]). + %(selection_attr)s preload : bool Indicates whether epochs are in memory. drop_log : tuple of tuple @@ -3160,6 +3467,10 @@ class Epochs(BaseEpochs): See :meth:`~mne.Epochs.equalize_event_counts` - 'USER' For user-defined reasons (see :meth:`~mne.Epochs.drop`). + + When dropping based on flat or reject parameters the tuple of + reasons contains a tuple of channels that satisfied the rejection + criteria. filename : str The filename of the object. times : ndarray @@ -3212,7 +3523,7 @@ class Epochs(BaseEpochs): def __init__( self, raw, - events, + events=None, event_id=None, tmin=-0.2, tmax=0.5, @@ -3231,7 +3542,7 @@ def __init__( metadata=None, event_repeated="error", verbose=None, - ): # noqa: D102 + ): from .io import BaseRaw if not isinstance(raw, BaseRaw): @@ -3240,6 +3551,7 @@ def __init__( "instance of mne.io.BaseRaw" ) info = deepcopy(raw.info) + annotations = raw.annotations.copy() # proj is on when applied in Raw proj = proj or raw.proj @@ -3249,8 +3561,14 @@ def __init__( # keep track of original sfreq (needed for annotations) raw_sfreq = raw.info["sfreq"] + # get events from annotations if no events given + if events is None: + events, event_id, annotations = _events_from_annotations( + raw, events, event_id, annotations, on_missing + ) + # call BaseEpochs constructor - super(Epochs, self).__init__( + super().__init__( info, None, events, @@ -3273,7 +3591,7 @@ def __init__( event_repeated=event_repeated, verbose=verbose, raw_sfreq=raw_sfreq, - annotations=raw.annotations, + annotations=annotations, ) @verbose @@ -3350,6 +3668,8 @@ class EpochsArray(BaseEpochs): %(proj_epochs)s %(on_missing_epochs)s %(metadata_epochs)s + + .. versionadded:: 0.16 %(selection)s %(drop_log)s @@ -3403,23 +3723,23 @@ def __init__( drop_log=None, raw_sfreq=None, verbose=None, - ): # noqa: D102 + ): dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 3: raise ValueError( - "Data must be a 3D array of shape (n_epochs, " "n_channels, n_samples)" + "Data must be a 3D array of shape (n_epochs, n_channels, n_samples)" ) if len(info["ch_names"]) != data.shape[1]: - raise ValueError("Info and data must have same number of " "channels.") + raise ValueError("Info and data must have same number of channels.") if events is None: n_epochs = len(data) events = _gen_events(n_epochs) info = info.copy() # do not modify original info tmax = (data.shape[2] - 1) / info["sfreq"] + tmin - super(EpochsArray, self).__init__( + super().__init__( info, data, events, @@ -3517,7 +3837,7 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): def equalize_epoch_counts(epochs_list, method="mintime"): - """Equalize the number of trials in multiple Epoch instances. + """Equalize the number of trials in multiple Epochs or EpochsTFR instances. Parameters ---------- @@ -3544,33 +3864,32 @@ def equalize_epoch_counts(epochs_list, method="mintime"): -------- >>> equalize_epoch_counts([epochs1, epochs2]) # doctest: +SKIP """ - if not all(isinstance(e, BaseEpochs) for e in epochs_list): + if not all(isinstance(epoch, (BaseEpochs, EpochsTFR)) for epoch in epochs_list): raise ValueError("All inputs must be Epochs instances") # make sure bad epochs are dropped - for e in epochs_list: - if not e._bad_dropped: - e.drop_bad() - event_times = [e.events[:, 0] for e in epochs_list] - indices = _get_drop_indices(event_times, method) - for e, inds in zip(epochs_list, indices): - e.drop(inds, reason="EQUALIZED_COUNT") + for epoch in epochs_list: + if not epoch._bad_dropped: + epoch.drop_bad() + sample_nums = [epoch.events[:, 0] for epoch in epochs_list] + indices = _get_drop_indices(sample_nums, method) + for epoch, inds in zip(epochs_list, indices): + epoch.drop(inds, reason="EQUALIZED_COUNT") -def _get_drop_indices(event_times, method): +def _get_drop_indices(sample_nums, method): """Get indices to drop from multiple event timing lists.""" - small_idx = np.argmin([e.shape[0] for e in event_times]) - small_e_times = event_times[small_idx] + small_idx = np.argmin([e.shape[0] for e in sample_nums]) + small_epoch_indices = sample_nums[small_idx] _check_option("method", method, ["mintime", "truncate"]) indices = list() - for e in event_times: + for event in sample_nums: if method == "mintime": - mask = _minimize_time_diff(small_e_times, e) + mask = _minimize_time_diff(small_epoch_indices, event) else: - mask = np.ones(e.shape[0], dtype=bool) - mask[small_e_times.shape[0] :] = False + mask = np.ones(event.shape[0], dtype=bool) + mask[small_epoch_indices.shape[0] :] = False indices.append(np.where(np.logical_not(mask))[0]) - return indices @@ -3584,7 +3903,7 @@ def _minimize_time_diff(t_shorter, t_longer): idx = np.argmin(np.abs(t_longer - t_shorter)) keep[idx] = True return keep - scores = np.ones((len(t_longer))) + scores = np.ones(len(t_longer)) x1 = np.arange(len(t_shorter)) # The first set of keep masks to test kwargs = dict(copy=False, bounds_error=False, assume_sorted=True) @@ -3616,11 +3935,13 @@ def _is_good( reject, flat, full_report=False, - ignore_chs=[], + ignore_chs=(), verbose=None, ): """Test if data segment e is good according to reject and flat. + The reject and flat parameters can accept functions as values. + If full_report=True, it will give True/False as well as a list of all offending channels. """ @@ -3628,31 +3949,60 @@ def _is_good( has_printed = False checkable = np.ones(len(ch_names), dtype=bool) checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False + for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: - for key, thresh in refl.items(): + for key, refl in refl.items(): + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: e_idx = e[idx] - deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) checkable_idx = checkable[idx] - idx_deltas = np.where( - np.logical_and(f(deltas, thresh), checkable_idx) - )[0] + # Check if criterion is a function and apply it + if callable(criterion): + result = criterion(e_idx) + _validate_type(result, tuple, "reject/flat output") + if len(result) != 2: + raise TypeError( + "Function criterion must return a tuple of length 2" + ) + cri_truth, reasons = result + _validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool") + _validate_type( + reasons, (str, list, tuple), reasons, "str, list, or tuple" + ) + idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[ + 0 + ] + else: + deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) + idx_deltas = np.where( + np.logical_and(f(deltas, criterion), checkable_idx) + )[0] if len(idx_deltas) > 0: - bad_names = [ch_names[idx[i]] for i in idx_deltas] - if not has_printed: - logger.info( - " Rejecting %s epoch based on %s : " - "%s" % (t, name, bad_names) - ) - has_printed = True - if not full_report: - return False + # Check to verify that refl is a callable that returns + # (bool, reason). Reason must be a str/list/tuple. + # If using tuple + if callable(refl): + if isinstance(reasons, str): + reasons = (reasons,) + for idx, reason in enumerate(reasons): + _validate_type(reason, str, reason) + bad_tuple += tuple(reasons) else: - bad_tuple += tuple(bad_names) + bad_names = [ch_names[idx[i]] for i in idx_deltas] + if not has_printed: + logger.info( + f" Rejecting {t} epoch based on {name} : " + f"{bad_names}" + ) + has_printed = True + if not full_report: + return False + else: + bad_tuple += tuple(bad_names) if not full_report: return True @@ -3731,8 +4081,7 @@ def _read_one_epoch_file(f, tree, preload): elif kind == FIFF.FIFF_EPOCH: # delay reading until later fid.seek(pos, 0) - data_tag = read_tag_info(fid) - data_tag.pos = pos + data_tag = _read_tag_header(fid, pos) data_tag.type = data_tag.type ^ (1 << 30) elif kind in [FIFF.FIFF_MNE_BASELINE_MIN, 304]: # Constant 304 was used before v0.11 @@ -3763,12 +4112,12 @@ def _read_one_epoch_file(f, tree, preload): n_samp = last - first + 1 logger.info(" Found the data of interest:") logger.info( - " t = %10.2f ... %10.2f ms" - % (1000 * first / info["sfreq"], 1000 * last / info["sfreq"]) + f" t = {1000 * first / info['sfreq']:10.2f} ... " + f"{1000 * last / info['sfreq']:10.2f} ms" ) if info["comps"] is not None: logger.info( - " %d CTF compensation matrices available" % len(info["comps"]) + f" {len(info['comps'])} CTF compensation matrices available" ) # Inspect the data @@ -3850,7 +4199,7 @@ def _read_one_epoch_file(f, tree, preload): @verbose -def read_epochs(fname, proj=True, preload=True, verbose=None): +def read_epochs(fname, proj=True, preload=True, verbose=None) -> "EpochsFIF": """Read epochs from a fif file. Parameters @@ -3873,9 +4222,7 @@ def read_epochs(fname, proj=True, preload=True, verbose=None): class _RawContainer: """Helper for a raw data container.""" - def __init__( - self, fid, data_tag, event_samps, epoch_shape, cals, fmt - ): # noqa: D102 + def __init__(self, fid, data_tag, event_samps, epoch_shape, cals, fmt): self.fid = fid self.data_tag = data_tag self.event_samps = event_samps @@ -3909,7 +4256,7 @@ class EpochsFIF(BaseEpochs): """ @verbose - def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102 + def __init__(self, fname, proj=True, preload=True, verbose=None): from .io.base import _get_fname_rep if _path_like(fname): @@ -4036,7 +4383,7 @@ def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102 # call BaseEpochs constructor # again, ensure we're retaining the baseline period originally loaded # from disk without trying to re-apply baseline correction - super(EpochsFIF, self).__init__( + super().__init__( info, data, events, @@ -4159,9 +4506,7 @@ def _concatenate_epochs( ): """Auxiliary function for concatenating epochs.""" if not isinstance(epochs_list, (list, tuple)): - raise TypeError( - "epochs_list must be a list or tuple, got %s" % (type(epochs_list),) - ) + raise TypeError(f"epochs_list must be a list or tuple, got {type(epochs_list)}") # to make warning messages only occur once during concatenation warned = False @@ -4169,8 +4514,7 @@ def _concatenate_epochs( for ei, epochs in enumerate(epochs_list): if not isinstance(epochs, BaseEpochs): raise TypeError( - "epochs_list[%d] must be an instance of Epochs, " - "got %s" % (ei, type(epochs)) + f"epochs_list[{ei}] must be an instance of Epochs, got {type(epochs)}" ) if ( @@ -4180,8 +4524,8 @@ def _concatenate_epochs( ): warned = True warn( - "Concatenation of Annotations within Epochs is not supported " - "yet. All annotations will be dropped." + "Concatenation of Annotations within Epochs is not supported yet. All " + "annotations will be dropped." ) # create a copy, so that the Annotations are not modified in place @@ -4473,9 +4817,7 @@ def average_movements( from .chpi import head_pos_to_trans_rot_t if not isinstance(epochs, BaseEpochs): - raise TypeError( - "epochs must be an instance of Epochs, not %s" % (type(epochs),) - ) + raise TypeError(f"epochs must be an instance of Epochs, not {type(epochs)}") orig_sfreq = epochs.info["sfreq"] if orig_sfreq is None else orig_sfreq orig_sfreq = float(orig_sfreq) if isinstance(head_pos, np.ndarray): @@ -4486,7 +4828,7 @@ def average_movements( origin = _check_origin(origin, epochs.info, "head") recon_trans = _check_destination(destination, epochs.info, True) - logger.info("Aligning and averaging up to %s epochs" % (len(epochs.events))) + logger.info(f"Aligning and averaging up to {len(epochs.events)} epochs") if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])): raise RuntimeError("Epochs must have monotonically increasing events") info_to = epochs.info.copy() @@ -4528,12 +4870,12 @@ def average_movements( loc_str = ", ".join("%0.1f" % tr for tr in (trans[:3, 3] * 1000)) if last_trans is None or not np.allclose(last_trans, trans): logger.info( - " Processing epoch %s (device location: %s mm)" % (ei + 1, loc_str) + f" Processing epoch {ei + 1} (device location: {loc_str} mm)" ) reuse = False last_trans = trans else: - logger.info(" Processing epoch %s (device location: same)" % (ei + 1,)) + logger.info(f" Processing epoch {ei + 1} (device location: same)") reuse = True epoch = epoch.copy() # because we operate inplace if not reuse: @@ -4582,7 +4924,7 @@ def average_movements( data, info_to, picks, n_events=count, kind="average", comment=epochs._name ) _remove_meg_projs_comps(evoked, ignore_ref) - logger.info("Created Evoked dataset from %s epochs" % (count,)) + logger.info(f"Created Evoked dataset from {count} epochs") return (evoked, mapping) if return_mapping else evoked @@ -4594,7 +4936,7 @@ def make_fixed_length_epochs( reject_by_annotation=True, proj=True, overlap=0.0, - id=1, + id=1, # noqa: A002 verbose=None, ): """Divide continuous raw data into equal-sized consecutive epochs. diff --git a/mne/event.py b/mne/event.py index 211ed4e5d5d..a79ea13dbcc 100644 --- a/mne/event.py +++ b/mne/event.py @@ -321,9 +321,7 @@ def read_events( event_list = _mask_trigs(event_list, mask, mask_type) masked_len = event_list.shape[0] if masked_len < unmasked_len: - warn( - "{} of {} events masked".format(unmasked_len - masked_len, unmasked_len) - ) + warn(f"{unmasked_len - masked_len} of {unmasked_len} events masked") out = event_list if return_event_id: if event_id is None: @@ -927,7 +925,13 @@ def shift_time_events(events, ids, tshift, sfreq): @fill_doc def make_fixed_length_events( - raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0 + raw, + id=1, # noqa: A002 + start=0, + stop=None, + duration=1.0, + first_samp=True, + overlap=0.0, ): """Make a set of :term:`events` separated by a fixed duration. @@ -969,7 +973,7 @@ def make_fixed_length_events( duration, overlap = float(duration), float(overlap) if not 0 <= overlap < duration: raise ValueError( - "overlap must be >=0 but < duration (%s), got %s" % (duration, overlap) + f"overlap must be >=0 but < duration ({duration}), got {overlap}" ) start = raw.time_as_index(start, use_rounding=True)[0] @@ -1026,7 +1030,7 @@ def concatenate_events(events, first_samps, last_samps): _validate_type(events, list, "events") if not (len(events) == len(last_samps) and len(events) == len(first_samps)): raise ValueError( - "events, first_samps, and last_samps must all have " "the same lengths" + "events, first_samps, and last_samps must all have the same lengths" ) first_samps = np.array(first_samps) last_samps = np.array(last_samps) @@ -1144,7 +1148,7 @@ class AcqParserFIF: "OldMask", ) - def __init__(self, info): # noqa: D102 + def __init__(self, info): acq_pars = info["acq_pars"] if not acq_pars: raise ValueError("No acquisition parameters") diff --git a/mne/evoked.py b/mne/evoked.py index 31083795507..2e36f47f81b 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -9,6 +9,8 @@ # Copyright the MNE-Python contributors. from copy import deepcopy +from inspect import getfullargspec +from typing import Union import numpy as np @@ -46,6 +48,7 @@ from .html_templates import _get_html_template from .parallel import parallel_func from .time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method +from .time_frequency.tfr import AverageTFR from .utils import ( ExtendedTimeMixin, SizeMixin, @@ -174,7 +177,7 @@ def __init__( allow_maxshield=False, *, verbose=None, - ): # noqa: D102 + ): _validate_type(proj, bool, "'proj'") # Read the requested data fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read")) @@ -257,7 +260,15 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None): @verbose def apply_function( - self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs + self, + fun, + picks=None, + dtype=None, + n_jobs=None, + channel_wise=True, + *, + verbose=None, + **kwargs, ): """Apply a function to a subset of channels. @@ -270,6 +281,9 @@ def apply_function( %(dtype_applyfun)s %(n_jobs)s Ignored if ``channel_wise=False`` as the workload is split across channels. + %(channel_wise_applyfun)s + + .. versionadded:: 1.6 %(verbose)s %(kwargs_fun)s @@ -288,21 +302,55 @@ def apply_function( if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) + args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs + if channel_wise is False: + if ("ch_idx" in args) or ("ch_name" in args): + raise ValueError( + "apply_function cannot access ch_idx or ch_name " + "when channel_wise=False" + ) + if "ch_idx" in args: + logger.info("apply_function requested to access ch_idx") + if "ch_name" in args: + logger.info("apply_function requested to access ch_name") + # check the dimension of the incoming evoked data _check_option("evoked.ndim", self._data.ndim, [2]) - parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) - if n_jobs == 1: - # modify data inplace to save memory - for idx in picks: - self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) + if channel_wise: + parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) + if n_jobs == 1: + # modify data inplace to save memory + for ch_idx in picks: + if "ch_idx" in args: + kwargs.update(ch_idx=ch_idx) + if "ch_name" in args: + kwargs.update(ch_name=self.info["ch_names"][ch_idx]) + self._data[ch_idx, :] = _check_fun( + fun, data_in[ch_idx, :], **kwargs + ) + else: + # use parallel function + data_picks_new = parallel( + p_fun( + fun, + data_in[ch_idx, :], + **kwargs, + **{ + k: v + for k, v in [ + ("ch_name", self.info["ch_names"][ch_idx]), + ("ch_idx", ch_idx), + ] + if k in args + }, + ) + for ch_idx in picks + ) + for run_idx, ch_idx in enumerate(picks): + self._data[ch_idx, :] = data_picks_new[run_idx] else: - # use parallel function - data_picks_new = parallel( - p_fun(fun, data_in[p, :], **kwargs) for p in picks - ) - for pp, p in enumerate(picks): - self._data[p, :] = data_picks_new[pp] + self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs) return self @@ -399,8 +447,8 @@ def __repr__(self): # noqa: D105 comment += "..." else: comment = self.comment - s = "'%s' (%s, N=%s)" % (comment, self.kind, self.nave) - s += ", %0.5g – %0.5g s" % (self.times[0], self.times[-1]) + s = f"'{comment}' ({self.kind}, N={self.nave})" + s += f", {self.times[0]:0.5g} – {self.times[-1]:0.5g} s" s += ", baseline " if self.baseline is None: s += "off" @@ -414,8 +462,8 @@ def __repr__(self): # noqa: D105 ): s += " (baseline period was cropped after baseline correction)" s += ", %s ch" % self.data.shape[0] - s += ", ~%s" % (sizeof_fmt(self._size),) - return "" % s + s += f", ~{sizeof_fmt(self._size)}" + return f"" @repr_html def _repr_html_(self): @@ -548,7 +596,7 @@ def plot_topo( scalings=None, title=None, proj=False, - vline=[0.0], + vline=(0.0,), fig_background=None, merge_grads=False, legend=True, @@ -913,6 +961,8 @@ def get_peak( time_as_index=False, merge_grads=False, return_amplitude=False, + *, + strict=True, ): """Get location and latency of peak amplitude. @@ -940,6 +990,12 @@ def get_peak( If True, return also the amplitude at the maximum response. .. versionadded:: 0.16 + strict : bool + If True, raise an error if values are all positive when detecting + a minimum (mode='neg'), or all negative when detecting a maximum + (mode='pos'). Defaults to True. + + .. versionadded:: 1.7 Returns ------- @@ -1031,7 +1087,14 @@ def get_peak( data, _ = _merge_ch_data(data, ch_type, []) ch_names = [ch_name[:-1] + "X" for ch_name in ch_names[::2]] - ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin, tmax, mode) + ch_idx, time_idx, max_amp = _get_peak( + data, + self.times, + tmin, + tmax, + mode, + strict=strict, + ) out = (ch_names[ch_idx], time_idx if time_as_index else self.times[time_idx]) @@ -1106,6 +1169,66 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of evoked data. + + Parameters + ---------- + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Returns + ------- + tfr : instance of AverageTFR + The time-frequency-resolved power estimates of the data. + + Notes + ----- + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + return AverageTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def plot_psd( self, @@ -1255,7 +1378,7 @@ def to_data_frame( data = _scale_dataframe_data(self, data, picks, scalings) # prepare extra columns / multiindex mindex = list() - times = _convert_times(self, times, time_format) + times = _convert_times(times, time_format, self.info["meas_date"]) mindex.append(("time", times)) # build DataFrame df = _build_data_frame( @@ -1316,20 +1439,20 @@ def __init__( baseline=None, *, verbose=None, - ): # noqa: D102 + ): dtype = np.complex128 if np.iscomplexobj(data) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 2: raise ValueError( - "Data must be a 2D array of shape (n_channels, " - "n_samples), got shape %s" % (data.shape,) + "Data must be a 2D array of shape (n_channels, n_samples), got shape " + f"{data.shape}" ) if len(info["ch_names"]) != np.shape(data)[0]: raise ValueError( - "Info (%s) and data (%s) must have same number " - "of channels." % (len(info["ch_names"]), np.shape(data)[0]) + f"Info ({len(info['ch_names'])}) and data ({np.shape(data)[0]}) must " + "have same number of channels." ) self.data = data @@ -1351,8 +1474,7 @@ def __init__( _validate_type(self.kind, "str", "kind") if self.kind not in _aspect_dict: raise ValueError( - 'unknown kind "%s", should be "average" or ' - '"standard_error"' % (self.kind,) + f'unknown kind "{self.kind}", should be "average" or "standard_error"' ) self._aspect_kind = _aspect_dict[self.kind] @@ -1420,18 +1542,14 @@ def _check_evokeds_ch_names_times(all_evoked): for ii, ev in enumerate(all_evoked[1:]): if ev.ch_names != ch_names: if set(ev.ch_names) != set(ch_names): - raise ValueError( - "%s and %s do not contain the same channels." % (evoked, ev) - ) + raise ValueError(f"{evoked} and {ev} do not contain the same channels.") else: warn("Order of channels differs, reordering channels ...") ev = ev.copy() ev.reorder_channels(ch_names) all_evoked[ii + 1] = ev if not np.max(np.abs(ev.times - evoked.times)) < 1e-7: - raise ValueError( - "%s and %s do not contain the same time instants" % (evoked, ev) - ) + raise ValueError(f"{evoked} and {ev} do not contain the same time instants") return all_evoked @@ -1538,7 +1656,7 @@ def read_evokeds( proj=True, allow_maxshield=False, verbose=None, -): +) -> Union[list[Evoked], Evoked]: """Read evoked dataset(s). Parameters @@ -1660,17 +1778,16 @@ def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): found_cond = np.where(goods)[0] if len(found_cond) != 1: raise ValueError( - 'condition "%s" (%s) not found, out of ' - "found datasets:\n%s" % (condition, kind, t) + f'condition "{condition}" ({kind}) not found, out of found ' + f"datasets:\n{t}" ) condition = found_cond[0] elif condition is None: if len(evoked_node) > 1: _, _, conditions = _get_entries(fid, evoked_node, allow_maxshield) raise TypeError( - "Evoked file has more than one " - "condition, the condition parameters " - "must be specified from:\n%s" % conditions + "Evoked file has more than one condition, the condition parameters " + f"must be specified from:\n{conditions}" ) else: condition = 0 @@ -1804,19 +1921,18 @@ def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): del first, last if nsamp is not None and data.shape[1] != nsamp: raise ValueError( - "Incorrect number of samples (%d instead of " - " %d)" % (data.shape[1], nsamp) + f"Incorrect number of samples ({data.shape[1]} instead of {nsamp})" ) logger.info(" Found the data of interest:") logger.info( - " t = %10.2f ... %10.2f ms (%s)" - % (1000 * times[0], 1000 * times[-1], comment) + f" t = {1000 * times[0]:10.2f} ... {1000 * times[-1]:10.2f} ms (" + f"{comment})" ) if info["comps"] is not None: logger.info( - " %d CTF compensation matrices available" % len(info["comps"]) + f" {len(info['comps'])} CTF compensation matrices available" ) - logger.info(" nave = %d - aspect type = %d" % (nave, aspect_kind)) + logger.info(f" nave = {nave} - aspect type = {aspect_kind}") # Calibrate cals = np.array( @@ -1955,7 +2071,7 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch="raise", overwrite= end_block(fid, FIFF.FIFFB_MEAS) -def _get_peak(data, times, tmin=None, tmax=None, mode="abs"): +def _get_peak(data, times, tmin=None, tmax=None, mode="abs", *, strict=True): """Get feature-index and time of maximum signal from 2D array. Note. This is a 'getter', not a 'finder'. For non-evoked type @@ -1976,6 +2092,10 @@ def _get_peak(data, times, tmin=None, tmax=None, mode="abs"): values will be considered. If 'neg' only negative values will be considered. If 'abs' absolute values will be considered. Defaults to 'abs'. + strict : bool + If True, raise an error if values are all positive when detecting + a minimum (mode='neg'), or all negative when detecting a maximum + (mode='pos'). Defaults to True. Returns ------- @@ -2014,14 +2134,14 @@ def _get_peak(data, times, tmin=None, tmax=None, mode="abs"): maxfun = np.argmax if mode == "pos": - if not np.any(data[~mask] > 0): + if strict and not np.any(data[~mask] > 0): raise ValueError( - "No positive values encountered. Cannot " "operate in pos mode." + "No positive values encountered. Cannot operate in pos mode." ) elif mode == "neg": - if not np.any(data[~mask] < 0): + if strict and not np.any(data[~mask] < 0): raise ValueError( - "No negative values encountered. Cannot " "operate in neg mode." + "No negative values encountered. Cannot operate in neg mode." ) maxfun = np.argmin diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index 0da7647ecb7..d705d8cef9d 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -4,11 +4,150 @@ # Copyright the MNE-Python contributors. import os +from pathlib import Path -from ..utils import _check_pybv_installed +import numpy as np + +from mne.channels.channels import _unit2human +from mne.io.constants import FIFF +from mne.utils import _check_pybv_installed, warn _check_pybv_installed() -from pybv._export import _export_mne_raw # noqa: E402 +from pybv import write_brainvision # noqa: E402 + + +def _export_mne_raw(*, raw, fname, events=None, overwrite=False): + """Export raw data from MNE-Python. + + Parameters + ---------- + raw : mne.io.Raw + The raw data to export. + fname : str | pathlib.Path + The name of the file where raw data will be exported to. Must end with + ``".vhdr"``, and accompanying *.vmrk* and *.eeg* files will be written inside + the same directory. + events : np.ndarray | None + Events to be written to the marker file (*.vmrk*). If array, must be in + `MNE-Python format `_. If + ``None`` (default), events will be written based on ``raw.annotations``. + overwrite : bool + Whether or not to overwrite existing data. Defaults to ``False``. + + """ + # prepare file location + if not str(fname).endswith(".vhdr"): + raise ValueError("`fname` must have the '.vhdr' extension for BrainVision.") + fname = Path(fname) + folder_out = fname.parents[0] + fname_base = fname.stem + + # prepare data from raw + data = raw.get_data() # gets data starting from raw.first_samp + sfreq = raw.info["sfreq"] # in Hz + meas_date = raw.info["meas_date"] # datetime.datetime + ch_names = raw.ch_names + + # write voltage units as micro-volts and all other units without scaling + # write units that we don't know as n/a + unit = [] + for ch in raw.info["chs"]: + if ch["unit"] == FIFF.FIFF_UNIT_V: + unit.append("µV") + elif ch["unit"] == FIFF.FIFF_UNIT_CEL: + unit.append("°C") + else: + unit.append(_unit2human.get(ch["unit"], "n/a")) + unit = [u if u != "NA" else "n/a" for u in unit] + + # enforce conversion to float32 format + # XXX: Could add a feature that checks data and optimizes `unit`, `resolution`, and + # `format` so that raw.orig_format could be retained if reasonable. + if raw.orig_format != "single": + warn( + f"Encountered data in '{raw.orig_format}' format. Converting to float32.", + RuntimeWarning, + ) + + fmt = "binary_float32" + resolution = 0.1 + + # handle events + # if we got an ndarray, this is in MNE-Python format + msg = "`events` must be None or array in MNE-Python format." + if events is not None: + # subtract raw.first_samp because brainvision marks events starting from the + # first available data point and ignores the raw.first_samp + assert isinstance(events, np.ndarray), msg + assert events.ndim == 2, msg + assert events.shape[-1] == 3, msg + events[:, 0] -= raw.first_samp + events = events[:, [0, 2]] # reorder for pybv required order + else: # else, prepare pybv style events from raw.annotations + events = _mne_annots2pybv_events(raw) + + # no information about reference channels in mne currently + ref_ch_names = None + + # write to BrainVision + write_brainvision( + data=data, + sfreq=sfreq, + ch_names=ch_names, + ref_ch_names=ref_ch_names, + fname_base=fname_base, + folder_out=folder_out, + overwrite=overwrite, + events=events, + resolution=resolution, + unit=unit, + fmt=fmt, + meas_date=meas_date, + ) + + +def _mne_annots2pybv_events(raw): + """Convert mne Annotations to pybv events.""" + events = [] + for annot in raw.annotations: + # handle onset and duration: seconds to sample, relative to + # raw.first_samp / raw.first_time + onset = annot["onset"] - raw.first_time + onset = raw.time_as_index(onset).astype(int)[0] + duration = int(annot["duration"] * raw.info["sfreq"]) + + # triage type and description + # defaults to type="Comment" and the full description + etype = "Comment" + description = annot["description"] + for start in ["Stimulus/S", "Response/R", "Comment/"]: + if description.startswith(start): + etype = start.split("/")[0] + description = description.replace(start, "") + break + + if etype in ["Stimulus", "Response"] and description.strip().isdigit(): + description = int(description.strip()) + else: + # if cannot convert to int, we must use this as "Comment" + etype = "Comment" + + event_dict = dict( + onset=onset, # in samples + duration=duration, # in samples + description=description, + type=etype, + ) + + if "ch_names" in annot: + # handle channels + channels = list(annot["ch_names"]) + event_dict["channels"] = channels + + # add a "pybv" event + events += [event_dict] + + return events def _export_raw(fname, raw, overwrite): diff --git a/mne/export/_edf.py b/mne/export/_edf.py index 7097f7bd85d..3f7e55b3d77 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -3,127 +3,86 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from contextlib import contextmanager +import datetime as dt +from typing import Callable import numpy as np -from ..utils import _check_edflib_installed, warn +from ..utils import _check_edfio_installed, warn -_check_edflib_installed() -from EDFlib.edfwriter import EDFwriter # noqa: E402 +_check_edfio_installed() +from edfio import Edf, EdfAnnotation, EdfSignal, Patient, Recording # noqa: E402 -def _try_to_set_value(header, key, value, channel_index=None): - """Set key/value pairs in EDF header.""" - # all EDFLib set functions are set - # for example "setPatientName()" - func_name = f"set{key}" - func = getattr(header, func_name) - - # some setter functions are indexed by channels - if channel_index is None: - return_val = func(value) - else: - return_val = func(channel_index, value) - - # a nonzero return value indicates an error - if return_val != 0: - raise RuntimeError( - f"Setting {key} with {value} " f"returned an error value " f"{return_val}." - ) - - -@contextmanager -def _auto_close(fid): - # try to close the handle no matter what - try: - yield fid - finally: - try: - fid.close() - except Exception: - pass # we did our best +# copied from edfio (Apache license) +def _round_float_to_8_characters( + value: float, + round_func: Callable[[float], int], +) -> float: + if isinstance(value, int) or value.is_integer(): + return value + length = 8 + integer_part_length = str(value).find(".") + if integer_part_length == length: + return round_func(value) + factor = 10 ** (length - 1 - integer_part_length) + return round_func(value * factor) / factor def _export_raw(fname, raw, physical_range, add_ch_type): """Export Raw objects to EDF files. - TODO: if in future the Info object supports transducer or - technician information, allow writing those here. + TODO: if in future the Info object supports transducer or technician information, + allow writing those here. """ - # scale to save data in EDF - phys_dims = "uV" - - # get EEG-related data in uV + # get voltage-based data in uV units = dict( eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" ) - digital_min = -32767 - digital_max = 32767 - file_type = EDFwriter.EDFLIB_FILETYPE_EDFPLUS + digital_min, digital_max = -32767, 32767 # load data first raw.load_data() - # remove extra STI channels - orig_ch_types = raw.get_channel_types() - drop_chs = [] - if "stim" in orig_ch_types: - stim_index = np.argwhere(np.array(orig_ch_types) == "stim") - stim_index = np.atleast_1d(stim_index.squeeze()).tolist() - drop_chs.extend([raw.ch_names[idx] for idx in stim_index]) - - # Add warning if any channel types are not voltage based. - # Users are expected to only export data that is voltage based, - # such as EEG, ECoG, sEEG, etc. - # Non-voltage channels are dropped by the export function. - # Note: we can write these other channels, such as 'misc' - # but these are simply a "catch all" for unknown or undesired - # channels. - voltage_types = list(units) + ["stim", "misc"] - non_voltage_ch = [ch not in voltage_types for ch in orig_ch_types] - if any(non_voltage_ch): - warn( - f"Non-voltage channels detected: {non_voltage_ch}. MNE-Python's " - "EDF exporter only supports voltage-based channels, because the " - "EDF format cannot accommodate much of the accompanying data " - "necessary for channel types like MEG and fNIRS (channel " - "orientations, coordinate frame transforms, etc). You can " - "override this restriction by setting those channel types to " - '"misc" but no guarantees are made of the fidelity of that ' - "approach." - ) - - ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] - ch_types = np.array(raw.get_channel_types(picks=ch_names)) - n_channels = len(ch_names) + ch_types = np.array(raw.get_channel_types()) n_times = raw.n_times - # Sampling frequency in EDF only supports integers, so to allow for - # float sampling rates from Raw, we adjust the output sampling rate - # for all channels and the data record duration. + # get the entire dataset in uV + data = raw.get_data(units=units) + + # Sampling frequency in EDF only supports integers, so to allow for float sampling + # rates from Raw, we adjust the output sampling rate for all channels and the data + # record duration. sfreq = raw.info["sfreq"] if float(sfreq).is_integer(): out_sfreq = int(sfreq) data_record_duration = None + # make non-integer second durations work + if (pad_width := int(np.ceil(n_times / sfreq) * sfreq - n_times)) > 0: + warn( + "EDF format requires equal-length data blocks, so " + f"{pad_width / sfreq:.3g} seconds of zeros were appended to all " + "channels when writing the final block." + ) + data = np.pad(data, (0, int(pad_width))) else: - out_sfreq = np.floor(sfreq).astype(int) - data_record_duration = int(np.around(out_sfreq / sfreq, decimals=6) * 1e6) - + data_record_duration = _round_float_to_8_characters( + np.floor(sfreq) / sfreq, round + ) + out_sfreq = np.floor(sfreq) / data_record_duration warn( - f"Data has a non-integer sampling rate of {sfreq}; writing to " - "EDF format may cause a small change to sample times." + f"Data has a non-integer sampling rate of {sfreq}; writing to EDF format " + "may cause a small change to sample times." ) # get any filter information applied to the data lowpass = raw.info["lowpass"] highpass = raw.info["highpass"] linefreq = raw.info["line_freq"] - filter_str_info = f"HP:{highpass}Hz LP:{lowpass}Hz N:{linefreq}Hz" - - # get the entire dataset in uV - data = raw.get_data(units=units, picks=ch_names) + filter_str_info = f"HP:{highpass}Hz LP:{lowpass}Hz" + if linefreq is not None: + filter_str_info += " N:{linefreq}Hz" if physical_range == "auto": # get max and min for each channel type data @@ -131,204 +90,133 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ch_types_phys_min = dict() for _type in np.unique(ch_types): - _picks = np.nonzero(ch_types == _type)[0] + _picks = [n for n, t in zip(raw.ch_names, ch_types) if t == _type] _data = raw.get_data(units=units, picks=_picks) ch_types_phys_max[_type] = _data.max() ch_types_phys_min[_type] = _data.min() + elif physical_range == "channelwise": + prange = None else: # get the physical min and max of the data in uV - # Physical ranges of the data in uV is usually set by the manufacturer - # and properties of the electrode. In general, physical max and min - # should be the clipping levels of the ADC input and they should be - # the same for all channels. For example, Nihon Kohden uses +3200 uV - # and -3200 uV for all EEG channels (which are the actual clipping - # levels of their input amplifiers & ADC). - # For full discussion, see: https://github.com/sccn/eeglab/issues/246 + # Physical ranges of the data in uV are usually set by the manufacturer and + # electrode properties. In general, physical min and max should be the clipping + # levels of the ADC input, and they should be the same for all channels. For + # example, Nihon Kohden uses ±3200 uV for all EEG channels (corresponding to the + # actual clipping levels of their input amplifiers & ADC). For a discussion, + # see https://github.com/sccn/eeglab/issues/246 pmin, pmax = physical_range[0], physical_range[1] # check that physical min and max is not exceeded if data.max() > pmax: warn( - f"The maximum μV of the data {data.max()} is " - f"more than the physical max passed in {pmax}.", + f"The maximum μV of the data {data.max()} is more than the physical max" + f" passed in {pmax}." ) if data.min() < pmin: warn( - f"The minimum μV of the data {data.min()} is " - f"less than the physical min passed in {pmin}.", + f"The minimum μV of the data {data.min()} is less than the physical min" + f" passed in {pmin}." + ) + data = np.clip(data, pmin, pmax) + prange = pmin, pmax + signals = [] + for idx, ch in enumerate(raw.ch_names): + ch_type = ch_types[idx] + signal_label = f"{ch_type.upper()} {ch}" if add_ch_type else ch + if len(signal_label) > 16: + raise RuntimeError( + f"Signal label for {ch} ({ch_type}) is longer than 16 characters, which" + " is not supported by the EDF standard. Please shorten the channel name" + "before exporting to EDF." ) - # create instance of EDF Writer - with _auto_close(EDFwriter(fname, file_type, n_channels)) as hdl: - # set channel data - for idx, ch in enumerate(ch_names): - ch_type = ch_types[idx] - signal_label = f"{ch_type.upper()} {ch}" if add_ch_type else ch - if len(signal_label) > 16: - raise RuntimeError( - f"Signal label for {ch} ({ch_type}) is " - f"longer than 16 characters, which is not " - f"supported in EDF. Please shorten the " - f"channel name before exporting to EDF." - ) - - if physical_range == "auto": - # take the channel type minimum and maximum - pmin = ch_types_phys_min[ch_type] - pmax = ch_types_phys_max[ch_type] - for key, val in [ - ("PhysicalMaximum", pmax), - ("PhysicalMinimum", pmin), - ("DigitalMaximum", digital_max), - ("DigitalMinimum", digital_min), - ("PhysicalDimension", phys_dims), - ("SampleFrequency", out_sfreq), - ("SignalLabel", signal_label), - ("PreFilter", filter_str_info), - ]: - _try_to_set_value(hdl, key, val, channel_index=idx) - - # set patient info - subj_info = raw.info.get("subject_info") - if subj_info is not None: - # get the full name of subject if available - first_name = subj_info.get("first_name", "") - middle_name = subj_info.get("middle_name", "") - last_name = subj_info.get("last_name", "") - name = " ".join(filter(None, [first_name, middle_name, last_name])) - - birthday = subj_info.get("birthday") - hand = subj_info.get("hand") - weight = subj_info.get("weight") - height = subj_info.get("height") - sex = subj_info.get("sex") - - additional_patient_info = [] - for key, value in [("height", height), ("weight", weight), ("hand", hand)]: - if value: - additional_patient_info.append(f"{key}={value}") - if len(additional_patient_info) == 0: - additional_patient_info = None - else: - additional_patient_info = " ".join(additional_patient_info) - - if birthday is not None: - if hdl.setPatientBirthDate(birthday[0], birthday[1], birthday[2]) != 0: - raise RuntimeError( - f"Setting patient birth date to {birthday} " - f"returned an error" - ) - for key, val in [ - ("PatientCode", subj_info.get("his_id", "")), - ("PatientName", name), - ("PatientGender", sex), - ("AdditionalPatientInfo", additional_patient_info), - ]: - # EDFwriter compares integer encodings of sex and will - # raise a TypeError if value is None as returned by - # subj_info.get(key) if key is missing. - if val is not None: - _try_to_set_value(hdl, key, val) - - # set measurement date - meas_date = raw.info["meas_date"] - if meas_date: - subsecond = int(meas_date.microsecond / 100) - if ( - hdl.setStartDateTime( - year=meas_date.year, - month=meas_date.month, - day=meas_date.day, - hour=meas_date.hour, - minute=meas_date.minute, - second=meas_date.second, - subsecond=subsecond, - ) - != 0 - ): - raise RuntimeError( - f"Setting start date time {meas_date} " f"returned an error" - ) - - device_info = raw.info.get("device_info") - if device_info is not None: - device_type = device_info.get("type") - _try_to_set_value(hdl, "Equipment", device_type) - - # set data record duration - if data_record_duration is not None: - _try_to_set_value(hdl, "DataRecordDuration", data_record_duration) - - # compute number of data records to loop over - n_blocks = np.ceil(n_times / out_sfreq).astype(int) - - # increase the number of annotation signals if necessary - annots = raw.annotations - if annots is not None: - n_annotations = len(raw.annotations) - n_annot_chans = int(n_annotations / n_blocks) + 1 - if n_annot_chans > 1: - hdl.setNumberOfAnnotationSignals(n_annot_chans) - - # Write each data record sequentially - for idx in range(n_blocks): - end_samp = (idx + 1) * out_sfreq - if end_samp > n_times: - end_samp = n_times - start_samp = idx * out_sfreq - - # then for each datarecord write each channel - for jdx in range(n_channels): - # create a buffer with sampling rate - buf = np.zeros(out_sfreq, np.float64, "C") + if physical_range == "auto": # per channel type + pmin = ch_types_phys_min[ch_type] + pmax = ch_types_phys_max[ch_type] + prange = pmin, pmax + + signals.append( + EdfSignal( + data[idx], + out_sfreq, + label=signal_label, + transducer_type="", + physical_dimension="" if ch_type == "stim" else "uV", + physical_range=prange, + digital_range=(digital_min, digital_max), + prefiltering=filter_str_info, + ) + ) - # get channel data for this block - ch_data = data[jdx, start_samp:end_samp] + # set patient info + subj_info = raw.info.get("subject_info") + if subj_info is not None: + # get the full name of subject if available + first_name = subj_info.get("first_name", "") + middle_name = subj_info.get("middle_name", "") + last_name = subj_info.get("last_name", "") + name = "_".join(filter(None, [first_name, middle_name, last_name])) + + birthday = subj_info.get("birthday") + if birthday is not None: + birthday = dt.date(*birthday) + hand = subj_info.get("hand") + weight = subj_info.get("weight") + height = subj_info.get("height") + sex = subj_info.get("sex") + + additional_patient_info = [] + for key, value in [("height", height), ("weight", weight), ("hand", hand)]: + if value: + additional_patient_info.append(f"{key}={value}") + + patient = Patient( + code=subj_info.get("his_id") or "X", + sex={0: "X", 1: "M", 2: "F", None: "X"}[sex], + birthdate=birthday, + name=name or "X", + additional=additional_patient_info, + ) + else: + patient = None - # assign channel data to the buffer and write to EDF - buf[: len(ch_data)] = ch_data - err = hdl.writeSamples(buf) - if err != 0: - raise RuntimeError( - f"writeSamples() for channel{ch_names[jdx]} " - f"returned error: {err}" - ) + # set measurement date + if (meas_date := raw.info["meas_date"]) is not None: + startdate = dt.date(meas_date.year, meas_date.month, meas_date.day) + starttime = dt.time( + meas_date.hour, meas_date.minute, meas_date.second, meas_date.microsecond + ) + else: + startdate = None + starttime = None - # there was an incomplete datarecord - if len(ch_data) != len(buf): - warn( - f"EDF format requires equal-length data blocks, " - f"so {(len(buf) - len(ch_data)) / sfreq} seconds of " - "zeros were appended to all channels when writing the " - "final block." + device_info = raw.info.get("device_info") + if device_info is not None: + device_type = device_info.get("type") or "X" + recording = Recording(startdate=startdate, equipment_code=device_type) + else: + recording = Recording(startdate=startdate) + + annotations = [] + for desc, onset, duration, ch_names in zip( + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + raw.annotations.ch_names, + ): + if ch_names: + for ch_name in ch_names: + annotations.append( + EdfAnnotation(onset, duration, desc + f"@@{ch_name}") ) - - # write annotations - if annots is not None: - for desc, onset, duration, ch_names in zip( - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - raw.annotations.ch_names, - ): - # annotations are written in terms of 100 microseconds - onset = onset * 10000 - duration = duration * 10000 - if ch_names: - for ch_name in ch_names: - if ( - hdl.writeAnnotation(onset, duration, desc + f"@@{ch_name}") - != 0 - ): - raise RuntimeError( - f"writeAnnotation() returned an error " - f"trying to write {desc}@@{ch_name} at {onset} " - f"for {duration} seconds." - ) - else: - if hdl.writeAnnotation(onset, duration, desc) != 0: - raise RuntimeError( - f"writeAnnotation() returned an error " - f"trying to write {desc} at {onset} " - f"for {duration} seconds." - ) + else: + annotations.append(EdfAnnotation(onset, duration, desc)) + + Edf( + signals=signals, + patient=patient, + recording=recording, + starttime=starttime, + data_record_duration=data_record_duration, + annotations=annotations, + ).write(fname) diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index ef10c71acfc..70462a96841 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -50,7 +50,6 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose= using MFF read functions. """ mffpy = _import_mffpy("Export evokeds to MFF.") - import pytz info = evoked[0].info if np.round(info["sfreq"]) != info["sfreq"]: @@ -73,7 +72,7 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose= if op.exists(fname): os.remove(fname) if op.isfile(fname) else shutil.rmtree(fname) writer = mffpy.Writer(fname) - current_time = pytz.utc.localize(datetime.datetime.utcnow()) + current_time = datetime.datetime.now(datetime.timezone.utc) writer.addxml("fileInfo", recordTime=current_time) try: device = info["device_info"]["type"] diff --git a/mne/export/_export.py b/mne/export/_export.py index 80a18d090d2..aed7e44e0c8 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -211,9 +211,9 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): if fmt not in supported_formats: supported = [] - for format, extensions in supported_formats.items(): + for supp_format, extensions in supported_formats.items(): ext_str = ", ".join(f"*.{ext}" for ext in extensions) - supported.append(f"{format} ({ext_str})") + supported.append(f"{supp_format} ({ext_str})") supported_str = ", ".join(supported) raise ValueError( diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 67bd417bb50..9c8a60f50bb 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -6,7 +6,6 @@ from contextlib import nullcontext from datetime import datetime, timezone -from os import remove from pathlib import Path import numpy as np @@ -33,7 +32,7 @@ ) from mne.tests.test_epochs import _get_data from mne.utils import ( - _check_edflib_installed, + _check_edfio_installed, _record_warnings, _resource_path, object_diff, @@ -79,7 +78,10 @@ def test_export_raw_pybv(tmp_path, meas_date, orig_time, ext): raw.set_annotations(annots) temp_fname = tmp_path / ("test" + ext) - with pytest.warns(RuntimeWarning, match="'short' format. Converting"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="'short' format. Converting"), + ): raw.export(temp_fname) raw_read = read_raw_brainvision(str(temp_fname).replace(".eeg", ".vhdr")) assert raw.ch_names == raw_read.ch_names @@ -120,17 +122,11 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) -@pytest.mark.skipif( - not _check_edflib_installed(strict=False), reason="edflib-python not installed" -) -def test_double_export_edf(tmp_path): - """Test exporting an EDF file multiple times.""" - rng = np.random.RandomState(123456) - format = "edf" +def _create_raw_for_edf_tests(stim_channel_index=None): + rng = np.random.RandomState(12345) ch_types = [ "eeg", "eeg", - "stim", "ecog", "ecog", "seeg", @@ -140,12 +136,27 @@ def test_double_export_edf(tmp_path): "dbs", "bio", ] - info = create_info(len(ch_types), sfreq=1000, ch_types=ch_types) - info = info.set_meas_date("2023-09-04 14:53:09.000") - data = rng.random(size=(len(ch_types), 1000)) * 1e-5 + if stim_channel_index is not None: + ch_types.insert(stim_channel_index, "stim") + ch_names = np.arange(len(ch_types)).astype(str).tolist() + info = create_info(ch_names, sfreq=1000, ch_types=ch_types) + data = rng.random(size=(len(ch_names), 2000)) * 1e-5 + return RawArray(data, info) + + +edfio_mark = pytest.mark.skipif( + not _check_edfio_installed(strict=False), reason="unsafe use of private module" +) + + +@edfio_mark() +def test_double_export_edf(tmp_path): + """Test exporting an EDF file multiple times.""" + raw = _create_raw_for_edf_tests(stim_channel_index=2) + raw.info.set_meas_date("2023-09-04 14:53:09.000") # include subject info and measurement date - info["subject_info"] = dict( + raw.info["subject_info"] = dict( his_id="12345", first_name="mne", last_name="python", @@ -155,69 +166,60 @@ def test_double_export_edf(tmp_path): height=1.75, hand=3, ) - raw = RawArray(data, info) # export once - temp_fname = tmp_path / f"test.{format}" + temp_fname = tmp_path / "test.edf" raw.export(temp_fname, add_ch_type=True) raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) # export again - raw_read.load_data() raw_read.export(temp_fname, add_ch_type=True, overwrite=True) raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) - # stim channel should be dropped - raw.drop_channels("2") - assert raw.ch_names == raw_read.ch_names - # only compare the original length, since extra zeros are appended - orig_raw_len = len(raw) - assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 - ) - assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) + assert_array_equal(raw.times, raw_read.times) # check info for key in set(raw.info) - {"chs"}: assert raw.info[key] == raw_read.info[key] - # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() read_ch_types = raw_read.get_channel_types() assert_array_equal(orig_ch_types, read_ch_types) - # check handling of missing subject metadata - del info["subject_info"]["sex"] - raw_2 = RawArray(data, info) - raw_2.export(temp_fname, add_ch_type=True, overwrite=True) - -@pytest.mark.skipif( - not _check_edflib_installed(strict=False), reason="edflib-python not installed" -) -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" - rng = np.random.RandomState(123456) - format = "edf" - ch_types = [ - "eeg", - "eeg", - "stim", - "ecog", - "ecog", - "seeg", - "eog", - "ecg", - "emg", - "dbs", - "bio", - ] +@edfio_mark() +def test_edf_physical_range(tmp_path): + """Test exporting an EDF file with different physical range settings.""" + ch_types = ["eeg"] * 4 ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, ch_types=ch_types) - data = rng.random(size=(len(ch_names), 2000)) * 1.0e-5 + fs = 1000 + info = create_info(len(ch_types), sfreq=fs, ch_types=ch_types) + data = np.tile( + np.sin(2 * np.pi * 10 * np.arange(0, 2, 1 / fs)) * 1e-5, (len(ch_names), 1) + ) + data = (data.T + [0.1, 0, 0, -0.1]).T # add offsets raw = RawArray(data, info) + # export with physical range per channel type (default) + temp_fname = tmp_path / "test_auto.edf" + raw.export(temp_fname) + raw_read = read_raw_edf(temp_fname, preload=True) + with pytest.raises(AssertionError, match="Arrays are not almost equal"): + assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) + + # export with physical range per channel + temp_fname = tmp_path / "test_per_channel.edf" + raw.export(temp_fname, physical_range="channelwise") + raw_read = read_raw_edf(temp_fname, preload=True) + assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) + + +@edfio_mark() +def test_export_edf_annotations(tmp_path): + """Test that exporting EDF preserves annotations.""" + raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], duration=[0, 1, 0, 0], @@ -227,7 +229,7 @@ def test_export_edf_annotations(tmp_path): raw.set_annotations(annotations) # export - temp_fname = tmp_path / f"test.{format}" + temp_fname = tmp_path / "test.edf" raw.export(temp_fname) # read in the file @@ -238,24 +240,19 @@ def test_export_edf_annotations(tmp_path): assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) -@pytest.mark.skipif( - not _check_edflib_installed(strict=False), reason="edflib-python not installed" -) +@edfio_mark() def test_rawarray_edf(tmp_path): """Test saving a Raw array with integer sfreq to EDF.""" - rng = np.random.RandomState(12345) - format = "edf" - ch_types = ["eeg", "eeg", "stim", "ecog", "seeg", "eog", "ecg", "emg", "dbs", "bio"] - ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, ch_types=ch_types) - data = rng.random(size=(len(ch_names), 1000)) * 1e-5 + raw = _create_raw_for_edf_tests() # include subject info and measurement date - subject_info = dict( - first_name="mne", last_name="python", birthday=(1992, 1, 20), sex=1, hand=3 + raw.info["subject_info"] = dict( + first_name="mne", + last_name="python", + birthday=(1992, 1, 20), + sex=1, + hand=3, ) - info["subject_info"] = subject_info - raw = RawArray(data, info) time_now = datetime.now() meas_date = datetime( year=time_now.year, @@ -267,125 +264,108 @@ def test_rawarray_edf(tmp_path): tzinfo=timezone.utc, ) raw.set_meas_date(meas_date) - temp_fname = tmp_path / f"test.{format}" + temp_fname = tmp_path / "test.edf" raw.export(temp_fname, add_ch_type=True) raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) - # stim channel should be dropped - raw.drop_channels("2") - assert raw.ch_names == raw_read.ch_names - # only compare the original length, since extra zeros are appended - orig_raw_len = len(raw) - assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 - ) - assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) + assert_array_equal(raw.times, raw_read.times) - # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() read_ch_types = raw_read.get_channel_types() assert_array_equal(orig_ch_types, read_ch_types) assert raw.info["meas_date"] == raw_read.info["meas_date"] - # channel name can't be longer than 16 characters with the type added - raw_bad = raw.copy() - raw_bad.rename_channels({"1": "abcdefghijklmnopqrstuvwxyz"}) - with pytest.raises(RuntimeError, match="Signal label"), pytest.warns( - RuntimeWarning, match="Data has a non-integer" - ): - raw_bad.export(temp_fname, overwrite=True) - - # include bad birthday that is non-EDF compliant - bad_info = info.copy() - bad_info["subject_info"]["birthday"] = (1700, 1, 20) - raw = RawArray(data, bad_info) - with pytest.raises(RuntimeError, match="Setting patient birth date"): - raw.export(temp_fname, overwrite=True) - # include bad measurement date that is non-EDF compliant - raw = RawArray(data, info) - meas_date = datetime(year=1984, month=1, day=1, tzinfo=timezone.utc) - raw.set_meas_date(meas_date) - with pytest.raises(RuntimeError, match="Setting start date time"): - raw.export(temp_fname, overwrite=True) +@edfio_mark() +def test_edf_export_non_voltage_channels(tmp_path): + """Test saving a Raw array containing a non-voltage channel.""" + temp_fname = tmp_path / "test.edf" - # test that warning is raised if there are non-voltage based channels - raw = RawArray(data, info) + raw = _create_raw_for_edf_tests() raw.set_channel_types({"9": "hbr"}, on_unit_change="ignore") - with pytest.warns(RuntimeWarning, match="Non-voltage channels"): - raw.export(temp_fname, overwrite=True) + raw.export(temp_fname, overwrite=True) # data should match up to the non-accepted channel raw_read = read_raw_edf(temp_fname, preload=True) - orig_raw_len = len(raw) - assert_array_almost_equal( - raw.get_data()[:-1, :], raw_read.get_data()[:, :orig_raw_len], decimal=4 - ) - assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) - - # the data should still match though - raw_read = read_raw_edf(temp_fname, preload=True) - raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names - orig_raw_len = len(raw) - assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 - ) - assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + assert_array_almost_equal(raw.get_data()[:-1], raw_read.get_data()[:-1], decimal=10) + assert_array_almost_equal(raw.get_data()[-1], raw_read.get_data()[-1], decimal=5) + assert_array_equal(raw.times, raw_read.times) + + +@edfio_mark() +def test_channel_label_too_long_for_edf_raises_error(tmp_path): + """Test trying to save an EDF where a channel label is longer than 16 characters.""" + raw = _create_raw_for_edf_tests() + raw.rename_channels({"1": "abcdefghijklmnopqrstuvwxyz"}) + with pytest.raises(RuntimeError, match="Signal label"): + raw.export(tmp_path / "test.edf") + + +@edfio_mark() +def test_measurement_date_outside_range_valid_for_edf(tmp_path): + """Test trying to save an EDF with a measurement date before 1985-01-01.""" + raw = _create_raw_for_edf_tests() + raw.set_meas_date(datetime(year=1984, month=1, day=1, tzinfo=timezone.utc)) + with pytest.raises(ValueError, match="EDF only allows dates from 1985 to 2084"): + raw.export(tmp_path / "test.edf", overwrite=True) -@pytest.mark.skipif( - not _check_edflib_installed(strict=False), reason="edflib-python not installed" +@pytest.mark.filterwarnings("ignore:Data has a non-integer:RuntimeWarning") +@pytest.mark.parametrize( + ("physical_range", "exceeded_bound"), + [ + ((-1e6, 0), "maximum"), + ((0, 1e6), "minimum"), + ], ) +@edfio_mark() +def test_export_edf_signal_clipping(tmp_path, physical_range, exceeded_bound): + """Test if exporting data exceeding physical min/max clips and emits a warning.""" + raw = read_raw_fif(fname_raw) + raw.pick(picks=["eeg", "ecog", "seeg"]).load_data() + temp_fname = tmp_path / "test.edf" + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match=f"The {exceeded_bound}"), + ): + raw.export(temp_fname, physical_range=physical_range) + raw_read = read_raw_edf(temp_fname, preload=True) + assert raw_read.get_data().min() >= physical_range[0] + assert raw_read.get_data().max() <= physical_range[1] + + +@edfio_mark() @pytest.mark.parametrize( - ["dataset", "format"], + ("input_path", "warning_msg"), [ - ["test", "edf"], - pytest.param("misc", "edf", marks=[pytest.mark.slowtest, misc._pytest_mark()]), + (fname_raw, "Data has a non-integer"), + pytest.param( + misc_path / "ecog" / "sample_ecog_ieeg.fif", + "EDF format requires", + marks=[pytest.mark.slowtest, misc._pytest_mark()], + ), ], ) -def test_export_raw_edf(tmp_path, dataset, format): +def test_export_raw_edf(tmp_path, input_path, warning_msg): """Test saving a Raw instance to EDF format.""" - if dataset == "test": - raw = read_raw_fif(fname_raw) - elif dataset == "misc": - fname = misc_path / "ecog" / "sample_ecog_ieeg.fif" - raw = read_raw_fif(fname) + raw = read_raw_fif(input_path) # only test with EEG channels raw.pick(picks=["eeg", "ecog", "seeg"]).load_data() - orig_ch_names = raw.ch_names - temp_fname = tmp_path / f"test.{format}" - - # test runtime errors - with pytest.warns() as record: - raw.export(temp_fname, physical_range=(-1e6, 0)) - if dataset == "test": - assert any("Data has a non-integer" in str(rec.message) for rec in record) - assert any("The maximum" in str(rec.message) for rec in record) - remove(temp_fname) - - with pytest.warns() as record: - raw.export(temp_fname, physical_range=(0, 1e6)) - if dataset == "test": - assert any("Data has a non-integer" in str(rec.message) for rec in record) - assert any("The minimum" in str(rec.message) for rec in record) - remove(temp_fname) - - if dataset == "test": - with pytest.warns(RuntimeWarning, match="Data has a non-integer"): - raw.export(temp_fname) - elif dataset == "misc": - with pytest.warns(RuntimeWarning, match="EDF format requires"): - raw.export(temp_fname) + temp_fname = tmp_path / "test.edf" + + with pytest.warns(RuntimeWarning, match=warning_msg): + raw.export(temp_fname) if "epoc" in raw.ch_names: raw.drop_channels(["epoc"]) raw_read = read_raw_edf(temp_fname, preload=True) - assert orig_ch_names == raw_read.ch_names + assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) @@ -395,7 +375,7 @@ def test_export_raw_edf(tmp_path, dataset, format): # will result in a resolution of 0.09 uV. This resolution # though is acceptable for most EEG manufacturers. assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=8 ) # Due to the data record duration limitations of EDF files, one @@ -407,6 +387,27 @@ def test_export_raw_edf(tmp_path, dataset, format): assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) +@edfio_mark() +def test_export_raw_edf_does_not_fail_on_empty_header_fields(tmp_path): + """Test writing a Raw instance with empty header fields to EDF.""" + rng = np.random.RandomState(123456) + + ch_types = ["eeg"] + info = create_info(len(ch_types), sfreq=1000, ch_types=ch_types) + info["subject_info"] = { + "his_id": "", + "first_name": "", + "middle_name": "", + "last_name": "", + } + info["device_info"] = {"type": "123"} + + data = rng.random(size=(len(ch_types), 1000)) * 1e-5 + raw = RawArray(data, info) + + raw.export(tmp_path / "test.edf", add_ch_type=True) + + @pytest.mark.xfail(reason="eeglabio (usage?) bugs that should be fixed") @pytest.mark.parametrize("preload", (True, False)) def test_export_epochs_eeglab(tmp_path, preload): @@ -459,6 +460,7 @@ def test_export_epochs_eeglab(tmp_path, preload): def test_export_evokeds_to_mff(tmp_path, fmt, do_history): """Test exporting evoked dataset to MFF.""" pytest.importorskip("mffpy", "0.5.7") + pytest.importorskip("defusedxml") evoked = read_evokeds_mff(egi_evoked_fname) export_fname = tmp_path / "evoked.mff" history = [ @@ -515,6 +517,7 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history): def test_export_to_mff_no_device(): """Test no device type throws ValueError.""" pytest.importorskip("mffpy", "0.5.7") + pytest.importorskip("defusedxml") evoked = read_evokeds_mff(egi_evoked_fname, condition="Category 1") evoked.info["device_info"] = None with pytest.raises(ValueError, match="No device type."): diff --git a/mne/filter.py b/mne/filter.py index 528128822b8..82b77a17a7c 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -5,6 +5,7 @@ from collections import Counter from copy import deepcopy from functools import partial +from math import gcd import numpy as np from scipy import fft, signal @@ -19,6 +20,7 @@ _setup_cuda_fft_resample, _smart_pad, ) +from .fixes import minimum_phase from .parallel import parallel_func from .utils import ( _check_option, @@ -306,39 +308,7 @@ def _overlap_add_filter( copy=True, pad="reflect_limited", ): - """Filter the signal x using h with overlap-add FFTs. - - Parameters - ---------- - x : array, shape (n_signals, n_times) - Signals to filter. - h : 1d array - Filter impulse response (FIR filter coefficients). Must be odd length - if ``phase='linear'``. - n_fft : int - Length of the FFT. If None, the best size is determined automatically. - phase : str - If ``'zero'``, the delay for the filter is compensated (and it must be - an odd-length symmetric filter). If ``'linear'``, the response is - uncompensated. If ``'zero-double'``, the filter is applied in the - forward and reverse directions. If 'minimum', a minimum-phase - filter will be used. - picks : list | None - See calling functions. - n_jobs : int | str - Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy`` - is installed properly. - copy : bool - If True, a copy of x, filtered, is returned. Otherwise, it operates - on x in place. - pad : str - Padding type for ``_smart_pad``. - - Returns - ------- - x : array, shape (n_signals, n_times) - x filtered. - """ + """Filter the signal x using h with overlap-add FFTs.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) # Extend the signal by mirroring the edges to reduce transient filter @@ -379,8 +349,8 @@ def _overlap_add_filter( logger.debug("FFT block length: %s" % n_fft) if n_fft < min_fft: raise ValueError( - "n_fft is too short, has to be at least " - "2 * len(h) - 1 (%s), got %s" % (min_fft, n_fft) + f"n_fft is too short, has to be at least 2 * len(h) - 1 ({min_fft}), got " + f"{n_fft}" ) # Figure out if we should use CUDA @@ -492,9 +462,9 @@ def _firwin_design(N, freq, gain, window, sfreq): this_N += 1 - this_N % 2 # make it odd if this_N > N: raise ValueError( - "The requested filter length %s is too short " - "for the requested %0.2f Hz transition band, " - "which requires %s samples" % (N, transition * sfreq / 2.0, this_N) + f"The requested filter length {N} is too short for the requested " + f"{transition * sfreq / 2.0:0.2f} Hz transition band, which " + f"requires {this_N} samples" ) # Construct a lowpass this_h = signal.firwin( @@ -525,34 +495,6 @@ def _construct_fir_filter( (windowing is a smoothing in frequency domain). If x is multi-dimensional, this operates along the last dimension. - - Parameters - ---------- - sfreq : float - Sampling rate in Hz. - freq : 1d array - Frequency sampling points in Hz. - gain : 1d array - Filter gain at frequency sampling points. - Must be all 0 and 1 for fir_design=="firwin". - filter_length : int - Length of the filter to use. Must be odd length if phase == "zero". - phase : str - If 'zero', the delay for the filter is compensated (and it must be - an odd-length symmetric filter). If 'linear', the response is - uncompensated. If 'zero-double', the filter is applied in the - forward and reverse directions. If 'minimum', a minimum-phase - filter will be used. - fir_window : str - The window to use in FIR design, can be "hamming" (default), - "hann", or "blackman". - fir_design : str - Can be "firwin2" or "firwin". - - Returns - ------- - h : array - Filter coefficients. """ assert freq[0] == 0 if fir_design == "firwin2": @@ -561,24 +503,26 @@ def _construct_fir_filter( assert fir_design == "firwin" fir_design = partial(_firwin_design, sfreq=sfreq) # issue a warning if attenuation is less than this - min_att_db = 12 if phase == "minimum" else 20 + min_att_db = 12 if phase == "minimum-half" else 20 # normalize frequencies freq = np.array(freq) / (sfreq / 2.0) if freq[0] != 0 or freq[-1] != 1: raise ValueError( - "freq must start at 0 and end an Nyquist (%s), got %s" % (sfreq / 2.0, freq) + f"freq must start at 0 and end an Nyquist ({sfreq / 2.0}), got {freq}" ) gain = np.array(gain) # Use overlap-add filter with a fixed length N = _check_zero_phase_length(filter_length, phase, gain[-1]) # construct symmetric (linear phase) filter - if phase == "minimum": + if phase == "minimum-half": h = fir_design(N * 2 - 1, freq, gain, window=fir_window) - h = signal.minimum_phase(h) + h = minimum_phase(h) else: h = fir_design(N, freq, gain, window=fir_window) + if phase == "minimum": + h = minimum_phase(h, half=False) assert h.size == N att_db, att_freq = _filter_attenuation(h, freq, gain) if phase == "zero-double": @@ -586,8 +530,8 @@ def _construct_fir_filter( if att_db < min_att_db: att_freq *= sfreq / 2.0 warn( - "Attenuation at stop frequency %0.2f Hz is only %0.2f dB. " - "Increase filter_length for higher attenuation." % (att_freq, att_db) + f"Attenuation at stop frequency {att_freq:0.2f} Hz is only {att_db:0.2f} " + "dB. Increase filter_length for higher attenuation." ) return h @@ -596,9 +540,7 @@ def _check_zero_phase_length(N, phase, gain_nyq=0): N = int(N) if N % 2 == 0: if phase == "zero": - raise RuntimeError( - 'filter_length must be odd if phase="zero", ' "got %s" % N - ) + raise RuntimeError(f'filter_length must be odd if phase="zero", got {N}') elif phase == "zero-double" and gain_nyq == 1: N += 1 return N @@ -874,26 +816,20 @@ def construct_iir_filter( # ensure we have a valid ftype if "ftype" not in iir_params: raise RuntimeError( - "ftype must be an entry in iir_params if " - "b" - " " - "and " - "a" - " are not specified" + "ftype must be an entry in iir_params if 'b' and 'a' are not specified." ) ftype = iir_params["ftype"] if ftype not in known_filters: raise RuntimeError( - "ftype must be in filter_dict from " - "scipy.signal (e.g., butter, cheby1, etc.) not " - "%s" % ftype + "ftype must be in filter_dict from scipy.signal (e.g., butter, cheby1, " + f"etc.) not {ftype}" ) # use order-based design f_pass = np.atleast_1d(f_pass) if f_pass.ndim > 1: raise ValueError("frequencies must be 1D, got %dD" % f_pass.ndim) - edge_freqs = ", ".join("%0.2f" % (f,) for f in f_pass) + edge_freqs = ", ".join(f"{f:0.2f}" for f in f_pass) Wp = f_pass / (float(sfreq) / 2) # IT will de designed ftype_nice = _ftype_dict.get(ftype, ftype) @@ -934,14 +870,7 @@ def construct_iir_filter( Ws = np.asanyarray(f_stop) / (float(sfreq) / 2) if "gpass" not in iir_params or "gstop" not in iir_params: raise ValueError( - "iir_params must have at least " - "gstop" - " and" - " " - "gpass" - " (or " - "N" - ") entries" + "iir_params must have at least 'gstop' and 'gpass' (or N) entries." ) system = signal.iirdesign( Wp, @@ -967,8 +896,8 @@ def construct_iir_filter( # 2 * 20 here because we do forward-backward filtering if phase in ("zero", "zero-double"): cutoffs *= 2 - cutoffs = ", ".join(["%0.2f" % (c,) for c in cutoffs]) - logger.info("- Cutoff%s at %s Hz: %s dB" % (_pl(f_pass), edge_freqs, cutoffs)) + cutoffs = ", ".join([f"{c:0.2f}" for c in cutoffs]) + logger.info(f"- Cutoff{_pl(f_pass)} at {edge_freqs} Hz: {cutoffs} dB") # now deal with padding if "padlen" not in iir_params: padlen = estimate_ringing_samples(system) @@ -1253,16 +1182,15 @@ def create_filter( # If no data specified, sanity checking will be skipped if data is None: logger.info( - "No data specified. Sanity checks related to the length of" - " the signal relative to the filter order will be" - " skipped." + "No data specified. Sanity checks related to the length of the signal " + "relative to the filter order will be skipped." ) if h_freq is not None: h_freq = np.array(h_freq, float).ravel() if (h_freq > (sfreq / 2.0)).any(): raise ValueError( - "h_freq (%s) must be less than the Nyquist " - "frequency %s" % (h_freq, sfreq / 2.0) + f"h_freq ({h_freq}) must be less than the Nyquist frequency " + f"{sfreq / 2.0}" ) if l_freq is not None: l_freq = np.array(l_freq, float).ravel() @@ -1302,7 +1230,7 @@ def create_filter( gain = [1.0, 1.0] if l_freq is None and h_freq is not None: h_freq = h_freq.item() - logger.info("Setting up low-pass filter at %0.2g Hz" % (h_freq,)) + logger.info(f"Setting up low-pass filter at {h_freq:0.2g} Hz") ( data, sfreq, @@ -1339,7 +1267,7 @@ def create_filter( gain += [0] elif l_freq is not None and h_freq is None: l_freq = l_freq.item() - logger.info("Setting up high-pass filter at %0.2g Hz" % (l_freq,)) + logger.info(f"Setting up high-pass filter at {l_freq:0.2g} Hz") ( data, sfreq, @@ -1378,7 +1306,7 @@ def create_filter( if (l_freq < h_freq).any(): l_freq, h_freq = l_freq.item(), h_freq.item() logger.info( - "Setting up band-pass filter from %0.2g - %0.2g Hz" % (l_freq, h_freq) + f"Setting up band-pass filter from {l_freq:0.2g} - {h_freq:0.2g} Hz" ) ( data, @@ -1430,7 +1358,7 @@ def create_filter( msg = "Setting up band-stop filter" if len(l_freq) == 1: l_freq, h_freq = l_freq.item(), h_freq.item() - msg += " from %0.2g - %0.2g Hz" % (h_freq, l_freq) + msg += f" from {h_freq:0.2g} - {l_freq:0.2g} Hz" logger.info(msg) # Note: order of outputs is intentionally switched here! ( @@ -1491,7 +1419,7 @@ def create_filter( freq = np.r_[freq, [sfreq / 2.0]] gain = np.r_[gain, [1.0]] if np.any(np.abs(np.diff(gain, 2)) > 1): - raise ValueError("Stop bands are not sufficiently " "separated.") + raise ValueError("Stop bands are not sufficiently separated.") if method == "fir": out = _construct_fir_filter( sfreq, freq, gain, filter_length, phase, fir_window, fir_design @@ -1870,21 +1798,14 @@ def _check_filterable(x, kind="filtered", alternative="filter"): pass else: raise TypeError( - "This low-level function only operates on np.ndarray " - f"instances. To get a {kind} {name} instance, use a method " - f"like `inst_new = inst.copy().{alternative}(...)` " - "instead." + "This low-level function only operates on np.ndarray instances. To get " + f"a {kind} {name} instance, use a method like `inst_new = inst.copy()." + f"{alternative}(...)` instead." ) _validate_type(x, (np.ndarray, list, tuple), f"Data to be {kind}") x = np.asanyarray(x) if x.dtype != np.float64: - raise ValueError( - "Data to be %s must be real floating, got %s" - % ( - kind, - x.dtype, - ) - ) + raise ValueError(f"Data to be {kind} must be real floating, got {x.dtype}") return x @@ -1898,12 +1819,13 @@ def resample( x, up=1.0, down=1.0, - npad=100, + *, axis=-1, - window="boxcar", + window="auto", n_jobs=None, - pad="reflect_limited", - *, + pad="auto", + npad=100, + method="fft", verbose=None, ): """Resample an array. @@ -1918,15 +1840,18 @@ def resample( Factor to upsample by. down : float Factor to downsample by. - %(npad)s axis : int Axis along which to resample (default is the last axis). %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'reflect_limited'``. + ``n_jobs='cuda'`` is only supported when ``method="fft"``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(npad_resample)s + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1936,26 +1861,16 @@ def resample( Notes ----- - This uses (hopefully) intelligent edge padding and frequency-domain - windowing improve scipy.signal.resample's resampling method, which + When using ``method="fft"`` (default), + this uses (hopefully) intelligent edge padding and frequency-domain + windowing improve :func:`scipy.signal.resample`'s resampling method, which we have adapted for our use here. Choices of npad and window have important consequences, and the default choices should work well for most natural signals. - - Resampling arguments are broken into "up" and "down" components for future - compatibility in case we decide to use an upfirdn implementation. The - current implementation is functionally equivalent to passing - up=up/down and down=1. """ - # check explicitly for backwards compatibility - if not isinstance(axis, int): - err = ( - "The axis parameter needs to be an integer (got %s). " - "The axis parameter was missing from this function for a " - "period of time, you might be intending to specify the " - "subsequent window parameter." % repr(axis) - ) - raise TypeError(err) + _validate_type(method, str, "method") + _validate_type(pad, str, "pad") + _check_option("method", method, ("fft", "polyphase")) # make sure our arithmetic will work x = _check_filterable(x, "resampled", "resample") @@ -1963,31 +1878,89 @@ def resample( del up, down if axis < 0: axis = x.ndim + axis - orig_last_axis = x.ndim - 1 - if axis != orig_last_axis: - x = x.swapaxes(axis, orig_last_axis) - orig_shape = x.shape - x_len = orig_shape[-1] - if x_len == 0: - warn("x has zero length along last axis, returning a copy of x") + if x.shape[axis] == 0: + warn(f"x has zero length along axis={axis}, returning a copy of x") return x.copy() - bad_msg = 'npad must be "auto" or an integer' + + # prep for resampling along the last axis (swap axis with last then reshape) + out_shape = list(x.shape) + out_shape.pop(axis) + out_shape.append(final_len) + x = np.atleast_2d(x.swapaxes(axis, -1).reshape((-1, x.shape[axis]))) + + # do the resampling using FFT or polyphase methods + kwargs = dict(pad=pad, window=window, n_jobs=n_jobs) + if method == "fft": + y = _resample_fft(x, npad=npad, ratio=ratio, final_len=final_len, **kwargs) + else: + up, down, kwargs["window"] = _prep_polyphase( + ratio, x.shape[-1], final_len, window + ) + half_len = len(window) // 2 + logger.info( + f"Polyphase resampling neighborhood: ±{half_len} " + f"input sample{_pl(half_len)}" + ) + y = _resample_polyphase(x, up=up, down=down, **kwargs) + assert y.shape[-1] == final_len + + # restore dimensions (reshape then swap axis with last) + y = y.reshape(out_shape).swapaxes(axis, -1) + + return y + + +def _prep_polyphase(ratio, x_len, final_len, window): + if isinstance(window, str) and window == "auto": + window = ("kaiser", 5.0) # SciPy default + up = final_len + down = x_len + g_ = gcd(up, down) + up = up // g_ + down = down // g_ + # Figure out our signal neighborhood and design window (adapted from SciPy) + if not isinstance(window, (list, np.ndarray)): + # Design a linear-phase low-pass FIR filter + max_rate = max(up, down) + f_c = 1.0 / max_rate # cutoff of FIR filter (rel. to Nyquist) + half_len = 10 * max_rate # reasonable cutoff for sinc-like function + window = signal.firwin(2 * half_len + 1, f_c, window=window) + return up, down, window + + +def _resample_polyphase(x, *, up, down, pad, window, n_jobs): + if pad == "auto": + pad = "reflect" + kwargs = dict(padtype=pad, window=window, up=up, down=down) + _validate_type( + n_jobs, (None, "int-like"), "n_jobs", extra="when method='polyphase'" + ) + parallel, p_fun, n_jobs = parallel_func(signal.resample_poly, n_jobs) + if n_jobs == 1: + y = signal.resample_poly(x, axis=-1, **kwargs) + else: + y = np.array(parallel(p_fun(x_, **kwargs) for x_ in x)) + return y + + +def _resample_fft(x_flat, *, ratio, final_len, pad, window, npad, n_jobs): + x_len = x_flat.shape[-1] + pad = "reflect_limited" if pad == "auto" else pad + if (isinstance(window, str) and window == "auto") or window is None: + window = "boxcar" if isinstance(npad, str): - if npad != "auto": - raise ValueError(bad_msg) + _check_option("npad", npad, ("auto",), extra="when a string") # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 npad = 2 ** int(np.ceil(np.log2(x_len + min_add))) - x_len npad, extra = divmod(npad, 2) npads = np.array([npad, npad + extra], int) else: - if npad != int(npad): - raise ValueError(bad_msg) + npad = _ensure_int(npad, "npad", extra="or 'auto'") npads = np.array([npad, npad], int) del npad # prep for resampling now - x_flat = x.reshape((-1, x_len)) orig_len = x_len + npads.sum() # length after padding new_len = max(int(round(ratio * orig_len)), 1) # length after resampling to_removes = [int(round(ratio * npads[0]))] @@ -1997,15 +1970,12 @@ def resample( # assert np.abs(to_removes[1] - to_removes[0]) <= int(np.ceil(ratio)) # figure out windowing function - if window is not None: - if callable(window): - W = window(fft.fftfreq(orig_len)) - elif isinstance(window, np.ndarray) and window.shape == (orig_len,): - W = window - else: - W = fft.ifftshift(signal.get_window(window, orig_len)) + if callable(window): + W = window(fft.fftfreq(orig_len)) + elif isinstance(window, np.ndarray) and window.shape == (orig_len,): + W = window else: - W = np.ones(orig_len) + W = fft.ifftshift(signal.get_window(window, orig_len)) W *= float(new_len) / float(orig_len) # figure out if we should use CUDA @@ -2015,7 +1985,7 @@ def resample( # use of the 'flat' window is recommended for minimal ringing parallel, p_fun, n_jobs = parallel_func(_fft_resample, n_jobs) if n_jobs == 1: - y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x.dtype) + y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x_flat.dtype) for xi, x_ in enumerate(x_flat): y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: @@ -2024,12 +1994,6 @@ def resample( ) y = np.array(y) - # Restore the original array shape (modified for resampling) - y.shape = orig_shape[:-1] + (y.shape[1],) - if axis != orig_last_axis: - y = y.swapaxes(axis, orig_last_axis) - assert y.shape[axis] == final_len - return y @@ -2141,7 +2105,7 @@ def detrend(x, order=1, axis=-1): "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) -_known_phases_fir = ("linear", "zero", "zero-double", "minimum") +_known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half") _known_phases_iir = ("zero", "zero-double", "forward") _known_fir_designs = ("firwin", "firwin2") _fir_design_dict = { @@ -2235,15 +2199,12 @@ def float_array(c): if l_freq is not None: l_freq = cast(l_freq) if np.any(l_freq <= 0): - raise ValueError( - "highpass frequency %s must be greater than zero" % (l_freq,) - ) + raise ValueError(f"highpass frequency {l_freq} must be greater than zero") if h_freq is not None: h_freq = cast(h_freq) if np.any(h_freq >= sfreq / 2.0): raise ValueError( - "lowpass frequency %s must be less than Nyquist " - "(%s)" % (h_freq, sfreq / 2.0) + f"lowpass frequency {h_freq} must be less than Nyquist ({sfreq / 2.0})" ) dB_cutoff = False # meaning, don't try to compute or report @@ -2263,12 +2224,9 @@ def float_array(c): logger.info("FIR filter parameters") logger.info("---------------------") logger.info( - "Designing a %s, %s, %s %s filter:" - % (report_pass, report_phase, causality, kind) - ) - logger.info( - "- %s design (%s) method" % (_fir_design_dict[fir_design], fir_design) + f"Designing a {report_pass}, {report_phase}, {causality} {kind} filter:" ) + logger.info(f"- {_fir_design_dict[fir_design]} design ({fir_design}) method") this_dict = _fir_window_dict[fir_window] if fir_design == "firwin": logger.info( @@ -2282,8 +2240,8 @@ def float_array(c): if isinstance(l_trans_bandwidth, str): if l_trans_bandwidth != "auto": raise ValueError( - 'l_trans_bandwidth must be "auto" if ' - 'string, got "%s"' % l_trans_bandwidth + 'l_trans_bandwidth must be "auto" if string, got "' + f'{l_trans_bandwidth}"' ) l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.0), l_freq) l_trans_rep = np.array(l_trans_bandwidth, float) @@ -2305,7 +2263,7 @@ def float_array(c): l_trans_bandwidth = cast(l_trans_bandwidth) if np.any(l_trans_bandwidth <= 0): raise ValueError( - "l_trans_bandwidth must be positive, got %s" % (l_trans_bandwidth,) + f"l_trans_bandwidth must be positive, got {l_trans_bandwidth}" ) l_stop = l_freq - l_trans_bandwidth if reverse: # band-stop style @@ -2313,10 +2271,9 @@ def float_array(c): l_freq += l_trans_bandwidth if np.any(l_stop < 0): raise ValueError( - "Filter specification invalid: Lower stop " - "frequency negative (%0.2f Hz). Increase pass" - " frequency or reduce the transition " - "bandwidth (l_trans_bandwidth)" % l_stop + "Filter specification invalid: Lower stop frequency negative (" + f"{l_stop:0.2f} Hz). Increase pass frequency or reduce the " + "transition bandwidth (l_trans_bandwidth)" ) if h_freq is not None: # low-pass component if isinstance(h_trans_bandwidth, str): @@ -2346,7 +2303,7 @@ def float_array(c): h_trans_bandwidth = cast(h_trans_bandwidth) if np.any(h_trans_bandwidth <= 0): raise ValueError( - "h_trans_bandwidth must be positive, got %s" % (h_trans_bandwidth,) + f"h_trans_bandwidth must be positive, got {h_trans_bandwidth}" ) h_stop = h_freq + h_trans_bandwidth if reverse: # band-stop style @@ -2354,8 +2311,8 @@ def float_array(c): h_freq -= h_trans_bandwidth if np.any(h_stop > sfreq / 2): raise ValueError( - "Effective band-stop frequency (%s) is too " - "high (maximum based on Nyquist is %s)" % (h_stop, sfreq / 2.0) + f"Effective band-stop frequency ({h_stop}) is too high (maximum " + f"based on Nyquist is {sfreq / 2.0})" ) if isinstance(filter_length, str) and filter_length.lower() == "auto": @@ -2366,9 +2323,7 @@ def float_array(c): if l_freq is not None: l_check = min(np.atleast_1d(l_trans_bandwidth)) mult_fact = 2.0 if fir_design == "firwin2" else 1.0 - filter_length = "%ss" % ( - _length_factors[fir_window] * mult_fact / float(min(h_check, l_check)), - ) + filter_length = f"{_length_factors[fir_window] * mult_fact / float(min(h_check, l_check))}s" # noqa: E501 next_pow_2 = False # disable old behavior else: next_pow_2 = isinstance(filter_length, str) and phase == "zero-double" @@ -2381,15 +2336,12 @@ def float_array(c): filter_length += (filter_length - 1) % 2 logger.info( - "- Filter length: %s samples (%0.3f s)" - % (filter_length, filter_length / sfreq) + f"- Filter length: {filter_length} samples ({filter_length / sfreq:0.3f} s)" ) logger.info("") if filter_length <= 0: - raise ValueError( - "filter_length must be positive, got %s" % (filter_length,) - ) + raise ValueError(f"filter_length must be positive, got {filter_length}") if next_pow_2: filter_length = 2 ** int(np.ceil(np.log2(filter_length))) @@ -2404,9 +2356,8 @@ def float_array(c): filter_length = len_x if filter_length > len_x and not (l_freq is None and h_freq is None): warn( - "filter_length (%s) is longer than the signal (%s), " - "distortion is likely. Reduce filter length or filter a " - "longer signal." % (filter_length, len_x) + f"filter_length ({filter_length}) is longer than the signal ({len_x}), " + "distortion is likely. Reduce filter length or filter a longer signal." ) logger.debug("Using filter length: %s" % filter_length) @@ -2427,10 +2378,8 @@ def float_array(c): def _check_resamp_noop(sfreq, o_sfreq, rtol=1e-6): if np.isclose(sfreq, o_sfreq, atol=0, rtol=rtol): logger.info( - ( - f"Sampling frequency of the instance is already {sfreq}, " - "returning unmodified." - ) + f"Sampling frequency of the instance is already {sfreq}, returning " + "unmodified." ) return True return False @@ -2456,7 +2405,7 @@ def savgol_filter(self, h_freq, verbose=None): Returns ------- - inst : instance of Epochs or Evoked + inst : instance of Epochs, Evoked or SourceEstimate The object with the filtering applied. See Also @@ -2469,6 +2418,8 @@ def savgol_filter(self, h_freq, verbose=None): https://gist.github.com/larsoner/bbac101d50176611136b + When working on SourceEstimates the sample rate of the original data is inferred from tstep. + .. versionadded:: 0.9.0 References @@ -2484,13 +2435,19 @@ def savgol_filter(self, h_freq, verbose=None): >>> evoked.savgol_filter(10.) # low-pass at around 10 Hz # doctest:+SKIP >>> evoked.plot() # doctest:+SKIP """ # noqa: E501 + from .source_estimate import _BaseSourceEstimate + _check_preload(self, "inst.savgol_filter") + if not isinstance(self, _BaseSourceEstimate): + s_freq = self.info["sfreq"] + else: + s_freq = 1 / self.tstep h_freq = float(h_freq) - if h_freq >= self.info["sfreq"] / 2.0: + if h_freq >= s_freq / 2.0: raise ValueError("h_freq must be less than half the sample rate") # savitzky-golay filtering - window_length = (int(np.round(self.info["sfreq"] / h_freq)) // 2) * 2 + 1 + window_length = (int(np.round(s_freq / h_freq)) // 2) * 2 + 1 logger.info("Using savgol length %d" % window_length) self._data[:] = signal.savgol_filter( self._data, axis=-1, polyorder=5, window_length=window_length @@ -2517,7 +2474,7 @@ def filter( *, verbose=None, ): - """Filter a subset of channels. + """Filter a subset of channels/vertices. Parameters ---------- @@ -2541,7 +2498,7 @@ def filter( Returns ------- - inst : instance of Epochs, Evoked, or Raw + inst : instance of Epochs, Evoked, SourceEstimate, or Raw The filtered data. See Also @@ -2578,6 +2535,9 @@ def filter( ``len(picks) * n_times`` additional time points need to be temporarily stored in memory. + When working on SourceEstimates the sample rate of the original + data is inferred from tstep. + For more information, see the tutorials :ref:`disc-filtering` and :ref:`tut-filter-resample` and :func:`mne.filter.create_filter`. @@ -2586,11 +2546,16 @@ def filter( """ from .annotations import _annotations_starts_stops from .io import BaseRaw + from .source_estimate import _BaseSourceEstimate _check_preload(self, "inst.filter") + if not isinstance(self, _BaseSourceEstimate): + update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq) + s_freq = self.info["sfreq"] + else: + s_freq = 1.0 / self.tstep if pad is None and method != "iir": pad = "edge" - update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq) if isinstance(self, BaseRaw): # Deal with annotations onsets, ends = _annotations_starts_stops( @@ -2609,7 +2574,7 @@ def filter( use_verbose = verbose if si == max_idx else "error" filter_data( self._data[:, start:stop], - self.info["sfreq"], + s_freq, l_freq, h_freq, picks, @@ -2626,20 +2591,22 @@ def filter( pad=pad, verbose=use_verbose, ) - # update info if filter is applied to all data channels, + # update info if filter is applied to all data channels/vertices, # and it's not a band-stop filter - _filt_update_info(self.info, update_info, l_freq, h_freq) + if not isinstance(self, _BaseSourceEstimate): + _filt_update_info(self.info, update_info, l_freq, h_freq) return self @verbose def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", n_jobs=None, pad="edge", - *, + method="fft", verbose=None, ): """Resample data. @@ -2656,11 +2623,12 @@ def resample( %(npad)s %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'edge'``, which pads with the edge values of each - vector. + %(pad_resample)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -2681,7 +2649,7 @@ def resample( from .evoked import Evoked # Should be guaranteed by our inheritance, and the fact that - # mne.io.BaseRaw overrides this method + # mne.io.BaseRaw and _BaseSourceEstimate overrides this method assert isinstance(self, (BaseEpochs, Evoked)) sfreq = float(sfreq) @@ -2691,7 +2659,14 @@ def resample( _check_preload(self, "inst.resample") self._data = resample( - self._data, sfreq, o_sfreq, npad, window=window, n_jobs=n_jobs, pad=pad + self._data, + sfreq, + o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass @@ -2711,13 +2686,13 @@ def resample( def apply_hilbert( self, picks=None, envelope=False, n_jobs=None, n_fft="auto", *, verbose=None ): - """Compute analytic signal or envelope for a subset of channels. + """Compute analytic signal or envelope for a subset of channels/vertices. Parameters ---------- %(picks_all_data_noref)s envelope : bool - Compute the envelope signal of each channel. Default False. + Compute the envelope signal of each channel/vertex. Default False. See Notes. %(n_jobs)s n_fft : int | None | str @@ -2729,19 +2704,19 @@ def apply_hilbert( Returns ------- - self : instance of Raw, Epochs, or Evoked + self : instance of Raw, Epochs, Evoked or SourceEstimate The raw object with transformed data. Notes ----- **Parameters** - If ``envelope=False``, the analytic signal for the channels defined in + If ``envelope=False``, the analytic signal for the channels/vertices defined in ``picks`` is computed and the data of the Raw object is converted to a complex representation (the analytic signal is complex valued). If ``envelope=True``, the absolute value of the analytic signal for the - channels defined in ``picks`` is computed, resulting in the envelope + channels/vertices defined in ``picks`` is computed, resulting in the envelope signal. .. warning: Do not use ``envelope=True`` if you intend to compute @@ -2774,24 +2749,30 @@ def apply_hilbert( by computing the analytic signal in sensor space, applying the MNE inverse, and computing the envelope in source space. """ + from .source_estimate import _BaseSourceEstimate + + if not isinstance(self, _BaseSourceEstimate): + use_info = self.info + else: + use_info = len(self._data) _check_preload(self, "inst.apply_hilbert") + picks = _picks_to_idx(use_info, picks, exclude=(), with_ref_meg=False) + if n_fft is None: n_fft = len(self.times) elif isinstance(n_fft, str): if n_fft != "auto": raise ValueError( - "n_fft must be an integer, string, or None, " - "got %s" % (type(n_fft),) + f"n_fft must be an integer, string, or None, got {type(n_fft)}" ) n_fft = next_fast_len(len(self.times)) n_fft = int(n_fft) if n_fft < len(self.times): raise ValueError( - "n_fft (%d) must be at least the number of time " - "points (%d)" % (n_fft, len(self.times)) + f"n_fft ({n_fft}) must be at least the number of time points (" + f"{len(self.times)})" ) dtype = None if envelope else np.complex128 - picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) args, kwargs = (), dict(n_fft=n_fft, envelope=envelope) data_in = self._data @@ -2822,9 +2803,7 @@ def _check_fun(fun, d, *args, **kwargs): if not isinstance(d, np.ndarray): raise TypeError("Return value must be an ndarray") if d.shape != want_shape: - raise ValueError( - "Return data must have shape %s not %s" % (want_shape, d.shape) - ) + raise ValueError(f"Return data must have shape {want_shape} not {d.shape}") return d diff --git a/mne/fixes.py b/mne/fixes.py index 1d3cc5aadb4..f7534377b5a 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -31,9 +31,8 @@ ############################################################################### # distutils -# distutils has been deprecated since Python 3.10 and is scheduled for removal -# from the standard library with the release of Python 3.12. For version -# comparisons, we use setuptools's `parse_version` if available. +# distutils has been deprecated since Python 3.10 and was removed +# from the standard library with the release of Python 3.12. def _compare_version(version_a, operator, version_b): @@ -99,7 +98,7 @@ def _safe_svd(A, **kwargs): except np.linalg.LinAlgError as exp: from .utils import warn - warn("SVD error (%s), attempting to use GESVD instead of GESDD" % (exp,)) + warn(f"SVD error ({exp}), attempting to use GESVD instead of GESDD") return linalg.svd(A, lapack_driver="gesvd", **kwargs) @@ -114,7 +113,7 @@ def _csc_matrix_cast(x): def rng_uniform(rng): - """Get the unform/randint from the rng.""" + """Get the uniform/randint from the rng.""" # prefer Generator.integers, fall back to RandomState.randint return getattr(rng, "integers", getattr(rng, "randint", None)) @@ -193,8 +192,8 @@ def _get_param_names(cls): "scikit-learn estimators should always " "specify their parameters in the signature" " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." % (cls, init_signature) + f" {cls} with constructor {init_signature} doesn't " + " follow this convention." ) # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) @@ -223,7 +222,7 @@ def get_params(self, deep=True): try: with warnings.catch_warnings(record=True) as w: value = getattr(self, key, None) - if len(w) and w[0].category == DeprecationWarning: + if len(w) and w[0].category is DeprecationWarning: # if the parameter is deprecated, don't show it continue finally: @@ -265,9 +264,9 @@ def set_params(self, **params): name, sub_name = split if name not in valid_params: raise ValueError( - "Invalid parameter %s for estimator %s. " + f"Invalid parameter {name} for estimator {self}. " "Check the list of available parameters " - "with `estimator.get_params().keys()`." % (name, self) + "with `estimator.get_params().keys()`." ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) @@ -275,10 +274,10 @@ def set_params(self, **params): # simple objects case if key not in valid_params: raise ValueError( - "Invalid parameter %s for estimator %s. " + f"Invalid parameter {key} for estimator " + f"{self.__class__.__name__}. " "Check the list of available parameters " "with `estimator.get_params().keys()`." - % (key, self.__class__.__name__) ) setattr(self, key, value) return self @@ -288,7 +287,7 @@ def __repr__(self): # noqa: D105 pprint(self.get_params(deep=False), params) params.seek(0) class_name = self.__class__.__name__ - return "%s(%s)" % (class_name, params.read().strip()) + return f"{class_name}({params.read().strip()})" # __getstate__ and __setstate__ are omitted because they only contain # conditionals that are not satisfied by our objects (e.g., @@ -867,7 +866,7 @@ def pinvh(a, rtol=None): def pinv(a, rtol=None): """Compute a pseudo-inverse of a matrix.""" - u, s, vh = np.linalg.svd(a, full_matrices=False) + u, s, vh = _safe_svd(a, full_matrices=False) del a maxS = np.max(s) if rtol is None: @@ -890,3 +889,58 @@ def _numpy_h5py_dep(): "ignore", "`product` is deprecated.*", DeprecationWarning ) yield + + +def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True): + """Wrap scipy.signal.minimum_phase with half option.""" + # Can be removed once + from scipy.fft import fft, ifft + from scipy.signal import minimum_phase as sp_minimum_phase + + assert isinstance(method, str) and method == "homomorphic" + + if "half" in inspect.getfullargspec(sp_minimum_phase).kwonlyargs: + return sp_minimum_phase(h, method=method, n_fft=n_fft, half=half) + h = np.asarray(h) + if np.iscomplexobj(h): + raise ValueError("Complex filters not supported") + if h.ndim != 1 or h.size <= 2: + raise ValueError("h must be 1-D and at least 2 samples long") + n_half = len(h) // 2 + if not np.allclose(h[-n_half:][::-1], h[:n_half]): + warnings.warn( + "h does not appear to by symmetric, conversion may fail", + RuntimeWarning, + stacklevel=2, + ) + if n_fft is None: + n_fft = 2 ** int(np.ceil(np.log2(2 * (len(h) - 1) / 0.01))) + n_fft = int(n_fft) + if n_fft < len(h): + raise ValueError("n_fft must be at least len(h)==%s" % len(h)) + + # zero-pad; calculate the DFT + h_temp = np.abs(fft(h, n_fft)) + # take 0.25*log(|H|**2) = 0.5*log(|H|) + h_temp += 1e-7 * h_temp[h_temp > 0].min() # don't let log blow up + np.log(h_temp, out=h_temp) + if half: # halving of magnitude spectrum optional + h_temp *= 0.5 + # IDFT + h_temp = ifft(h_temp).real + # multiply pointwise by the homomorphic filter + # lmin[n] = 2u[n] - d[n] + # i.e., double the positive frequencies and zero out the negative ones; + # Oppenheim+Shafer 3rd ed p991 eq13.42b and p1004 fig13.7 + win = np.zeros(n_fft) + win[0] = 1 + stop = n_fft // 2 + win[1:stop] = 2 + if n_fft % 2: + win[stop] = 1 + h_temp *= win + h_temp = ifft(np.exp(fft(h_temp))) + h_minimum = h_temp.real + + n_out = (n_half + len(h) % 2) if half else len(h) + return h_minimum[:n_out] diff --git a/mne/forward/_compute_forward.py b/mne/forward/_compute_forward.py index 6c4e157f7f9..641f315239a 100644 --- a/mne/forward/_compute_forward.py +++ b/mne/forward/_compute_forward.py @@ -661,7 +661,7 @@ def _magnetic_dipole_field_vec(rrs, coils, too_close="raise"): rmags, cosmags, ws, bins = _triage_coils(coils) fwd, min_dist = _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close) if min_dist < _MIN_DIST_LIMIT: - msg = "Coil too close (dist = %g mm)" % (min_dist * 1000,) + msg = f"Coil too close (dist = {min_dist * 1000:g} mm)" if too_close == "raise": raise RuntimeError(msg) func = warn if too_close == "warning" else logger.info diff --git a/mne/forward/_lead_dots.py b/mne/forward/_lead_dots.py index b158f6db07f..3b2118de409 100644 --- a/mne/forward/_lead_dots.py +++ b/mne/forward/_lead_dots.py @@ -69,12 +69,12 @@ def _get_legen_table( # Updated due to API change (GH 1167) os.makedirs(fname) if ch_type == "meg": - fname = op.join(fname, "legder_%s_%s.bin" % (n_coeff, n_interp)) + fname = op.join(fname, f"legder_{n_coeff}_{n_interp}.bin") leg_fun = _get_legen_der extra_str = " derivative" lut_shape = (n_interp + 1, n_coeff, 3) else: # 'eeg' - fname = op.join(fname, "legval_%s_%s.bin" % (n_coeff, n_interp)) + fname = op.join(fname, f"legval_{n_coeff}_{n_interp}.bin") leg_fun = _get_legen extra_str = "" lut_shape = (n_interp + 1, n_coeff) diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 04b0eaf9592..24131ad4a10 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -93,7 +93,7 @@ def _read_coil_def_file(fname, use_registry=True): if not use_registry or fname not in _coil_registry: big_val = 0.5 coils = list() - with open(fname, "r") as fid: + with open(fname) as fid: lines = fid.readlines() lines = lines[::-1] while len(lines) > 0: @@ -299,8 +299,8 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None) else: if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise RuntimeError( - "BEM is in %s coordinates, should be in MRI" - % (_coord_frame_name(bem["surfs"][0]["coord_frame"]),) + f'BEM is in {_coord_frame_name(bem["surfs"][0]["coord_frame"])} ' + 'coordinates, should be in MRI' ) if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( @@ -661,7 +661,7 @@ def make_forward_solution( followed by :func:`mne.convert_forward_solution`. .. note:: - If the BEM solution was computed with :doc:`OpenMEEG ` + If the BEM solution was computed with `OpenMEEG `__ in :func:`mne.make_bem_solution`, then OpenMEEG will automatically be used to compute the forward solution. @@ -693,7 +693,7 @@ def make_forward_solution( logger.info("MRI -> head transform : %s" % trans) logger.info("Measurement data : %s" % info_extra) if isinstance(bem, ConductorModel) and bem["is_sphere"]: - logger.info("Sphere model : origin at %s mm" % (bem["r0"],)) + logger.info(f"Sphere model : origin at {bem['r0']} mm") logger.info("Standard field computations") else: logger.info("Conductor model : %s" % bem_extra) @@ -819,8 +819,9 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, verbose=N head = "The following dipoles are outside the inner skull boundary" msg = len(head) * "#" + "\n" + head + "\n" for t, pos in zip(times[np.logical_not(inuse)], pos[np.logical_not(inuse)]): - msg += " t={:.0f} ms, pos=({:.0f}, {:.0f}, {:.0f}) mm\n".format( - t * 1000.0, pos[0] * 1000.0, pos[1] * 1000.0, pos[2] * 1000.0 + msg += ( + f" t={t * 1000.0:.0f} ms, pos=({pos[0] * 1000.0:.0f}, " + f"{pos[1] * 1000.0:.0f}, {pos[2] * 1000.0:.0f}) mm\n" ) msg += len(head) * "#" logger.error(msg) diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 9531445edd1..ebe005787fb 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -525,7 +525,7 @@ def _merge_fwds(fwds, *, verbose=None): @verbose -def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbose=None): +def read_forward_solution(fname, include=(), exclude=(), *, ordered=True, verbose=None): """Read a forward solution a.k.a. lead field. Parameters @@ -1258,7 +1258,7 @@ def compute_orient_prior(forward, loose="auto", verbose=None): if any(v > 0.0 for v in loose.values()): raise ValueError( "loose must be 0. with forward operator " - "with fixed orientation, got %s" % (loose,) + f"with fixed orientation, got {loose}" ) return orient_prior if all(v == 1.0 for v in loose.values()): @@ -1269,7 +1269,7 @@ def compute_orient_prior(forward, loose="auto", verbose=None): raise ValueError( "Forward operator is not oriented in surface " "coordinates. loose parameter should be 1. " - "not %s." % (loose,) + f"not {loose}." ) start = 0 logged = dict() @@ -1419,13 +1419,12 @@ def compute_depth_prior( if isinstance(limit_depth_chs, str): if limit_depth_chs != "whiten": raise ValueError( - 'limit_depth_chs, if str, must be "whiten", got ' - "%s" % (limit_depth_chs,) + f'limit_depth_chs, if str, must be "whiten", got {limit_depth_chs}' ) if not isinstance(noise_cov, Covariance): raise ValueError( 'With limit_depth_chs="whiten", noise_cov must be' - " a Covariance, got %s" % (type(noise_cov),) + f" a Covariance, got {type(noise_cov)}" ) if combine_xyz is not False: # private / expert option _check_option("combine_xyz", combine_xyz, ("fro", "spectral")) @@ -1456,8 +1455,10 @@ def compute_depth_prior( # d[k] = linalg.svdvals(x)[0] G.shape = (G.shape[0], -1, 3) d = np.linalg.norm( - np.einsum("svj,svk->vjk", G, G), ord=2, axis=(1, 2) # vector dot prods - ) # ord=2 spectral (largest s.v.) + np.einsum("svj,svk->vjk", G, G), # vector dot prods + ord=2, # ord=2 spectral (largest s.v.) + axis=(1, 2), + ) G.shape = (G.shape[0], -1) # XXX Currently the fwd solns never have "patch_areas" defined @@ -1489,7 +1490,7 @@ def compute_depth_prior( " limit = %d/%d = %f" % (n_limit + 1, len(d), np.sqrt(limit / ws[0])) ) scale = 1.0 / limit - logger.info(" scale = %g exp = %g" % (scale, exp)) + logger.info(f" scale = {scale:g} exp = {exp:g}") w = np.minimum(w / limit, 1) depth_prior = w**exp @@ -1511,8 +1512,8 @@ def _stc_src_sel( del stc if not len(src) == len(vertices): raise RuntimeError( - "Mismatch between number of source spaces (%s) and " - "STC vertices (%s)" % (len(src), len(vertices)) + f"Mismatch between number of source spaces ({len(src)}) and " + f"STC vertices ({len(vertices)})" ) src_sels, stc_sels, out_vertices = [], [], [] src_offset = stc_offset = 0 diff --git a/mne/forward/tests/test_field_interpolation.py b/mne/forward/tests/test_field_interpolation.py index f19b844d46c..4f09a90df73 100644 --- a/mne/forward/tests/test_field_interpolation.py +++ b/mne/forward/tests/test_field_interpolation.py @@ -237,10 +237,16 @@ def test_make_field_map_meeg(): assert_allclose(map_["data"].min(), min_, rtol=5e-2) # calculated from correct looking mapping on 2015/12/26 assert_allclose( - np.sqrt(np.sum(maps[0]["data"] ** 2)), 19.0903, atol=1e-3, rtol=1e-3 # 16.6088, + np.sqrt(np.sum(maps[0]["data"] ** 2)), + 19.0903, + atol=1e-3, + rtol=1e-3, ) assert_allclose( - np.sqrt(np.sum(maps[1]["data"] ** 2)), 19.4748, atol=1e-3, rtol=1e-3 # 20.1245, + np.sqrt(np.sum(maps[1]["data"] ** 2)), + 19.4748, + atol=1e-3, + rtol=1e-3, ) diff --git a/mne/forward/tests/test_forward.py b/mne/forward/tests/test_forward.py index f636a424813..7442c68959c 100644 --- a/mne/forward/tests/test_forward.py +++ b/mne/forward/tests/test_forward.py @@ -37,16 +37,14 @@ ) from mne.io import read_info from mne.label import read_label -from mne.utils import requires_mne, run_subprocess +from mne.utils import _record_warnings, requires_mne, run_subprocess data_path = testing.data_path(download=False) fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_meeg_grad = ( data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif" ) -fname_evoked = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif" -) +fname_evoked = Path(__file__).parents[2] / "io" / "tests" / "data" / "test-ave.fif" label_path = data_path / "MEG" / "sample" / "labels" @@ -232,7 +230,10 @@ def test_apply_forward(): # Evoked evoked = read_evokeds(fname_evoked, condition=0) evoked.pick(picks="meg") - with pytest.warns(RuntimeWarning, match="only .* positive values"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="only .* positive values"), + ): evoked = apply_forward(fwd, stc, evoked.info, start=start, stop=stop) data = evoked.data times = evoked.times @@ -250,13 +251,14 @@ def test_apply_forward(): stc.tmin, stc.tstep, ) - with pytest.warns(RuntimeWarning, match="very large"): + large_ctx = pytest.warns(RuntimeWarning, match="very large") + with large_ctx: evoked_2 = apply_forward(fwd, stc_vec, evoked.info) assert np.abs(evoked_2.data).mean() > 1e-5 assert_allclose(evoked.data, evoked_2.data, atol=1e-10) # Raw - with pytest.warns(RuntimeWarning, match="only .* positive values"): + with large_ctx, pytest.warns(RuntimeWarning, match="only .* positive values"): raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, stop=stop) data, times = raw_proj[:, :] diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 7c0dfa110aa..4e52b9a50b0 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -44,6 +44,7 @@ from mne.surface import _get_ico_surface from mne.transforms import Transform from mne.utils import ( + _record_warnings, catch_logging, requires_mne, requires_mne_mark, @@ -53,9 +54,7 @@ data_path = testing.data_path(download=False) fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -fname_raw = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) +fname_raw = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif" fname_evo = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" fname_dip = data_path / "MEG" / "sample" / "sample_audvis_trunc_set1.dip" @@ -66,7 +65,7 @@ fname_aseg = subjects_dir / "sample" / "mri" / "aseg.mgz" fname_bem_meg = subjects_dir / "sample" / "bem" / "sample-1280-bem-sol.fif" -io_path = Path(__file__).parent.parent.parent / "io" +io_path = Path(__file__).parents[2] / "io" bti_dir = io_path / "bti" / "tests" / "data" kit_dir = io_path / "kit" / "tests" / "data" trans_path = kit_dir / "trans-sample.fif" @@ -200,7 +199,7 @@ def test_magnetic_dipole(): r0 = coils[0]["rmag"][[0]] with pytest.raises(RuntimeError, match="Coil too close"): _magnetic_dipole_field_vec(r0, coils[:1]) - with pytest.warns(RuntimeWarning, match="Coil too close"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Coil too close"): fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close="warning") assert not np.isfinite(fwd).any() with np.errstate(invalid="ignore"): diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 6063f6f628f..983b4b5b067 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -1907,7 +1907,7 @@ def _configure_dock(self): func=self._save_trans, tooltip="Save the transform file to disk", layout=save_trans_layout, - filter="Head->MRI transformation (*-trans.fif *_trans.fif)", + filter_="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._widgets["load_trans"] = self._renderer._dock_add_file_button( @@ -1916,7 +1916,7 @@ def _configure_dock(self): func=self._load_trans, tooltip="Load the transform file from disk", layout=save_trans_layout, - filter="Head->MRI transformation (*-trans.fif *_trans.fif)", + filter_="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._renderer._layout_add_widget(trans_layout, save_trans_layout) diff --git a/mne/gui/_gui.py b/mne/gui/_gui.py index 3c437f9d266..e522986a60c 100644 --- a/mne/gui/_gui.py +++ b/mne/gui/_gui.py @@ -3,31 +3,24 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from ..utils import get_config, verbose, warn +from ..utils import get_config, verbose @verbose def coregistration( *, - tabbed=None, - split=None, width=None, + height=None, inst=None, subject=None, subjects_dir=None, - guess_mri_subject=None, - height=None, head_opacity=None, head_high_res=None, trans=None, - scrollable=None, orient_to_surface=None, scale_by_distance=None, mark_inside=None, interaction=None, - scale=None, - advanced_rendering=None, - head_inside=None, fullscreen=None, show=True, block=False, @@ -45,29 +38,20 @@ def coregistration( Parameters ---------- - tabbed : bool - Combine the data source panel and the coregistration panel into a - single panel with tabs. - split : bool - Split the main panels with a movable splitter (good for QT4 but - unnecessary for wx backend). width : int | None Specify the width for window (in logical pixels). Default is None, which uses ``MNE_COREG_WINDOW_WIDTH`` config value (which defaults to 800). + height : int | None + Specify a height for window (in logical pixels). + Default is None, which uses ``MNE_COREG_WINDOW_WIDTH`` config value + (which defaults to 400). inst : None | str Path to an instance file containing the digitizer data. Compatible for Raw, Epochs, and Evoked files. subject : None | str Name of the mri subject. %(subjects_dir)s - guess_mri_subject : bool - When selecting a new head shape file, guess the subject's name based - on the filename and change the MRI subject accordingly (default True). - height : int | None - Specify a height for window (in logical pixels). - Default is None, which uses ``MNE_COREG_WINDOW_WIDTH`` config value - (which defaults to 400). head_opacity : float | None The opacity of the head surface in the range [0., 1.]. Default is None, which uses ``MNE_COREG_HEAD_OPACITY`` config value @@ -78,8 +62,6 @@ def coregistration( (which defaults to True). trans : path-like | None The transform file to use. - scrollable : bool - Make the coregistration panel vertically scrollable (default True). orient_to_surface : bool | None If True (default), orient EEG electrode and head shape points to the head surface. @@ -102,21 +84,6 @@ def coregistration( .. versionchanged:: 1.0 Default interaction mode if ``None`` and no config setting found changed from ``'trackball'`` to ``'terrain'``. - scale : float | None - The scaling for the scene. - - .. versionadded:: 0.16 - advanced_rendering : bool - Use advanced OpenGL rendering techniques (default True). - For some renderers (such as MESA software) this can cause rendering - bugs. - - .. versionadded:: 0.18 - head_inside : bool - If True (default), add opaque inner scalp head surface to help occlude - points behind the head. - - .. versionadded:: 0.23 %(fullscreen)s Default is None, which uses ``MNE_COREG_FULLSCREEN`` config value (which defaults to False). @@ -143,28 +110,6 @@ def coregistration( .. youtube:: ALV5qqMHLlQ """ - unsupported_params = { - "tabbed": tabbed, - "split": split, - "scrollable": scrollable, - "head_inside": head_inside, - "guess_mri_subject": guess_mri_subject, - "scale": scale, - "advanced_rendering": advanced_rendering, - } - for key, val in unsupported_params.items(): - if isinstance(val, tuple): - to_raise = val[0] != val[1] - else: - to_raise = val is not None - if to_raise: - warn( - f"The parameter {key} is deprecated and will be removed in 1.7, do " - "not pass a value for it", - FutureWarning, - ) - del tabbed, split, scrollable, head_inside, guess_mri_subject, scale - del advanced_rendering config = get_config() if head_high_res is None: head_high_res = config.get("MNE_COREG_HEAD_HIGH_RES", "true") == "true" diff --git a/mne/gui/tests/test_gui_api.py b/mne/gui/tests/test_gui_api.py index 004c670a5ca..ae04124dd14 100644 --- a/mne/gui/tests/test_gui_api.py +++ b/mne/gui/tests/test_gui_api.py @@ -11,10 +11,9 @@ pytest.importorskip("nibabel") -def test_gui_api(renderer_notebook, nbexec, *, n_warn=0, backend="qt"): +def test_gui_api(renderer_notebook, nbexec, *, backend="qt"): """Test GUI API.""" import contextlib - import sys import warnings import mne @@ -25,7 +24,6 @@ def test_gui_api(renderer_notebook, nbexec, *, n_warn=0, backend="qt"): except Exception: # Notebook standalone mode backend = "notebook" - n_warn = 0 # nbexec does not expose renderer_notebook so I use a # temporary variable to synchronize the tests if backend == "notebook": @@ -44,8 +42,7 @@ def test_gui_api(renderer_notebook, nbexec, *, n_warn=0, backend="qt"): with mne.utils._record_warnings() as w: renderer._window_set_theme("dark") w = [ww for ww in w if "is not yet supported" in str(ww.message)] - if sys.platform != "darwin": # sometimes this is fine - assert len(w) == n_warn, [ww.message for ww in w] + assert len(w) == 0, [ww.message for ww in w] # window without 3d plotter if backend == "qt": @@ -387,10 +384,9 @@ def _check_widget_trigger( def test_gui_api_qt(renderer_interactive_pyvistaqt): """Test GUI API with the Qt backend.""" _, api = _check_qt_version(return_api=True) - n_warn = int(api in ("PySide6", "PyQt6")) # TODO: After merging https://github.com/mne-tools/mne-python/pull/11567 # The Qt CI run started failing about 50% of the time, so let's skip this # for now. if api == "PySide6": pytest.skip("PySide6 causes segfaults on CIs sometimes") - test_gui_api(None, None, n_warn=n_warn, backend="qt") + test_gui_api(None, None, backend="qt") diff --git a/mne/html/d3.v3.min.js b/mne/html/d3.v3.min.js deleted file mode 100644 index eed58e6a572..00000000000 --- a/mne/html/d3.v3.min.js +++ /dev/null @@ -1,5 +0,0 @@ -!function(){function n(n){return null!=n&&!isNaN(n)}function t(n){return n.length}function e(n){for(var t=1;n*t%1;)t*=10;return t}function r(n,t){try{for(var e in t)Object.defineProperty(n.prototype,e,{value:t[e],enumerable:!1})}catch(r){n.prototype=t}}function u(){}function i(n){return aa+n in this}function o(n){return n=aa+n,n in this&&delete this[n]}function a(){var n=[];return this.forEach(function(t){n.push(t)}),n}function c(){var n=0;for(var t in this)t.charCodeAt(0)===ca&&++n;return n}function s(){for(var n in this)if(n.charCodeAt(0)===ca)return!1;return!0}function l(){}function f(n,t,e){return function(){var r=e.apply(t,arguments);return r===t?n:r}}function h(n,t){if(t in n)return t;t=t.charAt(0).toUpperCase()+t.substring(1);for(var e=0,r=sa.length;r>e;++e){var u=sa[e]+t;if(u in n)return u}}function g(){}function p(){}function v(n){function t(){for(var t,r=e,u=-1,i=r.length;++ue;e++)for(var u,i=n[e],o=0,a=i.length;a>o;o++)(u=i[o])&&t(u,o,e);return n}function D(n){return fa(n,ya),n}function P(n){var t,e;return function(r,u,i){var o,a=n[i].update,c=a.length;for(i!=e&&(e=i,t=0),u>=t&&(t=u+1);!(o=a[t])&&++t0&&(n=n.substring(0,a));var s=Ma.get(n);return s&&(n=s,c=F),a?t?u:r:t?g:i}function H(n,t){return function(e){var r=Xo.event;Xo.event=e,t[0]=this.__data__;try{n.apply(this,t)}finally{Xo.event=r}}}function F(n,t){var e=H(n,t);return function(n){var t=this,r=n.relatedTarget;r&&(r===t||8&r.compareDocumentPosition(t))||e.call(t,n)}}function O(){var n=".dragsuppress-"+ ++ba,t="click"+n,e=Xo.select(Go).on("touchmove"+n,d).on("dragstart"+n,d).on("selectstart"+n,d);if(_a){var r=Jo.style,u=r[_a];r[_a]="none"}return function(i){function o(){e.on(t,null)}e.on(n,null),_a&&(r[_a]=u),i&&(e.on(t,function(){d(),o()},!0),setTimeout(o,0))}}function Y(n,t){t.changedTouches&&(t=t.changedTouches[0]);var e=n.ownerSVGElement||n;if(e.createSVGPoint){var r=e.createSVGPoint();if(0>wa&&(Go.scrollX||Go.scrollY)){e=Xo.select("body").append("svg").style({position:"absolute",top:0,left:0,margin:0,padding:0,border:"none"},"important");var u=e[0][0].getScreenCTM();wa=!(u.f||u.e),e.remove()}return wa?(r.x=t.pageX,r.y=t.pageY):(r.x=t.clientX,r.y=t.clientY),r=r.matrixTransform(n.getScreenCTM().inverse()),[r.x,r.y]}var i=n.getBoundingClientRect();return[t.clientX-i.left-n.clientLeft,t.clientY-i.top-n.clientTop]}function I(n){return n>0?1:0>n?-1:0}function Z(n,t,e){return(t[0]-n[0])*(e[1]-n[1])-(t[1]-n[1])*(e[0]-n[0])}function V(n){return n>1?0:-1>n?Sa:Math.acos(n)}function X(n){return n>1?Ea:-1>n?-Ea:Math.asin(n)}function $(n){return((n=Math.exp(n))-1/n)/2}function B(n){return((n=Math.exp(n))+1/n)/2}function W(n){return((n=Math.exp(2*n))-1)/(n+1)}function J(n){return(n=Math.sin(n/2))*n}function G(){}function K(n,t,e){return new Q(n,t,e)}function Q(n,t,e){this.h=n,this.s=t,this.l=e}function nt(n,t,e){function r(n){return n>360?n-=360:0>n&&(n+=360),60>n?i+(o-i)*n/60:180>n?o:240>n?i+(o-i)*(240-n)/60:i}function u(n){return Math.round(255*r(n))}var i,o;return n=isNaN(n)?0:(n%=360)<0?n+360:n,t=isNaN(t)?0:0>t?0:t>1?1:t,e=0>e?0:e>1?1:e,o=.5>=e?e*(1+t):e+t-e*t,i=2*e-o,gt(u(n+120),u(n),u(n-120))}function tt(n,t,e){return new et(n,t,e)}function et(n,t,e){this.h=n,this.c=t,this.l=e}function rt(n,t,e){return isNaN(n)&&(n=0),isNaN(t)&&(t=0),ut(e,Math.cos(n*=Na)*t,Math.sin(n)*t)}function ut(n,t,e){return new it(n,t,e)}function it(n,t,e){this.l=n,this.a=t,this.b=e}function ot(n,t,e){var r=(n+16)/116,u=r+t/500,i=r-e/200;return u=ct(u)*Fa,r=ct(r)*Oa,i=ct(i)*Ya,gt(lt(3.2404542*u-1.5371385*r-.4985314*i),lt(-.969266*u+1.8760108*r+.041556*i),lt(.0556434*u-.2040259*r+1.0572252*i))}function at(n,t,e){return n>0?tt(Math.atan2(e,t)*La,Math.sqrt(t*t+e*e),n):tt(0/0,0/0,n)}function ct(n){return n>.206893034?n*n*n:(n-4/29)/7.787037}function st(n){return n>.008856?Math.pow(n,1/3):7.787037*n+4/29}function lt(n){return Math.round(255*(.00304>=n?12.92*n:1.055*Math.pow(n,1/2.4)-.055))}function ft(n){return gt(n>>16,255&n>>8,255&n)}function ht(n){return ft(n)+""}function gt(n,t,e){return new pt(n,t,e)}function pt(n,t,e){this.r=n,this.g=t,this.b=e}function vt(n){return 16>n?"0"+Math.max(0,n).toString(16):Math.min(255,n).toString(16)}function dt(n,t,e){var r,u,i,o=0,a=0,c=0;if(r=/([a-z]+)\((.*)\)/i.exec(n))switch(u=r[2].split(","),r[1]){case"hsl":return e(parseFloat(u[0]),parseFloat(u[1])/100,parseFloat(u[2])/100);case"rgb":return t(Mt(u[0]),Mt(u[1]),Mt(u[2]))}return(i=Va.get(n))?t(i.r,i.g,i.b):(null!=n&&"#"===n.charAt(0)&&(4===n.length?(o=n.charAt(1),o+=o,a=n.charAt(2),a+=a,c=n.charAt(3),c+=c):7===n.length&&(o=n.substring(1,3),a=n.substring(3,5),c=n.substring(5,7)),o=parseInt(o,16),a=parseInt(a,16),c=parseInt(c,16)),t(o,a,c))}function mt(n,t,e){var r,u,i=Math.min(n/=255,t/=255,e/=255),o=Math.max(n,t,e),a=o-i,c=(o+i)/2;return a?(u=.5>c?a/(o+i):a/(2-o-i),r=n==o?(t-e)/a+(e>t?6:0):t==o?(e-n)/a+2:(n-t)/a+4,r*=60):(r=0/0,u=c>0&&1>c?0:r),K(r,u,c)}function yt(n,t,e){n=xt(n),t=xt(t),e=xt(e);var r=st((.4124564*n+.3575761*t+.1804375*e)/Fa),u=st((.2126729*n+.7151522*t+.072175*e)/Oa),i=st((.0193339*n+.119192*t+.9503041*e)/Ya);return ut(116*u-16,500*(r-u),200*(u-i))}function xt(n){return(n/=255)<=.04045?n/12.92:Math.pow((n+.055)/1.055,2.4)}function Mt(n){var t=parseFloat(n);return"%"===n.charAt(n.length-1)?Math.round(2.55*t):t}function _t(n){return"function"==typeof n?n:function(){return n}}function bt(n){return n}function wt(n){return function(t,e,r){return 2===arguments.length&&"function"==typeof e&&(r=e,e=null),St(t,e,n,r)}}function St(n,t,e,r){function u(){var n,t=c.status;if(!t&&c.responseText||t>=200&&300>t||304===t){try{n=e.call(i,c)}catch(r){return o.error.call(i,r),void 0}o.load.call(i,n)}else o.error.call(i,c)}var i={},o=Xo.dispatch("beforesend","progress","load","error"),a={},c=new XMLHttpRequest,s=null;return!Go.XDomainRequest||"withCredentials"in c||!/^(http(s)?:)?\/\//.test(n)||(c=new XDomainRequest),"onload"in c?c.onload=c.onerror=u:c.onreadystatechange=function(){c.readyState>3&&u()},c.onprogress=function(n){var t=Xo.event;Xo.event=n;try{o.progress.call(i,c)}finally{Xo.event=t}},i.header=function(n,t){return n=(n+"").toLowerCase(),arguments.length<2?a[n]:(null==t?delete a[n]:a[n]=t+"",i)},i.mimeType=function(n){return arguments.length?(t=null==n?null:n+"",i):t},i.responseType=function(n){return arguments.length?(s=n,i):s},i.response=function(n){return e=n,i},["get","post"].forEach(function(n){i[n]=function(){return i.send.apply(i,[n].concat(Bo(arguments)))}}),i.send=function(e,r,u){if(2===arguments.length&&"function"==typeof r&&(u=r,r=null),c.open(e,n,!0),null==t||"accept"in a||(a.accept=t+",*/*"),c.setRequestHeader)for(var l in a)c.setRequestHeader(l,a[l]);return null!=t&&c.overrideMimeType&&c.overrideMimeType(t),null!=s&&(c.responseType=s),null!=u&&i.on("error",u).on("load",function(n){u(null,n)}),o.beforesend.call(i,c),c.send(null==r?null:r),i},i.abort=function(){return c.abort(),i},Xo.rebind(i,o,"on"),null==r?i:i.get(kt(r))}function kt(n){return 1===n.length?function(t,e){n(null==t?e:null)}:n}function Et(){var n=At(),t=Ct()-n;t>24?(isFinite(t)&&(clearTimeout(Wa),Wa=setTimeout(Et,t)),Ba=0):(Ba=1,Ga(Et))}function At(){var n=Date.now();for(Ja=Xa;Ja;)n>=Ja.t&&(Ja.f=Ja.c(n-Ja.t)),Ja=Ja.n;return n}function Ct(){for(var n,t=Xa,e=1/0;t;)t.f?t=n?n.n=t.n:Xa=t.n:(t.t8?function(n){return n/e}:function(n){return n*e},symbol:n}}function zt(n){var t=n.decimal,e=n.thousands,r=n.grouping,u=n.currency,i=r?function(n){for(var t=n.length,u=[],i=0,o=r[0];t>0&&o>0;)u.push(n.substring(t-=o,t+o)),o=r[i=(i+1)%r.length];return u.reverse().join(e)}:bt;return function(n){var e=Qa.exec(n),r=e[1]||" ",o=e[2]||">",a=e[3]||"",c=e[4]||"",s=e[5],l=+e[6],f=e[7],h=e[8],g=e[9],p=1,v="",d="",m=!1;switch(h&&(h=+h.substring(1)),(s||"0"===r&&"="===o)&&(s=r="0",o="=",f&&(l-=Math.floor((l-1)/4))),g){case"n":f=!0,g="g";break;case"%":p=100,d="%",g="f";break;case"p":p=100,d="%",g="r";break;case"b":case"o":case"x":case"X":"#"===c&&(v="0"+g.toLowerCase());case"c":case"d":m=!0,h=0;break;case"s":p=-1,g="r"}"$"===c&&(v=u[0],d=u[1]),"r"!=g||h||(g="g"),null!=h&&("g"==g?h=Math.max(1,Math.min(21,h)):("e"==g||"f"==g)&&(h=Math.max(0,Math.min(20,h)))),g=nc.get(g)||qt;var y=s&&f;return function(n){var e=d;if(m&&n%1)return"";var u=0>n||0===n&&0>1/n?(n=-n,"-"):a;if(0>p){var c=Xo.formatPrefix(n,h);n=c.scale(n),e=c.symbol+d}else n*=p;n=g(n,h);var x=n.lastIndexOf("."),M=0>x?n:n.substring(0,x),_=0>x?"":t+n.substring(x+1);!s&&f&&(M=i(M));var b=v.length+M.length+_.length+(y?0:u.length),w=l>b?new Array(b=l-b+1).join(r):"";return y&&(M=i(w+M)),u+=v,n=M+_,("<"===o?u+n+w:">"===o?w+u+n:"^"===o?w.substring(0,b>>=1)+u+n+w.substring(b):u+(y?n:w+n))+e}}}function qt(n){return n+""}function Tt(){this._=new Date(arguments.length>1?Date.UTC.apply(this,arguments):arguments[0])}function Rt(n,t,e){function r(t){var e=n(t),r=i(e,1);return r-t>t-e?e:r}function u(e){return t(e=n(new ec(e-1)),1),e}function i(n,e){return t(n=new ec(+n),e),n}function o(n,r,i){var o=u(n),a=[];if(i>1)for(;r>o;)e(o)%i||a.push(new Date(+o)),t(o,1);else for(;r>o;)a.push(new Date(+o)),t(o,1);return a}function a(n,t,e){try{ec=Tt;var r=new Tt;return r._=n,o(r,t,e)}finally{ec=Date}}n.floor=n,n.round=r,n.ceil=u,n.offset=i,n.range=o;var c=n.utc=Dt(n);return c.floor=c,c.round=Dt(r),c.ceil=Dt(u),c.offset=Dt(i),c.range=a,n}function Dt(n){return function(t,e){try{ec=Tt;var r=new Tt;return r._=t,n(r,e)._}finally{ec=Date}}}function Pt(n){function t(n){function t(t){for(var e,u,i,o=[],a=-1,c=0;++aa;){if(r>=s)return-1;if(u=t.charCodeAt(a++),37===u){if(o=t.charAt(a++),i=N[o in uc?t.charAt(a++):o],!i||(r=i(n,e,r))<0)return-1}else if(u!=e.charCodeAt(r++))return-1}return r}function r(n,t,e){b.lastIndex=0;var r=b.exec(t.substring(e));return r?(n.w=w.get(r[0].toLowerCase()),e+r[0].length):-1}function u(n,t,e){M.lastIndex=0;var r=M.exec(t.substring(e));return r?(n.w=_.get(r[0].toLowerCase()),e+r[0].length):-1}function i(n,t,e){E.lastIndex=0;var r=E.exec(t.substring(e));return r?(n.m=A.get(r[0].toLowerCase()),e+r[0].length):-1}function o(n,t,e){S.lastIndex=0;var r=S.exec(t.substring(e));return r?(n.m=k.get(r[0].toLowerCase()),e+r[0].length):-1}function a(n,t,r){return e(n,C.c.toString(),t,r)}function c(n,t,r){return e(n,C.x.toString(),t,r)}function s(n,t,r){return e(n,C.X.toString(),t,r)}function l(n,t,e){var r=x.get(t.substring(e,e+=2).toLowerCase());return null==r?-1:(n.p=r,e)}var f=n.dateTime,h=n.date,g=n.time,p=n.periods,v=n.days,d=n.shortDays,m=n.months,y=n.shortMonths;t.utc=function(n){function e(n){try{ec=Tt;var t=new ec;return t._=n,r(t)}finally{ec=Date}}var r=t(n);return e.parse=function(n){try{ec=Tt;var t=r.parse(n);return t&&t._}finally{ec=Date}},e.toString=r.toString,e},t.multi=t.utc.multi=ee;var x=Xo.map(),M=jt(v),_=Ht(v),b=jt(d),w=Ht(d),S=jt(m),k=Ht(m),E=jt(y),A=Ht(y);p.forEach(function(n,t){x.set(n.toLowerCase(),t)});var C={a:function(n){return d[n.getDay()]},A:function(n){return v[n.getDay()]},b:function(n){return y[n.getMonth()]},B:function(n){return m[n.getMonth()]},c:t(f),d:function(n,t){return Ut(n.getDate(),t,2)},e:function(n,t){return Ut(n.getDate(),t,2)},H:function(n,t){return Ut(n.getHours(),t,2)},I:function(n,t){return Ut(n.getHours()%12||12,t,2)},j:function(n,t){return Ut(1+tc.dayOfYear(n),t,3)},L:function(n,t){return Ut(n.getMilliseconds(),t,3)},m:function(n,t){return Ut(n.getMonth()+1,t,2)},M:function(n,t){return Ut(n.getMinutes(),t,2)},p:function(n){return p[+(n.getHours()>=12)]},S:function(n,t){return Ut(n.getSeconds(),t,2)},U:function(n,t){return Ut(tc.sundayOfYear(n),t,2)},w:function(n){return n.getDay()},W:function(n,t){return Ut(tc.mondayOfYear(n),t,2)},x:t(h),X:t(g),y:function(n,t){return Ut(n.getFullYear()%100,t,2)},Y:function(n,t){return Ut(n.getFullYear()%1e4,t,4)},Z:ne,"%":function(){return"%"}},N={a:r,A:u,b:i,B:o,c:a,d:Bt,e:Bt,H:Jt,I:Jt,j:Wt,L:Qt,m:$t,M:Gt,p:l,S:Kt,U:Ot,w:Ft,W:Yt,x:c,X:s,y:Zt,Y:It,Z:Vt,"%":te};return t}function Ut(n,t,e){var r=0>n?"-":"",u=(r?-n:n)+"",i=u.length;return r+(e>i?new Array(e-i+1).join(t)+u:u)}function jt(n){return new RegExp("^(?:"+n.map(Xo.requote).join("|")+")","i")}function Ht(n){for(var t=new u,e=-1,r=n.length;++e68?1900:2e3)}function $t(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+2));return r?(n.m=r[0]-1,e+r[0].length):-1}function Bt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+2));return r?(n.d=+r[0],e+r[0].length):-1}function Wt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+3));return r?(n.j=+r[0],e+r[0].length):-1}function Jt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+2));return r?(n.H=+r[0],e+r[0].length):-1}function Gt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+2));return r?(n.M=+r[0],e+r[0].length):-1}function Kt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+2));return r?(n.S=+r[0],e+r[0].length):-1}function Qt(n,t,e){ic.lastIndex=0;var r=ic.exec(t.substring(e,e+3));return r?(n.L=+r[0],e+r[0].length):-1}function ne(n){var t=n.getTimezoneOffset(),e=t>0?"-":"+",r=~~(oa(t)/60),u=oa(t)%60;return e+Ut(r,"0",2)+Ut(u,"0",2)}function te(n,t,e){oc.lastIndex=0;var r=oc.exec(t.substring(e,e+1));return r?e+r[0].length:-1}function ee(n){for(var t=n.length,e=-1;++ea;++a)u.point((e=n[a])[0],e[1]);return u.lineEnd(),void 0}var c=new ke(e,n,null,!0),s=new ke(e,null,c,!1);c.o=s,i.push(c),o.push(s),c=new ke(r,n,null,!1),s=new ke(r,null,c,!0),c.o=s,i.push(c),o.push(s)}}),o.sort(t),Se(i),Se(o),i.length){for(var a=0,c=e,s=o.length;s>a;++a)o[a].e=c=!c;for(var l,f,h=i[0];;){for(var g=h,p=!0;g.v;)if((g=g.n)===h)return;l=g.z,u.lineStart();do{if(g.v=g.o.v=!0,g.e){if(p)for(var a=0,s=l.length;s>a;++a)u.point((f=l[a])[0],f[1]);else r(g.x,g.n.x,1,u);g=g.n}else{if(p){l=g.p.z;for(var a=l.length-1;a>=0;--a)u.point((f=l[a])[0],f[1])}else r(g.x,g.p.x,-1,u);g=g.p}g=g.o,l=g.z,p=!p}while(!g.v);u.lineEnd()}}}function Se(n){if(t=n.length){for(var t,e,r=0,u=n[0];++r1&&2&t&&e.push(e.pop().concat(e.shift())),g.push(e.filter(Ae))}}var g,p,v,d=t(i),m=u.invert(r[0],r[1]),y={point:o,lineStart:c,lineEnd:s,polygonStart:function(){y.point=l,y.lineStart=f,y.lineEnd=h,g=[],p=[],i.polygonStart()},polygonEnd:function(){y.point=o,y.lineStart=c,y.lineEnd=s,g=Xo.merge(g);var n=Le(m,p);g.length?we(g,Ne,n,e,i):n&&(i.lineStart(),e(null,null,1,i),i.lineEnd()),i.polygonEnd(),g=p=null},sphere:function(){i.polygonStart(),i.lineStart(),e(null,null,1,i),i.lineEnd(),i.polygonEnd()}},x=Ce(),M=t(x);return y}}function Ae(n){return n.length>1}function Ce(){var n,t=[];return{lineStart:function(){t.push(n=[])},point:function(t,e){n.push([t,e])},lineEnd:g,buffer:function(){var e=t;return t=[],n=null,e},rejoin:function(){t.length>1&&t.push(t.pop().concat(t.shift()))}}}function Ne(n,t){return((n=n.x)[0]<0?n[1]-Ea-Aa:Ea-n[1])-((t=t.x)[0]<0?t[1]-Ea-Aa:Ea-t[1])}function Le(n,t){var e=n[0],r=n[1],u=[Math.sin(e),-Math.cos(e),0],i=0,o=0;hc.reset();for(var a=0,c=t.length;c>a;++a){var s=t[a],l=s.length;if(l)for(var f=s[0],h=f[0],g=f[1]/2+Sa/4,p=Math.sin(g),v=Math.cos(g),d=1;;){d===l&&(d=0),n=s[d];var m=n[0],y=n[1]/2+Sa/4,x=Math.sin(y),M=Math.cos(y),_=m-h,b=oa(_)>Sa,w=p*x;if(hc.add(Math.atan2(w*Math.sin(_),v*M+w*Math.cos(_))),i+=b?_+(_>=0?ka:-ka):_,b^h>=e^m>=e){var S=fe(se(f),se(n));pe(S);var k=fe(u,S);pe(k);var E=(b^_>=0?-1:1)*X(k[2]);(r>E||r===E&&(S[0]||S[1]))&&(o+=b^_>=0?1:-1)}if(!d++)break;h=m,p=x,v=M,f=n}}return(-Aa>i||Aa>i&&0>hc)^1&o}function ze(n){var t,e=0/0,r=0/0,u=0/0;return{lineStart:function(){n.lineStart(),t=1},point:function(i,o){var a=i>0?Sa:-Sa,c=oa(i-e);oa(c-Sa)0?Ea:-Ea),n.point(u,r),n.lineEnd(),n.lineStart(),n.point(a,r),n.point(i,r),t=0):u!==a&&c>=Sa&&(oa(e-u)Aa?Math.atan((Math.sin(t)*(i=Math.cos(r))*Math.sin(e)-Math.sin(r)*(u=Math.cos(t))*Math.sin(n))/(u*i*o)):(t+r)/2}function Te(n,t,e,r){var u;if(null==n)u=e*Ea,r.point(-Sa,u),r.point(0,u),r.point(Sa,u),r.point(Sa,0),r.point(Sa,-u),r.point(0,-u),r.point(-Sa,-u),r.point(-Sa,0),r.point(-Sa,u);else if(oa(n[0]-t[0])>Aa){var i=n[0]i}function e(n){var e,i,c,s,l;return{lineStart:function(){s=c=!1,l=1},point:function(f,h){var g,p=[f,h],v=t(f,h),d=o?v?0:u(f,h):v?u(f+(0>f?Sa:-Sa),h):0;if(!e&&(s=c=v)&&n.lineStart(),v!==c&&(g=r(e,p),(de(e,g)||de(p,g))&&(p[0]+=Aa,p[1]+=Aa,v=t(p[0],p[1]))),v!==c)l=0,v?(n.lineStart(),g=r(p,e),n.point(g[0],g[1])):(g=r(e,p),n.point(g[0],g[1]),n.lineEnd()),e=g;else if(a&&e&&o^v){var m;d&i||!(m=r(p,e,!0))||(l=0,o?(n.lineStart(),n.point(m[0][0],m[0][1]),n.point(m[1][0],m[1][1]),n.lineEnd()):(n.point(m[1][0],m[1][1]),n.lineEnd(),n.lineStart(),n.point(m[0][0],m[0][1])))}!v||e&&de(e,p)||n.point(p[0],p[1]),e=p,c=v,i=d},lineEnd:function(){c&&n.lineEnd(),e=null},clean:function(){return l|(s&&c)<<1}}}function r(n,t,e){var r=se(n),u=se(t),o=[1,0,0],a=fe(r,u),c=le(a,a),s=a[0],l=c-s*s;if(!l)return!e&&n;var f=i*c/l,h=-i*s/l,g=fe(o,a),p=ge(o,f),v=ge(a,h);he(p,v);var d=g,m=le(p,d),y=le(d,d),x=m*m-y*(le(p,p)-1);if(!(0>x)){var M=Math.sqrt(x),_=ge(d,(-m-M)/y);if(he(_,p),_=ve(_),!e)return _;var b,w=n[0],S=t[0],k=n[1],E=t[1];w>S&&(b=w,w=S,S=b);var A=S-w,C=oa(A-Sa)A;if(!C&&k>E&&(b=k,k=E,E=b),N?C?k+E>0^_[1]<(oa(_[0]-w)Sa^(w<=_[0]&&_[0]<=S)){var L=ge(d,(-m+M)/y);return he(L,p),[_,ve(L)]}}}function u(t,e){var r=o?n:Sa-n,u=0;return-r>t?u|=1:t>r&&(u|=2),-r>e?u|=4:e>r&&(u|=8),u}var i=Math.cos(n),o=i>0,a=oa(i)>Aa,c=cr(n,6*Na);return Ee(t,e,c,o?[0,-n]:[-Sa,n-Sa])}function De(n,t,e,r){return function(u){var i,o=u.a,a=u.b,c=o.x,s=o.y,l=a.x,f=a.y,h=0,g=1,p=l-c,v=f-s;if(i=n-c,p||!(i>0)){if(i/=p,0>p){if(h>i)return;g>i&&(g=i)}else if(p>0){if(i>g)return;i>h&&(h=i)}if(i=e-c,p||!(0>i)){if(i/=p,0>p){if(i>g)return;i>h&&(h=i)}else if(p>0){if(h>i)return;g>i&&(g=i)}if(i=t-s,v||!(i>0)){if(i/=v,0>v){if(h>i)return;g>i&&(g=i)}else if(v>0){if(i>g)return;i>h&&(h=i)}if(i=r-s,v||!(0>i)){if(i/=v,0>v){if(i>g)return;i>h&&(h=i)}else if(v>0){if(h>i)return;g>i&&(g=i)}return h>0&&(u.a={x:c+h*p,y:s+h*v}),1>g&&(u.b={x:c+g*p,y:s+g*v}),u}}}}}}function Pe(n,t,e,r){function u(r,u){return oa(r[0]-n)0?0:3:oa(r[0]-e)0?2:1:oa(r[1]-t)0?1:0:u>0?3:2}function i(n,t){return o(n.x,t.x)}function o(n,t){var e=u(n,1),r=u(t,1);return e!==r?e-r:0===e?t[1]-n[1]:1===e?n[0]-t[0]:2===e?n[1]-t[1]:t[0]-n[0]}return function(a){function c(n){for(var t=0,e=d.length,r=n[1],u=0;e>u;++u)for(var i,o=1,a=d[u],c=a.length,s=a[0];c>o;++o)i=a[o],s[1]<=r?i[1]>r&&Z(s,i,n)>0&&++t:i[1]<=r&&Z(s,i,n)<0&&--t,s=i;return 0!==t}function s(i,a,c,s){var l=0,f=0;if(null==i||(l=u(i,c))!==(f=u(a,c))||o(i,a)<0^c>0){do s.point(0===l||3===l?n:e,l>1?r:t);while((l=(l+c+4)%4)!==f)}else s.point(a[0],a[1])}function l(u,i){return u>=n&&e>=u&&i>=t&&r>=i}function f(n,t){l(n,t)&&a.point(n,t)}function h(){N.point=p,d&&d.push(m=[]),S=!0,w=!1,_=b=0/0}function g(){v&&(p(y,x),M&&w&&A.rejoin(),v.push(A.buffer())),N.point=f,w&&a.lineEnd()}function p(n,t){n=Math.max(-Ac,Math.min(Ac,n)),t=Math.max(-Ac,Math.min(Ac,t));var e=l(n,t);if(d&&m.push([n,t]),S)y=n,x=t,M=e,S=!1,e&&(a.lineStart(),a.point(n,t));else if(e&&w)a.point(n,t);else{var r={a:{x:_,y:b},b:{x:n,y:t}};C(r)?(w||(a.lineStart(),a.point(r.a.x,r.a.y)),a.point(r.b.x,r.b.y),e||a.lineEnd(),k=!1):e&&(a.lineStart(),a.point(n,t),k=!1)}_=n,b=t,w=e}var v,d,m,y,x,M,_,b,w,S,k,E=a,A=Ce(),C=De(n,t,e,r),N={point:f,lineStart:h,lineEnd:g,polygonStart:function(){a=A,v=[],d=[],k=!0},polygonEnd:function(){a=E,v=Xo.merge(v);var t=c([n,r]),e=k&&t,u=v.length;(e||u)&&(a.polygonStart(),e&&(a.lineStart(),s(null,null,1,a),a.lineEnd()),u&&we(v,i,t,s,a),a.polygonEnd()),v=d=m=null}};return N}}function Ue(n,t){function e(e,r){return e=n(e,r),t(e[0],e[1])}return n.invert&&t.invert&&(e.invert=function(e,r){return e=t.invert(e,r),e&&n.invert(e[0],e[1])}),e}function je(n){var t=0,e=Sa/3,r=nr(n),u=r(t,e);return u.parallels=function(n){return arguments.length?r(t=n[0]*Sa/180,e=n[1]*Sa/180):[180*(t/Sa),180*(e/Sa)]},u}function He(n,t){function e(n,t){var e=Math.sqrt(i-2*u*Math.sin(t))/u;return[e*Math.sin(n*=u),o-e*Math.cos(n)]}var r=Math.sin(n),u=(r+Math.sin(t))/2,i=1+r*(2*u-r),o=Math.sqrt(i)/u;return e.invert=function(n,t){var e=o-t;return[Math.atan2(n,e)/u,X((i-(n*n+e*e)*u*u)/(2*u))]},e}function Fe(){function n(n,t){Nc+=u*n-r*t,r=n,u=t}var t,e,r,u;Rc.point=function(i,o){Rc.point=n,t=r=i,e=u=o},Rc.lineEnd=function(){n(t,e)}}function Oe(n,t){Lc>n&&(Lc=n),n>qc&&(qc=n),zc>t&&(zc=t),t>Tc&&(Tc=t)}function Ye(){function n(n,t){o.push("M",n,",",t,i)}function t(n,t){o.push("M",n,",",t),a.point=e}function e(n,t){o.push("L",n,",",t)}function r(){a.point=n}function u(){o.push("Z")}var i=Ie(4.5),o=[],a={point:n,lineStart:function(){a.point=t},lineEnd:r,polygonStart:function(){a.lineEnd=u},polygonEnd:function(){a.lineEnd=r,a.point=n},pointRadius:function(n){return i=Ie(n),a},result:function(){if(o.length){var n=o.join("");return o=[],n}}};return a}function Ie(n){return"m0,"+n+"a"+n+","+n+" 0 1,1 0,"+-2*n+"a"+n+","+n+" 0 1,1 0,"+2*n+"z"}function Ze(n,t){dc+=n,mc+=t,++yc}function Ve(){function n(n,r){var u=n-t,i=r-e,o=Math.sqrt(u*u+i*i);xc+=o*(t+n)/2,Mc+=o*(e+r)/2,_c+=o,Ze(t=n,e=r)}var t,e;Pc.point=function(r,u){Pc.point=n,Ze(t=r,e=u)}}function Xe(){Pc.point=Ze}function $e(){function n(n,t){var e=n-r,i=t-u,o=Math.sqrt(e*e+i*i);xc+=o*(r+n)/2,Mc+=o*(u+t)/2,_c+=o,o=u*n-r*t,bc+=o*(r+n),wc+=o*(u+t),Sc+=3*o,Ze(r=n,u=t)}var t,e,r,u;Pc.point=function(i,o){Pc.point=n,Ze(t=r=i,e=u=o)},Pc.lineEnd=function(){n(t,e)}}function Be(n){function t(t,e){n.moveTo(t,e),n.arc(t,e,o,0,ka)}function e(t,e){n.moveTo(t,e),a.point=r}function r(t,e){n.lineTo(t,e)}function u(){a.point=t}function i(){n.closePath()}var o=4.5,a={point:t,lineStart:function(){a.point=e},lineEnd:u,polygonStart:function(){a.lineEnd=i},polygonEnd:function(){a.lineEnd=u,a.point=t},pointRadius:function(n){return o=n,a},result:g};return a}function We(n){function t(n){return(a?r:e)(n)}function e(t){return Ke(t,function(e,r){e=n(e,r),t.point(e[0],e[1])})}function r(t){function e(e,r){e=n(e,r),t.point(e[0],e[1])}function r(){x=0/0,S.point=i,t.lineStart()}function i(e,r){var i=se([e,r]),o=n(e,r);u(x,M,y,_,b,w,x=o[0],M=o[1],y=e,_=i[0],b=i[1],w=i[2],a,t),t.point(x,M)}function o(){S.point=e,t.lineEnd()}function c(){r(),S.point=s,S.lineEnd=l}function s(n,t){i(f=n,h=t),g=x,p=M,v=_,d=b,m=w,S.point=i}function l(){u(x,M,y,_,b,w,g,p,f,v,d,m,a,t),S.lineEnd=o,o()}var f,h,g,p,v,d,m,y,x,M,_,b,w,S={point:e,lineStart:r,lineEnd:o,polygonStart:function(){t.polygonStart(),S.lineStart=c},polygonEnd:function(){t.polygonEnd(),S.lineStart=r}};return S}function u(t,e,r,a,c,s,l,f,h,g,p,v,d,m){var y=l-t,x=f-e,M=y*y+x*x;if(M>4*i&&d--){var _=a+g,b=c+p,w=s+v,S=Math.sqrt(_*_+b*b+w*w),k=Math.asin(w/=S),E=oa(oa(w)-1)i||oa((y*L+x*z)/M-.5)>.3||o>a*g+c*p+s*v)&&(u(t,e,r,a,c,s,C,N,E,_/=S,b/=S,w,d,m),m.point(C,N),u(C,N,E,_,b,w,l,f,h,g,p,v,d,m))}}var i=.5,o=Math.cos(30*Na),a=16;return t.precision=function(n){return arguments.length?(a=(i=n*n)>0&&16,t):Math.sqrt(i)},t}function Je(n){var t=We(function(t,e){return n([t*La,e*La])});return function(n){return tr(t(n))}}function Ge(n){this.stream=n}function Ke(n,t){return{point:t,sphere:function(){n.sphere()},lineStart:function(){n.lineStart()},lineEnd:function(){n.lineEnd()},polygonStart:function(){n.polygonStart()},polygonEnd:function(){n.polygonEnd()}}}function Qe(n){return nr(function(){return n})()}function nr(n){function t(n){return n=a(n[0]*Na,n[1]*Na),[n[0]*h+c,s-n[1]*h]}function e(n){return n=a.invert((n[0]-c)/h,(s-n[1])/h),n&&[n[0]*La,n[1]*La]}function r(){a=Ue(o=ur(m,y,x),i);var n=i(v,d);return c=g-n[0]*h,s=p+n[1]*h,u()}function u(){return l&&(l.valid=!1,l=null),t}var i,o,a,c,s,l,f=We(function(n,t){return n=i(n,t),[n[0]*h+c,s-n[1]*h]}),h=150,g=480,p=250,v=0,d=0,m=0,y=0,x=0,M=Ec,_=bt,b=null,w=null;return t.stream=function(n){return l&&(l.valid=!1),l=tr(M(o,f(_(n)))),l.valid=!0,l},t.clipAngle=function(n){return arguments.length?(M=null==n?(b=n,Ec):Re((b=+n)*Na),u()):b -},t.clipExtent=function(n){return arguments.length?(w=n,_=n?Pe(n[0][0],n[0][1],n[1][0],n[1][1]):bt,u()):w},t.scale=function(n){return arguments.length?(h=+n,r()):h},t.translate=function(n){return arguments.length?(g=+n[0],p=+n[1],r()):[g,p]},t.center=function(n){return arguments.length?(v=n[0]%360*Na,d=n[1]%360*Na,r()):[v*La,d*La]},t.rotate=function(n){return arguments.length?(m=n[0]%360*Na,y=n[1]%360*Na,x=n.length>2?n[2]%360*Na:0,r()):[m*La,y*La,x*La]},Xo.rebind(t,f,"precision"),function(){return i=n.apply(this,arguments),t.invert=i.invert&&e,r()}}function tr(n){return Ke(n,function(t,e){n.point(t*Na,e*Na)})}function er(n,t){return[n,t]}function rr(n,t){return[n>Sa?n-ka:-Sa>n?n+ka:n,t]}function ur(n,t,e){return n?t||e?Ue(or(n),ar(t,e)):or(n):t||e?ar(t,e):rr}function ir(n){return function(t,e){return t+=n,[t>Sa?t-ka:-Sa>t?t+ka:t,e]}}function or(n){var t=ir(n);return t.invert=ir(-n),t}function ar(n,t){function e(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,s=Math.sin(t),l=s*r+a*u;return[Math.atan2(c*i-l*o,a*r-s*u),X(l*i+c*o)]}var r=Math.cos(n),u=Math.sin(n),i=Math.cos(t),o=Math.sin(t);return e.invert=function(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,s=Math.sin(t),l=s*i-c*o;return[Math.atan2(c*i+s*o,a*r+l*u),X(l*r-a*u)]},e}function cr(n,t){var e=Math.cos(n),r=Math.sin(n);return function(u,i,o,a){var c=o*t;null!=u?(u=sr(e,u),i=sr(e,i),(o>0?i>u:u>i)&&(u+=o*ka)):(u=n+o*ka,i=n-.5*c);for(var s,l=u;o>0?l>i:i>l;l-=c)a.point((s=ve([e,-r*Math.cos(l),-r*Math.sin(l)]))[0],s[1])}}function sr(n,t){var e=se(t);e[0]-=n,pe(e);var r=V(-e[1]);return((-e[2]<0?-r:r)+2*Math.PI-Aa)%(2*Math.PI)}function lr(n,t,e){var r=Xo.range(n,t-Aa,e).concat(t);return function(n){return r.map(function(t){return[n,t]})}}function fr(n,t,e){var r=Xo.range(n,t-Aa,e).concat(t);return function(n){return r.map(function(t){return[t,n]})}}function hr(n){return n.source}function gr(n){return n.target}function pr(n,t,e,r){var u=Math.cos(t),i=Math.sin(t),o=Math.cos(r),a=Math.sin(r),c=u*Math.cos(n),s=u*Math.sin(n),l=o*Math.cos(e),f=o*Math.sin(e),h=2*Math.asin(Math.sqrt(J(r-t)+u*o*J(e-n))),g=1/Math.sin(h),p=h?function(n){var t=Math.sin(n*=h)*g,e=Math.sin(h-n)*g,r=e*c+t*l,u=e*s+t*f,o=e*i+t*a;return[Math.atan2(u,r)*La,Math.atan2(o,Math.sqrt(r*r+u*u))*La]}:function(){return[n*La,t*La]};return p.distance=h,p}function vr(){function n(n,u){var i=Math.sin(u*=Na),o=Math.cos(u),a=oa((n*=Na)-t),c=Math.cos(a);Uc+=Math.atan2(Math.sqrt((a=o*Math.sin(a))*a+(a=r*i-e*o*c)*a),e*i+r*o*c),t=n,e=i,r=o}var t,e,r;jc.point=function(u,i){t=u*Na,e=Math.sin(i*=Na),r=Math.cos(i),jc.point=n},jc.lineEnd=function(){jc.point=jc.lineEnd=g}}function dr(n,t){function e(t,e){var r=Math.cos(t),u=Math.cos(e),i=n(r*u);return[i*u*Math.sin(t),i*Math.sin(e)]}return e.invert=function(n,e){var r=Math.sqrt(n*n+e*e),u=t(r),i=Math.sin(u),o=Math.cos(u);return[Math.atan2(n*i,r*o),Math.asin(r&&e*i/r)]},e}function mr(n,t){function e(n,t){var e=oa(oa(t)-Ea)u;u++){for(;r>1&&Z(n[e[r-2]],n[e[r-1]],n[u])<=0;)--r;e[r++]=u}return e.slice(0,r)}function kr(n,t){return n[0]-t[0]||n[1]-t[1]}function Er(n,t,e){return(e[0]-t[0])*(n[1]-t[1])<(e[1]-t[1])*(n[0]-t[0])}function Ar(n,t,e,r){var u=n[0],i=e[0],o=t[0]-u,a=r[0]-i,c=n[1],s=e[1],l=t[1]-c,f=r[1]-s,h=(a*(c-s)-f*(u-i))/(f*o-a*l);return[u+h*o,c+h*l]}function Cr(n){var t=n[0],e=n[n.length-1];return!(t[0]-e[0]||t[1]-e[1])}function Nr(){Jr(this),this.edge=this.site=this.circle=null}function Lr(n){var t=Jc.pop()||new Nr;return t.site=n,t}function zr(n){Or(n),$c.remove(n),Jc.push(n),Jr(n)}function qr(n){var t=n.circle,e=t.x,r=t.cy,u={x:e,y:r},i=n.P,o=n.N,a=[n];zr(n);for(var c=i;c.circle&&oa(e-c.circle.x)l;++l)s=a[l],c=a[l-1],$r(s.edge,c.site,s.site,u);c=a[0],s=a[f-1],s.edge=Vr(c.site,s.site,null,u),Fr(c),Fr(s)}function Tr(n){for(var t,e,r,u,i=n.x,o=n.y,a=$c._;a;)if(r=Rr(a,o)-i,r>Aa)a=a.L;else{if(u=i-Dr(a,o),!(u>Aa)){r>-Aa?(t=a.P,e=a):u>-Aa?(t=a,e=a.N):t=e=a;break}if(!a.R){t=a;break}a=a.R}var c=Lr(n);if($c.insert(t,c),t||e){if(t===e)return Or(t),e=Lr(t.site),$c.insert(c,e),c.edge=e.edge=Vr(t.site,c.site),Fr(t),Fr(e),void 0;if(!e)return c.edge=Vr(t.site,c.site),void 0;Or(t),Or(e);var s=t.site,l=s.x,f=s.y,h=n.x-l,g=n.y-f,p=e.site,v=p.x-l,d=p.y-f,m=2*(h*d-g*v),y=h*h+g*g,x=v*v+d*d,M={x:(d*y-g*x)/m+l,y:(h*x-v*y)/m+f};$r(e.edge,s,p,M),c.edge=Vr(s,n,null,M),e.edge=Vr(n,p,null,M),Fr(t),Fr(e)}}function Rr(n,t){var e=n.site,r=e.x,u=e.y,i=u-t;if(!i)return r;var o=n.P;if(!o)return-1/0;e=o.site;var a=e.x,c=e.y,s=c-t;if(!s)return a;var l=a-r,f=1/i-1/s,h=l/s;return f?(-h+Math.sqrt(h*h-2*f*(l*l/(-2*s)-c+s/2+u-i/2)))/f+r:(r+a)/2}function Dr(n,t){var e=n.N;if(e)return Rr(e,t);var r=n.site;return r.y===t?r.x:1/0}function Pr(n){this.site=n,this.edges=[]}function Ur(n){for(var t,e,r,u,i,o,a,c,s,l,f=n[0][0],h=n[1][0],g=n[0][1],p=n[1][1],v=Xc,d=v.length;d--;)if(i=v[d],i&&i.prepare())for(a=i.edges,c=a.length,o=0;c>o;)l=a[o].end(),r=l.x,u=l.y,s=a[++o%c].start(),t=s.x,e=s.y,(oa(r-t)>Aa||oa(u-e)>Aa)&&(a.splice(o,0,new Br(Xr(i.site,l,oa(r-f)Aa?{x:f,y:oa(t-f)Aa?{x:oa(e-p)Aa?{x:h,y:oa(t-h)Aa?{x:oa(e-g)=-Ca)){var g=c*c+s*s,p=l*l+f*f,v=(f*g-s*p)/h,d=(c*p-l*g)/h,f=d+a,m=Gc.pop()||new Hr;m.arc=n,m.site=u,m.x=v+o,m.y=f+Math.sqrt(v*v+d*d),m.cy=f,n.circle=m;for(var y=null,x=Wc._;x;)if(m.yd||d>=a)return;if(h>p){if(i){if(i.y>=s)return}else i={x:d,y:c};e={x:d,y:s}}else{if(i){if(i.yr||r>1)if(h>p){if(i){if(i.y>=s)return}else i={x:(c-u)/r,y:c};e={x:(s-u)/r,y:s}}else{if(i){if(i.yg){if(i){if(i.x>=a)return}else i={x:o,y:r*o+u};e={x:a,y:r*a+u}}else{if(i){if(i.xr;++r)if(o=l[r],o.x==e[0]){if(o.i)if(null==s[o.i+1])for(s[o.i-1]+=o.x,s.splice(o.i,1),u=r+1;i>u;++u)l[u].i--;else for(s[o.i-1]+=o.x+s[o.i+1],s.splice(o.i,2),u=r+1;i>u;++u)l[u].i-=2;else if(null==s[o.i+1])s[o.i]=o.x;else for(s[o.i]=o.x+s[o.i+1],s.splice(o.i+1,1),u=r+1;i>u;++u)l[u].i--;l.splice(r,1),i--,r--}else o.x=su(parseFloat(e[0]),parseFloat(o.x));for(;i>r;)o=l.pop(),null==s[o.i+1]?s[o.i]=o.x:(s[o.i]=o.x+s[o.i+1],s.splice(o.i+1,1)),i--;return 1===s.length?null==s[0]?(o=l[0].x,function(n){return o(n)+""}):function(){return t}:function(n){for(r=0;i>r;++r)s[(o=l[r]).i]=o.x(n);return s.join("")}}function fu(n,t){for(var e,r=Xo.interpolators.length;--r>=0&&!(e=Xo.interpolators[r](n,t)););return e}function hu(n,t){var e,r=[],u=[],i=n.length,o=t.length,a=Math.min(n.length,t.length);for(e=0;a>e;++e)r.push(fu(n[e],t[e]));for(;i>e;++e)u[e]=n[e];for(;o>e;++e)u[e]=t[e];return function(n){for(e=0;a>e;++e)u[e]=r[e](n);return u}}function gu(n){return function(t){return 0>=t?0:t>=1?1:n(t)}}function pu(n){return function(t){return 1-n(1-t)}}function vu(n){return function(t){return.5*(.5>t?n(2*t):2-n(2-2*t))}}function du(n){return n*n}function mu(n){return n*n*n}function yu(n){if(0>=n)return 0;if(n>=1)return 1;var t=n*n,e=t*n;return 4*(.5>n?e:3*(n-t)+e-.75)}function xu(n){return function(t){return Math.pow(t,n)}}function Mu(n){return 1-Math.cos(n*Ea)}function _u(n){return Math.pow(2,10*(n-1))}function bu(n){return 1-Math.sqrt(1-n*n)}function wu(n,t){var e;return arguments.length<2&&(t=.45),arguments.length?e=t/ka*Math.asin(1/n):(n=1,e=t/4),function(r){return 1+n*Math.pow(2,-10*r)*Math.sin((r-e)*ka/t)}}function Su(n){return n||(n=1.70158),function(t){return t*t*((n+1)*t-n)}}function ku(n){return 1/2.75>n?7.5625*n*n:2/2.75>n?7.5625*(n-=1.5/2.75)*n+.75:2.5/2.75>n?7.5625*(n-=2.25/2.75)*n+.9375:7.5625*(n-=2.625/2.75)*n+.984375}function Eu(n,t){n=Xo.hcl(n),t=Xo.hcl(t);var e=n.h,r=n.c,u=n.l,i=t.h-e,o=t.c-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.c:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return rt(e+i*n,r+o*n,u+a*n)+""}}function Au(n,t){n=Xo.hsl(n),t=Xo.hsl(t);var e=n.h,r=n.s,u=n.l,i=t.h-e,o=t.s-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.s:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return nt(e+i*n,r+o*n,u+a*n)+""}}function Cu(n,t){n=Xo.lab(n),t=Xo.lab(t);var e=n.l,r=n.a,u=n.b,i=t.l-e,o=t.a-r,a=t.b-u;return function(n){return ot(e+i*n,r+o*n,u+a*n)+""}}function Nu(n,t){return t-=n,function(e){return Math.round(n+t*e)}}function Lu(n){var t=[n.a,n.b],e=[n.c,n.d],r=qu(t),u=zu(t,e),i=qu(Tu(e,t,-u))||0;t[0]*e[1]180?l+=360:l-s>180&&(s+=360),u.push({i:r.push(r.pop()+"rotate(",null,")")-2,x:su(s,l)})):l&&r.push(r.pop()+"rotate("+l+")"),f!=h?u.push({i:r.push(r.pop()+"skewX(",null,")")-2,x:su(f,h)}):h&&r.push(r.pop()+"skewX("+h+")"),g[0]!=p[0]||g[1]!=p[1]?(e=r.push(r.pop()+"scale(",null,",",null,")"),u.push({i:e-4,x:su(g[0],p[0])},{i:e-2,x:su(g[1],p[1])})):(1!=p[0]||1!=p[1])&&r.push(r.pop()+"scale("+p+")"),e=u.length,function(n){for(var t,i=-1;++ie;++e)(t=n[e][1])>u&&(r=e,u=t);return r}function ei(n){return n.reduce(ri,0)}function ri(n,t){return n+t[1]}function ui(n,t){return ii(n,Math.ceil(Math.log(t.length)/Math.LN2+1))}function ii(n,t){for(var e=-1,r=+n[0],u=(n[1]-r)/t,i=[];++e<=t;)i[e]=u*e+r;return i}function oi(n){return[Xo.min(n),Xo.max(n)]}function ai(n,t){return n.parent==t.parent?1:2}function ci(n){var t=n.children;return t&&t.length?t[0]:n._tree.thread}function si(n){var t,e=n.children;return e&&(t=e.length)?e[t-1]:n._tree.thread}function li(n,t){var e=n.children;if(e&&(u=e.length))for(var r,u,i=-1;++i0&&(n=r);return n}function fi(n,t){return n.x-t.x}function hi(n,t){return t.x-n.x}function gi(n,t){return n.depth-t.depth}function pi(n,t){function e(n,r){var u=n.children;if(u&&(o=u.length))for(var i,o,a=null,c=-1;++c=0;)t=u[i]._tree,t.prelim+=e,t.mod+=e,e+=t.shift+(r+=t.change)}function di(n,t,e){n=n._tree,t=t._tree;var r=e/(t.number-n.number);n.change+=r,t.change-=r,t.shift+=e,t.prelim+=e,t.mod+=e}function mi(n,t,e){return n._tree.ancestor.parent==t.parent?n._tree.ancestor:e}function yi(n,t){return n.value-t.value}function xi(n,t){var e=n._pack_next;n._pack_next=t,t._pack_prev=n,t._pack_next=e,e._pack_prev=t}function Mi(n,t){n._pack_next=t,t._pack_prev=n}function _i(n,t){var e=t.x-n.x,r=t.y-n.y,u=n.r+t.r;return.999*u*u>e*e+r*r}function bi(n){function t(n){l=Math.min(n.x-n.r,l),f=Math.max(n.x+n.r,f),h=Math.min(n.y-n.r,h),g=Math.max(n.y+n.r,g)}if((e=n.children)&&(s=e.length)){var e,r,u,i,o,a,c,s,l=1/0,f=-1/0,h=1/0,g=-1/0;if(e.forEach(wi),r=e[0],r.x=-r.r,r.y=0,t(r),s>1&&(u=e[1],u.x=u.r,u.y=0,t(u),s>2))for(i=e[2],Ei(r,u,i),t(i),xi(r,i),r._pack_prev=i,xi(i,u),u=r._pack_next,o=3;s>o;o++){Ei(r,u,i=e[o]);var p=0,v=1,d=1;for(a=u._pack_next;a!==u;a=a._pack_next,v++)if(_i(a,i)){p=1;break}if(1==p)for(c=r._pack_prev;c!==a._pack_prev&&!_i(c,i);c=c._pack_prev,d++);p?(d>v||v==d&&u.ro;o++)i=e[o],i.x-=m,i.y-=y,x=Math.max(x,i.r+Math.sqrt(i.x*i.x+i.y*i.y));n.r=x,e.forEach(Si)}}function wi(n){n._pack_next=n._pack_prev=n}function Si(n){delete n._pack_next,delete n._pack_prev}function ki(n,t,e,r){var u=n.children;if(n.x=t+=r*n.x,n.y=e+=r*n.y,n.r*=r,u)for(var i=-1,o=u.length;++iu&&(e+=u/2,u=0),0>i&&(r+=i/2,i=0),{x:e,y:r,dx:u,dy:i}}function Ti(n){var t=n[0],e=n[n.length-1];return e>t?[t,e]:[e,t]}function Ri(n){return n.rangeExtent?n.rangeExtent():Ti(n.range())}function Di(n,t,e,r){var u=e(n[0],n[1]),i=r(t[0],t[1]);return function(n){return i(u(n))}}function Pi(n,t){var e,r=0,u=n.length-1,i=n[r],o=n[u];return i>o&&(e=r,r=u,u=e,e=i,i=o,o=e),n[r]=t.floor(i),n[u]=t.ceil(o),n}function Ui(n){return n?{floor:function(t){return Math.floor(t/n)*n},ceil:function(t){return Math.ceil(t/n)*n}}:ls}function ji(n,t,e,r){var u=[],i=[],o=0,a=Math.min(n.length,t.length)-1;for(n[a]2?ji:Di,c=r?Pu:Du;return o=u(n,t,c,e),a=u(t,n,c,fu),i}function i(n){return o(n)}var o,a;return i.invert=function(n){return a(n)},i.domain=function(t){return arguments.length?(n=t.map(Number),u()):n},i.range=function(n){return arguments.length?(t=n,u()):t},i.rangeRound=function(n){return i.range(n).interpolate(Nu)},i.clamp=function(n){return arguments.length?(r=n,u()):r},i.interpolate=function(n){return arguments.length?(e=n,u()):e},i.ticks=function(t){return Ii(n,t)},i.tickFormat=function(t,e){return Zi(n,t,e)},i.nice=function(t){return Oi(n,t),u()},i.copy=function(){return Hi(n,t,e,r)},u()}function Fi(n,t){return Xo.rebind(n,t,"range","rangeRound","interpolate","clamp")}function Oi(n,t){return Pi(n,Ui(Yi(n,t)[2]))}function Yi(n,t){null==t&&(t=10);var e=Ti(n),r=e[1]-e[0],u=Math.pow(10,Math.floor(Math.log(r/t)/Math.LN10)),i=t/r*u;return.15>=i?u*=10:.35>=i?u*=5:.75>=i&&(u*=2),e[0]=Math.ceil(e[0]/u)*u,e[1]=Math.floor(e[1]/u)*u+.5*u,e[2]=u,e}function Ii(n,t){return Xo.range.apply(Xo,Yi(n,t))}function Zi(n,t,e){var r=Yi(n,t);return Xo.format(e?e.replace(Qa,function(n,t,e,u,i,o,a,c,s,l){return[t,e,u,i,o,a,c,s||"."+Xi(l,r),l].join("")}):",."+Vi(r[2])+"f")}function Vi(n){return-Math.floor(Math.log(n)/Math.LN10+.01)}function Xi(n,t){var e=Vi(t[2]);return n in fs?Math.abs(e-Vi(Math.max(Math.abs(t[0]),Math.abs(t[1]))))+ +("e"!==n):e-2*("%"===n)}function $i(n,t,e,r){function u(n){return(e?Math.log(0>n?0:n):-Math.log(n>0?0:-n))/Math.log(t)}function i(n){return e?Math.pow(t,n):-Math.pow(t,-n)}function o(t){return n(u(t))}return o.invert=function(t){return i(n.invert(t))},o.domain=function(t){return arguments.length?(e=t[0]>=0,n.domain((r=t.map(Number)).map(u)),o):r},o.base=function(e){return arguments.length?(t=+e,n.domain(r.map(u)),o):t},o.nice=function(){var t=Pi(r.map(u),e?Math:gs);return n.domain(t),r=t.map(i),o},o.ticks=function(){var n=Ti(r),o=[],a=n[0],c=n[1],s=Math.floor(u(a)),l=Math.ceil(u(c)),f=t%1?2:t;if(isFinite(l-s)){if(e){for(;l>s;s++)for(var h=1;f>h;h++)o.push(i(s)*h);o.push(i(s))}else for(o.push(i(s));s++0;h--)o.push(i(s)*h);for(s=0;o[s]c;l--);o=o.slice(s,l)}return o},o.tickFormat=function(n,t){if(!arguments.length)return hs;arguments.length<2?t=hs:"function"!=typeof t&&(t=Xo.format(t));var r,a=Math.max(.1,n/o.ticks().length),c=e?(r=1e-12,Math.ceil):(r=-1e-12,Math.floor);return function(n){return n/i(c(u(n)+r))<=a?t(n):""}},o.copy=function(){return $i(n.copy(),t,e,r)},Fi(o,n)}function Bi(n,t,e){function r(t){return n(u(t))}var u=Wi(t),i=Wi(1/t);return r.invert=function(t){return i(n.invert(t))},r.domain=function(t){return arguments.length?(n.domain((e=t.map(Number)).map(u)),r):e},r.ticks=function(n){return Ii(e,n)},r.tickFormat=function(n,t){return Zi(e,n,t)},r.nice=function(n){return r.domain(Oi(e,n))},r.exponent=function(o){return arguments.length?(u=Wi(t=o),i=Wi(1/t),n.domain(e.map(u)),r):t},r.copy=function(){return Bi(n.copy(),t,e)},Fi(r,n)}function Wi(n){return function(t){return 0>t?-Math.pow(-t,n):Math.pow(t,n)}}function Ji(n,t){function e(e){return o[((i.get(e)||"range"===t.t&&i.set(e,n.push(e)))-1)%o.length]}function r(t,e){return Xo.range(n.length).map(function(n){return t+e*n})}var i,o,a;return e.domain=function(r){if(!arguments.length)return n;n=[],i=new u;for(var o,a=-1,c=r.length;++ae?[0/0,0/0]:[e>0?u[e-1]:n[0],et?0/0:t/i+n,[t,t+1/i]},r.copy=function(){return Ki(n,t,e)},u()}function Qi(n,t){function e(e){return e>=e?t[Xo.bisect(n,e)]:void 0}return e.domain=function(t){return arguments.length?(n=t,e):n},e.range=function(n){return arguments.length?(t=n,e):t},e.invertExtent=function(e){return e=t.indexOf(e),[n[e-1],n[e]]},e.copy=function(){return Qi(n,t)},e}function no(n){function t(n){return+n}return t.invert=t,t.domain=t.range=function(e){return arguments.length?(n=e.map(t),t):n},t.ticks=function(t){return Ii(n,t)},t.tickFormat=function(t,e){return Zi(n,t,e)},t.copy=function(){return no(n)},t}function to(n){return n.innerRadius}function eo(n){return n.outerRadius}function ro(n){return n.startAngle}function uo(n){return n.endAngle}function io(n){function t(t){function o(){s.push("M",i(n(l),a))}for(var c,s=[],l=[],f=-1,h=t.length,g=_t(e),p=_t(r);++f1&&u.push("H",r[0]),u.join("")}function so(n){for(var t=0,e=n.length,r=n[0],u=[r[0],",",r[1]];++t1){a=t[1],i=n[c],c++,r+="C"+(u[0]+o[0])+","+(u[1]+o[1])+","+(i[0]-a[0])+","+(i[1]-a[1])+","+i[0]+","+i[1];for(var s=2;s9&&(u=3*t/Math.sqrt(u),o[a]=u*e,o[a+1]=u*r));for(a=-1;++a<=c;)u=(n[Math.min(c,a+1)][0]-n[Math.max(0,a-1)][0])/(6*(1+o[a]*o[a])),i.push([u||0,o[a]*u||0]);return i}function Eo(n){return n.length<3?oo(n):n[0]+po(n,ko(n))}function Ao(n){for(var t,e,r,u=-1,i=n.length;++ue?s():(i.active=e,o.event&&o.event.start.call(n,l,t),o.tween.forEach(function(e,r){(r=r.call(n,l,t))&&v.push(r)}),Xo.timer(function(){return p.c=c(r||1)?be:c,1},0,a),void 0)}function c(r){if(i.active!==e)return s();for(var u=r/g,a=f(u),c=v.length;c>0;)v[--c].call(n,a);return u>=1?(o.event&&o.event.end.call(n,l,t),s()):void 0}function s(){return--i.count?delete i[e]:delete n.__transition__,1}var l=n.__data__,f=o.ease,h=o.delay,g=o.duration,p=Ja,v=[];return p.t=h+a,r>=h?u(r-h):(p.c=u,void 0)},0,a)}}function Ho(n,t){n.attr("transform",function(n){return"translate("+t(n)+",0)"})}function Fo(n,t){n.attr("transform",function(n){return"translate(0,"+t(n)+")"})}function Oo(n){return n.toISOString()}function Yo(n,t,e){function r(t){return n(t)}function u(n,e){var r=n[1]-n[0],u=r/e,i=Xo.bisect(js,u);return i==js.length?[t.year,Yi(n.map(function(n){return n/31536e6}),e)[2]]:i?t[u/js[i-1]1?{floor:function(t){for(;e(t=n.floor(t));)t=Io(t-1);return t},ceil:function(t){for(;e(t=n.ceil(t));)t=Io(+t+1);return t}}:n))},r.ticks=function(n,t){var e=Ti(r.domain()),i=null==n?u(e,10):"number"==typeof n?u(e,n):!n.range&&[{range:n},t];return i&&(n=i[0],t=i[1]),n.range(e[0],Io(+e[1]+1),1>t?1:t)},r.tickFormat=function(){return e},r.copy=function(){return Yo(n.copy(),t,e)},Fi(r,n)}function Io(n){return new Date(n)}function Zo(n){return JSON.parse(n.responseText)}function Vo(n){var t=Wo.createRange();return t.selectNode(Wo.body),t.createContextualFragment(n.responseText)}var Xo={version:"3.4.2"};Date.now||(Date.now=function(){return+new Date});var $o=[].slice,Bo=function(n){return $o.call(n)},Wo=document,Jo=Wo.documentElement,Go=window;try{Bo(Jo.childNodes)[0].nodeType}catch(Ko){Bo=function(n){for(var t=n.length,e=new Array(t);t--;)e[t]=n[t];return e}}try{Wo.createElement("div").style.setProperty("opacity",0,"")}catch(Qo){var na=Go.Element.prototype,ta=na.setAttribute,ea=na.setAttributeNS,ra=Go.CSSStyleDeclaration.prototype,ua=ra.setProperty;na.setAttribute=function(n,t){ta.call(this,n,t+"")},na.setAttributeNS=function(n,t,e){ea.call(this,n,t,e+"")},ra.setProperty=function(n,t,e){ua.call(this,n,t+"",e)}}Xo.ascending=function(n,t){return t>n?-1:n>t?1:n>=t?0:0/0},Xo.descending=function(n,t){return n>t?-1:t>n?1:t>=n?0:0/0},Xo.min=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=e);)e=void 0;for(;++ur&&(e=r)}else{for(;++u=e);)e=void 0;for(;++ur&&(e=r)}return e},Xo.max=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=e);)e=void 0;for(;++ue&&(e=r)}else{for(;++u=e);)e=void 0;for(;++ue&&(e=r)}return e},Xo.extent=function(n,t){var e,r,u,i=-1,o=n.length;if(1===arguments.length){for(;++i=e);)e=u=void 0;for(;++ir&&(e=r),r>u&&(u=r))}else{for(;++i=e);)e=void 0;for(;++ir&&(e=r),r>u&&(u=r))}return[e,u]},Xo.sum=function(n,t){var e,r=0,u=n.length,i=-1;if(1===arguments.length)for(;++i1&&(t=t.map(e)),t=t.filter(n),t.length?Xo.quantile(t.sort(Xo.ascending),.5):void 0},Xo.bisector=function(n){return{left:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n.call(t,t[i],i)r;){var i=r+u>>>1;er?0:r);r>e;)i[e]=[t=u,u=n[++e]];return i},Xo.zip=function(){if(!(u=arguments.length))return[];for(var n=-1,e=Xo.min(arguments,t),r=new Array(e);++n=0;)for(r=n[u],t=r.length;--t>=0;)e[--o]=r[t];return e};var oa=Math.abs;Xo.range=function(n,t,r){if(arguments.length<3&&(r=1,arguments.length<2&&(t=n,n=0)),1/0===(t-n)/r)throw new Error("infinite range");var u,i=[],o=e(oa(r)),a=-1;if(n*=o,t*=o,r*=o,0>r)for(;(u=n+r*++a)>t;)i.push(u/o);else for(;(u=n+r*++a)=o.length)return r?r.call(i,a):e?a.sort(e):a;for(var s,l,f,h,g=-1,p=a.length,v=o[c++],d=new u;++g=o.length)return n;var r=[],u=a[e++];return n.forEach(function(n,u){r.push({key:n,values:t(u,e)})}),u?r.sort(function(n,t){return u(n.key,t.key)}):r}var e,r,i={},o=[],a=[];return i.map=function(t,e){return n(e,t,0)},i.entries=function(e){return t(n(Xo.map,e,0),0)},i.key=function(n){return o.push(n),i},i.sortKeys=function(n){return a[o.length-1]=n,i},i.sortValues=function(n){return e=n,i},i.rollup=function(n){return r=n,i},i},Xo.set=function(n){var t=new l;if(n)for(var e=0,r=n.length;r>e;++e)t.add(n[e]);return t},r(l,{has:i,add:function(n){return this[aa+n]=!0,n},remove:function(n){return n=aa+n,n in this&&delete this[n]},values:a,size:c,empty:s,forEach:function(n){for(var t in this)t.charCodeAt(0)===ca&&n.call(this,t.substring(1))}}),Xo.behavior={},Xo.rebind=function(n,t){for(var e,r=1,u=arguments.length;++r=0&&(r=n.substring(e+1),n=n.substring(0,e)),n)return arguments.length<2?this[n].on(r):this[n].on(r,t);if(2===arguments.length){if(null==t)for(n in this)this.hasOwnProperty(n)&&this[n].on(r,null);return this}},Xo.event=null,Xo.requote=function(n){return n.replace(la,"\\$&")};var la=/[\\\^\$\*\+\?\|\[\]\(\)\.\{\}]/g,fa={}.__proto__?function(n,t){n.__proto__=t}:function(n,t){for(var e in t)n[e]=t[e]},ha=function(n,t){return t.querySelector(n)},ga=function(n,t){return t.querySelectorAll(n)},pa=Jo[h(Jo,"matchesSelector")],va=function(n,t){return pa.call(n,t)};"function"==typeof Sizzle&&(ha=function(n,t){return Sizzle(n,t)[0]||null},ga=function(n,t){return Sizzle.uniqueSort(Sizzle(n,t))},va=Sizzle.matchesSelector),Xo.selection=function(){return xa};var da=Xo.selection.prototype=[];da.select=function(n){var t,e,r,u,i=[];n=M(n);for(var o=-1,a=this.length;++o=0&&(e=n.substring(0,t),n=n.substring(t+1)),ma.hasOwnProperty(e)?{space:ma[e],local:n}:n}},da.attr=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node();return n=Xo.ns.qualify(n),n.local?e.getAttributeNS(n.space,n.local):e.getAttribute(n)}for(t in n)this.each(b(t,n[t]));return this}return this.each(b(n,t))},da.classed=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node(),r=(n=k(n)).length,u=-1;if(t=e.classList){for(;++ur){if("string"!=typeof n){2>r&&(t="");for(e in n)this.each(C(e,n[e],t));return this}if(2>r)return Go.getComputedStyle(this.node(),null).getPropertyValue(n);e=""}return this.each(C(n,t,e))},da.property=function(n,t){if(arguments.length<2){if("string"==typeof n)return this.node()[n];for(t in n)this.each(N(t,n[t]));return this}return this.each(N(n,t))},da.text=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.textContent=null==t?"":t}:null==n?function(){this.textContent=""}:function(){this.textContent=n}):this.node().textContent},da.html=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.innerHTML=null==t?"":t}:null==n?function(){this.innerHTML=""}:function(){this.innerHTML=n}):this.node().innerHTML},da.append=function(n){return n=L(n),this.select(function(){return this.appendChild(n.apply(this,arguments))})},da.insert=function(n,t){return n=L(n),t=M(t),this.select(function(){return this.insertBefore(n.apply(this,arguments),t.apply(this,arguments)||null)})},da.remove=function(){return this.each(function(){var n=this.parentNode;n&&n.removeChild(this)})},da.data=function(n,t){function e(n,e){var r,i,o,a=n.length,f=e.length,h=Math.min(a,f),g=new Array(f),p=new Array(f),v=new Array(a);if(t){var d,m=new u,y=new u,x=[];for(r=-1;++rr;++r)p[r]=z(e[r]);for(;a>r;++r)v[r]=n[r]}p.update=g,p.parentNode=g.parentNode=v.parentNode=n.parentNode,c.push(p),s.push(g),l.push(v)}var r,i,o=-1,a=this.length;if(!arguments.length){for(n=new Array(a=(r=this[0]).length);++oi;i++){u.push(t=[]),t.parentNode=(e=this[i]).parentNode;for(var a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return x(u)},da.order=function(){for(var n=-1,t=this.length;++n=0;)(e=r[u])&&(i&&i!==e.nextSibling&&i.parentNode.insertBefore(e,i),i=e);return this},da.sort=function(n){n=T.apply(this,arguments);for(var t=-1,e=this.length;++tn;n++)for(var e=this[n],r=0,u=e.length;u>r;r++){var i=e[r];if(i)return i}return null},da.size=function(){var n=0;return this.each(function(){++n}),n};var ya=[];Xo.selection.enter=D,Xo.selection.enter.prototype=ya,ya.append=da.append,ya.empty=da.empty,ya.node=da.node,ya.call=da.call,ya.size=da.size,ya.select=function(n){for(var t,e,r,u,i,o=[],a=-1,c=this.length;++ar){if("string"!=typeof n){2>r&&(t=!1);for(e in n)this.each(j(e,n[e],t));return this}if(2>r)return(r=this.node()["__on"+n])&&r._;e=!1}return this.each(j(n,t,e))};var Ma=Xo.map({mouseenter:"mouseover",mouseleave:"mouseout"});Ma.forEach(function(n){"on"+n in Wo&&Ma.remove(n)});var _a="onselectstart"in Wo?null:h(Jo.style,"userSelect"),ba=0;Xo.mouse=function(n){return Y(n,m())};var wa=/WebKit/.test(Go.navigator.userAgent)?-1:0;Xo.touches=function(n,t){return arguments.length<2&&(t=m().touches),t?Bo(t).map(function(t){var e=Y(n,t);return e.identifier=t.identifier,e}):[]},Xo.behavior.drag=function(){function n(){this.on("mousedown.drag",o).on("touchstart.drag",a)}function t(){return Xo.event.changedTouches[0].identifier}function e(n,t){return Xo.touches(n).filter(function(n){return n.identifier===t})[0]}function r(n,t,e,r){return function(){function o(){var n=t(l,g),e=n[0]-v[0],r=n[1]-v[1];d|=e|r,v=n,f({type:"drag",x:n[0]+c[0],y:n[1]+c[1],dx:e,dy:r})}function a(){m.on(e+"."+p,null).on(r+"."+p,null),y(d&&Xo.event.target===h),f({type:"dragend"})}var c,s=this,l=s.parentNode,f=u.of(s,arguments),h=Xo.event.target,g=n(),p=null==g?"drag":"drag-"+g,v=t(l,g),d=0,m=Xo.select(Go).on(e+"."+p,o).on(r+"."+p,a),y=O();i?(c=i.apply(s,arguments),c=[c.x-v[0],c.y-v[1]]):c=[0,0],f({type:"dragstart"})}}var u=y(n,"drag","dragstart","dragend"),i=null,o=r(g,Xo.mouse,"mousemove","mouseup"),a=r(t,e,"touchmove","touchend");return n.origin=function(t){return arguments.length?(i=t,n):i},Xo.rebind(n,u,"on")};var Sa=Math.PI,ka=2*Sa,Ea=Sa/2,Aa=1e-6,Ca=Aa*Aa,Na=Sa/180,La=180/Sa,za=Math.SQRT2,qa=2,Ta=4;Xo.interpolateZoom=function(n,t){function e(n){var t=n*y;if(m){var e=B(v),o=i/(qa*h)*(e*W(za*t+v)-$(v));return[r+o*s,u+o*l,i*e/B(za*t+v)]}return[r+n*s,u+n*l,i*Math.exp(za*t)]}var r=n[0],u=n[1],i=n[2],o=t[0],a=t[1],c=t[2],s=o-r,l=a-u,f=s*s+l*l,h=Math.sqrt(f),g=(c*c-i*i+Ta*f)/(2*i*qa*h),p=(c*c-i*i-Ta*f)/(2*c*qa*h),v=Math.log(Math.sqrt(g*g+1)-g),d=Math.log(Math.sqrt(p*p+1)-p),m=d-v,y=(m||Math.log(c/i))/za;return e.duration=1e3*y,e},Xo.behavior.zoom=function(){function n(n){n.on(A,s).on(Pa+".zoom",f).on(C,h).on("dblclick.zoom",g).on(L,l)}function t(n){return[(n[0]-S.x)/S.k,(n[1]-S.y)/S.k]}function e(n){return[n[0]*S.k+S.x,n[1]*S.k+S.y]}function r(n){S.k=Math.max(E[0],Math.min(E[1],n))}function u(n,t){t=e(t),S.x+=n[0]-t[0],S.y+=n[1]-t[1]}function i(){_&&_.domain(M.range().map(function(n){return(n-S.x)/S.k}).map(M.invert)),w&&w.domain(b.range().map(function(n){return(n-S.y)/S.k}).map(b.invert))}function o(n){n({type:"zoomstart"})}function a(n){i(),n({type:"zoom",scale:S.k,translate:[S.x,S.y]})}function c(n){n({type:"zoomend"})}function s(){function n(){l=1,u(Xo.mouse(r),g),a(i)}function e(){f.on(C,Go===r?h:null).on(N,null),p(l&&Xo.event.target===s),c(i)}var r=this,i=z.of(r,arguments),s=Xo.event.target,l=0,f=Xo.select(Go).on(C,n).on(N,e),g=t(Xo.mouse(r)),p=O();U.call(r),o(i)}function l(){function n(){var n=Xo.touches(g);return h=S.k,n.forEach(function(n){n.identifier in v&&(v[n.identifier]=t(n))}),n}function e(){for(var t=Xo.event.changedTouches,e=0,i=t.length;i>e;++e)v[t[e].identifier]=null;var o=n(),c=Date.now();if(1===o.length){if(500>c-x){var s=o[0],l=v[s.identifier];r(2*S.k),u(s,l),d(),a(p)}x=c}else if(o.length>1){var s=o[0],f=o[1],h=s[0]-f[0],g=s[1]-f[1];m=h*h+g*g}}function i(){for(var n,t,e,i,o=Xo.touches(g),c=0,s=o.length;s>c;++c,i=null)if(e=o[c],i=v[e.identifier]){if(t)break;n=e,t=i}if(i){var l=(l=e[0]-n[0])*l+(l=e[1]-n[1])*l,f=m&&Math.sqrt(l/m);n=[(n[0]+e[0])/2,(n[1]+e[1])/2],t=[(t[0]+i[0])/2,(t[1]+i[1])/2],r(f*h)}x=null,u(n,t),a(p)}function f(){if(Xo.event.touches.length){for(var t=Xo.event.changedTouches,e=0,r=t.length;r>e;++e)delete v[t[e].identifier];for(var u in v)return void n()}b.on(M,null).on(_,null),w.on(A,s).on(L,l),k(),c(p)}var h,g=this,p=z.of(g,arguments),v={},m=0,y=Xo.event.changedTouches[0].identifier,M="touchmove.zoom-"+y,_="touchend.zoom-"+y,b=Xo.select(Go).on(M,i).on(_,f),w=Xo.select(g).on(A,null).on(L,e),k=O();U.call(g),e(),o(p)}function f(){var n=z.of(this,arguments);m?clearTimeout(m):(U.call(this),o(n)),m=setTimeout(function(){m=null,c(n)},50),d();var e=v||Xo.mouse(this);p||(p=t(e)),r(Math.pow(2,.002*Ra())*S.k),u(e,p),a(n)}function h(){p=null}function g(){var n=z.of(this,arguments),e=Xo.mouse(this),i=t(e),s=Math.log(S.k)/Math.LN2;o(n),r(Math.pow(2,Xo.event.shiftKey?Math.ceil(s)-1:Math.floor(s)+1)),u(e,i),a(n),c(n)}var p,v,m,x,M,_,b,w,S={x:0,y:0,k:1},k=[960,500],E=Da,A="mousedown.zoom",C="mousemove.zoom",N="mouseup.zoom",L="touchstart.zoom",z=y(n,"zoomstart","zoom","zoomend");return n.event=function(n){n.each(function(){var n=z.of(this,arguments),t=S;ks?Xo.select(this).transition().each("start.zoom",function(){S=this.__chart__||{x:0,y:0,k:1},o(n)}).tween("zoom:zoom",function(){var e=k[0],r=k[1],u=e/2,i=r/2,o=Xo.interpolateZoom([(u-S.x)/S.k,(i-S.y)/S.k,e/S.k],[(u-t.x)/t.k,(i-t.y)/t.k,e/t.k]);return function(t){var r=o(t),c=e/r[2];this.__chart__=S={x:u-r[0]*c,y:i-r[1]*c,k:c},a(n)}}).each("end.zoom",function(){c(n)}):(this.__chart__=S,o(n),a(n),c(n))})},n.translate=function(t){return arguments.length?(S={x:+t[0],y:+t[1],k:S.k},i(),n):[S.x,S.y]},n.scale=function(t){return arguments.length?(S={x:S.x,y:S.y,k:+t},i(),n):S.k},n.scaleExtent=function(t){return arguments.length?(E=null==t?Da:[+t[0],+t[1]],n):E},n.center=function(t){return arguments.length?(v=t&&[+t[0],+t[1]],n):v},n.size=function(t){return arguments.length?(k=t&&[+t[0],+t[1]],n):k},n.x=function(t){return arguments.length?(_=t,M=t.copy(),S={x:0,y:0,k:1},n):_},n.y=function(t){return arguments.length?(w=t,b=t.copy(),S={x:0,y:0,k:1},n):w},Xo.rebind(n,z,"on")};var Ra,Da=[0,1/0],Pa="onwheel"in Wo?(Ra=function(){return-Xo.event.deltaY*(Xo.event.deltaMode?120:1)},"wheel"):"onmousewheel"in Wo?(Ra=function(){return Xo.event.wheelDelta},"mousewheel"):(Ra=function(){return-Xo.event.detail},"MozMousePixelScroll");G.prototype.toString=function(){return this.rgb()+""},Xo.hsl=function(n,t,e){return 1===arguments.length?n instanceof Q?K(n.h,n.s,n.l):dt(""+n,mt,K):K(+n,+t,+e)};var Ua=Q.prototype=new G;Ua.brighter=function(n){return n=Math.pow(.7,arguments.length?n:1),K(this.h,this.s,this.l/n)},Ua.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),K(this.h,this.s,n*this.l)},Ua.rgb=function(){return nt(this.h,this.s,this.l)},Xo.hcl=function(n,t,e){return 1===arguments.length?n instanceof et?tt(n.h,n.c,n.l):n instanceof it?at(n.l,n.a,n.b):at((n=yt((n=Xo.rgb(n)).r,n.g,n.b)).l,n.a,n.b):tt(+n,+t,+e)};var ja=et.prototype=new G;ja.brighter=function(n){return tt(this.h,this.c,Math.min(100,this.l+Ha*(arguments.length?n:1)))},ja.darker=function(n){return tt(this.h,this.c,Math.max(0,this.l-Ha*(arguments.length?n:1)))},ja.rgb=function(){return rt(this.h,this.c,this.l).rgb()},Xo.lab=function(n,t,e){return 1===arguments.length?n instanceof it?ut(n.l,n.a,n.b):n instanceof et?rt(n.l,n.c,n.h):yt((n=Xo.rgb(n)).r,n.g,n.b):ut(+n,+t,+e)};var Ha=18,Fa=.95047,Oa=1,Ya=1.08883,Ia=it.prototype=new G;Ia.brighter=function(n){return ut(Math.min(100,this.l+Ha*(arguments.length?n:1)),this.a,this.b)},Ia.darker=function(n){return ut(Math.max(0,this.l-Ha*(arguments.length?n:1)),this.a,this.b)},Ia.rgb=function(){return ot(this.l,this.a,this.b)},Xo.rgb=function(n,t,e){return 1===arguments.length?n instanceof pt?gt(n.r,n.g,n.b):dt(""+n,gt,nt):gt(~~n,~~t,~~e)};var Za=pt.prototype=new G;Za.brighter=function(n){n=Math.pow(.7,arguments.length?n:1);var t=this.r,e=this.g,r=this.b,u=30;return t||e||r?(t&&u>t&&(t=u),e&&u>e&&(e=u),r&&u>r&&(r=u),gt(Math.min(255,~~(t/n)),Math.min(255,~~(e/n)),Math.min(255,~~(r/n)))):gt(u,u,u)},Za.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),gt(~~(n*this.r),~~(n*this.g),~~(n*this.b))},Za.hsl=function(){return mt(this.r,this.g,this.b)},Za.toString=function(){return"#"+vt(this.r)+vt(this.g)+vt(this.b)};var Va=Xo.map({aliceblue:15792383,antiquewhite:16444375,aqua:65535,aquamarine:8388564,azure:15794175,beige:16119260,bisque:16770244,black:0,blanchedalmond:16772045,blue:255,blueviolet:9055202,brown:10824234,burlywood:14596231,cadetblue:6266528,chartreuse:8388352,chocolate:13789470,coral:16744272,cornflowerblue:6591981,cornsilk:16775388,crimson:14423100,cyan:65535,darkblue:139,darkcyan:35723,darkgoldenrod:12092939,darkgray:11119017,darkgreen:25600,darkgrey:11119017,darkkhaki:12433259,darkmagenta:9109643,darkolivegreen:5597999,darkorange:16747520,darkorchid:10040012,darkred:9109504,darksalmon:15308410,darkseagreen:9419919,darkslateblue:4734347,darkslategray:3100495,darkslategrey:3100495,darkturquoise:52945,darkviolet:9699539,deeppink:16716947,deepskyblue:49151,dimgray:6908265,dimgrey:6908265,dodgerblue:2003199,firebrick:11674146,floralwhite:16775920,forestgreen:2263842,fuchsia:16711935,gainsboro:14474460,ghostwhite:16316671,gold:16766720,goldenrod:14329120,gray:8421504,green:32768,greenyellow:11403055,grey:8421504,honeydew:15794160,hotpink:16738740,indianred:13458524,indigo:4915330,ivory:16777200,khaki:15787660,lavender:15132410,lavenderblush:16773365,lawngreen:8190976,lemonchiffon:16775885,lightblue:11393254,lightcoral:15761536,lightcyan:14745599,lightgoldenrodyellow:16448210,lightgray:13882323,lightgreen:9498256,lightgrey:13882323,lightpink:16758465,lightsalmon:16752762,lightseagreen:2142890,lightskyblue:8900346,lightslategray:7833753,lightslategrey:7833753,lightsteelblue:11584734,lightyellow:16777184,lime:65280,limegreen:3329330,linen:16445670,magenta:16711935,maroon:8388608,mediumaquamarine:6737322,mediumblue:205,mediumorchid:12211667,mediumpurple:9662683,mediumseagreen:3978097,mediumslateblue:8087790,mediumspringgreen:64154,mediumturquoise:4772300,mediumvioletred:13047173,midnightblue:1644912,mintcream:16121850,mistyrose:16770273,moccasin:16770229,navajowhite:16768685,navy:128,oldlace:16643558,olive:8421376,olivedrab:7048739,orange:16753920,orangered:16729344,orchid:14315734,palegoldenrod:15657130,palegreen:10025880,paleturquoise:11529966,palevioletred:14381203,papayawhip:16773077,peachpuff:16767673,peru:13468991,pink:16761035,plum:14524637,powderblue:11591910,purple:8388736,red:16711680,rosybrown:12357519,royalblue:4286945,saddlebrown:9127187,salmon:16416882,sandybrown:16032864,seagreen:3050327,seashell:16774638,sienna:10506797,silver:12632256,skyblue:8900331,slateblue:6970061,slategray:7372944,slategrey:7372944,snow:16775930,springgreen:65407,steelblue:4620980,tan:13808780,teal:32896,thistle:14204888,tomato:16737095,turquoise:4251856,violet:15631086,wheat:16113331,white:16777215,whitesmoke:16119285,yellow:16776960,yellowgreen:10145074});Va.forEach(function(n,t){Va.set(n,ft(t))}),Xo.functor=_t,Xo.xhr=wt(bt),Xo.dsv=function(n,t){function e(n,e,i){arguments.length<3&&(i=e,e=null);var o=St(n,t,null==e?r:u(e),i);return o.row=function(n){return arguments.length?o.response(null==(e=n)?r:u(n)):e},o}function r(n){return e.parse(n.responseText)}function u(n){return function(t){return e.parse(t.responseText,n)}}function i(t){return t.map(o).join(n)}function o(n){return a.test(n)?'"'+n.replace(/\"/g,'""')+'"':n}var a=new RegExp('["'+n+"\n]"),c=n.charCodeAt(0);return e.parse=function(n,t){var r;return e.parseRows(n,function(n,e){if(r)return r(n,e-1);var u=new Function("d","return {"+n.map(function(n,t){return JSON.stringify(n)+": d["+t+"]"}).join(",")+"}");r=t?function(n,e){return t(u(n),e)}:u})},e.parseRows=function(n,t){function e(){if(l>=s)return o;if(u)return u=!1,i;var t=l;if(34===n.charCodeAt(t)){for(var e=t;e++l;){var r=n.charCodeAt(l++),a=1;if(10===r)u=!0;else if(13===r)u=!0,10===n.charCodeAt(l)&&(++l,++a);else if(r!==c)continue;return n.substring(t,l-a)}return n.substring(t)}for(var r,u,i={},o={},a=[],s=n.length,l=0,f=0;(r=e())!==o;){for(var h=[];r!==i&&r!==o;)h.push(r),r=e();(!t||(h=t(h,f++)))&&a.push(h)}return a},e.format=function(t){if(Array.isArray(t[0]))return e.formatRows(t);var r=new l,u=[];return t.forEach(function(n){for(var t in n)r.has(t)||u.push(r.add(t))}),[u.map(o).join(n)].concat(t.map(function(t){return u.map(function(n){return o(t[n])}).join(n)})).join("\n")},e.formatRows=function(n){return n.map(i).join("\n")},e},Xo.csv=Xo.dsv(",","text/csv"),Xo.tsv=Xo.dsv(" ","text/tab-separated-values");var Xa,$a,Ba,Wa,Ja,Ga=Go[h(Go,"requestAnimationFrame")]||function(n){setTimeout(n,17)};Xo.timer=function(n,t,e){var r=arguments.length;2>r&&(t=0),3>r&&(e=Date.now());var u=e+t,i={c:n,t:u,f:!1,n:null};$a?$a.n=i:Xa=i,$a=i,Ba||(Wa=clearTimeout(Wa),Ba=1,Ga(Et))},Xo.timer.flush=function(){At(),Ct()},Xo.round=function(n,t){return t?Math.round(n*(t=Math.pow(10,t)))/t:Math.round(n)};var Ka=["y","z","a","f","p","n","\xb5","m","","k","M","G","T","P","E","Z","Y"].map(Lt);Xo.formatPrefix=function(n,t){var e=0;return n&&(0>n&&(n*=-1),t&&(n=Xo.round(n,Nt(n,t))),e=1+Math.floor(1e-12+Math.log(n)/Math.LN10),e=Math.max(-24,Math.min(24,3*Math.floor((0>=e?e+1:e-1)/3)))),Ka[8+e/3]};var Qa=/(?:([^{])?([<>=^]))?([+\- ])?([$#])?(0)?(\d+)?(,)?(\.-?\d+)?([a-z%])?/i,nc=Xo.map({b:function(n){return n.toString(2)},c:function(n){return String.fromCharCode(n)},o:function(n){return n.toString(8)},x:function(n){return n.toString(16)},X:function(n){return n.toString(16).toUpperCase()},g:function(n,t){return n.toPrecision(t)},e:function(n,t){return n.toExponential(t)},f:function(n,t){return n.toFixed(t)},r:function(n,t){return(n=Xo.round(n,Nt(n,t))).toFixed(Math.max(0,Math.min(20,Nt(n*(1+1e-15),t))))}}),tc=Xo.time={},ec=Date;Tt.prototype={getDate:function(){return this._.getUTCDate()},getDay:function(){return this._.getUTCDay()},getFullYear:function(){return this._.getUTCFullYear()},getHours:function(){return this._.getUTCHours()},getMilliseconds:function(){return this._.getUTCMilliseconds()},getMinutes:function(){return this._.getUTCMinutes()},getMonth:function(){return this._.getUTCMonth()},getSeconds:function(){return this._.getUTCSeconds()},getTime:function(){return this._.getTime()},getTimezoneOffset:function(){return 0},valueOf:function(){return this._.valueOf()},setDate:function(){rc.setUTCDate.apply(this._,arguments)},setDay:function(){rc.setUTCDay.apply(this._,arguments)},setFullYear:function(){rc.setUTCFullYear.apply(this._,arguments)},setHours:function(){rc.setUTCHours.apply(this._,arguments)},setMilliseconds:function(){rc.setUTCMilliseconds.apply(this._,arguments)},setMinutes:function(){rc.setUTCMinutes.apply(this._,arguments)},setMonth:function(){rc.setUTCMonth.apply(this._,arguments)},setSeconds:function(){rc.setUTCSeconds.apply(this._,arguments)},setTime:function(){rc.setTime.apply(this._,arguments)}};var rc=Date.prototype;tc.year=Rt(function(n){return n=tc.day(n),n.setMonth(0,1),n},function(n,t){n.setFullYear(n.getFullYear()+t)},function(n){return n.getFullYear()}),tc.years=tc.year.range,tc.years.utc=tc.year.utc.range,tc.day=Rt(function(n){var t=new ec(2e3,0);return t.setFullYear(n.getFullYear(),n.getMonth(),n.getDate()),t},function(n,t){n.setDate(n.getDate()+t)},function(n){return n.getDate()-1}),tc.days=tc.day.range,tc.days.utc=tc.day.utc.range,tc.dayOfYear=function(n){var t=tc.year(n);return Math.floor((n-t-6e4*(n.getTimezoneOffset()-t.getTimezoneOffset()))/864e5)},["sunday","monday","tuesday","wednesday","thursday","friday","saturday"].forEach(function(n,t){t=7-t;var e=tc[n]=Rt(function(n){return(n=tc.day(n)).setDate(n.getDate()-(n.getDay()+t)%7),n},function(n,t){n.setDate(n.getDate()+7*Math.floor(t))},function(n){var e=tc.year(n).getDay();return Math.floor((tc.dayOfYear(n)+(e+t)%7)/7)-(e!==t)});tc[n+"s"]=e.range,tc[n+"s"].utc=e.utc.range,tc[n+"OfYear"]=function(n){var e=tc.year(n).getDay();return Math.floor((tc.dayOfYear(n)+(e+t)%7)/7)}}),tc.week=tc.sunday,tc.weeks=tc.sunday.range,tc.weeks.utc=tc.sunday.utc.range,tc.weekOfYear=tc.sundayOfYear;var uc={"-":"",_:" ",0:"0"},ic=/^\s*\d+/,oc=/^%/;Xo.locale=function(n){return{numberFormat:zt(n),timeFormat:Pt(n)}};var ac=Xo.locale({decimal:".",thousands:",",grouping:[3],currency:["$",""],dateTime:"%a %b %e %X %Y",date:"%m/%d/%Y",time:"%H:%M:%S",periods:["AM","PM"],days:["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],shortDays:["Sun","Mon","Tue","Wed","Thu","Fri","Sat"],months:["January","February","March","April","May","June","July","August","September","October","November","December"],shortMonths:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]});Xo.format=ac.numberFormat,Xo.geo={},re.prototype={s:0,t:0,add:function(n){ue(n,this.t,cc),ue(cc.s,this.s,this),this.s?this.t+=cc.t:this.s=cc.t},reset:function(){this.s=this.t=0},valueOf:function(){return this.s}};var cc=new re;Xo.geo.stream=function(n,t){n&&sc.hasOwnProperty(n.type)?sc[n.type](n,t):ie(n,t)};var sc={Feature:function(n,t){ie(n.geometry,t)},FeatureCollection:function(n,t){for(var e=n.features,r=-1,u=e.length;++rn?4*Sa+n:n,gc.lineStart=gc.lineEnd=gc.point=g}};Xo.geo.bounds=function(){function n(n,t){x.push(M=[l=n,h=n]),f>t&&(f=t),t>g&&(g=t)}function t(t,e){var r=se([t*Na,e*Na]);if(m){var u=fe(m,r),i=[u[1],-u[0],0],o=fe(i,u);pe(o),o=ve(o);var c=t-p,s=c>0?1:-1,v=o[0]*La*s,d=oa(c)>180;if(d^(v>s*p&&s*t>v)){var y=o[1]*La;y>g&&(g=y)}else if(v=(v+360)%360-180,d^(v>s*p&&s*t>v)){var y=-o[1]*La;f>y&&(f=y)}else f>e&&(f=e),e>g&&(g=e);d?p>t?a(l,t)>a(l,h)&&(h=t):a(t,h)>a(l,h)&&(l=t):h>=l?(l>t&&(l=t),t>h&&(h=t)):t>p?a(l,t)>a(l,h)&&(h=t):a(t,h)>a(l,h)&&(l=t)}else n(t,e);m=r,p=t}function e(){_.point=t}function r(){M[0]=l,M[1]=h,_.point=n,m=null}function u(n,e){if(m){var r=n-p;y+=oa(r)>180?r+(r>0?360:-360):r}else v=n,d=e;gc.point(n,e),t(n,e)}function i(){gc.lineStart()}function o(){u(v,d),gc.lineEnd(),oa(y)>Aa&&(l=-(h=180)),M[0]=l,M[1]=h,m=null}function a(n,t){return(t-=n)<0?t+360:t}function c(n,t){return n[0]-t[0]}function s(n,t){return t[0]<=t[1]?t[0]<=n&&n<=t[1]:nhc?(l=-(h=180),f=-(g=90)):y>Aa?g=90:-Aa>y&&(f=-90),M[0]=l,M[1]=h -}};return function(n){g=h=-(l=f=1/0),x=[],Xo.geo.stream(n,_);var t=x.length;if(t){x.sort(c);for(var e,r=1,u=x[0],i=[u];t>r;++r)e=x[r],s(e[0],u)||s(e[1],u)?(a(u[0],e[1])>a(u[0],u[1])&&(u[1]=e[1]),a(e[0],u[1])>a(u[0],u[1])&&(u[0]=e[0])):i.push(u=e);for(var o,e,p=-1/0,t=i.length-1,r=0,u=i[t];t>=r;u=e,++r)e=i[r],(o=a(u[1],e[0]))>p&&(p=o,l=e[0],h=u[1])}return x=M=null,1/0===l||1/0===f?[[0/0,0/0],[0/0,0/0]]:[[l,f],[h,g]]}}(),Xo.geo.centroid=function(n){pc=vc=dc=mc=yc=xc=Mc=_c=bc=wc=Sc=0,Xo.geo.stream(n,kc);var t=bc,e=wc,r=Sc,u=t*t+e*e+r*r;return Ca>u&&(t=xc,e=Mc,r=_c,Aa>vc&&(t=dc,e=mc,r=yc),u=t*t+e*e+r*r,Ca>u)?[0/0,0/0]:[Math.atan2(e,t)*La,X(r/Math.sqrt(u))*La]};var pc,vc,dc,mc,yc,xc,Mc,_c,bc,wc,Sc,kc={sphere:g,point:me,lineStart:xe,lineEnd:Me,polygonStart:function(){kc.lineStart=_e},polygonEnd:function(){kc.lineStart=xe}},Ec=Ee(be,ze,Te,[-Sa,-Sa/2]),Ac=1e9;Xo.geo.clipExtent=function(){var n,t,e,r,u,i,o={stream:function(n){return u&&(u.valid=!1),u=i(n),u.valid=!0,u},extent:function(a){return arguments.length?(i=Pe(n=+a[0][0],t=+a[0][1],e=+a[1][0],r=+a[1][1]),u&&(u.valid=!1,u=null),o):[[n,t],[e,r]]}};return o.extent([[0,0],[960,500]])},(Xo.geo.conicEqualArea=function(){return je(He)}).raw=He,Xo.geo.albers=function(){return Xo.geo.conicEqualArea().rotate([96,0]).center([-.6,38.7]).parallels([29.5,45.5]).scale(1070)},Xo.geo.albersUsa=function(){function n(n){var i=n[0],o=n[1];return t=null,e(i,o),t||(r(i,o),t)||u(i,o),t}var t,e,r,u,i=Xo.geo.albers(),o=Xo.geo.conicEqualArea().rotate([154,0]).center([-2,58.5]).parallels([55,65]),a=Xo.geo.conicEqualArea().rotate([157,0]).center([-3,19.9]).parallels([8,18]),c={point:function(n,e){t=[n,e]}};return n.invert=function(n){var t=i.scale(),e=i.translate(),r=(n[0]-e[0])/t,u=(n[1]-e[1])/t;return(u>=.12&&.234>u&&r>=-.425&&-.214>r?o:u>=.166&&.234>u&&r>=-.214&&-.115>r?a:i).invert(n)},n.stream=function(n){var t=i.stream(n),e=o.stream(n),r=a.stream(n);return{point:function(n,u){t.point(n,u),e.point(n,u),r.point(n,u)},sphere:function(){t.sphere(),e.sphere(),r.sphere()},lineStart:function(){t.lineStart(),e.lineStart(),r.lineStart()},lineEnd:function(){t.lineEnd(),e.lineEnd(),r.lineEnd()},polygonStart:function(){t.polygonStart(),e.polygonStart(),r.polygonStart()},polygonEnd:function(){t.polygonEnd(),e.polygonEnd(),r.polygonEnd()}}},n.precision=function(t){return arguments.length?(i.precision(t),o.precision(t),a.precision(t),n):i.precision()},n.scale=function(t){return arguments.length?(i.scale(t),o.scale(.35*t),a.scale(t),n.translate(i.translate())):i.scale()},n.translate=function(t){if(!arguments.length)return i.translate();var s=i.scale(),l=+t[0],f=+t[1];return e=i.translate(t).clipExtent([[l-.455*s,f-.238*s],[l+.455*s,f+.238*s]]).stream(c).point,r=o.translate([l-.307*s,f+.201*s]).clipExtent([[l-.425*s+Aa,f+.12*s+Aa],[l-.214*s-Aa,f+.234*s-Aa]]).stream(c).point,u=a.translate([l-.205*s,f+.212*s]).clipExtent([[l-.214*s+Aa,f+.166*s+Aa],[l-.115*s-Aa,f+.234*s-Aa]]).stream(c).point,n},n.scale(1070)};var Cc,Nc,Lc,zc,qc,Tc,Rc={point:g,lineStart:g,lineEnd:g,polygonStart:function(){Nc=0,Rc.lineStart=Fe},polygonEnd:function(){Rc.lineStart=Rc.lineEnd=Rc.point=g,Cc+=oa(Nc/2)}},Dc={point:Oe,lineStart:g,lineEnd:g,polygonStart:g,polygonEnd:g},Pc={point:Ze,lineStart:Ve,lineEnd:Xe,polygonStart:function(){Pc.lineStart=$e},polygonEnd:function(){Pc.point=Ze,Pc.lineStart=Ve,Pc.lineEnd=Xe}};Xo.geo.path=function(){function n(n){return n&&("function"==typeof a&&i.pointRadius(+a.apply(this,arguments)),o&&o.valid||(o=u(i)),Xo.geo.stream(n,o)),i.result()}function t(){return o=null,n}var e,r,u,i,o,a=4.5;return n.area=function(n){return Cc=0,Xo.geo.stream(n,u(Rc)),Cc},n.centroid=function(n){return dc=mc=yc=xc=Mc=_c=bc=wc=Sc=0,Xo.geo.stream(n,u(Pc)),Sc?[bc/Sc,wc/Sc]:_c?[xc/_c,Mc/_c]:yc?[dc/yc,mc/yc]:[0/0,0/0]},n.bounds=function(n){return qc=Tc=-(Lc=zc=1/0),Xo.geo.stream(n,u(Dc)),[[Lc,zc],[qc,Tc]]},n.projection=function(n){return arguments.length?(u=(e=n)?n.stream||Je(n):bt,t()):e},n.context=function(n){return arguments.length?(i=null==(r=n)?new Ye:new Be(n),"function"!=typeof a&&i.pointRadius(a),t()):r},n.pointRadius=function(t){return arguments.length?(a="function"==typeof t?t:(i.pointRadius(+t),+t),n):a},n.projection(Xo.geo.albersUsa()).context(null)},Xo.geo.transform=function(n){return{stream:function(t){var e=new Ge(t);for(var r in n)e[r]=n[r];return e}}},Ge.prototype={point:function(n,t){this.stream.point(n,t)},sphere:function(){this.stream.sphere()},lineStart:function(){this.stream.lineStart()},lineEnd:function(){this.stream.lineEnd()},polygonStart:function(){this.stream.polygonStart()},polygonEnd:function(){this.stream.polygonEnd()}},Xo.geo.projection=Qe,Xo.geo.projectionMutator=nr,(Xo.geo.equirectangular=function(){return Qe(er)}).raw=er.invert=er,Xo.geo.rotation=function(n){function t(t){return t=n(t[0]*Na,t[1]*Na),t[0]*=La,t[1]*=La,t}return n=ur(n[0]%360*Na,n[1]*Na,n.length>2?n[2]*Na:0),t.invert=function(t){return t=n.invert(t[0]*Na,t[1]*Na),t[0]*=La,t[1]*=La,t},t},rr.invert=er,Xo.geo.circle=function(){function n(){var n="function"==typeof r?r.apply(this,arguments):r,t=ur(-n[0]*Na,-n[1]*Na,0).invert,u=[];return e(null,null,1,{point:function(n,e){u.push(n=t(n,e)),n[0]*=La,n[1]*=La}}),{type:"Polygon",coordinates:[u]}}var t,e,r=[0,0],u=6;return n.origin=function(t){return arguments.length?(r=t,n):r},n.angle=function(r){return arguments.length?(e=cr((t=+r)*Na,u*Na),n):t},n.precision=function(r){return arguments.length?(e=cr(t*Na,(u=+r)*Na),n):u},n.angle(90)},Xo.geo.distance=function(n,t){var e,r=(t[0]-n[0])*Na,u=n[1]*Na,i=t[1]*Na,o=Math.sin(r),a=Math.cos(r),c=Math.sin(u),s=Math.cos(u),l=Math.sin(i),f=Math.cos(i);return Math.atan2(Math.sqrt((e=f*o)*e+(e=s*l-c*f*a)*e),c*l+s*f*a)},Xo.geo.graticule=function(){function n(){return{type:"MultiLineString",coordinates:t()}}function t(){return Xo.range(Math.ceil(i/d)*d,u,d).map(h).concat(Xo.range(Math.ceil(s/m)*m,c,m).map(g)).concat(Xo.range(Math.ceil(r/p)*p,e,p).filter(function(n){return oa(n%d)>Aa}).map(l)).concat(Xo.range(Math.ceil(a/v)*v,o,v).filter(function(n){return oa(n%m)>Aa}).map(f))}var e,r,u,i,o,a,c,s,l,f,h,g,p=10,v=p,d=90,m=360,y=2.5;return n.lines=function(){return t().map(function(n){return{type:"LineString",coordinates:n}})},n.outline=function(){return{type:"Polygon",coordinates:[h(i).concat(g(c).slice(1),h(u).reverse().slice(1),g(s).reverse().slice(1))]}},n.extent=function(t){return arguments.length?n.majorExtent(t).minorExtent(t):n.minorExtent()},n.majorExtent=function(t){return arguments.length?(i=+t[0][0],u=+t[1][0],s=+t[0][1],c=+t[1][1],i>u&&(t=i,i=u,u=t),s>c&&(t=s,s=c,c=t),n.precision(y)):[[i,s],[u,c]]},n.minorExtent=function(t){return arguments.length?(r=+t[0][0],e=+t[1][0],a=+t[0][1],o=+t[1][1],r>e&&(t=r,r=e,e=t),a>o&&(t=a,a=o,o=t),n.precision(y)):[[r,a],[e,o]]},n.step=function(t){return arguments.length?n.majorStep(t).minorStep(t):n.minorStep()},n.majorStep=function(t){return arguments.length?(d=+t[0],m=+t[1],n):[d,m]},n.minorStep=function(t){return arguments.length?(p=+t[0],v=+t[1],n):[p,v]},n.precision=function(t){return arguments.length?(y=+t,l=lr(a,o,90),f=fr(r,e,y),h=lr(s,c,90),g=fr(i,u,y),n):y},n.majorExtent([[-180,-90+Aa],[180,90-Aa]]).minorExtent([[-180,-80-Aa],[180,80+Aa]])},Xo.geo.greatArc=function(){function n(){return{type:"LineString",coordinates:[t||r.apply(this,arguments),e||u.apply(this,arguments)]}}var t,e,r=hr,u=gr;return n.distance=function(){return Xo.geo.distance(t||r.apply(this,arguments),e||u.apply(this,arguments))},n.source=function(e){return arguments.length?(r=e,t="function"==typeof e?null:e,n):r},n.target=function(t){return arguments.length?(u=t,e="function"==typeof t?null:t,n):u},n.precision=function(){return arguments.length?n:0},n},Xo.geo.interpolate=function(n,t){return pr(n[0]*Na,n[1]*Na,t[0]*Na,t[1]*Na)},Xo.geo.length=function(n){return Uc=0,Xo.geo.stream(n,jc),Uc};var Uc,jc={sphere:g,point:g,lineStart:vr,lineEnd:g,polygonStart:g,polygonEnd:g},Hc=dr(function(n){return Math.sqrt(2/(1+n))},function(n){return 2*Math.asin(n/2)});(Xo.geo.azimuthalEqualArea=function(){return Qe(Hc)}).raw=Hc;var Fc=dr(function(n){var t=Math.acos(n);return t&&t/Math.sin(t)},bt);(Xo.geo.azimuthalEquidistant=function(){return Qe(Fc)}).raw=Fc,(Xo.geo.conicConformal=function(){return je(mr)}).raw=mr,(Xo.geo.conicEquidistant=function(){return je(yr)}).raw=yr;var Oc=dr(function(n){return 1/n},Math.atan);(Xo.geo.gnomonic=function(){return Qe(Oc)}).raw=Oc,xr.invert=function(n,t){return[n,2*Math.atan(Math.exp(t))-Ea]},(Xo.geo.mercator=function(){return Mr(xr)}).raw=xr;var Yc=dr(function(){return 1},Math.asin);(Xo.geo.orthographic=function(){return Qe(Yc)}).raw=Yc;var Ic=dr(function(n){return 1/(1+n)},function(n){return 2*Math.atan(n)});(Xo.geo.stereographic=function(){return Qe(Ic)}).raw=Ic,_r.invert=function(n,t){return[-t,2*Math.atan(Math.exp(n))-Ea]},(Xo.geo.transverseMercator=function(){var n=Mr(_r),t=n.center,e=n.rotate;return n.center=function(n){return n?t([-n[1],n[0]]):(n=t(),[-n[1],n[0]])},n.rotate=function(n){return n?e([n[0],n[1],n.length>2?n[2]+90:90]):(n=e(),[n[0],n[1],n[2]-90])},n.rotate([0,0])}).raw=_r,Xo.geom={},Xo.geom.hull=function(n){function t(n){if(n.length<3)return[];var t,u=_t(e),i=_t(r),o=n.length,a=[],c=[];for(t=0;o>t;t++)a.push([+u.call(this,n[t],t),+i.call(this,n[t],t),t]);for(a.sort(kr),t=0;o>t;t++)c.push([a[t][0],-a[t][1]]);var s=Sr(a),l=Sr(c),f=l[0]===s[0],h=l[l.length-1]===s[s.length-1],g=[];for(t=s.length-1;t>=0;--t)g.push(n[a[s[t]][2]]);for(t=+f;t=r&&s.x<=i&&s.y>=u&&s.y<=o?[[r,o],[i,o],[i,u],[r,u]]:[];l.point=n[a]}),t}function e(n){return n.map(function(n,t){return{x:Math.round(i(n,t)/Aa)*Aa,y:Math.round(o(n,t)/Aa)*Aa,i:t}})}var r=br,u=wr,i=r,o=u,a=Kc;return n?t(n):(t.links=function(n){return nu(e(n)).edges.filter(function(n){return n.l&&n.r}).map(function(t){return{source:n[t.l.i],target:n[t.r.i]}})},t.triangles=function(n){var t=[];return nu(e(n)).cells.forEach(function(e,r){for(var u,i,o=e.site,a=e.edges.sort(jr),c=-1,s=a.length,l=a[s-1].edge,f=l.l===o?l.r:l.l;++c=s,h=r>=l,g=(h<<1)+f;n.leaf=!1,n=n.nodes[g]||(n.nodes[g]=iu()),f?u=s:a=s,h?o=l:c=l,i(n,t,e,r,u,o,a,c)}var l,f,h,g,p,v,d,m,y,x=_t(a),M=_t(c);if(null!=t)v=t,d=e,m=r,y=u;else if(m=y=-(v=d=1/0),f=[],h=[],p=n.length,o)for(g=0;p>g;++g)l=n[g],l.xm&&(m=l.x),l.y>y&&(y=l.y),f.push(l.x),h.push(l.y);else for(g=0;p>g;++g){var _=+x(l=n[g],g),b=+M(l,g);v>_&&(v=_),d>b&&(d=b),_>m&&(m=_),b>y&&(y=b),f.push(_),h.push(b)}var w=m-v,S=y-d;w>S?y=d+w:m=v+S;var k=iu();if(k.add=function(n){i(k,n,+x(n,++g),+M(n,g),v,d,m,y)},k.visit=function(n){ou(n,k,v,d,m,y)},g=-1,null==t){for(;++g=0?n.substring(0,t):n,r=t>=0?n.substring(t+1):"in";return e=ts.get(e)||ns,r=es.get(r)||bt,gu(r(e.apply(null,$o.call(arguments,1))))},Xo.interpolateHcl=Eu,Xo.interpolateHsl=Au,Xo.interpolateLab=Cu,Xo.interpolateRound=Nu,Xo.transform=function(n){var t=Wo.createElementNS(Xo.ns.prefix.svg,"g");return(Xo.transform=function(n){if(null!=n){t.setAttribute("transform",n);var e=t.transform.baseVal.consolidate()}return new Lu(e?e.matrix:rs)})(n)},Lu.prototype.toString=function(){return"translate("+this.translate+")rotate("+this.rotate+")skewX("+this.skew+")scale("+this.scale+")"};var rs={a:1,b:0,c:0,d:1,e:0,f:0};Xo.interpolateTransform=Ru,Xo.layout={},Xo.layout.bundle=function(){return function(n){for(var t=[],e=-1,r=n.length;++ea*a/d){if(p>c){var s=t.charge/c;n.px-=i*s,n.py-=o*s}return!0}if(t.point&&c&&p>c){var s=t.pointCharge/c;n.px-=i*s,n.py-=o*s}}return!t.charge}}function t(n){n.px=Xo.event.x,n.py=Xo.event.y,a.resume()}var e,r,u,i,o,a={},c=Xo.dispatch("start","tick","end"),s=[1,1],l=.9,f=us,h=is,g=-30,p=os,v=.1,d=.64,m=[],y=[];return a.tick=function(){if((r*=.99)<.005)return c.end({type:"end",alpha:r=0}),!0;var t,e,a,f,h,p,d,x,M,_=m.length,b=y.length;for(e=0;b>e;++e)a=y[e],f=a.source,h=a.target,x=h.x-f.x,M=h.y-f.y,(p=x*x+M*M)&&(p=r*i[e]*((p=Math.sqrt(p))-u[e])/p,x*=p,M*=p,h.x-=x*(d=f.weight/(h.weight+f.weight)),h.y-=M*d,f.x+=x*(d=1-d),f.y+=M*d);if((d=r*v)&&(x=s[0]/2,M=s[1]/2,e=-1,d))for(;++e<_;)a=m[e],a.x+=(x-a.x)*d,a.y+=(M-a.y)*d;if(g)for(Zu(t=Xo.geom.quadtree(m),r,o),e=-1;++e<_;)(a=m[e]).fixed||t.visit(n(a));for(e=-1;++e<_;)a=m[e],a.fixed?(a.x=a.px,a.y=a.py):(a.x-=(a.px-(a.px=a.x))*l,a.y-=(a.py-(a.py=a.y))*l);c.tick({type:"tick",alpha:r})},a.nodes=function(n){return arguments.length?(m=n,a):m},a.links=function(n){return arguments.length?(y=n,a):y},a.size=function(n){return arguments.length?(s=n,a):s},a.linkDistance=function(n){return arguments.length?(f="function"==typeof n?n:+n,a):f},a.distance=a.linkDistance,a.linkStrength=function(n){return arguments.length?(h="function"==typeof n?n:+n,a):h},a.friction=function(n){return arguments.length?(l=+n,a):l},a.charge=function(n){return arguments.length?(g="function"==typeof n?n:+n,a):g},a.chargeDistance=function(n){return arguments.length?(p=n*n,a):Math.sqrt(p)},a.gravity=function(n){return arguments.length?(v=+n,a):v},a.theta=function(n){return arguments.length?(d=n*n,a):Math.sqrt(d)},a.alpha=function(n){return arguments.length?(n=+n,r?r=n>0?n:0:n>0&&(c.start({type:"start",alpha:r=n}),Xo.timer(a.tick)),a):r},a.start=function(){function n(n,r){if(!e){for(e=new Array(c),a=0;c>a;++a)e[a]=[];for(a=0;s>a;++a){var u=y[a];e[u.source.index].push(u.target),e[u.target.index].push(u.source)}}for(var i,o=e[t],a=-1,s=o.length;++at;++t)(r=m[t]).index=t,r.weight=0;for(t=0;l>t;++t)r=y[t],"number"==typeof r.source&&(r.source=m[r.source]),"number"==typeof r.target&&(r.target=m[r.target]),++r.source.weight,++r.target.weight;for(t=0;c>t;++t)r=m[t],isNaN(r.x)&&(r.x=n("x",p)),isNaN(r.y)&&(r.y=n("y",v)),isNaN(r.px)&&(r.px=r.x),isNaN(r.py)&&(r.py=r.y);if(u=[],"function"==typeof f)for(t=0;l>t;++t)u[t]=+f.call(this,y[t],t);else for(t=0;l>t;++t)u[t]=f;if(i=[],"function"==typeof h)for(t=0;l>t;++t)i[t]=+h.call(this,y[t],t);else for(t=0;l>t;++t)i[t]=h;if(o=[],"function"==typeof g)for(t=0;c>t;++t)o[t]=+g.call(this,m[t],t);else for(t=0;c>t;++t)o[t]=g;return a.resume()},a.resume=function(){return a.alpha(.1)},a.stop=function(){return a.alpha(0)},a.drag=function(){return e||(e=Xo.behavior.drag().origin(bt).on("dragstart.force",Fu).on("drag.force",t).on("dragend.force",Ou)),arguments.length?(this.on("mouseover.force",Yu).on("mouseout.force",Iu).call(e),void 0):e},Xo.rebind(a,c,"on")};var us=20,is=1,os=1/0;Xo.layout.hierarchy=function(){function n(t,o,a){var c=u.call(e,t,o);if(t.depth=o,a.push(t),c&&(s=c.length)){for(var s,l,f=-1,h=t.children=new Array(s),g=0,p=o+1;++fg;++g)for(u.call(n,s[0][g],p=v[g],l[0][g][1]),h=1;d>h;++h)u.call(n,s[h][g],p+=l[h-1][g][1],l[h][g][1]);return a}var t=bt,e=Qu,r=ni,u=Ku,i=Ju,o=Gu;return n.values=function(e){return arguments.length?(t=e,n):t},n.order=function(t){return arguments.length?(e="function"==typeof t?t:cs.get(t)||Qu,n):e},n.offset=function(t){return arguments.length?(r="function"==typeof t?t:ss.get(t)||ni,n):r},n.x=function(t){return arguments.length?(i=t,n):i},n.y=function(t){return arguments.length?(o=t,n):o},n.out=function(t){return arguments.length?(u=t,n):u},n};var cs=Xo.map({"inside-out":function(n){var t,e,r=n.length,u=n.map(ti),i=n.map(ei),o=Xo.range(r).sort(function(n,t){return u[n]-u[t]}),a=0,c=0,s=[],l=[];for(t=0;r>t;++t)e=o[t],c>a?(a+=i[e],s.push(e)):(c+=i[e],l.push(e));return l.reverse().concat(s)},reverse:function(n){return Xo.range(n.length).reverse()},"default":Qu}),ss=Xo.map({silhouette:function(n){var t,e,r,u=n.length,i=n[0].length,o=[],a=0,c=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];r>a&&(a=r),o.push(r)}for(e=0;i>e;++e)c[e]=(a-o[e])/2;return c},wiggle:function(n){var t,e,r,u,i,o,a,c,s,l=n.length,f=n[0],h=f.length,g=[];for(g[0]=c=s=0,e=1;h>e;++e){for(t=0,u=0;l>t;++t)u+=n[t][e][1];for(t=0,i=0,a=f[e][0]-f[e-1][0];l>t;++t){for(r=0,o=(n[t][e][1]-n[t][e-1][1])/(2*a);t>r;++r)o+=(n[r][e][1]-n[r][e-1][1])/a;i+=o*n[t][e][1]}g[e]=c-=u?i/u*a:0,s>c&&(s=c)}for(e=0;h>e;++e)g[e]-=s;return g},expand:function(n){var t,e,r,u=n.length,i=n[0].length,o=1/u,a=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];if(r)for(t=0;u>t;t++)n[t][e][1]/=r;else for(t=0;u>t;t++)n[t][e][1]=o}for(e=0;i>e;++e)a[e]=0;return a},zero:ni});Xo.layout.histogram=function(){function n(n,i){for(var o,a,c=[],s=n.map(e,this),l=r.call(this,s,i),f=u.call(this,l,s,i),i=-1,h=s.length,g=f.length-1,p=t?1:1/h;++i0)for(i=-1;++i=l[0]&&a<=l[1]&&(o=c[Xo.bisect(f,a,1,g)-1],o.y+=p,o.push(n[i]));return c}var t=!0,e=Number,r=oi,u=ui;return n.value=function(t){return arguments.length?(e=t,n):e},n.range=function(t){return arguments.length?(r=_t(t),n):r},n.bins=function(t){return arguments.length?(u="number"==typeof t?function(n){return ii(n,t)}:_t(t),n):u},n.frequency=function(e){return arguments.length?(t=!!e,n):t},n},Xo.layout.tree=function(){function n(n,i){function o(n,t){var r=n.children,u=n._tree;if(r&&(i=r.length)){for(var i,a,s,l=r[0],f=l,h=-1;++h0&&(di(mi(a,n,r),n,u),s+=u,l+=u),f+=a._tree.mod,s+=i._tree.mod,h+=c._tree.mod,l+=o._tree.mod;a&&!si(o)&&(o._tree.thread=a,o._tree.mod+=f-l),i&&!ci(c)&&(c._tree.thread=i,c._tree.mod+=s-h,r=n)}return r}var s=t.call(this,n,i),l=s[0];pi(l,function(n,t){n._tree={ancestor:n,prelim:0,mod:0,change:0,shift:0,number:t?t._tree.number+1:0}}),o(l),a(l,-l._tree.prelim);var f=li(l,hi),h=li(l,fi),g=li(l,gi),p=f.x-e(f,h)/2,v=h.x+e(h,f)/2,d=g.depth||1;return pi(l,u?function(n){n.x*=r[0],n.y=n.depth*r[1],delete n._tree}:function(n){n.x=(n.x-p)/(v-p)*r[0],n.y=n.depth/d*r[1],delete n._tree}),s}var t=Xo.layout.hierarchy().sort(null).value(null),e=ai,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Vu(n,t)},Xo.layout.pack=function(){function n(n,i){var o=e.call(this,n,i),a=o[0],c=u[0],s=u[1],l=null==t?Math.sqrt:"function"==typeof t?t:function(){return t};if(a.x=a.y=0,pi(a,function(n){n.r=+l(n.value)}),pi(a,bi),r){var f=r*(t?1:Math.max(2*a.r/c,2*a.r/s))/2;pi(a,function(n){n.r+=f}),pi(a,bi),pi(a,function(n){n.r-=f})}return ki(a,c/2,s/2,t?1:1/Math.max(2*a.r/c,2*a.r/s)),o}var t,e=Xo.layout.hierarchy().sort(yi),r=0,u=[1,1];return n.size=function(t){return arguments.length?(u=t,n):u},n.radius=function(e){return arguments.length?(t=null==e||"function"==typeof e?e:+e,n):t},n.padding=function(t){return arguments.length?(r=+t,n):r},Vu(n,e)},Xo.layout.cluster=function(){function n(n,i){var o,a=t.call(this,n,i),c=a[0],s=0;pi(c,function(n){var t=n.children;t&&t.length?(n.x=Ci(t),n.y=Ai(t)):(n.x=o?s+=e(n,o):0,n.y=0,o=n)});var l=Ni(c),f=Li(c),h=l.x-e(l,f)/2,g=f.x+e(f,l)/2;return pi(c,u?function(n){n.x=(n.x-c.x)*r[0],n.y=(c.y-n.y)*r[1]}:function(n){n.x=(n.x-h)/(g-h)*r[0],n.y=(1-(c.y?n.y/c.y:1))*r[1]}),a}var t=Xo.layout.hierarchy().sort(null).value(null),e=ai,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Vu(n,t)},Xo.layout.treemap=function(){function n(n,t){for(var e,r,u=-1,i=n.length;++ut?0:t),e.area=isNaN(r)||0>=r?0:r}function t(e){var i=e.children;if(i&&i.length){var o,a,c,s=f(e),l=[],h=i.slice(),p=1/0,v="slice"===g?s.dx:"dice"===g?s.dy:"slice-dice"===g?1&e.depth?s.dy:s.dx:Math.min(s.dx,s.dy);for(n(h,s.dx*s.dy/e.value),l.area=0;(c=h.length)>0;)l.push(o=h[c-1]),l.area+=o.area,"squarify"!==g||(a=r(l,v))<=p?(h.pop(),p=a):(l.area-=l.pop().area,u(l,v,s,!1),v=Math.min(s.dx,s.dy),l.length=l.area=0,p=1/0);l.length&&(u(l,v,s,!0),l.length=l.area=0),i.forEach(t)}}function e(t){var r=t.children;if(r&&r.length){var i,o=f(t),a=r.slice(),c=[];for(n(a,o.dx*o.dy/t.value),c.area=0;i=a.pop();)c.push(i),c.area+=i.area,null!=i.z&&(u(c,i.z?o.dx:o.dy,o,!a.length),c.length=c.area=0);r.forEach(e)}}function r(n,t){for(var e,r=n.area,u=0,i=1/0,o=-1,a=n.length;++oe&&(i=e),e>u&&(u=e));return r*=r,t*=t,r?Math.max(t*u*p/r,r/(t*i*p)):1/0}function u(n,t,e,r){var u,i=-1,o=n.length,a=e.x,s=e.y,l=t?c(n.area/t):0;if(t==e.dx){for((r||l>e.dy)&&(l=e.dy);++ie.dx)&&(l=e.dx);++ie&&(t=1),1>e&&(n=0),function(){var e,r,u;do e=2*Math.random()-1,r=2*Math.random()-1,u=e*e+r*r;while(!u||u>1);return n+t*e*Math.sqrt(-2*Math.log(u)/u)}},logNormal:function(){var n=Xo.random.normal.apply(Xo,arguments);return function(){return Math.exp(n())}},bates:function(n){var t=Xo.random.irwinHall(n);return function(){return t()/n}},irwinHall:function(n){return function(){for(var t=0,e=0;n>e;e++)t+=Math.random();return t}}},Xo.scale={};var ls={floor:bt,ceil:bt};Xo.scale.linear=function(){return Hi([0,1],[0,1],fu,!1)};var fs={s:1,g:1,p:1,r:1,e:1};Xo.scale.log=function(){return $i(Xo.scale.linear().domain([0,1]),10,!0,[1,10])};var hs=Xo.format(".0e"),gs={floor:function(n){return-Math.ceil(-n)},ceil:function(n){return-Math.floor(-n)}};Xo.scale.pow=function(){return Bi(Xo.scale.linear(),1,[0,1])},Xo.scale.sqrt=function(){return Xo.scale.pow().exponent(.5)},Xo.scale.ordinal=function(){return Ji([],{t:"range",a:[[]]})},Xo.scale.category10=function(){return Xo.scale.ordinal().range(ps)},Xo.scale.category20=function(){return Xo.scale.ordinal().range(vs)},Xo.scale.category20b=function(){return Xo.scale.ordinal().range(ds)},Xo.scale.category20c=function(){return Xo.scale.ordinal().range(ms)};var ps=[2062260,16744206,2924588,14034728,9725885,9197131,14907330,8355711,12369186,1556175].map(ht),vs=[2062260,11454440,16744206,16759672,2924588,10018698,14034728,16750742,9725885,12955861,9197131,12885140,14907330,16234194,8355711,13092807,12369186,14408589,1556175,10410725].map(ht),ds=[3750777,5395619,7040719,10264286,6519097,9216594,11915115,13556636,9202993,12426809,15186514,15190932,8666169,11356490,14049643,15177372,8077683,10834324,13528509,14589654].map(ht),ms=[3244733,7057110,10406625,13032431,15095053,16616764,16625259,16634018,3253076,7652470,10607003,13101504,7695281,10394312,12369372,14342891,6513507,9868950,12434877,14277081].map(ht);Xo.scale.quantile=function(){return Gi([],[]) -},Xo.scale.quantize=function(){return Ki(0,1,[0,1])},Xo.scale.threshold=function(){return Qi([.5],[0,1])},Xo.scale.identity=function(){return no([0,1])},Xo.svg={},Xo.svg.arc=function(){function n(){var n=t.apply(this,arguments),i=e.apply(this,arguments),o=r.apply(this,arguments)+ys,a=u.apply(this,arguments)+ys,c=(o>a&&(c=o,o=a,a=c),a-o),s=Sa>c?"0":"1",l=Math.cos(o),f=Math.sin(o),h=Math.cos(a),g=Math.sin(a);return c>=xs?n?"M0,"+i+"A"+i+","+i+" 0 1,1 0,"+-i+"A"+i+","+i+" 0 1,1 0,"+i+"M0,"+n+"A"+n+","+n+" 0 1,0 0,"+-n+"A"+n+","+n+" 0 1,0 0,"+n+"Z":"M0,"+i+"A"+i+","+i+" 0 1,1 0,"+-i+"A"+i+","+i+" 0 1,1 0,"+i+"Z":n?"M"+i*l+","+i*f+"A"+i+","+i+" 0 "+s+",1 "+i*h+","+i*g+"L"+n*h+","+n*g+"A"+n+","+n+" 0 "+s+",0 "+n*l+","+n*f+"Z":"M"+i*l+","+i*f+"A"+i+","+i+" 0 "+s+",1 "+i*h+","+i*g+"L0,0"+"Z"}var t=to,e=eo,r=ro,u=uo;return n.innerRadius=function(e){return arguments.length?(t=_t(e),n):t},n.outerRadius=function(t){return arguments.length?(e=_t(t),n):e},n.startAngle=function(t){return arguments.length?(r=_t(t),n):r},n.endAngle=function(t){return arguments.length?(u=_t(t),n):u},n.centroid=function(){var n=(t.apply(this,arguments)+e.apply(this,arguments))/2,i=(r.apply(this,arguments)+u.apply(this,arguments))/2+ys;return[Math.cos(i)*n,Math.sin(i)*n]},n};var ys=-Ea,xs=ka-Aa;Xo.svg.line=function(){return io(bt)};var Ms=Xo.map({linear:oo,"linear-closed":ao,step:co,"step-before":so,"step-after":lo,basis:mo,"basis-open":yo,"basis-closed":xo,bundle:Mo,cardinal:go,"cardinal-open":fo,"cardinal-closed":ho,monotone:Eo});Ms.forEach(function(n,t){t.key=n,t.closed=/-closed$/.test(n)});var _s=[0,2/3,1/3,0],bs=[0,1/3,2/3,0],ws=[0,1/6,2/3,1/6];Xo.svg.line.radial=function(){var n=io(Ao);return n.radius=n.x,delete n.x,n.angle=n.y,delete n.y,n},so.reverse=lo,lo.reverse=so,Xo.svg.area=function(){return Co(bt)},Xo.svg.area.radial=function(){var n=Co(Ao);return n.radius=n.x,delete n.x,n.innerRadius=n.x0,delete n.x0,n.outerRadius=n.x1,delete n.x1,n.angle=n.y,delete n.y,n.startAngle=n.y0,delete n.y0,n.endAngle=n.y1,delete n.y1,n},Xo.svg.chord=function(){function n(n,a){var c=t(this,i,n,a),s=t(this,o,n,a);return"M"+c.p0+r(c.r,c.p1,c.a1-c.a0)+(e(c,s)?u(c.r,c.p1,c.r,c.p0):u(c.r,c.p1,s.r,s.p0)+r(s.r,s.p1,s.a1-s.a0)+u(s.r,s.p1,c.r,c.p0))+"Z"}function t(n,t,e,r){var u=t.call(n,e,r),i=a.call(n,u,r),o=c.call(n,u,r)+ys,l=s.call(n,u,r)+ys;return{r:i,a0:o,a1:l,p0:[i*Math.cos(o),i*Math.sin(o)],p1:[i*Math.cos(l),i*Math.sin(l)]}}function e(n,t){return n.a0==t.a0&&n.a1==t.a1}function r(n,t,e){return"A"+n+","+n+" 0 "+ +(e>Sa)+",1 "+t}function u(n,t,e,r){return"Q 0,0 "+r}var i=hr,o=gr,a=No,c=ro,s=uo;return n.radius=function(t){return arguments.length?(a=_t(t),n):a},n.source=function(t){return arguments.length?(i=_t(t),n):i},n.target=function(t){return arguments.length?(o=_t(t),n):o},n.startAngle=function(t){return arguments.length?(c=_t(t),n):c},n.endAngle=function(t){return arguments.length?(s=_t(t),n):s},n},Xo.svg.diagonal=function(){function n(n,u){var i=t.call(this,n,u),o=e.call(this,n,u),a=(i.y+o.y)/2,c=[i,{x:i.x,y:a},{x:o.x,y:a},o];return c=c.map(r),"M"+c[0]+"C"+c[1]+" "+c[2]+" "+c[3]}var t=hr,e=gr,r=Lo;return n.source=function(e){return arguments.length?(t=_t(e),n):t},n.target=function(t){return arguments.length?(e=_t(t),n):e},n.projection=function(t){return arguments.length?(r=t,n):r},n},Xo.svg.diagonal.radial=function(){var n=Xo.svg.diagonal(),t=Lo,e=n.projection;return n.projection=function(n){return arguments.length?e(zo(t=n)):t},n},Xo.svg.symbol=function(){function n(n,r){return(Ss.get(t.call(this,n,r))||Ro)(e.call(this,n,r))}var t=To,e=qo;return n.type=function(e){return arguments.length?(t=_t(e),n):t},n.size=function(t){return arguments.length?(e=_t(t),n):e},n};var Ss=Xo.map({circle:Ro,cross:function(n){var t=Math.sqrt(n/5)/2;return"M"+-3*t+","+-t+"H"+-t+"V"+-3*t+"H"+t+"V"+-t+"H"+3*t+"V"+t+"H"+t+"V"+3*t+"H"+-t+"V"+t+"H"+-3*t+"Z"},diamond:function(n){var t=Math.sqrt(n/(2*Cs)),e=t*Cs;return"M0,"+-t+"L"+e+",0"+" 0,"+t+" "+-e+",0"+"Z"},square:function(n){var t=Math.sqrt(n)/2;return"M"+-t+","+-t+"L"+t+","+-t+" "+t+","+t+" "+-t+","+t+"Z"},"triangle-down":function(n){var t=Math.sqrt(n/As),e=t*As/2;return"M0,"+e+"L"+t+","+-e+" "+-t+","+-e+"Z"},"triangle-up":function(n){var t=Math.sqrt(n/As),e=t*As/2;return"M0,"+-e+"L"+t+","+e+" "+-t+","+e+"Z"}});Xo.svg.symbolTypes=Ss.keys();var ks,Es,As=Math.sqrt(3),Cs=Math.tan(30*Na),Ns=[],Ls=0;Ns.call=da.call,Ns.empty=da.empty,Ns.node=da.node,Ns.size=da.size,Xo.transition=function(n){return arguments.length?ks?n.transition():n:xa.transition()},Xo.transition.prototype=Ns,Ns.select=function(n){var t,e,r,u=this.id,i=[];n=M(n);for(var o=-1,a=this.length;++oi;i++){u.push(t=[]);for(var e=this[i],a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return Do(u,this.id)},Ns.tween=function(n,t){var e=this.id;return arguments.length<2?this.node().__transition__[e].tween.get(n):R(this,null==t?function(t){t.__transition__[e].tween.remove(n)}:function(r){r.__transition__[e].tween.set(n,t)})},Ns.attr=function(n,t){function e(){this.removeAttribute(a)}function r(){this.removeAttributeNS(a.space,a.local)}function u(n){return null==n?e:(n+="",function(){var t,e=this.getAttribute(a);return e!==n&&(t=o(e,n),function(n){this.setAttribute(a,t(n))})})}function i(n){return null==n?r:(n+="",function(){var t,e=this.getAttributeNS(a.space,a.local);return e!==n&&(t=o(e,n),function(n){this.setAttributeNS(a.space,a.local,t(n))})})}if(arguments.length<2){for(t in n)this.attr(t,n[t]);return this}var o="transform"==n?Ru:fu,a=Xo.ns.qualify(n);return Po(this,"attr."+n,t,a.local?i:u)},Ns.attrTween=function(n,t){function e(n,e){var r=t.call(this,n,e,this.getAttribute(u));return r&&function(n){this.setAttribute(u,r(n))}}function r(n,e){var r=t.call(this,n,e,this.getAttributeNS(u.space,u.local));return r&&function(n){this.setAttributeNS(u.space,u.local,r(n))}}var u=Xo.ns.qualify(n);return this.tween("attr."+n,u.local?r:e)},Ns.style=function(n,t,e){function r(){this.style.removeProperty(n)}function u(t){return null==t?r:(t+="",function(){var r,u=Go.getComputedStyle(this,null).getPropertyValue(n);return u!==t&&(r=fu(u,t),function(t){this.style.setProperty(n,r(t),e)})})}var i=arguments.length;if(3>i){if("string"!=typeof n){2>i&&(t="");for(e in n)this.style(e,n[e],t);return this}e=""}return Po(this,"style."+n,t,u)},Ns.styleTween=function(n,t,e){function r(r,u){var i=t.call(this,r,u,Go.getComputedStyle(this,null).getPropertyValue(n));return i&&function(t){this.style.setProperty(n,i(t),e)}}return arguments.length<3&&(e=""),this.tween("style."+n,r)},Ns.text=function(n){return Po(this,"text",n,Uo)},Ns.remove=function(){return this.each("end.transition",function(){var n;this.__transition__.count<2&&(n=this.parentNode)&&n.removeChild(this)})},Ns.ease=function(n){var t=this.id;return arguments.length<1?this.node().__transition__[t].ease:("function"!=typeof n&&(n=Xo.ease.apply(Xo,arguments)),R(this,function(e){e.__transition__[t].ease=n}))},Ns.delay=function(n){var t=this.id;return R(this,"function"==typeof n?function(e,r,u){e.__transition__[t].delay=+n.call(e,e.__data__,r,u)}:(n=+n,function(e){e.__transition__[t].delay=n}))},Ns.duration=function(n){var t=this.id;return R(this,"function"==typeof n?function(e,r,u){e.__transition__[t].duration=Math.max(1,n.call(e,e.__data__,r,u))}:(n=Math.max(1,n),function(e){e.__transition__[t].duration=n}))},Ns.each=function(n,t){var e=this.id;if(arguments.length<2){var r=Es,u=ks;ks=e,R(this,function(t,r,u){Es=t.__transition__[e],n.call(t,t.__data__,r,u)}),Es=r,ks=u}else R(this,function(r){var u=r.__transition__[e];(u.event||(u.event=Xo.dispatch("start","end"))).on(n,t)});return this},Ns.transition=function(){for(var n,t,e,r,u=this.id,i=++Ls,o=[],a=0,c=this.length;c>a;a++){o.push(n=[]);for(var t=this[a],s=0,l=t.length;l>s;s++)(e=t[s])&&(r=Object.create(e.__transition__[u]),r.delay+=r.duration,jo(e,s,i,r)),n.push(e)}return Do(o,i)},Xo.svg.axis=function(){function n(n){n.each(function(){var n,s=Xo.select(this),l=this.__chart__||e,f=this.__chart__=e.copy(),h=null==c?f.ticks?f.ticks.apply(f,a):f.domain():c,g=null==t?f.tickFormat?f.tickFormat.apply(f,a):bt:t,p=s.selectAll(".tick").data(h,f),v=p.enter().insert("g",".domain").attr("class","tick").style("opacity",Aa),d=Xo.transition(p.exit()).style("opacity",Aa).remove(),m=Xo.transition(p).style("opacity",1),y=Ri(f),x=s.selectAll(".domain").data([0]),M=(x.enter().append("path").attr("class","domain"),Xo.transition(x));v.append("line"),v.append("text");var _=v.select("line"),b=m.select("line"),w=p.select("text").text(g),S=v.select("text"),k=m.select("text");switch(r){case"bottom":n=Ho,_.attr("y2",u),S.attr("y",Math.max(u,0)+o),b.attr("x2",0).attr("y2",u),k.attr("x",0).attr("y",Math.max(u,0)+o),w.attr("dy",".71em").style("text-anchor","middle"),M.attr("d","M"+y[0]+","+i+"V0H"+y[1]+"V"+i);break;case"top":n=Ho,_.attr("y2",-u),S.attr("y",-(Math.max(u,0)+o)),b.attr("x2",0).attr("y2",-u),k.attr("x",0).attr("y",-(Math.max(u,0)+o)),w.attr("dy","0em").style("text-anchor","middle"),M.attr("d","M"+y[0]+","+-i+"V0H"+y[1]+"V"+-i);break;case"left":n=Fo,_.attr("x2",-u),S.attr("x",-(Math.max(u,0)+o)),b.attr("x2",-u).attr("y2",0),k.attr("x",-(Math.max(u,0)+o)).attr("y",0),w.attr("dy",".32em").style("text-anchor","end"),M.attr("d","M"+-i+","+y[0]+"H0V"+y[1]+"H"+-i);break;case"right":n=Fo,_.attr("x2",u),S.attr("x",Math.max(u,0)+o),b.attr("x2",u).attr("y2",0),k.attr("x",Math.max(u,0)+o).attr("y",0),w.attr("dy",".32em").style("text-anchor","start"),M.attr("d","M"+i+","+y[0]+"H0V"+y[1]+"H"+i)}if(f.rangeBand){var E=f,A=E.rangeBand()/2;l=f=function(n){return E(n)+A}}else l.rangeBand?l=f:d.call(n,f);v.call(n,l),m.call(n,f)})}var t,e=Xo.scale.linear(),r=zs,u=6,i=6,o=3,a=[10],c=null;return n.scale=function(t){return arguments.length?(e=t,n):e},n.orient=function(t){return arguments.length?(r=t in qs?t+"":zs,n):r},n.ticks=function(){return arguments.length?(a=arguments,n):a},n.tickValues=function(t){return arguments.length?(c=t,n):c},n.tickFormat=function(e){return arguments.length?(t=e,n):t},n.tickSize=function(t){var e=arguments.length;return e?(u=+t,i=+arguments[e-1],n):u},n.innerTickSize=function(t){return arguments.length?(u=+t,n):u},n.outerTickSize=function(t){return arguments.length?(i=+t,n):i},n.tickPadding=function(t){return arguments.length?(o=+t,n):o},n.tickSubdivide=function(){return arguments.length&&n},n};var zs="bottom",qs={top:1,right:1,bottom:1,left:1};Xo.svg.brush=function(){function n(i){i.each(function(){var i=Xo.select(this).style("pointer-events","all").style("-webkit-tap-highlight-color","rgba(0,0,0,0)").on("mousedown.brush",u).on("touchstart.brush",u),o=i.selectAll(".background").data([0]);o.enter().append("rect").attr("class","background").style("visibility","hidden").style("cursor","crosshair"),i.selectAll(".extent").data([0]).enter().append("rect").attr("class","extent").style("cursor","move");var a=i.selectAll(".resize").data(p,bt);a.exit().remove(),a.enter().append("g").attr("class",function(n){return"resize "+n}).style("cursor",function(n){return Ts[n]}).append("rect").attr("x",function(n){return/[ew]$/.test(n)?-3:null}).attr("y",function(n){return/^[ns]/.test(n)?-3:null}).attr("width",6).attr("height",6).style("visibility","hidden"),a.style("display",n.empty()?"none":null);var l,f=Xo.transition(i),h=Xo.transition(o);c&&(l=Ri(c),h.attr("x",l[0]).attr("width",l[1]-l[0]),e(f)),s&&(l=Ri(s),h.attr("y",l[0]).attr("height",l[1]-l[0]),r(f)),t(f)})}function t(n){n.selectAll(".resize").attr("transform",function(n){return"translate("+l[+/e$/.test(n)]+","+f[+/^s/.test(n)]+")"})}function e(n){n.select(".extent").attr("x",l[0]),n.selectAll(".extent,.n>rect,.s>rect").attr("width",l[1]-l[0])}function r(n){n.select(".extent").attr("y",f[0]),n.selectAll(".extent,.e>rect,.w>rect").attr("height",f[1]-f[0])}function u(){function u(){32==Xo.event.keyCode&&(C||(x=null,L[0]-=l[1],L[1]-=f[1],C=2),d())}function p(){32==Xo.event.keyCode&&2==C&&(L[0]+=l[1],L[1]+=f[1],C=0,d())}function v(){var n=Xo.mouse(_),u=!1;M&&(n[0]+=M[0],n[1]+=M[1]),C||(Xo.event.altKey?(x||(x=[(l[0]+l[1])/2,(f[0]+f[1])/2]),L[0]=l[+(n[0]p?(u=r,r=p):u=p),v[0]!=r||v[1]!=u?(e?o=null:i=null,v[0]=r,v[1]=u,!0):void 0}function y(){v(),S.style("pointer-events","all").selectAll(".resize").style("display",n.empty()?"none":null),Xo.select("body").style("cursor",null),z.on("mousemove.brush",null).on("mouseup.brush",null).on("touchmove.brush",null).on("touchend.brush",null).on("keydown.brush",null).on("keyup.brush",null),N(),w({type:"brushend"})}var x,M,_=this,b=Xo.select(Xo.event.target),w=a.of(_,arguments),S=Xo.select(_),k=b.datum(),E=!/^(n|s)$/.test(k)&&c,A=!/^(e|w)$/.test(k)&&s,C=b.classed("extent"),N=O(),L=Xo.mouse(_),z=Xo.select(Go).on("keydown.brush",u).on("keyup.brush",p);if(Xo.event.changedTouches?z.on("touchmove.brush",v).on("touchend.brush",y):z.on("mousemove.brush",v).on("mouseup.brush",y),S.interrupt().selectAll("*").interrupt(),C)L[0]=l[0]-L[0],L[1]=f[0]-L[1];else if(k){var q=+/w$/.test(k),T=+/^n/.test(k);M=[l[1-q]-L[0],f[1-T]-L[1]],L[0]=l[q],L[1]=f[T]}else Xo.event.altKey&&(x=L.slice());S.style("pointer-events","none").selectAll(".resize").style("display",null),Xo.select("body").style("cursor",b.style("cursor")),w({type:"brushstart"}),v()}var i,o,a=y(n,"brushstart","brush","brushend"),c=null,s=null,l=[0,0],f=[0,0],h=!0,g=!0,p=Rs[0];return n.event=function(n){n.each(function(){var n=a.of(this,arguments),t={x:l,y:f,i:i,j:o},e=this.__chart__||t;this.__chart__=t,ks?Xo.select(this).transition().each("start.brush",function(){i=e.i,o=e.j,l=e.x,f=e.y,n({type:"brushstart"})}).tween("brush:brush",function(){var e=hu(l,t.x),r=hu(f,t.y);return i=o=null,function(u){l=t.x=e(u),f=t.y=r(u),n({type:"brush",mode:"resize"})}}).each("end.brush",function(){i=t.i,o=t.j,n({type:"brush",mode:"resize"}),n({type:"brushend"})}):(n({type:"brushstart"}),n({type:"brush",mode:"resize"}),n({type:"brushend"}))})},n.x=function(t){return arguments.length?(c=t,p=Rs[!c<<1|!s],n):c},n.y=function(t){return arguments.length?(s=t,p=Rs[!c<<1|!s],n):s},n.clamp=function(t){return arguments.length?(c&&s?(h=!!t[0],g=!!t[1]):c?h=!!t:s&&(g=!!t),n):c&&s?[h,g]:c?h:s?g:null},n.extent=function(t){var e,r,u,a,h;return arguments.length?(c&&(e=t[0],r=t[1],s&&(e=e[0],r=r[0]),i=[e,r],c.invert&&(e=c(e),r=c(r)),e>r&&(h=e,e=r,r=h),(e!=l[0]||r!=l[1])&&(l=[e,r])),s&&(u=t[0],a=t[1],c&&(u=u[1],a=a[1]),o=[u,a],s.invert&&(u=s(u),a=s(a)),u>a&&(h=u,u=a,a=h),(u!=f[0]||a!=f[1])&&(f=[u,a])),n):(c&&(i?(e=i[0],r=i[1]):(e=l[0],r=l[1],c.invert&&(e=c.invert(e),r=c.invert(r)),e>r&&(h=e,e=r,r=h))),s&&(o?(u=o[0],a=o[1]):(u=f[0],a=f[1],s.invert&&(u=s.invert(u),a=s.invert(a)),u>a&&(h=u,u=a,a=h))),c&&s?[[e,u],[r,a]]:c?[e,r]:s&&[u,a])},n.clear=function(){return n.empty()||(l=[0,0],f=[0,0],i=o=null),n},n.empty=function(){return!!c&&l[0]==l[1]||!!s&&f[0]==f[1]},Xo.rebind(n,a,"on")};var Ts={n:"ns-resize",e:"ew-resize",s:"ns-resize",w:"ew-resize",nw:"nwse-resize",ne:"nesw-resize",se:"nwse-resize",sw:"nesw-resize"},Rs=[["n","e","s","w","nw","ne","se","sw"],["e","w"],["n","s"],[]],Ds=tc.format=ac.timeFormat,Ps=Ds.utc,Us=Ps("%Y-%m-%dT%H:%M:%S.%LZ");Ds.iso=Date.prototype.toISOString&&+new Date("2000-01-01T00:00:00.000Z")?Oo:Us,Oo.parse=function(n){var t=new Date(n);return isNaN(t)?null:t},Oo.toString=Us.toString,tc.second=Rt(function(n){return new ec(1e3*Math.floor(n/1e3))},function(n,t){n.setTime(n.getTime()+1e3*Math.floor(t))},function(n){return n.getSeconds()}),tc.seconds=tc.second.range,tc.seconds.utc=tc.second.utc.range,tc.minute=Rt(function(n){return new ec(6e4*Math.floor(n/6e4))},function(n,t){n.setTime(n.getTime()+6e4*Math.floor(t))},function(n){return n.getMinutes()}),tc.minutes=tc.minute.range,tc.minutes.utc=tc.minute.utc.range,tc.hour=Rt(function(n){var t=n.getTimezoneOffset()/60;return new ec(36e5*(Math.floor(n/36e5-t)+t))},function(n,t){n.setTime(n.getTime()+36e5*Math.floor(t))},function(n){return n.getHours()}),tc.hours=tc.hour.range,tc.hours.utc=tc.hour.utc.range,tc.month=Rt(function(n){return n=tc.day(n),n.setDate(1),n},function(n,t){n.setMonth(n.getMonth()+t)},function(n){return n.getMonth()}),tc.months=tc.month.range,tc.months.utc=tc.month.utc.range;var js=[1e3,5e3,15e3,3e4,6e4,3e5,9e5,18e5,36e5,108e5,216e5,432e5,864e5,1728e5,6048e5,2592e6,7776e6,31536e6],Hs=[[tc.second,1],[tc.second,5],[tc.second,15],[tc.second,30],[tc.minute,1],[tc.minute,5],[tc.minute,15],[tc.minute,30],[tc.hour,1],[tc.hour,3],[tc.hour,6],[tc.hour,12],[tc.day,1],[tc.day,2],[tc.week,1],[tc.month,1],[tc.month,3],[tc.year,1]],Fs=Ds.multi([[".%L",function(n){return n.getMilliseconds()}],[":%S",function(n){return n.getSeconds()}],["%I:%M",function(n){return n.getMinutes()}],["%I %p",function(n){return n.getHours()}],["%a %d",function(n){return n.getDay()&&1!=n.getDate()}],["%b %d",function(n){return 1!=n.getDate()}],["%B",function(n){return n.getMonth()}],["%Y",be]]),Os={range:function(n,t,e){return Xo.range(+n,+t,e).map(Io)},floor:bt,ceil:bt};Hs.year=tc.year,tc.scale=function(){return Yo(Xo.scale.linear(),Hs,Fs)};var Ys=Hs.map(function(n){return[n[0].utc,n[1]]}),Is=Ps.multi([[".%L",function(n){return n.getUTCMilliseconds()}],[":%S",function(n){return n.getUTCSeconds()}],["%I:%M",function(n){return n.getUTCMinutes()}],["%I %p",function(n){return n.getUTCHours()}],["%a %d",function(n){return n.getUTCDay()&&1!=n.getUTCDate()}],["%b %d",function(n){return 1!=n.getUTCDate()}],["%B",function(n){return n.getUTCMonth()}],["%Y",be]]);Ys.year=tc.year.utc,tc.scale.utc=function(){return Yo(Xo.scale.linear(),Ys,Is)},Xo.text=wt(function(n){return n.responseText}),Xo.json=function(n,t){return St(n,"application/json",Zo,t)},Xo.html=function(n,t){return St(n,"text/html",Vo,t)},Xo.xml=wt(function(n){return n.responseXML}),"function"==typeof define&&define.amd?define(Xo):"object"==typeof module&&module.exports?module.exports=Xo:this.d3=Xo}(); \ No newline at end of file diff --git a/mne/html/mpld3.v0.2.min.js b/mne/html/mpld3.v0.2.min.js deleted file mode 100644 index adefb15efa7..00000000000 --- a/mne/html/mpld3.v0.2.min.js +++ /dev/null @@ -1,2 +0,0 @@ -!function(t){function s(t){var s={};for(var o in t)s[o]=t[o];return s}function o(t,s){t="undefined"!=typeof t?t:10,s="undefined"!=typeof s?s:"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";for(var o=s.charAt(Math.round(Math.random()*(s.length-11))),e=1;t>e;e++)o+=s.charAt(Math.round(Math.random()*(s.length-1)));return o}function e(s,o){var e=t.interpolate([s[0].valueOf(),s[1].valueOf()],[o[0].valueOf(),o[1].valueOf()]);return function(t){var s=e(t);return[new Date(s[0]),new Date(s[1])]}}function i(t){return"undefined"==typeof t}function r(t){return null==t||i(t)}function n(t,s){return t.length>0?t[s%t.length]:null}function a(){function s(s,n){var a=t.functor(o),p=t.functor(e),h=[],l=[],c=0,d=-1,u=0,f=!1;if(!n){n=["M"];for(var y=1;yc;)i.call(this,s[c],c)?(h.push(a.call(this,s[c],c),p.call(this,s[c],c)),c++):(h=null,c=u);h?f&&h.length>0?(l.push("M",h[0],h[1]),f=!1):(l.push(n[d]),l=l.concat(h)):f=!0}return c!=s.length&&console.warn("Warning: not all vertices used in Path"),l.join(" ")}var o=function(t){return t[0]},e=function(t){return t[1]},i=function(){return!0},r={M:1,m:1,L:1,l:1,Q:2,q:2,T:1,t:1,S:2,s:2,C:3,c:3,Z:0,z:0};return s.x=function(t){return arguments.length?(o=t,s):o},s.y=function(t){return arguments.length?(e=t,s):e},s.defined=function(t){return arguments.length?(i=t,s):i},s.call=s,s}function p(){function t(t){return s.forEach(function(s){t=s(t)}),t}var s=Array.prototype.slice.call(arguments,0),o=s.length;return t.domain=function(o){return arguments.length?(s[0].domain(o),t):s[0].domain()},t.range=function(e){return arguments.length?(s[o-1].range(e),t):s[o-1].range()},t.step=function(t){return s[t]},t}function h(t,s){if(O.call(this,t,s),this.cssclass="mpld3-"+this.props.xy+"grid","x"==this.props.xy)this.transform="translate(0,"+this.ax.height+")",this.position="bottom",this.scale=this.ax.xdom,this.tickSize=-this.ax.height;else{if("y"!=this.props.xy)throw"unrecognized grid xy specifier: should be 'x' or 'y'";this.transform="translate(0,0)",this.position="left",this.scale=this.ax.ydom,this.tickSize=-this.ax.width}}function l(t,s){O.call(this,t,s);var o={bottom:[0,this.ax.height],top:[0,0],left:[0,0],right:[this.ax.width,0]},e={bottom:"x",top:"x",left:"y",right:"y"};this.transform="translate("+o[this.props.position]+")",this.props.xy=e[this.props.position],this.cssclass="mpld3-"+this.props.xy+"axis",this.scale=this.ax[this.props.xy+"dom"]}function c(t,s){if("undefined"==typeof s){if(this.ax=null,this.fig=null,"display"!==this.trans)throw"ax must be defined if transform != 'display'"}else this.ax=s,this.fig=s.fig;if(this.zoomable="data"===t,this.x=this["x_"+t],this.y=this["y_"+t],"undefined"==typeof this.x||"undefined"==typeof this.y)throw"unrecognized coordinate code: "+t}function d(t,s){O.call(this,t,s),this.data=t.fig.get_data(this.props.data),this.pathcodes=this.props.pathcodes,this.pathcoords=new c(this.props.coordinates,this.ax),this.offsetcoords=new c(this.props.offsetcoordinates,this.ax),this.datafunc=a()}function u(t,s){O.call(this,t,s),(null==this.props.facecolors||0==this.props.facecolors.length)&&(this.props.facecolors=["none"]),(null==this.props.edgecolors||0==this.props.edgecolors.length)&&(this.props.edgecolors=["none"]);var o=this.ax.fig.get_data(this.props.offsets);(null===o||0===o.length)&&(o=[null]);var e=Math.max(this.props.paths.length,o.length);if(o.length===e)this.offsets=o;else{this.offsets=[];for(var i=0;e>i;i++)this.offsets.push(n(o,i))}this.pathcoords=new c(this.props.pathcoordinates,this.ax),this.offsetcoords=new c(this.props.offsetcoordinates,this.ax)}function f(s,o){O.call(this,s,o);var e=this.props;e.facecolor="none",e.edgecolor=e.color,delete e.color,e.edgewidth=e.linewidth,delete e.linewidth,this.defaultProps=d.prototype.defaultProps,d.call(this,s,e),this.datafunc=t.svg.line().interpolate("linear")}function y(s,o){O.call(this,s,o),this.marker=null!==this.props.markerpath?0==this.props.markerpath[0].length?null:F.path().call(this.props.markerpath[0],this.props.markerpath[1]):null===this.props.markername?null:t.svg.symbol(this.props.markername).size(Math.pow(this.props.markersize,2))();var e={paths:[this.props.markerpath],offsets:s.fig.get_data(this.props.data),xindex:this.props.xindex,yindex:this.props.yindex,offsetcoordinates:this.props.coordinates,edgecolors:[this.props.edgecolor],edgewidths:[this.props.edgewidth],facecolors:[this.props.facecolor],alphas:[this.props.alpha],zorder:this.props.zorder,id:this.props.id};this.requiredProps=u.prototype.requiredProps,this.defaultProps=u.prototype.defaultProps,u.call(this,s,e)}function g(t,s){O.call(this,t,s),this.coords=new c(this.props.coordinates,this.ax)}function m(t,s){O.call(this,t,s),this.text=this.props.text,this.position=this.props.position,this.coords=new c(this.props.coordinates,this.ax)}function x(s,o){function e(t){return new Date(t[0],t[1],t[2],t[3],t[4],t[5])}function i(t,s){return"date"!==t?s:[e(s[0]),e(s[1])]}function r(s,o,e){var i="date"===s?t.time.scale():"log"===s?t.scale.log():t.scale.linear();return i.domain(o).range(e)}O.call(this,s,o),this.axnum=this.fig.axes.length,this.axid=this.fig.figid+"_ax"+(this.axnum+1),this.clipid=this.axid+"_clip",this.props.xdomain=this.props.xdomain||this.props.xlim,this.props.ydomain=this.props.ydomain||this.props.ylim,this.sharex=[],this.sharey=[],this.elements=[];var n=this.props.bbox;this.position=[n[0]*this.fig.width,(1-n[1]-n[3])*this.fig.height],this.width=n[2]*this.fig.width,this.height=n[3]*this.fig.height,this.props.xdomain=i(this.props.xscale,this.props.xdomain),this.props.ydomain=i(this.props.yscale,this.props.ydomain),this.x=this.xdom=r(this.props.xscale,this.props.xdomain,[0,this.width]),this.y=this.ydom=r(this.props.yscale,this.props.ydomain,[this.height,0]),"date"===this.props.xscale&&(this.x=F.multiscale(t.scale.linear().domain(this.props.xlim).range(this.props.xdomain.map(Number)),this.xdom)),"date"===this.props.yscale&&(this.x=F.multiscale(t.scale.linear().domain(this.props.ylim).range(this.props.ydomain.map(Number)),this.ydom));for(var a=this.props.axes,p=0;p0&&this.buttons.forEach(function(t){t.actions.filter(s).length>0&&t.deactivate()})},F.Button=v,v.prototype=Object.create(O.prototype),v.prototype.constructor=v,v.prototype.setState=function(t){t?this.activate():this.deactivate()},v.prototype.click=function(){this.active?this.deactivate():this.activate()},v.prototype.activate=function(){this.toolbar.deactivate_by_action(this.actions),this.onActivate(),this.active=!0,this.toolbar.toolbar.select("."+this.cssclass).classed({pressed:!0}),this.sticky||this.deactivate()},v.prototype.deactivate=function(){this.onDeactivate(),this.active=!1,this.toolbar.toolbar.select("."+this.cssclass).classed({pressed:!1})},v.prototype.sticky=!1,v.prototype.actions=[],v.prototype.icon=function(){return""},v.prototype.onActivate=function(){},v.prototype.onDeactivate=function(){},v.prototype.onDraw=function(){},F.ButtonFactory=function(t){function s(t){v.call(this,t,this.buttonID)}if("string"!=typeof t.buttonID)throw"ButtonFactory: buttonID must be present and be a string";s.prototype=Object.create(v.prototype),s.prototype.constructor=s;for(var o in t)s.prototype[o]=t[o];return s},F.Plugin=A,A.prototype=Object.create(O.prototype),A.prototype.constructor=A,A.prototype.requiredProps=[],A.prototype.defaultProps={},A.prototype.draw=function(){},F.ResetPlugin=z,F.register_plugin("reset",z),z.prototype=Object.create(A.prototype),z.prototype.constructor=z,z.prototype.requiredProps=[],z.prototype.defaultProps={},F.ZoomPlugin=w,F.register_plugin("zoom",w),w.prototype=Object.create(A.prototype),w.prototype.constructor=w,w.prototype.requiredProps=[],w.prototype.defaultProps={button:!0,enabled:null},w.prototype.activate=function(){this.fig.enable_zoom()},w.prototype.deactivate=function(){this.fig.disable_zoom()},w.prototype.draw=function(){this.props.enabled?this.fig.enable_zoom():this.fig.disable_zoom()},F.BoxZoomPlugin=_,F.register_plugin("boxzoom",_),_.prototype=Object.create(A.prototype),_.prototype.constructor=_,_.prototype.requiredProps=[],_.prototype.defaultProps={button:!0,enabled:null},_.prototype.activate=function(){this.enable&&this.enable()},_.prototype.deactivate=function(){this.disable&&this.disable()},_.prototype.draw=function(){function t(t){if(this.enabled){var o=s.extent();s.empty()||t.set_axlim([o[0][0],o[1][0]],[o[0][1],o[1][1]])}t.axes.call(s.clear())}F.insert_css("#"+this.fig.figid+" rect.extent."+this.extentClass,{fill:"#fff","fill-opacity":0,stroke:"#999"});var s=this.fig.getBrush();this.enable=function(){this.fig.showBrush(this.extentClass),s.on("brushend",t.bind(this)),this.enabled=!0},this.disable=function(){this.fig.hideBrush(this.extentClass),this.enabled=!1},this.toggle=function(){this.enabled?this.disable():this.enable()},this.disable()},F.TooltipPlugin=k,F.register_plugin("tooltip",k),k.prototype=Object.create(A.prototype),k.prototype.constructor=k,k.prototype.requiredProps=["id"],k.prototype.defaultProps={labels:null,hoffset:0,voffset:10,location:"mouse"},k.prototype.draw=function(){function s(t,s){this.tooltip.style("visibility","visible").text(null===r?"("+t+")":n(r,s))}function o(){if("mouse"===a){var s=t.mouse(this.fig.canvas.node());this.x=s[0]+this.props.hoffset,this.y=s[1]-this.props.voffset}this.tooltip.attr("x",this.x).attr("y",this.y)}function e(){this.tooltip.style("visibility","hidden")}var i=F.get_element(this.props.id,this.fig),r=this.props.labels,a=this.props.location;this.tooltip=this.fig.canvas.append("text").attr("class","mpld3-tooltip-text").attr("x",0).attr("y",0).text("").style("visibility","hidden"),"bottom left"==a||"top left"==a?(this.x=i.ax.position[0]+5+this.props.hoffset,this.tooltip.style("text-anchor","beginning")):"bottom right"==a||"top right"==a?(this.x=i.ax.position[0]+i.ax.width-5+this.props.hoffset,this.tooltip.style("text-anchor","end")):this.tooltip.style("text-anchor","middle"),"bottom left"==a||"bottom right"==a?this.y=i.ax.position[1]+i.ax.height-5+this.props.voffset:("top left"==a||"top right"==a)&&(this.y=i.ax.position[1]+5+this.props.voffset),i.elements().on("mouseover",s.bind(this)).on("mousemove",o.bind(this)).on("mouseout",e.bind(this))},F.LinkedBrushPlugin=P,F.register_plugin("linkedbrush",P),P.prototype=Object.create(F.Plugin.prototype),P.prototype.constructor=P,P.prototype.requiredProps=["id"],P.prototype.defaultProps={button:!0,enabled:null},P.prototype.activate=function(){this.enable&&this.enable()},P.prototype.deactivate=function(){this.disable&&this.disable()},P.prototype.draw=function(){function s(s){l!=this&&(t.select(l).call(p.clear()),l=this,p.x(s.xdom).y(s.ydom))}function o(t){var s=h[t.axnum];if(s.length>0){var o=s[0].props.xindex,e=s[0].props.yindex,i=p.extent();p.empty()?c.selectAll("path").classed("mpld3-hidden",!1):c.selectAll("path").classed("mpld3-hidden",function(t){return i[0][0]>t[o]||i[1][0]t[e]||i[1][1]1?s[1]:""},"object"==typeof module&&module.exports?module.exports=F:this.mpld3=F,console.log("Loaded mpld3 version "+F.version)}(d3); \ No newline at end of file diff --git a/mne/html_templates/repr/tfr.html.jinja b/mne/html_templates/repr/tfr.html.jinja new file mode 100644 index 00000000000..f6ab107ab0b --- /dev/null +++ b/mne/html_templates/repr/tfr.html.jinja @@ -0,0 +1,60 @@ + + + + + + {%- for unit in units %} + + {%- if loop.index == 1 %} + + {%- endif %} + + + {%- endfor %} + + + + + {%- if inst_type == "Epochs" %} + + + + + {% endif -%} + {%- if inst_type == "Evoked" %} + + + + + {% endif -%} + + + + + + + + + {% if "taper" in tfr._dims %} + + + + + {% endif %} + + + + + + + + + + + + + + + + +
Data type{{ tfr._data_type }}
Units{{ unit }}
Data source{{ inst_type }}
Number of epochs{{ tfr.shape[0] }}
Number of averaged trials{{ nave }}
Dims{{ tfr._dims | join(", ") }}
Estimation method{{ tfr.method }}
Number of tapers{{ tfr._mt_weights.size }}
Number of channels{{ tfr.ch_names|length }}
Number of timepoints{{ tfr.times|length }}
Number of frequency bins{{ tfr.freqs|length }}
Frequency range{{ '%.2f'|format(tfr.freqs[0]) }} – {{ '%.2f'|format(tfr.freqs[-1]) }} Hz
diff --git a/mne/inverse_sparse/mxne_inverse.py b/mne/inverse_sparse/mxne_inverse.py index 9a2d8c4b5c8..703a0d30ca4 100644 --- a/mne/inverse_sparse/mxne_inverse.py +++ b/mne/inverse_sparse/mxne_inverse.py @@ -55,9 +55,7 @@ def _prepare_weights(forward, gain, source_weighting, weights, weights_min): weights = np.max(np.abs(weights.data), axis=1) weights_max = np.max(weights) if weights_min > weights_max: - raise ValueError( - "weights_min > weights_max (%s > %s)" % (weights_min, weights_max) - ) + raise ValueError(f"weights_min > weights_max ({weights_min} > {weights_max})") weights_min = weights_min / weights_max weights = weights / weights_max n_dip_per_pos = 1 if is_fixed_orient(forward) else 3 @@ -813,7 +811,7 @@ def tf_mixed_norm( if len(tstep) != len(wsize): raise ValueError( "The same number of window sizes and steps must be " - "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + f"passed. Got tstep = {tstep} and wsize = {wsize}" ) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( @@ -1090,7 +1088,7 @@ def _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps): for i, (coef1, coef2) in enumerate(zip(coefs_grid_1, coefs_grid_2)): sure_path[i] = _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps) if verbose: - logger.info("alpha %s :: sure %s" % (alpha_grid[i], sure_path[i])) + logger.info(f"alpha {alpha_grid[i]} :: sure {sure_path[i]}") best_alpha_ = alpha_grid[np.argmin(sure_path)] X = coefs_grid_1[np.argmin(sure_path)] diff --git a/mne/inverse_sparse/mxne_optim.py b/mne/inverse_sparse/mxne_optim.py index b70476991a2..dbac66a96f9 100644 --- a/mne/inverse_sparse/mxne_optim.py +++ b/mne/inverse_sparse/mxne_optim.py @@ -243,7 +243,7 @@ def _mixed_norm_solver_bcd( ) if gap < tol: - logger.debug("Convergence reached ! (gap: %s < %s)" % (gap, tol)) + logger.debug(f"Convergence reached ! (gap: {gap} < {tol})") break # using Anderson acceleration of the primal variable for faster @@ -525,7 +525,7 @@ def mixed_norm_solver( ) ) if gap < tol: - logger.info("Convergence reached ! (gap: %s < %s)" % (gap, tol)) + logger.info(f"Convergence reached ! (gap: {gap} < {tol})") break # add sources if not last iteration @@ -545,7 +545,7 @@ def mixed_norm_solver( idx = np.searchsorted(idx_active_set, idx_old_active_set) X_init[idx] = X else: - warn("Did NOT converge ! (gap: %s > %s)" % (gap, tol)) + warn(f"Did NOT converge ! (gap: {gap} > {tol})") else: X, active_set, E = l21_solver( M, G, alpha, lc, maxit=maxit, tol=tol, n_orient=n_orient, init=None @@ -640,8 +640,8 @@ def gprime(w): if weight_init is not None and weight_init.shape != (G.shape[1],): raise ValueError( - "Wrong dimension for weight initialization. Got %s. " - "Expected %s." % (weight_init.shape, (G.shape[1],)) + f"Wrong dimension for weight initialization. Got {weight_init.shape}. " + f"Expected {(G.shape[1],)}." ) weights = weight_init if weight_init is not None else np.ones(G.shape[1]) @@ -778,7 +778,7 @@ def safe_max_abs_diff(A, ia, B, ib): class _Phi: """Have phi stft as callable w/o using a lambda that does not pickle.""" - def __init__(self, wsize, tstep, n_coefs, n_times): # noqa: D102 + def __init__(self, wsize, tstep, n_coefs, n_times): self.wsize = np.atleast_1d(wsize) self.tstep = np.atleast_1d(tstep) self.n_coefs = np.atleast_1d(n_coefs) @@ -799,7 +799,7 @@ def __call__(self, x): # noqa: D105 else: return np.hstack([x @ op for op in self.ops]) / np.sqrt(self.n_dicts) - def norm(self, z, ord=2): + def norm(self, z, ord=2): # noqa: A002 """Squared L2 norm if ord == 2 and L1 norm if order == 1.""" if ord not in (1, 2): raise ValueError( @@ -819,7 +819,7 @@ def norm(self, z, ord=2): class _PhiT: """Have phi.T istft as callable w/o using a lambda that does not pickle.""" - def __init__(self, tstep, n_freqs, n_steps, n_times): # noqa: D102 + def __init__(self, tstep, n_freqs, n_steps, n_times): self.tstep = tstep self.n_freqs = n_freqs self.n_steps = n_steps @@ -977,9 +977,9 @@ def norm_epsilon(Y, l1_ratio, phi, w_space=1.0, w_time=None): p_sum_w2 = np.cumsum(w_time**2) p_sum_Yw = np.cumsum(Y * w_time) upper = p_sum_Y2 / (Y / w_time) ** 2 - 2.0 * p_sum_Yw / (Y / w_time) + p_sum_w2 - upper_greater = np.where( - upper > w_space**2 * (1.0 - l1_ratio) ** 2 / l1_ratio**2 - )[0] + upper_greater = np.where(upper > w_space**2 * (1.0 - l1_ratio) ** 2 / l1_ratio**2)[ + 0 + ] i0 = upper_greater[0] - 1 if upper_greater.size else K - 1 @@ -1270,7 +1270,7 @@ def _tf_mixed_norm_solver_bcd_( "\n Iteration %d :: n_active %d" % (i + 1, np.sum(active_set) / n_orient) ) - logger.info(" dgap %.2e :: p_obj %f :: d_obj %f" % (gap, p_obj, d_obj)) + logger.info(f" dgap {gap:.2e} :: p_obj {p_obj} :: d_obj {d_obj}") if converged: break @@ -1504,7 +1504,7 @@ def tf_mixed_norm_solver( if len(tstep) != len(wsize): raise ValueError( "The same number of window sizes and steps must be " - "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + f"passed. Got tstep = {tstep} and wsize = {wsize}" ) n_steps = np.ceil(M.shape[1] / tstep.astype(float)).astype(int) @@ -1624,7 +1624,7 @@ def iterative_tf_mixed_norm_solver( if len(tstep) != len(wsize): raise ValueError( "The same number of window sizes and steps must be " - "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + f"passed. Got tstep = {tstep} and wsize = {wsize}" ) n_steps = np.ceil(n_times / tstep.astype(float)).astype(int) diff --git a/mne/io/_read_raw.py b/mne/io/_read_raw.py index f9e715be6b0..6df23ee02f1 100644 --- a/mne/io/_read_raw.py +++ b/mne/io/_read_raw.py @@ -5,11 +5,11 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - from functools import partial from pathlib import Path from ..utils import fill_doc +from .base import BaseRaw def _read_unsupported(fname, **kwargs): @@ -110,7 +110,7 @@ def split_name_ext(fname): @fill_doc -def read_raw(fname, *, preload=False, verbose=None, **kwargs): +def read_raw(fname, *, preload=False, verbose=None, **kwargs) -> BaseRaw: """Read raw file. This function is a convenient wrapper for readers defined in `mne.io`. The diff --git a/mne/io/array/array.py b/mne/io/array/array.py index 456bd763015..dda73b80a23 100644 --- a/mne/io/array/array.py +++ b/mne/io/array/array.py @@ -52,9 +52,7 @@ class RawArray(BaseRaw): """ @verbose - def __init__( - self, data, info, first_samp=0, copy="auto", verbose=None - ): # noqa: D102 + def __init__(self, data, info, first_samp=0, copy="auto", verbose=None): _validate_type(info, "info", "info") _check_option("copy", copy, ("data", "info", "both", "auto", None)) dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 @@ -62,13 +60,14 @@ def __init__( data = np.asanyarray(orig_data, dtype=dtype) if data.ndim != 2: raise ValueError( - "Data must be a 2D array of shape (n_channels, " - "n_samples), got shape %s" % (data.shape,) + "Data must be a 2D array of shape (n_channels, n_samples), got shape " + f"{data.shape}" ) if len(data) != len(info["ch_names"]): raise ValueError( - "len(data) (%s) does not match " - 'len(info["ch_names"]) (%s)' % (len(data), len(info["ch_names"])) + 'len(data) ({}) does not match len(info["ch_names"]) ({})'.format( + len(data), len(info["ch_names"]) + ) ) assert len(info["ch_names"]) == info["nchan"] if copy in ("auto", "info", "both"): @@ -78,15 +77,14 @@ def __init__( data = data.copy() elif copy != "auto" and data is not orig_data: raise ValueError( - "data copying was not requested by copy=%r but " - "it was required to get to double floating point " - "precision" % (copy,) + f"data copying was not requested by copy={copy!r} but it was required " + "to get to double floating point precision" ) logger.info( - "Creating RawArray with %s data, n_channels=%s, n_times=%s" - % (dtype.__name__, data.shape[0], data.shape[1]) + f"Creating RawArray with {dtype.__name__} data, " + f"n_channels={data.shape[0]}, n_times={data.shape[1]}" ) - super(RawArray, self).__init__( + super().__init__( info, data, first_samps=(int(first_samp),), dtype=dtype, verbose=verbose ) logger.info( diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py index 59e9913175e..10b7c834d98 100644 --- a/mne/io/array/tests/test_array.py +++ b/mne/io/array/tests/test_array.py @@ -18,7 +18,7 @@ from mne.io.array import RawArray from mne.io.tests.test_raw import _test_raw_reader -base_dir = Path(__file__).parent.parent.parent / "tests" / "data" +base_dir = Path(__file__).parents[2] / "tests" / "data" fif_fname = base_dir / "test_raw.fif" @@ -151,7 +151,9 @@ def test_array_raw(): # plotting raw2.plot() - raw2.compute_psd(tmax=2.0, n_fft=1024).plot(average=True, spatial_colors=False) + raw2.compute_psd(tmax=2.0, n_fft=1024).plot( + average=True, amplitude=False, spatial_colors=False + ) plt.close("all") # epoching @@ -184,5 +186,5 @@ def test_array_raw(): raw = RawArray(data, info) raw.set_montage(montage) spectrum = raw.compute_psd() - spectrum.plot(average=False) # looking for nonexistent layout + spectrum.plot(average=False, amplitude=False) # looking for nonexistent layout spectrum.plot_topo() diff --git a/mne/io/artemis123/artemis123.py b/mne/io/artemis123/artemis123.py index 3cdedb3770d..99b00d36f45 100644 --- a/mne/io/artemis123/artemis123.py +++ b/mne/io/artemis123/artemis123.py @@ -23,7 +23,7 @@ @verbose def read_raw_artemis123( input_fname, preload=False, verbose=None, pos_fname=None, add_head_trans=True -): +) -> "RawArtemis123": """Read Artemis123 data as raw object. Parameters @@ -83,7 +83,7 @@ def _get_artemis123_info(fname, pos_fname=None): header_info["comments"] = "" header_info["channels"] = [] - with open(header, "r") as fid: + with open(header) as fid: # section flag # 0 - None # 1 - main header @@ -131,10 +131,7 @@ def _get_artemis123_info(fname, pos_fname=None): tmp[k] = v header_info["channels"].append(tmp) elif sectionFlag == 3: - header_info["comments"] = "%s%s" % ( - header_info["comments"], - line.strip(), - ) + header_info["comments"] = f"{header_info['comments']}{line.strip()}" elif sectionFlag == 4: header_info["num_samples"] = int(line.strip()) elif sectionFlag == 5: @@ -173,7 +170,7 @@ def _get_artemis123_info(fname, pos_fname=None): # build description desc = "" for k in ["Purpose", "Notes"]: - desc += "{} : {}\n".format(k, header_info[k]) + desc += f"{k} : {header_info[k]}\n" desc += "Comments : {}".format(header_info["comments"]) info.update( @@ -340,7 +337,7 @@ def __init__( verbose=None, pos_fname=None, add_head_trans=True, - ): # noqa: D102 + ): from ...chpi import ( _fit_coil_order_dev_head_trans, compute_chpi_amplitudes, @@ -363,7 +360,7 @@ def __init__( last_samps = [header_info.get("num_samples", 1) - 1] - super(RawArtemis123, self).__init__( + super().__init__( info, preload, filenames=[input_fname], diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index 9a1cdb36eec..9b002c7b712 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -36,11 +36,10 @@ def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): angle = np.rad2deg(_angle_between_quats(quat_est, quat)) dist = np.linalg.norm(trans - trans_est) - assert dist <= dist_tol, "%0.3f > %0.3f mm translation" % ( - 1000 * dist, - 1000 * dist_tol, + assert dist <= dist_tol, ( + f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} " "mm translation" ) - assert angle <= angle_tol, "%0.3f > %0.3f° rotation" % (angle, angle_tol) + assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" @pytest.mark.timeout(60) # ~25 s on Travis Linux OpenBLAS @@ -97,7 +96,10 @@ def test_dev_head_t(): assert_equal(raw.info["sfreq"], 5000.0) # test with head loc and digitization - with pytest.warns(RuntimeWarning, match="Large difference"): + with ( + pytest.warns(RuntimeWarning, match="consistency"), + pytest.warns(RuntimeWarning, match="Large difference"), + ): raw = read_raw_artemis123( short_HPI_dip_fname, add_head_trans=True, pos_fname=dig_fname ) diff --git a/mne/io/artemis123/utils.py b/mne/io/artemis123/utils.py index 95f307058ea..432e593553d 100644 --- a/mne/io/artemis123/utils.py +++ b/mne/io/artemis123/utils.py @@ -19,9 +19,9 @@ def _load_mne_locs(fname=None): if not op.exists(fname): raise OSError('MNE locs file "%s" does not exist' % (fname)) - logger.info("Loading mne loc file {}".format(fname)) + logger.info(f"Loading mne loc file {fname}") locs = dict() - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: vals = line.strip().split(",") locs[vals[0]] = np.array(vals[1::], np.float64) @@ -50,7 +50,7 @@ def _generate_mne_locs_file(output_fname): def _load_tristan_coil_locs(coil_loc_path): """Load the Coil locations from Tristan CAD drawings.""" channel_info = dict() - with open(coil_loc_path, "r") as fid: + with open(coil_loc_path) as fid: # skip 2 Header lines fid.readline() fid.readline() @@ -72,7 +72,7 @@ def _compute_mne_loc(coil_loc): Note input coil locations are in inches. """ - loc = np.zeros((12)) + loc = np.zeros(12) if (np.linalg.norm(coil_loc["inner_coil"]) == 0) and ( np.linalg.norm(coil_loc["outer_coil"]) == 0 ): @@ -91,7 +91,7 @@ def _compute_mne_loc(coil_loc): def _read_pos(fname): """Read the .pos file and return positions as dig points.""" nas, lpa, rpa, hpi, extra = None, None, None, None, None - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: line = line.strip() if len(line) > 0: diff --git a/mne/io/base.py b/mne/io/base.py index de6f3aa589d..ae622cfa307 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -18,6 +18,7 @@ from copy import deepcopy from dataclasses import dataclass, field from datetime import timedelta +from inspect import getfullargspec import numpy as np @@ -81,6 +82,7 @@ from ..html_templates import _get_html_template from ..parallel import parallel_func from ..time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method +from ..time_frequency.tfr import RawTFR from ..utils import ( SizeMixin, TimeMixin, @@ -203,7 +205,7 @@ def __init__( orig_units=None, *, verbose=None, - ): # noqa: D102 + ): # wait until the end to preload data, but triage here if isinstance(preload, np.ndarray): # some functions (e.g., filtering) only work w/64-bit data @@ -265,8 +267,7 @@ def __init__( if orig_units: if not isinstance(orig_units, dict): raise ValueError( - "orig_units must be of type dict, but got " - " {}".format(type(orig_units)) + f"orig_units must be of type dict, but got {type(orig_units)}" ) # original units need to be truncated to 15 chars or renamed @@ -291,8 +292,7 @@ def __init__( if not all(ch_correspond): ch_without_orig_unit = ch_names[ch_correspond.index(False)] raise ValueError( - "Channel {} has no associated original " - "unit.".format(ch_without_orig_unit) + f"Channel {ch_without_orig_unit} has no associated original unit." ) # Final check of orig_units, editing a unit if it is not a valid @@ -415,8 +415,8 @@ def _read_segment( if isinstance(data_buffer, np.ndarray): if data_buffer.shape != data_shape: raise ValueError( - "data_buffer has incorrect shape: %s != %s" - % (data_buffer.shape, data_shape) + f"data_buffer has incorrect shape: " + f"{data_buffer.shape} != {data_shape}" ) data = data_buffer else: @@ -660,8 +660,7 @@ def time_as_index(self, times, use_rounding=False, origin=None): delta = 0 elif self.info["meas_date"] is None: raise ValueError( - 'origin must be None when info["meas_date"] ' - "is None, got %s" % (origin,) + f'origin must be None when info["meas_date"] is None, got {origin}' ) else: first_samp_in_abs_time = self.info["meas_date"] + timedelta( @@ -670,7 +669,7 @@ def time_as_index(self, times, use_rounding=False, origin=None): delta = (origin - first_samp_in_abs_time).total_seconds() times = np.atleast_1d(times) + delta - return super(BaseRaw, self).time_as_index(times, use_rounding) + return super().time_as_index(times, use_rounding) @property def _raw_lengths(self): @@ -797,6 +796,9 @@ def _parse_get_set_params(self, item): item1 = int(item1) if isinstance(item1, (int, np.integer)): start, stop, step = item1, item1 + 1, 1 + # Need to special case -1, because -1:0 will be empty + if start == -1: + stop = None else: raise ValueError("Must pass int or slice to __getitem__") @@ -1087,19 +1089,50 @@ def apply_function( if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) + args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs + if channel_wise is False: + if ("ch_idx" in args) or ("ch_name" in args): + raise ValueError( + "apply_function cannot access ch_idx or ch_name " + "when channel_wise=False" + ) + if "ch_idx" in args: + logger.info("apply_function requested to access ch_idx") + if "ch_name" in args: + logger.info("apply_function requested to access ch_name") + if channel_wise: parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) if n_jobs == 1: # modify data inplace to save memory - for idx in picks: - self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) + for ch_idx in picks: + if "ch_idx" in args: + kwargs.update(ch_idx=ch_idx) + if "ch_name" in args: + kwargs.update(ch_name=self.info["ch_names"][ch_idx]) + self._data[ch_idx, :] = _check_fun( + fun, data_in[ch_idx, :], **kwargs + ) else: # use parallel function data_picks_new = parallel( - p_fun(fun, data_in[p], **kwargs) for p in picks + p_fun( + fun, + data_in[ch_idx], + **kwargs, + **{ + k: v + for k, v in [ + ("ch_name", self.info["ch_names"][ch_idx]), + ("ch_idx", ch_idx), + ] + if k in args + }, + ) + for ch_idx in picks ) - for pp, p in enumerate(picks): - self._data[p, :] = data_picks_new[pp] + for run_idx, ch_idx in enumerate(picks): + self._data[ch_idx, :] = data_picks_new[run_idx] else: self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs) @@ -1124,7 +1157,7 @@ def filter( skip_by_annotation=("edge", "bad_acq_skip"), pad="reflect_limited", verbose=None, - ): # noqa: D102 + ): return super().filter( l_freq, h_freq, @@ -1259,12 +1292,14 @@ def notch_filter( def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", stim_picks=None, n_jobs=None, events=None, - pad="reflect_limited", + pad="auto", + method="fft", verbose=None, ): """Resample all channels. @@ -1293,7 +1328,7 @@ def resample( ---------- sfreq : float New sample rate to use. - %(npad)s + %(npad_resample)s %(window_resample)s stim_picks : list of int | None Stim channels. These channels are simply subsampled or @@ -1306,10 +1341,12 @@ def resample( An optional event matrix. When specified, the onsets of the events are resampled jointly with the data. NB: The input events are not modified, but a new array is returned with the raw instead. - %(pad)s - The default is ``'reflect_limited'``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1363,7 +1400,13 @@ def resample( ) kwargs = dict( - up=sfreq, down=o_sfreq, npad=npad, window=window, n_jobs=n_jobs, pad=pad + up=sfreq, + down=o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) ratio, n_news = zip( *( @@ -1474,13 +1517,13 @@ def crop(self, tmin=0.0, tmax=None, include_tmax=True, *, verbose=None): tmax = max_time if tmin > tmax: - raise ValueError("tmin (%s) must be less than tmax (%s)" % (tmin, tmax)) + raise ValueError(f"tmin ({tmin}) must be less than tmax ({tmax})") if tmin < 0.0: - raise ValueError("tmin (%s) must be >= 0" % (tmin,)) + raise ValueError(f"tmin ({tmin}) must be >= 0") elif tmax - int(not include_tmax) / self.info["sfreq"] > max_time: raise ValueError( - "tmax (%s) must be less than or equal to the max " - "time (%0.4f s)" % (tmax, max_time) + f"tmax ({tmax}) must be less than or equal to the max " + f"time ({max_time:0.4f} s)" ) smin, smax = np.where( @@ -1652,7 +1695,13 @@ def save( endings_err = (".fif", ".fif.gz") # convert to str, check for overwrite a few lines later - fname = _check_fname(fname, overwrite=True, verbose="error") + fname = _check_fname( + fname, + overwrite=True, + verbose="error", + check_bids_split=True, + name="fname", + ) check_fname(fname, "raw", endings, endings_err=endings_err) split_size = _get_split_size(split_size) @@ -1760,9 +1809,7 @@ def _tmin_tmax_to_start_stop(self, tmin, tmax): stop = self.time_as_index(float(tmax), use_rounding=True)[0] + 1 stop = min(stop, self.last_samp - self.first_samp + 1) if stop <= start or stop <= 0: - raise ValueError( - "tmin (%s) and tmax (%s) yielded no samples" % (tmin, tmax) - ) + raise ValueError(f"tmin ({tmin}) and tmax ({tmax}) yielded no samples") return start, stop @copy_function_doc_to_method_doc(plot_raw) @@ -1800,6 +1847,7 @@ def plot( precompute=None, use_opengl=None, *, + picks=None, theme=None, overview_mode=None, splash=True, @@ -1838,6 +1886,7 @@ def plot( time_format=time_format, precompute=precompute, use_opengl=use_opengl, + picks=picks, theme=theme, overview_mode=overview_mode, splash=splash, @@ -2048,15 +2097,12 @@ def __repr__(self): # noqa: D105 name = self.filenames[0] name = "" if name is None else op.basename(name) + ", " size_str = str(sizeof_fmt(self._size)) # str in case it fails -> None - size_str += ", data%s loaded" % ("" if self.preload else " not") - s = "%s%s x %s (%0.1f s), ~%s" % ( - name, - len(self.ch_names), - self.n_times, - self.times[-1], - size_str, + size_str += f", data{'' if self.preload else ' not'} loaded" + s = ( + f"{name}{len(self.ch_names)} x {self.n_times} " + f"({self.times[-1]:0.1f} s), ~{size_str}" ) - return "<%s | %s>" % (self.__class__.__name__, s) + return f"<{self.__class__.__name__} | {s}>" @repr_html def _repr_html_(self, caption=None): @@ -2114,8 +2160,8 @@ def add_events(self, events, stim_channel=None, replace=False): idx = events[:, 0].astype(int) if np.any(idx < self.first_samp) or np.any(idx > self.last_samp): raise ValueError( - "event sample numbers must be between %s and %s" - % (self.first_samp, self.last_samp) + f"event sample numbers must be between {self.first_samp} " + f"and {self.last_samp}" ) if not all(idx == events[:, 0]): raise ValueError("event sample numbers must be integers") @@ -2153,7 +2199,9 @@ def compute_psd( Parameters ---------- %(method_psd)s - Default is ``'welch'``. + Note that ``"multitaper"`` cannot be used if ``reject_by_annotation=True`` + and there are ``"bad_*"`` annotations in the :class:`~mne.io.Raw` data; + in such cases use ``"welch"``. Default is ``'welch'``. %(fmin_fmax_psd)s %(tmin_tmax_psd)s %(picks_good_data_noref)s @@ -2198,6 +2246,69 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + reject_by_annotation=True, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of sensor data. + + Parameters + ---------- + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + %(reject_by_annotation_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Returns + ------- + tfr : instance of RawTFR + The time-frequency-resolved power estimates of the data. + + Notes + ----- + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + return RawTFR( + self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=reject_by_annotation, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def to_data_frame( self, @@ -2260,7 +2371,9 @@ def to_data_frame( data = _scale_dataframe_data(self, data, picks, scalings) # prepare extra columns / multiindex mindex = list() - times = _convert_times(self, times, time_format) + times = _convert_times( + times, time_format, self.info["meas_date"], self.first_time + ) mindex.append(("time", times)) # build DataFrame df = _build_data_frame( @@ -2519,7 +2632,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): class _RawShell: """Create a temporary raw object.""" - def __init__(self): # noqa: D102 + def __init__(self): self.first_samp = None self.last_samp = None self._first_time = None @@ -2552,6 +2665,13 @@ def set_annotations(self, annotations): def _write_raw(raw_fid_writer, fpath, split_naming, overwrite): """Write raw file with splitting.""" dir_path = fpath.parent + _check_fname( + dir_path, + overwrite="read", + must_exist=True, + name="parent directory", + need_dir=True, + ) # We have to create one extra filename here to make the for loop below happy, # but it will raise an error if it actually gets used split_fnames = _make_split_fnames( @@ -2676,8 +2796,9 @@ def _check_start_stop_within_bounds(self): # we've done something wrong if we hit this n_times_max = len(self.raw.times) error_msg = ( - "Can't write raw file with no data: {0} -> {1} (max: {2}) requested" - ).format(self.start, self.stop, n_times_max) + f"Can't write raw file with no data: {self.start} -> {self.stop} " + f"(max: {n_times_max}) requested" + ) if self.start >= self.stop or self.stop > n_times_max: raise RuntimeError(error_msg) @@ -2781,17 +2902,12 @@ def _write_raw_data( # This should occur on the first buffer write of the file, so # we should mention the space required for the meas info raise ValueError( - "buffer size (%s) is too large for the given split size (%s) " - "by %s bytes after writing info (%s) and leaving enough space " - 'for end tags (%s): decrease "buffer_size_sec" or increase ' - '"split_size".' - % ( - this_buff_size_bytes, - split_size, - overage, - pos_prev, - _NEXT_FILE_BUFFER, - ) + f"buffer size ({this_buff_size_bytes}) is too large for the " + f"given split size ({split_size}) " + f"by {overage} bytes after writing info ({pos_prev}) and " + "leaving enough space " + f'for end tags ({_NEXT_FILE_BUFFER}): decrease "buffer_size_sec" ' + 'or increase "split_size".' ) new_start = last diff --git a/mne/io/besa/tests/test_besa.py b/mne/io/besa/tests/test_besa.py index aeecf48cd63..2ee2843840b 100644 --- a/mne/io/besa/tests/test_besa.py +++ b/mne/io/besa/tests/test_besa.py @@ -1,6 +1,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. """Test reading BESA fileformats.""" + import inspect from pathlib import Path diff --git a/mne/io/boxy/boxy.py b/mne/io/boxy/boxy.py index b2afe096f64..a3beefc218c 100644 --- a/mne/io/boxy/boxy.py +++ b/mne/io/boxy/boxy.py @@ -15,7 +15,7 @@ @fill_doc -def read_raw_boxy(fname, preload=False, verbose=None): +def read_raw_boxy(fname, preload=False, verbose=None) -> "RawBOXY": """Reader for an optical imaging recording. This function has been tested using the ISS Imagent I and II systems @@ -68,7 +68,7 @@ def __init__(self, fname, preload=False, verbose=None): raw_extras["offsets"] = list() # keep track of our offsets sfreq = None fname = str(_check_fname(fname, "read", True, "fname")) - with open(fname, "r") as fid: + with open(fname) as fid: line_num = 0 i_line = fid.readline() while i_line: @@ -170,7 +170,7 @@ def __init__(self, fname, preload=False, verbose=None): assert len(raw_extras["offsets"]) == delta + 1 if filetype == "non-parsed": delta //= raw_extras["source_num"] - super(RawBOXY, self).__init__( + super().__init__( info, preload, filenames=[fname], @@ -235,7 +235,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): # Loop through our data. one = np.zeros((len(col_names), stop_read - start_read)) - with open(boxy_file, "r") as fid: + with open(boxy_file) as fid: # Just a more efficient version of this: # ii = 0 # for line_num, i_line in enumerate(fid): diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index 3a4f63718c3..1942744afe3 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -111,7 +111,7 @@ def __init__( orig_format = "single" if isinstance(fmt, dict) else fmt raw_extras = dict(offsets=offsets, fmt=fmt, order=order, n_samples=n_samples) - super(RawBrainVision, self).__init__( + super().__init__( info, last_samps=[n_samples - 1], filenames=[data_fname], @@ -124,7 +124,7 @@ def __init__( self.set_montage(montage) - settings, cfg, cinfo, _ = _aux_hdr_info(hdr_fname) + settings, _, _, _ = _aux_hdr_info(hdr_fname) split_settings = settings.splitlines() self.impedances = _parse_impedance(split_settings, self.info["meas_date"]) @@ -344,9 +344,10 @@ def _read_annotations_brainvision(fname, sfreq="auto"): def _check_bv_version(header, kind): """Check the header version.""" - _data_err = """\ - MNE-Python currently only supports %s versions 1.0 and 2.0, got unparsable\ - %r. Contact MNE-Python developers for support.""" + _data_err = ( + "MNE-Python currently only supports %s versions 1.0 and 2.0, got unparsable " + "%r. Contact MNE-Python developers for support." + ) # optional space, optional Core or V-Amp, optional Exchange, # Version/Header, optional comma, 1/2 _data_re = ( @@ -355,14 +356,15 @@ def _check_bv_version(header, kind): assert kind in ("header", "marker") - if header == "": - warn(f"Missing header in {kind} file.") for version in range(1, 3): this_re = _data_re % (kind.capitalize(), version) if re.search(this_re, header) is not None: return version else: - warn(_data_err % (kind, header)) + if header == "": + warn(f"Missing header in {kind} file.") + else: + warn(_data_err % (kind, header)) _orientation_dict = dict(MULTIPLEXED="F", VECTORIZED="C") @@ -445,7 +447,7 @@ def _aux_hdr_info(hdr_fname): params, settings = settings.split("[Comment]") else: params, settings = settings, "" - cfg = configparser.ConfigParser() + cfg = configparser.ConfigParser(interpolation=None) with StringIO(params) as fid: cfg.read_file(fid) @@ -542,7 +544,7 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): # Try to get measurement date from marker file # Usually saved with a marker "New Segment", see BrainVision documentation regexp = r"^Mk\d+=New Segment,.*,\d+,\d+,-?\d+,(\d{20})$" - with open(mrk_fname, "r") as tmp_mrk_f: + with open(mrk_fname) as tmp_mrk_f: lines = tmp_mrk_f.readlines() for line in lines: @@ -636,7 +638,7 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): ch_name = ch_dict[ch[0]] montage_names.append(ch_name) # 1: radius, 2: theta, 3: phi - rad, theta, phi = [float(c) for c in ch[1].split(",")] + rad, theta, phi = (float(c) for c in ch[1].split(",")) pol = np.deg2rad(theta) az = np.deg2rad(phi) # Coordinates could be "idealized" (spherical head model) @@ -656,9 +658,9 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): if len(to_misc) > 0: misc += to_misc warn( - "No coordinate information found for channels {}. " - "Setting channel types to misc. To avoid this warning, set " - "channel types explicitly.".format(to_misc) + f"No coordinate information found for channels {to_misc}. Setting " + "channel types to misc. To avoid this warning, set channel types " + "explicitly." ) if np.isnan(cals).any(): @@ -865,8 +867,8 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): nyquist = "" warn( "Channels contain different lowpass filters. " - "Highest (weakest) filter setting (%0.2f Hz%s) " - "will be stored." % (info["lowpass"], nyquist) + f"Highest (weakest) filter setting ({info['lowpass']:0.2f} " + f"Hz{nyquist}) will be stored." ) # Creates a list of dicts of eeg channels for raw.info @@ -921,7 +923,7 @@ def read_raw_brainvision( scale=1.0, preload=False, verbose=None, -): +) -> RawBrainVision: """Reader for Brain Vision EEG file. Parameters @@ -988,9 +990,7 @@ def __call__(self, description): elif description in _OTHER_ACCEPTED_MARKERS: code = _OTHER_ACCEPTED_MARKERS[description] else: - code = super(_BVEventParser, self).__call__( - description, offset=_OTHER_OFFSET - ) + code = super().__call__(description, offset=_OTHER_OFFSET) return code diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py index 1688963296a..309e44e3cf8 100644 --- a/mne/io/brainvision/tests/test_brainvision.py +++ b/mne/io/brainvision/tests/test_brainvision.py @@ -20,7 +20,7 @@ from mne.datasets import testing from mne.io import read_raw_brainvision, read_raw_fif from mne.io.tests.test_raw import _test_raw_reader -from mne.utils import _stamp_to_dt, object_diff +from mne.utils import _record_warnings, _stamp_to_dt, object_diff data_dir = Path(__file__).parent / "data" vhdr_path = data_dir / "test.vhdr" @@ -72,6 +72,8 @@ # This should be amend in its own PR. montage = data_dir / "test.hpts" +_no_dig = pytest.warns(RuntimeWarning, match="No info on DataPoints") + def test_orig_units(recwarn): """Test exposure of original channel units.""" @@ -133,14 +135,14 @@ def _mocked_meas_date_data(tmp_path_factory): """Prepare files for mocked_meas_date_file fixture.""" # Prepare the files tmp_path = tmp_path_factory.mktemp("brainvision_mocked_meas_date") - vhdr_fname, vmrk_fname, eeg_fname = [ + vhdr_fname, vmrk_fname, eeg_fname = ( tmp_path / ff.name for ff in [vhdr_path, vmrk_path, eeg_path] - ] + ) for orig, dest in zip([vhdr_path, eeg_path], [vhdr_fname, eeg_fname]): shutil.copyfile(orig, dest) # Get the marker information - with open(vmrk_path, "r") as fin: + with open(vmrk_path) as fin: lines = fin.readlines() return vhdr_fname, vmrk_fname, lines @@ -331,7 +333,7 @@ def test_ch_names_comma(tmp_path): shutil.copyfile(src, tmp_path / dest) comma_vhdr = tmp_path / "test.vhdr" - with open(comma_vhdr, "r") as fin: + with open(comma_vhdr) as fin: lines = fin.readlines() new_lines = [] @@ -473,7 +475,7 @@ def test_brainvision_data_partially_disabled_hw_filters(): def test_brainvision_data_software_filters_latin1_global_units(): """Test reading raw Brain Vision files.""" - with pytest.warns(RuntimeWarning, match="software filter"): + with _no_dig, pytest.warns(RuntimeWarning, match="software filter"): raw = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_old_path, @@ -485,7 +487,7 @@ def test_brainvision_data_software_filters_latin1_global_units(): assert raw.info["lowpass"] == 50.0 # test sensor name with spaces (#9299) - with pytest.warns(RuntimeWarning, match="software filter"): + with _no_dig, pytest.warns(RuntimeWarning, match="software filter"): raw = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_old_longname_path, @@ -566,7 +568,7 @@ def test_brainvision_data(): def test_brainvision_vectorized_data(): """Test reading BrainVision data files with vectorized data.""" - with pytest.warns(RuntimeWarning, match="software filter"): + with _no_dig, pytest.warns(RuntimeWarning, match="software filter"): raw = read_raw_brainvision(vhdr_old_path, preload=True) assert_array_equal(raw._data.shape, (29, 251)) @@ -611,7 +613,10 @@ def test_brainvision_vectorized_data(): def test_coodinates_extraction(): """Test reading of [Coordinates] section if present.""" # vhdr 2 has a Coordinates section - with pytest.warns(RuntimeWarning, match="coordinate information"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="coordinate information"), + ): raw = read_raw_brainvision(vhdr_v2_path) # Basic check of extracted coordinates diff --git a/mne/io/bti/bti.py b/mne/io/bti/bti.py index 190625f8ee0..616602892dd 100644 --- a/mne/io/bti/bti.py +++ b/mne/io/bti/bti.py @@ -72,13 +72,13 @@ def _instantiate_default_info_chs(): class _bytes_io_mock_context: """Make a context for BytesIO.""" - def __init__(self, target): # noqa: D102 + def __init__(self, target): self.target = target def __enter__(self): # noqa: D105 return self.target - def __exit__(self, type, value, tb): # noqa: D105 + def __exit__(self, exception_type, value, tb): # noqa: D105 pass @@ -138,7 +138,7 @@ def _rename_channels(names, ecg_ch="E31", eog_ch=("E63", "E64")): List of names, channel names in Neuromag style """ new = list() - ref_mag, ref_grad, eog, eeg, ext = [count(1) for _ in range(5)] + ref_mag, ref_grad, eog, eeg, ext = (count(1) for _ in range(5)) for i, name in enumerate(names, 1): if name.startswith("A"): name = "MEG %3.3d" % i @@ -176,7 +176,7 @@ def _read_head_shape(fname): dig_points = read_double_matrix(fid, _n_dig_points, 3) # reorder to lpa, rpa, nasion so = is direct. - nasion, lpa, rpa = [idx_points[_, :] for _ in [2, 0, 1]] + nasion, lpa, rpa = (idx_points[_, :] for _ in [2, 0, 1]) hpi = idx_points[3 : len(idx_points), :] return nasion, lpa, rpa, hpi, dig_points @@ -1077,7 +1077,7 @@ def __init__( eog_ch=("E63", "E64"), preload=False, verbose=None, - ): # noqa: D102 + ): _validate_type(pdf_fname, ("path-like", BytesIO), "pdf_fname") info, bti_info = _get_bti_info( pdf_fname=pdf_fname, @@ -1096,7 +1096,7 @@ def __init__( filename = bti_info["pdf"] if isinstance(filename, BytesIO): filename = repr(filename) - super(RawBTi, self).__init__( + super().__init__( info, preload, filenames=[filename], @@ -1435,7 +1435,7 @@ def read_raw_bti( eog_ch=("E63", "E64"), preload=False, verbose=None, -): +) -> RawBTi: """Raw object from 4D Neuroimaging MagnesWH3600 data. .. note:: diff --git a/mne/io/bti/tests/test_bti.py b/mne/io/bti/tests/test_bti.py index de2a5fdd79c..afe387b8769 100644 --- a/mne/io/bti/tests/test_bti.py +++ b/mne/io/bti/tests/test_bti.py @@ -155,16 +155,16 @@ def test_raw(pdf, config, hs, exported, tmp_path): ) assert len(ex.info["dig"]) in (3563, 5154) assert_dig_allclose(ex.info, ra.info, limit=100) - coil1, coil2 = [ + coil1, coil2 = ( np.concatenate([d["loc"].flatten() for d in r_.info["chs"][:NCH]]) for r_ in (ra, ex) - ] + ) assert_array_almost_equal(coil1, coil2, 7) - loc1, loc2 = [ + loc1, loc2 = ( np.concatenate([d["loc"].flatten() for d in r_.info["chs"][:NCH]]) for r_ in (ra, ex) - ] + ) assert_allclose(loc1, loc2) assert_allclose(ra[:NCH][0], ex[:NCH][0]) diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index a242e85952b..5e5c60ee1a1 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -174,7 +174,7 @@ def read_raw_cnt( header="auto", preload=False, verbose=None, -): +) -> "RawCNT": """Read CNT data as raw object. .. Note:: @@ -292,7 +292,7 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format, he fid.seek(205) session_label = read_str(fid, 20) - session_date = "%s %s" % (read_str(fid, 10), read_str(fid, 12)) + session_date = f"{read_str(fid, 10)} {read_str(fid, 12)}" meas_date = _session_date_2_meas_date(session_date, date_format) fid.seek(370) @@ -309,7 +309,8 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format, he # Header has a field for number of samples, but it does not seem to be # too reliable. That's why we have option for setting n_bytes manually. fid.seek(864) - n_samples = np.fromfile(fid, dtype=" n_samples: + n_bytes = 4 + n_samples = n_samples_header + warn( + "Annotations are outside data range. " + "Changing data format to 'int32'." + ) else: n_bytes = data_size // (n_samples * n_channels) else: n_bytes = 2 if data_format == "int16" else 4 n_samples = data_size // (n_bytes * n_channels) + # See PR #12393 + if n_samples_header != 0: + n_samples = n_samples_header # Channel offset refers to the size of blocks per channel in the file. cnt_info["channel_offset"] = np.fromfile(fid, dtype=" 1: @@ -508,7 +521,7 @@ def __init__( header="auto", preload=False, verbose=None, - ): # noqa: D102 + ): _check_option("date_format", date_format, ["mm/dd/yy", "dd/mm/yy"]) if date_format == "dd/mm/yy": _date_format = "%d/%m/%y %H:%M:%S" @@ -520,7 +533,7 @@ def __init__( input_fname, eog, ecg, emg, misc, data_format, _date_format, header ) last_samps = [cnt_info["n_samples"] - 1] - super(RawCNT, self).__init__( + super().__init__( info, preload, filenames=[input_fname], @@ -548,6 +561,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): channel_offset = self._raw_extras[fi]["channel_offset"] baselines = self._raw_extras[fi]["baselines"] n_bytes = self._raw_extras[fi]["n_bytes"] + n_samples = self._raw_extras[fi]["n_samples"] dtype = " "RawCTF": """Raw object from CTF directory. Parameters @@ -55,11 +55,6 @@ def read_raw_ctf( ------- raw : instance of RawCTF The raw data. - See :class:`mne.io.Raw` for documentation of attributes and methods. - - See Also - -------- - mne.io.Raw : Documentation of attributes and methods of RawCTF. Notes ----- @@ -111,7 +106,7 @@ def __init__( preload=False, verbose=None, clean_names=False, - ): # noqa: D102 + ): # adapted from mne_ctf2fiff.c directory = str( _check_fname(directory, "read", True, "directory", need_dir=True) @@ -169,7 +164,7 @@ def __init__( f"file(s): {missing_names}, and the following file(s) had no " f"valid samples: {no_samps}" ) - super(RawCTF, self).__init__( + super().__init__( info, preload, first_samps=first_samps, @@ -232,7 +227,7 @@ def _clean_names_inst(inst): def _get_sample_info(fname, res4, system_clock): """Determine the number of valid samples.""" - logger.info("Finding samples for %s: " % (fname,)) + logger.info(f"Finding samples for {fname}: ") if CTF.SYSTEM_CLOCK_CH in res4["ch_names"]: clock_ch = res4["ch_names"].index(CTF.SYSTEM_CLOCK_CH) else: diff --git a/mne/io/ctf/eeg.py b/mne/io/ctf/eeg.py index 29ece5e9f74..36ce3321b31 100644 --- a/mne/io/ctf/eeg.py +++ b/mne/io/ctf/eeg.py @@ -79,7 +79,7 @@ def _read_pos(directory, transformations): fname = fname[0] digs = list() i = 2000 - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: line = line.strip() if len(line) > 0: diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index b177e29bf9d..791fdceaf51 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -171,8 +171,8 @@ def _check_comp_ch(cch, kind, desired=None): desired = cch["grad_order_no"] if cch["grad_order_no"] != desired: raise RuntimeError( - "%s channel with inconsistent compensation " - "grade %s, should be %s" % (kind, cch["grad_order_no"], desired) + f"{kind} channel with inconsistent compensation " + f"grade {cch['grad_order_no']}, should be {desired}" ) return desired @@ -217,8 +217,8 @@ def _convert_channel_info(res4, t, use_eeg_pos): if cch["sensor_type_index"] != CTF.CTFV_MEG_CH: text += " ref" warn( - "%s channel %s did not have position assigned, so " - "it was changed to a MISC channel" % (text, ch["ch_name"]) + f"{text} channel {ch['ch_name']} did not have position " + "assigned, so it was changed to a MISC channel" ) continue ch["unit"] = FIFF.FIFF_UNIT_T @@ -535,7 +535,7 @@ def _read_bad_chans(directory, info): if not op.exists(fname): return [] mapping = dict(zip(_clean_names(info["ch_names"]), info["ch_names"])) - with open(fname, "r") as fid: + with open(fname) as fid: bad_chans = [mapping[f.strip()] for f in fid.readlines()] return bad_chans @@ -549,7 +549,7 @@ def _annotate_bad_segments(directory, start_time, meas_date): onsets = [] durations = [] desc = [] - with open(fname, "r") as fid: + with open(fname) as fid: for f in fid.readlines(): tmp = f.strip().split() desc.append("bad_%s" % tmp[0]) diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index f5340421a70..bf4415d90b8 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -92,9 +92,7 @@ def test_read_ctf(tmp_path): args = ( str(ch_num + 1), raw.ch_names[ch_num], - ) + tuple( - "%0.5f" % x for x in 100 * pos[ii] - ) # convert to cm + ) + tuple("%0.5f" % x for x in 100 * pos[ii]) # convert to cm fid.write(("\t".join(args) + "\n").encode("ascii")) pos_read_old = np.array([raw.info["chs"][p]["loc"][:3] for p in picks]) with pytest.warns(RuntimeWarning, match="RMSP .* changed to a MISC ch"): @@ -115,7 +113,7 @@ def test_read_ctf(tmp_path): shutil.copytree(ctf_eeg_fname, ctf_no_hc_fname) remove_base = op.join(ctf_no_hc_fname, op.basename(ctf_fname_catch[:-3])) os.remove(remove_base + ".hc") - with pytest.warns(RuntimeWarning, match="MISC channel"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="MISC channel"): pytest.raises(RuntimeError, read_raw_ctf, ctf_no_hc_fname) os.remove(remove_base + ".eeg") shutil.copy( diff --git a/mne/io/curry/curry.py b/mne/io/curry/curry.py index e5b8ce02ed3..3d0fb9afbca 100644 --- a/mne/io/curry/curry.py +++ b/mne/io/curry/curry.py @@ -197,10 +197,10 @@ def _read_curry_parameters(fname): if any(var_name in line for var_name in var_names): key, val = line.replace(" ", "").replace("\n", "").split("=") param_dict[key.lower().replace("_", "")] = val - for type in CHANTYPES: - if "DEVICE_PARAMETERS" + CHANTYPES[type] + " START" in line: + for key, type_ in CHANTYPES.items(): + if f"DEVICE_PARAMETERS{type_} START" in line: data_unit = next(fid) - unit_dict[type] = ( + unit_dict[key] = ( data_unit.replace(" ", "").replace("\n", "").split("=")[-1] ) @@ -425,7 +425,7 @@ def _make_trans_dig(curry_paths, info, curry_dev_dev_t): ) ) dist = 1000 * np.linalg.norm(unknown_curry_t["trans"][:3, 3]) - logger.info(" Fit a %0.1f° rotation, %0.1f mm translation" % (angle, dist)) + logger.info(f" Fit a {angle:0.1f}° rotation, {dist:0.1f} mm translation") unknown_dev_t = combine_transforms( unknown_curry_t, curry_dev_dev_t, "unknown", "meg" ) @@ -464,7 +464,7 @@ def _make_trans_dig(curry_paths, info, curry_dev_dev_t): def _first_hpi(fname): # Get the first HPI result - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: line = line.strip() if any(x in line for x in ("FileVersion", "NumCoils")) or not line: @@ -472,7 +472,7 @@ def _first_hpi(fname): hpi = np.array(line.split(), float) break else: - raise RuntimeError("Could not find valid HPI in %s" % (fname,)) + raise RuntimeError(f"Could not find valid HPI in {fname}") # t is the first entry assert hpi.ndim == 1 hpi = hpi[1:] @@ -542,7 +542,7 @@ def _read_annotations_curry(fname, sfreq="auto"): @verbose -def read_raw_curry(fname, preload=False, verbose=None): +def read_raw_curry(fname, preload=False, verbose=None) -> "RawCurry": """Read raw data from Curry files. Parameters @@ -596,7 +596,7 @@ def __init__(self, fname, preload=False, verbose=None): last_samps = [n_samples - 1] raw_extras = dict(is_ascii=is_ascii) - super(RawCurry, self).__init__( + super().__init__( info, preload, filenames=[data_fname], diff --git a/mne/io/curry/tests/test_curry.py b/mne/io/curry/tests/test_curry.py index c4710ecb679..de5247fb3de 100644 --- a/mne/io/curry/tests/test_curry.py +++ b/mne/io/curry/tests/test_curry.py @@ -325,7 +325,7 @@ def test_check_missing_files(): def _mock_info_file(src, dst, sfreq, time_step): - with open(src, "r") as in_file, open(dst, "w") as out_file: + with open(src) as in_file, open(dst, "w") as out_file: for line in in_file: if "SampleFreqHz" in line: out_file.write(line.replace("500", str(sfreq))) diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index 7c02642ec8f..8a982f43e86 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -40,6 +40,7 @@ "TEMP": FIFF.FIFFV_TEMPERATURE_CH, "MISC": FIFF.FIFFV_MISC_CH, "SAO2": FIFF.FIFFV_BIO_CH, + "STIM": FIFF.FIFFV_STIM_CH, } @@ -86,6 +87,7 @@ class RawEDF(BaseRaw): %(preload)s %(units_edf_bdf_io)s %(encoding_edf)s + %(exclude_after_unique)s %(verbose)s See Also @@ -147,13 +149,22 @@ def __init__( include=None, units=None, encoding="utf8", + exclude_after_unique=False, *, verbose=None, ): - logger.info("Extracting EDF parameters from {}...".format(input_fname)) + logger.info(f"Extracting EDF parameters from {input_fname}...") input_fname = os.path.abspath(input_fname) info, edf_info, orig_units = _get_info( - input_fname, stim_channel, eog, misc, exclude, infer_types, preload, include + input_fname, + stim_channel, + eog, + misc, + exclude, + infer_types, + preload, + include, + exclude_after_unique, ) logger.info("Creating raw.info structure...") @@ -284,7 +295,7 @@ def __init__( include=None, verbose=None, ): - logger.info("Extracting EDF parameters from {}...".format(input_fname)) + logger.info(f"Extracting EDF parameters from {input_fname}...") input_fname = os.path.abspath(input_fname) info, edf_info, orig_units = _get_info( input_fname, stim_channel, eog, misc, exclude, True, preload, include @@ -369,7 +380,7 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, # We could read this one EDF block at a time, which would be this: ch_offsets = np.cumsum(np.concatenate([[0], n_samps]), dtype=np.int64) - block_start_idx, r_lims, d_lims = _blk_read_lims(start, stop, buf_len) + block_start_idx, r_lims, _ = _blk_read_lims(start, stop, buf_len) # But to speed it up, we really need to read multiple blocks at once, # Otherwise we can end up with e.g. 18,181 chunks for a 20 MB file! # Let's do ~10 MB chunks: @@ -472,7 +483,8 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, return tal_data -def _read_header(fname, exclude, infer_types, include=None): +@fill_doc +def _read_header(fname, exclude, infer_types, include=None, exclude_after_unique=False): """Unify EDF, BDF and GDF _read_header call. Parameters @@ -494,6 +506,7 @@ def _read_header(fname, exclude, infer_types, include=None): include : list of str | str Channel names to be included. A str is interpreted as a regular expression. 'exclude' must be empty if include is assigned. + %(exclude_after_unique)s Returns ------- @@ -502,7 +515,9 @@ def _read_header(fname, exclude, infer_types, include=None): ext = os.path.splitext(fname)[1][1:].lower() logger.info("%s file detected" % ext.upper()) if ext in ("bdf", "edf"): - return _read_edf_header(fname, exclude, infer_types, include) + return _read_edf_header( + fname, exclude, infer_types, include, exclude_after_unique + ) elif ext == "gdf": return _read_gdf_header(fname, exclude, include), None else: @@ -512,13 +527,23 @@ def _read_header(fname, exclude, infer_types, include=None): def _get_info( - fname, stim_channel, eog, misc, exclude, infer_types, preload, include=None + fname, + stim_channel, + eog, + misc, + exclude, + infer_types, + preload, + include=None, + exclude_after_unique=False, ): """Extract information from EDF+, BDF or GDF file.""" eog = eog if eog is not None else [] misc = misc if misc is not None else [] - edf_info, orig_units = _read_header(fname, exclude, infer_types, include) + edf_info, orig_units = _read_header( + fname, exclude, infer_types, include, exclude_after_unique + ) # XXX: `tal_ch_names` to pass to `_check_stim_channel` should be computed # from `edf_info['ch_names']` and `edf_info['tal_idx']` but 'tal_idx' @@ -681,46 +706,10 @@ def _get_info( info["subject_info"]["weight"] = float(edf_info["subject_info"]["weight"]) # Filter settings - highpass = edf_info["highpass"] - lowpass = edf_info["lowpass"] - if highpass.size == 0: - pass - elif all(highpass): - if highpass[0] == "NaN": - # Placeholder for future use. Highpass set in _empty_info. - pass - elif highpass[0] == "DC": - info["highpass"] = 0.0 - else: - hp = highpass[0] - try: - hp = float(hp) - except Exception: - hp = 0.0 - info["highpass"] = hp - else: - info["highpass"] = float(np.max(highpass)) - warn( - "Channels contain different highpass filters. Highest filter " - "setting will be stored." - ) - if np.isnan(info["highpass"]): - info["highpass"] = 0.0 - if lowpass.size == 0: - # Placeholder for future use. Lowpass set in _empty_info. - pass - elif all(lowpass): - if lowpass[0] in ("NaN", "0", "0.0"): - # Placeholder for future use. Lowpass set in _empty_info. - pass - else: - info["lowpass"] = float(lowpass[0]) - else: - info["lowpass"] = float(np.min(lowpass)) - warn( - "Channels contain different lowpass filters. Lowest filter " - "setting will be stored." - ) + if filt_ch_idxs := [x for x in sel if x not in stim_channel_idxs]: + _set_prefilter(info, edf_info, filt_ch_idxs, "highpass") + _set_prefilter(info, edf_info, filt_ch_idxs, "lowpass") + if np.isnan(info["lowpass"]): info["lowpass"] = info["sfreq"] / 2.0 @@ -760,25 +749,47 @@ def _get_info( def _parse_prefilter_string(prefiltering): """Parse prefilter string from EDF+ and BDF headers.""" - highpass = np.array( - [ - v - for hp in [ - re.findall(r"HP:\s*([0-9]+[.]*[0-9]*)", filt) for filt in prefiltering - ] - for v in hp - ] - ) - lowpass = np.array( - [ - v - for hp in [ - re.findall(r"LP:\s*([0-9]+[.]*[0-9]*)", filt) for filt in prefiltering - ] - for v in hp - ] - ) - return highpass, lowpass + filter_types = ["HP", "LP"] + filter_strings = {t: [] for t in filter_types} + for filt in prefiltering: + for t in filter_types: + matches = re.findall(rf"{t}:\s*([a-zA-Z0-9,.]+)(Hz)?", filt) + value = "" + for match in matches: + if match[0]: + value = match[0].replace("Hz", "").replace(",", ".") + filter_strings[t].append(value) + return np.array(filter_strings["HP"]), np.array(filter_strings["LP"]) + + +def _prefilter_float(filt): + if isinstance(filt, (int, float, np.number)): + return filt + if filt == "DC": + return 0.0 + if filt.replace(".", "", 1).isdigit(): + return float(filt) + return np.nan + + +def _set_prefilter(info, edf_info, ch_idxs, key): + value = 0 + if len(values := edf_info.get(key, [])): + values = [x for i, x in enumerate(values) if i in ch_idxs] + if len(np.unique(values)) > 1: + warn( + f"Channels contain different {key} filters. " + f"{'Highest' if key == 'highpass' else 'Lowest'} filter " + "setting will be stored." + ) + if key == "highpass": + value = np.nanmax([_prefilter_float(x) for x in values]) + else: + value = np.nanmin([_prefilter_float(x) for x in values]) + else: + value = _prefilter_float(values[0]) + if not np.isnan(value) and value != 0: + info[key] = value def _edf_str(x): @@ -789,7 +800,9 @@ def _edf_str_num(x): return _edf_str(x).replace(",", ".") -def _read_edf_header(fname, exclude, infer_types, include=None): +def _read_edf_header( + fname, exclude, infer_types, include=None, exclude_after_unique=False +): """Read header information from EDF+ or BDF file.""" edf_info = {"events": []} @@ -846,11 +859,11 @@ def _read_edf_header(fname, exclude, infer_types, include=None): fid.read(8) # skip file's meas_date else: meas_date = fid.read(8).decode("latin-1") - day, month, year = [int(x) for x in meas_date.split(".")] + day, month, year = (int(x) for x in meas_date.split(".")) year = year + 2000 if year < 85 else year + 1900 meas_time = fid.read(8).decode("latin-1") - hour, minute, sec = [int(x) for x in meas_time.split(".")] + hour, minute, sec = (int(x) for x in meas_time.split(".")) try: meas_date = datetime( year, month, day, hour, minute, sec, tzinfo=timezone.utc @@ -912,10 +925,15 @@ def _read_edf_header(fname, exclude, infer_types, include=None): else: ch_types, ch_names = ["EEG"] * nchan, ch_labels - exclude = _find_exclude_idx(ch_names, exclude, include) tal_idx = _find_tal_idx(ch_names) + if exclude_after_unique: + # make sure channel names are unique + ch_names = _unique_channel_names(ch_names) + + exclude = _find_exclude_idx(ch_names, exclude, include) exclude = np.concatenate([exclude, tal_idx]) sel = np.setdiff1d(np.arange(len(ch_names)), exclude) + for ch in channels: fid.read(80) # transducer units = [fid.read(8).strip().decode("latin-1") for ch in channels] @@ -924,7 +942,7 @@ def _read_edf_header(fname, exclude, infer_types, include=None): if i in exclude: continue # allow μ (greek mu), µ (micro symbol) and μ (sjis mu) codepoints - if unit in ("\u03BCV", "\u00B5V", "\x83\xCAV", "uV"): + if unit in ("\u03bcV", "\u00b5V", "\x83\xcaV", "uV"): edf_info["units"].append(1e-6) elif unit == "mV": edf_info["units"].append(1e-3) @@ -935,8 +953,9 @@ def _read_edf_header(fname, exclude, infer_types, include=None): ch_names = [ch_names[idx] for idx in sel] units = [units[idx] for idx in sel] - # make sure channel names are unique - ch_names = _unique_channel_names(ch_names) + if not exclude_after_unique: + # make sure channel names are unique + ch_names = _unique_channel_names(ch_names) orig_units = dict(zip(ch_names, units)) physical_min = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ @@ -951,7 +970,7 @@ def _read_edf_header(fname, exclude, infer_types, include=None): digital_max = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ sel ] - prefiltering = [_edf_str(fid.read(80)).strip() for ch in channels][:-1] + prefiltering = np.array([_edf_str(fid.read(80)).strip() for ch in channels]) highpass, lowpass = _parse_prefilter_string(prefiltering) # number of samples per record @@ -1129,7 +1148,7 @@ def _read_gdf_header(fname, exclude, include=None): physical_max = np.fromfile(fid, FLOAT64, len(channels)) digital_min = np.fromfile(fid, INT64, len(channels)) digital_max = np.fromfile(fid, INT64, len(channels)) - prefiltering = [_edf_str(fid.read(80)) for ch in channels][:-1] + prefiltering = [_edf_str(fid.read(80)) for ch in channels] highpass, lowpass = _parse_prefilter_string(prefiltering) # n samples per record @@ -1269,7 +1288,9 @@ def _read_gdf_header(fname, exclude, include=None): if patient["birthday"] != datetime(1, 1, 1, 0, 0, tzinfo=timezone.utc): today = datetime.now(tz=timezone.utc) patient["age"] = today.year - patient["birthday"].year - today = today.replace(year=patient["birthday"].year) + # fudge the day by -1 if today happens to be a leap day + day = 28 if today.month == 2 and today.day == 29 else today.day + today = today.replace(year=patient["birthday"].year, day=day) if today < patient["birthday"]: patient["age"] -= 1 else: @@ -1454,7 +1475,9 @@ def _read_gdf_header(fname, exclude, include=None): def _check_stim_channel( - stim_channel, ch_names, tal_ch_names=["EDF Annotations", "BDF Annotations"] + stim_channel, + ch_names, + tal_ch_names=("EDF Annotations", "BDF Annotations"), ): """Check that the stimulus channel exists in the current datafile.""" DEFAULT_STIM_CH_NAMES = ["status", "trigger"] @@ -1498,10 +1521,10 @@ def _check_stim_channel( ] if len(tal_ch_names_found): _msg = ( - "The synthesis of the stim channel is not supported" - " since 0.18. Please remove {} from `stim_channel`" - " and use `mne.events_from_annotations` instead" - ).format(tal_ch_names_found) + "The synthesis of the stim channel is not supported since 0.18. Please " + f"remove {tal_ch_names_found} from `stim_channel` and use " + "`mne.events_from_annotations` instead." + ) raise ValueError(_msg) ch_names_low = [ch.lower() for ch in ch_names] @@ -1565,9 +1588,10 @@ def read_raw_edf( preload=False, units=None, encoding="utf8", + exclude_after_unique=False, *, verbose=None, -): +) -> RawEDF: """Reader function for EDF and EDF+ files. Parameters @@ -1609,6 +1633,7 @@ def read_raw_edf( %(preload)s %(units_edf_bdf_io)s %(encoding_edf)s + %(exclude_after_unique)s %(verbose)s Returns @@ -1683,6 +1708,7 @@ def read_raw_edf( include=include, units=units, encoding=encoding, + exclude_after_unique=exclude_after_unique, verbose=verbose, ) @@ -1699,9 +1725,10 @@ def read_raw_bdf( preload=False, units=None, encoding="utf8", + exclude_after_unique=False, *, verbose=None, -): +) -> RawEDF: """Reader function for BDF files. Parameters @@ -1743,6 +1770,7 @@ def read_raw_bdf( %(preload)s %(units_edf_bdf_io)s %(encoding_edf)s + %(exclude_after_unique)s %(verbose)s Returns @@ -1814,6 +1842,7 @@ def read_raw_bdf( include=include, units=units, encoding=encoding, + exclude_after_unique=exclude_after_unique, verbose=verbose, ) @@ -1828,7 +1857,7 @@ def read_raw_gdf( include=None, preload=False, verbose=None, -): +) -> RawGDF: """Reader function for GDF files. Parameters diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index 38532e062c1..7517693b6ea 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -31,13 +31,16 @@ from mne.io.edf.edf import ( _edf_str, _parse_prefilter_string, + _prefilter_float, _read_annotations_edf, _read_ch, _read_edf_header, _read_header, + _set_prefilter, ) from mne.io.tests.test_raw import _test_raw_reader from mne.tests.test_annotations import _assert_annotations_equal +from mne.utils import _record_warnings td_mark = testing._pytest_mark() @@ -172,24 +175,26 @@ def test_bdf_data(): # XXX BDF data for these is around 0.01 when it should be in the uV range, # probably some bug test_scaling = False - raw_py = _test_raw_reader( - read_raw_bdf, - input_fname=bdf_path, - eog=eog, - misc=misc, - exclude=["M2", "IEOG"], - test_scaling=test_scaling, - ) + with pytest.warns(RuntimeWarning, match="Channels contain different"): + raw_py = _test_raw_reader( + read_raw_bdf, + input_fname=bdf_path, + eog=eog, + misc=misc, + exclude=["M2", "IEOG"], + test_scaling=test_scaling, + ) assert len(raw_py.ch_names) == 71 - raw_py = _test_raw_reader( - read_raw_bdf, - input_fname=bdf_path, - montage="biosemi64", - eog=eog, - misc=misc, - exclude=["M2", "IEOG"], - test_scaling=test_scaling, - ) + with pytest.warns(RuntimeWarning, match="Channels contain different"): + raw_py = _test_raw_reader( + read_raw_bdf, + input_fname=bdf_path, + montage="biosemi64", + eog=eog, + misc=misc, + exclude=["M2", "IEOG"], + test_scaling=test_scaling, + ) assert len(raw_py.ch_names) == 71 assert "RawEDF" in repr(raw_py) picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") @@ -408,7 +413,7 @@ def test_no_data_channels(): annot_2 = raw.annotations _assert_annotations_equal(annot, annot_2) # only annotations (should warn) - with pytest.warns(RuntimeWarning, match="read_annotations"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="read_annotations"): read_raw_edf(edf_annot_only) @@ -630,27 +635,101 @@ def test_read_latin1_annotations(tmp_path): _read_annotations_edf(str(annot_file)) # default encoding="utf8" fails -def test_edf_prefilter_parse(): +@pytest.mark.parametrize( + "prefiltering, hp, lp", + [ + pytest.param(["HP: 1Hz LP: 30Hz"], ["1"], ["30"], id="basic edf"), + pytest.param(["LP: 30Hz HP: 1Hz"], ["1"], ["30"], id="reversed order"), + pytest.param(["HP: 1 LP: 30"], ["1"], ["30"], id="w/o Hz"), + pytest.param(["HP: 0,1 LP: 30,5"], ["0.1"], ["30.5"], id="using comma"), + pytest.param( + ["HP:0.1Hz LP:75Hz N:50Hz"], ["0.1"], ["75"], id="with notch filter" + ), + pytest.param([""], [""], [""], id="empty string"), + pytest.param(["HP: DC; LP: 410"], ["DC"], ["410"], id="bdf_dc"), + pytest.param( + ["", "HP:0.1Hz LP:75Hz N:50Hz", ""], + ["", "0.1", ""], + ["", "75", ""], + id="multi-ch", + ), + ], +) +def test_edf_parse_prefilter_string(prefiltering, hp, lp): """Test prefilter strings from header are parsed correctly.""" - prefilter_basic = ["HP: 0Hz LP: 0Hz"] - highpass, lowpass = _parse_prefilter_string(prefilter_basic) - assert_array_equal(highpass, ["0"]) - assert_array_equal(lowpass, ["0"]) + highpass, lowpass = _parse_prefilter_string(prefiltering) + assert_array_equal(highpass, hp) + assert_array_equal(lowpass, lp) - prefilter_normal_multi_ch = ["HP: 1Hz LP: 30Hz"] * 10 - highpass, lowpass = _parse_prefilter_string(prefilter_normal_multi_ch) - assert_array_equal(highpass, ["1"] * 10) - assert_array_equal(lowpass, ["30"] * 10) - prefilter_unfiltered_ch = prefilter_normal_multi_ch + [""] - highpass, lowpass = _parse_prefilter_string(prefilter_unfiltered_ch) - assert_array_equal(highpass, ["1"] * 10) - assert_array_equal(lowpass, ["30"] * 10) +@pytest.mark.parametrize( + "prefilter_string, expected", + [ + ("0", 0), + ("1.1", 1.1), + ("DC", 0), + ("", np.nan), + ("1.1.1", np.nan), + (1.1, 1.1), + (1, 1), + (np.float32(1.1), np.float32(1.1)), + (np.nan, np.nan), + ], +) +def test_edf_prefilter_float(prefilter_string, expected): + """Test to make float from prefilter string.""" + assert_equal(_prefilter_float(prefilter_string), expected) - prefilter_edf_specs_doc = ["HP:0.1Hz LP:75Hz N:50Hz"] - highpass, lowpass = _parse_prefilter_string(prefilter_edf_specs_doc) - assert_array_equal(highpass, ["0.1"]) - assert_array_equal(lowpass, ["75"]) + +@pytest.mark.parametrize( + "edf_info, hp, lp, hp_warn, lp_warn", + [ + ({"highpass": ["0"], "lowpass": ["1.1"]}, -1, 1.1, False, False), + ({"highpass": [""], "lowpass": [""]}, -1, -1, False, False), + ({"highpass": ["DC"], "lowpass": [""]}, -1, -1, False, False), + ({"highpass": [1], "lowpass": [2]}, 1, 2, False, False), + ({"highpass": [np.nan], "lowpass": [np.nan]}, -1, -1, False, False), + ({"highpass": ["1", "2"], "lowpass": ["3", "4"]}, 2, 3, True, True), + ({"highpass": [np.nan, 1], "lowpass": ["", 3]}, 1, 3, True, True), + ({"highpass": [np.nan, np.nan], "lowpass": [1, 2]}, -1, 1, False, True), + ({}, -1, -1, False, False), + ], +) +def test_edf_set_prefilter(edf_info, hp, lp, hp_warn, lp_warn): + """Test _set_prefilter function.""" + info = {"lowpass": -1, "highpass": -1} + + if hp_warn: + ctx = pytest.warns( + RuntimeWarning, + match=( + "Channels contain different highpass filters. " + "Highest filter setting will be stored." + ), + ) + else: + ctx = nullcontext() + with ctx: + _set_prefilter( + info, edf_info, list(range(len(edf_info.get("highpass", [])))), "highpass" + ) + + if lp_warn: + ctx = pytest.warns( + RuntimeWarning, + match=( + "Channels contain different lowpass filters. " + "Lowest filter setting will be stored." + ), + ) + else: + ctx = nullcontext() + with ctx: + _set_prefilter( + info, edf_info, list(range(len(edf_info.get("lowpass", [])))), "lowpass" + ) + assert info["highpass"] == hp + assert info["lowpass"] == lp @testing.requires_testing_data @@ -730,9 +809,23 @@ def test_edf_stim_ch_pick_up(test_input, EXPECTED): @testing.requires_testing_data -def test_bdf_multiple_annotation_channels(): +@pytest.mark.parametrize( + "exclude_after_unique, warns", + [ + (False, False), + (True, True), + ], +) +def test_bdf_multiple_annotation_channels(exclude_after_unique, warns): """Test BDF with multiple annotation channels.""" - raw = read_raw_bdf(bdf_multiple_annotations_path) + if warns: + ctx = pytest.warns(RuntimeWarning, match="Channel names are not unique") + else: + ctx = nullcontext() + with ctx: + raw = read_raw_bdf( + bdf_multiple_annotations_path, exclude_after_unique=exclude_after_unique + ) assert len(raw.annotations) == 10 descriptions = np.array( [ @@ -817,37 +910,40 @@ def test_empty_chars(): def _hp_lp_rev(*args, **kwargs): out, orig_units = _read_edf_header(*args, **kwargs) out["lowpass"], out["highpass"] = out["highpass"], out["lowpass"] - # this will happen for test_edf_stim_resamp.edf - if ( - len(out["lowpass"]) - and out["lowpass"][0] == "0.000" - and len(out["highpass"]) - and out["highpass"][0] == "0.0" - ): - out["highpass"][0] = "10.0" + return out, orig_units + + +def _hp_lp_mod(*args, **kwargs): + out, orig_units = _read_edf_header(*args, **kwargs) + out["lowpass"][:] = "1" + out["highpass"][:] = "10" return out, orig_units @pytest.mark.filterwarnings("ignore:.*too long.*:RuntimeWarning") @pytest.mark.parametrize( - "fname, lo, hi, warns", + "fname, lo, hi, warns, patch_func", [ - (edf_path, 256, 0, False), - (edf_uneven_path, 50, 0, False), - (edf_stim_channel_path, 64, 0, False), - pytest.param(edf_overlap_annot_path, 64, 0, False, marks=td_mark), - pytest.param(edf_reduced, 256, 0, False, marks=td_mark), - pytest.param(test_generator_edf, 100, 0, False, marks=td_mark), - pytest.param(edf_stim_resamp_path, 256, 0, True, marks=td_mark), + (edf_path, 256, 0, False, "rev"), + (edf_uneven_path, 50, 0, False, "rev"), + (edf_stim_channel_path, 64, 0, False, "rev"), + pytest.param(edf_overlap_annot_path, 64, 0, False, "rev", marks=td_mark), + pytest.param(edf_reduced, 256, 0, False, "rev", marks=td_mark), + pytest.param(test_generator_edf, 100, 0, False, "rev", marks=td_mark), + pytest.param(edf_stim_resamp_path, 256, 0, False, "rev", marks=td_mark), + pytest.param(edf_stim_resamp_path, 256, 0, True, "mod", marks=td_mark), ], ) -def test_hp_lp_reversed(fname, lo, hi, warns, monkeypatch): +def test_hp_lp_reversed(fname, lo, hi, warns, patch_func, monkeypatch): """Test HP/LP reversed (gh-8584).""" fname = str(fname) raw = read_raw_edf(fname) assert raw.info["lowpass"] == lo assert raw.info["highpass"] == hi - monkeypatch.setattr(edf.edf, "_read_edf_header", _hp_lp_rev) + if patch_func == "rev": + monkeypatch.setattr(edf.edf, "_read_edf_header", _hp_lp_rev) + elif patch_func == "mod": + monkeypatch.setattr(edf.edf, "_read_edf_header", _hp_lp_mod) if warns: ctx = pytest.warns(RuntimeWarning, match="greater than lowpass") new_lo, new_hi = raw.info["sfreq"] / 2.0, 0.0 @@ -885,6 +981,32 @@ def test_exclude(): assert ch not in raw.ch_names +@pytest.mark.parametrize( + "EXPECTED, exclude, exclude_after_unique, warns", + [ + (["EEG F2-Ref"], "EEG F1-Ref", False, False), + (["EEG F1-Ref-0", "EEG F2-Ref", "EEG F1-Ref-1"], "EEG F1-Ref-1", False, True), + (["EEG F2-Ref"], ["EEG F1-Ref"], False, False), + (["EEG F2-Ref"], "EEG F1-Ref", True, True), + (["EEG F1-Ref-0", "EEG F2-Ref"], "EEG F1-Ref-1", True, True), + (["EEG F1-Ref-0", "EEG F2-Ref", "EEG F1-Ref-1"], ["EEG F1-Ref"], True, True), + ], +) +def test_exclude_duplicate_channel_data(exclude, exclude_after_unique, warns, EXPECTED): + """Test exclude parameter for duplicate channel data.""" + if warns: + ctx = pytest.warns(RuntimeWarning, match="Channel names are not unique") + else: + ctx = nullcontext() + with ctx: + raw = read_raw_edf( + duplicate_channel_labels_path, + exclude=exclude, + exclude_after_unique=exclude_after_unique, + ) + assert raw.ch_names == EXPECTED + + def test_include(): """Test include parameter.""" raw = read_raw_edf(edf_path, include=["I1", "I2"]) @@ -898,6 +1020,32 @@ def test_include(): assert str(e.value) == "'exclude' must be empty" "if 'include' is assigned." +@pytest.mark.parametrize( + "EXPECTED, include, exclude_after_unique, warns", + [ + (["EEG F1-Ref-0", "EEG F1-Ref-1"], "EEG F1-Ref", False, True), + ([], "EEG F1-Ref-1", False, False), + (["EEG F1-Ref-0", "EEG F1-Ref-1"], ["EEG F1-Ref"], False, True), + (["EEG F1-Ref-0", "EEG F1-Ref-1"], "EEG F1-Ref", True, True), + (["EEG F1-Ref-1"], "EEG F1-Ref-1", True, True), + ([], ["EEG F1-Ref"], True, True), + ], +) +def test_include_duplicate_channel_data(include, exclude_after_unique, warns, EXPECTED): + """Test include parameter for duplicate channel data.""" + if warns: + ctx = pytest.warns(RuntimeWarning, match="Channel names are not unique") + else: + ctx = nullcontext() + with ctx: + raw = read_raw_edf( + duplicate_channel_labels_path, + include=include, + exclude_after_unique=exclude_after_unique, + ) + assert raw.ch_names == EXPECTED + + @testing.requires_testing_data def test_ch_types(): """Test reading of channel types from EDF channel label.""" diff --git a/mne/io/edf/tests/test_gdf.py b/mne/io/edf/tests/test_gdf.py index c029cc3280c..8942d13f8a6 100644 --- a/mne/io/edf/tests/test_gdf.py +++ b/mne/io/edf/tests/test_gdf.py @@ -8,7 +8,6 @@ from datetime import datetime, timedelta, timezone import numpy as np -import pytest import scipy.io as sio from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal @@ -38,7 +37,7 @@ def test_gdf_data(): # Test Status is added as event EXPECTED_EVS_ONSETS = raw._raw_extras[0]["events"][1] EXPECTED_EVS_ID = { - "{}".format(evs): i + f"{evs}": i for i, evs in enumerate( [ 32769, @@ -153,8 +152,7 @@ def test_gdf2_data(): @testing.requires_testing_data def test_one_channel_gdf(): """Test a one-channel GDF file.""" - with pytest.warns(RuntimeWarning, match="different highpass"): - ecg = read_raw_gdf(gdf_1ch_path, preload=True) + ecg = read_raw_gdf(gdf_1ch_path, preload=True) assert ecg["ECG"][0].shape == (1, 4500) assert 150.0 == ecg.info["sfreq"] diff --git a/mne/io/eeglab/eeglab.py b/mne/io/eeglab/eeglab.py index 413a8ae4bfc..905e9620010 100644 --- a/mne/io/eeglab/eeglab.py +++ b/mne/io/eeglab/eeglab.py @@ -52,8 +52,6 @@ def _check_eeglab_fname(fname, dataname): "Old data format .dat detected. Please update your EEGLAB " "version and resave the data in .fdt format" ) - elif fmt != ".fdt": - raise OSError("Expected .fdt file format. Found %s format" % fmt) basedir = op.dirname(fname) data_fname = op.join(basedir, dataname) @@ -293,7 +291,7 @@ def read_raw_eeglab( uint16_codec=None, montage_units="auto", verbose=None, -): +) -> "RawEEGLAB": r"""Read an EEGLAB .set file. Parameters @@ -349,7 +347,7 @@ def read_epochs_eeglab( uint16_codec=None, montage_units="auto", verbose=None, -): +) -> "EpochsEEGLAB": r"""Reader function for EEGLAB epochs files. Parameters @@ -449,7 +447,7 @@ def __init__( uint16_codec=None, montage_units="auto", verbose=None, - ): # noqa: D102 + ): input_fname = str(_check_fname(input_fname, "read", True, "input_fname")) eeg = _check_load_mat(input_fname, uint16_codec) if eeg.trials != 1: @@ -467,7 +465,7 @@ def __init__( data_fname = _check_eeglab_fname(input_fname, eeg.data) logger.info("Reading %s" % data_fname) - super(RawEEGLAB, self).__init__( + super().__init__( info, preload, filenames=[data_fname], @@ -491,7 +489,7 @@ def __init__( data = np.empty((n_chan, n_times), dtype=float) data[:n_chan] = eeg.data data *= CAL - super(RawEEGLAB, self).__init__( + super().__init__( info, data, filenames=[input_fname], @@ -602,7 +600,7 @@ def __init__( uint16_codec=None, montage_units="auto", verbose=None, - ): # noqa: D102 + ): input_fname = str( _check_fname(fname=input_fname, must_exist=True, overwrite="read") ) @@ -694,7 +692,7 @@ def __init__( assert data.shape == (eeg.trials, eeg.nbchan, eeg.pnts) tmin, tmax = eeg.xmin, eeg.xmax - super(EpochsEEGLAB, self).__init__( + super().__init__( info, data, events, @@ -799,6 +797,22 @@ def _read_annotations_eeglab(eeg, uint16_codec=None): ) duration[idx] = np.nan if is_empty_array else event.duration + # Drop events with NaN onset see PR #12484 + valid_indices = [ + idx for idx, onset_idx in enumerate(onset) if not np.isnan(onset_idx) + ] + n_dropped = len(onset) - len(valid_indices) + if len(valid_indices) != len(onset): + warn( + f"{n_dropped} events have an onset that is NaN. These values are " + "usually ignored by EEGLAB and will be dropped from the " + "annotations." + ) + + onset = np.array([onset[idx] for idx in valid_indices]) + duration = np.array([duration[idx] for idx in valid_indices]) + description = [description[idx] for idx in valid_indices] + return Annotations( onset=np.array(onset) / eeg.srate, duration=duration / eeg.srate, diff --git a/mne/io/eeglab/tests/test_eeglab.py b/mne/io/eeglab/tests/test_eeglab.py index 7d78f95ef6a..ebd5a6a6706 100644 --- a/mne/io/eeglab/tests/test_eeglab.py +++ b/mne/io/eeglab/tests/test_eeglab.py @@ -28,7 +28,7 @@ from mne.io.eeglab._eeglab import _readmat from mne.io.eeglab.eeglab import _dol_to_lod, _get_montage_information from mne.io.tests.test_raw import _test_raw_reader -from mne.utils import Bunch, _check_pymatreader_installed +from mne.utils import Bunch, _check_pymatreader_installed, _record_warnings base_dir = testing.data_path(download=False) / "EEGLAB" raw_fname_mat = base_dir / "test_raw.set" @@ -71,7 +71,7 @@ def test_io_set_raw(fname): """Test importing EEGLAB .set files.""" montage = read_custom_montage(montage_path) - montage.ch_names = ["EEG {0:03d}".format(ii) for ii in range(len(montage.ch_names))] + montage.ch_names = [f"EEG {ii:03d}" for ii in range(len(montage.ch_names))] kws = dict(reader=read_raw_eeglab, input_fname=fname) if fname.name == "test_raw_chanloc.set": @@ -140,7 +140,10 @@ def test_io_set_raw_more(tmp_path): shutil.copyfile( base_dir / "test_raw.fdt", negative_latency_fname.with_suffix(".fdt") ) - with pytest.warns(RuntimeWarning, match="has a sample index of -1."): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="has a sample index of -1."), + ): read_raw_eeglab(input_fname=negative_latency_fname, preload=True) # test negative event latencies @@ -163,7 +166,7 @@ def test_io_set_raw_more(tmp_path): oned_as="row", ) with pytest.raises(ValueError, match="event sample index is negative"): - with pytest.warns(RuntimeWarning, match="has a sample index of -1."): + with _record_warnings(): read_raw_eeglab(input_fname=negative_latency_fname, preload=True) # test overlapping events @@ -350,9 +353,9 @@ def test_io_set_raw_more(tmp_path): def test_io_set_epochs(fnames): """Test importing EEGLAB .set epochs files.""" epochs_fname, epochs_fname_onefile = fnames - with pytest.warns(RuntimeWarning, match="multiple events"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="multiple events"): epochs = read_epochs_eeglab(epochs_fname) - with pytest.warns(RuntimeWarning, match="multiple events"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="multiple events"): epochs2 = read_epochs_eeglab(epochs_fname_onefile) # one warning for each read_epochs_eeglab because both files have epochs # associated with multiple events @@ -568,9 +571,7 @@ def test_position_information(three_chanpos_fname): input_fname=three_chanpos_fname, preload=True, montage_units="cm", - ).set_montage( - None - ) # Flush the montage builtin within input_fname + ).set_montage(None) # Flush the montage builtin within input_fname _assert_array_allclose_nan( np.array([ch["loc"] for ch in raw.info["chs"]]), EXPECTED_LOCATIONS_FROM_MONTAGE @@ -718,3 +719,36 @@ def get_bad_information(eeg, get_pos, *, montage_units): assert len(pos["lpa"]) == 3 assert len(pos["rpa"]) == 3 assert len(raw.info["dig"]) == n_eeg + 3 + + +@testing.requires_testing_data +def test_eeglab_drop_nan_annotations(tmp_path): + """Test reading file with NaN annotations.""" + pytest.importorskip("eeglabio") + from eeglabio.raw import export_set + + file_path = tmp_path / "test_nan_anno.set" + raw = read_raw_eeglab(raw_fname_mat, preload=True) + data = raw.get_data() + sfreq = raw.info["sfreq"] + ch_names = raw.ch_names + anno = [ + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + ] + anno[1][0] = np.nan + + export_set( + str(file_path), + data, + sfreq, + ch_names, + ch_locs=None, + annotations=anno, + ref_channels="common", + ch_types=np.repeat("EEG", len(ch_names)), + ) + + with pytest.raises(RuntimeWarning, match="1 .* have an onset that is NaN.*"): + raw = read_raw_eeglab(file_path, preload=True) diff --git a/mne/io/egi/egi.py b/mne/io/egi/egi.py index 0b62d7b6389..b0124bdc541 100644 --- a/mne/io/egi/egi.py +++ b/mne/io/egi/egi.py @@ -104,7 +104,7 @@ def read_raw_egi( preload=False, channel_naming="E%d", verbose=None, -): +) -> "RawEGI": """Read EGI simple binary as raw object. .. note:: This function attempts to create a synthetic trigger channel. @@ -193,7 +193,7 @@ def __init__( preload=False, channel_naming="E%d", verbose=None, - ): # noqa: D102 + ): input_fname = str(_check_fname(input_fname, "read", True, "input_fname")) if eog is None: eog = [] @@ -307,7 +307,7 @@ def __init__( orig_format = ( egi_info["orig_format"] if egi_info["orig_format"] != "float" else "single" ) - super(RawEGI, self).__init__( + super().__init__( info, preload, orig_format=orig_format, diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index 457cab90087..3a039b0c784 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -10,7 +10,6 @@ from pathlib import Path import numpy as np -from defusedxml.minidom import parse from ..._fiff.constants import FIFF from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info @@ -19,7 +18,7 @@ from ...annotations import Annotations from ...channels.montage import make_dig_montage from ...evoked import EvokedArray -from ...utils import _check_fname, _check_option, logger, verbose, warn +from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn from ..base import BaseRaw from .events import _combine_triggers, _read_events from .general import ( @@ -36,6 +35,9 @@ def _read_mff_header(filepath): """Read mff header.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + all_files = _get_signalfname(filepath) eeg_file = all_files["EEG"]["signal"] eeg_info_file = all_files["EEG"]["info"] @@ -62,7 +64,7 @@ def _read_mff_header(filepath): record_time, ) if g is None: - raise RuntimeError("Could not parse recordTime %r" % (record_time,)) + raise RuntimeError(f"Could not parse recordTime {repr(record_time)}") frac = g.groups()[0] assert len(frac) in (6, 9) and all(f.isnumeric() for f in frac) # regex div = 1000 if len(frac) == 6 else 1000000 @@ -70,7 +72,7 @@ def _read_mff_header(filepath): # convert from times in µS to samples for ei, e in enumerate(epochs[key]): if e % div != 0: - raise RuntimeError("Could not parse epoch time %s" % (e,)) + raise RuntimeError(f"Could not parse epoch time {e}") epochs[key][ei] = e // div epochs[key] = np.array(epochs[key], np.uint64) # I guess they refer to times in milliseconds? @@ -102,7 +104,7 @@ def _read_mff_header(filepath): if bad: raise RuntimeError( "EGI epoch first/last samps could not be parsed:\n" - "%s\n%s" % (list(epochs["first_samps"]), list(epochs["last_samps"])) + f'{list(epochs["first_samps"])}\n{list(epochs["last_samps"])}' ) summaryinfo.update(epochs) # index which samples in raw are actually readable from disk (i.e., not @@ -154,7 +156,7 @@ def _read_mff_header(filepath): if not same_blocks: raise RuntimeError( "PNS and signals samples did not match:\n" - "%s\nvs\n%s" % (list(pns_samples), list(signal_samples)) + f"{list(pns_samples)}\nvs\n{list(signal_samples)}" ) pns_file = op.join(filepath, "pnsSet.xml") @@ -289,6 +291,9 @@ def _get_eeg_calibration_info(filepath, egi_info): def _read_locs(filepath, egi_info, channel_naming): """Read channel locations.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + fname = op.join(filepath, "coordinates.xml") if not op.exists(fname): logger.warn("File coordinates.xml not found, not setting channel locations") @@ -642,7 +647,7 @@ def __init__( self._filenames = [file_bin] self._raw_extras = [egi_info] - super(RawMff, self).__init__( + super().__init__( info, preload=preload, orig_format="single", diff --git a/mne/io/egi/events.py b/mne/io/egi/events.py index 6f0ea1472c8..500848ee715 100644 --- a/mne/io/egi/events.py +++ b/mne/io/egi/events.py @@ -7,9 +7,8 @@ from os.path import basename, join, splitext import numpy as np -from defusedxml.ElementTree import parse -from ...utils import logger +from ...utils import _soft_import, logger def _read_events(input_fname, info): @@ -82,7 +81,8 @@ def _read_mff_events(filename, sfreq): def _parse_xml(xml_file): """Parse XML file.""" - xml = parse(xml_file) + defusedxml = _soft_import("defusedxml", "reading EGI MFF data") + xml = defusedxml.ElementTree.parse(xml_file) root = xml.getroot() return _xml2list(root) diff --git a/mne/io/egi/general.py b/mne/io/egi/general.py index ebd5a700363..9ca6dc7f0b9 100644 --- a/mne/io/egi/general.py +++ b/mne/io/egi/general.py @@ -6,13 +6,15 @@ import re import numpy as np -from defusedxml.minidom import parse -from ...utils import _pl +from ...utils import _pl, _soft_import def _extract(tags, filepath=None, obj=None): """Extract info from XML.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + if obj is not None: fileobj = obj elif filepath is not None: @@ -30,6 +32,9 @@ def _extract(tags, filepath=None, obj=None): def _get_gains(filepath): """Parse gains.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + file_obj = parse(filepath) objects = file_obj.getElementsByTagName("calibration") gains = dict() @@ -46,6 +51,9 @@ def _get_gains(filepath): def _get_ep_info(filepath): """Get epoch info.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + epochfile = filepath + "/epochs.xml" epochlist = parse(epochfile) epochs = epochlist.getElementsByTagName("epoch") @@ -123,6 +131,9 @@ def _get_blocks(filepath): def _get_signalfname(filepath): """Get filenames.""" + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + listfiles = os.listdir(filepath) binfiles = list( f for f in listfiles if "signal" in f and f[-4:] == ".bin" and f[0] != "." @@ -140,7 +151,7 @@ def _get_signalfname(filepath): elif len(infobj.getElementsByTagName("PNSData")): signal_type = "PNS" all_files[signal_type] = { - "signal": "signal{}.bin".format(bin_num_str), + "signal": f"signal{bin_num_str}.bin", "info": infofile, } if "EEG" not in all_files: diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 8da704243fd..71120d8d6f7 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -70,6 +70,7 @@ ) def test_egi_mff_pause(fname, skip_times, event_times): """Test EGI MFF with pauses.""" + pytest.importorskip("defusedxml") if fname == egi_pause_w1337_fname: # too slow to _test_raw_reader raw = read_raw_egi(fname).load_data() @@ -129,6 +130,7 @@ def test_egi_mff_pause(fname, skip_times, event_times): ) def test_egi_mff_pause_chunks(fname, tmp_path): """Test that on-demand of all short segments works (via I/O).""" + pytest.importorskip("defusedxml") fname_temp = tmp_path / "test_raw.fif" raw_data = read_raw_egi(fname, preload=True).get_data() raw = read_raw_egi(fname) @@ -142,6 +144,7 @@ def test_egi_mff_pause_chunks(fname, tmp_path): @requires_testing_data def test_io_egi_mff(): """Test importing EGI MFF simple binary files.""" + pytest.importorskip("defusedxml") # want vars for n chans n_ref = 1 n_eeg = 128 @@ -258,6 +261,7 @@ def test_io_egi(): @requires_testing_data def test_io_egi_pns_mff(tmp_path): """Test importing EGI MFF with PNS data.""" + pytest.importorskip("defusedxml") raw = read_raw_egi(egi_mff_pns_fname, include=None, preload=True, verbose="error") assert "RawMff" in repr(raw) pns_chans = pick_types(raw.info, ecg=True, bio=True, emg=True) @@ -293,7 +297,7 @@ def test_io_egi_pns_mff(tmp_path): egi_fname_mat = testing_path / "EGI" / "test_egi_pns.mat" mc = sio.loadmat(egi_fname_mat) for ch_name, ch_idx, mat_name in zip(pns_names, pns_chans, mat_names): - print("Testing {}".format(ch_name)) + print(f"Testing {ch_name}") mc_key = [x for x in mc.keys() if mat_name in x][0] cal = raw.info["chs"][ch_idx]["cal"] mat_data = mc[mc_key] * cal @@ -314,6 +318,7 @@ def test_io_egi_pns_mff(tmp_path): @pytest.mark.parametrize("preload", (True, False)) def test_io_egi_pns_mff_bug(preload): """Test importing EGI MFF with PNS data (BUG).""" + pytest.importorskip("defusedxml") egi_fname_mff = testing_path / "EGI" / "test_egi_pns_bug.mff" with pytest.warns(RuntimeWarning, match="EGI PSG sample bug"): raw = read_raw_egi( @@ -344,7 +349,7 @@ def test_io_egi_pns_mff_bug(preload): "EMGLeg", ] for ch_name, ch_idx, mat_name in zip(pns_names, pns_chans, mat_names): - print("Testing {}".format(ch_name)) + print(f"Testing {ch_name}") mc_key = [x for x in mc.keys() if mat_name in x][0] cal = raw.info["chs"][ch_idx]["cal"] mat_data = mc[mc_key] * cal @@ -356,6 +361,7 @@ def test_io_egi_pns_mff_bug(preload): @requires_testing_data def test_io_egi_crop_no_preload(): """Test crop non-preloaded EGI MFF data (BUG).""" + pytest.importorskip("defusedxml") raw = read_raw_egi(egi_mff_fname, preload=False) raw.crop(17.5, 20.5) raw.load_data() @@ -383,6 +389,8 @@ def test_io_egi_crop_no_preload(): def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads): """Test reading evoked MFF file.""" pytest.importorskip("mffpy", "0.5.7") + + pytest.importorskip("defusedxml") # expected n channels n_eeg = 256 n_ref = 1 @@ -468,6 +476,7 @@ def test_read_evokeds_mff_bad_input(): @requires_testing_data def test_egi_coord_frame(): """Test that EGI coordinate frame is changed to head.""" + pytest.importorskip("defusedxml") info = read_raw_egi(egi_mff_fname).info want_idents = ( FIFF.FIFFV_POINT_LPA, @@ -505,6 +514,7 @@ def test_egi_coord_frame(): ) def test_meas_date(fname, timestamp, utc_offset): """Test meas date conversion.""" + pytest.importorskip("defusedxml") raw = read_raw_egi(fname, verbose="warning") dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z") measdate = dt.astimezone(timezone.utc) @@ -526,6 +536,7 @@ def test_meas_date(fname, timestamp, utc_offset): ) def test_set_standard_montage_mff(fname, standard_montage): """Test setting a standard montage.""" + pytest.importorskip("defusedxml") raw = read_raw_egi(fname, verbose="warning") n_eeg = int(standard_montage.split("-")[-1]) n_dig = n_eeg + 3 diff --git a/mne/io/eximia/eximia.py b/mne/io/eximia/eximia.py index 0af9d9daf5d..1d253f369d1 100644 --- a/mne/io/eximia/eximia.py +++ b/mne/io/eximia/eximia.py @@ -13,7 +13,7 @@ @fill_doc -def read_raw_eximia(fname, preload=False, verbose=None): +def read_raw_eximia(fname, preload=False, verbose=None) -> "RawEximia": """Reader for an eXimia EEG file. Parameters @@ -87,12 +87,12 @@ def __init__(self, fname, preload=False, verbose=None): n_samples, extra = divmod(n_bytes, (n_chan * 2)) if extra != 0: warn( - "Incorrect number of samples in file (%s), the file is " - "likely truncated" % (n_samples,) + f"Incorrect number of samples in file ({n_samples}), the file is likely" + " truncated" ) for ch, cal in zip(info["chs"], cals): ch["cal"] = cal - super(RawEximia, self).__init__( + super().__init__( info, preload=preload, last_samps=(n_samples - 1,), diff --git a/mne/io/eyelink/_utils.py b/mne/io/eyelink/_utils.py index f6ab2f8790d..99c1e1c96f6 100644 --- a/mne/io/eyelink/_utils.py +++ b/mne/io/eyelink/_utils.py @@ -3,7 +3,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - import re from datetime import datetime, timedelta, timezone @@ -508,7 +507,7 @@ def _adjust_times( np.arange(first, last + step / 2, step), columns=[time_col] ) return pd.merge_asof( - new_times, df, on=time_col, direction="nearest", tolerance=step / 10 + new_times, df, on=time_col, direction="nearest", tolerance=step / 2 ) diff --git a/mne/io/eyelink/eyelink.py b/mne/io/eyelink/eyelink.py index 196aef408b1..2ab00d22b58 100644 --- a/mne/io/eyelink/eyelink.py +++ b/mne/io/eyelink/eyelink.py @@ -28,7 +28,7 @@ def read_raw_eyelink( find_overlaps=False, overlap_threshold=0.05, verbose=None, -): +) -> "RawEyelink": """Reader for an Eyelink ``.asc`` file. Parameters @@ -99,7 +99,7 @@ def __init__( overlap_threshold=0.05, verbose=None, ): - logger.info("Loading {}".format(fname)) + logger.info(f"Loading {fname}") fname = Path(fname) @@ -108,7 +108,7 @@ def __init__( fname, find_overlaps, overlap_threshold, apply_offsets ) # ======================== Create Raw Object ========================= - super(RawEyelink, self).__init__( + super().__init__( info, preload=eye_ch_data, filenames=[fname], diff --git a/mne/io/eyelink/tests/test_eyelink.py b/mne/io/eyelink/tests/test_eyelink.py index 47b25e94489..7f57596ac38 100644 --- a/mne/io/eyelink/tests/test_eyelink.py +++ b/mne/io/eyelink/tests/test_eyelink.py @@ -12,6 +12,7 @@ from mne.io import read_raw_eyelink from mne.io.eyelink._utils import _adjust_times, _find_overlaps from mne.io.tests.test_raw import _test_raw_reader +from mne.utils import _record_warnings pd = pytest.importorskip("pandas") @@ -233,19 +234,16 @@ def _simulate_eye_tracking_data(in_file, out_file): else: fp.write("%s\n" % line) - fp.write("%s\n" % "START\t7452389\tRIGHT\tSAMPLES\tEVENTS") - fp.write("%s\n" % new_samples_line) + fp.write("START\t7452389\tRIGHT\tSAMPLES\tEVENTS\n") + fp.write(f"{new_samples_line}\n") for timestamp in np.arange(7452389, 7453390): # simulate a second block fp.write( - "%s\n" - % ( - f"{timestamp}\t-2434.0\t-1760.0\t840.0\t100\t20\t45\t45\t127.0\t" - "...\t1497\t5189\t512.5\t............." - ) + f"{timestamp}\t-2434.0\t-1760.0\t840.0\t100\t20\t45\t45\t127.0\t" + "...\t1497\t5189\t512.5\t.............\n" ) - fp.write("%s\n" % "END\t7453390\tRIGHT\tSAMPLES\tEVENTS") + fp.write("END\t7453390\tRIGHT\tSAMPLES\tEVENTS\n") @requires_testing_data @@ -255,7 +253,10 @@ def test_multi_block_misc_channels(fname, tmp_path): out_file = tmp_path / "tmp_eyelink.asc" _simulate_eye_tracking_data(fname, out_file) - with pytest.warns(RuntimeWarning, match="Raw eyegaze coordinates"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="Raw eyegaze coordinates"), + ): raw = read_raw_eyelink(out_file, apply_offsets=True) chs_in_file = [ @@ -295,7 +296,7 @@ def test_annotations_without_offset(tmp_path): out_file = tmp_path / "tmp_eyelink.asc" # create fake dataset - with open(fname_href, "r") as file: + with open(fname_href) as file: lines = file.readlines() ts = lines[-3].split("\t")[0] line = f"MSG\t{ts} test string\n" diff --git a/mne/io/fieldtrip/fieldtrip.py b/mne/io/fieldtrip/fieldtrip.py index 8d054b076ee..3dac2992be1 100644 --- a/mne/io/fieldtrip/fieldtrip.py +++ b/mne/io/fieldtrip/fieldtrip.py @@ -20,7 +20,7 @@ ) -def read_raw_fieldtrip(fname, info, data_name="data"): +def read_raw_fieldtrip(fname, info, data_name="data") -> RawArray: """Load continuous (raw) data from a FieldTrip preprocessing structure. This function expects to find single trial raw data (FT_DATATYPE_RAW) in @@ -83,7 +83,9 @@ def read_raw_fieldtrip(fname, info, data_name="data"): return raw -def read_epochs_fieldtrip(fname, info, data_name="data", trialinfo_column=0): +def read_epochs_fieldtrip( + fname, info, data_name="data", trialinfo_column=0 +) -> EpochsArray: """Load epoched data from a FieldTrip preprocessing structure. This function expects to find epoched data in the structure data_name is diff --git a/mne/io/fieldtrip/tests/helpers.py b/mne/io/fieldtrip/tests/helpers.py index 5ab02286b66..66cb582dde9 100644 --- a/mne/io/fieldtrip/tests/helpers.py +++ b/mne/io/fieldtrip/tests/helpers.py @@ -185,7 +185,7 @@ def get_epochs(system): else: event_id = [int(cfg_local["eventvalue"])] - event_id = [id for id in event_id if id in events[:, 2]] + event_id = [id_ for id_ in event_id if id_ in events[:, 2]] epochs = mne.Epochs( raw_data, diff --git a/mne/io/fieldtrip/tests/test_fieldtrip.py b/mne/io/fieldtrip/tests/test_fieldtrip.py index 0f66d1b1fae..11546e82607 100644 --- a/mne/io/fieldtrip/tests/test_fieldtrip.py +++ b/mne/io/fieldtrip/tests/test_fieldtrip.py @@ -68,16 +68,14 @@ def test_read_evoked(cur_system, version, use_info): """Test comparing reading an Evoked object and the FieldTrip version.""" test_data_folder_ft = get_data_paths(cur_system) mne_avg = get_evoked(cur_system) + cur_fname = test_data_folder_ft / f"averaged_{version}.mat" if use_info: info = get_raw_info(cur_system) - ctx = nullcontext() + avg_ft = mne.io.read_evoked_fieldtrip(cur_fname, info) else: info = None - ctx = pytest.warns(**no_info_warning) - - cur_fname = test_data_folder_ft / f"averaged_{version}.mat" - with ctx: - avg_ft = mne.io.read_evoked_fieldtrip(cur_fname, info) + with _record_warnings(), pytest.warns(**no_info_warning): + avg_ft = mne.io.read_evoked_fieldtrip(cur_fname, info) mne_data = mne_avg.data[:, :-1] ft_data = avg_ft.data @@ -98,6 +96,7 @@ def test_read_epochs(cur_system, version, use_info, monkeypatch): has_pandas = pandas is not False test_data_folder_ft = get_data_paths(cur_system) mne_epoched = get_epochs(cur_system) + cur_fname = test_data_folder_ft / f"epoched_{version}.mat" if use_info: info = get_raw_info(cur_system) ctx = nullcontext() @@ -105,9 +104,8 @@ def test_read_epochs(cur_system, version, use_info, monkeypatch): info = None ctx = pytest.warns(**no_info_warning) - cur_fname = test_data_folder_ft / f"epoched_{version}.mat" if has_pandas: - with ctx: + with _record_warnings(), ctx: epoched_ft = mne.io.read_epochs_fieldtrip(cur_fname, info) assert isinstance(epoched_ft.metadata, pandas.DataFrame) else: @@ -133,7 +131,7 @@ def modify_mat(fname, variable_names=None, ignore_fields=None): return out monkeypatch.setattr(pymatreader, "read_mat", modify_mat) - with pytest.warns(RuntimeWarning, match="multiple"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="multiple"): mne.io.read_epochs_fieldtrip(cur_fname, info) @@ -160,7 +158,7 @@ def test_read_raw_fieldtrip(cur_system, version, use_info): cur_fname = test_data_folder_ft / f"raw_{version}.mat" - with ctx: + with _record_warnings(), ctx: raw_fiff_ft = mne.io.read_raw_fieldtrip(cur_fname, info) if cur_system == "BTI" and not use_info: @@ -253,19 +251,19 @@ def test_one_channel_elec_bug(version): @pytest.mark.filterwarnings("ignore:.*parse meas date.*:RuntimeWarning") @pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning") @pytest.mark.parametrize("version", all_versions) -@pytest.mark.parametrize("type", ["averaged", "epoched", "raw"]) -def test_throw_exception_on_cellarray(version, type): +@pytest.mark.parametrize("type_", ["averaged", "epoched", "raw"]) +def test_throw_exception_on_cellarray(version, type_): """Test for a meaningful exception when the data is a cell array.""" - fname = get_data_paths("cellarray") / f"{type}_{version}.mat" + fname = get_data_paths("cellarray") / f"{type_}_{version}.mat" info = get_raw_info("CNT") with pytest.raises( RuntimeError, match="Loading of data in cell arrays " "is not supported" ): - if type == "averaged": + if type_ == "averaged": mne.read_evoked_fieldtrip(fname, info) - elif type == "epoched": + elif type_ == "epoched": mne.read_epochs_fieldtrip(fname, info) - elif type == "raw": + elif type_ == "raw": mne.io.read_raw_fieldtrip(fname, info) diff --git a/mne/io/fieldtrip/utils.py b/mne/io/fieldtrip/utils.py index c4950d45bea..9a4274f6a43 100644 --- a/mne/io/fieldtrip/utils.py +++ b/mne/io/fieldtrip/utils.py @@ -54,9 +54,8 @@ def _create_info(ft_struct, raw_info): if missing_channels: warn( "The following channels are present in the FieldTrip data " - "but cannot be found in the provided info: %s.\n" + f"but cannot be found in the provided info: {str(missing_channels)}.\n" "These channels will be removed from the resulting data!" - % (str(missing_channels),) ) missing_chan_idx = [ch_names.index(ch) for ch in missing_channels] @@ -174,8 +173,8 @@ def _create_info_chs_dig(ft_struct): cur_ch["coil_type"] = FIFF.FIFFV_COIL_NONE else: warn( - "Cannot guess the correct type of channel %s. Making " - "it a MISC channel." % (cur_channel_label,) + f"Cannot guess the correct type of channel {cur_channel_label}. " + "Making it a MISC channel." ) cur_ch["kind"] = FIFF.FIFFV_MISC_CH cur_ch["coil_type"] = FIFF.FIFFV_COIL_NONE @@ -363,7 +362,7 @@ def _process_channel_meg(cur_ch, grad): cur_ch["coil_type"] = FIFF.FIFFV_COIL_AXIAL_GRAD_5CM cur_ch["unit"] = FIFF.FIFF_UNIT_T else: - raise RuntimeError("Unexpected coil type: %s." % (chantype,)) + raise RuntimeError(f"Unexpected coil type: {chantype}.") cur_ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py index d81fd99c556..54bfe9e1921 100644 --- a/mne/io/fiff/raw.py +++ b/mne/io/fiff/raw.py @@ -16,7 +16,7 @@ from ..._fiff.constants import FIFF from ..._fiff.meas_info import read_meas_info from ..._fiff.open import _fiff_get_fid, _get_next_fname, fiff_open -from ..._fiff.tag import read_tag, read_tag_info +from ..._fiff.tag import _call_dict, read_tag from ..._fiff.tree import dir_tree_find from ..._fiff.utils import _mult_cal_one from ...annotations import Annotations, _read_annotations_fif @@ -97,7 +97,7 @@ def __init__( preload=False, on_split_missing="raise", verbose=None, - ): # noqa: D102 + ): raws = [] do_check_ext = not _file_like(fname) next_fname = fname @@ -124,7 +124,7 @@ def __init__( fname = None # noqa _check_raw_compatibility(raws) - super(Raw, self).__init__( + super().__init__( copy.deepcopy(raws[0].info), False, [r.first_samp for r in raws], @@ -255,48 +255,40 @@ def _read_raw_file( nskip = 0 orig_format = None + _byte_dict = { + FIFF.FIFFT_DAU_PACK16: 2, + FIFF.FIFFT_SHORT: 2, + FIFF.FIFFT_FLOAT: 4, + FIFF.FIFFT_DOUBLE: 8, + FIFF.FIFFT_INT: 4, + FIFF.FIFFT_COMPLEX_FLOAT: 8, + FIFF.FIFFT_COMPLEX_DOUBLE: 16, + } + _orig_format_dict = { + FIFF.FIFFT_DAU_PACK16: "short", + FIFF.FIFFT_SHORT: "short", + FIFF.FIFFT_FLOAT: "single", + FIFF.FIFFT_DOUBLE: "double", + FIFF.FIFFT_INT: "int", + FIFF.FIFFT_COMPLEX_FLOAT: "single", + FIFF.FIFFT_COMPLEX_DOUBLE: "double", + } + for k in range(first, nent): ent = directory[k] # There can be skips in the data (e.g., if the user unclicked) # an re-clicked the button - if ent.kind == FIFF.FIFF_DATA_SKIP: - tag = read_tag(fid, ent.pos) - nskip = int(tag.data.item()) - elif ent.kind == FIFF.FIFF_DATA_BUFFER: + if ent.kind == FIFF.FIFF_DATA_BUFFER: # Figure out the number of samples in this buffer - if ent.type == FIFF.FIFFT_DAU_PACK16: - nsamp = ent.size // (2 * nchan) - elif ent.type == FIFF.FIFFT_SHORT: - nsamp = ent.size // (2 * nchan) - elif ent.type == FIFF.FIFFT_FLOAT: - nsamp = ent.size // (4 * nchan) - elif ent.type == FIFF.FIFFT_DOUBLE: - nsamp = ent.size // (8 * nchan) - elif ent.type == FIFF.FIFFT_INT: - nsamp = ent.size // (4 * nchan) - elif ent.type == FIFF.FIFFT_COMPLEX_FLOAT: - nsamp = ent.size // (8 * nchan) - elif ent.type == FIFF.FIFFT_COMPLEX_DOUBLE: - nsamp = ent.size // (16 * nchan) - else: - raise ValueError( - "Cannot handle data buffers of type " "%d" % ent.type - ) + try: + div = _byte_dict[ent.type] + except KeyError: + raise RuntimeError( + f"Cannot handle data buffers of type {ent.type}" + ) from None + nsamp = ent.size // (div * nchan) if orig_format is None: - if ent.type == FIFF.FIFFT_DAU_PACK16: - orig_format = "short" - elif ent.type == FIFF.FIFFT_SHORT: - orig_format = "short" - elif ent.type == FIFF.FIFFT_FLOAT: - orig_format = "single" - elif ent.type == FIFF.FIFFT_DOUBLE: - orig_format = "double" - elif ent.type == FIFF.FIFFT_INT: - orig_format = "int" - elif ent.type == FIFF.FIFFT_COMPLEX_FLOAT: - orig_format = "single" - elif ent.type == FIFF.FIFFT_COMPLEX_DOUBLE: - orig_format = "double" + orig_format = _orig_format_dict[ent.type] # Do we have an initial skip pending? if first_skip > 0: @@ -327,6 +319,9 @@ def _read_raw_file( ) ) first_samp += nsamp + elif ent.kind == FIFF.FIFF_DATA_SKIP: + tag = read_tag(fid, ent.pos) + nskip = int(tag.data.item()) next_fname = _get_next_fname(fid, fname_rep, tree) @@ -381,22 +376,17 @@ def _dtype(self): if self._dtype_ is not None: return self._dtype_ dtype = None - for raw_extra, filename in zip(self._raw_extras, self._filenames): + for raw_extra in self._raw_extras: for ent in raw_extra["ent"]: if ent is not None: - with _fiff_get_fid(filename) as fid: - fid.seek(ent.pos, 0) - tag = read_tag_info(fid) - if tag is not None: - if tag.type in ( - FIFF.FIFFT_COMPLEX_FLOAT, - FIFF.FIFFT_COMPLEX_DOUBLE, - ): - dtype = np.complex128 - else: - dtype = np.float64 - if dtype is not None: - break + if ent.type in ( + FIFF.FIFFT_COMPLEX_FLOAT, + FIFF.FIFFT_COMPLEX_DOUBLE, + ): + dtype = np.complex128 + else: + dtype = np.float64 + break if dtype is not None: break if dtype is None: @@ -421,27 +411,31 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): first_pick = max(start - first, 0) last_pick = min(nsamp, stop - first) picksamp = last_pick - first_pick - # only read data if it exists - if ent is not None: - one = read_tag( - fid, - ent.pos, - shape=(nsamp, nchan), - rlims=(first_pick, last_pick), - ).data - try: - one.shape = (picksamp, nchan) - except AttributeError: # one is None - n_bad += picksamp - else: - _mult_cal_one( - data[:, offset : (offset + picksamp)], - one.T, - idx, - cals, - mult, - ) + this_start = offset offset += picksamp + this_stop = offset + # only read data if it exists + if ent is None: + continue # just use zeros for gaps + # faster to always read full tag, taking advantage of knowing the header + # already (cutting out some of read_tag) ... + fid.seek(ent.pos + 16, 0) + one = _call_dict[ent.type](fid, ent, shape=None, rlims=None) + try: + one.shape = (nsamp, nchan) + except AttributeError: # one is None + n_bad += picksamp + else: + # ... then pick samples we want + if first_pick != 0 or last_pick != nsamp: + one = one[first_pick:last_pick] + _mult_cal_one( + data[:, this_start:this_stop], + one.T, + idx, + cals, + mult, + ) if n_bad: warn( f"FIF raw buffer could not be read, acquisition error " @@ -502,7 +496,7 @@ def _check_entry(first, nent): @fill_doc def read_raw_fif( fname, allow_maxshield=False, preload=False, on_split_missing="raise", verbose=None -): +) -> Raw: """Reader function for Raw FIF data. Parameters diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 985688a9c7e..dc3c732979d 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -30,8 +30,7 @@ pick_types, ) from mne._fiff.constants import FIFF -from mne._fiff.open import read_tag, read_tag_info -from mne._fiff.tag import _read_tag_header +from mne._fiff.tag import _read_tag_header, read_tag from mne.annotations import Annotations from mne.datasets import testing from mne.filter import filter_data @@ -42,6 +41,7 @@ _record_warnings, assert_and_remove_boundary_annot, assert_object_equal, + catch_logging, requires_mne, run_subprocess, ) @@ -52,7 +52,7 @@ ms_fname = testing_path / "SSS" / "test_move_anon_raw.fif" skip_fname = testing_path / "misc" / "intervalrecording_raw.fif" -base_dir = Path(__file__).parent.parent.parent / "tests" / "data" +base_dir = Path(__file__).parents[2] / "tests" / "data" test_fif_fname = base_dir / "test_raw.fif" test_fif_gz_fname = base_dir / "test_raw.fif.gz" ctf_fname = base_dir / "test_ctf_raw.fif" @@ -658,9 +658,9 @@ def test_split_files(tmp_path, mod, monkeypatch): m.setattr(base, "MAX_N_SPLITS", 2) with pytest.raises(RuntimeError, match="Exceeded maximum number of splits"): raw.save(fname, split_naming="bids", **kwargs) - fname_1, fname_2, fname_3 = [ + fname_1, fname_2, fname_3 = ( (tmp_path / f"test_split-{ii:02d}_{mod}.fif") for ii in range(1, 4) - ] + ) assert not fname.is_file() assert fname_1.is_file() assert fname_2.is_file() @@ -669,12 +669,40 @@ def test_split_files(tmp_path, mod, monkeypatch): m.setattr(base, "MAX_N_SPLITS", 2) with pytest.raises(RuntimeError, match="Exceeded maximum number of splits"): raw.save(fname, split_naming="neuromag", **kwargs) - fname_2, fname_3 = [(tmp_path / f"test_{mod}-{ii}.fif") for ii in range(1, 3)] + fname_2, fname_3 = ((tmp_path / f"test_{mod}-{ii}.fif") for ii in range(1, 3)) assert fname.is_file() assert fname_2.is_file() assert not fname_3.is_file() +def test_bids_split_files(tmp_path): + """Test that BIDS split files are written safely.""" + mne_bids = pytest.importorskip("mne_bids") + bids_path = mne_bids.BIDSPath( + root=tmp_path, + subject="01", + datatype="meg", + split="01", + suffix="raw", + extension=".fif", + check=False, + ) + (tmp_path / "sub-01" / "meg").mkdir(parents=True) + raw = read_raw_fif(test_fif_fname) + save_kwargs = dict( + buffer_size_sec=1.0, split_size="10MB", split_naming="bids", verbose=True + ) + with pytest.raises(ValueError, match="Passing a BIDSPath"): + raw.save(bids_path, **save_kwargs) + bids_path.split = None + want_paths = [Path(bids_path.copy().update(split=ii).fpath) for ii in range(1, 3)] + for want_path in want_paths: + assert not want_path.is_file() + raw.save(bids_path, **save_kwargs) + for want_path in want_paths: + assert want_path.is_file() + + def _err(*args, **kwargs): raise RuntimeError("Killed mid-write") @@ -770,6 +798,10 @@ def test_io_raw(tmp_path): sl = slice(inds[0], inds[1]) assert_allclose(data[:, sl], raw[:, sl][0], rtol=1e-6, atol=1e-20) + # missing dir raises informative error + with pytest.raises(FileNotFoundError, match="parent directory does not exist"): + raw.save(tmp_path / "foo" / "test_raw.fif", split_size="1MB") + @pytest.mark.parametrize( "fname_in, fname_out", @@ -897,7 +929,7 @@ def test_io_complex(tmp_path, dtype): @testing.requires_testing_data def test_getitem(): """Test getitem/indexing of Raw.""" - for preload in [False, True, "memmap.dat"]: + for preload in [False, True, "memmap1.dat"]: raw = read_raw_fif(fif_fname, preload=preload) data, times = raw[0, :] data1, times1 = raw[0] @@ -1016,7 +1048,7 @@ def test_proj(tmp_path): @testing.requires_testing_data -@pytest.mark.parametrize("preload", [False, True, "memmap.dat"]) +@pytest.mark.parametrize("preload", [False, True, "memmap2.dat"]) def test_preload_modify(preload, tmp_path): """Test preloading and modifying data.""" rng = np.random.RandomState(0) @@ -1260,7 +1292,7 @@ def test_crop(): assert raw1[:][0].shape == (1, 2001) # degenerate - with pytest.raises(ValueError, match="No samples.*when include_tmax=Fals"): + with pytest.raises(ValueError, match="No samples.*when include_tmax=False"): raw.crop(0, 0, include_tmax=False) # edge cases cropping to exact duration +/- 1 sample @@ -1290,23 +1322,28 @@ def test_resample_equiv(): @pytest.mark.slowtest @testing.requires_testing_data @pytest.mark.parametrize( - "preload, n, npad", + "preload, n, npad, method", [ - (True, 512, "auto"), - (False, 512, 0), + (True, 512, "auto", "fft"), + (True, 512, "auto", "polyphase"), + (False, 512, 0, "fft"), # only test one with non-preload because it's slow ], ) -def test_resample(tmp_path, preload, n, npad): +def test_resample(tmp_path, preload, n, npad, method): """Test resample (with I/O and multiple files).""" + kwargs = dict(npad=npad, method=method) raw = read_raw_fif(fif_fname) raw.crop(0, raw.times[n - 1]) + # Reduce to a few MEG channels and a few stim channels to speed up + n_meg = 5 + raw.pick(raw.ch_names[:n_meg] + raw.ch_names[312:320]) # 10 MEG + 3 STIM + 5 EEG assert len(raw.times) == n if preload: raw.load_data() raw_resamp = raw.copy() sfreq = raw.info["sfreq"] # test parallel on upsample - raw_resamp.resample(sfreq * 2, n_jobs=2, npad=npad) + raw_resamp.resample(sfreq * 2, n_jobs=2, **kwargs) assert raw_resamp.n_times == len(raw_resamp.times) raw_resamp.save(tmp_path / "raw_resamp-raw.fif") raw_resamp = read_raw_fif(tmp_path / "raw_resamp-raw.fif", preload=True) @@ -1315,7 +1352,13 @@ def test_resample(tmp_path, preload, n, npad): assert raw_resamp.get_data().shape[1] == raw_resamp.n_times assert raw.get_data().shape[0] == raw_resamp._data.shape[0] # test non-parallel on downsample - raw_resamp.resample(sfreq, n_jobs=None, npad=npad) + with catch_logging() as log: + raw_resamp.resample(sfreq, n_jobs=None, verbose=True, **kwargs) + log = log.getvalue() + if method == "fft": + assert "neighborhood" not in log + else: + assert "neighborhood" in log assert raw_resamp.info["sfreq"] == sfreq assert raw.get_data().shape == raw_resamp._data.shape assert raw.first_samp == raw_resamp.first_samp @@ -1324,18 +1367,12 @@ def test_resample(tmp_path, preload, n, npad): # works (hooray). Note that the stim channels had to be sub-sampled # without filtering to be accurately preserved # note we have to treat MEG and EEG+STIM channels differently (tols) - assert_allclose( - raw.get_data()[:306, 200:-200], - raw_resamp._data[:306, 200:-200], - rtol=1e-2, - atol=1e-12, - ) - assert_allclose( - raw.get_data()[306:, 200:-200], - raw_resamp._data[306:, 200:-200], - rtol=1e-2, - atol=1e-7, - ) + want_meg = raw.get_data()[:n_meg, 200:-200] + got_meg = raw_resamp._data[:n_meg, 200:-200] + want_non_meg = raw.get_data()[n_meg:, 200:-200] + got_non_meg = raw_resamp._data[n_meg:, 200:-200] + assert_allclose(got_meg, want_meg, rtol=1e-2, atol=1e-12) + assert_allclose(want_non_meg, got_non_meg, rtol=1e-2, atol=1e-7) # now check multiple file support w/resampling, as order of operations # (concat, resample) should not affect our data @@ -1344,9 +1381,9 @@ def test_resample(tmp_path, preload, n, npad): raw3 = raw.copy() raw4 = raw.copy() raw1 = concatenate_raws([raw1, raw2]) - raw1.resample(10.0, npad=npad) - raw3.resample(10.0, npad=npad) - raw4.resample(10.0, npad=npad) + raw1.resample(10.0, **kwargs) + raw3.resample(10.0, **kwargs) + raw4.resample(10.0, **kwargs) raw3 = concatenate_raws([raw3, raw4]) assert_array_equal(raw1._data, raw3._data) assert_array_equal(raw1._first_samps, raw3._first_samps) @@ -1364,12 +1401,12 @@ def test_resample(tmp_path, preload, n, npad): # basic decimation stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(8.0, npad=npad)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) + assert_allclose(raw.resample(8.0, **kwargs)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) # decimation of multiple stim channels raw = RawArray(2 * [stim], create_info(2, len(stim), 2 * ["stim"])) assert_allclose( - raw.resample(8.0, npad=npad, verbose="error")._data, + raw.resample(8.0, **kwargs, verbose="error")._data, [[1, 1, 0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1, 0, 0]], ) @@ -1377,19 +1414,19 @@ def test_resample(tmp_path, preload, n, npad): # done naively stim = [0, 0, 0, 1, 1, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(4.0, npad=npad)._data, [[0, 1, 1, 0]]) + assert_allclose(raw.resample(4.0, **kwargs)._data, [[0, 1, 1, 0]]) # two events are merged in this case (warning) stim = [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(8.0, npad=npad) + raw.resample(8.0, **kwargs) # events are dropped in this case (warning) stim = [0, 1, 1, 0, 0, 1, 1, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(4.0, npad=npad) + raw.resample(4.0, **kwargs) # test resampling events: this should no longer give a warning # we often have first_samp != 0, include it here too @@ -1400,7 +1437,7 @@ def test_resample(tmp_path, preload, n, npad): first_samp = len(stim) // 2 raw = RawArray([stim], create_info(1, o_sfreq, ["stim"]), first_samp=first_samp) events = find_events(raw) - raw, events = raw.resample(n_sfreq, events=events, npad=npad) + raw, events = raw.resample(n_sfreq, events=events, **kwargs) # Try index into raw.times with resampled events: raw.times[events[:, 0] - raw.first_samp] n_fsamp = int(first_samp * sfreq_ratio) # how it's calc'd in base.py @@ -1425,16 +1462,16 @@ def test_resample(tmp_path, preload, n, npad): # test copy flag stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - raw_resampled = raw.copy().resample(4.0, npad=npad) + raw_resampled = raw.copy().resample(4.0, **kwargs) assert raw_resampled is not raw - raw_resampled = raw.resample(4.0, npad=npad) + raw_resampled = raw.resample(4.0, **kwargs) assert raw_resampled is raw # resample should still work even when no stim channel is present raw = RawArray(np.random.randn(1, 100), create_info(1, 100, ["eeg"])) with raw.info._unlock(): raw.info["lowpass"] = 50.0 - raw.resample(10, npad=npad) + raw.resample(10, **kwargs) assert raw.info["lowpass"] == 5.0 assert len(raw) == 10 @@ -1917,7 +1954,7 @@ def test_equalize_channels(): def test_memmap(tmp_path): """Test some interesting memmapping cases.""" # concatenate_raw - memmaps = [str(tmp_path / str(ii)) for ii in range(3)] + memmaps = [str(tmp_path / str(ii)) for ii in range(4)] raw_0 = read_raw_fif(test_fif_fname, preload=memmaps[0]) assert raw_0._data.filename == memmaps[0] raw_1 = read_raw_fif(test_fif_fname, preload=memmaps[1]) @@ -1942,8 +1979,8 @@ def test_memmap(tmp_path): # now let's see if .copy() actually works; it does, but eventually # we should make it optionally memmap to a new filename rather than # create an in-memory version (filename=None) - raw_0 = read_raw_fif(test_fif_fname, preload=memmaps[0]) - assert raw_0._data.filename == memmaps[0] + raw_0 = read_raw_fif(test_fif_fname, preload=memmaps[3]) + assert raw_0._data.filename == memmaps[3] assert raw_0._data[:1, 3:5].all() raw_1 = raw_0.copy() assert isinstance(raw_1._data, np.memmap) @@ -2034,8 +2071,7 @@ def test_bad_acq(fname): raw = read_raw_fif(fname, allow_maxshield="yes").load_data() with open(fname, "rb") as fid: for ent in raw._raw_extras[0]["ent"]: - fid.seek(ent.pos, 0) - tag = _read_tag_header(fid) + tag = _read_tag_header(fid, ent.pos) # hack these, others (kind, type) should be correct tag.pos, tag.next = ent.pos, ent.next assert tag == ent @@ -2075,16 +2111,19 @@ def test_corrupted(tmp_path, offset): # at the end, so use the skip one (straight from acq). raw = read_raw_fif(skip_fname) with open(skip_fname, "rb") as fid: - tag = read_tag_info(fid) - tag = read_tag(fid) - dirpos = int(tag.data.item()) + file_id_tag = read_tag(fid, 0) + dir_pos_tag = read_tag(fid, file_id_tag.next_pos) + dirpos = int(dir_pos_tag.data.item()) assert dirpos == 12641532 fid.seek(0) data = fid.read(dirpos + offset) bad_fname = tmp_path / "test_raw.fif" with open(bad_fname, "wb") as fid: fid.write(data) - with pytest.warns(RuntimeWarning, match=".*tag directory.*corrupt.*"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match=".*tag directory.*corrupt.*"), + ): raw_bad = read_raw_fif(bad_fname) assert_allclose(raw.get_data(), raw_bad.get_data()) diff --git a/mne/io/fil/fil.py b/mne/io/fil/fil.py index 08b7778398a..eba8662f342 100644 --- a/mne/io/fil/fil.py +++ b/mne/io/fil/fil.py @@ -25,7 +25,9 @@ @verbose -def read_raw_fil(binfile, precision="single", preload=False, *, verbose=None): +def read_raw_fil( + binfile, precision="single", preload=False, *, verbose=None +) -> "RawFIL": """Raw object from FIL-OPMEG formatted data. Parameters @@ -115,11 +117,11 @@ def __init__(self, binfile, precision="single", preload=False): else: warn("No sensor position information found.") - with open(files["meg"], "r") as fid: + with open(files["meg"]) as fid: meg = json.load(fid) info = _compose_meas_info(meg, chans) - super(RawFIL, self).__init__( + super().__init__( info, preload, filenames=[files["bin"]], @@ -129,7 +131,7 @@ def __init__(self, binfile, precision="single", preload=False): ) if files["coordsystem"].is_file(): - with open(files["coordsystem"], "r") as fid: + with open(files["coordsystem"]) as fid: csys = json.load(fid) hc = csys["HeadCoilCoordinates"] @@ -311,8 +313,8 @@ def _from_tsv(fname, dtypes=None): dtypes = [dtypes] * info.shape[1] if not len(dtypes) == info.shape[1]: raise ValueError( - "dtypes length mismatch. Provided: {0}, " - "Expected: {1}".format(len(dtypes), info.shape[1]) + f"dtypes length mismatch. Provided: {len(dtypes)}, " + f"Expected: {info.shape[1]}" ) for i, name in enumerate(column_names): data_dict[name] = info[:, i].astype(dtypes[i]).tolist() diff --git a/mne/io/hitachi/hitachi.py b/mne/io/hitachi/hitachi.py index 0f046bb37e6..4b5c0b9fac6 100644 --- a/mne/io/hitachi/hitachi.py +++ b/mne/io/hitachi/hitachi.py @@ -17,7 +17,7 @@ @fill_doc -def read_raw_hitachi(fname, preload=False, verbose=None): +def read_raw_hitachi(fname, preload=False, verbose=None) -> "RawHitachi": """Reader for a Hitachi fNIRS recording. Parameters @@ -268,7 +268,7 @@ def _get_hitachi_info(fname, S_offset, D_offset, ignore_names): "3x11": "ETG-4000", } _check_option("Hitachi mode", mode, sorted(names)) - n_row, n_col = [int(x) for x in mode.split("x")] + n_row, n_col = (int(x) for x in mode.split("x")) logger.info(f"Constructing pairing matrix for {names[mode]} ({mode})") pairs = _compute_pairs(n_row, n_col, n=1 + (mode == "3x3")) assert n_nirs == len(pairs) * 2 diff --git a/mne/io/hitachi/tests/test_hitachi.py b/mne/io/hitachi/tests/test_hitachi.py index edad56dc75e..300af7cf5e8 100644 --- a/mne/io/hitachi/tests/test_hitachi.py +++ b/mne/io/hitachi/tests/test_hitachi.py @@ -22,9 +22,7 @@ ) CONTENTS = dict() -CONTENTS[ - "1.18" -] = b"""\ +CONTENTS["1.18"] = b"""\ Header,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, File Version,1.18,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, Patient Information,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, @@ -129,9 +127,7 @@ """ # noqa: E501 -CONTENTS[ - "1.25" -] = b"""\ +CONTENTS["1.25"] = b"""\ Header File Version,1.25 Patient Information diff --git a/mne/io/kit/coreg.py b/mne/io/kit/coreg.py index 4e5bd0bdf8f..3e691249790 100644 --- a/mne/io/kit/coreg.py +++ b/mne/io/kit/coreg.py @@ -72,7 +72,7 @@ def read_mrk(fname): elif fname.suffix == ".pickled": warn( "Reading pickled files is unsafe and not future compatible, save " - "to a standard format (text or FIF) instea, e.g. with:\n" + "to a standard format (text or FIF) instead, e.g. with:\n" r"np.savetxt(fid, pts, delimiter=\"\\t\", newline=\"\\n\")", FutureWarning, ) @@ -86,7 +86,7 @@ def read_mrk(fname): # check output mrk_points = np.asarray(mrk_points) if mrk_points.shape != (5, 3): - err = "%r is no marker file, shape is " "%s" % (fname, mrk_points.shape) + err = f"{repr(fname)} is no marker file, shape is {mrk_points.shape}" raise ValueError(err) return mrk_points @@ -114,7 +114,7 @@ def read_sns(fname): return locs -def _set_dig_kit(mrk, elp, hsp, eeg): +def _set_dig_kit(mrk, elp, hsp, eeg, *, bad_coils=()): """Add landmark points and head shape data to the KIT instance. Digitizer data (elp and hsp) are represented in [mm] in the Polhemus @@ -133,6 +133,9 @@ def _set_dig_kit(mrk, elp, hsp, eeg): Digitizer head shape points, or path to head shape file. If more than 10`000 points are in the head shape, they are automatically decimated. + bad_coils : list + Indices of bad marker coils (up to two). Bad coils will be excluded + when computing the device-head transformation. eeg : dict Ordered dict of EEG dig points. @@ -154,28 +157,31 @@ def _set_dig_kit(mrk, elp, hsp, eeg): hsp = _decimate_points(hsp, res=0.005) n_new = len(hsp) warn( - "The selected head shape contained {n_in} points, which is " - "more than recommended ({n_rec}), and was automatically " - "downsampled to {n_new} points. The preferred way to " - "downsample is using FastScan.".format( - n_in=n_pts, n_rec=KIT.DIG_POINTS, n_new=n_new - ) + f"The selected head shape contained {n_pts} points, which is more than " + f"recommended ({KIT.DIG_POINTS}), and was automatically downsampled to " + f"{n_new} points. The preferred way to downsample is using FastScan." ) if isinstance(elp, (str, Path, PathLike)): elp_points = _read_dig_kit(elp) if len(elp_points) != 8: raise ValueError( - "File %r should contain 8 points; got shape " - "%s." % (elp, elp_points.shape) + f"File {repr(elp)} should contain 8 points; got shape " + f"{elp_points.shape}." ) elp = elp_points - elif len(elp) not in (6, 7, 8): - raise ValueError( - "ELP should contain 6 ~ 8 points; got shape " "%s." % (elp.shape,) - ) + if len(bad_coils) > 0: + elp = np.delete(elp, np.array(bad_coils) + 3, 0) + # check we have at least 3 marker coils (whether read from file or + # passed in directly) + if len(elp) not in (6, 7, 8): + raise ValueError(f"ELP should contain 6 ~ 8 points; got shape {elp.shape}.") if isinstance(mrk, (str, Path, PathLike)): mrk = read_mrk(mrk) + if len(bad_coils) > 0: + mrk = np.delete(mrk, bad_coils, 0) + if len(mrk) not in (3, 4, 5): + raise ValueError(f"MRK should contain 3 ~ 5 points; got shape {mrk.shape}.") mrk = apply_trans(als_ras_trans, mrk) diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py index fa9ff8cfeea..71cc38e6c94 100644 --- a/mne/io/kit/kit.py +++ b/mne/io/kit/kit.py @@ -43,7 +43,7 @@ INT32 = " RawKIT: r"""Reader function for Ricoh/KIT conversion to FIF. Parameters @@ -932,6 +945,7 @@ def read_raw_kit( Force reading old data that is not officially supported. Alternatively, read and re-save the data with the KIT MEG Laboratory application. %(standardize_names)s + %(kit_badcoils)s %(verbose)s Returns @@ -966,6 +980,7 @@ def read_raw_kit( stim_code=stim_code, allow_unknown_format=allow_unknown_format, standardize_names=standardize_names, + bad_coils=bad_coils, verbose=verbose, ) @@ -981,7 +996,7 @@ def read_epochs_kit( allow_unknown_format=False, standardize_names=False, verbose=None, -): +) -> EpochsKIT: """Reader function for Ricoh/KIT epochs files. Parameters diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py deleted file mode 100644 index f971fff18d4..00000000000 --- a/mne/io/meas_info.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "meas_info") diff --git a/mne/io/nedf/nedf.py b/mne/io/nedf/nedf.py index 55661511f83..df6030f31c1 100644 --- a/mne/io/nedf/nedf.py +++ b/mne/io/nedf/nedf.py @@ -6,11 +6,10 @@ from datetime import datetime, timezone import numpy as np -from defusedxml import ElementTree from ..._fiff.meas_info import create_info from ..._fiff.utils import _mult_cal_one -from ...utils import _check_fname, verbose, warn +from ...utils import _check_fname, _soft_import, verbose, warn from ..base import BaseRaw @@ -52,6 +51,7 @@ def _parse_nedf_header(header): n_samples : int The number of data samples. """ + defusedxml = _soft_import("defusedxml", "reading NEDF data") info = {} # nedf files have three accelerometer channels sampled at 100Hz followed # by five EEG samples + TTL trigger sampled at 500Hz @@ -69,7 +69,7 @@ def _parse_nedf_header(header): headerend = header.find(b"\0") if headerend == -1: raise RuntimeError("End of header null not found") - headerxml = ElementTree.fromstring(header[:headerend]) + headerxml = defusedxml.ElementTree.fromstring(header[:headerend]) nedfversion = headerxml.findtext("NEDFversion", "") if nedfversion not in ["1.3", "1.4"]: warn("NEDFversion unsupported, use with caution") @@ -202,7 +202,7 @@ def _convert_eeg(chunks, n_eeg, n_tot): @verbose -def read_raw_nedf(filename, preload=False, verbose=None): +def read_raw_nedf(filename, preload=False, verbose=None) -> RawNedf: """Read NeuroElectrics .nedf files. NEDF file versions starting from 1.3 are supported. diff --git a/mne/io/nedf/tests/test_nedf.py b/mne/io/nedf/tests/test_nedf.py index cf0043bbeeb..f06e1376e59 100644 --- a/mne/io/nedf/tests/test_nedf.py +++ b/mne/io/nedf/tests/test_nedf.py @@ -29,6 +29,8 @@ \x00""" +pytest.importorskip("defusedxml") + @pytest.mark.parametrize("nacc", (0, 3)) def test_nedf_header_parser(nacc): diff --git a/mne/io/neuralynx/neuralynx.py b/mne/io/neuralynx/neuralynx.py index 4bfad0fea2c..0390fb70071 100644 --- a/mne/io/neuralynx/neuralynx.py +++ b/mne/io/neuralynx/neuralynx.py @@ -1,5 +1,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import datetime import glob import os @@ -7,6 +8,7 @@ from ..._fiff.meas_info import create_info from ..._fiff.utils import _mult_cal_one +from ...annotations import Annotations from ...utils import _check_fname, _soft_import, fill_doc, logger, verbose from ..base import BaseRaw @@ -14,7 +16,7 @@ @fill_doc def read_raw_neuralynx( fname, *, preload=False, exclude_fname_patterns=None, verbose=None -): +) -> "RawNeuralynx": """Reader for Neuralynx files. Parameters @@ -37,8 +39,39 @@ def read_raw_neuralynx( See Also -------- mne.io.Raw : Documentation of attributes and methods of RawNeuralynx. + + Notes + ----- + Neuralynx files are read from disk using the `Neo package + `__. + Currently, only reading of the ``.ncs files`` is supported. + + ``raw.info["meas_date"]`` is read from the ``recording_opened`` property + of the first ``.ncs`` file (i.e. channel) in the dataset (a warning is issued + if files have different dates of acquisition). + + Channel-specific high and lowpass frequencies of online filters are determined + based on the ``DspLowCutFrequency`` and ``DspHighCutFrequency`` header fields, + respectively. If no filters were used for a channel, the default lowpass is set + to the Nyquist frequency and the default highpass is set to 0. + If channels have different high/low cutoffs, ``raw.info["highpass"]`` and + ``raw.info["lowpass"]`` are then set to the maximum highpass and minimumlowpass + values across channels, respectively. + + Other header variables can be inspected using Neo directly. For example:: + + from neo.io import NeuralynxIO # doctest: +SKIP + fname = 'path/to/your/data' # doctest: +SKIP + nlx_reader = NeuralynxIO(dirname=fname) # doctest: +SKIP + print(nlx_reader.header) # doctest: +SKIP + print(nlx_reader.file_headers.items()) # doctest: +SKIP """ - return RawNeuralynx(fname, preload, verbose, exclude_fname_patterns) + return RawNeuralynx( + fname, + preload=preload, + exclude_fname_patterns=exclude_fname_patterns, + verbose=verbose, + ) @fill_doc @@ -47,13 +80,18 @@ class RawNeuralynx(BaseRaw): @verbose def __init__( - self, fname, preload=False, verbose=None, exclude_fname_patterns: list = None + self, + fname, + *, + preload=False, + exclude_fname_patterns=None, + verbose=None, ): + fname = _check_fname(fname, "read", True, "fname", need_dir=True) + _soft_import("neo", "Reading NeuralynxIO files", strict=True) from neo.io import NeuralynxIO - fname = _check_fname(fname, "read", True, "fname", need_dir=True) - logger.info(f"Checking files in {fname}") # construct a list of filenames to ignore @@ -71,12 +109,18 @@ def __init__( try: nlx_reader = NeuralynxIO(dirname=fname, exclude_filename=exclude_fnames) except ValueError as e: - raise ValueError( - "It seems some .ncs channels might have different number of samples. " - + "This is likely due to different sampling rates. " - + "Try excluding them with `exclude_fname_patterns` input arg." - + f"\nOriginal neo.NeuralynxIO.parse_header() ValueError:\n{e}" - ) + # give a more informative error message and what the user can do about it + if "Incompatible section structures across streams" in str(e): + raise ValueError( + "It seems .ncs channels have different numbers of samples. " + + "This is likely due to different sampling rates. " + + "Try reading in only channels with uniform sampling rate " + + "by excluding other channels with `exclude_fname_patterns` " + + "input argument." + + f"\nOriginal neo.NeuralynxRawIO ValueError:\n{e}" + ) from None + else: + raise info = create_info( ch_types="seeg", @@ -84,35 +128,177 @@ def __init__( sfreq=nlx_reader.get_signal_sampling_rate(), ) - # find total number of samples per .ncs file (`channel`) by summing - # the sample sizes of all segments + ncs_fnames = nlx_reader.ncs_filenames.values() + ncs_hdrs = [ + hdr + for hdr_key, hdr in nlx_reader.file_headers.items() + if hdr_key in ncs_fnames + ] + + # if all files have the same recording_opened date, write it to info + meas_dates = np.array([hdr["recording_opened"] for hdr in ncs_hdrs]) + # to be sure, only write if all dates are the same + meas_diff = [] + for md in meas_dates: + meas_diff.append((md - meas_dates[0]).total_seconds()) + + # tolerate a +/-1 second meas_date difference (arbitrary threshold) + # else issue a warning + warn_meas = (np.abs(meas_diff) > 1.0).any() + if warn_meas: + logger.warning( + "Not all .ncs files have the same recording_opened date. " + + "Writing meas_date based on the first .ncs file." + ) + + # Neuarlynx allows channel specific low/highpass filters + # if not enabled, assume default lowpass = nyquist, highpass = 0 + default_lowpass = info["sfreq"] / 2 # nyquist + default_highpass = 0 + + has_hp = [hdr["DSPLowCutFilterEnabled"] for hdr in ncs_hdrs] + has_lp = [hdr["DSPHighCutFilterEnabled"] for hdr in ncs_hdrs] + if not all(has_hp) or not all(has_lp): + logger.warning( + "Not all .ncs files have the same high/lowpass filter settings. " + + "Assuming default highpass = 0, lowpass = nyquist." + ) + + highpass_freqs = [ + float(hdr["DspLowCutFrequency"]) + if hdr["DSPLowCutFilterEnabled"] + else default_highpass + for hdr in ncs_hdrs + ] + + lowpass_freqs = [ + float(hdr["DspHighCutFrequency"]) + if hdr["DSPHighCutFilterEnabled"] + else default_lowpass + for hdr in ncs_hdrs + ] + + with info._unlock(): + info["meas_date"] = meas_dates[0].astimezone(datetime.timezone.utc) + info["highpass"] = np.max(highpass_freqs) + info["lowpass"] = np.min(lowpass_freqs) + + # Neo reads only valid contiguous .ncs samples grouped as segments n_segments = nlx_reader.header["nb_segment"][0] block_id = 0 # assumes there's only one block of recording - n_total_samples = sum( - nlx_reader.get_signal_size(block_id, segment) - for segment in range(n_segments) + + # get segment start/stop times + start_times = np.array( + [nlx_reader.segment_t_start(block_id, i) for i in range(n_segments)] + ) + stop_times = np.array( + [nlx_reader.segment_t_stop(block_id, i) for i in range(n_segments)] ) - # construct an array of shape (n_total_samples,) indicating - # segment membership for each sample - sample2segment = np.concatenate( + # find discontinuous boundaries (of length n-1) + next_start_times = start_times[1::] + previous_stop_times = stop_times[:-1] + seg_diffs = next_start_times - previous_stop_times + + # mark as discontinuous any two segments that have + # start/stop delta larger than sampling period (1.5/sampling_rate) + logger.info("Checking for temporal discontinuities in Neo data segments.") + delta = 1.5 / info["sfreq"] + gaps = seg_diffs > delta + + seg_gap_dict = {} + + logger.info( + f"N = {gaps.sum()} discontinuous Neo segments detected " + + f"with delta > {delta} sec. " + + "Annotating gaps as BAD_ACQ_SKIP." + if gaps.any() + else "No discontinuities detected." + ) + + gap_starts = stop_times[:-1][gaps] # gap starts at segment offset + gap_stops = start_times[1::][gaps] # gap stops at segment onset + + # (n_gaps,) array of ints giving number of samples per inferred gap + gap_n_samps = np.array( [ - np.full(shape=(nlx_reader.get_signal_size(block_id, i),), fill_value=i) - for i in range(n_segments) + int(round(stop * info["sfreq"])) - int(round(start * info["sfreq"])) + for start, stop in zip(gap_starts, gap_stops) + ] + ).astype(int) # force an int array (if no gaps, empty array is a float) + + # get sort indices for all segments (valid and gap) in ascending order + all_starts_ids = np.argsort(np.concatenate([start_times, gap_starts])) + + # variable indicating whether each segment is a gap or not + gap_indicator = np.concatenate( + [ + np.full(len(start_times), fill_value=0), + np.full(len(gap_starts), fill_value=1), ] ) + gap_indicator = gap_indicator[all_starts_ids].astype(bool) + + # store this in a dict to be passed to _raw_extras + seg_gap_dict = { + "gap_n_samps": gap_n_samps, + "isgap": gap_indicator, # False (data segment) or True (gap segment) + } + + valid_segment_sizes = [ + nlx_reader.get_signal_size(block_id, i) for i in range(n_segments) + ] + + sizes_sorted = np.concatenate([valid_segment_sizes, gap_n_samps])[ + all_starts_ids + ] - super(RawNeuralynx, self).__init__( + # now construct an (n_samples,) indicator variable + sample2segment = np.concatenate( + [np.full(shape=(n,), fill_value=i) for i, n in enumerate(sizes_sorted)] + ) + + # get the start sample index for each gap segment () + gap_start_ids = np.cumsum(np.hstack([[0], sizes_sorted[:-1]]))[gap_indicator] + + # recreate time axis for gap annotations + mne_times = np.arange(0, len(sample2segment)) / info["sfreq"] + + assert len(gap_start_ids) == len(gap_n_samps) + annotations = Annotations( + onset=[mne_times[onset_id] for onset_id in gap_start_ids], + duration=[ + mne_times[onset_id + (n - 1)] - mne_times[onset_id] + for onset_id, n in zip(gap_start_ids, gap_n_samps) + ], + description=["BAD_ACQ_SKIP"] * len(gap_start_ids), + ) + + super().__init__( info=info, - last_samps=[n_total_samples - 1], + last_samps=[sizes_sorted.sum() - 1], filenames=[fname], preload=preload, - raw_extras=[dict(smp2seg=sample2segment, exclude_fnames=exclude_fnames)], + raw_extras=[ + dict( + smp2seg=sample2segment, + exclude_fnames=exclude_fnames, + segment_sizes=sizes_sorted, + seg_gap_dict=seg_gap_dict, + ) + ], ) + self.set_annotations(annotations) + def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" + from neo import AnalogSignal, Segment from neo.io import NeuralynxIO + from neo.io.proxyobjects import AnalogSignalProxy + + # quantities is a dependency of neo so we are guaranteed it exists + from quantities import Hz nlx_reader = NeuralynxIO( dirname=self._filenames[fi], @@ -126,13 +312,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): [len(segment.analogsignals) for segment in neo_block[0].segments] ) == len(neo_block[0].segments) - # collect sizes of each segment - segment_sizes = np.array( - [ - nlx_reader.get_signal_size(0, segment_id) - for segment_id in range(len(neo_block[0].segments)) - ] - ) + segment_sizes = self._raw_extras[fi]["segment_sizes"] # construct a (n_segments, 2) array of the first and last # sample index for each segment relative to the start of the recording @@ -167,9 +347,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): sel_samples_local[0:-1, 1] = ( sel_samples_global[0:-1, 1] - sel_samples_global[0:-1, 0] ) - sel_samples_local[ - 1::, 0 - ] = 0 # now set the start sample for all segments after the first to 0 + sel_samples_local[1::, 0] = ( + 0 # now set the start sample for all segments after the first to 0 + ) sel_samples_local[0, 0] = ( start - sel_samples_global[0, 0] @@ -178,15 +358,47 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): -1, 0 ] # express stop sample relative to segment onset - # now load data from selected segments/channels via - # neo.Segment.AnalogSignal.load() + # array containing Segments + segments_arr = np.array(neo_block[0].segments, dtype=object) + + # if gaps were detected, correctly insert gap Segments in between valid Segments + gap_samples = self._raw_extras[fi]["seg_gap_dict"]["gap_n_samps"] + gap_segments = [Segment(f"gap-{i}") for i in range(len(gap_samples))] + + # create AnalogSignal objects representing gap data filled with 0's + sfreq = nlx_reader.get_signal_sampling_rate() + n_chans = ( + np.arange(idx.start, idx.stop, idx.step).size + if type(idx) is slice + else len(idx) # idx can be a slice or an np.array so check both + ) + + for seg, n in zip(gap_segments, gap_samples): + asig = AnalogSignal( + signal=np.zeros((n, n_chans)), units="uV", sampling_rate=sfreq * Hz + ) + seg.analogsignals.append(asig) + + n_total_segments = len(neo_block[0].segments + gap_segments) + segments_arr = np.zeros((n_total_segments,), dtype=object) + + # insert inferred gap segments at the right place in between valid segments + isgap = self._raw_extras[0]["seg_gap_dict"]["isgap"] + segments_arr[~isgap] = neo_block[0].segments + segments_arr[isgap] = gap_segments + + # now load data for selected segments/channels via + # neo.Segment.AnalogSignalProxy.load() or + # pad directly as AnalogSignal.magnitude for any gap data all_data = np.concatenate( [ signal.load(channel_indexes=idx).magnitude[ samples[0] : samples[-1] + 1, : ] + if isinstance(signal, AnalogSignalProxy) + else signal.magnitude[samples[0] : samples[-1] + 1, :] for seg, samples in zip( - neo_block[0].segments[first_seg : last_seg + 1], sel_samples_local + segments_arr[first_seg : last_seg + 1], sel_samples_local ) for signal in seg.analogsignals ] diff --git a/mne/io/neuralynx/tests/test_neuralynx.py b/mne/io/neuralynx/tests/test_neuralynx.py index 21cb73927a8..ceebdd3c975 100644 --- a/mne/io/neuralynx/tests/test_neuralynx.py +++ b/mne/io/neuralynx/tests/test_neuralynx.py @@ -2,7 +2,7 @@ # Copyright the MNE-Python contributors. import os from ast import literal_eval -from typing import Dict +from datetime import datetime, timezone import numpy as np import pytest @@ -15,8 +15,10 @@ testing_path = data_path(download=False) / "neuralynx" +pytest.importorskip("neo") -def _nlxheader_to_dict(matdict: Dict) -> Dict: + +def _nlxheader_to_dict(matdict: dict) -> dict: """Convert the read-in "Header" field into a dict. All the key-value pairs of Header entries are formatted as strings @@ -65,32 +67,78 @@ def _read_nlx_mat_chan(matfile: str) -> np.ndarray: return x -mne_testing_ncs = [ - "LAHC1.ncs", - "LAHC2.ncs", - "LAHC3.ncs", - "LAHCu1.ncs", # the 'u' files are going to be filtered out - "xAIR1.ncs", - "xEKG1.ncs", -] +def _read_nlx_mat_chan_keep_gaps(matfile: str) -> np.ndarray: + """Read a single channel from a Neuralynx .mat file and keep invalid samples.""" + mat = loadmat(matfile) + + hdr_dict = _nlxheader_to_dict(mat) + + # Nlx2MatCSC.m reads the data in N equal-sized (512-item) chunks + # this array (1, n_chunks) stores the number of valid samples + # per chunk (the last chunk is usually shorter) + n_valid_samples = mat["NumberOfValidSamples"].ravel() + + # read in the artificial zeros so that + # we can compare with the mne padded arrays + ncs_records_with_gaps = [9, 15, 20] + for i in ncs_records_with_gaps: + n_valid_samples[i] = 512 + + # concatenate chunks, respecting the number of valid samples + x = np.concatenate( + [mat["Samples"][0:n, i] for i, n in enumerate(n_valid_samples)] + ) # in ADBits + + # this value is the same for all channels and + # converts data from ADBits to Volts + conversionf = literal_eval(hdr_dict["ADBitVolts"]) + x = x * conversionf + + # if header says input was inverted at acquisition + # (possibly for spike detection or so?), flip it back + # NeuralynxIO does this under the hood in NeuralynxIO.parse_header() + # see this discussion: https://github.com/NeuralEnsemble/python-neo/issues/819 + if hdr_dict["InputInverted"] == "True": + x *= -1 + return x + + +# set known values for the Neuralynx data for testing expected_chan_names = ["LAHC1", "LAHC2", "LAHC3", "xAIR1", "xEKG1"] +expected_hp_freq = 0.1 +expected_lp_freq = 500.0 +expected_sfreq = 2000.0 +expected_meas_date = datetime.strptime("2023/11/02 13:39:27", "%Y/%m/%d %H:%M:%S") @requires_testing_data def test_neuralynx(): """Test basic reading.""" - pytest.importorskip("neo") - from neo.io import NeuralynxIO - excluded_ncs_files = ["LAHCu1.ncs", "LAHCu2.ncs", "LAHCu3.ncs"] + excluded_ncs_files = [ + "LAHCu1.ncs", + "LAHC1_3_gaps.ncs", + "LAHC2_3_gaps.ncs", + ] # ==== MNE-Python ==== # + fname_patterns = ["*u*.ncs", "*3_gaps.ncs"] raw = read_raw_neuralynx( - fname=testing_path, preload=True, exclude_fname_patterns=["*u*.ncs"] + fname=testing_path, + preload=True, + exclude_fname_patterns=fname_patterns, ) + # test that we picked the right info from headers + assert raw.info["highpass"] == expected_hp_freq, "highpass freq not set correctly" + assert raw.info["lowpass"] == expected_lp_freq, "lowpass freq not set correctly" + assert raw.info["sfreq"] == expected_sfreq, "sampling freq not set correctly" + + meas_date_utc = expected_meas_date.astimezone(timezone.utc) + assert raw.info["meas_date"] == meas_date_utc, "meas_date not set correctly" + # test that channel selection worked assert ( raw.ch_names == expected_chan_names @@ -136,5 +184,63 @@ def test_neuralynx(): ) # data _test_raw_reader( - read_raw_neuralynx, fname=testing_path, exclude_fname_patterns=["*u*.ncs"] + read_raw_neuralynx, + fname=testing_path, + exclude_fname_patterns=fname_patterns, + ) + + +@requires_testing_data +def test_neuralynx_gaps(): + """Test gap detection.""" + # ignore files with no gaps + ignored_ncs_files = [ + "LAHC1.ncs", + "LAHC2.ncs", + "LAHC3.ncs", + "xAIR1.ncs", + "xEKG1.ncs", + "LAHCu1.ncs", + ] + raw = read_raw_neuralynx( + fname=testing_path, + preload=True, + exclude_fname_patterns=ignored_ncs_files, + ) + mne_y, _ = raw.get_data(return_times=True) # in V + + # there should be 2 channels with 3 gaps (of 130 samples in total) + n_expected_gaps = 3 + n_expected_missing_samples = 130 + assert len(raw.annotations) == n_expected_gaps, "Wrong number of gaps detected" + assert ( + (mne_y[0, :] == 0).sum() == n_expected_missing_samples + ), "Number of true and inferred missing samples differ" + + # read in .mat files containing original gaps + matchans = ["LAHC1_3_gaps.mat", "LAHC2_3_gaps.mat"] + + # (n_chan, n_samples) array, in V + mat_y = np.stack( + [ + _read_nlx_mat_chan_keep_gaps(os.path.join(testing_path, ch)) + for ch in matchans + ] ) + + # compare originally modified .ncs arrays with MNE-padded arrays + # and test that we back-inserted 0's at the right places + assert_allclose( + mne_y, mat_y, rtol=1e-6, err_msg="MNE and Nlx2MatCSC.m not all close" + ) + + # test that channel selection works + raw = read_raw_neuralynx( + fname=testing_path, + preload=False, + exclude_fname_patterns=ignored_ncs_files, + ) + + raw.pick("LAHC2") + assert raw.ch_names == ["LAHC2"] + raw.load_data() # before gh-12357 this would fail diff --git a/mne/io/nicolet/nicolet.py b/mne/io/nicolet/nicolet.py index 85a7d1e5607..0ef0c0a4f4a 100644 --- a/mne/io/nicolet/nicolet.py +++ b/mne/io/nicolet/nicolet.py @@ -19,7 +19,7 @@ @fill_doc def read_raw_nicolet( input_fname, ch_type, eog=(), ecg=(), emg=(), misc=(), preload=False, verbose=None -): +) -> "RawNicolet": """Read Nicolet data as raw object. ..note:: This reader takes data files with the extension ``.data`` as an @@ -84,7 +84,7 @@ def _get_nicolet_info(fname, ch_type, eog, ecg, emg, misc): logger.info("Reading header...") header_info = dict() - with open(header, "r") as fid: + with open(header) as fid: for line in fid: var, value = line.split("=") if var == "elec_names": @@ -183,11 +183,11 @@ def __init__( misc=(), preload=False, verbose=None, - ): # noqa: D102 + ): input_fname = path.abspath(input_fname) info, header_info = _get_nicolet_info(input_fname, ch_type, eog, ecg, emg, misc) last_samps = [header_info["num_samples"] - 1] - super(RawNicolet, self).__init__( + super().__init__( info, preload, filenames=[input_fname], diff --git a/mne/io/nihon/nihon.py b/mne/io/nihon/nihon.py index b39a18af838..ef14a735ca9 100644 --- a/mne/io/nihon/nihon.py +++ b/mne/io/nihon/nihon.py @@ -24,7 +24,7 @@ def _ensure_path(fname): @fill_doc -def read_raw_nihon(fname, preload=False, verbose=None): +def read_raw_nihon(fname, preload=False, verbose=None) -> "RawNihon": """Reader for an Nihon Kohden EEG file. Parameters @@ -70,7 +70,7 @@ def _read_nihon_metadata(fname): warn("No PNT file exists. Metadata will be blank") return metadata logger.info("Found PNT file, reading metadata.") - with open(pnt_fname, "r") as fid: + with open(pnt_fname) as fid: version = np.fromfile(fid, "|S16", 1).astype("U16")[0] if version not in _valid_headers: raise ValueError(f"Not a valid Nihon Kohden PNT file ({version})") @@ -135,7 +135,7 @@ def _read_21e_file(fname): logger.info("Found 21E file, reading channel names.") for enc in _encodings: try: - with open(e_fname, "r", encoding=enc) as fid: + with open(e_fname, encoding=enc) as fid: keep_parsing = False for line in fid: if line.startswith("["): @@ -169,17 +169,16 @@ def _read_nihon_header(fname): _chan_labels = _read_21e_file(fname) header = {} logger.info(f"Reading header from {fname}") - with open(fname, "r") as fid: + with open(fname) as fid: version = np.fromfile(fid, "|S16", 1).astype("U16")[0] if version not in _valid_headers: - raise ValueError("Not a valid Nihon Kohden EEG file ({})".format(version)) + raise ValueError(f"Not a valid Nihon Kohden EEG file ({version})") fid.seek(0x0081) control_block = np.fromfile(fid, "|S16", 1).astype("U16")[0] if control_block not in _valid_headers: raise ValueError( - "Not a valid Nihon Kohden EEG file " - "(control block {})".format(version) + f"Not a valid Nihon Kohden EEG file (control block {version})" ) fid.seek(0x17FE) @@ -285,10 +284,10 @@ def _read_nihon_annotations(fname): warn("No LOG file exists. Annotations will not be read") return dict(onset=[], duration=[], description=[]) logger.info("Found LOG file, reading events.") - with open(log_fname, "r") as fid: + with open(log_fname) as fid: version = np.fromfile(fid, "|S16", 1).astype("U16")[0] if version not in _valid_headers: - raise ValueError("Not a valid Nihon Kohden LOG file ({})".format(version)) + raise ValueError(f"Not a valid Nihon Kohden LOG file ({version})") fid.seek(0x91) n_logblocks = np.fromfile(fid, np.uint8, 1)[0] @@ -416,7 +415,7 @@ def __init__(self, fname, preload=False, verbose=None): info["chs"][i_ch]["range"] = t_range info["chs"][i_ch]["cal"] = 1 / t_range - super(RawNihon, self).__init__( + super().__init__( info, preload=preload, last_samps=(n_samples - 1,), diff --git a/mne/io/nirx/nirx.py b/mne/io/nirx/nirx.py index 98d81f9c268..52826f266f3 100644 --- a/mne/io/nirx/nirx.py +++ b/mne/io/nirx/nirx.py @@ -34,7 +34,9 @@ @fill_doc -def read_raw_nirx(fname, saturated="annotate", preload=False, verbose=None): +def read_raw_nirx( + fname, saturated="annotate", preload=False, verbose=None +) -> "RawNIRX": """Reader for a NIRX fNIRS recording. Parameters @@ -63,7 +65,7 @@ def read_raw_nirx(fname, saturated="annotate", preload=False, verbose=None): def _open(fname): - return open(fname, "r", encoding="latin-1") + return open(fname, encoding="latin-1") @fill_doc @@ -99,7 +101,7 @@ def __init__(self, fname, saturated, preload=False, verbose=None): fname = str(_check_fname(fname, "read", True, "fname", need_dir=True)) - json_config = glob.glob("%s/*%s" % (fname, "config.json")) + json_config = glob.glob(f"{fname}/*{'config.json'}") if len(json_config): is_aurora = True else: @@ -128,7 +130,7 @@ def __init__(self, fname, saturated, preload=False, verbose=None): "config.txt", "probeInfo.mat", ) - n_dat = len(glob.glob("%s/*%s" % (fname, "dat"))) + n_dat = len(glob.glob(f"{fname}/*{'dat'}")) if n_dat != 1: warn( "A single dat file was expected in the specified path, " @@ -141,7 +143,7 @@ def __init__(self, fname, saturated, preload=False, verbose=None): files = dict() nan_mask = dict() for key in keys: - files[key] = glob.glob("%s/*%s" % (fname, key)) + files[key] = glob.glob(f"{fname}/*{key}") fidx = 0 if len(files[key]) != 1: if key not in ("wl1", "wl2"): @@ -200,7 +202,7 @@ def __init__(self, fname, saturated, preload=False, verbose=None): if hdr["GeneralInfo"]["NIRStar"] not in ['"15.0"', '"15.2"', '"15.3"']: raise RuntimeError( "MNE does not support this NIRStar version" - " (%s)" % (hdr["GeneralInfo"]["NIRStar"],) + f" ({hdr['GeneralInfo']['NIRStar']})" ) if ( "NIRScout" not in hdr["GeneralInfo"]["Device"] @@ -474,7 +476,7 @@ def __init__(self, fname, saturated, preload=False, verbose=None): annot_mask |= mask nan_mask[key] = None # shouldn't need again - super(RawNIRX, self).__init__( + super().__init__( info, preload, filenames=[fname], diff --git a/mne/io/nsx/nsx.py b/mne/io/nsx/nsx.py index 95448b1b22c..c20e19b29ed 100644 --- a/mne/io/nsx/nsx.py +++ b/mne/io/nsx/nsx.py @@ -88,7 +88,7 @@ @fill_doc def read_raw_nsx( input_fname, stim_channel=True, eog=None, misc=None, preload=False, *, verbose=None -): +) -> "RawNSX": """Reader function for NSx (Blackrock Microsystems) files. Parameters @@ -178,7 +178,7 @@ def __init__( preload=False, verbose=None, ): - logger.info("Extracting NSX parameters from {}...".format(input_fname)) + logger.info(f"Extracting NSX parameters from {input_fname}...") input_fname = os.path.abspath(input_fname) ( info, @@ -191,7 +191,7 @@ def __init__( ) = _get_hdr_info(input_fname, stim_channel=stim_channel, eog=eog, misc=misc) raw_extras["orig_format"] = orig_format first_samps = (raw_extras["timestamp"][0],) - super(RawNSX, self).__init__( + super().__init__( info, first_samps=first_samps, last_samps=[first_samps[0] + n_samples - 1], @@ -311,7 +311,7 @@ def _read_header_22_and_above(fname): basic_header[x] = basic_header[x] * 1e-3 ver_major, ver_minor = basic_header.pop("ver_major"), basic_header.pop("ver_minor") - basic_header["spec"] = "{}.{}".format(ver_major, ver_minor) + basic_header["spec"] = f"{ver_major}.{ver_minor}" data_header = list() index = 0 @@ -355,9 +355,9 @@ def _get_hdr_info(fname, stim_channel=True, eog=None, misc=None): ch_names = list(nsx_info["extended"]["electrode_label"]) ch_types = list(nsx_info["extended"]["type"]) ch_units = list(nsx_info["extended"]["units"]) - ch_names, ch_types, ch_units = [ + ch_names, ch_types, ch_units = ( list(map(bytes.decode, xx)) for xx in (ch_names, ch_types, ch_units) - ] + ) max_analog_val = nsx_info["extended"]["max_analog_val"].astype("double") min_analog_val = nsx_info["extended"]["min_analog_val"].astype("double") max_digital_val = nsx_info["extended"]["max_digital_val"].astype("double") diff --git a/mne/io/persyst/persyst.py b/mne/io/persyst/persyst.py index 44334fa4555..11f8a3a35ea 100644 --- a/mne/io/persyst/persyst.py +++ b/mne/io/persyst/persyst.py @@ -18,7 +18,7 @@ @fill_doc -def read_raw_persyst(fname, preload=False, verbose=None): +def read_raw_persyst(fname, preload=False, verbose=None) -> "RawPersyst": """Reader for a Persyst (.lay/.dat) recording. Parameters @@ -226,7 +226,7 @@ def __init__(self, fname, preload=False, verbose=None): raw_extras = {"dtype": dtype, "n_chs": n_chs, "n_samples": n_samples} # create Raw object - super(RawPersyst, self).__init__( + super().__init__( info, preload, filenames=[dat_fpath], @@ -351,7 +351,7 @@ def _read_lay_contents(fname): # initialize all section to empty str section = "" - with open(fname, "r") as fin: + with open(fname) as fin: for line in fin: # break a line into a status, key and value status, key, val = _process_lay_line(line, section) diff --git a/mne/io/persyst/tests/test_persyst.py b/mne/io/persyst/tests/test_persyst.py index c81b53f2b79..76e117817fd 100644 --- a/mne/io/persyst/tests/test_persyst.py +++ b/mne/io/persyst/tests/test_persyst.py @@ -85,7 +85,7 @@ def test_persyst_dates(tmp_path): # reformat the lay file to have testdate with # "/" character - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): @@ -101,7 +101,7 @@ def test_persyst_dates(tmp_path): # reformat the lay file to have testdate with # "-" character os.remove(new_fname_lay) - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): @@ -163,7 +163,7 @@ def test_persyst_moved_file(tmp_path): # to the full path, but it should still not work # as reader requires lay and dat file to be in # same directory - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): @@ -216,7 +216,7 @@ def test_persyst_errors(tmp_path): shutil.copy(fname_dat, new_fname_dat) # reformat the lay file - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): @@ -229,7 +229,7 @@ def test_persyst_errors(tmp_path): # reformat the lay file os.remove(new_fname_lay) - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): @@ -243,7 +243,7 @@ def test_persyst_errors(tmp_path): # reformat the lay file to have testdate # improperly specified os.remove(new_fname_lay) - with open(fname_lay, "r") as fin: + with open(fname_lay) as fin: with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): diff --git a/mne/io/pick.py b/mne/io/pick.py index f7c77b1af14..4ae1d25b3c5 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -4,7 +4,6 @@ # Copyright the MNE-Python contributors. -from .._fiff import _io_dep_getattr from .._fiff.pick import ( _DATA_CH_TYPES_ORDER_DEFAULT, _DATA_CH_TYPES_SPLIT, @@ -18,11 +17,3 @@ "_DATA_CH_TYPES_ORDER_DEFAULT", "_DATA_CH_TYPES_SPLIT", ] - - -def __getattr__(name): - try: - return globals()[name] - except KeyError: - pass - return _io_dep_getattr(name, "pick") diff --git a/mne/io/proj.py b/mne/io/proj.py deleted file mode 100644 index 98445f1ce7e..00000000000 --- a/mne/io/proj.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "proj") diff --git a/mne/io/reference.py b/mne/io/reference.py deleted file mode 100644 index 850d6bd7294..00000000000 --- a/mne/io/reference.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "reference") diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index e32b32370b3..bde3e045528 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -21,7 +21,9 @@ @fill_doc -def read_raw_snirf(fname, optode_frame="unknown", preload=False, verbose=None): +def read_raw_snirf( + fname, optode_frame="unknown", preload=False, verbose=None +) -> "RawSNIRF": """Reader for a continuous wave SNIRF data. .. note:: This reader supports the .snirf file type only, @@ -57,7 +59,7 @@ def read_raw_snirf(fname, optode_frame="unknown", preload=False, verbose=None): def _open(fname): - return open(fname, "r", encoding="latin-1") + return open(fname, encoding="latin-1") @fill_doc @@ -166,7 +168,7 @@ def natural_keys(text): for c in channels ] ) - sources = [f"S{int(s)}" for s in sources] + sources = {int(s): f"S{int(s)}" for s in sources} if "detectorLabels_disabled" in dat["nirs/probe"]: # This is disabled as @@ -183,7 +185,7 @@ def natural_keys(text): for c in channels ] ) - detectors = [f"D{int(d)}" for d in detectors] + detectors = {int(d): f"D{int(d)}" for d in detectors} # Extract source and detector locations # 3D positions are optional in SNIRF, @@ -222,9 +224,6 @@ def natural_keys(text): "location information" ) - assert len(sources) == srcPos3D.shape[0] - assert len(detectors) == detPos3D.shape[0] - chnames = [] ch_types = [] for chan in channels: @@ -246,9 +245,9 @@ def natural_keys(text): )[0] ) ch_name = ( - sources[src_idx - 1] + sources[src_idx] + "_" - + detectors[det_idx - 1] + + detectors[det_idx] + " " + str(fnirs_wavelengths[wve_idx - 1]) ) @@ -263,7 +262,7 @@ def natural_keys(text): # Convert between SNIRF processed names and MNE type names dt_id = dt_id.lower().replace("dod", "fnirs_od") - ch_name = sources[src_idx - 1] + "_" + detectors[det_idx - 1] + ch_name = sources[src_idx] + "_" + detectors[det_idx] if dt_id == "fnirs_od": wve_idx = int( @@ -286,7 +285,8 @@ def natural_keys(text): subject_info = {} names = np.array(dat.get("nirs/metaDataTags/SubjectID")) - subject_info["first_name"] = _correct_shape(names)[0].decode("UTF-8") + names = _correct_shape(names)[0].decode("UTF-8") + subject_info["his_id"] = names # Read non standard (but allowed) custom metadata tags if "lastName" in dat.get("nirs/metaDataTags/"): ln = dat.get("/nirs/metaDataTags/lastName")[0].decode("UTF-8") @@ -294,6 +294,12 @@ def natural_keys(text): if "middleName" in dat.get("nirs/metaDataTags/"): m = dat.get("/nirs/metaDataTags/middleName")[0].decode("UTF-8") subject_info["middle_name"] = m + if "firstName" in dat.get("nirs/metaDataTags/"): + fn = dat.get("/nirs/metaDataTags/firstName")[0].decode("UTF-8") + subject_info["first_name"] = fn + else: + # MNE < 1.7 used to not write the firstName tag, so pull it from names + subject_info["first_name"] = names.split("_")[0] if "sex" in dat.get("nirs/metaDataTags/"): s = dat.get("/nirs/metaDataTags/sex")[0].decode("UTF-8") if s in {"M", "Male", "1", "m"}: @@ -413,10 +419,10 @@ def natural_keys(text): info["dig"] = dig str_date = _correct_shape( - np.array((dat.get("/nirs/metaDataTags/MeasurementDate"))) + np.array(dat.get("/nirs/metaDataTags/MeasurementDate")) )[0].decode("UTF-8") str_time = _correct_shape( - np.array((dat.get("/nirs/metaDataTags/MeasurementTime"))) + np.array(dat.get("/nirs/metaDataTags/MeasurementTime")) )[0].decode("UTF-8") str_datetime = str_date + str_time @@ -458,7 +464,7 @@ def natural_keys(text): with info._unlock(): info["subject_info"]["birthday"] = birthday - super(RawSNIRF, self).__init__( + super().__init__( info, preload, filenames=[fname], diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 2d2ad2c6324..f298a030bea 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -133,6 +133,7 @@ def test_snirf_gowerlabs(): def test_snirf_basic(): """Test reading SNIRF files.""" raw = read_raw_snirf(sfnirs_homer_103_wShort, preload=True) + assert raw.info["subject_info"]["his_id"] == "default" # Test data import assert raw._data.shape == (26, 145) @@ -243,21 +244,27 @@ def test_snirf_nonstandard(tmp_path): fname = str(tmp_path) + "/mod.snirf" # Manually mark up the file to match MNE-NIRS custom tags with h5py.File(fname, "r+") as f: - f.create_dataset("nirs/metaDataTags/middleName", data=["X".encode("UTF-8")]) - f.create_dataset("nirs/metaDataTags/lastName", data=["Y".encode("UTF-8")]) - f.create_dataset("nirs/metaDataTags/sex", data=["1".encode("UTF-8")]) + f.create_dataset("nirs/metaDataTags/middleName", data=[b"X"]) + f.create_dataset("nirs/metaDataTags/lastName", data=[b"Y"]) + f.create_dataset("nirs/metaDataTags/sex", data=[b"1"]) raw = read_raw_snirf(fname, preload=True) + assert raw.info["subject_info"]["first_name"] == "default" # pull from his_id + with h5py.File(fname, "r+") as f: + f.create_dataset("nirs/metaDataTags/firstName", data=[b"W"]) + raw = read_raw_snirf(fname, preload=True) + assert raw.info["subject_info"]["first_name"] == "W" assert raw.info["subject_info"]["middle_name"] == "X" assert raw.info["subject_info"]["last_name"] == "Y" assert raw.info["subject_info"]["sex"] == 1 + assert raw.info["subject_info"]["his_id"] == "default" with h5py.File(fname, "r+") as f: del f["nirs/metaDataTags/sex"] - f.create_dataset("nirs/metaDataTags/sex", data=["2".encode("UTF-8")]) + f.create_dataset("nirs/metaDataTags/sex", data=[b"2"]) raw = read_raw_snirf(fname, preload=True) assert raw.info["subject_info"]["sex"] == 2 with h5py.File(fname, "r+") as f: del f["nirs/metaDataTags/sex"] - f.create_dataset("nirs/metaDataTags/sex", data=["0".encode("UTF-8")]) + f.create_dataset("nirs/metaDataTags/sex", data=[b"0"]) raw = read_raw_snirf(fname, preload=True) assert raw.info["subject_info"]["sex"] == 0 diff --git a/mne/io/tag.py b/mne/io/tag.py deleted file mode 100644 index 41dc15fd40d..00000000000 --- a/mne/io/tag.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "tag") diff --git a/mne/io/tests/test_apply_function.py b/mne/io/tests/test_apply_function.py index b1869e1dae6..f250e9489b9 100644 --- a/mne/io/tests/test_apply_function.py +++ b/mne/io/tests/test_apply_function.py @@ -63,3 +63,32 @@ def test_apply_function_verbose(): assert out is raw raw.apply_function(printer, verbose=True) assert sio.getvalue().count("\n") == n_chan + + +def test_apply_function_ch_access(): + """Test apply_function is able to access channel idx.""" + + def _bad_ch_idx(x, ch_idx): + assert x[0] == ch_idx + return x + + def _bad_ch_name(x, ch_name): + assert isinstance(ch_name, str) + assert x[0] == float(ch_name) + return x + + data = np.full((2, 10), np.arange(2).reshape(-1, 1)) + raw = RawArray(data, create_info(2, 1.0, "mag")) + + # test ch_idx access in both code paths (parallel / 1 job) + raw.apply_function(_bad_ch_idx) + raw.apply_function(_bad_ch_idx, n_jobs=2) + raw.apply_function(_bad_ch_name) + raw.apply_function(_bad_ch_name, n_jobs=2) + + # test input catches + with pytest.raises( + ValueError, + match="cannot access.*when channel_wise=False", + ): + raw.apply_function(_bad_ch_idx, channel_wise=False) diff --git a/mne/io/tests/test_deprecation.py b/mne/io/tests/test_deprecation.py deleted file mode 100644 index fecf9a78091..00000000000 --- a/mne/io/tests/test_deprecation.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Test deprecation of mne.io private attributes to mne._fiff.""" - -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -import pytest - - -def test_deprecation(): - """Test deprecation of mne.io FIFF stuff.""" - import mne.io - - # Shouldn't warn (backcompat) - mne.io.constants.FIFF - mne.io.pick._picks_to_idx - mne.io.get_channel_type_constants() - - # Should warn - with pytest.warns(FutureWarning, match=r"mne\.io\.pick\.pick_channels is dep"): - from mne.io.pick import pick_channels # noqa: F401 - with pytest.warns(FutureWarning, match=r"mne\.io\.pick\.pick_channels is dep"): - mne.io.pick.pick_channels - with pytest.warns(FutureWarning, match=r"mne\.io\.meas_info\.read_info is dep"): - from mne.io.meas_info import read_info # noqa: F401 - from mne.io import meas_info - - with pytest.warns(FutureWarning, match=r"mne\.io\.meas_info\.read_info is dep"): - meas_info.read_info diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index bac32f83f65..33384c1e0e4 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -764,7 +764,7 @@ def raw_factory(meas_date): ) return raw - raw_A, raw_B = [raw_factory((x, 0)) for x in [0, 2]] + raw_A, raw_B = (raw_factory((x, 0)) for x in [0, 2]) raw_A.append(raw_B) assert_array_equal(raw_A.annotations.onset, EXPECTED_ONSET) @@ -1022,3 +1022,10 @@ def test_concatenate_raw_dev_head_t(): raw.info["dev_head_t"]["trans"][0, 0] = np.nan raw2 = raw.copy() concatenate_raws([raw, raw2]) + + +def test_last_samp(): + """Test that getting the last sample works.""" + raw = read_raw_fif(raw_fname).crop(0, 0.1).load_data() + last_data = raw._data[:, [-1]] + assert_array_equal(raw[:, -1][0], last_data) diff --git a/mne/io/tests/test_read_raw.py b/mne/io/tests/test_read_raw.py index a1e27166b0a..eccd074d9a0 100644 --- a/mne/io/tests/test_read_raw.py +++ b/mne/io/tests/test_read_raw.py @@ -14,7 +14,7 @@ from mne.io import read_raw from mne.io._read_raw import _get_readers, split_name_ext -base = Path(__file__).parent.parent +base = Path(__file__).parents[1] test_base = Path(testing.data_path(download=False)) @@ -50,7 +50,13 @@ def test_read_raw_suggested(fname): base / "tests/data/test_raw.fif", base / "tests/data/test_raw.fif.gz", base / "edf/tests/data/test.edf", - base / "edf/tests/data/test.bdf", + pytest.param( + base / "edf/tests/data/test.bdf", + marks=( + _testing_mark, + pytest.mark.filterwarnings("ignore:Channels contain different"), + ), + ), base / "brainvision/tests/data/test.vhdr", base / "kit/tests/data/test.sqd", pytest.param(test_base / "KIT" / "data_berlin.con", marks=_testing_mark), diff --git a/mne/io/utils.py b/mne/io/utils.py deleted file mode 100644 index 9460ceed55e..00000000000 --- a/mne/io/utils.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "utils") diff --git a/mne/io/write.py b/mne/io/write.py deleted file mode 100644 index 12c0ae00ca0..00000000000 --- a/mne/io/write.py +++ /dev/null @@ -1,11 +0,0 @@ -# Author: Eric Larson -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -from .._fiff import _io_dep_getattr - - -def __getattr__(name): - return _io_dep_getattr(name, "write") diff --git a/mne/label.py b/mne/label.py index b57b466df27..5c8a1b8ca30 100644 --- a/mne/label.py +++ b/mne/label.py @@ -242,7 +242,7 @@ def __init__( color=None, *, verbose=None, - ): # noqa: D102 + ): # check parameters if not isinstance(hemi, str): raise ValueError("hemi must be a string, not %s" % type(hemi)) @@ -335,13 +335,13 @@ def __add__(self, other): if self.subject != other.subject: raise ValueError( "Label subject parameters must match, got " - '"%s" and "%s". Consider setting the ' + f'"{self.subject}" and "{other.subject}". Consider setting the ' "subject parameter on initialization, or " "setting label.subject manually before " - "combining labels." % (self.subject, other.subject) + "combining labels." ) if self.hemi != other.hemi: - name = "%s + %s" % (self.name, other.name) + name = f"{self.name} + {other.name}" if self.hemi == "lh": lh, rh = self.copy(), other.copy() else: @@ -357,8 +357,8 @@ def __add__(self, other): other_dup = [np.where(other.vertices == d)[0][0] for d in duplicates] if not np.all(self.pos[self_dup] == other.pos[other_dup]): err = ( - "Labels %r and %r: vertices overlap but differ in " - "position values" % (self.name, other.name) + f"Labels {repr(self.name)} and {repr(other.name)}: vertices " + "overlap but differ in position values" ) raise ValueError(err) @@ -383,11 +383,11 @@ def __add__(self, other): indcs = np.argsort(vertices) vertices, pos, values = vertices[indcs], pos[indcs, :], values[indcs] - comment = "%s + %s" % (self.comment, other.comment) + comment = f"{self.comment} + {other.comment}" name0 = self.name if self.name else "unnamed" name1 = other.name if other.name else "unnamed" - name = "%s + %s" % (name0, name1) + name = f"{name0} + {name1}" color = _blend_colors(self.color, other.color) @@ -408,10 +408,10 @@ def __sub__(self, other): if self.subject != other.subject: raise ValueError( "Label subject parameters must match, got " - '"%s" and "%s". Consider setting the ' + f'"{self.subject}" and "{other.subject}". Consider setting the ' "subject parameter on initialization, or " "setting label.subject manually before " - "combining labels." % (self.subject, other.subject) + "combining labels." ) if self.hemi == other.hemi: @@ -419,7 +419,7 @@ def __sub__(self, other): else: keep = np.arange(len(self.vertices)) - name = "%s - %s" % (self.name or "unnamed", other.name or "unnamed") + name = f'{self.name or "unnamed"} - {other.name or "unnamed"}' return Label( self.vertices[keep], self.pos[keep], @@ -870,7 +870,7 @@ def center_of_mass( .. footbibliography:: """ if not isinstance(surf, str): - raise TypeError("surf must be a string, got %s" % (type(surf),)) + raise TypeError(f"surf must be a string, got {type(surf)}") subject = _check_subject(self.subject, subject) if np.any(self.values < 0): raise ValueError("Cannot compute COM with negative values") @@ -980,7 +980,7 @@ def _get_label_src(label, src): if src.kind != "surface": raise RuntimeError( "Cannot operate on SourceSpaces that are not " - "surface type, got %s" % (src.kind,) + f"surface type, got {src.kind}" ) if label.hemi == "lh": hemi_src = src[0] @@ -1017,11 +1017,10 @@ class BiHemiLabel: The name of the subject. """ - def __init__(self, lh, rh, name=None, color=None): # noqa: D102 + def __init__(self, lh, rh, name=None, color=None): if lh.subject != rh.subject: raise ValueError( - "lh.subject (%s) and rh.subject (%s) must " - "agree" % (lh.subject, rh.subject) + f"lh.subject ({lh.subject}) and rh.subject ({rh.subject}) must agree" ) self.lh = lh self.rh = rh @@ -1061,7 +1060,7 @@ def __add__(self, other): else: raise TypeError("Need: Label or BiHemiLabel. Got: %r" % other) - name = "%s + %s" % (self.name, other.name) + name = f"{self.name} + {other.name}" color = _blend_colors(self.color, other.color) return BiHemiLabel(lh, rh, name, color) @@ -1084,7 +1083,7 @@ def __sub__(self, other): elif len(rh.vertices) == 0: return lh else: - name = "%s - %s" % (self.name, other.name) + name = f"{self.name} - {other.name}" return BiHemiLabel(lh, rh, name, self.color) @@ -1133,8 +1132,8 @@ def read_label(filename, subject=None, color=None, *, verbose=None): hemi = "rh" else: raise ValueError( - "Cannot find which hemisphere it is. File should end" - " with lh.label or rh.label: %s" % (basename,) + "Cannot find which hemisphere it is. File should end with lh.label or " + f"rh.label: {basename}" ) # find name @@ -1144,10 +1143,10 @@ def read_label(filename, subject=None, color=None, *, verbose=None): basename_ = basename[:-6] else: basename_ = basename[:-9] - name = "%s-%s" % (basename_, hemi) + name = f"{basename_}-{hemi}" # read the file - with open(filename, "r") as fid: + with open(filename) as fid: comment = fid.readline().replace("\n", "")[1:] nv = int(fid.readline()) data = np.empty((5, nv)) @@ -1240,9 +1239,8 @@ def _prep_label_split(label, subject=None, subjects_dir=None): pass elif subject != label.subject: raise ValueError( - "The label specifies a different subject (%r) from " - "the subject parameter (%r)." % label.subject, - subject, + f"The label specifies a different subject ({repr(label.subject)}) from " + f"the subject parameter ({repr(subject)})." ) return label, subject, subjects_dir @@ -1296,7 +1294,7 @@ def _split_label_contig(label_to_split, subject=None, subjects_dir=None): else: basename = label_to_split.name name_ext = "" - name_pattern = "%s_div%%i%s" % (basename, name_ext) + name_pattern = f"{basename}_div%i{name_ext}" names = tuple(name_pattern % i for i in range(1, n_parts + 1)) # Colors @@ -1368,7 +1366,7 @@ def split_label(label, parts=2, subject=None, subjects_dir=None, freesurfer=Fals else: basename = label.name name_ext = "" - name_pattern = "%s_div%%i%s" % (basename, name_ext) + name_pattern = f"{basename}_div%i{name_ext}" names = tuple(name_pattern % i for i in range(1, n_parts + 1)) else: names = parts @@ -1482,7 +1480,7 @@ def label_sign_flip(label, src): vertno_sel = np.intersect1d(rh_vertno, vertices) ori.append(src[1]["nn"][vertno_sel]) if len(ori) == 0: - raise Exception('Unknown hemisphere type "%s"' % (label.hemi,)) + raise Exception(f'Unknown hemisphere type "{label.hemi}"') ori = np.concatenate(ori, axis=0) if len(ori) == 0: return np.array([], int) @@ -1707,7 +1705,7 @@ def _grow_labels(seeds, extents, hemis, names, dist, vert, subject): seed_repr = str(seed) else: seed_repr = ",".join(map(str, seed)) - comment = "Circular label: seed=%s, extent=%0.1fmm" % (seed_repr, extent) + comment = f"Circular label: seed={seed_repr}, extent={extent:0.1f}mm" label = Label( vertices=label_verts, pos=vert[hemi][label_verts], @@ -1794,10 +1792,10 @@ def grow_labels( n_seeds = len(seeds) if len(extents) != 1 and len(extents) != n_seeds: - raise ValueError("The extents parameter has to be of length 1 or " "len(seeds)") + raise ValueError("The extents parameter has to be of length 1 or len(seeds)") if len(hemis) != 1 and len(hemis) != n_seeds: - raise ValueError("The hemis parameter has to be of length 1 or " "len(seeds)") + raise ValueError("The hemis parameter has to be of length 1 or len(seeds)") if colors is not None: if len(colors.shape) == 1: # if one color for all seeds @@ -2159,8 +2157,8 @@ def _read_annot(fname): ) else: raise OSError( - "No such file %s, candidate parcellations in " - "that directory:\n%s" % (fname, "\n".join(cands)) + f"No such file {fname}, candidate parcellations in " + "that directory:\n" + "\n".join(cands) ) with open(fname, "rb") as fid: n_verts = np.fromfile(fid, ">i4", 1)[0] @@ -2238,14 +2236,14 @@ def _get_annot_fname(annot_fname, subject, hemi, parc, subjects_dir): def _load_vert_pos(subject, subjects_dir, surf_name, hemi, n_expected, extra=""): - fname_surf = op.join(subjects_dir, subject, "surf", "%s.%s" % (hemi, surf_name)) + fname_surf = op.join(subjects_dir, subject, "surf", f"{hemi}.{surf_name}") vert_pos, _ = read_surface(fname_surf) vert_pos /= 1e3 # the positions in labels are in meters if len(vert_pos) != n_expected: raise RuntimeError( - "Number of surface vertices (%s) for subject %s" + f"Number of surface vertices ({len(vert_pos)}) for subject {subject}" " does not match the expected number of vertices" - "(%s)%s" % (len(vert_pos), subject, n_expected, extra) + f"({n_expected}){extra}" ) return vert_pos @@ -2388,12 +2386,11 @@ def _check_labels_subject(labels, subject, name): if subject is not None: # label.subject can be None, depending on init if subject != label.subject: raise ValueError( - "Got multiple values of %s: %s and %s" - % (name, subject, label.subject) + f"Got multiple values of {name}: {subject} and {label.subject}" ) if subject is None: raise ValueError( - "if label.subject is None for all labels, " "%s must be provided" % name + f"if label.subject is None for all labels, {name} must be provided." ) return subject @@ -2521,7 +2518,7 @@ def labels_to_stc( if values.ndim == 1: values = values[:, np.newaxis] if values.ndim != 2: - raise ValueError("values must have 1 or 2 dimensions, got %s" % (values.ndim,)) + raise ValueError(f"values must have 1 or 2 dimensions, got {values.ndim}") _validate_type(src, (SourceSpaces, None)) if src is None: data, vertices, subject = _labels_to_stc_surf( @@ -2748,11 +2745,11 @@ def write_labels_to_annot( ) if any(i > 255 for i in color): - msg = "%s: %s (%s)" % (color, ", ".join(names), hemi) + msg = f"{color}: {', '.join(names)} ({hemi})" invalid_colors.append(msg) if len(names) > 1: - msg = "%s: %s (%s)" % (color, ", ".join(names), hemi) + msg = f"{color}: {', '.join(names)} ({hemi})" duplicate_colors.append(msg) # replace None values (labels with unspecified color) @@ -2801,7 +2798,7 @@ def write_labels_to_annot( other_indices = (annot_ids.index(i) for i in other_ids) other_names = (hemi_labels[i].name for i in other_indices) other_repr = ", ".join(other_names) - msg = "%s: %s overlaps %s" % (hemi, label.name, other_repr) + msg = f"{hemi}: {label.name} overlaps {other_repr}" overlap.append(msg) annot[label.vertices] = annot_id diff --git a/mne/minimum_norm/_eloreta.py b/mne/minimum_norm/_eloreta.py index 8f15365e5b4..b49b0a4a338 100644 --- a/mne/minimum_norm/_eloreta.py +++ b/mne/minimum_norm/_eloreta.py @@ -60,8 +60,8 @@ def _compute_eloreta(inv, lambda2, options): logger.info(" Computing optimized source covariance (eLORETA)...") if n_orient == 3: logger.info( - " Using %s orientation weights" - % ("uniform" if force_equal else "independent",) + f" Using {'uniform' if force_equal else 'independent'} " + "orientation weights" ) # src, sens, 3 G_3 = _get_G_3(G, n_orient) @@ -120,8 +120,7 @@ def _compute_eloreta(inv, lambda2, options): R_last.ravel() ) logger.debug( - " Iteration %s / %s ...%s (%0.1e)" - % (kk + 1, max_iter, extra, delta) + f" Iteration {kk + 1} / {max_iter} ...{extra} ({delta:0.1e})" ) if delta < eps: logger.info( diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index f41f660ac4e..440ed3735f2 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -725,7 +725,7 @@ def prepare_inverse_operator( # # w = diag(diag(R)) ** 0.5 # - noise_weight = inv["reginv"] * np.sqrt((1.0 + inv["sing"] ** 2 / lambda2)) + noise_weight = inv["reginv"] * np.sqrt(1.0 + inv["sing"] ** 2 / lambda2) noise_norm = np.zeros(inv["eigen_leads"]["nrow"]) (nrm2,) = linalg.get_blas_funcs(("nrm2",), (noise_norm,)) @@ -1079,7 +1079,7 @@ def _apply_inverse( # Pick the correct channels from the data # sel = _pick_channels_inverse_operator(evoked.ch_names, inv) - logger.info('Applying inverse operator to "%s"...' % (evoked.comment,)) + logger.info(f'Applying inverse operator to "{evoked.comment}"...') logger.info(" Picked %d channels from the data" % len(sel)) logger.info(" Computing inverse...") K, noise_norm, vertno, source_nn = _assemble_kernel( @@ -1108,7 +1108,7 @@ def _apply_inverse( sol = combine_xyz(sol) if noise_norm is not None: - logger.info(" %s..." % (method,)) + logger.info(f" {method}...") if is_free_ori and pick_ori == "vector": noise_norm = noise_norm.repeat(3, axis=0) sol *= noise_norm diff --git a/mne/minimum_norm/resolution_matrix.py b/mne/minimum_norm/resolution_matrix.py index 3dd24ac6847..655ca991914 100644 --- a/mne/minimum_norm/resolution_matrix.py +++ b/mne/minimum_norm/resolution_matrix.py @@ -1,4 +1,5 @@ """Compute resolution matrix for linear estimators.""" + # Authors: olaf.hauk@mrc-cbu.cam.ac.uk # # License: BSD-3-Clause diff --git a/mne/minimum_norm/spatial_resolution.py b/mne/minimum_norm/spatial_resolution.py index d68be423494..c9d28aef4d8 100644 --- a/mne/minimum_norm/spatial_resolution.py +++ b/mne/minimum_norm/spatial_resolution.py @@ -7,6 +7,7 @@ Resolution metrics: localisation error, spatial extent, relative amplitude. Metrics can be computed for point-spread and cross-talk functions (PSFs/CTFs). """ + import numpy as np from ..source_estimate import SourceEstimate diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index 58722a19fd5..e3be18a3fc9 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -55,7 +55,7 @@ from mne.source_estimate import VolSourceEstimate, read_source_estimate from mne.source_space._source_space import _get_src_nn from mne.surface import _normal_orth -from mne.time_frequency import EpochsTFR +from mne.time_frequency import EpochsTFRArray from mne.utils import _record_warnings, catch_logging test_path = testing.data_path(download=False) @@ -130,7 +130,7 @@ def _compare(a, b): if k not in b and k not in skip_types: raise ValueError( "First one had one second one didn't:\n" - "%s not in %s" % (k, b.keys()) + f"{k} not in {b.keys()}" ) if k not in skip_types: last_keys.pop() @@ -140,7 +140,7 @@ def _compare(a, b): if k not in a and k not in skip_types: raise ValueError( "Second one had one first one didn't:\n" - "%s not in %s" % (k, sorted(a.keys())) + f"{k} not in {sorted(a.keys())}" ) elif isinstance(a, list): assert len(a) == len(b) @@ -225,9 +225,7 @@ def _compare_inverses_approx( stc_2 /= norms corr = np.corrcoef(stc_1.ravel(), stc_2.ravel())[0, 1] assert corr > ctol - assert_allclose( - stc_1, stc_2, rtol=rtol, atol=atol, err_msg="%s: %s" % (method, corr) - ) + assert_allclose(stc_1, stc_2, rtol=rtol, atol=atol, err_msg=f"{method}: {corr}") def _compare_io(inv_op, *, out_file_ext=".fif", tmp_path): @@ -1377,11 +1375,11 @@ def test_apply_inverse_tfr(return_generator): times = np.arange(sfreq) / sfreq # make epochs 1s long data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size)) data = data + 1j * data # make complex to simulate amplitude + phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs) epochs_tfr.apply_baseline((0, 0.5)) pick_ori = "vector" - with pytest.raises(ValueError, match="Expected 2 inverse operators, " "got 3"): + with pytest.raises(ValueError, match="Expected 2 inverse operators, got 3"): apply_inverse_tfr_epochs(epochs_tfr, [inverse_operator] * 3, lambda2) # test epochs diff --git a/mne/minimum_norm/time_frequency.py b/mne/minimum_norm/time_frequency.py index 9561e3cd53a..16b76875941 100644 --- a/mne/minimum_norm/time_frequency.py +++ b/mne/minimum_norm/time_frequency.py @@ -861,9 +861,7 @@ def compute_source_psd( tmin = 0.0 if tmin is None else float(tmin) overlap = float(overlap) if not 0 <= overlap < 1: - raise ValueError( - "Overlap must be at least 0 and less than 1, got %s" % (overlap,) - ) + raise ValueError(f"Overlap must be at least 0 and less than 1, got {overlap}") n_fft = int(n_fft) duration = ((1.0 - overlap) * n_fft) / raw.info["sfreq"] events = make_fixed_length_events(raw, 1, tmin, tmax, duration) @@ -935,7 +933,7 @@ def _compute_source_psd_epochs( use_cps=True, ): """Generate compute_source_psd_epochs.""" - logger.info("Considering frequencies %g ... %g Hz" % (fmin, fmax)) + logger.info(f"Considering frequencies {fmin} ... {fmax} Hz") if label: # TODO: add multi-label support @@ -987,10 +985,10 @@ def _compute_source_psd_epochs( else: extra = "on %d epochs" % (n_epochs,) if isinstance(bandwidth, str): - bandwidth = "%s windowing" % (bandwidth,) + bandwidth = f"{bandwidth} windowing" else: - bandwidth = "%d tapers with bandwidth %0.1f Hz" % (n_tapers, bandwidth) - logger.info("Using %s %s" % (bandwidth, extra)) + bandwidth = f"{n_tapers} tapers with bandwidth {bandwidth:0.1f} Hz" + logger.info(f"Using {bandwidth} {extra}") if adaptive: parallel, my_psd_from_mt_adaptive, n_jobs = parallel_func( diff --git a/mne/misc.py b/mne/misc.py index 937f0eb4c9e..9313f048cbc 100644 --- a/mne/misc.py +++ b/mne/misc.py @@ -24,7 +24,7 @@ def parse_config(fname): """ reject_params = read_reject_parameters(fname) - with open(fname, "r") as f: + with open(fname) as f: lines = f.readlines() cat_ind = [i for i, x in enumerate(lines) if "category {" in x] @@ -69,7 +69,7 @@ def read_reject_parameters(fname): params : dict The rejection parameters. """ - with open(fname, "r") as f: + with open(fname) as f: lines = f.readlines() reject_names = ["gradReject", "magReject", "eegReject", "eogReject", "ecgReject"] @@ -85,7 +85,7 @@ def read_reject_parameters(fname): def read_flat_parameters(fname): """Read flat channel rejection parameters from .cov or .ave config file.""" - with open(fname, "r") as f: + with open(fname) as f: lines = f.readlines() reject_names = ["gradFlat", "magFlat", "eegFlat", "eogFlat", "ecgFlat"] diff --git a/mne/morph.py b/mne/morph.py index eb201e34451..5b8bfba41a7 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -190,7 +190,7 @@ def compute_source_morph( .. footbibliography:: """ src_data, kind, src_subject = _get_src_data(src) - subject_from = _check_subject_src(subject_from, src_subject) + subject_from = _check_subject_src(subject_from, src_subject, warn_none=True) del src _validate_type(src_to, (SourceSpaces, None), "src_to") _validate_type(subject_to, (str, None), "subject_to") @@ -241,7 +241,7 @@ def compute_source_morph( if src_to is None: if kind == "mixed": raise ValueError( - "src_to must be provided when using a " "mixed source space" + "src_to must be provided when using a mixed source space" ) else: surf_offset = 2 if src_to.kind == "mixed" else 0 @@ -268,9 +268,9 @@ def compute_source_morph( vertices_from = src_data["vertices_from"] if sparse: if spacing is not None: - raise ValueError("spacing must be set to None if " "sparse=True.") + raise ValueError("spacing must be set to None if sparse=True.") if xhemi: - raise ValueError("xhemi=True can only be used with " "sparse=False") + raise ValueError("xhemi=True can only be used with sparse=False") vertices_to_surf, morph_mat = _compute_sparse_morph( vertices_from, subject_from, subject_to, subjects_dir ) @@ -556,8 +556,8 @@ def apply( if stc.subject != self.subject_from: raise ValueError( "stc_from.subject and " - "morph.subject_from must match. (%s != %s)" - % (stc.subject, self.subject_from) + "morph.subject_from " + f"must match. ({stc.subject} != {self.subject_from})" ) out = _apply_morph_data(self, stc) if output != "stc": # convert to volume @@ -736,14 +736,14 @@ def _morph_vols(self, vols, mesg, subselect=True): def __repr__(self): # noqa: D105 s = "%s" % self.kind - s += ", %s -> %s" % (self.subject_from, self.subject_to) + s += f", {self.subject_from} -> {self.subject_to}" if self.kind == "volume": - s += ", zooms : {}".format(self.zooms) - s += ", niter_affine : {}".format(self.niter_affine) - s += ", niter_sdr : {}".format(self.niter_sdr) + s += f", zooms : {self.zooms}" + s += f", niter_affine : {self.niter_affine}" + s += f", niter_sdr : {self.niter_sdr}" elif self.kind in ("surface", "vector"): - s += ", spacing : {}".format(self.spacing) - s += ", smooth : %s" % self.smooth + s += f", spacing : {self.spacing}" + s += f", smooth : {self.smooth}" s += ", xhemi" if self.xhemi else "" return "" % s @@ -802,7 +802,7 @@ def _check_zooms(mri_from, zooms, zooms_src_to): if zooms.shape != (3,): raise ValueError( "zooms must be None, a singleton, or have shape (3,)," - " got shape %s" % (zooms.shape,) + f" got shape {zooms.shape}" ) zooms = tuple(zooms) return zooms @@ -823,30 +823,31 @@ def _resample_from_to(img, affine, to_vox_map): ############################################################################### # I/O -def _check_subject_src(subject, src, name="subject_from", src_name="src"): +def _check_subject_src( + subject, src, name="subject_from", src_name="src", *, warn_none=False +): if isinstance(src, str): subject_check = src elif src is None: # assume it's correct although dangerous but unlikely subject_check = subject else: subject_check = src._subject - if subject_check is None: - warn( - "The source space does not contain the subject name, we " - "recommend regenerating the source space (and forward / " - "inverse if applicable) for better code reliability" - ) + warn_none = True + if subject_check is None and warn_none: + warn( + "The source space does not contain the subject name, we " + "recommend regenerating the source space (and forward / " + "inverse if applicable) for better code reliability" + ) if subject is None: subject = subject_check elif subject_check is not None and subject != subject_check: raise ValueError( - "%s does not match %s subject (%s != %s)" - % (name, src_name, subject, subject_check) + f"{name} does not match {src_name} subject ({subject} != {subject_check})" ) if subject is None: raise ValueError( - "%s could not be inferred from %s, it must be " - "specified" % (name, src_name) + f"{name} could not be inferred from {src_name}, it must be specified" ) return subject @@ -898,8 +899,8 @@ def _check_dep(nibabel="2.1.0", dipy="0.10.1"): if not passed: raise ImportError( - "%s %s or higher must be correctly " - "installed and accessible from Python" % (lib, ver) + f"{lib} {ver} or higher must be correctly " + "installed and accessible from Python" ) @@ -1295,7 +1296,7 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=None, verbose=No spheres_to = [ subjects_dir / subject / "surf" / (xh + ".sphere.reg") for xh in ["lh", "rh"] ] - lhs, rhs = [read_surface(s)[0] for s in spheres_to] + lhs, rhs = (read_surface(s)[0] for s in spheres_to) if grade is not None: # fill a subset of vertices if isinstance(grade, list): @@ -1314,18 +1315,18 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=None, verbose=No # Compute nearest vertices in high dim mesh parallel, my_compute_nearest, _ = parallel_func(_compute_nearest, n_jobs) - lhs, rhs, rr = [a.astype(np.float32) for a in [lhs, rhs, ico["rr"]]] + lhs, rhs, rr = (a.astype(np.float32) for a in [lhs, rhs, ico["rr"]]) vertices = parallel(my_compute_nearest(xhs, rr) for xhs in [lhs, rhs]) # Make sure the vertices are ordered vertices = [np.sort(verts) for verts in vertices] for verts in vertices: if (np.diff(verts) == 0).any(): raise ValueError( - "Cannot use icosahedral grade %s with subject %s, " - "mapping %s vertices onto the high-resolution mesh " + f"Cannot use icosahedral grade {grade} with subject " + f"{subject}, mapping {len(verts)} vertices onto the " + "high-resolution mesh " "yields repeated vertices, use a lower grade or a " "list of vertices from an existing source space" - % (grade, subject, len(verts)) ) else: # potentially fill the surface vertices = [np.arange(lhs.shape[0]), np.arange(rhs.shape[0])] @@ -1449,9 +1450,9 @@ def _check_vertices_match(v1, v2, name): if np.isin(v2, v1).all(): ext = " Vertices were likely excluded during forward computation." raise ValueError( - "vertices do not match between morph (%s) and stc (%s) for %s:\n%s" - '\n%s\nPerhaps src_to=fwd["src"] needs to be passed when calling ' - "compute_source_morph.%s" % (len(v1), len(v2), name, v1, v2, ext) + f"vertices do not match between morph ({len(v1)}) and stc ({len(v2)}) " + 'for {name}:\n{v1}\n{v2}\nPerhaps src_to=fwd["src"] needs to be passed ' + f"when calling compute_source_morph.{ext}" ) @@ -1462,8 +1463,8 @@ def _apply_morph_data(morph, stc_from): """Morph a source estimate from one subject to another.""" if stc_from.subject is not None and stc_from.subject != morph.subject_from: raise ValueError( - "stc.subject (%s) != morph.subject_from (%s)" - % (stc_from.subject, morph.subject_from) + f"stc.subject ({stc_from.subject}) != morph.subject_from " + f"({morph.subject_from})" ) _check_option("morph.kind", morph.kind, ("surface", "volume", "mixed")) if morph.kind == "surface": @@ -1540,7 +1541,7 @@ def _apply_morph_data(morph, stc_from): for hemi, v1, v2 in zip( ("left", "right"), morph.src_data["vertices_from"], stc_from.vertices[:2] ): - _check_vertices_match(v1, v2, "%s hemisphere" % (hemi,)) + _check_vertices_match(v1, v2, f"{hemi} hemisphere") from_sl = slice(0, from_surf_stop) assert not from_used[from_sl].any() from_used[from_sl] = True diff --git a/mne/morph_map.py b/mne/morph_map.py index 64eb537b181..643cacf8dea 100644 --- a/mne/morph_map.py +++ b/mne/morph_map.py @@ -155,7 +155,7 @@ def _write_morph_map(fname, subject_from, subject_to, mmap_1, mmap_2): with start_and_end_file(fname) as fid: _write_morph_map_(fid, subject_from, subject_to, mmap_1, mmap_2) except Exception as exp: - warn('Could not write morph-map file "%s" (error: %s)' % (fname, exp)) + warn(f'Could not write morph-map file "{fname}" (error: {exp})') def _write_morph_map_(fid, subject_from, subject_to, mmap_1, mmap_2): diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 7d0741ab30a..54f1c825c13 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -7,7 +7,6 @@ __all__ = [ "annotate_movement", "annotate_muscle_zscore", "annotate_nan", - "apply_maxfilter", "compute_average_dev_head_t", "compute_bridged_electrodes", "compute_current_source_density", @@ -22,6 +21,7 @@ __all__ = [ "create_eog_epochs", "equalize_bads", "eyetracking", + "find_bad_channels_lof", "find_bad_channels_maxwell", "find_ecg_events", "find_eog_events", @@ -55,6 +55,7 @@ from ._fine_cal import ( read_fine_calibration, write_fine_calibration, ) +from ._lof import find_bad_channels_lof from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( @@ -77,7 +78,6 @@ from .ica import ( ) from .infomax_ import infomax from .interpolate import equalize_bads, interpolate_bridged_electrodes -from .maxfilter import apply_maxfilter from .maxwell import ( compute_maxwell_basis, find_bad_channels_maxwell, diff --git a/mne/preprocessing/_annotate_amplitude.py b/mne/preprocessing/_annotate_amplitude.py index 527e74650f0..2f61b19c3db 100644 --- a/mne/preprocessing/_annotate_amplitude.py +++ b/mne/preprocessing/_annotate_amplitude.py @@ -249,7 +249,7 @@ def _check_min_duration(min_duration, raw_duration): def _reject_short_segments(arr, min_duration_samples): """Check if flat or peak segments are longer than the minimum duration.""" - assert arr.dtype == bool and arr.ndim == 2 + assert arr.dtype == np.dtype(bool) and arr.ndim == 2 for k, ch in enumerate(arr): onsets, offsets = _mask_to_onsets_offsets(ch) _mark_inner(arr[k], onsets, offsets, min_duration_samples) diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index ca14c4de7e8..585b03fa10c 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -154,13 +154,13 @@ def compute_fine_calibration( cal_list = list() z_list = list() logger.info( - "Adjusting normals for %s magnetometers " - "(averaging over %s time intervals)" % (len(mag_picks), len(time_idxs) - 1) + f"Adjusting normals for {len(mag_picks)} magnetometers " + f"(averaging over {len(time_idxs) - 1} time intervals)" ) for start, stop in zip(time_idxs[:-1], time_idxs[1:]): logger.info( - " Processing interval %0.3f - %0.3f s" - % (start / info["sfreq"], stop / info["sfreq"]) + f" Processing interval {start / info['sfreq']:0.3f} - " + f"{stop / info['sfreq']:0.3f} s" ) data = raw[picks, start:stop][0] if ctc is not None: @@ -190,14 +190,12 @@ def compute_fine_calibration( # if len(grad_picks) > 0: extra = "X direction" if n_imbalance == 1 else ("XYZ directions") - logger.info( - "Computing imbalance for %s gradimeters (%s)" % (len(grad_picks), extra) - ) + logger.info(f"Computing imbalance for {len(grad_picks)} gradimeters ({extra})") imb_list = list() for start, stop in zip(time_idxs[:-1], time_idxs[1:]): logger.info( - " Processing interval %0.3f - %0.3f s" - % (start / info["sfreq"], stop / info["sfreq"]) + f" Processing interval {start / info['sfreq']:0.3f} - " + f"{stop / info['sfreq']:0.3f} s" ) data = raw[picks, start:stop][0] if ctc is not None: @@ -512,7 +510,7 @@ def read_fine_calibration(fname): fname = _check_fname(fname, overwrite="read", must_exist=True) check_fname(fname, "cal", (".dat",)) ch_names, locs, imb_cals = list(), list(), list() - with open(fname, "r") as fid: + with open(fname) as fid: for line in fid: if line[0] in "#\n": continue @@ -521,7 +519,7 @@ def read_fine_calibration(fname): raise RuntimeError( "Error parsing fine calibration file, " "should have 14 or 16 entries per line " - "but found %s on line:\n%s" % (len(vals), line) + f"but found {len(vals)} on line:\n{line}" ) # `vals` contains channel number ch_name = vals[0] diff --git a/mne/preprocessing/_lof.py b/mne/preprocessing/_lof.py new file mode 100644 index 00000000000..6d777599a8a --- /dev/null +++ b/mne/preprocessing/_lof.py @@ -0,0 +1,99 @@ +"""Bad channel detection using Local Outlier Factor (LOF).""" + +# Authors: Velu Prabhakar Kumaravel +# +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np + +from .._fiff.pick import _picks_to_idx +from ..io.base import BaseRaw +from ..utils import _soft_import, _validate_type, logger, verbose + + +@verbose +def find_bad_channels_lof( + raw, + n_neighbors=20, + *, + picks=None, + metric="euclidean", + threshold=1.5, + return_scores=False, + verbose=None, +): + """Find bad channels using Local Outlier Factor (LOF) algorithm. + + Parameters + ---------- + raw : instance of Raw + Raw data to process. + n_neighbors : int + Number of neighbors defining the local neighborhood (default is 20). + Smaller values will lead to higher LOF scores. + %(picks_good_data)s + metric : str + Metric to use for distance computation. Default is “euclidean”, + see :func:`sklearn.metrics.pairwise.distance_metrics` for details. + threshold : float + Threshold to define outliers. Theoretical threshold ranges anywhere + between 1.0 and any positive integer. Default: 1.5 + It is recommended to consider this as an hyperparameter to optimize. + return_scores : bool + If ``True``, return a dictionary with LOF scores for each + evaluated channel. Default is ``False``. + %(verbose)s + + Returns + ------- + noisy_chs : list + List of bad M/EEG channels that were automatically detected. + scores : ndarray, shape (n_picks,) + Only returned when ``return_scores`` is ``True``. It contains the + LOF outlier score for each channel in ``picks``. + + See Also + -------- + maxwell_filter + annotate_amplitude + + Notes + ----- + See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on + choosing ``threshold``. + + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ # noqa: E501 + _soft_import("sklearn", "using LOF detection", strict=True) + from sklearn.neighbors import LocalOutlierFactor + + _validate_type(raw, BaseRaw, "raw") + # Get the channel types + channel_types = raw.get_channel_types() + picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads") + picked_ch_types = set(channel_types[p] for p in picks) + + # Check if there are different channel types + if len(picked_ch_types) != 1: + raise ValueError( + f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}" + ) + ch_names = [raw.ch_names[pick] for pick in picks] + data = raw.get_data(picks=picks) + clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric) + clf.fit_predict(data) + scores_lof = clf.negative_outlier_factor_ + bad_channel_indices = [ + i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold + ] + bads = [ch_names[idx] for idx in bad_channel_indices] + logger.info(f"LOF: Detected bad channel(s): {bads}") + if return_scores: + return bads, scores_lof + else: + return bads diff --git a/mne/preprocessing/_peak_finder.py b/mne/preprocessing/_peak_finder.py index c1808397991..078e4aadb23 100644 --- a/mne/preprocessing/_peak_finder.py +++ b/mne/preprocessing/_peak_finder.py @@ -56,7 +56,7 @@ def peak_finder(x0, thresh=None, extrema=1, verbose=None): if thresh is None: thresh = (np.max(x0) - np.min(x0)) / 4 - logger.debug("Peak finder automatic threshold: %0.2g" % (thresh,)) + logger.debug(f"Peak finder automatic threshold: {thresh:0.2g}") assert extrema in [-1, 1] @@ -85,7 +85,7 @@ def peak_finder(x0, thresh=None, extrema=1, verbose=None): left_min = min_mag # Deal with first point a little differently since tacked it on - # Calculate the sign of the derivative since we taked the first point + # Calculate the sign of the derivative since we took the first point # on it does not necessarily alternate like the rest. signDx = np.sign(np.diff(x[:3])) if signDx[0] <= 0: # The first point is larger or equal to the second diff --git a/mne/preprocessing/_regress.py b/mne/preprocessing/_regress.py index 31a842f7d4f..260796a221d 100644 --- a/mne/preprocessing/_regress.py +++ b/mne/preprocessing/_regress.py @@ -6,7 +6,7 @@ import numpy as np -from .._fiff.pick import _picks_to_idx +from .._fiff.pick import _picks_to_idx, pick_info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..epochs import BaseEpochs from ..evoked import Evoked @@ -178,9 +178,7 @@ def fit(self, inst): reference (see :func:`mne.set_eeg_reference`) before performing EOG regression. """ - self._check_inst(inst) - picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude) - picks_artifact = _picks_to_idx(inst.info, self.picks_artifact) + picks, picks_artifact = self._check_inst(inst) # Calculate regression coefficients. Add a row of ones to also fit the # intercept. @@ -232,9 +230,7 @@ def apply(self, inst, copy=True): """ if copy: inst = inst.copy() - self._check_inst(inst) - picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude) - picks_artifact = _picks_to_idx(inst.info, self.picks_artifact) + picks, picks_artifact = self._check_inst(inst) # Check that the channels are compatible with the regression weights. ref_picks = _picks_to_idx( @@ -324,19 +320,25 @@ def _check_inst(self, inst): _validate_type( inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, Evoked" ) - if _needs_eeg_average_ref_proj(inst.info): + picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude) + picks_artifact = _picks_to_idx(inst.info, self.picks_artifact) + all_picks = np.unique(np.concatenate([picks, picks_artifact])) + use_info = pick_info(inst.info, all_picks) + del all_picks + if _needs_eeg_average_ref_proj(use_info): raise RuntimeError( - "No reference for the EEG channels has been " - "set. Use inst.set_eeg_reference() to do so." + "No average reference for the EEG channels has been " + "set. Use inst.set_eeg_reference(projection=True) to do so." ) if self.proj and not inst.proj: inst.apply_proj() - if not inst.proj and len(inst.info.get("projs", [])) > 0: + if not inst.proj and len(use_info.get("projs", [])) > 0: raise RuntimeError( "Projections need to be applied before " "regression can be performed. Use the " ".apply_proj() method to do so." ) + return picks, picks_artifact def __repr__(self): """Produce a string representation of this object.""" diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index d5bcfccb730..6b69bc9abca 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -25,7 +25,15 @@ apply_trans, quat_to_rot, ) -from ..utils import _mask_to_onsets_offsets, _pl, _validate_type, logger, verbose +from ..utils import ( + _check_option, + _mask_to_onsets_offsets, + _pl, + _validate_type, + logger, + verbose, + warn, +) @verbose @@ -94,16 +102,13 @@ def annotate_muscle_zscore( ch_type = "eeg" else: raise ValueError( - "No M/EEG channel types found, please specify a" - " ch_type or provide M/EEG sensor data" + "No M/EEG channel types found, please specify a 'ch_type' or provide " + "M/EEG sensor data." ) - logger.info("Using %s sensors for muscle artifact detection" % (ch_type)) - - if ch_type in ("mag", "grad"): - raw_copy.pick(ch_type) + logger.info("Using %s sensors for muscle artifact detection", ch_type) else: - ch_type = {"meg": False, ch_type: True} - raw_copy.pick(**ch_type) + _check_option("ch_type", ch_type, ["mag", "grad", "eeg"]) + raw_copy.pick(ch_type) raw_copy.filter( filter_freq[0], @@ -245,7 +250,7 @@ def annotate_movement( if use_dev_head_trans not in ["average", "info"]: raise ValueError( "use_dev_head_trans must be either" - + " 'average' or 'info': got '%s'" % (use_dev_head_trans,) + f" 'average' or 'info': got '{use_dev_head_trans}'" ) if use_dev_head_trans == "average": @@ -289,7 +294,8 @@ def annotate_movement( return annot, disp -def compute_average_dev_head_t(raw, pos): +@verbose +def compute_average_dev_head_t(raw, pos, *, verbose=None): """Get new device to head transform based on good segments. Segments starting with "BAD" annotations are not included for calculating @@ -297,19 +303,59 @@ def compute_average_dev_head_t(raw, pos): Parameters ---------- - raw : instance of Raw - Data to compute head position. - pos : array, shape (N, 10) - The position and quaternion parameters from cHPI fitting. + raw : instance of Raw | list of Raw + Data to compute head position. Can be a list containing multiple raw + instances. + pos : array, shape (N, 10) | list of ndarray + The position and quaternion parameters from cHPI fitting. Can be + a list containing multiple position arrays, one per raw instance passed. + %(verbose)s Returns ------- dev_head_t : instance of Transform New ``dev_head_t`` transformation using the averaged good head positions. + + Notes + ----- + .. versionchanged:: 1.7 + Support for multiple raw instances and position arrays was added. """ + # Get weighted head pos trans and rot + if not isinstance(raw, (list, tuple)): + raw = [raw] + if not isinstance(pos, (list, tuple)): + pos = [pos] + if len(pos) != len(raw): + raise ValueError( + f"Number of head positions ({len(pos)}) must match the number of raw " + f"instances ({len(raw)})" + ) + hp = list() + dt = list() + for ri, (r, p) in enumerate(zip(raw, pos)): + _validate_type(r, BaseRaw, f"raw[{ri}]") + _validate_type(p, np.ndarray, f"pos[{ri}]") + hp_, dt_ = _raw_hp_weights(r, p) + hp.append(hp_) + dt.append(dt_) + hp = np.concatenate(hp, axis=0) + dt = np.concatenate(dt, axis=0) + dt /= dt.sum() + best_q = _average_quats(hp[:, 1:4], weights=dt) + trans = np.eye(4) + trans[:3, :3] = quat_to_rot(best_q) + trans[:3, 3] = dt @ hp[:, 4:7] + dist = np.linalg.norm(trans[:3, 3]) + if dist > 1: # less than 1 meter is sane + warn(f"Implausible head position detected: {dist} meters from device origin") + dev_head_t = Transform("meg", "head", trans) + return dev_head_t + + +def _raw_hp_weights(raw, pos): sfreq = raw.info["sfreq"] seg_good = np.ones(len(raw.times)) - trans_pos = np.zeros(3) hp = pos.copy() hp_ts = hp[:, 0] - raw._first_time @@ -349,19 +395,7 @@ def compute_average_dev_head_t(raw, pos): assert (dt >= 0).all() dt = dt / sfreq del seg_good, idx - - # Get weighted head pos trans and rot - trans_pos += np.dot(dt, hp[:, 4:7]) - - rot_qs = hp[:, 1:4] - best_q = _average_quats(rot_qs, weights=dt) - - trans = np.eye(4) - trans[:3, :3] = quat_to_rot(best_q) - trans[:3, 3] = trans_pos / dt.sum() - assert np.linalg.norm(trans[:3, 3]) < 1 # less than 1 meter is sane - dev_head_t = Transform("meg", "head", trans) - return dev_head_t + return hp, dt def _annotations_from_mask(times, mask, annot_name, orig_time=None): @@ -599,7 +633,7 @@ def annotate_break( # Log some info n_breaks = len(break_annotations) break_times = [ - f"{o:.1f} – {o+d:.1f} s [{d:.1f} s]" + f"{o:.1f} – {o + d:.1f} s [{d:.1f} s]" for o, d in zip(break_annotations.onset, break_annotations.duration) ] break_times = "\n ".join(break_times) diff --git a/mne/preprocessing/ecg.py b/mne/preprocessing/ecg.py index d773f72ba41..e36319316b1 100644 --- a/mne/preprocessing/ecg.py +++ b/mne/preprocessing/ecg.py @@ -322,7 +322,7 @@ def _get_ecg_channel_index(ch_name, inst): ) else: if ch_name not in inst.ch_names: - raise ValueError("%s not in channel list (%s)" % (ch_name, inst.ch_names)) + raise ValueError(f"{ch_name} not in channel list ({inst.ch_names})") ecg_idx = pick_channels(inst.ch_names, include=[ch_name]) if len(ecg_idx) == 0: diff --git a/mne/preprocessing/eyetracking/__init__.py b/mne/preprocessing/eyetracking/__init__.py index 01a30bf4436..efab0fb079d 100644 --- a/mne/preprocessing/eyetracking/__init__.py +++ b/mne/preprocessing/eyetracking/__init__.py @@ -5,6 +5,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from .eyetracking import set_channel_types_eyetrack +from .eyetracking import set_channel_types_eyetrack, convert_units from .calibration import Calibration, read_eyelink_calibration from ._pupillometry import interpolate_blinks +from .utils import get_screen_visual_angle diff --git a/mne/preprocessing/eyetracking/_pupillometry.py b/mne/preprocessing/eyetracking/_pupillometry.py index b1d544f24ab..8da124b2e1f 100644 --- a/mne/preprocessing/eyetracking/_pupillometry.py +++ b/mne/preprocessing/eyetracking/_pupillometry.py @@ -60,14 +60,14 @@ def interpolate_blinks(raw, buffer=0.05, match="BAD_blink", interpolate_gaze=Fal # get the blink annotations blink_annots = [annot for annot in raw.annotations if annot["description"] in match] if not blink_annots: - warn("No annotations matching {} found. Aborting.".format(match)) + warn(f"No annotations matching {match} found. Aborting.") return raw _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze=interpolate_gaze) # remove bad from the annotation description for desc in match: if desc.startswith("BAD_"): - logger.info("Removing 'BAD_' from {}.".format(desc)) + logger.info(f"Removing 'BAD_' from {desc}.") raw.annotations.rename({desc: desc.replace("BAD_", "")}) return raw @@ -77,6 +77,7 @@ def _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze): logger.info("Interpolating missing data during blinks...") pre_buffer, post_buffer = buffer # iterate over each eyetrack channel and interpolate the blinks + interpolated_chs = [] for ci, ch_info in enumerate(raw.info["chs"]): if interpolate_gaze: # interpolate over all eyetrack channels if ch_info["kind"] != FIFF.FIFFV_EYETRACK_CH: @@ -107,3 +108,10 @@ def _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze): ) # Replace the samples at the blink_indices with the interpolated values raw._data[ci, blink_indices] = interpolated_samples + interpolated_chs.append(ch_info["ch_name"]) + if interpolated_chs: + logger.info( + f"Interpolated {len(interpolated_chs)} channels: {interpolated_chs}" + ) + else: + warn("No channels were interpolated.") diff --git a/mne/preprocessing/eyetracking/calibration.py b/mne/preprocessing/eyetracking/calibration.py index e405e72f9eb..84b53ee3006 100644 --- a/mne/preprocessing/eyetracking/calibration.py +++ b/mne/preprocessing/eyetracking/calibration.py @@ -219,6 +219,6 @@ def read_eyelink_calibration( each eye of every calibration that was performed during the recording session. """ fname = _check_fname(fname, overwrite="read", must_exist=True, name="fname") - logger.info("Reading calibration data from {}".format(fname)) + logger.info(f"Reading calibration data from {fname}") lines = fname.read_text(encoding="ASCII").splitlines() return _parse_calibration(lines, screen_size, screen_distance, screen_resolution) diff --git a/mne/preprocessing/eyetracking/eyetracking.py b/mne/preprocessing/eyetracking/eyetracking.py index ab3d51c6af1..883cf1934c6 100644 --- a/mne/preprocessing/eyetracking/eyetracking.py +++ b/mne/preprocessing/eyetracking/eyetracking.py @@ -8,6 +8,12 @@ import numpy as np from ..._fiff.constants import FIFF +from ...epochs import BaseEpochs +from ...evoked import Evoked +from ...io import BaseRaw +from ...utils import _check_option, _validate_type, logger, warn +from .calibration import Calibration +from .utils import _check_calibration # specific function to set eyetrack channels @@ -78,8 +84,7 @@ def set_channel_types_eyetrack(inst, mapping): ch_type = ch_desc[0].lower() if ch_type not in valid_types: raise ValueError( - "ch_type must be one of {}. " - "Got '{}' instead.".format(valid_types, ch_type) + f"ch_type must be one of {valid_types}. Got '{ch_type}' instead." ) if ch_type == "eyegaze": coil_type = FIFF.FIFFV_COIL_EYETRACK_POS @@ -165,3 +170,162 @@ def _convert_mm_to_m(array): def _convert_deg_to_rad(array): return array * np.pi / 180.0 + + +def convert_units(inst, calibration, to="radians"): + """Convert Eyegaze data from pixels to radians of visual angle or vice versa. + + .. warning:: + Currently, depending on the units (pixels or radians), eyegaze channels may not + be reported correctly in visualization functions like :meth:`mne.io.Raw.plot`. + They will be shown correctly in :func:`mne.viz.eyetracking.plot_gaze`. + See :gh:`11879` for more information. + + .. Important:: + There are important considerations to keep in mind when using this function, + see the Notes section below. + + Parameters + ---------- + inst : instance of Raw, Epochs, or Evoked + The Raw, Epochs, or Evoked instance with eyegaze channels. + calibration : Calibration + Instance of Calibration, containing information about the screen size + (in meters), viewing distance (in meters), and the screen resolution + (in pixels). + to : str + Must be either ``"radians"`` or ``"pixels"``, indicating the desired unit. + + Returns + ------- + inst : instance of Raw | Epochs | Evoked + The Raw, Epochs, or Evoked instance, modified in place. + + Notes + ----- + There are at least two important considerations to keep in mind when using this + function: + + 1. Converting between on-screen pixels and visual angle is not a linear + transformation. If the visual angle subtends less than approximately ``.44`` + radians (``25`` degrees), the conversion could be considered to be approximately + linear. However, as the visual angle increases, the conversion becomes + increasingly non-linear. This may lead to unexpected results after converting + between pixels and visual angle. + + * This function assumes that the head is fixed in place and aligned with the center + of the screen, such that gaze to the center of the screen results in a visual + angle of ``0`` radians. + + .. versionadded:: 1.7 + """ + _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst") + _validate_type(calibration, Calibration, "calibration") + _check_option("to", to, ("radians", "pixels")) + _check_calibration(calibration) + + # get screen parameters + screen_size = calibration["screen_size"] + screen_resolution = calibration["screen_resolution"] + dist = calibration["screen_distance"] + + # loop through channels and convert units + converted_chs = [] + for ch_dict in inst.info["chs"]: + if ch_dict["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_POS: + continue + unit = ch_dict["unit"] + name = ch_dict["ch_name"] + + if ch_dict["loc"][4] == -1: # x-coordinate + size = screen_size[0] + res = screen_resolution[0] + elif ch_dict["loc"][4] == 1: # y-coordinate + size = screen_size[1] + res = screen_resolution[1] + else: + raise ValueError( + f"loc array not set properly for channel '{name}'. Index 4 should" + f" be -1 or 1, but got {ch_dict['loc'][4]}" + ) + # check unit, convert, and set new unit + if to == "radians": + if unit != FIFF.FIFF_UNIT_PX: + raise ValueError( + f"Data must be in pixels in order to convert to radians." + f" Got {unit} for {name}" + ) + inst.apply_function(_pix_to_rad, picks=name, size=size, res=res, dist=dist) + ch_dict["unit"] = FIFF.FIFF_UNIT_RAD + elif to == "pixels": + if unit != FIFF.FIFF_UNIT_RAD: + raise ValueError( + f"Data must be in radians in order to convert to pixels." + f" Got {unit} for {name}" + ) + inst.apply_function(_rad_to_pix, picks=name, size=size, res=res, dist=dist) + ch_dict["unit"] = FIFF.FIFF_UNIT_PX + converted_chs.append(name) + if converted_chs: + logger.info(f"Converted {converted_chs} to {to}.") + if to == "radians": + # check if any values are greaater than .44 radians + # (25 degrees) and warn user + data = inst.get_data(picks=converted_chs) + if np.any(np.abs(data) > 0.52): + warn( + "Some visual angle values subtend greater than .52 radians " + "(30 degrees), meaning that the conversion between pixels " + "and visual angle may be very non-linear. Take caution when " + "interpreting these values. Max visual angle value in data:" + f" {np.nanmax(data):0.2f} radians.", + UserWarning, + ) + else: + warn("Could not find any eyegaze channels. Doing nothing.", UserWarning) + return inst + + +def _pix_to_rad(data, size, res, dist): + """Convert pixel coordinates to radians of visual angle. + + Parameters + ---------- + data : array-like, shape (n_samples,) + A vector of pixel coordinates. + size : float + The width or height of the screen, in meters. + res : int + The screen resolution in pixels, along the x or y axis. + dist : float + The viewing distance from the screen, in meters. + + Returns + ------- + rad : ndarray, shape (n_samples) + the data in radians. + """ + # Center the data so that 0 radians will be the center of the screen + data -= res / 2 + # How many meters is the pixel width or height + px_size = size / res + # Convert to radians + return np.arctan((data * px_size) / dist) + + +def _rad_to_pix(data, size, res, dist): + """Convert radians of visual angle to pixel coordinates. + + See the parameters section of _pix_to_rad for more information. + + Returns + ------- + pix : ndarray, shape (n_samples) + the data in pixels. + """ + # How many meters is the pixel width or height + px_size = size / res + # 1. calculate length of opposite side of triangle (in meters) + # 2. convert meters to pixel coordinates + # 3. add half of screen resolution to uncenter the pixel data (0,0 is top left) + return np.tan(data) * dist / px_size + res / 2 diff --git a/mne/preprocessing/eyetracking/tests/test_eyetracking.py b/mne/preprocessing/eyetracking/tests/test_eyetracking.py new file mode 100644 index 00000000000..8bea006d9fd --- /dev/null +++ b/mne/preprocessing/eyetracking/tests/test_eyetracking.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +import mne +from mne._fiff.constants import FIFF +from mne.utils import _record_warnings + + +def test_set_channel_types_eyetrack(eyetrack_raw): + """Test that set_channel_types_eyetrack worked on the fixture.""" + assert eyetrack_raw.info["chs"][0]["kind"] == FIFF.FIFFV_EYETRACK_CH + assert eyetrack_raw.info["chs"][1]["coil_type"] == FIFF.FIFFV_COIL_EYETRACK_POS + assert eyetrack_raw.info["chs"][0]["unit"] == FIFF.FIFF_UNIT_PX + assert eyetrack_raw.info["chs"][2]["unit"] == FIFF.FIFF_UNIT_NONE + + +def test_convert_units(eyetrack_raw, eyetrack_cal): + """Test unit conversion.""" + raw, cal = eyetrack_raw, eyetrack_cal # shorter names + + # roundtrip conversion should be identical to original data + data_orig = raw.get_data(picks=[0]) # take the first x-coord channel + mne.preprocessing.eyetracking.convert_units(raw, cal, "radians") + assert raw.info["chs"][0]["unit"] == FIFF.FIFF_UNIT_RAD + # Gaze was to center of screen, so x-coord and y-coord should now be 0 radians + assert_allclose(raw.get_data(picks=[0, 1]), 0) + + # Should raise an error if we try to convert to radians again + with pytest.raises(ValueError, match="Data must be in"): + mne.preprocessing.eyetracking.convert_units(raw, cal, "radians") + + # Convert back to pixels + mne.preprocessing.eyetracking.convert_units(raw, cal, "pixels") + assert raw.info["chs"][1]["unit"] == FIFF.FIFF_UNIT_PX + data_new = raw.get_data(picks=[0]) + assert_allclose(data_orig, data_new) + + # Should raise an error if we try to convert to pixels again + with pytest.raises(ValueError, match="Data must be in"): + mne.preprocessing.eyetracking.convert_units(raw, cal, "pixels") + + # Finally, check that we raise other errors or warnings when we should + # warn if no eyegaze channels found + raw_misc = raw.copy() + with _record_warnings(): # channel units change warning + raw_misc.set_channel_types({ch: "misc" for ch in raw_misc.ch_names}) + with pytest.warns(UserWarning, match="Could not"): + mne.preprocessing.eyetracking.convert_units(raw_misc, cal, "radians") + + # raise an error if the calibration is missing a key + bad_cal = cal.copy() + bad_cal.pop("screen_size") + bad_cal["screen_distance"] = None + with pytest.raises(KeyError, match="Calibration object must have the following"): + mne.preprocessing.eyetracking.convert_units(raw, bad_cal, "radians") + + # warn if visual angle is too large + cal_tmp = cal.copy() + cal_tmp["screen_distance"] = 0.1 + raw_tmp = raw.copy() + raw_tmp._data[0, :10] = 1900 # gaze to extremity of screen + with pytest.warns(UserWarning, match="Some visual angle values"): + mne.preprocessing.eyetracking.convert_units(raw_tmp, cal_tmp, "radians") + + # raise an error if channel locations not set + raw_missing = raw.copy() + raw_missing.info["chs"][0]["loc"] = np.zeros(12) + with pytest.raises(ValueError, match="loc array not set"): + mne.preprocessing.eyetracking.convert_units(raw_missing, cal, "radians") + + +def test_get_screen_visual_angle(eyetrack_cal): + """Test calculating the radians of visual angle for a screen.""" + # Our toy calibration should subtend .56 x .32 radians i.e 31.5 x 18.26 degrees + viz_angle = mne.preprocessing.eyetracking.get_screen_visual_angle(eyetrack_cal) + assert viz_angle.shape == (2,) + np.testing.assert_allclose(np.round(viz_angle, 2), (0.56, 0.32)) diff --git a/mne/preprocessing/eyetracking/utils.py b/mne/preprocessing/eyetracking/utils.py new file mode 100644 index 00000000000..89c379c9760 --- /dev/null +++ b/mne/preprocessing/eyetracking/utils.py @@ -0,0 +1,41 @@ +import numpy as np + +from ...utils import _validate_type +from .calibration import Calibration + + +def _check_calibration( + calibration, want_keys=("screen_size", "screen_resolution", "screen_distance") +): + missing_keys = [] + for key in want_keys: + if calibration.get(key, None) is None: + missing_keys.append(key) + + if missing_keys: + raise KeyError( + "Calibration object must have the following keys with valid values:" + f" {', '.join(missing_keys)}" + ) + else: + return True + + +def get_screen_visual_angle(calibration): + """Calculate the radians of visual angle that the participant screen subtends. + + Parameters + ---------- + calibration : Calibration + An instance of Calibration. Must have valid values for ``"screen_size"`` and + ``"screen_distance"`` keys. + + Returns + ------- + visual angle in radians : ndarray, shape (2,) + The visual angle of the monitor width and height, respectively. + """ + _validate_type(calibration, Calibration, "calibration") + _check_calibration(calibration, want_keys=("screen_size", "screen_distance")) + size = np.array(calibration["screen_size"]) + return 2 * np.arctan(size / (2 * calibration["screen_distance"])) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 64667185330..85bd312f3b2 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -16,7 +16,7 @@ from inspect import Parameter, isfunction, signature from numbers import Integral from time import time -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union import numpy as np from scipy import linalg, stats @@ -189,8 +189,8 @@ def _check_for_unsupported_ica_channels(picks, info, allow_ref_meg=False): check = all([ch in types for ch in chs]) if not check: raise ValueError( - "Invalid channel type%s passed for ICA: %s." - "Only the following types are supported: %s" % (_pl(chs), chs, types) + f"Invalid channel type{_pl(chs)} passed for ICA: {chs}." + f"Only the following types are supported: {types}" ) @@ -445,7 +445,7 @@ def __init__( max_iter="auto", allow_ref_meg=False, verbose=None, - ): # noqa: D102 + ): _validate_type(method, str, "method") _validate_type(n_components, (float, "int-like", None)) @@ -508,13 +508,13 @@ def _get_infos_for_repr(self): class _InfosForRepr: fit_on: Optional[Literal["raw data", "epochs"]] fit_method: Literal["fastica", "infomax", "extended-infomax", "picard"] - fit_params: Dict[str, Union[str, float]] + fit_params: dict[str, Union[str, float]] fit_n_iter: Optional[int] fit_n_samples: Optional[int] fit_n_components: Optional[int] fit_n_pca_components: Optional[int] - ch_types: List[str] - excludes: List[str] + ch_types: list[str] + excludes: list[str] if self.current_fit == "unfitted": fit_on = None @@ -754,7 +754,7 @@ def fit( var_ord = var.argsort()[::-1] _sort_components(self, var_ord, copy=False) t_stop = time() - logger.info("Fitting ICA took {:.1f}s.".format(t_stop - t_start)) + logger.info(f"Fitting ICA took {t_stop - t_start:.1f}s.") return self def _reset(self): @@ -818,8 +818,8 @@ def _fit_epochs(self, epochs, picks, decim, verbose): """Aux method.""" if epochs.events.size == 0: raise RuntimeError( - "Tried to fit ICA with epochs, but none were " - 'found: epochs.events is "{}".'.format(epochs.events) + "Tried to fit ICA with epochs, but none were found: epochs.events is " + f'"{epochs.events}".' ) # this should be a copy (picks a list of int) @@ -935,7 +935,7 @@ def _fit(self, data, fit_type): f"n_pca_components ({self.n_pca_components}) results in " f"only {n_pca} components (EV={evs[1]:0.1f}%)" ) - logger.info("%s: %s components" % (msg, self.n_components_)) + logger.info(f"{msg}: {self.n_components_} components") # the things to store for PCA self.pca_mean_ = pca.mean_ @@ -1550,7 +1550,7 @@ def _find_bads_ch( elif measure == "correlation": this_idx = np.where(abs(scores[-1]) > threshold)[0] else: - raise ValueError("Unknown measure {}".format(measure)) + raise ValueError(f"Unknown measure {measure}") idx += [this_idx] self.labels_["%s/%i/" % (prefix, ii) + ch] = list(this_idx) @@ -2784,7 +2784,7 @@ def _get_target_ch(container, target): picks = list(set(picks) - set(ref_picks)) if len(picks) == 0: - raise ValueError("%s not in channel list (%s)" % (target, container.ch_names)) + raise ValueError(f"{target} not in channel list ({container.ch_names})") return picks @@ -3063,7 +3063,7 @@ def read_ica(fname, verbose=None): fid.close() - ica_init, ica_misc = [_deserialize(k) for k in (ica_init, ica_misc)] + ica_init, ica_misc = (_deserialize(k) for k in (ica_init, ica_misc)) n_pca_components = ica_init.pop("n_pca_components") current_fit = ica_init.pop("current_fit") max_pca_components = ica_init.pop("max_pca_components") @@ -3341,7 +3341,7 @@ def corrmap( template_fig, labelled_ics = None, None if plot is True: if is_subject: # plotting from an ICA object - ttl = "Template from subj. {}".format(str(template[0])) + ttl = f"Template from subj. {str(template[0])}" template_fig = icas[template[0]].plot_components( picks=template[1], ch_type=ch_type, @@ -3376,8 +3376,8 @@ def corrmap( threshold = np.atleast_1d(np.array(threshold, float)).ravel() threshold_err = ( "No component detected using when z-scoring " - "threshold%s %s, consider using a more lenient " - "threshold" % (threshold_extra, threshold) + f"threshold{threshold_extra} {threshold}, consider using a more lenient " + "threshold" ) if len(all_maps) == 0: raise RuntimeError(threshold_err) @@ -3393,7 +3393,7 @@ def corrmap( # find iteration with highest avg correlation with target _, median_corr, _, max_corrs = paths[np.argmax([path[1] for path in paths])] - allmaps, indices, subjs, nones = [list() for _ in range(4)] + allmaps, indices, subjs, nones = (list() for _ in range(4)) logger.info("Median correlation with constructed map: %0.3f" % median_corr) del median_corr if plot is True: diff --git a/mne/preprocessing/infomax_.py b/mne/preprocessing/infomax_.py index 9b2841caa20..0f873c9d0bd 100644 --- a/mne/preprocessing/infomax_.py +++ b/mne/preprocessing/infomax_.py @@ -145,7 +145,7 @@ def infomax( if block is None: block = int(math.floor(math.sqrt(n_samples / 3.0))) - logger.info("Computing%sInfomax ICA" % " Extended " if extended else " ") + logger.info(f"Computing{' Extended ' if extended else ' '}Infomax ICA") # collect parameters nblock = n_samples // block diff --git a/mne/preprocessing/interpolate.py b/mne/preprocessing/interpolate.py index 8e69f364a10..fc9b3c0fdec 100644 --- a/mne/preprocessing/interpolate.py +++ b/mne/preprocessing/interpolate.py @@ -45,9 +45,7 @@ def equalize_bads(insts, interp_thresh=1.0, copy=True): them, possibly with some formerly bad channels interpolated. """ if not 0 <= interp_thresh <= 1: - raise ValueError( - "interp_thresh must be between 0 and 1, got %s" % (interp_thresh,) - ) + raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}") all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts]))) if isinstance(insts[0], BaseEpochs): @@ -123,8 +121,7 @@ def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4): pos = montage.get_positions() if pos["coord_frame"] != "head": raise RuntimeError( - "Montage channel positions must be in ``head``" - "got {}".format(pos["coord_frame"]) + f"Montage channel positions must be in ``head`` got {pos['coord_frame']}" ) # store bads orig to put back at the end bads_orig = inst.info["bads"] @@ -164,7 +161,7 @@ def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4): # compute centroid position in spherical "head" coordinates pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names) # create the virtual channel info and set the position - virtual_info = create_info([f"virtual {k+1}"], inst.info["sfreq"], "eeg") + virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg") virtual_info["chs"][0]["loc"][:3] = pos_virtual # create virtual channel data = inst.get_data(picks=group_names) @@ -183,7 +180,7 @@ def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4): nave=inst.nave, kind=inst.kind, ) - virtual_chs[f"virtual {k+1}"] = virtual_ch + virtual_chs[f"virtual {k + 1}"] = virtual_ch # add the virtual channels inst.add_channels(list(virtual_chs.values()), force_update_info=True) diff --git a/mne/preprocessing/maxfilter.py b/mne/preprocessing/maxfilter.py deleted file mode 100644 index 64a48b68cf3..00000000000 --- a/mne/preprocessing/maxfilter.py +++ /dev/null @@ -1,230 +0,0 @@ -# Authors: Alexandre Gramfort -# Matti Hämäläinen -# Martin Luessi -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -import os - -from ..bem import fit_sphere_to_headshape -from ..io import read_raw_fif -from ..utils import deprecated, logger, verbose, warn - - -def _mxwarn(msg): - """Warn about a bug.""" - warn( - "Possible MaxFilter bug: %s, more info: " - "http://imaging.mrc-cbu.cam.ac.uk/meg/maxbugs" % msg - ) - - -@deprecated( - "apply_maxfilter will be removed in 1.7, use mne.preprocessing.maxwell_filter or " - "the MEGIN command-line utility maxfilter and mne.bem.fit_sphere_to_headshape " - "instead." -) -@verbose -def apply_maxfilter( - in_fname, - out_fname, - origin=None, - frame="device", - bad=None, - autobad="off", - skip=None, - force=False, - st=False, - st_buflen=16.0, - st_corr=0.96, - mv_trans=None, - mv_comp=False, - mv_headpos=False, - mv_hp=None, - mv_hpistep=None, - mv_hpisubt=None, - mv_hpicons=True, - linefreq=None, - cal=None, - ctc=None, - mx_args="", - overwrite=True, - verbose=None, -): - """Apply NeuroMag MaxFilter to raw data. - - Needs Maxfilter license, maxfilter has to be in PATH. - - Parameters - ---------- - in_fname : path-like - Input file name. - out_fname : path-like - Output file name. - origin : array-like or str - Head origin in mm. If None it will be estimated from headshape points. - frame : ``'device'`` | ``'head'`` - Coordinate frame for head center. - bad : str, list (or None) - List of static bad channels. Can be a list with channel names, or a - string with channels (names or logical channel numbers). - autobad : str ('on', 'off', 'n') - Sets automated bad channel detection on or off. - skip : str or a list of float-tuples (or None) - Skips raw data sequences, time intervals pairs in s, - e.g.: 0 30 120 150. - force : bool - Ignore program warnings. - st : bool - Apply the time-domain MaxST extension. - st_buflen : float - MaxSt buffer length in s (disabled if st is False). - st_corr : float - MaxSt subspace correlation limit (disabled if st is False). - mv_trans : str (filename or 'default') (or None) - Transforms the data into the coil definitions of in_fname, or into the - default frame (None: don't use option). - mv_comp : bool (or 'inter') - Estimates and compensates head movements in continuous raw data. - mv_headpos : bool - Estimates and stores head position parameters, but does not compensate - movements (disabled if mv_comp is False). - mv_hp : str (or None) - Stores head position data in an ascii file - (disabled if mv_comp is False). - mv_hpistep : float (or None) - Sets head position update interval in ms (disabled if mv_comp is - False). - mv_hpisubt : str ('amp', 'base', 'off') (or None) - Subtracts hpi signals: sine amplitudes, amp + baseline, or switch off - (disabled if mv_comp is False). - mv_hpicons : bool - Check initial consistency isotrak vs hpifit - (disabled if mv_comp is False). - linefreq : int (50, 60) (or None) - Sets the basic line interference frequency (50 or 60 Hz) - (None: do not use line filter). - cal : str - Path to calibration file. - ctc : str - Path to Cross-talk compensation file. - mx_args : str - Additional command line arguments to pass to MaxFilter. - %(overwrite)s - %(verbose)s - - Returns - ------- - origin: str - Head origin in selected coordinate frame. - """ - # check for possible maxfilter bugs - if mv_trans is not None and mv_comp: - _mxwarn("Don't use '-trans' with head-movement compensation " "'-movecomp'") - - if autobad != "off" and (mv_headpos or mv_comp): - _mxwarn( - "Don't use '-autobad' with head-position estimation " - "'-headpos' or movement compensation '-movecomp'" - ) - - if st and autobad != "off": - _mxwarn("Don't use '-autobad' with '-st' option") - - # determine the head origin if necessary - if origin is None: - logger.info("Estimating head origin from headshape points..") - raw = read_raw_fif(in_fname) - r, o_head, o_dev = fit_sphere_to_headshape(raw.info, units="mm") - raw.close() - logger.info("[done]") - if frame == "head": - origin = o_head - elif frame == "device": - origin = o_dev - else: - raise RuntimeError("invalid frame for origin") - - if not isinstance(origin, str): - origin = "%0.1f %0.1f %0.1f" % (origin[0], origin[1], origin[2]) - - # format command - cmd = "maxfilter -f %s -o %s -frame %s -origin %s " % ( - in_fname, - out_fname, - frame, - origin, - ) - - if bad is not None: - # format the channels - if not isinstance(bad, list): - bad = bad.split() - bad = map(str, bad) - bad_logic = [ch[3:] if ch.startswith("MEG") else ch for ch in bad] - bad_str = " ".join(bad_logic) - - cmd += "-bad %s " % bad_str - - cmd += "-autobad %s " % autobad - - if skip is not None: - if isinstance(skip, list): - skip = " ".join(["%0.3f %0.3f" % (s[0], s[1]) for s in skip]) - cmd += "-skip %s " % skip - - if force: - cmd += "-force " - - if st: - cmd += "-st " - cmd += " %d " % st_buflen - cmd += "-corr %0.4f " % st_corr - - if mv_trans is not None: - cmd += "-trans %s " % mv_trans - - if mv_comp: - cmd += "-movecomp " - if mv_comp == "inter": - cmd += " inter " - - if mv_headpos: - cmd += "-headpos " - - if mv_hp is not None: - cmd += "-hp %s " % mv_hp - - if mv_hpisubt is not None: - cmd += "hpisubt %s " % mv_hpisubt - - if mv_hpicons: - cmd += "-hpicons " - - if linefreq is not None: - cmd += "-linefreq %d " % linefreq - - if cal is not None: - cmd += "-cal %s " % cal - - if ctc is not None: - cmd += "-ctc %s " % ctc - - cmd += mx_args - - if overwrite and os.path.exists(out_fname): - os.remove(out_fname) - - logger.info("Running MaxFilter: %s " % cmd) - if os.getenv("_MNE_MAXFILTER_TEST", "") != "true": # fake maxfilter - # OK to `nosec` because it's deprecated / will be removed - st = os.system(cmd) # nosec B605 - else: - print(cmd) # we can check the output - st = 0 - if st != 0: - raise RuntimeError("MaxFilter returned non-zero exit status %d" % st) - logger.info("[done]") - - return origin diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 25430db6f9e..8f4f5c64521 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -503,8 +503,8 @@ def _prep_maxwell_filter( missing = sorted(set(good_names) - set(got_names)) if missing: raise ValueError( - "%s channel names were missing some " - "good MEG channel names:\n%s" % (item, ", ".join(missing)) + f"{item} channel names were missing some " + f"good MEG channel names:\n{', '.join(missing)}" ) idx = [got_names.index(name) for name in good_names] extended_proj_.append(proj["data"]["data"][:, idx]) @@ -519,8 +519,12 @@ def _prep_maxwell_filter( # sss_cal = dict() if calibration is not None: + # Modifies info in place, so make a copy for recon later + info_recon = info.copy() calibration, sss_cal = _update_sensor_geometry(info, calibration, ignore_ref) mag_or_fine.fill(True) # all channels now have some mag-type data + else: + info_recon = info # Determine/check the origin of the expansion origin = _check_origin(origin, info, coord_frame, disp=True) @@ -553,7 +557,8 @@ def _prep_maxwell_filter( # exp = dict(origin=origin_head, int_order=int_order, ext_order=0) all_coils = _prep_mf_coils(info, ignore_ref) - S_recon = _trans_sss_basis(exp, all_coils, recon_trans, coil_scale) + all_coils_recon = _prep_mf_coils(info_recon, ignore_ref) + S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans, coil_scale) exp["ext_order"] = ext_order exp["extended_proj"] = extended_proj del extended_proj @@ -564,8 +569,8 @@ def _prep_maxwell_filter( dist = np.sqrt(np.sum(_sq(diff))) if dist > 25.0: warn( - "Head position change is over 25 mm (%s) = %0.1f mm" - % (", ".join("%0.1f" % x for x in diff), dist) + f'Head position change is over 25 mm ' + f'({", ".join("%0.1f" % x for x in diff)}) = {dist:0.1f} mm' ) # Reconstruct raw file object with spatiotemporal processed data @@ -699,9 +704,9 @@ def _run_maxwell_filter( max_samps = (ends - onsets).max() if not 0.0 < st_duration <= max_samps + 1.0: raise ValueError( - "st_duration (%0.1fs) must be between 0 and the " + f"st_duration ({st_duration / sfreq:0.1f}s) must be between 0 and the " "longest contiguous duration of the data " - "(%0.1fs)." % (st_duration / sfreq, max_samps / sfreq) + "({max_samps / sfreq:0.1f}s)." ) # Generate time points to break up data into equal-length windows starts, stops = list(), list() @@ -717,16 +722,16 @@ def _run_maxwell_filter( if n_last_buf >= st_duration: logger.info( " Spatiotemporal window did not fit evenly into" - "contiguous data segment. %0.2f seconds were lumped " - "into the previous window." - % ((n_last_buf - st_duration) / sfreq,) + "contiguous data segment. " + f"{(n_last_buf - st_duration) / sfreq:0.2f} seconds " + "were lumped into the previous window." ) else: logger.info( - " Contiguous data segment of duration %0.2f " + f" Contiguous data segment of duration " + f"{n_last_buf / sfreq:0.2f} " "seconds is too short to be processed with tSSS " - "using duration %0.2f" - % (n_last_buf / sfreq, st_duration / sfreq) + f"using duration {st_duration / sfreq:0.2f}" ) assert len(read_lims) >= 2 assert read_lims[0] == onset and read_lims[-1] == end @@ -737,13 +742,13 @@ def _run_maxwell_filter( # Loop through buffer windows of data n_sig = int(np.floor(np.log10(max(len(starts), 0)))) + 1 - logger.info(" Processing %s data chunk%s" % (len(starts), _pl(starts))) + logger.info(f" Processing {len(starts)} data chunk{_pl(starts)}") for ii, (start, stop) in enumerate(zip(starts, stops)): if start == stop: continue # Skip zero-length annotations tsss_valid = (stop - start) >= st_duration rel_times = raw_sss.times[start:stop] - t_str = "%8.3f - %8.3f s" % tuple(rel_times[[0, -1]]) + t_str = f"{rel_times[[0, -1]][0]:8.3f} - {rel_times[[0, -1]][1]:8.3f} s" t_str += ("(#%d/%d)" % (ii + 1, len(starts))).rjust(2 * n_sig + 5) # Get original data @@ -899,8 +904,8 @@ def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info): grad_base = list(grad_base)[0] mag_scale = 1.0 / grad_base logger.info( - " Setting mag_scale=%0.2f based on gradiometer " - "distance %0.2f mm" % (mag_scale, 1000 * grad_base) + f" Setting mag_scale={mag_scale:0.2f} based on gradiometer " + f"distance {1000 * grad_base:0.2f} mm" ) mag_scale = float(mag_scale) coil_scale = np.ones((len(meg_picks), 1)) @@ -957,7 +962,7 @@ def _check_destination(destination, info, head_frame): if recon_trans.to_str != "head" or recon_trans.from_str != "MEG device": raise RuntimeError( "Destination transform is not MEG device -> head, " - "got %s -> %s" % (recon_trans.from_str, recon_trans.to_str) + f"got {recon_trans.from_str} -> {recon_trans.to_str}" ) return recon_trans @@ -1149,14 +1154,14 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq): if not _time_mask(t, tmin=raw._first_time - 1e-3, tmax=None, sfreq=sfreq).all(): raise ValueError( "Head position time points must be greater than " - "first sample offset, but found %0.4f < %0.4f" % (t[0], raw._first_time) + f"first sample offset, but found {t[0]:0.4f} < {raw._first_time:0.4f}" ) max_dist = np.sqrt(np.sum(pos[:, 4:7] ** 2, axis=1)).max() if max_dist > 1.0: warn( - "Found a distance greater than 1 m (%0.3g m) from the device " + f"Found a distance greater than 1 m ({max_dist:0.3g} m) from the device " "origin, positions may be invalid and Maxwell filtering could " - "fail" % (max_dist,) + "fail" ) dev_head_ts = np.zeros((len(t), 4, 4)) dev_head_ts[:, 3, 3] = 1.0 @@ -1311,17 +1316,8 @@ def _regularize( S_decomp = S_decomp.take(reg_moments, axis=1) if regularize is not None or n_use_out != n_out: logger.info( - " Using %s/%s harmonic components for %s " - "(%s/%s in, %s/%s out)" - % ( - n_use_in + n_use_out, - n_in + n_out, - t_str, - n_use_in, - n_in, - n_use_out, - n_out, - ) + f" Using {n_use_in + n_use_out}/{n_in + n_out} harmonic components " + f"for {t_str} ({n_use_in}/{n_in} in, {n_use_out}/{n_out} out)" ) return S_decomp, reg_moments, n_use_in @@ -1348,8 +1344,8 @@ def _get_mf_picks_fix_mags(info, int_order, ext_order, ignore_ref=False, verbose n_bases = _get_n_moments([int_order, ext_order]).sum() if n_bases > good_mask.sum(): raise ValueError( - "Number of requested bases (%s) exceeds number of " - "good sensors (%s)" % (str(n_bases), good_mask.sum()) + f"Number of requested bases ({str(n_bases)}) exceeds number of " + f"good sensors ({good_mask.sum()})" ) recons = [ch for ch in meg_info["bads"]] if len(recons) > 0: @@ -1377,9 +1373,9 @@ def _get_mf_picks_fix_mags(info, int_order, ext_order, ignore_ref=False, verbose FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD, ] mag_or_fine[np.isin(coil_types, ctf_grads)] = False - msg = " Processing %s gradiometers and %s magnetometers" % ( - len(grad_picks), - len(mag_picks), + msg = ( + f" Processing {len(grad_picks)} gradiometers " + f"and {len(mag_picks)} magnetometers" ) n_kit = len(mag_picks) - mag_or_fine.sum() if n_kit > 0: @@ -2113,7 +2109,7 @@ def _prep_fine_cal(info, fine_cal): ) ) if len(missing): - warn("Found cal channel%s not in data: %s" % (_pl(missing), missing)) + warn(f"Found cal channel{_pl(missing)} not in data: {missing}") return info_to_cal, fine_cal, ch_names @@ -2204,8 +2200,8 @@ def _update_sensor_geometry(info, fine_cal, ignore_ref): np.rad2deg(np.arccos(ang_shift), ang_shift) # Convert to degrees logger.info( " Adjusted coil positions by (μ ± σ): " - "%0.1f° ± %0.1f° (max: %0.1f°)" - % (np.mean(ang_shift), np.std(ang_shift), np.max(np.abs(ang_shift))) + f"{np.mean(ang_shift):0.1f}° ± {np.std(ang_shift):0.1f}° " + f"(max: {np.max(np.abs(ang_shift)):0.1f}°)" ) return calibration, sss_cal @@ -2759,7 +2755,7 @@ def find_bad_channels_maxwell( break name = raw.ch_names[these_picks[idx]] - logger.debug(" Bad: %s %0.1f" % (name, max_)) + logger.debug(f" Bad: {name} {max_:0.1f}") these_picks.pop(idx) chunk_noisy.append(name) noisy_chs.update(chunk_noisy) @@ -2780,8 +2776,8 @@ def find_bad_channels_maxwell( scores_noisy = scores_noisy[params["meg_picks"]] thresh_noisy = thresh_noisy[params["meg_picks"]] - logger.info(" Static bad channels: %s" % (noisy_chs,)) - logger.info(" Static flat channels: %s" % (flat_chs,)) + logger.info(f" Static bad channels: {noisy_chs}") + logger.info(f" Static flat channels: {flat_chs}") logger.info("[done]") if return_scores: diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index f6f17a1ae04..9a39a342e50 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -25,8 +25,11 @@ def beer_lambert_law(raw, ppf=6.0): ---------- raw : instance of Raw The optical density data. - ppf : float - The partial pathlength factor. + ppf : tuple | float + The partial pathlength factors for each wavelength. + + .. versionchanged:: 1.7 + Support for different factors for the two wavelengths. Returns ------- @@ -35,8 +38,15 @@ def beer_lambert_law(raw, ppf=6.0): """ raw = raw.copy().load_data() _validate_type(raw, BaseRaw, "raw") - _validate_type(ppf, "numeric", "ppf") - ppf = float(ppf) + _validate_type(ppf, ("numeric", "array-like"), "ppf") + ppf = np.array(ppf, float) + if ppf.ndim == 0: # upcast single float to shape (2,) + ppf = np.array([ppf, ppf]) + if ppf.shape != (2,): + raise ValueError( + f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}" + ) + ppf = ppf[:, np.newaxis] # shape (2, 1) picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert") # This is the one place we *really* need the actual/accurate frequencies freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) diff --git a/mne/preprocessing/nirs/_tddr.py b/mne/preprocessing/nirs/_tddr.py index 59c2ec926d9..a7d0af9a305 100644 --- a/mne/preprocessing/nirs/_tddr.py +++ b/mne/preprocessing/nirs/_tddr.py @@ -111,7 +111,6 @@ def _TDDR(signal, sample_rate): tune = 4.685 D = np.sqrt(np.finfo(signal.dtype).eps) mu = np.inf - iter = 0 # Step 1. Compute temporal derivative of the signal deriv = np.diff(signal_low) @@ -120,8 +119,7 @@ def _TDDR(signal, sample_rate): w = np.ones(deriv.shape) # Step 3. Iterative estimation of robust weights - while iter < 50: - iter = iter + 1 + for _ in range(50): mu0 = mu # Step 3a. Estimate weighted mean diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index 29dd6b3bd4d..da5341b17d5 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -78,7 +78,7 @@ def test_beer_lambert_v_matlab(): pymatreader = pytest.importorskip("pymatreader") raw = read_raw_nirx(fname_nirx_15_0) raw = optical_density(raw) - raw = beer_lambert_law(raw, ppf=0.121) + raw = beer_lambert_law(raw, ppf=(0.121, 0.121)) raw._data *= 1e6 # Scale to uM for comparison to MATLAB matlab_fname = ( diff --git a/mne/preprocessing/nirs/tests/test_optical_density.py b/mne/preprocessing/nirs/tests/test_optical_density.py index 77d7a559bb9..4ac662e0c9a 100644 --- a/mne/preprocessing/nirs/tests/test_optical_density.py +++ b/mne/preprocessing/nirs/tests/test_optical_density.py @@ -52,7 +52,7 @@ def test_optical_density_manual(): test_tol = 0.01 raw = read_raw_nirx(fname_nirx, preload=True) # log(1) = 0 - raw._data[4] = np.ones((145)) + raw._data[4] = np.ones(145) # log(0.5)/-1 = 0.69 # log(1.5)/-1 = -0.40 test_data = np.tile([0.5, 1.5], 73)[:145] diff --git a/mne/preprocessing/otp.py b/mne/preprocessing/otp.py index 6cbd3822641..572e99ec7e2 100644 --- a/mne/preprocessing/otp.py +++ b/mne/preprocessing/otp.py @@ -88,9 +88,8 @@ def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None n_samples = int(round(float(duration) * raw.info["sfreq"])) if n_samples < len(picks_good) - 1: raise ValueError( - "duration (%s) yielded %s samples, which is fewer " - "than the number of channels -1 (%s)" - % (n_samples / raw.info["sfreq"], n_samples, len(picks_good) - 1) + f"duration ({n_samples / raw.info['sfreq']}) yielded {n_samples} samples, " + f"which is fewer than the number of channels -1 ({len(picks_good) - 1})" ) n_overlap = n_samples // 2 raw_otp = raw.copy().load_data(verbose=False) @@ -105,7 +104,8 @@ def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None read_lims = list(range(0, len(raw.times), n_samples)) + [len(raw.times)] for start, stop in zip(read_lims[:-1], read_lims[1:]): logger.info( - " Denoising % 8.2f – % 8.2f s" % tuple(raw.times[[start, stop - 1]]) + f" Denoising {raw.times[[start, stop - 1]][0]: 8.2f} – " + f"{raw.times[[start, stop - 1]][1]: 8.2f} s" ) otp.feed(raw[picks, start:stop][0]) return raw_otp diff --git a/mne/preprocessing/ssp.py b/mne/preprocessing/ssp.py index 985a30a6e9d..271f9195416 100644 --- a/mne/preprocessing/ssp.py +++ b/mne/preprocessing/ssp.py @@ -13,7 +13,7 @@ from .._fiff.reference import make_eeg_average_ref_proj from ..epochs import Epochs from ..proj import compute_proj_epochs, compute_proj_evoked -from ..utils import logger, verbose, warn +from ..utils import _validate_type, logger, verbose, warn from .ecg import find_ecg_events from .eog import find_eog_events @@ -112,7 +112,10 @@ def _compute_exg_proj( my_info["bads"] += bads # Handler rejection parameters + _validate_type(reject, (None, dict), "reject") + _validate_type(flat, (None, dict), "flat") if reject is not None: # make sure they didn't pass None + reject = reject.copy() # must make a copy or we modify default! if ( len( pick_types( @@ -170,6 +173,7 @@ def _compute_exg_proj( ): _safe_del_key(reject, "eog") if flat is not None: # make sure they didn't pass None + flat = flat.copy() if ( len( pick_types( @@ -300,9 +304,9 @@ def compute_proj_ecg( filter_length="10s", n_jobs=None, ch_name=None, - reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6), + reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6), # noqa: B006 flat=None, - bads=[], + bads=(), avg_ref=False, no_proj=False, event_id=999, @@ -461,9 +465,9 @@ def compute_proj_eog( average=True, filter_length="10s", n_jobs=None, - reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf), + reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf), # noqa: B006 flat=None, - bads=[], + bads=(), avg_ref=False, no_proj=False, event_id=998, diff --git a/mne/preprocessing/tests/test_annotate_amplitude.py b/mne/preprocessing/tests/test_annotate_amplitude.py index d39fabdb3ce..3618e480657 100644 --- a/mne/preprocessing/tests/test_annotate_amplitude.py +++ b/mne/preprocessing/tests/test_annotate_amplitude.py @@ -247,11 +247,11 @@ def test_flat_bad_acq_skip(): raw = read_raw_fif(skip_fname, preload=True) annots, bads = annotate_amplitude(raw, flat=0) assert len(annots) == 0 - assert bads == [ # MaxFilter finds the same 21 channels - "MEG%04d" % (int(num),) + assert bads == [ + f"MEG{num.zfill(4)}" for num in "141 331 421 431 611 641 1011 1021 1031 1241 1421 " "1741 1841 2011 2131 2141 2241 2531 2541 2611 2621".split() - ] + ] # MaxFilter finds the same 21 channels # -- overlap of flat segment with bad_acq_skip -- n_ch, n_times = 11, 1000 diff --git a/mne/preprocessing/tests/test_annotate_nan.py b/mne/preprocessing/tests/test_annotate_nan.py index 48e8e95ce00..5e56a83f979 100644 --- a/mne/preprocessing/tests/test_annotate_nan.py +++ b/mne/preprocessing/tests/test_annotate_nan.py @@ -12,9 +12,7 @@ import mne from mne.preprocessing import annotate_nan -raw_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) +raw_fname = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif" @pytest.mark.parametrize("meas_date", (None, "orig")) diff --git a/mne/preprocessing/tests/test_artifact_detection.py b/mne/preprocessing/tests/test_artifact_detection.py index af01fa4416d..6aa386d0b05 100644 --- a/mne/preprocessing/tests/test_artifact_detection.py +++ b/mne/preprocessing/tests/test_artifact_detection.py @@ -18,6 +18,7 @@ compute_average_dev_head_t, ) from mne.tests.test_annotations import _assert_annotations_equal +from mne.transforms import _angle_dist_between_rigid, quat_to_rot, rot_to_quat data_path = testing.data_path(download=False) sss_path = data_path / "SSS" @@ -35,6 +36,7 @@ def test_movement_annotation_head_correction(meas_date): raw.set_meas_date(None) else: assert meas_date == "orig" + raw_unannot = raw.copy() # Check 5 rotation segments are detected annot_rot, [] = annotate_movement(raw, pos, rotation_velocity_limit=5) @@ -67,7 +69,7 @@ def test_movement_annotation_head_correction(meas_date): _assert_annotations_equal(annot_all_2, annot_all) assert annot_all.orig_time == raw.info["meas_date"] raw.set_annotations(annot_all) - dev_head_t = compute_average_dev_head_t(raw, pos) + dev_head_t = compute_average_dev_head_t(raw, pos)["trans"] dev_head_t_ori = np.array( [ @@ -78,13 +80,83 @@ def test_movement_annotation_head_correction(meas_date): ] ) - assert_allclose(dev_head_t_ori, dev_head_t["trans"], rtol=1e-5, atol=0) + assert_allclose(dev_head_t_ori, dev_head_t, rtol=1e-5, atol=0) + + with pytest.raises(ValueError, match="Number of .* must match .*"): + compute_average_dev_head_t([raw], [pos] * 2) + # Using two identical ones should be identical ... + dev_head_t_double = compute_average_dev_head_t([raw] * 2, [pos] * 2)["trans"] + assert_allclose(dev_head_t, dev_head_t_double) + # ... unannotated and annotated versions differ ... + dev_head_t_unannot = compute_average_dev_head_t(raw_unannot, pos)["trans"] + rot_tol = 1.5e-3 + mov_tol = 1e-3 + assert not np.allclose( + dev_head_t_unannot[:3, :3], + dev_head_t[:3, :3], + atol=rot_tol, + rtol=0, + ) + assert not np.allclose( + dev_head_t_unannot[:3, 3], + dev_head_t[:3, 3], + atol=mov_tol, + rtol=0, + ) + # ... and Averaging the two is close to (but not identical!) to operating on the two + # files. Note they shouldn't be identical because there are more time points + # included in the unannotated version! + dev_head_t_naive = np.eye(4) + dev_head_t_naive[:3, :3] = quat_to_rot( + np.mean( + rot_to_quat(np.array([dev_head_t[:3, :3], dev_head_t_unannot[:3, :3]])), + axis=0, + ) + ) + dev_head_t_naive[:3, 3] = np.mean( + [dev_head_t[:3, 3], dev_head_t_unannot[:3, 3]], axis=0 + ) + dev_head_t_combo = compute_average_dev_head_t([raw, raw_unannot], [pos] * 2)[ + "trans" + ] + unit_kw = dict(distance_units="mm", angle_units="deg") + deg_annot_combo, mm_annot_combo = _angle_dist_between_rigid( + dev_head_t, + dev_head_t_combo, + **unit_kw, + ) + deg_unannot_combo, mm_unannot_combo = _angle_dist_between_rigid( + dev_head_t_unannot, + dev_head_t_combo, + **unit_kw, + ) + deg_annot_unannot, mm_annot_unannot = _angle_dist_between_rigid( + dev_head_t, + dev_head_t_unannot, + **unit_kw, + ) + deg_combo_naive, mm_combo_naive = _angle_dist_between_rigid( + dev_head_t_combo, + dev_head_t_naive, + **unit_kw, + ) + # combo<->naive closer than combo<->annotated closer than annotated<->unannotated + assert 0.05 < deg_combo_naive < deg_annot_combo < deg_annot_unannot < 1.5 + assert 0.1 < mm_combo_naive < mm_annot_combo < mm_annot_unannot < 2 + # combo<->naive closer than combo<->unannotated closer than annotated<->unannotated + assert 0.05 < deg_combo_naive < deg_unannot_combo < deg_annot_unannot < 1.5 + assert 0.12 < mm_combo_naive < mm_unannot_combo < mm_annot_unannot < 2.0 # Smoke test skipping time due to previous annotations. raw.set_annotations(Annotations([raw.times[0]], 0.1, "bad")) annot_dis, _ = annotate_movement(raw, pos, mean_distance_limit=0.02) assert annot_dis.duration.size == 1 + # really far should warn + pos[:, 4] += 5 + with pytest.warns(RuntimeWarning, match="Implausible head position"): + compute_average_dev_head_t(raw, pos) + @testing.requires_testing_data @pytest.mark.parametrize("meas_date", (None, "orig")) diff --git a/mne/preprocessing/tests/test_csd.py b/mne/preprocessing/tests/test_csd.py index 31d3c64e5de..1c9be1a86cf 100644 --- a/mne/preprocessing/tests/test_csd.py +++ b/mne/preprocessing/tests/test_csd.py @@ -28,7 +28,7 @@ coords_fname = data_path / "test_eeg_pos.mat" csd_fname = data_path / "test_eeg_csd.mat" -io_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +io_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = io_path / "test_raw.fif" diff --git a/mne/preprocessing/tests/test_ecg.py b/mne/preprocessing/tests/test_ecg.py index 283009de5f1..73fee8c38f0 100644 --- a/mne/preprocessing/tests/test_ecg.py +++ b/mne/preprocessing/tests/test_ecg.py @@ -8,7 +8,7 @@ from mne.io import read_raw_fif from mne.preprocessing import create_ecg_epochs, find_ecg_events -data_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" proj_fname = data_path / "test-proj.fif" diff --git a/mne/preprocessing/tests/test_eeglab_infomax.py b/mne/preprocessing/tests/test_eeglab_infomax.py index f0835099c96..584406820a7 100644 --- a/mne/preprocessing/tests/test_eeglab_infomax.py +++ b/mne/preprocessing/tests/test_eeglab_infomax.py @@ -77,9 +77,9 @@ def test_mne_python_vs_eeglab(): Y = generate_data_for_comparing_against_eeglab_infomax(ch_type, random_state) N, T = Y.shape for method in methods: - eeglab_results_file = "eeglab_%s_results_%s_data.mat" % ( - method, - dict(eeg="eeg", mag="meg")[ch_type], + eeglab_results_file = ( + f"eeglab_{method}_results_" + f"{dict(eeg='eeg', mag='meg')[ch_type]}_data.mat" ) # For comparison against eeglab, make sure the following @@ -171,9 +171,7 @@ def test_mne_python_vs_eeglab(): sources = np.dot(unmixing, Y) mixing = pinv(unmixing) - mvar = ( - np.sum(mixing**2, axis=0) * np.sum(sources**2, axis=1) / (N * T - 1) - ) + mvar = np.sum(mixing**2, axis=0) * np.sum(sources**2, axis=1) / (N * T - 1) windex = np.argsort(mvar)[::-1] unmixing_ordered = unmixing[windex, :] diff --git a/mne/preprocessing/tests/test_eog.py b/mne/preprocessing/tests/test_eog.py index ad977cd581a..eb4163fcc13 100644 --- a/mne/preprocessing/tests/test_eog.py +++ b/mne/preprocessing/tests/test_eog.py @@ -6,7 +6,7 @@ from mne.io import read_raw_fif from mne.preprocessing.eog import find_eog_events -data_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" proj_fname = data_path / "test-proj.fif" diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 95c9e7d63ba..2b3d4df0e3f 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -18,7 +18,7 @@ write_fine_calibration, ) from mne.preprocessing.tests.test_maxwell import _assert_shielding -from mne.transforms import _angle_between_quats, rot_to_quat +from mne.transforms import _angle_dist_between_rigid from mne.utils import object_diff # Define fine calibration filepaths @@ -75,16 +75,17 @@ def test_compute_fine_cal(): orig_trans = _loc_to_coil_trans(orig_locs) want_trans = _loc_to_coil_trans(want_locs) got_trans = _loc_to_coil_trans(got_locs) - dist = np.linalg.norm(got_trans[:, :3, 3] - want_trans[:, :3, 3], axis=1) - assert_allclose(dist, 0.0, atol=1e-6) - dist = np.linalg.norm(got_trans[:, :3, 3] - orig_trans[:, :3, 3], axis=1) - assert_allclose(dist, 0.0, atol=1e-6) - orig_quat = rot_to_quat(orig_trans[:, :3, :3]) - want_quat = rot_to_quat(want_trans[:, :3, :3]) - got_quat = rot_to_quat(got_trans[:, :3, :3]) - want_orig_angles = np.rad2deg(_angle_between_quats(want_quat, orig_quat)) - got_want_angles = np.rad2deg(_angle_between_quats(got_quat, want_quat)) - got_orig_angles = np.rad2deg(_angle_between_quats(got_quat, orig_quat)) + want_orig_angles, want_orig_dist = _angle_dist_between_rigid( + want_trans, orig_trans, angle_units="deg" + ) + got_want_angles, got_want_dist = _angle_dist_between_rigid( + got_trans, want_trans, angle_units="deg" + ) + got_orig_angles, got_orig_dist = _angle_dist_between_rigid( + got_trans, orig_trans, angle_units="deg" + ) + assert_allclose(got_want_dist, 0.0, atol=1e-6) + assert_allclose(got_orig_dist, 0.0, atol=1e-6) for key in ("mag", "grad"): # imb_cals value p = pick_types(raw.info, meg=key, exclude=()) diff --git a/mne/preprocessing/tests/test_hfc.py b/mne/preprocessing/tests/test_hfc.py index 50157bc8551..66af5304cf9 100644 --- a/mne/preprocessing/tests/test_hfc.py +++ b/mne/preprocessing/tests/test_hfc.py @@ -18,7 +18,7 @@ fil_path = testing.data_path(download=False) / "FIL" fname_root = "sub-noise_ses-001_task-noise220622_run-001" -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" ctf_fname = io_dir / "tests" / "data" / "test_ctf_raw.fif" fif_fname = io_dir / "tests" / "data" / "test_raw.fif" diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index d96cfbfcbc9..6caac588229 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -57,7 +57,7 @@ from mne.rank import _compute_rank_int from mne.utils import _record_warnings, catch_logging, check_version -data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" test_cov_name = data_dir / "test-cov.fif" @@ -78,6 +78,8 @@ ) pytest.importorskip("sklearn") +_baseline_corrected = pytest.warns(RuntimeWarning, match="were baseline-corrected") + def ICA(*args, **kwargs): """Fix the random state in tests.""" @@ -171,7 +173,10 @@ def test_ica_simple(method): info = create_info(data.shape[-2], 1000.0, "eeg") cov = make_ad_hoc_cov(info) ica = ICA(n_components=n_components, method=method, random_state=0, noise_cov=cov) - with pytest.warns(RuntimeWarning, match="No average EEG.*"): + with ( + pytest.warns(RuntimeWarning, match="high-pass filtered"), + pytest.warns(RuntimeWarning, match="No average EEG.*"), + ): ica.fit(RawArray(data, info)) transform = ica.unmixing_matrix_ @ ica.pca_components_ @ A amari_distance = np.mean( @@ -649,7 +654,7 @@ def test_ica_additional(method, tmp_path, short_raw_epochs): # test if n_components=None works ica = ICA(n_components=None, method=method, max_iter=1) - with pytest.warns(UserWarning, match="did not converge"): + with _baseline_corrected, pytest.warns(UserWarning, match="did not converge"): ica.fit(epochs) _assert_ica_attributes(ica, epochs.get_data("data"), limits=(0.05, 20)) @@ -1032,7 +1037,7 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): with pytest.raises(ValueError, match="ICA must be fitted first"): ica.get_explained_variance_ratio(epochs) - with pytest.warns(RuntimeWarning, match="were baseline-corrected"): + with _record_warnings(), _baseline_corrected: ica.fit(epochs) # components = int, ch_type = None @@ -1255,7 +1260,10 @@ def test_fit_params_epochs_vs_raw(param_name, param_val, tmp_path): ica = ICA(n_components=n_components, max_iter=max_iter, method=method) fit_params = {param_name: param_val} - with pytest.warns(RuntimeWarning, match="parameters.*will be ignored"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="parameters.*will be ignored"), + ): ica.fit(inst=epochs, **fit_params) assert ica.reject_ == reject _assert_ica_attributes(ica) @@ -1448,7 +1456,7 @@ def test_ica_labels(): assert key in raw.ch_names raw.set_channel_types(rename) ica = ICA(n_components=4, max_iter=2, method="fastica", allow_ref_meg=True) - with pytest.warns(UserWarning, match="did not converge"): + with _record_warnings(), pytest.warns(UserWarning, match="did not converge"): ica.fit(raw) _assert_ica_attributes(ica) @@ -1473,7 +1481,7 @@ def test_ica_labels(): # derive reference ICA components and append them to raw ica_rf = ICA(n_components=2, max_iter=2, allow_ref_meg=True) - with pytest.warns(UserWarning, match="did not converge"): + with _record_warnings(): # high pass and/or no convergence ica_rf.fit(raw.copy().pick("ref_meg")) icacomps = ica_rf.get_sources(raw) # rename components so they are auto-detected by find_bads_ref @@ -1509,7 +1517,7 @@ def test_ica_labels(): assert_allclose(scores, [0.81, 0.14, 0.37, 0.05], atol=0.03) ica = ICA(n_components=4, max_iter=2, method="fastica", allow_ref_meg=True) - with pytest.warns(UserWarning, match="did not converge"): + with _record_warnings(), pytest.warns(UserWarning, match="did not converge"): ica.fit(raw, picks="eeg") ica.find_bads_muscle(raw) assert "muscle" in ica.labels_ diff --git a/mne/preprocessing/tests/test_interpolate.py b/mne/preprocessing/tests/test_interpolate.py index b2b05446c84..8bf4cf0e345 100644 --- a/mne/preprocessing/tests/test_interpolate.py +++ b/mne/preprocessing/tests/test_interpolate.py @@ -12,7 +12,7 @@ from mne.preprocessing.interpolate import _find_centroid_sphere from mne.transforms import _cart_to_sph -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" raw_fname_ctf = base_dir / "test_ctf_raw.fif" diff --git a/mne/preprocessing/tests/test_lof.py b/mne/preprocessing/tests/test_lof.py new file mode 100644 index 00000000000..858fa0e4432 --- /dev/null +++ b/mne/preprocessing/tests/test_lof.py @@ -0,0 +1,39 @@ +# Authors: Velu Prabhakar Kumaravel +# +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import pytest + +from mne.io import read_raw_fif +from mne.preprocessing import find_bad_channels_lof + +base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +raw_fname = base_dir / "test_raw.fif" + + +@pytest.mark.parametrize( + "n_neighbors, ch_type, n_ch, n_bad", + [ + (8, "eeg", 60, 8), + (10, "grad", 204, 2), + (20, "mag", 102, 0), + (30, "grad", 204, 2), + ], +) +def test_lof(n_neighbors, ch_type, n_ch, n_bad): + """Test LOF detection.""" + pytest.importorskip("sklearn") + raw = read_raw_fif(raw_fname).load_data() + assert raw.info["bads"] == [] + bads, scores = find_bad_channels_lof( + raw, n_neighbors, picks=ch_type, return_scores=True + ) + bads_2 = find_bad_channels_lof(raw, n_neighbors, picks=ch_type) + assert len(scores) == n_ch + assert len(bads) == n_bad + assert bads == bads_2 + with pytest.raises(ValueError, match="channel type"): + find_bad_channels_lof(raw) diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 4bfd5cd396c..8f497178408 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -58,7 +58,7 @@ use_log_level, ) -io_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +io_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_small_fname = io_path / "test_raw.fif" data_path = testing.data_path(download=False) @@ -122,7 +122,7 @@ tri_ctc_fname = triux_path / "ct_sparse_BMLHUS.fif" tri_cal_fname = triux_path / "sss_cal_BMLHUS.dat" -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" fname_ctf_raw = io_dir / "tests" / "data" / "test_ctf_comp_raw.fif" ctf_fname_continuous = data_path / "CTF" / "testdata_ctf.ds" @@ -173,11 +173,7 @@ def _assert_n_free(raw_sss, lower, upper=None): """Check the DOF.""" upper = lower if upper is None else upper n_free = raw_sss.info["proc_history"][0]["max_info"]["sss_info"]["nfree"] - assert lower <= n_free <= upper, "nfree fail: %s <= %s <= %s" % ( - lower, - n_free, - upper, - ) + assert lower <= n_free <= upper, f"nfree fail: {lower} <= {n_free} <= {upper}" def _assert_mag_coil_type(info, coil_type): @@ -730,7 +726,8 @@ def test_spatiotemporal_only(): raw_tsss = maxwell_filter(raw, st_duration=tmax, st_correlation=1.0, st_only=True) assert_allclose(raw[:][0], raw_tsss[:][0]) # degenerate - pytest.raises(ValueError, maxwell_filter, raw, st_only=True) # no ST + with pytest.raises(ValueError, match="must not be None if st_only"): + maxwell_filter(raw, st_only=True) # two-step process equivalent to single-step process raw_tsss = maxwell_filter(raw, st_duration=tmax, st_only=True) raw_tsss = maxwell_filter(raw_tsss) @@ -771,7 +768,7 @@ def test_fine_calibration(): log = log.getvalue() assert "Using fine calibration" in log assert fine_cal_fname.stem in log - assert_meg_snr(raw_sss, sss_fine_cal, 82, 611) + assert_meg_snr(raw_sss, sss_fine_cal, 1.3, 180) # similar to MaxFilter py_cal = raw_sss.info["proc_history"][0]["max_info"]["sss_cal"] assert py_cal is not None assert len(py_cal) > 0 @@ -812,15 +809,11 @@ def test_fine_calibration(): regularize=None, bad_condition="ignore", ) - assert_meg_snr(raw_sss_3D, sss_fine_cal, 1.0, 6.0) + assert_meg_snr(raw_sss_3D, sss_fine_cal, 0.9, 6.0) + assert_meg_snr(raw_sss_3D, raw_sss, 1.1, 6.0) # slightly better than 1D raw_ctf = read_crop(fname_ctf_raw).apply_gradient_compensation(0) - pytest.raises( - RuntimeError, - maxwell_filter, - raw_ctf, - origin=(0.0, 0.0, 0.04), - calibration=fine_cal_fname, - ) + with pytest.raises(RuntimeError, match="Not all MEG channels"): + maxwell_filter(raw_ctf, origin=(0.0, 0.0, 0.04), calibration=fine_cal_fname) @pytest.mark.slowtest @@ -884,7 +877,8 @@ def test_cross_talk(tmp_path): assert len(py_ctc) > 0 with pytest.raises(TypeError, match="path-like"): maxwell_filter(raw, cross_talk=raw) - pytest.raises(ValueError, maxwell_filter, raw, cross_talk=raw_fname) + with pytest.raises(ValueError, match="Invalid cross-talk FIF"): + maxwell_filter(raw, cross_talk=raw_fname) mf_ctc = sss_ctc.info["proc_history"][0]["max_info"]["sss_ctc"] del mf_ctc["block_id"] # we don't write this assert isinstance(py_ctc["decoupler"], sparse.csc_matrix) @@ -916,13 +910,8 @@ def test_cross_talk(tmp_path): with pytest.warns(RuntimeWarning, match="Not all cross-talk channels"): maxwell_filter(raw_missing, cross_talk=ctc_fname) # MEG channels not in cross-talk - pytest.raises( - RuntimeError, - maxwell_filter, - raw_ctf, - origin=(0.0, 0.0, 0.04), - cross_talk=ctc_fname, - ) + with pytest.raises(RuntimeError, match="Missing MEG channels"): + maxwell_filter(raw_ctf, origin=(0.0, 0.0, 0.04), cross_talk=ctc_fname) @testing.requires_testing_data @@ -970,10 +959,10 @@ def test_head_translation(): read_info(sample_fname)["dev_head_t"]["trans"], ) # Degenerate cases - pytest.raises( - RuntimeError, maxwell_filter, raw, destination=mf_head_origin, coord_frame="meg" - ) - pytest.raises(ValueError, maxwell_filter, raw, destination=[0.0] * 4) + with pytest.raises(RuntimeError, match=".* can only be set .* head .*"): + maxwell_filter(raw, destination=mf_head_origin, coord_frame="meg") + with pytest.raises(ValueError, match="destination must be"): + maxwell_filter(raw, destination=[0.0] * 4) # TODO: Eventually add simulation tests mirroring Taulu's original paper @@ -994,7 +983,7 @@ def _assert_shielding(raw_sss, erm_power, min_factor, max_factor=np.inf, meg="ma factor = erm_power / sss_power assert ( min_factor <= factor < max_factor - ), "Shielding factor not %0.3f <= %0.3f < %0.3f" % (min_factor, factor, max_factor) + ), f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" @buggy_mkl_svd @@ -1347,7 +1336,7 @@ def test_shielding_factor(tmp_path): assert counts[0] == 3 # Show it by rewriting the 3D as 1D and testing it temp_fname = tmp_path / "test_cal.dat" - with open(fine_cal_fname_3d, "r") as fid: + with open(fine_cal_fname_3d) as fid: with open(temp_fname, "w") as fid_out: for line in fid: fid_out.write(" ".join(line.strip().split(" ")[:14]) + "\n") @@ -1395,7 +1384,7 @@ def test_all(): coord_frames = ("head", "head", "meg", "head") ctcs = (ctc_fname, ctc_fname, ctc_fname, ctc_mgh_fname) mins = (3.5, 3.5, 1.2, 0.9) - meds = (10.8, 10.4, 3.2, 6.0) + meds = (10.8, 10.2, 3.2, 5.9) st_durs = (1.0, 1.0, 1.0, None) destinations = (None, sample_fname, None, None) origins = (mf_head_origin, mf_head_origin, mf_meg_origin, mf_head_origin) @@ -1436,7 +1425,7 @@ def test_triux(): sss_py = maxwell_filter( raw, coord_frame="meg", regularize=None, calibration=tri_cal_fname ) - assert_meg_snr(sss_py, read_crop(tri_sss_cal_fname), 22, 200) + assert_meg_snr(sss_py, read_crop(tri_sss_cal_fname), 5, 100) # ctc+cal sss_py = maxwell_filter( raw, @@ -1445,7 +1434,7 @@ def test_triux(): calibration=tri_cal_fname, cross_talk=tri_ctc_fname, ) - assert_meg_snr(sss_py, read_crop(tri_sss_ctc_cal_fname), 28, 200) + assert_meg_snr(sss_py, read_crop(tri_sss_ctc_cal_fname), 5, 100) # regularization sss_py = maxwell_filter(raw, coord_frame="meg", regularize="in") sss_mf = read_crop(tri_sss_reg_fname) diff --git a/mne/preprocessing/tests/test_realign.py b/mne/preprocessing/tests/test_realign.py index 60ec5b0d5ba..952c6ac30bb 100644 --- a/mne/preprocessing/tests/test_realign.py +++ b/mne/preprocessing/tests/test_realign.py @@ -158,7 +158,7 @@ def _assert_similarity(raw, other, n_events, ratio_other, events_raw=None): evoked_other = Epochs(other, events_other, **kwargs).average() assert evoked_raw.nave == evoked_other.nave == len(events_raw) assert len(evoked_raw.data) == len(evoked_other.data) == 1 # just EEG - if 0.99 <= ratio_other <= 1.01: # when drift is not too large + if 0.99 <= ratio_other <= 1.01: # when drift is not too large corr = np.corrcoef(evoked_raw.data[0], evoked_other.data[0])[0, 1] assert 0.9 <= corr <= 1.0 return evoked_raw, events_raw, evoked_other, events_other diff --git a/mne/preprocessing/tests/test_regress.py b/mne/preprocessing/tests/test_regress.py index 8050b6bebf7..48d960e0464 100644 --- a/mne/preprocessing/tests/test_regress.py +++ b/mne/preprocessing/tests/test_regress.py @@ -42,6 +42,19 @@ def test_regress_artifact(): epochs, betas = regress_artifact(epochs, picks="eog", picks_artifact="eog") assert np.ptp(epochs.get_data("eog")) < 1e-15 # constant value assert_allclose(betas, 1) + # proj should only be required of channels being processed + raw = read_raw_fif(raw_fname).crop(0, 1).load_data() + raw.del_proj() + raw.set_eeg_reference(projection=True) + model = EOGRegression(proj=False, picks="meg", picks_artifact="eog") + model.fit(raw) + model.apply(raw) + model = EOGRegression(proj=False, picks="eeg", picks_artifact="eog") + with pytest.raises(RuntimeError, match="Projections need to be applied"): + model.fit(raw) + raw.del_proj() + with pytest.raises(RuntimeError, match="No average reference for the EEG"): + model.fit(raw) @testing.requires_testing_data diff --git a/mne/preprocessing/tests/test_ssp.py b/mne/preprocessing/tests/test_ssp.py index fdcf4f9db23..a6ece5ea2e1 100644 --- a/mne/preprocessing/tests/test_ssp.py +++ b/mne/preprocessing/tests/test_ssp.py @@ -11,8 +11,9 @@ from mne.datasets import testing from mne.io import read_raw_ctf, read_raw_fif from mne.preprocessing.ssp import compute_proj_ecg, compute_proj_eog +from mne.utils import _record_warnings -data_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" dur_use = 5.0 eog_times = np.array([0.5, 2.3, 3.6, 14.5]) @@ -69,7 +70,10 @@ def test_compute_proj_ecg(short_raw, average): # XXX: better tests # without setting a bad channel, this should throw a warning - with pytest.warns(RuntimeWarning, match="No good epochs found"): + # (first with a call that makes sure we copy the mutable default "reject") + with pytest.warns(RuntimeWarning, match="longer than the signal"): + compute_proj_ecg(raw.copy().pick("mag"), l_freq=None, h_freq=None) + with _record_warnings(), pytest.warns(RuntimeWarning, match="No good epochs found"): projs, events, drop_log = compute_proj_ecg( raw, n_mag=2, @@ -130,7 +134,7 @@ def test_compute_proj_eog(average, short_raw): assert proj["explained_var"] > thresh_eeg # XXX: better tests - with pytest.warns(RuntimeWarning, match="longer"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="longer"): projs, events = compute_proj_eog( raw, n_mag=2, @@ -147,7 +151,10 @@ def test_compute_proj_eog(average, short_raw): assert projs == [] raw._data[raw.ch_names.index("EOG 061"), :] = 1.0 - with pytest.warns(RuntimeWarning, match="filter.*longer than the signal"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="filter.*longer than the signal"), + ): projs, events = compute_proj_eog(raw=raw, tmax=dur_use, ch_name="EOG 061") @@ -172,7 +179,7 @@ def test_compute_proj_parallel(short_raw): filter_length=100, ) raw_2 = short_raw.copy() - with pytest.warns(RuntimeWarning, match="Attenuation"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Attenuation"): projs_2, _ = compute_proj_eog( raw_2, n_eeg=2, diff --git a/mne/preprocessing/tests/test_stim.py b/mne/preprocessing/tests/test_stim.py index 2ef1c6e367a..270b8d93354 100644 --- a/mne/preprocessing/tests/test_stim.py +++ b/mne/preprocessing/tests/test_stim.py @@ -14,7 +14,7 @@ from mne.io import read_raw_fif from mne.preprocessing.stim import fix_stim_artifact -data_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" diff --git a/mne/preprocessing/tests/test_xdawn.py b/mne/preprocessing/tests/test_xdawn.py index 31e751acb37..03bc445f2a7 100644 --- a/mne/preprocessing/tests/test_xdawn.py +++ b/mne/preprocessing/tests/test_xdawn.py @@ -24,7 +24,7 @@ from mne.io import read_raw_fif from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" @@ -335,7 +335,7 @@ def _simulate_erplike_mixed_data(n_epochs=100, n_channels=10): events[:, 2] = y info = create_info( - ch_names=["C{:02d}".format(i) for i in range(n_channels)], + ch_names=[f"C{i:02d}" for i in range(n_channels)], ch_types=["eeg"] * n_channels, sfreq=sfreq, ) diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index ffb0cb0e5cd..c0a0bb88cb3 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -202,7 +202,7 @@ def _fit_xdawn( except np.linalg.LinAlgError as exp: raise ValueError( "Could not compute eigenvalues, ensure " - "proper regularization (%s)" % (exp,) + f"proper regularization ({exp})" ) evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs) @@ -425,9 +425,7 @@ def __init__( self, n_components=2, signal_cov=None, correct_overlap="auto", reg=None ): """Init.""" - super(Xdawn, self).__init__( - n_components=n_components, signal_cov=signal_cov, reg=reg - ) + super().__init__(n_components=n_components, signal_cov=signal_cov, reg=reg) self.correct_overlap = _check_option( "correct_overlap", correct_overlap, ["auto", True, False] ) @@ -532,7 +530,7 @@ def transform(self, inst): elif isinstance(inst, np.ndarray): X = inst if X.ndim not in (2, 3): - raise ValueError("X must be 2D or 3D, got %s" % (X.ndim,)) + raise ValueError(f"X must be 2D or 3D, got {X.ndim}") else: raise ValueError("Data input must be of Epoch type or numpy array") diff --git a/mne/proj.py b/mne/proj.py index a5bb406b844..d72bbd27e06 100644 --- a/mne/proj.py +++ b/mne/proj.py @@ -151,8 +151,8 @@ def _compute_proj( nrow=1, ncol=u.size, ) - desc = f"{kind}-{desc_prefix}-PCA-{k+1:02d}" - logger.info("Adding projection: %s", desc) + desc = f"{kind}-{desc_prefix}-PCA-{k + 1:02d}" + logger.info(f"Adding projection: {desc} (exp var={100 * float(var):0.1f}%)") proj = Projection( active=False, data=proj_data, @@ -217,17 +217,20 @@ def compute_proj_epochs( else: event_id = "Multiple-events" if desc_prefix is None: - desc_prefix = "%s-%-.3f-%-.3f" % (event_id, epochs.tmin, epochs.tmax) + desc_prefix = f"{event_id}-{epochs.tmin:<.3f}-{epochs.tmax:<.3f}" return _compute_proj(data, epochs.info, n_grad, n_mag, n_eeg, desc_prefix, meg=meg) -def _compute_cov_epochs(epochs, n_jobs): +def _compute_cov_epochs(epochs, n_jobs, *, log_drops=False): """Compute epochs covariance.""" parallel, p_fun, n_jobs = parallel_func(np.dot, n_jobs) + n_start = len(epochs.events) data = parallel(p_fun(e, e.T) for e in epochs) n_epochs = len(data) if n_epochs == 0: raise RuntimeError("No good epochs found") + if log_drops: + logger.info(f"Dropped {n_start - n_epochs}/{n_start} epochs") n_chan, n_samples = epochs.info["nchan"], len(epochs.times) _check_n_samples(n_samples * n_epochs, n_chan) @@ -273,7 +276,7 @@ def compute_proj_evoked( """ data = np.dot(evoked.data, evoked.data.T) # compute data covariance if desc_prefix is None: - desc_prefix = "%-.3f-%-.3f" % (evoked.times[0], evoked.times[-1]) + desc_prefix = f"{evoked.times[0]:<.3f}-{evoked.times[-1]:<.3f}" return _compute_proj(data, evoked.info, n_grad, n_mag, n_eeg, desc_prefix, meg=meg) @@ -351,7 +354,7 @@ def compute_proj_raw( baseline=None, proj=False, ) - data = _compute_cov_epochs(epochs, n_jobs) + data = _compute_cov_epochs(epochs, n_jobs, log_drops=True) info = epochs.info if not stop: stop = raw.n_times / raw.info["sfreq"] @@ -368,7 +371,7 @@ def compute_proj_raw( start = start / raw.info["sfreq"] stop = stop / raw.info["sfreq"] - desc_prefix = "Raw-%-.3f-%-.3f" % (start, stop) + desc_prefix = f"Raw-{start:<.3f}-{stop:<.3f}" projs = _compute_proj(data, info, n_grad, n_mag, n_eeg, desc_prefix, meg=meg) return projs @@ -456,7 +459,7 @@ def sensitivity_map( elif ncomp == 0: raise RuntimeError( "No valid projectors found for channel type " - "%s, cannot compute %s" % (ch_type, mode) + f"{ch_type}, cannot compute {mode}" ) # can only run the last couple methods if there are projectors elif mode in residual_types: diff --git a/mne/rank.py b/mne/rank.py index 539f897a253..a176a1f5431 100644 --- a/mne/rank.py +++ b/mne/rank.py @@ -100,7 +100,7 @@ def _estimate_rank_from_s(s, tol="auto", tol_kind="absolute"): max_s = np.amax(s, axis=-1) if isinstance(tol, str): if tol not in ("auto", "float32"): - raise ValueError('tol must be "auto" or float, got %r' % (tol,)) + raise ValueError(f'tol must be "auto" or float, got {repr(tol)}') # XXX this should be float32 probably due to how we save and # load data, but it breaks test_make_inverse_operator (!) # The factor of 2 gets test_compute_covariance_auto_reg[None] @@ -139,7 +139,13 @@ def _estimate_rank_raw( @fill_doc def _estimate_rank_meeg_signals( - data, info, scalings, tol="auto", return_singular=False, tol_kind="absolute" + data, + info, + scalings, + tol="auto", + return_singular=False, + tol_kind="absolute", + log_ch_type=None, ): """Estimate rank for M/EEG data. @@ -187,14 +193,24 @@ def _estimate_rank_meeg_signals( tol_kind=tol_kind, ) rank = out[0] if isinstance(out, tuple) else out - ch_type = " + ".join(list(zip(*picks_list))[0]) + if log_ch_type is None: + ch_type = " + ".join(list(zip(*picks_list))[0]) + else: + ch_type = log_ch_type logger.info(" Estimated rank (%s): %d" % (ch_type, rank)) return out @verbose def _estimate_rank_meeg_cov( - data, info, scalings, tol="auto", return_singular=False, verbose=None + data, + info, + scalings, + tol="auto", + return_singular=False, + *, + log_ch_type=None, + verbose=None, ): """Estimate rank of M/EEG covariance data, given the covariance. @@ -235,8 +251,11 @@ def _estimate_rank_meeg_cov( ) out = estimate_rank(data, tol=tol, norm=False, return_singular=return_singular) rank = out[0] if isinstance(out, tuple) else out - ch_type = " + ".join(list(zip(*picks_list))[0]) - logger.info(" Estimated rank (%s): %d" % (ch_type, rank)) + if log_ch_type is None: + ch_type_ = " + ".join(list(zip(*picks_list))[0]) + else: + ch_type_ = log_ch_type + logger.info(f" Estimated rank ({ch_type_}): {rank}") _undo_scaling_cov(data, picks_list, scalings) return out @@ -352,6 +371,32 @@ def compute_rank( ----- .. versionadded:: 0.18 """ + return _compute_rank( + inst=inst, + rank=rank, + scalings=scalings, + info=info, + tol=tol, + proj=proj, + tol_kind=tol_kind, + on_rank_mismatch=on_rank_mismatch, + ) + + +@verbose +def _compute_rank( + inst, + rank=None, + scalings=None, + info=None, + *, + tol="auto", + proj=True, + tol_kind="absolute", + on_rank_mismatch="ignore", + log_ch_type=None, + verbose=None, +): from .cov import Covariance from .epochs import BaseEpochs from .io import BaseRaw @@ -377,7 +422,7 @@ def compute_rank( else: info = inst.info inst_type = "data" - logger.info("Computing rank from %s with rank=%r" % (inst_type, rank)) + logger.info(f"Computing rank from {inst_type} with rank={repr(rank)}") _validate_type(rank, (str, dict, None), "rank") if isinstance(rank, str): # string, either 'info' or 'full' @@ -417,25 +462,22 @@ def compute_rank( proj_op, n_proj, _ = make_projector(info["projs"], ch_names) else: proj_op, n_proj = None, 0 + if log_ch_type is None: + ch_type_ = ch_type.upper() + else: + ch_type_ = log_ch_type if rank_type == "info": # use info this_rank = _info_rank(info, ch_type, picks, info_type) if info_type != "full": this_rank -= n_proj logger.info( - " %s: rank %d after %d projector%s applied to " - "%d channel%s" - % ( - ch_type.upper(), - this_rank, - n_proj, - _pl(n_proj), - n_chan, - _pl(n_chan), - ) + f" {ch_type_}: rank {this_rank} after " + f"{n_proj} projector{_pl(n_proj)} applied to " + "{n_chan} channel{_pl(n_chan)}" ) else: - logger.info(" %s: rank %d from info" % (ch_type.upper(), this_rank)) + logger.info(f" {ch_type_}: rank {this_rank} from info") else: # Use empirical estimation assert rank_type == "estimated" @@ -447,7 +489,13 @@ def compute_rank( if proj: data = np.dot(proj_op, data) this_rank = _estimate_rank_meeg_signals( - data, pick_info(simple_info, picks), scalings, tol, False, tol_kind + data, + pick_info(simple_info, picks), + scalings, + tol, + False, + tol_kind, + log_ch_type=log_ch_type, ) else: assert isinstance(inst, Covariance) @@ -464,6 +512,7 @@ def compute_rank( scalings, tol, return_singular=True, + log_ch_type=log_ch_type, verbose=est_verbose, ) if ch_type in rank: @@ -483,9 +532,9 @@ def compute_rank( continue this_info_rank = _info_rank(info, ch_type, picks, "info") logger.info( - " %s: rank %d computed from %d data channel%s " - "with %d projector%s" - % (ch_type.upper(), this_rank, n_chan, _pl(n_chan), n_proj, _pl(n_proj)) + f" {ch_type_}: rank {this_rank} computed from " + f"{n_chan} data channel{_pl(n_chan)} with " + f"{n_proj} projector{_pl(n_proj)}" ) if this_rank > this_info_rank: warn( diff --git a/mne/report/js_and_css/bootstrap-icons/gen_css_for_mne.py b/mne/report/js_and_css/bootstrap-icons/gen_css_for_mne.py index 95b99c306f7..7eac8ecdaa0 100644 --- a/mne/report/js_and_css/bootstrap-icons/gen_css_for_mne.py +++ b/mne/report/js_and_css/bootstrap-icons/gen_css_for_mne.py @@ -15,7 +15,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - import base64 from pathlib import Path diff --git a/mne/report/js_and_css/report.css b/mne/report/js_and_css/report.css new file mode 100644 index 00000000000..724a13241a5 --- /dev/null +++ b/mne/report/js_and_css/report.css @@ -0,0 +1,19 @@ +#container { + position: relative; + padding-bottom: 8rem; +} + +#content { + margin-top: 90px; + scroll-behavior: smooth; + position: relative; /* for scrollspy */ +} + +#toc { + margin-top: 90px; + padding-bottom: 8rem; +} + +footer { + margin-top: 8rem; +} diff --git a/mne/report/js_and_css/report.sass b/mne/report/js_and_css/report.sass deleted file mode 100644 index 4d533d07011..00000000000 --- a/mne/report/js_and_css/report.sass +++ /dev/null @@ -1,19 +0,0 @@ -#container { - position: relative - padding-bottom: 5rem -} - -#content { - margin-top: 90px - scroll-behavior: smooth - position: relative // for scrollspy -} - -#toc { - margin-top: 90px - padding-bottom: 5rem -} - -footer { - margin-top: 5rem; -} diff --git a/mne/report/report.py b/mne/report/report.py index 6a37f095c2f..b2fafe5b446 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -8,6 +8,7 @@ # Copyright the MNE-Python contributors. import base64 +import copy import dataclasses import fnmatch import io @@ -23,7 +24,7 @@ from io import BytesIO, StringIO from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -143,7 +144,7 @@ html_include_dir = Path(__file__).parent / "js_and_css" template_dir = Path(__file__).parent / "templates" JAVASCRIPT = (html_include_dir / "report.js").read_text(encoding="utf-8") -CSS = (html_include_dir / "report.sass").read_text(encoding="utf-8") +CSS = (html_include_dir / "report.css").read_text(encoding="utf-8") MAX_IMG_RES = 100 # in dots per inch MAX_IMG_WIDTH = 850 # in pixels @@ -302,11 +303,11 @@ class _ContentElement: name: str section: Optional[str] dom_id: str - tags: Tuple[str] + tags: tuple[str] html: str -def _check_tags(tags) -> Tuple[str]: +def _check_tags(tags) -> tuple[str]: # Must be iterable, but not a string if isinstance(tags, str): tags = (tags,) @@ -430,8 +431,8 @@ def _fig_to_img(fig, *, image_format="png", own_figure=True): if pil_kwargs: # matplotlib modifies the passed dict, which is a bug mpl_kwargs["pil_kwargs"] = pil_kwargs.copy() - with warnings.catch_warnings(): - fig.savefig(output, format=image_format, dpi=dpi, **mpl_kwargs) + + fig.savefig(output, format=image_format, dpi=dpi, **mpl_kwargs) if own_figure: plt.close(fig) @@ -846,7 +847,7 @@ def __init__( self.include = [] self.lang = "en-us" # language setting for the HTML file if not isinstance(raw_psd, bool) and not isinstance(raw_psd, dict): - raise TypeError("raw_psd must be bool or dict, got %s" % (type(raw_psd),)) + raise TypeError(f"raw_psd must be bool or dict, got {type(raw_psd)}") self.raw_psd = raw_psd self._init_render() # Initialize the renderer @@ -965,6 +966,64 @@ def _validate_input(self, items, captions, tag, comments=None): ) return items, captions, comments + def copy(self): + """Return a deepcopy of the report. + + Returns + ------- + report : instance of Report + The copied report. + """ + return copy.deepcopy(self) + + def get_contents(self): + """Get the content of the report. + + Returns + ------- + titles : list of str + The title of each content element. + tags : list of list of str + The tags for each content element, one list per element. + htmls : list of str + The HTML contents for each element. + + Notes + ----- + .. versionadded:: 1.7 + """ + htmls, _, titles, tags = self._content_as_html() + return titles, tags, htmls + + def reorder(self, order): + """Reorder the report content. + + Parameters + ---------- + order : array-like of int + The indices of the new order (as if you were reordering an array). + For example if there are 4 elements in the report, + ``order=[3, 0, 1, 2]`` would take the last element and move it to + the front. In other words, ``elements = [elements[ii] for ii in order]]``. + + Notes + ----- + .. versionadded:: 1.7 + """ + _validate_type(order, "array-like", "order") + order = np.array(order) + if order.dtype.kind != "i" or order.ndim != 1: + raise ValueError( + "order must be an array of integers, got " + f"{order.ndim}D array of dtype {order.dtype}" + ) + n_elements = len(self._content) + if not np.array_equal(np.sort(order), np.arange(n_elements)): + raise ValueError( + f"order must be a permutation of range({n_elements}), got:\n{order}" + ) + self._content = [self._content[ii] for ii in order] + def _content_as_html(self): """Generate HTML representations based on the added content & sections. @@ -1005,7 +1064,7 @@ def _content_as_html(self): ] section_htmls = [el.html for el in section_elements] section_tags = tuple( - sorted((set([t for el in section_elements for t in el.tags]))) + sorted(set([t for el in section_elements for t in el.tags])) ) section_dom_id = self._get_dom_id( section=None, # root level of document @@ -1039,18 +1098,12 @@ def _content_as_html(self): @property def html(self): """A list of HTML representations for all content elements.""" - htmls, _, _, _ = self._content_as_html() - return htmls + return self._content_as_html()[0] @property def tags(self): - """All tags currently used in the report.""" - tags = [] - for c in self._content: - tags.extend(c.tags) - - tags = tuple(sorted(set(tags))) - return tags + """A sorted tuple of all tags currently used in the report.""" + return tuple(sorted(set(sum(self._content_as_html()[3], ())))) def add_custom_css(self, css): """Add custom CSS to the report. @@ -1092,6 +1145,7 @@ def add_epochs( *, psd=True, projs=None, + image_kwargs=None, topomap_kwargs=None, drop_log_ignore=("IGNORED",), tags=("epochs",), @@ -1120,6 +1174,18 @@ def add_epochs( If ``True``, add PSD plots based on all ``epochs``. If ``False``, do not add PSD plots. %(projs_report)s + image_kwargs : dict | None + Keyword arguments to pass to the "epochs image"-generating + function (:meth:`mne.Epochs.plot_image`). + Keys are channel types, values are dicts containing kwargs to pass. + For example, to use the rejection limits per channel type you could pass:: + + image_kwargs=dict( + grad=dict(vmin=-reject['grad'], vmax=-reject['grad']), + mag=dict(vmin=-reject['mag'], vmax=reject['mag']), + ) + + .. versionadded:: 1.7 %(topomap_kwargs)s drop_log_ignore : array-like of str The drop reasons to ignore when creating the drop log bar plot. @@ -1130,7 +1196,7 @@ def add_epochs( Notes ----- - .. versionadded:: 0.24.0 + .. versionadded:: 0.24 """ tags = _check_tags(tags) add_projs = self.projs if projs is None else projs @@ -1138,6 +1204,7 @@ def add_epochs( epochs=epochs, psd=psd, add_projs=add_projs, + image_kwargs=image_kwargs, topomap_kwargs=topomap_kwargs, drop_log_ignore=drop_log_ignore, section=title, @@ -2271,7 +2338,7 @@ def add_figure( elif caption is None and len(figs) == 1: captions = [None] elif caption is None and len(figs) > 1: - captions = [f"Figure {i+1}" for i in range(len(figs))] + captions = [f"Figure {i + 1}" for i in range(len(figs))] else: captions = tuple(caption) @@ -2383,7 +2450,7 @@ def add_html( ) self._add_or_replace( title=title, - section=None, + section=section, tags=tags, html_partial=html_partial, replace=replace, @@ -2861,7 +2928,7 @@ def parse_folder( ) if sort_content: - self._content = self._sort(content=self._content, order=CONTENT_ORDER) + self._sort(order=CONTENT_ORDER) def __getstate__(self): """Get the state of the report as a dictionary.""" @@ -2940,7 +3007,7 @@ def save( fname = op.realpath(fname) # resolve symlinks if sort_content: - self._content = self._sort(content=self._content, order=CONTENT_ORDER) + self._sort(order=CONTENT_ORDER) if not overwrite and op.isfile(fname): msg = ( @@ -2998,35 +3065,28 @@ def __enter__(self): """Do nothing when entering the context block.""" return self - def __exit__(self, type, value, traceback): + def __exit__(self, exception_type, value, traceback): """Save the report when leaving the context block.""" if self.fname is not None: self.save(self.fname, open_browser=False, overwrite=True) - @staticmethod - def _sort(content, order): + def _sort(self, *, order): """Reorder content to reflect "natural" ordering.""" - content_unsorted = content.copy() - content_sorted = [] content_sorted_idx = [] - del content # First arrange content with known tags in the predefined order for tag in order: - for idx, content in enumerate(content_unsorted): + for idx, content in enumerate(self._content): if tag in content.tags: content_sorted_idx.append(idx) - content_sorted.append(content) # Now simply append the rest (custom tags) - content_remaining = [ - content - for idx, content in enumerate(content_unsorted) - if idx not in content_sorted_idx - ] - - content_sorted = [*content_sorted, *content_remaining] - return content_sorted + self.reorder( + np.r_[ + content_sorted_idx, + np.setdiff1d(np.arange(len(self._content)), content_sorted_idx), + ] + ) def _render_one_bem_axis( self, @@ -3143,7 +3203,7 @@ def _add_raw_butterfly_segments( del orig_annotations - captions = [f"Segment {i+1} of {len(images)}" for i in range(len(images))] + captions = [f"Segment {i + 1} of {len(images)}" for i in range(len(images))] self._add_slider( figs=None, @@ -3218,7 +3278,9 @@ def _add_raw( init_kwargs, plot_kwargs = _split_psd_kwargs(kwargs=add_psd) init_kwargs.setdefault("fmax", fmax) plot_kwargs.setdefault("show", False) - fig = raw.compute_psd(**init_kwargs).plot(**plot_kwargs) + with warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=FutureWarning) + fig = raw.compute_psd(**init_kwargs).plot(**plot_kwargs) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) self._add_figure( fig=fig, @@ -3785,7 +3847,7 @@ def _add_epochs_psd(self, *, epochs, psd, image_format, tags, section, replace): if fmax > 0.5 * epochs.info["sfreq"]: fmax = np.inf - fig = epochs_for_psd.compute_psd(fmax=fmax).plot(show=False) + fig = epochs_for_psd.compute_psd(fmax=fmax).plot(amplitude=False, show=False) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) duration = round(epoch_duration * len(epochs_for_psd), 1) caption = ( @@ -3826,63 +3888,9 @@ def _add_epochs_metadata(self, *, epochs, section, tags, replace): metadata.index.name = "Epoch #" assert metadata.index.is_unique - index_name = metadata.index.name # store for later use + data_id = metadata.index.name # store for later use metadata = metadata.reset_index() # We want "proper" columns only - html = metadata.to_html( - border=0, - index=False, - show_dimensions=True, - justify="unset", - float_format=lambda x: f"{round(x, 3):.3f}", - classes="table table-hover table-striped " - "table-sm table-responsive small", - ) - del metadata - - # Massage the table such that it woks nicely with bootstrap-table - htmls = html.split("\n") - header_pattern = "(.*)" - - for idx, html in enumerate(htmls): - if "' - ) - continue - - col_headers = re.findall(pattern=header_pattern, string=html) - if col_headers: - # Make columns sortable - assert len(col_headers) == 1 - col_header = col_headers[0] - htmls[idx] = html.replace( - "", - f'', - ) - - html = "\n".join(htmls) + html = _df_bootstrap_table(df=metadata, data_id=data_id) self._add_html_element( div_klass="epochs", tags=tags, @@ -3898,6 +3906,7 @@ def _add_epochs( epochs, psd, add_projs, + image_kwargs, topomap_kwargs, drop_log_ignore, image_format, @@ -3932,9 +3941,17 @@ def _add_epochs( ch_types = _get_data_ch_types(epochs) epochs.load_data() + _validate_type(image_kwargs, (dict, None), "image_kwargs") + # ensure dict with shallow copy because we will modify it + image_kwargs = dict() if image_kwargs is None else image_kwargs.copy() + for ch_type in ch_types: with use_log_level(_verbose_safe_false(level="error")): - figs = epochs.copy().pick(ch_type, verbose=False).plot_image(show=False) + figs = ( + epochs.copy() + .pick(ch_type, verbose=False) + .plot_image(show=False, **image_kwargs.pop(ch_type, dict())) + ) assert len(figs) == 1 fig = figs[0] @@ -3957,6 +3974,12 @@ def _add_epochs( replace=replace, own_figure=True, ) + if image_kwargs: + raise ValueError( + f"Ensure the keys in image_kwargs map onto channel types plotted in " + f"epochs.plot_image() of {ch_types}, could not use: " + f"{list(image_kwargs)}" + ) # Drop log if epochs._bad_dropped: @@ -4371,3 +4394,59 @@ def __call__(self, block, block_vars, gallery_conf): def copyfiles(self, *args, **kwargs): for key, value in self.files.items(): copyfile(key, value) + + +def _df_bootstrap_table(*, df, data_id): + html = df.to_html( + border=0, + index=False, + show_dimensions=True, + justify="unset", + float_format=lambda x: f"{x:.3f}", + classes="table table-hover table-striped table-sm table-responsive small", + na_rep="", + ) + htmls = html.split("\n") + header_pattern = "(.*)" + + for idx, html in enumerate(htmls): + if "' + ) + continue + + col_headers = re.findall(pattern=header_pattern, string=html) + if col_headers: + # Make columns sortable + assert len(col_headers) == 1 + col_header = col_headers[0] + htmls[idx] = html.replace( + "", + f'', + ) + + html = "\n".join(htmls) + return html diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index 4f307367b6a..3860e227318 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -5,7 +5,6 @@ # Copyright the MNE-Python contributors. import base64 -import copy import glob import os import pickle @@ -38,7 +37,7 @@ CONTENT_ORDER, _webp_supported, ) -from mne.utils import Bunch +from mne.utils import Bunch, _record_warnings from mne.utils._testing import assert_object_equal from mne.viz import plot_alignment @@ -57,13 +56,9 @@ inv_fname = sample_meg_dir / "sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif" stc_fname = sample_meg_dir / "sample_audvis_trunc-meg" mri_fname = subjects_dir / "sample" / "mri" / "T1.mgz" -bdf_fname = ( - Path(__file__).parent.parent.parent / "io" / "edf" / "tests" / "data" / "test.bdf" -) -edf_fname = ( - Path(__file__).parent.parent.parent / "io" / "edf" / "tests" / "data" / "test.edf" -) -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +bdf_fname = Path(__file__).parents[2] / "io" / "edf" / "tests" / "data" / "test.bdf" +edf_fname = Path(__file__).parents[2] / "io" / "edf" / "tests" / "data" / "test.edf" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" nirs_fname = ( data_dir / "SNIRF" / "NIRx" / "NIRSport2" / "1.0.3" / "2021-05-05_001.snirf" @@ -103,6 +98,8 @@ def _make_invisible(fig, **kwargs): @testing.requires_testing_data def test_render_report(renderer_pyvistaqt, tmp_path, invisible_fig): """Test rendering *.fif files for mne report.""" + pytest.importorskip("pymatreader") + raw_fname_new = tmp_path / "temp_raw.fif" raw_fname_new_bids = tmp_path / "temp_meg.fif" ms_fname_new = tmp_path / "temp_ms_raw.fif" @@ -640,7 +637,7 @@ def test_remove(): r.add_figure(fig=fig2, title="figure2", tags=("slider",)) # Test removal by title - r2 = copy.deepcopy(r) + r2 = r.copy() removed_index = r2.remove(title="figure1") assert removed_index == 2 assert len(r2.html) == 3 @@ -649,7 +646,7 @@ def test_remove(): assert r2.html[2] == r.html[3] # Test restricting to section - r2 = copy.deepcopy(r) + r2 = r.copy() removed_index = r2.remove(title="figure1", tags=("othertag",)) assert removed_index == 1 assert len(r2.html) == 3 @@ -694,7 +691,7 @@ def test_add_or_replace(tags): assert len(r.html) == 4 assert len(r._content) == 4 - old_r = copy.deepcopy(r) + old_r = r.copy() # Replace our last occurrence of title='duplicate' r.add_figure( @@ -767,7 +764,7 @@ def test_add_or_replace_section(): assert len(r.html) == 3 assert len(r._content) == 3 - old_r = copy.deepcopy(r) + old_r = r.copy() assert r.html[0] == old_r.html[0] assert r.html[1] == old_r.html[1] assert r.html[2] == old_r.html[2] @@ -886,6 +883,8 @@ def test_manual_report_2d(tmp_path, invisible_fig): raw = read_raw_fif(raw_fname) raw.pick(raw.ch_names[:6]).crop(10, None) raw.info.normalize_proj() + raw_non_preloaded = raw.copy() + raw.load_data() cov = read_cov(cov_fname) cov = pick_channels_cov(cov, raw.ch_names) events = read_events(events_fname) @@ -901,7 +900,12 @@ def test_manual_report_2d(tmp_path, invisible_fig): events=events, event_id=event_id, tmin=-0.2, tmax=0.5, sfreq=raw.info["sfreq"] ) epochs_without_metadata = Epochs( - raw=raw, events=events, event_id=event_id, baseline=None + raw=raw, + events=events, + event_id=event_id, + baseline=None, + decim=10, + verbose="error", ) epochs_with_metadata = Epochs( raw=raw, @@ -909,9 +913,11 @@ def test_manual_report_2d(tmp_path, invisible_fig): event_id=metadata_event_id, baseline=None, metadata=metadata, + decim=10, + verbose="error", ) evokeds = read_evokeds(evoked_fname) - evoked = evokeds[0].pick("eeg") + evoked = evokeds[0].pick("eeg").decimate(10, verbose="error") with pytest.warns(ConvergenceWarning, match="did not converge"): ica = ICA(n_components=3, max_iter=1, random_state=42).fit( @@ -929,7 +935,10 @@ def test_manual_report_2d(tmp_path, invisible_fig): tags=("epochs",), psd=False, projs=False, + image_kwargs=dict(mag=dict(colorbar=False)), ) + with pytest.raises(ValueError, match="map onto channel types"): + r.add_epochs(epochs=epochs_without_metadata, image_kwargs=dict(a=1), title="a") r.add_epochs( epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False ) @@ -965,11 +974,11 @@ def test_manual_report_2d(tmp_path, invisible_fig): ) r.add_ica(ica=ica, title="my ica", inst=None) with pytest.raises(RuntimeError, match="not preloaded"): - r.add_ica(ica=ica, title="ica", inst=raw) + r.add_ica(ica=ica, title="ica", inst=raw_non_preloaded) r.add_ica( ica=ica, title="my ica with raw inst", - inst=raw.copy().load_data(), + inst=raw, picks=[2], ecg_evoked=ica_ecg_evoked, eog_evoked=ica_eog_evoked, @@ -1002,8 +1011,12 @@ def test_manual_report_2d(tmp_path, invisible_fig): for ch in evoked_no_ch_locs.info["chs"]: ch["loc"][:3] = np.nan - with pytest.warns( - RuntimeWarning, match="No EEG channel locations found, cannot create joint plot" + with ( + _record_warnings(), + pytest.warns( + RuntimeWarning, + match="No EEG channel locations found, cannot create joint plot", + ), ): r.add_evokeds( evokeds=evoked_no_ch_locs, @@ -1031,7 +1044,10 @@ def test_manual_report_2d(tmp_path, invisible_fig): for ch in ica_no_ch_locs.info["chs"]: ch["loc"][:3] = np.nan - with pytest.warns(RuntimeWarning, match="No Magnetometers channel locations"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="No Magnetometers channel locations"), + ): r.add_ica( ica=ica_no_ch_locs, picks=[0], inst=raw.copy().load_data(), title="ICA" ) @@ -1055,7 +1071,10 @@ def test_manual_report_3d(tmp_path, renderer): add_kwargs = dict( trans=trans_fname, info=info, subject="sample", subjects_dir=subjects_dir ) - with pytest.warns(RuntimeWarning, match="could not be calculated"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="could not be calculated"), + ): r.add_trans(title="coreg no dig", **add_kwargs) with info._unlock(): info["dig"] = dig @@ -1088,24 +1107,53 @@ def test_sorting(tmp_path): """Test that automated ordering based on tags works.""" r = Report() - r.add_code(code="E = m * c**2", title="intelligence >9000", tags=("bem",)) - r.add_code(code="a**2 + b**2 = c**2", title="Pythagoras", tags=("evoked",)) - r.add_code(code="🧠", title="source of truth", tags=("source-estimate",)) - r.add_code(code="🥦", title="veggies", tags=("raw",)) + titles = ["intelligence >9000", "Pythagoras", "source of truth", "veggies"] + r.add_code(code="E = m * c**2", title=titles[0], tags=("bem",)) + r.add_code(code="a**2 + b**2 = c**2", title=titles[1], tags=("evoked",)) + r.add_code(code="🧠", title=titles[2], tags=("source-estimate",)) + r.add_code(code="🥦", title=titles[3], tags=("raw",)) # Check that repeated calls of add_* actually continuously appended to # the report orig_order = ["bem", "evoked", "source-estimate", "raw"] assert [c.tags[0] for c in r._content] == orig_order + # tags property behavior and get_contents + assert list(r.tags) == sorted(orig_order) + titles, tags, htmls = r.get_contents() + assert set(sum(tags, ())) == set(r.tags) + assert len(titles) == len(tags) == len(htmls) == len(r._content) + for title, tag, html in zip(titles, tags, htmls): + title = title.replace(">", ">") + assert title in html + for t in tag: + assert t in html + # Now check the actual sorting - content_sorted = r._sort(content=r._content, order=CONTENT_ORDER) + r_sorted = r.copy() + r_sorted._sort(order=CONTENT_ORDER) expected_order = ["raw", "evoked", "bem", "source-estimate"] - assert content_sorted != r._content - assert [c.tags[0] for c in content_sorted] == expected_order + assert r_sorted._content != r._content + assert [c.tags[0] for c in r_sorted._content] == expected_order + assert [c.tags[0] for c in r._content] == orig_order + + r.copy().save(fname=tmp_path / "report.html", sort_content=True, open_browser=False) + + # Manual sorting should be the same + r_sorted = r.copy() + order = np.argsort([CONTENT_ORDER.index(t) for t in orig_order]) + r_sorted.reorder(order) + + assert r_sorted._content != r._content + got_order = [c.tags[0] for c in r_sorted._content] + assert [c.tags[0] for c in r._content] == orig_order # original unmodified + assert got_order == expected_order - r.save(fname=tmp_path / "report.html", sort_content=True, open_browser=False) + with pytest.raises(ValueError, match="order must be a permutation"): + r.reorder(np.arange(len(r._content) + 1)) + with pytest.raises(ValueError, match="array of integers"): + r.reorder([1.0]) @pytest.mark.parametrize( diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py index 5e2a00c060f..b1c3428f9df 100644 --- a/mne/simulation/raw.py +++ b/mne/simulation/raw.py @@ -123,15 +123,15 @@ def _check_head_pos(head_pos, info, first_samp, times=None): bad = ts < 0 if bad.any(): raise RuntimeError( - "All position times must be >= 0, found %s/%s" "< 0" % (bad.sum(), len(bad)) + f"All position times must be >= 0, found {bad.sum()}/{len(bad)}" "< 0" ) if times is not None: bad = ts > times[-1] if bad.any(): raise RuntimeError( - "All position times must be <= t_end (%0.1f " - "s), found %s/%s bad values (is this a split " - "file?)" % (times[-1], bad.sum(), len(bad)) + f"All position times must be <= t_end ({times[-1]:0.1f} " + f"s), found {bad.sum()}/{len(bad)} bad values (is this a split " + "file?)" ) # If it starts close to zero, make it zero (else unique(offset) fails) if len(ts) > 0 and ts[0] < (0.5 / info["sfreq"]): @@ -313,8 +313,8 @@ def simulate_raw( # Extract necessary info meeg_picks = pick_types(info, meg=True, eeg=True, exclude=[]) logger.info( - 'Setting up raw simulation: %s position%s, "%s" interpolation' - % (len(dev_head_ts), _pl(dev_head_ts), interp) + f"Setting up raw simulation: {len(dev_head_ts)} " + f'position{_pl(dev_head_ts)}, "{interp}" interpolation' ) if isinstance(stc, SourceSimulator) and stc.first_samp != first_samp: @@ -356,8 +356,8 @@ def simulate_raw( this_n = stc_counted[1].data.shape[1] this_stop = this_start + this_n logger.info( - " Interval %0.3f–%0.3f s" - % (this_start / info["sfreq"], this_stop / info["sfreq"]) + f" Interval {this_start / info['sfreq']:0.3f}–" + f"{this_stop / info['sfreq']:0.3f} s" ) n_doing = this_stop - this_start assert n_doing > 0 @@ -498,7 +498,7 @@ def add_ecg( def _add_exg(raw, kind, head_pos, interp, n_jobs, random_state): assert isinstance(kind, str) and kind in ("ecg", "blink") _validate_type(raw, BaseRaw, "raw") - _check_preload(raw, "Adding %s noise " % (kind,)) + _check_preload(raw, f"Adding {kind} noise ") rng = check_random_state(random_state) info, times, first_samp = raw.info, raw.times, raw.first_samp data = raw._data @@ -686,7 +686,7 @@ def _stc_data_event(stc_counted, head_idx, sfreq, src=None, verts=None): stc_idx, stc = stc_counted if isinstance(stc, (list, tuple)): if len(stc) != 2: - raise ValueError("stc, if tuple, must be length 2, got %s" % (len(stc),)) + raise ValueError(f"stc, if tuple, must be length 2, got {len(stc)}") stc, stim_data = stc else: stim_data = None @@ -705,22 +705,22 @@ def _stc_data_event(stc_counted, head_idx, sfreq, src=None, verts=None): if stim_data.dtype.kind != "i": raise ValueError( "stim_data in a stc tuple must be an integer ndarray," - " got dtype %s" % (stim_data.dtype,) + f" got dtype {stim_data.dtype}" ) if stim_data.shape != (len(stc.times),): raise ValueError( - "event data had shape %s but needed to be (%s,) to" - "match stc" % (stim_data.shape, len(stc.times)) + f"event data had shape {stim_data.shape} but needed to " + f"be ({len(stc.times)},) tomatch stc" ) # Validate STC if not np.allclose(sfreq, 1.0 / stc.tstep): raise ValueError( - "stc and info must have same sample rate, " - "got %s and %s" % (1.0 / stc.tstep, sfreq) + f"stc and info must have same sample rate, " + f"got {1.0 / stc.tstep} and {sfreq}" ) if len(stc.times) <= 2: # to ensure event encoding works raise ValueError( - "stc must have at least three time points, got %s" % (len(stc.times),) + f"stc must have at least three time points, got {len(stc.times)}" ) verts_ = stc.vertices if verts is None: @@ -844,9 +844,7 @@ def _iter_forward_solutions( for ti, dev_head_t in enumerate(dev_head_ts): # Could be *slightly* more efficient not to do this N times, # but the cost here is tiny compared to actual fwd calculation - logger.info( - "Computing gain matrix for transform #%s/%s" % (ti + 1, len(dev_head_ts)) - ) + logger.info(f"Computing gain matrix for transform #{ti + 1}/{len(dev_head_ts)}") _transform_orig_meg_coils(megcoils, dev_head_t) # Make sure our sensors are all outside our BEM @@ -863,8 +861,8 @@ def _iter_forward_solutions( outside = np.ones(len(coil_rr), bool) if not outside.all(): raise RuntimeError( - "%s MEG sensors collided with inner skull " - "surface for transform %s" % (np.sum(~outside), ti) + f"{np.sum(~outside)} MEG sensors collided with inner skull " + f"surface for transform {ti}" ) megfwd = _compute_forwards( rr, sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=False diff --git a/mne/simulation/source.py b/mne/simulation/source.py index f87c9b420de..42c88c47a46 100644 --- a/mne/simulation/source.py +++ b/mne/simulation/source.py @@ -177,8 +177,8 @@ def simulate_sparse_stc( subject = subject_src elif subject_src is not None and subject != subject_src: raise ValueError( - "subject argument (%s) did not match the source " - "space subject_his_id (%s)" % (subject, subject_src) + f"subject argument ({subject}) did not match the source " + f"space subject_his_id ({subject_src})" ) data = np.zeros((n_dipoles, len(times))) for i_dip in range(n_dipoles): @@ -328,9 +328,8 @@ def simulate_stc( d = len(v) - len(np.unique(v)) if d > 0: raise RuntimeError( - "Labels had %s overlaps in the %s " - "hemisphere, " - "they must be non-overlapping" % (d, hemi) + f"Labels had {d} overlaps in the {hemi} " + "hemisphere, they must be non-overlapping" ) # the data is in the order left, right data = list() diff --git a/mne/simulation/tests/test_evoked.py b/mne/simulation/tests/test_evoked.py index bc33d6195a2..b8fc7f12ff8 100644 --- a/mne/simulation/tests/test_evoked.py +++ b/mne/simulation/tests/test_evoked.py @@ -34,15 +34,9 @@ data_path = testing.data_path(download=False) fwd_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif" -raw_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) -ave_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif" -) -cov_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-cov.fif" -) +raw_fname = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif" +ave_fname = Path(__file__).parents[2] / "io" / "tests" / "data" / "test-ave.fif" +cov_fname = Path(__file__).parents[2] / "io" / "tests" / "data" / "test-cov.fif" @testing.requires_testing_data diff --git a/mne/simulation/tests/test_raw.py b/mne/simulation/tests/test_raw.py index bf4caf3bdeb..2b047f758dd 100644 --- a/mne/simulation/tests/test_raw.py +++ b/mne/simulation/tests/test_raw.py @@ -59,9 +59,7 @@ from mne.tests.test_chpi import _assert_quats from mne.utils import catch_logging -raw_fname_short = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) +raw_fname_short = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif" data_path = testing.data_path(download=False) raw_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" @@ -400,7 +398,7 @@ def test_simulate_raw_bem(raw_data): fits = fit_dipole(evoked, cov, bem, trans, min_dist=1.0)[0].pos diffs = np.sqrt(np.sum((locs - fits) ** 2, axis=-1)) * 1000 med_diff = np.median(diffs) - assert med_diff < tol, "%s: %s" % (bem, med_diff) + assert med_diff < tol, f"{bem}: {med_diff}" # also test event timings with SourceSimulator first_samp = raw.first_samp events = find_events(raw, initial_event=True, verbose=False) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index efc5a06515a..481ae84efab 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -17,13 +17,14 @@ from ._fiff.constants import FIFF from ._fiff.meas_info import Info -from ._fiff.pick import pick_types +from ._fiff.pick import _picks_to_idx, pick_types from ._freesurfer import _get_atlas_values, _get_mri_info_data, read_freesurfer_lut from .baseline import rescale from .cov import Covariance from .evoked import _get_peak -from .filter import resample +from .filter import FilterMixin, _check_fun, resample from .fixes import _safe_svd +from .parallel import parallel_func from .source_space._source_space import ( SourceSpaces, _check_volume_labels, @@ -31,6 +32,7 @@ _ensure_src_subject, _get_morph_src_reordering, _get_src_nn, + get_decimated_surfaces, ) from .surface import _get_ico_surface, _project_onto_surface, mesh_edges, read_surface from .transforms import _get_trans, apply_trans @@ -41,6 +43,7 @@ _check_option, _check_pandas_index_arguments, _check_pandas_installed, + _check_preload, _check_src_normal, _check_stc_units, _check_subject, @@ -381,8 +384,8 @@ def read_source_estimate(fname, subject=None): kwargs["subject"] = subject if subject is not None and subject != kwargs["subject"]: raise RuntimeError( - 'provided subject name "%s" does not match ' - 'subject name from the file "%s' % (subject, kwargs["subject"]) + f'provided subject name "{subject}" does not match ' + f'subject name from the file "{kwargs["subject"]}' ) if ftype in ("volume", "discrete"): @@ -477,7 +480,7 @@ def _verify_source_estimate_compat(a, b): """Make sure two SourceEstimates are compatible for arith. operations.""" compat = False if type(a) != type(b): - raise ValueError("Cannot combine %s and %s." % (type(a), type(b))) + raise ValueError(f"Cannot combine {type(a)} and {type(b)}.") if len(a.vertices) == len(b.vertices): if all(np.array_equal(av, vv) for av, vv in zip(a.vertices, b.vertices)): compat = True @@ -489,17 +492,15 @@ def _verify_source_estimate_compat(a, b): if a.subject != b.subject: raise ValueError( "source estimates do not have the same subject " - "names, %r and %r" % (a.subject, b.subject) + f"names, {repr(a.subject)} and {repr(b.subject)}" ) -class _BaseSourceEstimate(TimeMixin): +class _BaseSourceEstimate(TimeMixin, FilterMixin): _data_ndim = 2 @verbose - def __init__( - self, data, vertices, tmin, tstep, subject=None, verbose=None - ): # noqa: D102 + def __init__(self, data, vertices, tmin, tstep, subject=None, verbose=None): assert hasattr(self, "_data_ndim"), self.__class__.__name__ assert hasattr(self, "_src_type"), self.__class__.__name__ assert hasattr(self, "_src_count"), self.__class__.__name__ @@ -511,13 +512,12 @@ def __init__( data = None if kernel.shape[1] != sens_data.shape[0]: raise ValueError( - "kernel (%s) and sens_data (%s) have invalid " - "dimensions" % (kernel.shape, sens_data.shape) + f"kernel ({kernel.shape}) and sens_data ({sens_data.shape}) " + "have invalid dimensions" ) if sens_data.ndim != 2: raise ValueError( - "The sensor data must have 2 dimensions, got " - "%s" % (sens_data.ndim,) + "The sensor data must have 2 dimensions, got {sens_data.ndim}" ) _validate_type(vertices, list, "vertices") @@ -537,8 +537,8 @@ def __init__( if data is not None: if data.ndim not in (self._data_ndim, self._data_ndim - 1): raise ValueError( - "Data (shape %s) must have %s dimensions for " - "%s" % (data.shape, self._data_ndim, self.__class__.__name__) + f"Data (shape {data.shape}) must have {self._data_ndim} " + f"dimensions for {self.__class__.__name__}" ) if data.shape[0] != n_src: raise ValueError( @@ -549,7 +549,7 @@ def __init__( if data.shape[1] != 3: raise ValueError( "Data for VectorSourceEstimate must have " - "shape[1] == 3, got shape %s" % (data.shape,) + f"shape[1] == 3, got shape {data.shape}" ) if data.ndim == self._data_ndim - 1: # allow upbroadcasting data = data[..., np.newaxis] @@ -572,10 +572,10 @@ def __repr__(self): # noqa: D105 s += ", tmin : %s (ms)" % (1e3 * self.tmin) s += ", tmax : %s (ms)" % (1e3 * self.times[-1]) s += ", tstep : %s (ms)" % (1e3 * self.tstep) - s += ", data shape : %s" % (self.shape,) + s += f", data shape : {self.shape}" sz = sum(object_size(x) for x in (self.vertices + [self.data])) s += f", ~{sizeof_fmt(sz)}" - return "<%s | %s>" % (type(self).__name__, s) + return f"<{type(self).__name__} | {s}>" @fill_doc def get_peak( @@ -643,6 +643,57 @@ def extract_label_time_course( verbose=verbose, ) + @verbose + def apply_function( + self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs + ): + """Apply a function to a subset of vertices. + + %(applyfun_summary_stc)s + + Parameters + ---------- + %(fun_applyfun_stc)s + %(picks_all)s + %(dtype_applyfun)s + %(n_jobs)s Ignored if ``vertice_wise=False`` as the workload + is split across vertices. + %(verbose)s + %(kwargs_fun)s + + Returns + ------- + self : instance of SourceEstimate + The SourceEstimate object with transformed data. + """ + _check_preload(self, "source_estimate.apply_function") + picks = _picks_to_idx(len(self._data), picks, exclude=(), with_ref_meg=False) + + if not callable(fun): + raise ValueError("fun needs to be a function") + + data_in = self._data + if dtype is not None and dtype != self._data.dtype: + self._data = self._data.astype(dtype) + + # check the dimension of the source estimate data + _check_option("source_estimate.ndim", self._data.ndim, [2, 3]) + + parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) + if n_jobs == 1: + # modify data inplace to save memory + for idx in picks: + self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) + else: + # use parallel function + data_picks_new = parallel( + p_fun(fun, data_in[p, :], **kwargs) for p in picks + ) + for pp, p in enumerate(picks): + self._data[p, :] = data_picks_new[pp] + + return self + @verbose def apply_baseline(self, baseline=(None, 0), *, verbose=None): """Baseline correct source estimate data. @@ -685,8 +736,7 @@ def save(self, fname, ftype="h5", *, overwrite=False, verbose=None): fname = _check_fname(fname=fname, overwrite=True) # check below if ftype != "h5": raise ValueError( - "%s objects can only be written as HDF5 files." - % (self.__class__.__name__,) + f"{self.__class__.__name__} objects can only be written as HDF5 files." ) _, write_hdf5 = _import_h5io_funcs() if fname.suffix != ".h5": @@ -821,7 +871,17 @@ def crop(self, tmin=None, tmax=None, include_tmax=True): return self # return self for chaining methods @verbose - def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=None): + def resample( + self, + sfreq, + *, + npad=100, + method="fft", + window="auto", + pad="auto", + n_jobs=None, + verbose=None, + ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. @@ -835,8 +895,15 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non Amount to pad the start and end of the data. Can also be "auto" to use a padding that will result in a power-of-two size (can be much faster). - window : str | tuple - Window to use in resampling. See :func:`scipy.signal.resample`. + %(method_resample)s + + .. versionadded:: 1.7 + %(window_resample)s + + .. versionadded:: 1.7 + %(pad_resample_auto)s + + .. versionadded:: 1.7 %(n_jobs)s %(verbose)s @@ -865,7 +932,9 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non data = self.data if data.dtype == np.float32: data = data.astype(np.float64) - self.data = resample(data, sfreq, o_sfreq, npad, n_jobs=n_jobs) + self.data = resample( + data, sfreq, o_sfreq, npad=npad, window=window, n_jobs=n_jobs, method=method + ) # adjust indirectly affected variables self.tstep = 1.0 / sfreq @@ -1381,15 +1450,15 @@ def to_data_frame( if self.subject is not None: default_index = ["subject", "time"] mindex.append(("subject", np.repeat(self.subject, data.shape[0]))) - times = _convert_times(self, times, time_format) + times = _convert_times(times, time_format) mindex.append(("time", times)) # triage surface vs volume source estimates col_names = list() kinds = ["VOL"] * len(self.vertices) if isinstance(self, (_BaseSurfaceSourceEstimate, _BaseMixedSourceEstimate)): kinds[:2] = ["LH", "RH"] - for ii, (kind, vertno) in enumerate(zip(kinds, self.vertices)): - col_names.extend(["{}_{}".format(kind, vert) for vert in vertno]) + for kind, vertno in zip(kinds, self.vertices): + col_names.extend([f"{kind}_{vert}" for vert in vertno]) # build DataFrame df = _build_data_frame( self, @@ -1539,7 +1608,7 @@ def in_label(self, label): ): raise RuntimeError( "label and stc must have same subject names, " - 'currently "%s" and "%s"' % (label.subject, self.subject) + f'currently "{label.subject}" and "{self.subject}"' ) if label.hemi == "both": @@ -1567,6 +1636,77 @@ def in_label(self, label): ) return label_stc + def save_as_surface(self, fname, src, *, scale=1, scale_rr=1e3): + """Save a surface source estimate (stc) as a GIFTI file. + + Parameters + ---------- + fname : path-like + Filename basename to save files as. + Will write anatomical GIFTI plus time series GIFTI for both lh/rh, + for example ``"basename"`` will write ``"basename.lh.gii"``, + ``"basename.lh.time.gii"``, ``"basename.rh.gii"``, and + ``"basename.rh.time.gii"``. + src : instance of SourceSpaces + The source space of the forward solution. + scale : float + Scale factor to apply to the data (functional) values. + scale_rr : float + Scale factor for the source vertex positions. The default (1e3) will + scale from meters to millimeters, which is more standard for GIFTI files. + + Notes + ----- + .. versionadded:: 1.7 + """ + nib = _import_nibabel() + _check_option("src.kind", src.kind, ("surface", "mixed")) + ss = get_decimated_surfaces(src) + assert len(ss) == 2 # should be guaranteed by _check_option above + + # Create lists to put DataArrays into + hemis = ("lh", "rh") + for s, hemi in zip(ss, hemis): + darrays = list() + darrays.append( + nib.gifti.gifti.GiftiDataArray( + data=(s["rr"] * scale_rr).astype(np.float32), + intent="NIFTI_INTENT_POINTSET", + datatype="NIFTI_TYPE_FLOAT32", + ) + ) + + # Make the topology DataArray + darrays.append( + nib.gifti.gifti.GiftiDataArray( + data=s["tris"].astype(np.int32), + intent="NIFTI_INTENT_TRIANGLE", + datatype="NIFTI_TYPE_INT32", + ) + ) + + # Make the output GIFTI for anatomicals + topo_gi_hemi = nib.gifti.gifti.GiftiImage(darrays=darrays) + + # actually save the file + nib.save(topo_gi_hemi, f"{fname}-{hemi}.gii") + + # Make the Time Series data arrays + ts = [] + data = getattr(self, f"{hemi}_data") * scale + ts = [ + nib.gifti.gifti.GiftiDataArray( + data=data[:, idx].astype(np.float32), + intent="NIFTI_INTENT_POINTSET", + datatype="NIFTI_TYPE_FLOAT32", + ) + for idx in range(data.shape[1]) + ] + + # save the time series + ts_gi = nib.gifti.gifti.GiftiImage(darrays=ts) + nib.save(ts_gi, f"{fname}-{hemi}.time.gii") + def expand(self, vertices): """Expand SourceEstimate to include more vertices. @@ -1586,7 +1726,7 @@ def expand(self, vertices): if not isinstance(vertices, list): raise TypeError("vertices must be a list") if not len(self.vertices) == len(vertices): - raise ValueError("vertices must have the same length as " "stc.vertices") + raise ValueError("vertices must have the same length as stc.vertices") # can no longer use kernel and sensor data self._remove_kernel_sens_data_() @@ -1800,7 +1940,7 @@ def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): ) elif ftype == "w": if self.shape[1] != 1: - raise ValueError("w files can only contain a single time " "point") + raise ValueError("w files can only contain a single time point.") logger.info("Writing STC to disk (w format)...") fname_l = str(_check_fname(fname + "-lh.w", overwrite=overwrite)) fname_r = str(_check_fname(fname + "-rh.w", overwrite=overwrite)) @@ -1961,7 +2101,7 @@ def center_of_mass( .. footbibliography:: """ if not isinstance(surf, str): - raise TypeError("surf must be a string, got %s" % (type(surf),)) + raise TypeError(f"surf must be a string, got {type(surf)}") subject = _check_subject(self.subject, subject) if np.any(self.data < 0): raise ValueError("Cannot compute COM with negative values") @@ -2001,7 +2141,7 @@ class _BaseVectorSourceEstimate(_BaseSourceEstimate): @verbose def __init__( self, data, vertices=None, tmin=None, tstep=None, subject=None, verbose=None - ): # noqa: D102 + ): assert hasattr(self, "_scalar_class") super().__init__(data, vertices, tmin, tstep, subject, verbose) @@ -2138,7 +2278,7 @@ def plot( add_data_kwargs=None, brain_kwargs=None, verbose=None, - ): # noqa: D102 + ): return plot_vector_source_estimates( self, subject=subject, @@ -2358,7 +2498,7 @@ def in_label(self, label, mri, src, *, verbose=None): """ if len(self.vertices) != 1: raise RuntimeError( - "This method can only be used with whole-brain " "volume source spaces" + "This method can only be used with whole-brain volume source spaces" ) _validate_type(label, (str, "int-like"), "label") if isinstance(label, str): @@ -2387,7 +2527,7 @@ def save_as_volume( src, dest="mri", mri_resolution=False, - format="nifti1", + format="nifti1", # noqa: A002 *, overwrite=False, verbose=None, @@ -2436,7 +2576,13 @@ def save_as_volume( ) nib.save(img, fname) - def as_volume(self, src, dest="mri", mri_resolution=False, format="nifti1"): + def as_volume( + self, + src, + dest="mri", + mri_resolution=False, + format="nifti1", # noqa: A002 + ): """Export volume source estimate as a nifti object. Parameters @@ -2643,7 +2789,7 @@ def plot_3d( add_data_kwargs=None, brain_kwargs=None, verbose=None, - ): # noqa: D102 + ): return _BaseVectorSourceEstimate.plot( self, subject=subject, @@ -2734,7 +2880,7 @@ class _BaseMixedSourceEstimate(_BaseSourceEstimate): @verbose def __init__( self, data, vertices=None, tmin=None, tstep=None, subject=None, verbose=None - ): # noqa: D102 + ): if not isinstance(vertices, list) or len(vertices) < 2: raise ValueError( "Vertices must be a list of numpy arrays with " @@ -2978,7 +3124,7 @@ def spatio_temporal_src_adjacency(src, n_times, dist=None, verbose=None): if src[0]["type"] == "vol": if dist is not None: raise ValueError( - "dist must be None for a volume " "source space. Got %s." % dist + f"dist must be None for a volume source space. Got {dist}." ) adjacency = _spatio_temporal_src_adjacency_vol(src, n_times) @@ -3417,13 +3563,12 @@ def _volume_labels(src, labels, mri_resolution): else: if len(labels) != 2: raise ValueError( - "labels, if list or tuple, must have length 2, " - "got %s" % (len(labels),) + "labels, if list or tuple, must have length 2, got {len(labels)}" ) mri, labels = labels infer_labels = False _validate_type(mri, "path-like", "labels[0]" + extra) - logger.info("Reading atlas %s" % (mri,)) + logger.info(f"Reading atlas {mri}") vol_info = _get_mri_info_data(str(mri), data=True) atlas_data = vol_info["data"] atlas_values = np.unique(atlas_data) @@ -3458,8 +3603,8 @@ def _volume_labels(src, labels, mri_resolution): atlas_shape = atlas_data.shape if atlas_shape != src_shape: raise RuntimeError( - "atlas shape %s does not match source space MRI " - "shape %s" % (atlas_shape, src_shape) + f"atlas shape {atlas_shape} does not match source space MRI " + f"shape {src_shape}" ) atlas_data = atlas_data.ravel(order="F") if mri_resolution: @@ -3561,10 +3706,10 @@ def _gen_extract_label_time_course( if len(vn) != len(svn): raise ValueError( "stc not compatible with source space. " - "stc has %s time series but there are %s " + f"stc has {len(svn)} time series but there are {len(vn)} " "vertices in source space. Ensure you used " "src from the forward or inverse operator, " - "as forward computation can exclude vertices." % (len(svn), len(vn)) + "as forward computation can exclude vertices." ) if not np.array_equal(svn, vn): raise ValueError("stc not compatible with source space") @@ -3686,7 +3831,7 @@ def stc_near_sensors( subjects_dir=None, src=None, picks=None, - surface="pial", + surface="auto", verbose=None, ): """Create a STC from ECoG, sEEG and DBS sensor data. @@ -3726,8 +3871,8 @@ def stc_near_sensors( .. versionadded:: 0.24 surface : str | None - The surface to use if ``src=None``. Default is the pial surface. - If None, the source space surface will be used. + The surface to use. If ``src=None``, defaults to the pial surface. + Otherwise, the source space surface will be used. .. versionadded:: 0.24.1 %(verbose)s @@ -3781,12 +3926,30 @@ def stc_near_sensors( _validate_type(mode, str, "mode") _validate_type(src, (None, SourceSpaces), "src") _check_option("mode", mode, ("sum", "single", "nearest", "weighted")) + if surface == "auto": + if src is not None: + pial_fname = op.join(subjects_dir, subject, "surf", "lh.pial") + pial_rr = read_surface(pial_fname)[0] + src_surf_is_pial = ( + op.isfile(pial_fname) + and src[0]["rr"].shape == pial_rr.shape + and np.allclose(src[0]["rr"], pial_rr) + ) + if not src_surf_is_pial: + warn( + "In version 1.8, ``surface='auto'`` will be the default " + "which will use the surface in ``src`` instead of the " + "pial surface when ``src != None``. Pass ``surface='pial'`` " + "or ``surface=None`` to suppress this warning", + DeprecationWarning, + ) + surface = "pial" if src is None or src.kind == "surface" else None # create a copy of Evoked using ecog, seeg and dbs if picks is None: picks = pick_types(evoked.info, ecog=True, seeg=True, dbs=True) evoked = evoked.copy().pick(picks) - frames = set(evoked.info["chs"][pick]["coord_frame"] for pick in picks) + frames = set(ch["coord_frame"] for ch in evoked.info["chs"]) if not frames == {FIFF.FIFFV_COORD_HEAD}: raise RuntimeError( "Channels must be in the head coordinate frame, " f"got {sorted(frames)}" diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index 3cfedb9d7a1..7f2910cbaad 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -35,13 +35,10 @@ write_int_matrix, write_string, ) - -# Remove get_mni_fiducials in 1.6 (deprecated) from .._freesurfer import ( _check_mri, _get_atlas_values, _get_mri_info_data, - get_mni_fiducials, # noqa: F401 get_volume_labels_from_aseg, read_freesurfer_lut, ) @@ -289,10 +286,10 @@ class SourceSpaces(list): access, like ``src.kind``. """ # noqa: E501 - def __init__(self, source_spaces, info=None): # noqa: D102 + def __init__(self, source_spaces, info=None): # First check the types is actually a valid config _validate_type(source_spaces, list, "source_spaces") - super(SourceSpaces, self).__init__(source_spaces) # list + super().__init__(source_spaces) # list self.kind # will raise an error if there is a problem if info is None: self.info = dict() @@ -323,7 +320,7 @@ def kind(self): else: kind = "volume" if any(k == "surf" for k in types[surf_check:]): - raise RuntimeError("Invalid source space with kinds %s" % (types,)) + raise RuntimeError(f"Invalid source space with kinds {types}") return kind @verbose @@ -449,9 +446,9 @@ def __repr__(self): # noqa: D105 r = _src_kind_dict[ss_type] if ss_type == "vol": if "seg_name" in ss: - r += " (%s)" % (ss["seg_name"],) + r += f" ({ss['seg_name']})" else: - r += ", shape=%s" % (ss["shape"],) + r += f", shape={ss['shape']}" elif ss_type == "surf": r += " (%s), n_vertices=%i" % (_get_hemi(ss)[0], ss["np"]) r += ", n_used=%i" % (ss["nuse"],) @@ -460,11 +457,11 @@ def __repr__(self): # noqa: D105 ss_repr.append("<%s>" % r) subj = self._subject if subj is not None: - extra += ["subject %r" % (subj,)] + extra += [f"subject {repr(subj)}"] sz = object_size(self) if sz is not None: extra += [f"~{sizeof_fmt(sz)}"] - return "" % (", ".join(ss_repr), ", ".join(extra)) + return f"" @property def _subject(self): @@ -671,10 +668,11 @@ def export_volume( # Figure out how to get from our input source space to output voxels fro_dst_t = invert_transform(transform) - dest = transform["to"] if coords == "head": head_mri_t = _get_trans(trans, "head", "mri")[0] - fro_dst_t = combine_transforms(head_mri_t, fro_dst_t, "head", dest) + fro_dst_t = combine_transforms( + head_mri_t, fro_dst_t, "head", transform["to"] + ) else: fro_dst_t = fro_dst_t @@ -1428,11 +1426,8 @@ def _check_spacing(spacing, verbose=None): """Check spacing parameter.""" # check to make sure our parameters are good, parse 'spacing' types = 'a string with values "ico#", "oct#", "all", or an int >= 2' - space_err = '"spacing" must be %s, got type %s (%r)' % ( - types, - type(spacing), - spacing, - ) + space_err = f'"spacing" must be {types}, got type {type(spacing)} ({repr(spacing)})' + if isinstance(spacing, str): if spacing == "all": stype = "all" @@ -1444,13 +1439,11 @@ def _check_spacing(spacing, verbose=None): sval = int(sval) except Exception: raise ValueError( - "%s subdivision must be an integer, got %r" % (stype, sval) + f"{stype} subdivision must be an integer, got {repr(sval)}" ) lim = 0 if stype == "ico" else 1 if sval < lim: - raise ValueError( - "%s subdivision must be >= %s, got %s" % (stype, lim, sval) - ) + raise ValueError(f"{stype} subdivision must be >= {lim}, got {sval}") else: raise ValueError(space_err) else: @@ -1463,7 +1456,7 @@ def _check_spacing(spacing, verbose=None): ico_surf = None src_type_str = "all" else: - src_type_str = "%s = %s" % (stype, sval) + src_type_str = f"{stype} = {sval}" if stype == "ico": logger.info("Icosahedron subdivision grade %s" % sval) ico_surf = _get_ico_surface(sval) @@ -1525,9 +1518,8 @@ def setup_source_space( setup_volume_source_space """ cmd = ( - "setup_source_space(%s, spacing=%s, surface=%s, " - "subjects_dir=%s, add_dist=%s, verbose=%s)" - % (subject, spacing, surface, subjects_dir, add_dist, verbose) + f"setup_source_space({subject}, spacing={spacing}, surface={surface}, " + f"subjects_dir={subjects_dir}, add_dist={add_dist}, verbose={verbose})" ) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) @@ -1536,7 +1528,7 @@ def setup_source_space( ] for surf, hemi in zip(surfs, ["LH", "RH"]): if surf is not None and not op.isfile(surf): - raise OSError("Could not find the %s surface %s" % (hemi, surf)) + raise OSError(f"Could not find the {hemi} surface {surf}") logger.info("Setting up the source space with the following parameters:\n") logger.info("SUBJECTS_DIR = %s" % subjects_dir) @@ -1554,8 +1546,7 @@ def setup_source_space( # pre-load ico/oct surf (once) for speed, if necessary if stype not in ("spacing", "all"): logger.info( - "Doing the %shedral vertex picking..." - % (dict(ico="icosa", oct="octa")[stype],) + f'Doing the {dict(ico="icosa", oct="octa")[stype]}hedral vertex picking...' ) for hemi, surf in zip(["lh", "rh"], surfs): logger.info("Loading %s..." % surf) @@ -1608,7 +1599,7 @@ def setup_source_space( def _check_volume_labels(volume_label, mri, name="volume_label"): - _validate_type(mri, "path-like", "mri when %s is not None" % (name,)) + _validate_type(mri, "path-like", f"mri when {name} is not None") mri = str(_check_fname(mri, overwrite="read", must_exist=True)) if isinstance(volume_label, str): volume_label = [volume_label] @@ -1617,22 +1608,22 @@ def _check_volume_labels(volume_label, mri, name="volume_label"): # Turn it into a dict if not mri.endswith("aseg.mgz"): raise RuntimeError( - "Must use a *aseg.mgz file unless %s is a dict, got %s" - % (name, op.basename(mri)) + f"Must use a *aseg.mgz file unless {name} is a dict, " + f"got {op.basename(mri)}" ) lut, _ = read_freesurfer_lut() use_volume_label = dict() for label in volume_label: if label not in lut: raise ValueError( - "Volume %r not found in file %s. Double check " - "FreeSurfer lookup table.%s" % (label, mri, _suggest(label, lut)) + f"Volume {repr(label)} not found in file {mri}. Double check " + f"FreeSurfer lookup table.{_suggest(label, lut)}" ) use_volume_label[label] = lut[label] volume_label = use_volume_label for label, id_ in volume_label.items(): _validate_type(label, str, "volume_label keys") - _validate_type(id_, "int-like", "volume_labels[%r]" % (label,)) + _validate_type(id_, "int-like", f"volume_labels[{repr(label)}]") volume_label = {k: _ensure_int(v) for k, v in volume_label.items()} return volume_label @@ -1828,10 +1819,10 @@ def setup_volume_source_space( logger.info("Boundary surface file : %s", surf_extra) else: logger.info( - "Sphere : origin at (%.1f %.1f %.1f) mm" - % (1000 * sphere[0], 1000 * sphere[1], 1000 * sphere[2]) + f"Sphere : origin at ({1000 * sphere[0]:.1f} " + f"{1000 * sphere[1]:.1f} {1000 * sphere[2]:.1f}) mm" ) - logger.info(" radius : %.1f mm" % (1000 * sphere[3],)) + logger.info(f" radius : {1000 * sphere[3]:.1f} mm") # triage pos argument if isinstance(pos, dict): @@ -1889,8 +1880,8 @@ def setup_volume_source_space( assert surf["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN if surf["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise ValueError( - "BEM is not in MRI coordinates, got %s" - % (_coord_frame_name(surf["coord_frame"]),) + f"BEM is not in MRI coordinates, got " + f"{_coord_frame_name(surf['coord_frame'])}" ) logger.info("Taking inner skull from %s" % bem) elif surface is not None: @@ -1999,8 +1990,8 @@ def _make_discrete_source_space(pos, coord_frame="mri"): # Check that coordinate frame is valid if coord_frame not in _str_to_frame: # will fail if coord_frame not string raise KeyError( - 'coord_frame must be one of %s, not "%s"' - % (list(_str_to_frame.keys()), coord_frame) + f"coord_frame must be one of {list(_str_to_frame.keys())}, " + f'not "{coord_frame}"' ) coord_frame = _str_to_frame[coord_frame] # now an int @@ -2050,11 +2041,12 @@ def _make_volume_source_space( volume_labels=None, do_neighbors=True, n_jobs=None, - vol_info={}, + vol_info=None, single_volume=False, ): """Make a source space which covers the volume bounded by surf.""" # Figure out the grid size in the MRI coordinate frame + vol_info = {} if vol_info is None else vol_info if "rr" in surf: mins = np.min(surf["rr"], axis=0) maxs = np.max(surf["rr"], axis=0) @@ -2068,13 +2060,12 @@ def _make_volume_source_space( # Define the sphere which fits the surface logger.info( - "Surface CM = (%6.1f %6.1f %6.1f) mm" - % (1000 * cm[0], 1000 * cm[1], 1000 * cm[2]) + f"Surface CM = ({1000 * cm[0]:6.1f} {1000 * cm[1]:6.1f} {1000 * cm[2]:6.1f}) mm" ) logger.info("Surface fits inside a sphere with radius %6.1f mm" % (1000 * maxdist)) logger.info("Surface extent:") for c, mi, ma in zip("xyz", mins, maxs): - logger.info(" %s = %6.1f ... %6.1f mm" % (c, 1000 * mi, 1000 * ma)) + logger.info(f" {c} = {1000 * mi:6.1f} ... {1000 * ma:6.1f} mm") maxn = np.array( [ np.floor(np.abs(m) / grid) + 1 if m > 0 else -np.floor(np.abs(m) / grid) - 1 @@ -2091,9 +2082,7 @@ def _make_volume_source_space( ) logger.info("Grid extent:") for c, mi, ma in zip("xyz", minn, maxn): - logger.info( - " %s = %6.1f ... %6.1f mm" % (c, 1000 * mi * grid, 1000 * ma * grid) - ) + logger.info(f" {c} = {1000 * mi * grid:6.1f} ... {1000 * ma * grid:6.1f} mm") # Now make the initial grid ns = tuple(maxn - minn + 1) @@ -2336,7 +2325,7 @@ def _vol_vertex(width, height, jj, kk, pp): def _src_vol_dims(s): - w, h, d = [s[f"mri_{key}"] for key in ("width", "height", "depth")] + w, h, d = (s[f"mri_{key}"] for key in ("width", "height", "depth")) return w, h, d, np.prod([w, h, d]) @@ -2411,7 +2400,7 @@ def _grid_interp(from_shape, to_shape, trans, order=1, inuse=None): shape = (np.prod(to_shape), np.prod(from_shape)) if inuse is None: inuse = np.ones(shape[1], bool) - assert inuse.dtype == bool + assert inuse.dtype == np.dtype(bool) assert inuse.shape == (shape[1],) data, indices, indptr = _grid_interp_jit(from_shape, to_shape, trans, order, inuse) data = np.concatenate(data) @@ -2632,7 +2621,7 @@ def _adjust_patch_info(s, verbose=None): def _ensure_src(src, kind=None, extra="", verbose=None): """Ensure we have a source space.""" _check_option("kind", kind, (None, "surface", "volume", "mixed", "discrete")) - msg = "src must be a string or instance of SourceSpaces%s" % (extra,) + msg = f"src must be a string or instance of SourceSpaces{extra}" if _path_like(src): src = str(src) if not op.isfile(src): @@ -2640,7 +2629,7 @@ def _ensure_src(src, kind=None, extra="", verbose=None): logger.info("Reading %s..." % src) src = read_source_spaces(src, verbose=False) if not isinstance(src, SourceSpaces): - raise ValueError("%s, got %s (type %s)" % (msg, src, type(src))) + raise ValueError(f"{msg}, got {src} (type {type(src)})") if kind is not None: if src.kind != kind and src.kind == "mixed": if kind == "surface": @@ -2648,9 +2637,7 @@ def _ensure_src(src, kind=None, extra="", verbose=None): elif kind == "volume": src = src[2:] if src.kind != kind: - raise ValueError( - "Source space must contain %s type, got " "%s" % (kind, src.kind) - ) + raise ValueError(f"Source space must contain {kind} type, got {src.kind}") return src @@ -2662,8 +2649,8 @@ def _ensure_src_subject(src, subject): raise ValueError("source space is too old, subject must be " "provided") elif src_subject is not None and subject != src_subject: raise ValueError( - 'Mismatch between provided subject "%s" and subject ' - 'name "%s" in the source space' % (subject, src_subject) + f'Mismatch between provided subject "{subject}" and subject ' + f'name "{src_subject}" in the source space' ) return subject @@ -2714,7 +2701,7 @@ def add_source_space_distances(src, dist_limit=np.inf, n_jobs=None, *, verbose=N src = _ensure_src(src) dist_limit = float(dist_limit) if dist_limit < 0: - raise ValueError("dist_limit must be non-negative, got %s" % (dist_limit,)) + raise ValueError(f"dist_limit must be non-negative, got {dist_limit}") patch_only = dist_limit == 0 if src.kind != "surface": raise RuntimeError("Currently all source spaces must be of surface " "type") @@ -2723,7 +2710,7 @@ def add_source_space_distances(src, dist_limit=np.inf, n_jobs=None, *, verbose=N min_dists = list() min_idxs = list() msg = "patch information" if patch_only else "source space distances" - logger.info("Calculating %s (limit=%s mm)..." % (msg, 1000 * dist_limit)) + logger.info(f"Calculating {msg} (limit={1000 * dist_limit} mm)...") max_n = max(s["nuse"] for s in src) if not patch_only and max_n > _DIST_WARN_LIMIT: warn( @@ -2893,14 +2880,12 @@ def _get_vertex_map_nn( """ # adapted from mne_make_source_space.c, knowing accurate=False (i.e. # nearest-neighbor mode should be used) - logger.info( - "Mapping %s %s -> %s (nearest neighbor)..." % (hemi, subject_from, subject_to) - ) + logger.info(f"Mapping {hemi} {subject_from} -> {subject_to} (nearest neighbor)...") regs = [ subjects_dir / s / "surf" / f"{hemi}.sphere.reg" for s in (subject_from, subject_to) ] - reg_fro, reg_to = [read_surface(r, return_dict=True)[-1] for r in regs] + reg_fro, reg_to = (read_surface(r, return_dict=True)[-1] for r in regs) if to_neighbor_tri is not None: reg_to["neighbor_tri"] = to_neighbor_tri if "neighbor_tri" not in reg_to: @@ -2978,7 +2963,7 @@ def morph_source_spaces( for fro in src_from: hemi, idx, id_ = _get_hemi(fro) to = subjects_dir / subject_to / "surf" / f"{hemi}.{surf}" - logger.info("Reading destination surface %s" % (to,)) + logger.info(f"Reading destination surface {to}") to = read_surface(to, return_dict=True, verbose=False)[-1] complete_surface_info(to, copy=False) # Now we morph the vertices to the destination @@ -3172,8 +3157,8 @@ def _compare_source_spaces(src0, src1, mode="exact", nearest=True, dist_tol=1.5e assert_array_equal( s["vertno"], np.where(s["inuse"])[0], - 'src%s[%s]["vertno"] != ' - 'np.where(src%s[%s]["inuse"])[0]' % (ii, si, ii, si), + f'src{ii}[{si}]["vertno"] != ' + f'np.where(src{ii}[{si}]["inuse"])[0]', ) assert_equal(len(s0["vertno"]), len(s1["vertno"])) agreement = np.mean(s0["inuse"] == s1["inuse"]) diff --git a/mne/source_space/tests/test_source_space.py b/mne/source_space/tests/test_source_space.py index 4db0286a2a5..628428fd84e 100644 --- a/mne/source_space/tests/test_source_space.py +++ b/mne/source_space/tests/test_source_space.py @@ -66,7 +66,7 @@ fname_src = data_path / "subjects" / "sample" / "bem" / "sample-oct-4-src.fif" fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" trans_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc-trans.fif" -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" fname_small = base_dir / "small-src.fif.gz" fname_ave = base_dir / "test-ave.fif" rng = np.random.RandomState(0) @@ -679,6 +679,7 @@ def test_source_space_from_label(tmp_path, pass_ids): _compare_source_spaces(src, src_from_file, mode="approx") +@pytest.mark.slowtest @testing.requires_testing_data def test_source_space_exclusive_complete(src_volume_labels): """Test that we produce exclusive and complete labels.""" @@ -699,7 +700,10 @@ def test_source_space_exclusive_complete(src_volume_labels): for si, s in enumerate(src): assert_allclose(src_full[0]["rr"], s["rr"], atol=1e-6) # also check single_volume=True -- should be the same result - with pytest.warns(RuntimeWarning, match="Found no usable.*Left-vessel.*"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="Found no usable.*Left-vessel.*"), + ): src_single = setup_volume_source_space( src[0]["subject_his_id"], 7.0, @@ -870,6 +874,27 @@ def test_combine_source_spaces(tmp_path): with pytest.warns(RuntimeWarning, match="2 surf vertices lay outside"): src.export_volume(image_fname, mri_resolution="sparse", overwrite=True) + # gh-12495 + image_fname = tmp_path / "temp-image.nii" + lh_cereb = mne.setup_volume_source_space( + "sample", + mri=aseg_fname, + volume_label="Left-Cerebellum-Cortex", + add_interpolator=False, + subjects_dir=subjects_dir, + ) + lh_cereb.export_volume(image_fname, mri_resolution=True) + aseg = nib.load(str(aseg_fname)) + out = nib.load(str(image_fname)) + assert_allclose(out.affine, aseg.affine) + src_data = _get_img_fdata(out).astype(bool) + aseg_data = _get_img_fdata(aseg) == 8 + n_src = src_data.sum() + n_aseg = aseg_data.sum() + assert n_aseg == n_src + n_overlap = (src_data & aseg_data).sum() + assert n_src == n_overlap + @testing.requires_testing_data def test_morph_source_spaces(): diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index 479bba3f45b..32243eeeff0 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -419,8 +419,8 @@ def _find_clusters( if show_info is True: if len(thresholds) == 0: warn( - 'threshold["start"] (%s) is more extreme than data ' - "statistics with most extreme value %s" % (threshold["start"], stop) + f'threshold["start"] ({threshold["start"]}) is more extreme ' + f"than data statistics with most extreme value {stop}" ) else: logger.info( @@ -479,7 +479,7 @@ def _find_clusters( len_c = c.stop - c.start elif isinstance(c, tuple): len_c = len(c) - elif c.dtype == bool: + elif c.dtype == np.dtype(bool): len_c = np.sum(c) else: len_c = len(c) @@ -928,8 +928,7 @@ def _permutation_cluster_test( and threshold < 0 ): raise ValueError( - "incompatible tail and threshold signs, got " - "%s and %s" % (tail, threshold) + f"incompatible tail and threshold signs, got {tail} and {threshold}" ) # check dimensions for each group in X (a list at this stage). @@ -956,7 +955,7 @@ def _permutation_cluster_test( # ------------------------------------------------------------- t_obs = stat_fun(*X) _validate_type(t_obs, np.ndarray, "return value of stat_fun") - logger.info("stat_fun(H1): min=%f max=%f" % (np.min(t_obs), np.max(t_obs))) + logger.info(f"stat_fun(H1): min={np.min(t_obs)} max={np.max(t_obs)}") # test if stat_fun treats variables independently if buffer_size is not None: @@ -976,9 +975,8 @@ def _permutation_cluster_test( # The stat should have the same shape as the samples for no adj. if t_obs.size != np.prod(sample_shape): raise ValueError( - "t_obs.shape %s provided by stat_fun %s is not " - "compatible with the sample shape %s" - % (t_obs.shape, stat_fun, sample_shape) + f"t_obs.shape {t_obs.shape} provided by stat_fun {stat_fun} is not " + f"compatible with the sample shape {sample_shape}" ) if adjacency is None or adjacency is False: t_obs.shape = sample_shape @@ -1138,14 +1136,14 @@ def _check_fun(X, stat_fun, threshold, tail=0, kind="within"): if stat_fun is not None and stat_fun is not ttest_1samp_no_p: warn( "Automatic threshold is only valid for stat_fun=None " - "(or ttest_1samp_no_p), got %s" % (stat_fun,) + f"(or ttest_1samp_no_p), got {stat_fun}" ) p_thresh = 0.05 / (1 + (tail == 0)) n_samples = len(X) threshold = -tstat.ppf(p_thresh, n_samples - 1) if np.sign(tail) < 0: threshold = -threshold - logger.info("Using a threshold of {:.6f}".format(threshold)) + logger.info(f"Using a threshold of {threshold:.6f}") stat_fun = ttest_1samp_no_p if stat_fun is None else stat_fun else: assert kind == "between" @@ -1153,7 +1151,7 @@ def _check_fun(X, stat_fun, threshold, tail=0, kind="within"): if stat_fun is not None and stat_fun is not f_oneway: warn( "Automatic threshold is only valid for stat_fun=None " - "(or f_oneway), got %s" % (stat_fun,) + f"(or f_oneway), got {stat_fun}" ) elif tail != 1: warn('Ignoring argument "tail", performing 1-tailed F-test') @@ -1161,7 +1159,7 @@ def _check_fun(X, stat_fun, threshold, tail=0, kind="within"): dfn = len(X) - 1 dfd = np.sum([len(x) for x in X]) - len(X) threshold = fstat.ppf(1.0 - p_thresh, dfn, dfd) - logger.info("Using a threshold of {:.6f}".format(threshold)) + logger.info(f"Using a threshold of {threshold:.6f}") stat_fun = f_oneway if stat_fun is None else stat_fun return stat_fun, threshold @@ -1634,7 +1632,7 @@ def _reshape_clusters(clusters, sample_shape): """Reshape cluster masks or indices to be of the correct shape.""" # format of the bool mask and indices are ndarrays if len(clusters) > 0 and isinstance(clusters[0], np.ndarray): - if clusters[0].dtype == bool: # format of mask + if clusters[0].dtype == np.dtype(bool): # format of mask clusters = [c.reshape(sample_shape) for c in clusters] else: # format of indices clusters = [np.unravel_index(c, sample_shape) for c in clusters] diff --git a/mne/stats/parametric.py b/mne/stats/parametric.py index e777bd7f53e..0da2d2d0732 100644 --- a/mne/stats/parametric.py +++ b/mne/stats/parametric.py @@ -197,14 +197,13 @@ def _map_effects(n_factors, effects): elif "*" in effects: pass # handle later else: - raise ValueError('"{}" is not a valid option for "effects"'.format(effects)) + raise ValueError(f'"{effects}" is not a valid option for "effects"') if isinstance(effects, list): bad_names = [e for e in effects if e not in factor_names] if len(bad_names) > 1: raise ValueError( - "Effect names: {} are not valid. They should " - "the first `n_factors` ({}) characters from the" - "alphabet".format(bad_names, n_factors) + f"Effect names: {bad_names} are not valid. They should consist of the " + f"first `n_factors` ({n_factors}) characters from the alphabet" ) indices = list(np.arange(2**n_factors - 1)) @@ -402,7 +401,7 @@ def f_mway_rm(data, factor_levels, effects="all", correction=False, return_pvals # numerical imprecision can cause eps=0.99999999999999989 # even with a single category, so never let our degrees of # freedom drop below 1. - df1, df2 = [np.maximum(d[None, :] * eps, 1.0) for d in (df1, df2)] + df1, df2 = (np.maximum(d[None, :] * eps, 1.0) for d in (df1, df2)) if return_pvals: pvals = stats.f(df1, df2).sf(fvals) diff --git a/mne/stats/permutations.py b/mne/stats/permutations.py index 3f515559c72..15c78ae0872 100644 --- a/mne/stats/permutations.py +++ b/mne/stats/permutations.py @@ -146,7 +146,7 @@ def stat_fun(x): rng = check_random_state(random_state) boot_indices = rng.choice(indices, replace=True, size=(n_bootstraps, len(indices))) stat = np.array([stat_fun(arr[inds]) for inds in boot_indices]) - ci = (((1 - ci) / 2) * 100, ((1 - ((1 - ci) / 2))) * 100) + ci = (((1 - ci) / 2) * 100, (1 - ((1 - ci) / 2)) * 100) ci_low, ci_up = np.percentile(stat, ci, axis=0) return np.array([ci_low, ci_up]) diff --git a/mne/stats/regression.py b/mne/stats/regression.py index 39bd8e63d95..c9c6c63a5dc 100644 --- a/mne/stats/regression.py +++ b/mne/stats/regression.py @@ -89,9 +89,7 @@ def linear_regression(inst, design_matrix, names=None): data = np.array([i.data for i in inst]) else: raise ValueError("Input must be epochs or iterable of source " "estimates") - logger.info( - msg + ", (%s targets, %s regressors)" % (np.prod(data.shape[1:]), len(names)) - ) + logger.info(msg + f", ({np.prod(data.shape[1:])} targets, {len(names)} regressors)") lm_params = _fit_lm(data, design_matrix, names) lm = namedtuple("lm", "beta stderr t_val p_val mlog10_p_val") lm_fits = {} @@ -266,7 +264,7 @@ def linear_regression_raw( """ if isinstance(solver, str): if solver not in {"cholesky"}: - raise ValueError("No such solver: {}".format(solver)) + raise ValueError(f"No such solver: {solver}") if solver == "cholesky": def solver(X, y): @@ -361,7 +359,7 @@ def _prepare_rerp_preds( else: tmin_s = {cond: int(round(tmin.get(cond, -0.1) * sfreq)) for cond in conds} if isinstance(tmax, (float, int)): - tmax_s = {cond: int(round((tmax * sfreq)) + 1) for cond in conds} + tmax_s = {cond: int(round(tmax * sfreq) + 1) for cond in conds} else: tmax_s = {cond: int(round(tmax.get(cond, 1.0) * sfreq)) + 1 for cond in conds} @@ -388,9 +386,9 @@ def _prepare_rerp_preds( covs = covariates[cond] if len(covs) != len(events): error = ( - "Condition {0} from ``covariates`` is " - "not the same length as ``events``" - ).format(cond) + f"Condition {cond} from ``covariates`` is not the same length as " + "``events``" + ) raise ValueError(error) onsets = -(events[np.where(covs != 0), 0] + tmin_)[0] v = np.asarray(covs)[np.nonzero(covs)].astype(float) diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py index d0fe0672bde..1b020d11d28 100644 --- a/mne/stats/tests/test_cluster_level.py +++ b/mne/stats/tests/test_cluster_level.py @@ -96,7 +96,10 @@ def test_thresholds(numba_conditional): # nan handling in TFCE X = np.repeat(X[0], 2, axis=1) X[:, 1] = 0 - with pytest.warns(RuntimeWarning, match="invalid value"): # NumPy + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="invalid value"), + ): # NumPy out = permutation_cluster_1samp_test( X, seed=0, threshold=dict(start=0, step=0.1), out_type="mask" ) @@ -140,7 +143,7 @@ def test_cache_dir(tmp_path, numba_conditional): # ensure that non-independence yields warning stat_fun = partial(ttest_1samp_no_p, sigma=1e-3) random_state = np.random.default_rng(0) - with pytest.warns(RuntimeWarning, match="independently"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="independently"): permutation_cluster_1samp_test( X, buffer_size=10, @@ -509,7 +512,7 @@ def test_cluster_permutation_with_adjacency(numba_conditional, monkeypatch): assert np.min(out_adjacency_6[2]) < 0.05 with pytest.raises(ValueError, match="not compatible"): - with pytest.warns(RuntimeWarning, match="No clusters"): + with _record_warnings(): spatio_temporal_func( X1d_3, n_permutations=50, @@ -610,7 +613,7 @@ def test_permutation_adjacency_equiv(numba_conditional): ) # make sure our output datatype is correct assert isinstance(clusters[0], np.ndarray) - assert clusters[0].dtype == bool + assert clusters[0].dtype == np.dtype(bool) assert_array_equal(clusters[0].shape, X.shape[1:]) # make sure all comparisons were done; for TFCE, no perm @@ -847,7 +850,7 @@ def test_output_equiv(shape, out_type, adjacency): assert isinstance(clu[0], slice) else: assert isinstance(clu, np.ndarray) - assert clu.dtype == bool + assert clu.dtype == np.dtype(bool) assert clu.shape == shape got_mask[clu] = n else: diff --git a/mne/stats/tests/test_parametric.py b/mne/stats/tests/test_parametric.py index e1d64583777..de7aa237c40 100644 --- a/mne/stats/tests/test_parametric.py +++ b/mne/stats/tests/test_parametric.py @@ -148,14 +148,14 @@ def test_ttest_equiv(kind, kwargs, sigma, seed): rng = np.random.RandomState(seed) def theirs(*a, **kw): - f = getattr(scipy.stats, "ttest_%s" % (kind,)) + f = getattr(scipy.stats, f"ttest_{kind}") if kind == "1samp": func = partial(f, popmean=0, **kwargs) else: func = partial(f, **kwargs) return func(*a, **kw)[0] - ours = partial(getattr(mne.stats, "ttest_%s_no_p" % (kind,)), sigma=sigma, **kwargs) + ours = partial(getattr(mne.stats, f"ttest_{kind}_no_p"), sigma=sigma, **kwargs) X = rng.randn(3, 4, 5) if kind == "ind": diff --git a/mne/surface.py b/mne/surface.py index 285d6ab0be1..0334ee12ab0 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -122,7 +122,7 @@ def _get_head_surface(subject, source, subjects_dir, on_defects, raise_error=Tru surf = None for this_source in source: this_head = op.realpath( - op.join(subjects_dir, subject, "bem", "%s-%s.fif" % (subject, this_source)) + op.join(subjects_dir, subject, "bem", f"{subject}-{this_source}.fif") ) if op.exists(this_head): surf = read_bem_surfaces( @@ -137,7 +137,7 @@ def _get_head_surface(subject, source, subjects_dir, on_defects, raise_error=Tru path = op.join(subjects_dir, subject, "bem") if not op.isdir(path): raise OSError('Subject bem directory "%s" does not exist.' % path) - files = sorted(glob(op.join(path, "%s*%s.fif" % (subject, this_source)))) + files = sorted(glob(op.join(path, f"{subject}*{this_source}.fif"))) for this_head in files: try: surf = read_bem_surfaces( @@ -157,8 +157,8 @@ def _get_head_surface(subject, source, subjects_dir, on_defects, raise_error=Tru if surf is None: if raise_error: raise OSError( - 'No file matching "%s*%s" and containing a head ' - "surface found." % (subject, this_source) + f'No file matching "{subject}*{this_source}" and containing a head ' + "surface found." ) else: return surf @@ -1032,7 +1032,7 @@ def _read_patch(fname): # This is adapted from PySurfer PR #269, Bruce Fischl's read_patch.m, # and PyCortex (BSD) patch = dict() - with open(fname, "r") as fid: + with open(fname) as fid: ver = np.fromfile(fid, dtype=">i4", count=1).item() if ver != -1: raise RuntimeError(f"incorrect version # {ver} (not -1) found") @@ -1454,9 +1454,7 @@ def _decimate_surface_sphere(rr, tris, n_triangles): ) func_map = dict(ico=_get_ico_surface, oct=_tessellate_sphere_surf) kind, level = map_[n_triangles] - logger.info( - "Decimating using Freesurfer spherical %s%s downsampling" % (kind, level) - ) + logger.info(f"Decimating using Freesurfer spherical {kind}{level} downsampling") ico_surf = func_map[kind](level) assert len(ico_surf["tris"]) == n_triangles tempdir = _TempDir() @@ -1539,8 +1537,8 @@ def decimate_surface(points, triangles, n_triangles, method="quadric", *, verbos _check_option("method", method, sorted(method_map)) if n_triangles > len(triangles): raise ValueError( - "Requested n_triangles (%s) exceeds number of " - "original triangles (%s)" % (n_triangles, len(triangles)) + f"Requested n_triangles ({n_triangles}) exceeds number of " + f"original triangles ({len(triangles)})" ) return method_map[method](points, triangles, n_triangles) @@ -1802,7 +1800,7 @@ def read_tri(fname_in, swap=False, verbose=None): ----- .. versionadded:: 0.13.0 """ - with open(fname_in, "r") as fid: + with open(fname_in) as fid: lines = fid.readlines() n_nodes = int(lines[0]) n_tris = int(lines[n_nodes + 1]) @@ -1829,8 +1827,7 @@ def read_tri(fname_in, swap=False, verbose=None): tris[:, [2, 1]] = tris[:, [1, 2]] tris -= 1 logger.info( - "Loaded surface from %s with %s nodes and %s triangles." - % (fname_in, n_nodes, n_tris) + f"Loaded surface from {fname_in} with {n_nodes} nodes and {n_tris} triangles." ) if n_items in [3, 4]: logger.info("Node normals were not included in the source file.") @@ -1843,7 +1840,7 @@ def read_tri(fname_in, swap=False, verbose=None): def _get_solids(tri_rrs, fros): """Compute _sum_solids_div total angle in chunks.""" # NOTE: This incorporates the division by 4PI that used to be separate - tot_angle = np.zeros((len(fros))) + tot_angle = np.zeros(len(fros)) for ti in range(len(tri_rrs)): tri_rr = tri_rrs[ti] v1 = fros - tri_rr[0] @@ -2175,7 +2172,7 @@ def _get_neighbors(loc, image, voxels, thresh, dist_params): next_loc = tuple(next_loc) if ( image[next_loc] > thresh - and image[next_loc] < image[loc] + and image[next_loc] <= image[loc] and next_loc not in voxels ): neighbors.add(next_loc) diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 35ffca3d09a..c968f639e22 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -49,7 +49,7 @@ data_path = testing.data_path(download=False) data_dir = data_path / "MEG" / "sample" -fif_fname = Path(__file__).parent.parent / "io" / "tests" / "data" / "test_raw.fif" +fif_fname = Path(__file__).parents[1] / "io" / "tests" / "data" / "test_raw.fif" first_samps = pytest.mark.parametrize("first_samp", (0, 10000)) edf_reduced = data_path / "EDF" / "test_reduced.edf" edf_annot_only = data_path / "EDF" / "SC4001EC-Hypnogram.edf" @@ -236,7 +236,7 @@ def test_crop(tmp_path): assert_allclose( getattr(raw_concat.annotations, attr), getattr(raw.annotations, attr), - err_msg="Failed for %s:" % (attr,), + err_msg=f"Failed for {attr}:", ) raw.set_annotations(None) # undo @@ -278,7 +278,9 @@ def test_crop(tmp_path): assert raw_read.annotations is not None assert len(raw_read.annotations.onset) == 0 # test saving and reloading cropped annotations in raw instance - info = create_info([f"EEG{i+1}" for i in range(3)], ch_types=["eeg"] * 3, sfreq=50) + info = create_info( + [f"EEG{i + 1}" for i in range(3)], ch_types=["eeg"] * 3, sfreq=50 + ) raw = RawArray(np.zeros((3, 50 * 20)), info) annotation = mne.Annotations([8, 12, 15], [2] * 3, [1, 2, 3]) raw = raw.set_annotations(annotation) @@ -425,7 +427,11 @@ def test_raw_reject(first_samp): with pytest.warns(RuntimeWarning, match="outside the data range"): raw.set_annotations(Annotations([2, 100, 105, 148], [2, 8, 5, 8], "BAD")) data, times = raw.get_data( - [0, 1, 3, 4], 100, 11200, "omit", return_times=True # 1-112 s + [0, 1, 3, 4], + 100, + 11200, + "omit", + return_times=True, # 1-112 s ) bad_times = np.concatenate( [np.arange(200, 400), np.arange(10000, 10800), np.arange(10500, 11000)] @@ -813,6 +819,49 @@ def test_events_from_annot_onset_alingment(): assert raw.first_samp == event_latencies[0, 0] +@pytest.mark.parametrize( + "use_rounding,tol,shape,onsets,descriptions", + [ + pytest.param(True, 0, (2, 3), [202, 402], [0, 2], id="rounding-notol"), + pytest.param(True, 1e-8, (3, 3), [202, 302, 402], [0, 1, 2], id="rounding-tol"), + pytest.param(False, 0, (2, 3), [202, 401], [0, 2], id="norounding-notol"), + pytest.param( + False, 1e-8, (3, 3), [202, 302, 401], [0, 1, 2], id="norounding-tol" + ), + pytest.param(None, None, (3, 3), [202, 302, 402], [0, 1, 2], id="default"), + ], +) +def test_events_from_annot_with_tolerance( + use_rounding, tol, shape, onsets, descriptions +): + """Test events_from_annotations w/ and w/o tolerance.""" + info = create_info(ch_names=1, sfreq=100) + raw = RawArray(data=np.empty((1, 1000)), info=info, first_samp=0) + meas_date = _handle_meas_date(0) + with raw.info._unlock(check_after=True): + raw.info["meas_date"] = meas_date + chunk_duration = 1 + annot = Annotations([2.02, 3.02, 4.02], chunk_duration, ["0", "1", "2"], 0) + raw.set_annotations(annot) + event_id = {"0": 0, "1": 1, "2": 2} + + if use_rounding is None: + events, _ = events_from_annotations( + raw, event_id=event_id, chunk_duration=chunk_duration + ) + else: + events, _ = events_from_annotations( + raw, + event_id=event_id, + chunk_duration=chunk_duration, + use_rounding=use_rounding, + tol=tol, + ) + assert events.shape == shape + assert (events[:, 0] == onsets).all() + assert (events[:, 2] == descriptions).all() + + def _create_annotation_based_on_descr( description, annotation_start_sampl=0, duration=0, orig_time=0 ): @@ -1200,7 +1249,7 @@ def test_date_none(tmp_path): n_chans = 139 n_samps = 20 data = np.random.random_sample((n_chans, n_samps)) - ch_names = ["E{}".format(x) for x in range(n_chans)] + ch_names = [f"E{x}" for x in range(n_chans)] ch_types = ["eeg"] * n_chans info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=2048) assert info["meas_date"] is None @@ -1246,7 +1295,7 @@ def test_crop_when_negative_orig_time(windows_like_datetime): assert len(annot) == 10 # Crop with negative tmin, tmax - tmin, tmax = [orig_time_stamp + t for t in (0.25, 0.75)] + tmin, tmax = (orig_time_stamp + t for t in (0.25, 0.75)) assert tmin < 0 and tmax < 0 crop_annot = annot.crop(tmin=tmin, tmax=tmax) assert_allclose(crop_annot.onset, [0.3, 0.4, 0.5, 0.6, 0.7]) @@ -1349,7 +1398,7 @@ def test_annotations_from_events(): # 4. Try passing callable # ------------------------------------------------------------------------- - event_desc = lambda d: "event{}".format(d) # noqa:E731 + event_desc = lambda d: f"event{d}" # noqa:E731 annots = annotations_from_events( events, sfreq=raw.info["sfreq"], @@ -1412,7 +1461,8 @@ def test_repr(): assert r == "" -def test_annotation_to_data_frame(): +@pytest.mark.parametrize("time_format", (None, "ms", "datetime", "timedelta")) +def test_annotation_to_data_frame(time_format): """Test annotation class to data frame conversion.""" pytest.importorskip("pandas") onset = np.arange(1, 10) @@ -1423,11 +1473,15 @@ def test_annotation_to_data_frame(): onset=onset, duration=durations, description=description, orig_time=0 ) - df = a.to_data_frame() + df = a.to_data_frame(time_format=time_format) for col in ["onset", "duration", "description"]: assert col in df.columns assert df.description[0] == "yy" - assert (df.onset[1] - df.onset[0]).seconds == 1 + want = 1000 if time_format == "ms" else 1 + got = df.onset[1] - df.onset[0] + if time_format in ("datetime", "timedelta"): + got = got.seconds + assert want == got assert df.groupby("description").count().onset["yy"] == 9 diff --git a/mne/tests/test_bem.py b/mne/tests/test_bem.py index 0dd682606f6..3217205ba9f 100644 --- a/mne/tests/test_bem.py +++ b/mne/tests/test_bem.py @@ -37,16 +37,16 @@ _ico_downsample, _order_surfaces, distance_to_bem, + fit_sphere_to_headshape, make_scalp_surfaces, ) from mne.datasets import testing from mne.io import read_info -from mne.preprocessing.maxfilter import fit_sphere_to_headshape from mne.surface import _get_ico_surface, read_surface from mne.transforms import translation -from mne.utils import catch_logging, check_version +from mne.utils import _record_warnings, catch_logging, check_version -fname_raw = Path(__file__).parent.parent / "io" / "tests" / "data" / "test_raw.fif" +fname_raw = Path(__file__).parents[1] / "io" / "tests" / "data" / "test_raw.fif" subjects_dir = testing.data_path(download=False) / "subjects" fname_bem_3 = subjects_dir / "sample" / "bem" / "sample-320-320-320-bem.fif" fname_bem_1 = subjects_dir / "sample" / "bem" / "sample-320-bem.fif" @@ -54,6 +54,8 @@ fname_bem_sol_1 = subjects_dir / "sample" / "bem" / "sample-320-bem-sol.fif" fname_dense_head = subjects_dir / "sample" / "bem" / "sample-head-dense.fif" +_few_points = pytest.warns(RuntimeWarning, match="Only .* head digitization") + def _compare_bem_surfaces(surfs_1, surfs_2): """Compare BEM surfaces.""" @@ -414,7 +416,7 @@ def test_fit_sphere_to_headshape(): # # Test with 4 points that match a perfect sphere dig_kinds = (FIFF.FIFFV_POINT_CARDINAL, FIFF.FIFFV_POINT_EXTRA) - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with _few_points: r, oh, od = fit_sphere_to_headshape(info, dig_kinds=dig_kinds, units="m") kwargs = dict(rtol=1e-3, atol=1e-5) assert_allclose(r, rad, **kwargs) @@ -424,7 +426,7 @@ def test_fit_sphere_to_headshape(): # Test with all points dig_kinds = ("cardinal", FIFF.FIFFV_POINT_EXTRA, "eeg") kwargs = dict(rtol=1e-3, atol=1e-3) - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with _few_points: r, oh, od = fit_sphere_to_headshape(info, dig_kinds=dig_kinds, units="m") assert_allclose(r, rad, **kwargs) assert_allclose(oh, center, **kwargs) @@ -432,7 +434,7 @@ def test_fit_sphere_to_headshape(): # Test with some noisy EEG points only. dig_kinds = "eeg" - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with _few_points: r, oh, od = fit_sphere_to_headshape(info, dig_kinds=dig_kinds, units="m") kwargs = dict(rtol=1e-3, atol=1e-2) assert_allclose(r, rad, **kwargs) @@ -446,7 +448,7 @@ def test_fit_sphere_to_headshape(): d["r"] -= center d["r"] *= big_rad / rad d["r"] += center - with pytest.warns(RuntimeWarning, match="Estimated head radius"): + with _few_points, pytest.warns(RuntimeWarning, match="Estimated head radius"): r, oh, od = fit_sphere_to_headshape(info_big, dig_kinds=dig_kinds, units="mm") assert_allclose(oh, center * 1000, atol=1e-3) assert_allclose(r, big_rad * 1000, atol=1e-3) @@ -459,27 +461,33 @@ def test_fit_sphere_to_headshape(): for d in info_shift["dig"]: d["r"] -= center d["r"] += shift_center - with pytest.warns(RuntimeWarning, match="from head frame origin"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="from head frame origin"), + ): r, oh, od = fit_sphere_to_headshape(info_shift, dig_kinds=dig_kinds, units="m") assert_allclose(oh, shift_center, atol=1e-6) assert_allclose(r, rad, atol=1e-6) # Test "auto" mode (default) # Should try "extra", fail, and go on to EEG - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with _few_points: r, oh, od = fit_sphere_to_headshape(info, units="m") kwargs = dict(rtol=1e-3, atol=1e-3) assert_allclose(r, rad, **kwargs) assert_allclose(oh, center, **kwargs) assert_allclose(od, dev_center, **kwargs) - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with _few_points: r2, oh2, od2 = fit_sphere_to_headshape(info, units="m") assert_allclose(r, r2, atol=1e-7) assert_allclose(oh, oh2, atol=1e-7) assert_allclose(od, od2, atol=1e-7) # this one should pass, 1 EXTRA point and 3 EEG (but the fit is terrible) info = Info(dig=dig[:7], dev_head_t=dev_head_t) - with pytest.warns(RuntimeWarning, match="Only .* head digitization"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="Estimated head radius"), + ): r, oh, od = fit_sphere_to_headshape(info, units="m") # this one should fail, 1 EXTRA point and 3 EEG (but the fit is terrible) info = Info(dig=dig[:6], dev_head_t=dev_head_t) @@ -499,12 +507,12 @@ def test_io_head_bem(tmp_path): with pytest.raises(ValueError, match="topological defects:"): write_head_bem(fname_defect, head["rr"], head["tris"]) - with pytest.warns(RuntimeWarning, match="topological defects:"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="topological defects:"): write_head_bem(fname_defect, head["rr"], head["tris"], on_defects="warn") # test on_defects in read_bem_surfaces with pytest.raises(ValueError, match="topological defects:"): read_bem_surfaces(fname_defect) - with pytest.warns(RuntimeWarning, match="topological defects:"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="topological defects:"): head_defect = read_bem_surfaces(fname_defect, on_defects="warn")[0] assert head["id"] == head_defect["id"] == FIFF.FIFFV_BEM_SURF_ID_HEAD @@ -550,12 +558,15 @@ def _decimate_surface(points, triangles, n_triangles): # These are ignorable monkeypatch.setattr(mne.bem, "_tri_levels", dict(sparse=315)) - with pytest.warns(RuntimeWarning, match=".*have fewer than three.*"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match=".*have fewer than three.*"), + ): make_scalp_surfaces(subject, subjects_dir, force=True, overwrite=True) (surf,) = read_bem_surfaces(sparse_path, on_defects="ignore") assert len(surf["tris"]) == 315 monkeypatch.setattr(mne.bem, "_tri_levels", dict(sparse=319)) - with pytest.warns(RuntimeWarning, match=".*is not complete.*"): + with _record_warnings(), pytest.warns(RuntimeWarning, match=".*is not complete.*"): make_scalp_surfaces(subject, subjects_dir, force=True, overwrite=True) (surf,) = read_bem_surfaces(sparse_path, on_defects="ignore") assert len(surf["tris"]) == 319 diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 3e0e3fb1e87..cb9ccc60c26 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -43,10 +43,16 @@ ) from mne.simulation import add_chpi from mne.transforms import _angle_between_quats, rot_to_quat -from mne.utils import assert_meg_snr, catch_logging, object_diff, verbose +from mne.utils import ( + _record_warnings, + assert_meg_snr, + catch_logging, + object_diff, + verbose, +) from mne.viz import plot_head_positions -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" ctf_fname = base_dir / "test_ctf_raw.fif" hp_fif_fname = base_dir / "test_chpi_raw_sss.fif" hp_fname = base_dir / "test_chpi_raw_hp.txt" @@ -204,7 +210,7 @@ def _assert_quats( # maxfilter produces some times that are implausibly large (weird) if not np.isclose(t[0], t_est[0], atol=1e-1): # within 100 ms raise AssertionError( - "Start times not within 100 ms: %0.3f != %0.3f" % (t[0], t_est[0]) + f"Start times not within 100 ms: {t[0]:0.3f} != {t_est[0]:0.3f}" ) use_mask = (t >= t_est[0]) & (t <= t_est[-1]) t = t[use_mask] @@ -223,10 +229,9 @@ def _assert_quats( distances = np.sqrt(np.sum((trans - trans_est_interp) ** 2, axis=1)) assert np.isfinite(distances).all() arg_worst = np.argmax(distances) - assert distances[arg_worst] <= dist_tol, "@ %0.3f seconds: %0.3f > %0.3f mm" % ( - t[arg_worst], - 1000 * distances[arg_worst], - 1000 * dist_tol, + assert distances[arg_worst] <= dist_tol, ( + f"@ {t[arg_worst]:0.3f} seconds: " + f"{1000 * distances[arg_worst]:0.3f} > {1000 * dist_tol:0.3f} mm" ) # limit rotation difference between MF and our estimation @@ -234,10 +239,9 @@ def _assert_quats( quats_est_interp = interp1d(t_est, quats_est, axis=0)(t) angles = 180 * _angle_between_quats(quats_est_interp, quats) / np.pi arg_worst = np.argmax(angles) - assert angles[arg_worst] <= angle_tol, "@ %0.3f seconds: %0.3f > %0.3f deg" % ( - t[arg_worst], - angles[arg_worst], - angle_tol, + assert angles[arg_worst] <= angle_tol, ( + f"@ {t[arg_worst]:0.3f} seconds: " + f"{angles[arg_worst]:0.3f} > {angle_tol:0.3f} deg" ) # error calculation difference @@ -366,7 +370,7 @@ def test_calculate_chpi_positions_vv(): ] ) raw_bad.pick([raw_bad.ch_names[pick] for pick in picks]) - with pytest.warns(RuntimeWarning, match="Discrepancy"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Discrepancy"): with catch_logging() as log_file: _calculate_chpi_positions(raw_bad, t_step_min=1.0, verbose=True) # ignore HPI info header and [done] footer diff --git a/mne/tests/test_coreg.py b/mne/tests/test_coreg.py index af5801114a9..5f4c58fa8a5 100644 --- a/mne/tests/test_coreg.py +++ b/mne/tests/test_coreg.py @@ -218,7 +218,7 @@ def test_scale_mri_xfm(tmp_path, few_surfaces, subjects_dir_tmp_few): subjects_dir_tmp_few / subject_from / "bem" - / ("%s-%s-src.fif" % (subject_from, spacing)) + / (f"{subject_from}-{spacing}-src.fif") ) src_from = mne.setup_source_space( subject_from, @@ -273,7 +273,7 @@ def test_scale_mri_xfm(tmp_path, few_surfaces, subjects_dir_tmp_few): subjects_dir_tmp_few / subject_to / "bem" - / ("%s-%s-src.fif" % (subject_to, spacing)) + / (f"{subject_to}-{spacing}-src.fif") ) assert src_to_fname.exists(), "Source space was not scaled" # Check MRI scaling diff --git a/mne/tests/test_cov.py b/mne/tests/test_cov.py index 5398c07ace7..d23452a6a0b 100644 --- a/mne/tests/test_cov.py +++ b/mne/tests/test_cov.py @@ -53,7 +53,7 @@ from mne.rank import _compute_rank_int from mne.utils import _record_warnings, assert_snr, catch_logging -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" cov_fname = base_dir / "test-cov.fif" cov_gz_fname = base_dir / "test-cov.fif.gz" cov_km_fname = base_dir / "test-km-cov.fif" @@ -294,7 +294,7 @@ def test_cov_estimation_on_raw(method, tmp_path): try: import sklearn # noqa: F401 except Exception as exp: - pytest.skip("sklearn is required, got %s" % (exp,)) + pytest.skip(f"sklearn is required, got {exp}") raw = read_raw_fif(raw_fname, preload=True) cov_mne = read_cov(erm_cov_fname) method_params = dict(shrunk=dict(shrinkage=[0])) @@ -352,7 +352,7 @@ def test_cov_estimation_on_raw(method, tmp_path): assert_snr(cov.data, cov_mne.data[:5, :5], 90) # cutoff samps # make sure we get a warning with too short a segment raw_2 = read_raw_fif(raw_fname).crop(0, 1) - with pytest.warns(RuntimeWarning, match="Too few samples"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Too few samples"): cov = compute_raw_covariance(raw_2, method=method, method_params=method_params) # no epochs found due to rejection pytest.raises( @@ -384,7 +384,7 @@ def test_cov_estimation_on_raw_reg(): raw.info["sfreq"] /= 10.0 raw = RawArray(raw._data[:, ::10].copy(), raw.info) # decimate for speed cov_mne = read_cov(erm_cov_fname) - with pytest.warns(RuntimeWarning, match="Too few samples"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Too few samples"): # "diagonal_fixed" is much faster. Use long epochs for speed. cov = compute_raw_covariance(raw, tstep=5.0, method="diagonal_fixed") assert_snr(cov.data, cov_mne.data, 5) @@ -393,7 +393,7 @@ def test_cov_estimation_on_raw_reg(): def _assert_cov(cov, cov_desired, tol=0.005, nfree=True): assert_equal(cov.ch_names, cov_desired.ch_names) err = np.linalg.norm(cov.data - cov_desired.data) / np.linalg.norm(cov.data) - assert err < tol, "%s >= %s" % (err, tol) + assert err < tol, f"{err} >= {tol}" if nfree: assert_equal(cov.nfree, cov_desired.nfree) @@ -891,13 +891,13 @@ def test_cov_ctf(): for comp in [0, 1]: raw.apply_gradient_compensation(comp) epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) - with pytest.warns(RuntimeWarning, match="Too few samples"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Too few samples"): noise_cov = compute_covariance(epochs, tmax=0.0, method=["empirical"]) prepare_noise_cov(noise_cov, raw.info, ch_names) raw.apply_gradient_compensation(0) epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) - with pytest.warns(RuntimeWarning, match="Too few samples"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="Too few samples"): noise_cov = compute_covariance(epochs, tmax=0.0, method=["empirical"]) raw.apply_gradient_compensation(1) diff --git a/mne/tests/test_dipole.py b/mne/tests/test_dipole.py index 73aaeb7ad68..8f7c9508024 100644 --- a/mne/tests/test_dipole.py +++ b/mne/tests/test_dipole.py @@ -215,9 +215,8 @@ def test_dipole_fitting(tmp_path): # Sanity check: do our residuals have less power than orig data? data_rms = np.sqrt(np.sum(evoked.data**2, axis=0)) resi_rms = np.sqrt(np.sum(residual.data**2, axis=0)) - assert (data_rms > resi_rms * 0.95).all(), "%s (factor: %s)" % ( - (data_rms / resi_rms).min(), - 0.95, + assert (data_rms > resi_rms * 0.95).all(), ( + f"{(data_rms / resi_rms).min()} " f"(factor: {0.95})" ) # Compare to original points @@ -560,7 +559,7 @@ def test_bdip(fname_dip_, fname_bdip_, tmp_path): b = getattr(this_bdip, key) if key == "khi2" and dip_has_conf: if d is not None: - assert_allclose(d, b, atol=atol, err_msg="%s: %s" % (kind, key)) + assert_allclose(d, b, atol=atol, err_msg=f"{kind}: {key}") else: assert b is None if dip_has_conf: @@ -574,7 +573,7 @@ def test_bdip(fname_dip_, fname_bdip_, tmp_path): d, b, rtol=0.12, # no so great, text I/O - err_msg="%s: %s" % (kind, key), + err_msg=f"{kind}: {key}", ) # Not stored assert this_bdip.name is None diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index 222165901a3..9e59c7302e7 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -278,14 +278,12 @@ def test_tabs(): whiten_evoked write_fiducials write_info -""".split( - "\n" -) +""".split("\n") def test_documented(): """Test that public functions and classes are documented.""" - doc_dir = (Path(__file__).parent.parent.parent / "doc").absolute() + doc_dir = (Path(__file__).parents[2] / "doc" / "api").absolute() doc_file = doc_dir / "python_reference.rst" if not doc_file.is_file(): pytest.skip("Documentation file not found: %s" % doc_file) @@ -357,9 +355,9 @@ def test_docdict_order(): from mne.utils.docs import docdict # read the file as text, and get entries via regex - docs_path = Path(__file__).parent.parent / "utils" / "docs.py" + docs_path = Path(__file__).parents[1] / "utils" / "docs.py" assert docs_path.is_file(), docs_path - with open(docs_path, "r", encoding="UTF-8") as fid: + with open(docs_path, encoding="UTF-8") as fid: docs = fid.read() entries = re.findall(r'docdict\[(?:\n )?["\'](.+)["\']\n?\] = ', docs) # test length & uniqueness diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 0200fb45f79..0bede8b53d4 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -64,6 +64,7 @@ from mne.preprocessing import maxwell_filter from mne.utils import ( _dt_to_stamp, + _record_warnings, assert_meg_snr, catch_logging, object_diff, @@ -76,7 +77,7 @@ fname_raw_movecomp_sss = data_path / "SSS" / "test_move_anon_movecomp_raw_sss.fif" fname_raw_move_pos = data_path / "SSS" / "test_move_anon_raw.pos" -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" evoked_nf_name = base_dir / "test-nf-ave.fif" @@ -549,7 +550,20 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (None, -1): # protect against older MNE-C types + + # Good function + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs) > 0, reasons + + # Bad function + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs), reasons + + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -563,6 +577,44 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) + + # Check that reject and flat in constructor are not callables + val = my_reject_1 + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of numeric, got instead.", + ): + Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) + + # Check if callable returns a tuple with reasons + bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None] + for val in bad_types: # protect against bad types + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of .* got instead.", + ): + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=True, + ) + epochs.drop_bad(**{kwarg: dict(grad=val)}) + pytest.raises( KeyError, Epochs, @@ -992,6 +1044,26 @@ def test_filter(tmp_path): assert_allclose(epochs.get_data(), data_filt, atol=1e-17) +def test_epochs_from_annotations(): + """Test epoch instantiation using annotations.""" + raw, events = _get_data()[:2] + with pytest.raises( + RuntimeError, match="No usable annotations found in the raw object" + ): + Epochs(raw) + raw.set_annotations( + mne.annotations_from_events( + events, raw.info["sfreq"], first_samp=raw.first_samp + ) + ) + # test on_missing + with pytest.raises(ValueError, match="No matching annotations"): + Epochs(raw, event_id="foo") + # test on_missing warn + with pytest.warns(match="No matching annotations"): + Epochs(raw, event_id=["1", "foo"], on_missing="warn") + + def test_epochs_hash(): """Test epoch hashing.""" raw, events = _get_data()[:2] @@ -1594,43 +1666,79 @@ def test_split_saving_and_loading_back(tmp_path, epochs_to_split, preload): @pytest.mark.parametrize( - "split_naming, dst_fname, split_fname_fn", + "split_naming, dst_fname, split_fname_fn, check_bids", [ ( "neuromag", "test_epo.fif", lambda i: f"test_epo-{i}.fif" if i else "test_epo.fif", + False, ), ( "bids", - "test_epo.fif", - lambda i: f"test_split-{i + 1:02d}_epo.fif", + Path("sub-01") / "meg" / "sub-01_epo.fif", + lambda i: Path("sub-01") / "meg" / f"sub-01_split-{i + 1:02d}_epo.fif", + True, ), ( "bids", "a_b-epo.fif", # Merely stating the fact: lambda i: f"a_split-{i + 1:02d}_b-epo.fif", + False, ), ], ids=["neuromag", "bids", "mix"], ) def test_split_naming( - tmp_path, epochs_to_split, split_naming, dst_fname, split_fname_fn + tmp_path, epochs_to_split, split_naming, dst_fname, split_fname_fn, check_bids ): """Test naming of the split files.""" epochs, split_size, n_files = epochs_to_split dst_fpath = tmp_path / dst_fname save_kwargs = {"split_size": split_size, "split_naming": split_naming} # we don't test for reserved files as it's not implemented here + if dst_fpath.parent != tmp_path: + dst_fpath.parent.mkdir(parents=True) epochs.save(dst_fpath, verbose=True, **save_kwargs) # check that the filenames match the intended pattern - assert len(list(tmp_path.iterdir())) == n_files - for i in range(n_files): - assert (tmp_path / split_fname_fn(i)).is_file() + assert len(list(dst_fpath.parent.iterdir())) == n_files assert not (tmp_path / split_fname_fn(n_files)).is_file() + want_paths = [tmp_path / split_fname_fn(i) for i in range(n_files)] + for want_path in want_paths: + assert want_path.is_file() + + if not check_bids: + return + # gh-12451 + # If we load sub-01_split-01_epo.fif we should then we shouldn't + # write sub-01_split-01_split-01_epo.fif + mne_bids = pytest.importorskip("mne_bids") + # Let's try to prevent people from making a mistake + bids_path = mne_bids.BIDSPath( + root=tmp_path, + subject="01", + datatype="meg", + split="01", + suffix="epo", + extension=".fif", + check=False, + ) + assert bids_path.fpath.is_file(), bids_path.fpath + for want_path in want_paths: + want_path.unlink() + assert not bids_path.fpath.is_file() + with pytest.raises(ValueError, match="Passing a BIDSPath"): + epochs.save(bids_path, verbose=True, **save_kwargs) + bad_path = bids_path.fpath.parent / (bids_path.fpath.stem[:-3] + "split-01_epo.fif") + assert str(bad_path).count("_split-01") == 2 + assert not bad_path.is_file(), bad_path + bids_path.split = None + epochs.save(bids_path, verbose=True, **save_kwargs) + for want_path in want_paths: + assert want_path.is_file() @pytest.mark.parametrize( @@ -1766,7 +1874,7 @@ def _assert_splits(fname, n, size): bad_fname = next_fnames.pop(-1) for ii, this_fname in enumerate(next_fnames[:-1]): assert this_fname.is_file(), f"Missing file: {this_fname}" - with open(this_fname, "r") as fid: + with open(this_fname) as fid: fid.seek(0, 2) file_size = fid.tell() min_ = 0.1 if ii < len(next_fnames) - 1 else 0.1 @@ -2128,6 +2236,93 @@ def test_reject_epochs(tmp_path): assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"]) +@testing.requires_testing_data +def test_callable_reject(): + """Test using a callable for rejection.""" + raw = read_raw_fif(fname_raw_testing, preload=True) + raw.crop(0, 5) + raw.del_proj() + chans = raw.info["ch_names"][-6:-1] + raw.pick(chans) + data = raw.get_data() + + # Add some artifacts + new_data = data + new_data[0, 180:200] *= 1e7 + new_data[0, 610:880] += 1e-3 + edit_raw = mne.io.RawArray(new_data, raw.info) + + events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) + epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) + assert len(epochs) == 5 + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")) + ) + + assert epochs.drop_log[2] == ("eeg median",) + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))) + ) + + assert epochs.drop_log[0] == ("eeg max",) + + def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return (max_condition.any() or median_condition.any()), "eeg max or median" + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad(reject=dict(eeg=reject_criteria)) + + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ( + "eeg max or median", + ) + + # Test reasons must be str or tuple of str + with pytest.raises( + TypeError, + match=r".* must be an instance of str, got instead.", + ): + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict( + eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2)) + ) + ) + + def test_preload_epochs(): """Test preload of epochs.""" raw, events, picks = _get_data() @@ -2271,7 +2466,7 @@ def test_crop(tmp_path): reject=reject, flat=flat, ) - with pytest.warns(RuntimeWarning, match="tmax is set to"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="tmax is set to"): epochs2.crop(-20, 200) # indices for slicing @@ -2644,25 +2839,58 @@ def test_subtract_evoked(): def test_epoch_eq(): - """Test epoch count equalization and condition combining.""" + """Test for equalize_epoch_counts and equalize_event_counts functions.""" + # load data raw, events, picks = _get_data() - # equalizing epochs objects + # test equalize epoch counts + # create epochs with unequal counts events_1 = events[events[:, 2] == event_id] epochs_1 = Epochs(raw, events_1, event_id, tmin, tmax, picks=picks) events_2 = events[events[:, 2] == event_id_2] epochs_2 = Epochs(raw, events_2, event_id_2, tmin, tmax, picks=picks) + # events 2 has one more event than events 1 epochs_1.drop_bad() # make sure drops are logged + epochs_2.drop_bad() # make sure drops are logged + # make sure there is a difference in the number of events + assert len(epochs_1) != len(epochs_2) + # make sure bad epochs are dropped before equalizing epoch counts assert_equal( len([log for log in epochs_1.drop_log if not log]), len(epochs_1.events) ) - assert epochs_1.drop_log == ((),) * len(epochs_1.events) - assert_equal(len([lg for lg in epochs_1.drop_log if not lg]), len(epochs_1.events)) - assert epochs_1.events.shape[0] != epochs_2.events.shape[0] + assert epochs_2.drop_log == ((),) * len(epochs_2.events) + # test mintime method + events_1[-1, 0] += 60 # hack: ensure mintime drops something other than last trial + # now run equalize_epoch_counts with mintime method equalize_epoch_counts([epochs_1, epochs_2], method="mintime") + # mintime method should give us the smallest difference between timings of epochs + alleged_mintime = np.sum(np.abs(epochs_1.events[:, 0] - epochs_2.events[:, 0])) + # test that "mintime" works as expected, by systematically dropping each event from + # events_2 and ensuring the latencies are actually smallest in the + # equalize_epoch_counts case. NB: len(events_2) > len(events_1) + for idx in range(events_2.shape[0]): + # delete epoch from events_2 + test_events = np.delete(events_2.copy(), idx, axis=0) + assert test_events.shape == epochs_1.events.shape == epochs_2.events.shape + # difference (in samples) between epochs_1 event times and the event times we + # get from our deletion of row `idx` from events_2 + latencies = epochs_1.events[:, 0] - test_events[:, 0] + got_mintime = np.sum(np.abs(latencies)) + assert got_mintime >= alleged_mintime + # make sure the number of events is equal assert_equal(epochs_1.events.shape[0], epochs_2.events.shape[0]) + # create new epochs with the same event ids as epochs_1 and epochs_2 epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) + epochs_3.drop_bad() # make sure drops are logged + epochs_4.drop_bad() # make sure drops are logged + # make sure there is a difference in the number of events + assert len(epochs_3) != len(epochs_4) + # test truncate method equalize_epoch_counts([epochs_3, epochs_4], method="truncate") + if len(epochs_3.events) > len(epochs_4.events): + assert_equal(epochs_3.events[-2, 0], epochs_3.events.shape[-1, 0]) + elif len(epochs_3.events) < len(epochs_4.events): + assert_equal(epochs_4.events[-2, 0], epochs_4.events[-1, 0]) assert_equal(epochs_1.events.shape[0], epochs_3.events.shape[0]) assert_equal(epochs_3.events.shape[0], epochs_4.events.shape[0]) @@ -2927,7 +3155,7 @@ def test_to_data_frame_index(index): # test index order/hierarchy preservation if not isinstance(index, list): index = [index] - assert df.index.names == index + assert list(df.index.names) == index # test that non-indexed data were present as columns non_index = list(set(["condition", "time", "epoch"]) - set(index)) if len(non_index): @@ -3159,9 +3387,16 @@ def test_drop_epochs(): events1 = events[events[:, 2] == event_id] # Bound checks - pytest.raises(IndexError, epochs.drop, [len(epochs.events)]) - pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1]) - pytest.raises(ValueError, epochs.drop, [[1, 2], [3, 4]]) + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): + epochs.drop([len(epochs.events)]) + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): + epochs.drop([-len(epochs.events) - 1]) + with pytest.raises(TypeError, match="indices must be a scalar or a 1-d array"): + epochs.drop([[1, 2], [3, 4]]) + with pytest.raises( + TypeError, match=r".* must be an instance of .* got instead." + ): + epochs.drop([1], reason=("a", "b", 2)) # Test selection attribute assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0]) @@ -3181,6 +3416,18 @@ def test_drop_epochs(): assert_array_equal(events[epochs[3:].selection], events1[[5, 6]]) assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]]) + # Test using tuple to drop epochs + raw, events, picks = _get_data() + epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True) + selection_tuple = epochs_tuple.selection.copy() + epochs_tuple.drop((2, 3, 4), reason=("a", "b")) + n_events = len(epochs.events) + assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [ + ("a", "b"), + ("a", "b"), + ("a", "b"), + ] + @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): @@ -3590,7 +3837,7 @@ def test_concatenate_epochs(): # check concatenating epochs where one of the objects is empty epochs2 = epochs.copy()[:0] - with pytest.warns(RuntimeWarning, match="was empty"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="was empty"): concatenate_epochs([epochs, epochs2]) # check concatenating epochs results are chronologically ordered @@ -4039,8 +4286,19 @@ def test_make_metadata(all_event_id, row_events, tmin, tmax, keep_first, keep_la Epochs(raw, events=events, event_id=event_id, metadata=metadata, verbose="warning") -def test_make_metadata_bounded_by_row_events(): - """Test make_metadata() with tmin, tmax set to None.""" +@pytest.mark.parametrize( + ("tmin", "tmax"), + [ + (None, None), + ("cue", "resp"), + (["cue"], ["resp"]), + (None, "resp"), + ("cue", None), + (["rec_start", "cue"], ["resp", "rec_end"]), + ], +) +def test_make_metadata_bounded_by_row_or_tmin_tmax_event_names(tmin, tmax): + """Test make_metadata() with tmin, tmax set to None or strings.""" pytest.importorskip("pandas") sfreq = 100 @@ -4085,8 +4343,8 @@ def test_make_metadata_bounded_by_row_events(): metadata, events_new, event_id_new = mne.epochs.make_metadata( events=events, event_id=event_id, - tmin=None, - tmax=None, + tmin=tmin, + tmax=tmax, sfreq=raw.info["sfreq"], row_events="cue", ) @@ -4109,8 +4367,15 @@ def test_make_metadata_bounded_by_row_events(): # 2nd trial assert np.isnan(metadata.iloc[1]["rec_end"]) - # 3rd trial until end of the recording - assert metadata.iloc[2]["resp"] < metadata.iloc[2]["rec_end"] + # 3rd trial + if tmax is None: + # until end of the recording + assert metadata.iloc[2]["resp"] < metadata.iloc[2]["rec_end"] + else: + # until tmax + assert np.isnan(metadata.iloc[2]["rec_end"]) + last_event_name = tmax[0] if isinstance(tmax, list) else tmax + assert metadata.iloc[2][last_event_name] > 0 def test_events_list(): @@ -4201,7 +4466,7 @@ def test_no_epochs(tmp_path): # and with no epochs remaining raw.info["bads"] = [] epochs = mne.Epochs(raw, events, reject=reject) - with pytest.warns(RuntimeWarning, match="no data"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="no data"): epochs.save(tmp_path / "sample-epo.fif", overwrite=True) assert len(epochs) == 0 # all dropped @@ -4586,6 +4851,39 @@ def fun(data): assert_array_equal(out.get_data(non_picks), epochs.get_data(non_picks)) +def test_apply_function_epo_ch_access(): + """Test ch-access within apply function to epoch objects.""" + + def _bad_ch_idx(x, ch_idx): + assert x.shape == (46,) + assert x[0] == ch_idx + return x + + def _bad_ch_name(x, ch_name): + assert x.shape == (46,) + assert isinstance(ch_name, str) + assert x[0] == float(ch_name) + return x + + data = np.full((2, 100), np.arange(2).reshape(-1, 1)) + raw = RawArray(data, create_info(2, 1.0, "mag")) + ev = np.array([[0, 0, 33], [50, 0, 33]]) + ep = Epochs(raw, ev, tmin=0, tmax=45, baseline=None, preload=True) + + # test ch_idx access in both code paths (parallel / 1 job) + ep.apply_function(_bad_ch_idx) + ep.apply_function(_bad_ch_idx, n_jobs=2) + ep.apply_function(_bad_ch_name) + ep.apply_function(_bad_ch_name, n_jobs=2) + + # test input catches + with pytest.raises( + ValueError, + match="cannot access.*when channel_wise=False", + ): + ep.apply_function(_bad_ch_idx, channel_wise=False) + + @testing.requires_testing_data def test_add_channels_picks(): """Check that add_channels properly deals with picks.""" diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py index 0d6ad7e0416..7d899291232 100644 --- a/mne/tests/test_event.py +++ b/mne/tests/test_event.py @@ -40,7 +40,7 @@ ) from mne.io import RawArray, read_raw_fif -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" fname = base_dir / "test-eve.fif" fname_raw = base_dir / "test_raw.fif" fname_gz = base_dir / "test-eve.fif.gz" diff --git a/mne/tests/test_evoked.py b/mne/tests/test_evoked.py index 8c1dd5631c4..17586b1a465 100644 --- a/mne/tests/test_evoked.py +++ b/mne/tests/test_evoked.py @@ -23,6 +23,7 @@ from mne import ( Epochs, EpochsArray, + SourceEstimate, combine_evoked, create_info, equalize_channels, @@ -34,9 +35,9 @@ from mne._fiff.constants import FIFF from mne.evoked import Evoked, EvokedArray, _get_peak from mne.io import read_raw_fif -from mne.utils import grand_average +from mne.utils import _record_warnings, grand_average -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" fname = base_dir / "test-ave.fif" fname_gz = base_dir / "test-ave.fif.gz" raw_fname = base_dir / "test_raw.fif" @@ -589,6 +590,24 @@ def test_get_peak(): with pytest.raises(ValueError, match="No positive values"): evoked_all_neg.get_peak(mode="pos") + # Test finding minimum and maximum values + evoked_all_neg_outlier = evoked_all_neg.copy() + evoked_all_pos_outlier = evoked_all_pos.copy() + + # Add an outlier to the data + evoked_all_neg_outlier.data[0, 15] = -1e-20 + evoked_all_pos_outlier.data[0, 15] = 1e-20 + + ch_name, time_idx, max_amp = evoked_all_neg_outlier.get_peak( + mode="pos", return_amplitude=True, strict=False + ) + assert max_amp == -1e-20 + + ch_name, time_idx, min_amp = evoked_all_pos_outlier.get_peak( + mode="neg", return_amplitude=True, strict=False + ) + assert min_amp == 1e-20 + # Test interaction between `mode` and `tmin` / `tmax` # For the test, create an Evoked where half of the values are negative # and the rest is positive @@ -799,7 +818,7 @@ def test_time_as_index_and_crop(): ) evoked.crop(evoked.tmin, evoked.tmax, include_tmax=False) n_times = len(evoked.times) - with pytest.warns(RuntimeWarning, match="tmax is set to"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="tmax is set to"): evoked.crop(tmin, tmax, include_tmax=False) assert len(evoked.times) == n_times assert_allclose(evoked.times[[0, -1]], [tmin, tmax - delta], atol=atol) @@ -899,7 +918,7 @@ def test_evoked_baseline(tmp_path): def test_hilbert(): - """Test hilbert on raw, epochs, and evoked.""" + """Test hilbert on raw, epochs, evoked and SourceEstimate data.""" raw = read_raw_fif(raw_fname).load_data() raw.del_proj() raw.pick(raw.ch_names[:2]) @@ -909,10 +928,17 @@ def test_hilbert(): epochs.apply_hilbert() epochs.load_data() evoked = epochs.average() + # Create SourceEstimate stc data + verts = [np.arange(10), np.arange(90)] + data = np.random.default_rng(0).normal(size=(100, 10)) + stc = SourceEstimate(data, verts, 0, 1e-1, "foo") + raw_hilb = raw.apply_hilbert() epochs_hilb = epochs.apply_hilbert() evoked_hilb = evoked.copy().apply_hilbert() evoked_hilb_2_data = epochs_hilb.get_data(copy=False).mean(0) + stc_hilb = stc.copy().apply_hilbert() + stc_hilb_env = stc.copy().apply_hilbert(envelope=True) assert_allclose(evoked_hilb.data, evoked_hilb_2_data) # This one is only approximate because of edge artifacts evoked_hilb_3 = Epochs(raw_hilb, events).average() @@ -923,6 +949,8 @@ def test_hilbert(): # envelope=True mode evoked_hilb_env = evoked.apply_hilbert(envelope=True) assert_allclose(evoked_hilb_env.data, np.abs(evoked_hilb.data)) + assert len(stc_hilb.data) == len(stc.data) + assert_allclose(stc_hilb_env.data, np.abs(stc_hilb.data)) def test_apply_function_evk(): @@ -941,3 +969,33 @@ def fun(data, multiplier): applied = evoked.apply_function(fun, n_jobs=None, multiplier=mult) assert np.shape(applied.data) == np.shape(evoked_data) assert np.equal(applied.data, evoked_data * mult).all() + + +def test_apply_function_evk_ch_access(): + """Check ch-access within the apply_function method for evoked data.""" + + def _bad_ch_idx(x, ch_idx): + assert x[0] == ch_idx + return x + + def _bad_ch_name(x, ch_name): + assert isinstance(ch_name, str) + assert x[0] == float(ch_name) + return x + + # create fake evoked data to use for checking apply_function + data = np.full((2, 100), np.arange(2).reshape(-1, 1)) + evoked = EvokedArray(data, create_info(2, 1000.0, "eeg")) + + # test ch_idx access in both code paths (parallel / 1 job) + evoked.apply_function(_bad_ch_idx) + evoked.apply_function(_bad_ch_idx, n_jobs=2) + evoked.apply_function(_bad_ch_name) + evoked.apply_function(_bad_ch_name, n_jobs=2) + + # test input catches + with pytest.raises( + ValueError, + match="cannot access.*when channel_wise=False", + ): + evoked.apply_function(_bad_ch_idx, channel_wise=False) diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index 110a8f136c3..00dce484a08 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -32,6 +32,8 @@ from mne.io import RawArray, read_raw_fif from mne.utils import catch_logging, requires_mne, run_subprocess, sum_squared +resample_method_parametrize = pytest.mark.parametrize("method", ("fft", "polyphase")) + def test_filter_array(): """Test filtering an array.""" @@ -86,12 +88,8 @@ def test_estimate_ringing(): (0.0001, (30000, 60000)), ): # 37993 n_ring = estimate_ringing_samples(butter(3, thresh, output=kind)) - assert lims[0] <= n_ring <= lims[1], "%s %s: %s <= %s <= %s" % ( - kind, - thresh, - lims[0], - n_ring, - lims[1], + assert lims[0] <= n_ring <= lims[1], ( + f"{kind} {thresh}: {lims[0]} " f"<= {n_ring} <= {lims[1]}" ) with pytest.warns(RuntimeWarning, match="properly estimate"): assert estimate_ringing_samples(butter(4, 0.00001)) == 100000 @@ -372,20 +370,27 @@ def test_notch_filters(method, filter_length, line_freq, tol): assert_almost_equal(new_power, orig_power, tol) -def test_resample(): +@resample_method_parametrize +def test_resample(method): """Test resampling.""" rng = np.random.RandomState(0) x = rng.normal(0, 1, (10, 10, 10)) - x_rs = resample(x, 1, 2, 10) + with catch_logging() as log: + x_rs = resample(x, 1, 2, npad=10, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "neighborhood" not in log + else: + assert "neighborhood" in log assert x.shape == (10, 10, 10) assert x_rs.shape == (10, 10, 5) x_2 = x.swapaxes(0, 1) - x_2_rs = resample(x_2, 1, 2, 10) + x_2_rs = resample(x_2, 1, 2, npad=10, method=method) assert_array_equal(x_2_rs.swapaxes(0, 1), x_rs) x_3 = x.swapaxes(0, 2) - x_3_rs = resample(x_3, 1, 2, 10, 0) + x_3_rs = resample(x_3, 1, 2, npad=10, axis=0, method=method) assert_array_equal(x_3_rs.swapaxes(0, 2), x_rs) # make sure we cast to array if necessary @@ -398,15 +403,15 @@ def test_resample_scipy(): for window in ("boxcar", "hann"): for N in (100, 101, 102, 103): x = np.arange(N).astype(float) - err_msg = "%s: %s" % (N, window) + err_msg = f"{N}: {window}" x_2_sp = sp_resample(x, 2 * N, window=window) for n_jobs in n_jobs_test: - x_2 = resample(x, 2, 1, 0, window=window, n_jobs=n_jobs) + x_2 = resample(x, 2, 1, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_2, x_2_sp, atol=1e-12, err_msg=err_msg) new_len = int(round(len(x) * (1.0 / 2.0))) x_p5_sp = sp_resample(x, new_len, window=window) for n_jobs in n_jobs_test: - x_p5 = resample(x, 1, 2, 0, window=window, n_jobs=n_jobs) + x_p5 = resample(x, 1, 2, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_p5, x_p5_sp, atol=1e-12, err_msg=err_msg) @@ -450,23 +455,25 @@ def test_resamp_stim_channel(): assert new_data.shape[1] == new_data_len -def test_resample_raw(): +@resample_method_parametrize +def test_resample_raw(method): """Test resampling using RawArray.""" x = np.zeros((1, 1001)) sfreq = 2048.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(128, npad=10) + raw.resample(128, npad=10, method=method) data = raw.get_data() assert data.shape == (1, 63) -def test_resample_below_1_sample(): +@resample_method_parametrize +def test_resample_below_1_sample(method): """Test resampling doesn't yield datapoints.""" # Raw x = np.zeros((1, 100)) sfreq = 1000.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(5) + raw.resample(5, method=method) assert len(raw.times) == 1 assert raw.get_data().shape[1] == 1 @@ -487,7 +494,13 @@ def test_resample_below_1_sample(): preload=True, verbose=False, ) - epochs.resample(1) + with catch_logging() as log: + epochs.resample(1, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "neighborhood" not in log + else: + assert "neighborhood" in log assert len(epochs.times) == 1 assert epochs.get_data(copy=False).shape[2] == 1 @@ -593,12 +606,12 @@ def test_filters(): # try new default and old default freqs = fftfreq(a.shape[-1], 1.0 / sfreq) A = np.abs(fft(a)) - kwargs = dict(fir_design="firwin") + kw = dict(fir_design="firwin") for fl in ["auto", "10s", "5000ms", 1024, 1023]: - bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs) - bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kwargs) - lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kwargs) - hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kw) + bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kw) + lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kw) + hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kw) assert_allclose(hp, bp, rtol=1e-3, atol=2e-3) assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3) # Sanity check ttenuation @@ -606,12 +619,18 @@ def test_filters(): assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.02) assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.2) # now the minimum-phase versions - bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kwargs) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum-half", **kw) bs = filter_data( - a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kwargs + a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum-half", **kw ) assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.11) assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.3) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kw) + bs = filter_data( + a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kw + ) + assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.12) + assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.27) # and since these are low-passed, downsampling/upsampling should be close n_resamp_ignore = 10 @@ -894,7 +913,7 @@ def test_reporting_iir(phase, ftype, btype, order, output): dB_cutoff = -7.58 dB_cutoff *= order_mult if btype == "lowpass": - keys += ["%0.2f dB" % (dB_cutoff,)] + keys += [f"{dB_cutoff:0.2f} dB"] for key in keys: assert key.lower() in log.lower() # Verify some of the filter properties @@ -1037,3 +1056,45 @@ def test_filter_picks(): raw.filter(picks=picks, **kwargs) want = want[1:] assert_allclose(raw.get_data(), want) + + +def test_filter_minimum_phase_bug(): + """Test gh-12267 is fixed.""" + sfreq = 1000.0 + n_taps = 1001 + l_freq = 10.0 # Hz + kwargs = dict( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=None, + filter_length=n_taps, + l_trans_bandwidth=l_freq / 2.0, + ) + h = create_filter(phase="zero", **kwargs) + h_min = create_filter(phase="minimum", **kwargs) + h_min_half = create_filter(phase="minimum-half", **kwargs) + assert h_min.size == h.size + kwargs = dict(worN=10000, fs=sfreq) + w, H = freqz(h, **kwargs) + assert w[0] == 0 + dc_dB = 20 * np.log10(np.abs(H[0])) + assert dc_dB < -100 + # good + w_min, H_min = freqz(h_min, **kwargs) + assert_allclose(w, w_min) + dc_dB_min = 20 * np.log10(np.abs(H_min[0])) + assert dc_dB_min < -100 + mask = w < 5 + assert 10 < mask.sum() < 101 + assert_allclose(np.abs(H[mask]), np.abs(H_min[mask]), atol=1e-3, rtol=1e-3) + assert_array_less(20 * np.log10(np.abs(H[mask])), -40) + assert_array_less(20 * np.log10(np.abs(H_min[mask])), -40) + # bad + w_min_half, H_min_half = freqz(h_min_half, **kwargs) + assert_allclose(w, w_min_half) + dc_dB_min_half = 20 * np.log10(np.abs(H_min_half[0])) + assert -80 < dc_dB_min_half < 40 + dB_min_half = 20 * np.log10(np.abs(H_min_half[mask])) + assert_array_less(dB_min_half, -20) + assert not (dB_min_half < -30).all() diff --git a/mne/tests/test_label.py b/mne/tests/test_label.py index 35e41d91f6c..01d934417e2 100644 --- a/mne/tests/test_label.py +++ b/mne/tests/test_label.py @@ -64,7 +64,7 @@ src_bad_fname = data_path / "subjects" / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" label_dir = subjects_dir / "sample" / "label" / "aparc" -test_path = Path(__file__).parent.parent / "io" / "tests" / "data" +test_path = Path(__file__).parents[1] / "io" / "tests" / "data" label_fname = test_path / "test-lh.label" label_rh_fname = test_path / "test-rh.label" @@ -182,7 +182,7 @@ def assert_labels_equal(l0, l1, decimal=5, comment=True, color=True): for attr in ["hemi", "subject"]: attr0 = getattr(l0, attr) attr1 = getattr(l1, attr) - msg = "label.%s: %r != %r" % (attr, attr0, attr1) + msg = f"label.{attr}: {repr(attr0)} != {repr(attr1)}" assert_equal(attr0, attr1, msg) for attr in ["vertices", "pos", "values"]: a0 = getattr(l0, attr) diff --git a/mne/tests/test_line_endings.py b/mne/tests/test_line_endings.py index 5c91c29fd9a..8ee4f604c9f 100644 --- a/mne/tests/test_line_endings.py +++ b/mne/tests/test_line_endings.py @@ -74,8 +74,7 @@ def _assert_line_endings(dir_): ) if len(report) > 0: raise AssertionError( - "Found %s files with incorrect endings:\n%s" - % (len(report), "\n".join(report)) + f"Found {len(report)} files with incorrect endings:\n" + "\n".join(report) ) diff --git a/mne/tests/test_misc.py b/mne/tests/test_misc.py index 887d45a6ffb..54286669e57 100644 --- a/mne/tests/test_misc.py +++ b/mne/tests/test_misc.py @@ -7,7 +7,7 @@ from mne.misc import parse_config -ave_fname = Path(__file__).parent.parent / "io" / "tests" / "data" / "test.ave" +ave_fname = Path(__file__).parents[1] / "io" / "tests" / "data" / "test.ave" def test_parse_ave(): diff --git a/mne/tests/test_proj.py b/mne/tests/test_proj.py index f13272dd2c1..36437238f90 100644 --- a/mne/tests/test_proj.py +++ b/mne/tests/test_proj.py @@ -42,7 +42,7 @@ from mne.rank import _compute_rank_int from mne.utils import _record_warnings -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_fname = base_dir / "test-eve.fif" proj_fname = base_dir / "test-proj.fif" diff --git a/mne/tests/test_rank.py b/mne/tests/test_rank.py index fb9efcba615..3832fe18bff 100644 --- a/mne/tests/test_rank.py +++ b/mne/tests/test_rank.py @@ -29,7 +29,7 @@ estimate_rank, ) -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" cov_fname = base_dir / "test-cov.fif" raw_fname = base_dir / "test_raw.fif" ave_fname = base_dir / "test-ave.fif" @@ -177,7 +177,7 @@ def test_cov_rank_estimation(rank_method, proj, meg): # count channel types ch_types = this_info.get_channel_types() - n_eeg, n_mag, n_grad = [ch_types.count(k) for k in ["eeg", "mag", "grad"]] + n_eeg, n_mag, n_grad = (ch_types.count(k) for k in ["eeg", "mag", "grad"]) n_meg = n_mag + n_grad has_sss = n_meg > 0 and len(this_info["proc_history"]) > 0 if has_sss: diff --git a/mne/tests/test_read_vectorview_selection.py b/mne/tests/test_read_vectorview_selection.py index 844a30edc5d..e0ed6f0af20 100644 --- a/mne/tests/test_read_vectorview_selection.py +++ b/mne/tests/test_read_vectorview_selection.py @@ -7,7 +7,7 @@ from mne import read_vectorview_selection from mne.io import read_raw_fif -test_path = Path(__file__).parent.parent / "io" / "tests" / "data" +test_path = Path(__file__).parents[1] / "io" / "tests" / "data" raw_fname = test_path / "test_raw.fif" raw_new_fname = test_path / "test_chpi_raw_sss.fif" diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 8c9e7df9389..08e08761ced 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -72,7 +72,7 @@ read_inverse_operator, ) from mne.morph_map import _make_morph_map_hemi -from mne.source_estimate import _get_vol_mask, grade_to_tris +from mne.source_estimate import _get_vol_mask, _make_stc, grade_to_tris from mne.source_space._source_space import _get_src_nn from mne.transforms import apply_trans, invert_transform, transform_surface_to from mne.utils import ( @@ -248,6 +248,34 @@ def test_volume_stc(tmp_path): assert_array_almost_equal(stc.data, stc_new.data) +@testing.requires_testing_data +def test_save_stc_as_gifti(tmp_path): + """Save the stc as a GIFTI file and export.""" + nib = pytest.importorskip("nibabel") + surfpath_src = bem_path / "sample-oct-6-src.fif" + surfpath_stc = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg" + src = read_source_spaces(surfpath_src) # need source space + stc = read_source_estimate(surfpath_stc) # need stc + assert isinstance(src, SourceSpaces) + assert isinstance(stc, SourceEstimate) + + surf_fname = tmp_path / "stc_write" + + stc.save_as_surface(surf_fname, src) + + # did structural get written? + img_lh = nib.load(f"{surf_fname}-lh.gii") + img_rh = nib.load(f"{surf_fname}-rh.gii") + assert isinstance(img_lh, nib.gifti.gifti.GiftiImage) + assert isinstance(img_rh, nib.gifti.gifti.GiftiImage) + + # did time series get written? + img_timelh = nib.load(f"{surf_fname}-lh.time.gii") + img_timerh = nib.load(f"{surf_fname}-rh.time.gii") + assert isinstance(img_timelh, nib.gifti.gifti.GiftiImage) + assert isinstance(img_timerh, nib.gifti.gifti.GiftiImage) + + @testing.requires_testing_data def test_stc_as_volume(): """Test previous volume source estimate morph.""" @@ -371,7 +399,7 @@ def test_stc_snr(): assert (stc.data < 0).any() with pytest.warns(RuntimeWarning, match="nAm"): stc.estimate_snr(evoked.info, fwd, cov) # dSPM - with pytest.warns(RuntimeWarning, match="free ori"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="free ori"): abs(stc).estimate_snr(evoked.info, fwd, cov) stc = apply_inverse(evoked, inv, method="MNE") snr = stc.estimate_snr(evoked.info, fwd, cov) @@ -558,61 +586,73 @@ def test_stc_arithmetic(): @pytest.mark.slowtest @testing.requires_testing_data -def test_stc_methods(): +@pytest.mark.parametrize("kind", ("scalar", "vector")) +@pytest.mark.parametrize("method", ("fft", "polyphase")) +def test_stc_methods(kind, method): """Test stc methods lh_data, rh_data, bin(), resample().""" - stc_ = read_source_estimate(fname_stc) + stc = read_source_estimate(fname_stc) - # Make a vector version of the above source estimate - x = stc_.data[:, np.newaxis, :] - yz = np.zeros((x.shape[0], 2, x.shape[2])) - vec_stc_ = VectorSourceEstimate( - np.concatenate((x, yz), 1), stc_.vertices, stc_.tmin, stc_.tstep, stc_.subject - ) + if kind == "vector": + # Make a vector version of the above source estimate + x = stc.data[:, np.newaxis, :] + yz = np.zeros((x.shape[0], 2, x.shape[2])) + stc = VectorSourceEstimate( + np.concatenate((x, yz), 1), + stc.vertices, + stc.tmin, + stc.tstep, + stc.subject, + ) - for stc in [stc_, vec_stc_]: - # lh_data / rh_data - assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) - assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) - - # bin - binned = stc.bin(0.12) - a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) - assert_array_equal(a, binned.data[..., 0]) - - stc = read_source_estimate(fname_stc) - stc.subject = "sample" - label_lh = read_labels_from_annot( - "sample", "aparc", "lh", subjects_dir=subjects_dir - )[0] - label_rh = read_labels_from_annot( - "sample", "aparc", "rh", subjects_dir=subjects_dir - )[0] - label_both = label_lh + label_rh - for label in (label_lh, label_rh, label_both): - assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 - stc_label = stc.in_label(label) - if label.hemi != "both": - if label.hemi == "lh": - verts = stc_label.vertices[0] - else: # label.hemi == 'rh': - verts = stc_label.vertices[1] - n_vertices_used = len(label.get_vertices_used(verts)) - assert_equal(len(stc_label.data), n_vertices_used) - stc_lh = stc.in_label(label_lh) - pytest.raises(ValueError, stc_lh.in_label, label_rh) - label_lh.subject = "foo" - pytest.raises(RuntimeError, stc.in_label, label_lh) - - stc_new = deepcopy(stc) - o_sfreq = 1.0 / stc.tstep - # note that using no padding for this STC reduces edge ringing... - stc_new.resample(2 * o_sfreq, npad=0) - assert stc_new.data.shape[1] == 2 * stc.data.shape[1] - assert stc_new.tstep == stc.tstep / 2 - stc_new.resample(o_sfreq, npad=0) - assert stc_new.data.shape[1] == stc.data.shape[1] - assert stc_new.tstep == stc.tstep - assert_array_almost_equal(stc_new.data, stc.data, 5) + # lh_data / rh_data + assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) + assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) + + # bin + binned = stc.bin(0.12) + a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) + assert_array_equal(a, binned.data[..., 0]) + + stc = read_source_estimate(fname_stc) + stc.subject = "sample" + label_lh = read_labels_from_annot( + "sample", "aparc", "lh", subjects_dir=subjects_dir + )[0] + label_rh = read_labels_from_annot( + "sample", "aparc", "rh", subjects_dir=subjects_dir + )[0] + label_both = label_lh + label_rh + for label in (label_lh, label_rh, label_both): + assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 + stc_label = stc.in_label(label) + if label.hemi != "both": + if label.hemi == "lh": + verts = stc_label.vertices[0] + else: # label.hemi == 'rh': + verts = stc_label.vertices[1] + n_vertices_used = len(label.get_vertices_used(verts)) + assert_equal(len(stc_label.data), n_vertices_used) + stc_lh = stc.in_label(label_lh) + pytest.raises(ValueError, stc_lh.in_label, label_rh) + label_lh.subject = "foo" + pytest.raises(RuntimeError, stc.in_label, label_lh) + + stc_new = deepcopy(stc) + o_sfreq = 1.0 / stc.tstep + # note that using no padding for this STC reduces edge ringing... + stc_new.resample(2 * o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == 2 * stc.data.shape[1] + assert stc_new.tstep == stc.tstep / 2 + stc_new.resample(o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == stc.data.shape[1] + assert stc_new.tstep == stc.tstep + if method == "fft": + # no low-passing so survives round-trip + assert_allclose(stc_new.data, stc.data, atol=1e-5) + else: + # low-passing means we need something more flexible + corr = np.corrcoef(stc_new.data.ravel(), stc.data.ravel())[0, 1] + assert 0.99 < corr < 1 @testing.requires_testing_data @@ -1200,7 +1240,7 @@ def test_to_data_frame_index(index): # test index setting if not isinstance(index, list): index = [index] - assert df.index.names == index + assert list(df.index.names) == index # test that non-indexed data were present as columns non_index = list(set(["time", "subject"]) - set(index)) if len(non_index): @@ -1673,7 +1713,8 @@ def test_stc_near_sensors(tmp_path): for s in src: transform_surface_to(s, "head", trans, copy=False) assert src[0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD - stc_src = stc_near_sensors(evoked, src=src, **kwargs) + with pytest.warns(DeprecationWarning, match="instead of the pial"): + stc_src = stc_near_sensors(evoked, src=src, **kwargs) assert len(stc_src.data) == 7928 with pytest.warns(RuntimeWarning, match="not included"): # some removed stc_src_full = compute_source_morph( @@ -2014,3 +2055,31 @@ def test_label_extraction_subject(kind): stc.subject = None with pytest.raises(ValueError, match=r"label\.sub.*not match.* sour"): extract_label_time_course(stc, labels_fs, src) + + +def test_apply_function_stc(): + """Check the apply_function method for source estimate data.""" + # Create a sample _BaseSourceEstimate object + n_vertices = 100 + n_times = 200 + vertices = [np.array(np.arange(50)), np.array(np.arange(50, 100))] + tmin = 0.0 + tstep = 0.001 + data = np.random.default_rng(0).normal(size=(n_vertices, n_times)) + + stc = _make_stc(data, vertices, tmin=tmin, tstep=tstep, src_type="surface") + + # A sample function to apply to the data + def fun(data_row, **kwargs): + return 2 * data_row + + # Test applying the function to all vertices without parallelization + stc_copy = stc.copy() + stc.apply_function(fun) + for idx in range(n_vertices): + assert_allclose(stc.data[idx, :], 2 * stc_copy.data[idx, :]) + + # Test applying the function with parallelization + stc.apply_function(fun, n_jobs=2) + for idx in range(n_vertices): + assert_allclose(stc.data[idx, :], 4 * stc_copy.data[idx, :]) diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py index 646b2793706..6199bdfbe41 100644 --- a/mne/tests/test_surface.py +++ b/mne/tests/test_surface.py @@ -49,7 +49,7 @@ def test_helmet(): """Test loading helmet surfaces.""" - base_dir = Path(__file__).parent.parent / "io" + base_dir = Path(__file__).parents[1] / "io" fname_raw = base_dir / "tests" / "data" / "test_raw.fif" fname_kit_raw = base_dir / "kit" / "tests" / "data" / "test_bin_raw.fif" fname_bti_raw = base_dir / "bti" / "tests" / "data" / "exported4D_linux_raw.fif" diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py index 2246609bdb8..ef6433d951d 100644 --- a/mne/tests/test_transforms.py +++ b/mne/tests/test_transforms.py @@ -64,7 +64,7 @@ subjects_dir = data_path / "subjects" fname_t1 = subjects_dir / "fsaverage" / "mri" / "T1.mgz" -base_dir = Path(__file__).parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" fname_trans = base_dir / "sample-audvis-raw-trans.txt" test_fif_fname = base_dir / "test_raw.fif" ctf_fname = base_dir / "test_ctf_raw.fif" diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 9fc0c271cc4..0faeb7263d8 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -1,12 +1,16 @@ __all__ = [ "AverageTFR", + "AverageTFRArray", + "BaseTFR", "CrossSpectralDensity", "EpochsSpectrum", "EpochsSpectrumArray", "EpochsTFR", + "EpochsTFRArray", + "RawTFR", + "RawTFRArray", "Spectrum", "SpectrumArray", - "_BaseTFR", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,8 +65,12 @@ from .spectrum import ( ) from .tfr import ( AverageTFR, + AverageTFRArray, + BaseTFR, EpochsTFR, - _BaseTFR, + EpochsTFRArray, + RawTFR, + RawTFRArray, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py index 1abf0c8e5a6..08acf28b357 100644 --- a/mne/time_frequency/_stockwell.py +++ b/mne/time_frequency/_stockwell.py @@ -12,8 +12,8 @@ from .._fiff.pick import _pick_data_channels, pick_info from ..parallel import parallel_func -from ..utils import _validate_type, fill_doc, logger, verbose -from .tfr import AverageTFR, _get_data +from ..utils import _validate_type, legacy, logger, verbose +from .tfr import AverageTFRArray, _ensure_slice, _get_data def _check_input_st(x_in, n_fft): @@ -22,20 +22,19 @@ def _check_input_st(x_in, n_fft): n_times = x_in.shape[-1] def _is_power_of_two(n): - return not (n > 0 and ((n & (n - 1)))) + return not (n > 0 and (n & (n - 1))) if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft): # Compute next power of 2 n_fft = 2 ** int(np.ceil(np.log2(n_times))) elif n_fft < n_times: raise ValueError( - "n_fft cannot be smaller than signal size. " - "Got %s < %s." % (n_fft, n_times) + f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}." ) if n_times < n_fft: logger.info( - 'The input signal is shorter ({}) than "n_fft" ({}). ' - "Applying zero padding.".format(x_in.shape[-1], n_fft) + f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). ' + "Applying zero padding." ) zero_pad = n_fft - n_times pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype) @@ -82,9 +81,10 @@ def _st(x, start_f, windows): def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): """Aux function.""" + decim = _ensure_slice(decim) n_samp = x.shape[-1] - n_out = n_samp - zero_pad - n_out = n_out // decim + bool(n_out % decim) + decim_indices = decim.indices(n_samp - zero_pad) + n_out = len(range(*decim_indices)) psd = np.empty((len(W), n_out)) itc = np.empty_like(psd) if compute_itc else None X = fft(x) @@ -92,10 +92,7 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): for i_f, window in enumerate(W): f = start_f + i_f ST = ifft(XX[:, f : f + n_samp] * window) - if zero_pad > 0: - TFR = ST[:, :-zero_pad:decim] - else: - TFR = ST[:, ::decim] + TFR = ST[:, slice(*decim_indices)] TFR_abs = np.abs(TFR) TFR_abs[TFR_abs == 0] = 1.0 if compute_itc: @@ -106,7 +103,22 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): return psd, itc -@fill_doc +def _compute_freqs_st(fmin, fmax, n_fft, sfreq): + from scipy.fft import fftfreq + + freqs = fftfreq(n_fft, 1.0 / sfreq) + if fmin is None: + fmin = freqs[freqs > 0][0] + if fmax is None: + fmax = freqs.max() + + start_f = np.abs(freqs - fmin).argmin() + stop_f = np.abs(freqs - fmax).argmin() + freqs = freqs[start_f:stop_f] + return start_f, stop_f, freqs + + +@verbose def tfr_array_stockwell( data, sfreq, @@ -117,6 +129,8 @@ def tfr_array_stockwell( decim=1, return_itc=False, n_jobs=None, + *, + verbose=None, ): """Compute power and intertrial coherence using Stockwell (S) transform. @@ -144,11 +158,11 @@ def tfr_array_stockwell( The width of the Gaussian window. If < 1, increased temporal resolution, if > 1, increased frequency resolution. Defaults to 1. (classical S-Transform). - decim : int - The decimation factor on the time axis. To reduce memory usage. + %(decim_tfr)s return_itc : bool Return intertrial coherence (ITC) as well as averaged power. %(n_jobs)s + %(verbose)s Returns ------- @@ -178,26 +192,17 @@ def tfr_array_stockwell( "data must be 3D with shape (n_epochs, n_channels, n_times), " f"got {data.shape}" ) - n_epochs, n_channels = data.shape[:2] - n_out = data.shape[2] // decim + bool(data.shape[-1] % decim) + decim = _ensure_slice(decim) + _, n_channels, n_out = data[..., decim].shape data, n_fft_, zero_pad = _check_input_st(data, n_fft) - - freqs = fftfreq(n_fft_, 1.0 / sfreq) - if fmin is None: - fmin = freqs[freqs > 0][0] - if fmax is None: - fmax = freqs.max() - - start_f = np.abs(freqs - fmin).argmin() - stop_f = np.abs(freqs - fmax).argmin() - freqs = freqs[start_f:stop_f] + start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq) W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width) n_freq = stop_f - start_f psd = np.empty((n_channels, n_freq, n_out)) itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None - parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs) + parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose) tfrs = parallel( my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W) for c in range(n_channels) @@ -210,6 +215,7 @@ def tfr_array_stockwell( return psd, itc, freqs +@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")') @verbose def tfr_stockwell( inst, @@ -282,6 +288,7 @@ def tfr_stockwell( picks = _pick_data_channels(inst.info) info = pick_info(inst.info, picks) data = data[:, picks, :] + decim = _ensure_slice(decim) power, itc, freqs = tfr_array_stockwell( data, sfreq=info["sfreq"], @@ -293,18 +300,25 @@ def tfr_stockwell( return_itc=return_itc, n_jobs=n_jobs, ) - times = inst.times[::decim].copy() + times = inst.times[decim].copy() nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method="stockwell-power") + out = AverageTFRArray( + info=info, + data=power, + times=times, + freqs=freqs, + nave=nave, + method="stockwell-power", + ) if return_itc: out = ( out, - AverageTFR( - deepcopy(info), - itc, - times.copy(), - freqs.copy(), - nave, + AverageTFRArray( + info=deepcopy(info), + data=itc, + times=times.copy(), + freqs=freqs.copy(), + nave=nave, method="stockwell-itc", ), ) diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index ed395137103..e2ea5ac1ba7 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -35,7 +35,7 @@ @verbose def pick_channels_csd( - csd, include=[], exclude=[], ordered=None, copy=True, *, verbose=None + csd, include=(), exclude=(), ordered=True, copy=True, *, verbose=None ): """Pick channels from cross-spectral density matrix. @@ -189,17 +189,18 @@ def __repr__(self): # noqa: D105 elif len(f) == 1: freq_strs.append(str(f[0])) else: - freq_strs.append("{}-{}".format(np.min(f), np.max(f))) + freq_strs.append(f"{np.min(f)}-{np.max(f)}") freq_str = ", ".join(freq_strs) + " Hz." if self.tmin is not None and self.tmax is not None: - time_str = "{} to {} s".format(self.tmin, self.tmax) + time_str = f"{self.tmin} to {self.tmax} s" else: time_str = "unknown" return ( - "" - ).format(self.n_channels, time_str, freq_str) + "" + ) def sum(self, fmin=None, fmax=None): """Calculate the sum CSD in the given frequency range(s). diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index c6af2b20c60..4a9e66c4673 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -285,9 +285,7 @@ def _compute_mt_params(n_times, sfreq, bandwidth, low_bias, adaptive, verbose=No """Triage windowing and multitaper parameters.""" # Compute standardized half-bandwidth if isinstance(bandwidth, str): - logger.info( - ' Using standard spectrum estimation with "%s" window' % (bandwidth,) - ) + logger.info(f' Using standard spectrum estimation with "{bandwidth}" window') window_fun = get_window(bandwidth, n_times)[np.newaxis] return window_fun, np.ones(1), False @@ -297,9 +295,8 @@ def _compute_mt_params(n_times, sfreq, bandwidth, low_bias, adaptive, verbose=No half_nbw = 4.0 if half_nbw < 0.5: raise ValueError( - "bandwidth value %s yields a normalized half-bandwidth of " - "%s < 0.5, use a value of at least %s" - % (bandwidth, half_nbw, sfreq / n_times) + f"bandwidth value {bandwidth} yields a normalized half-bandwidth of " + f"{half_nbw} < 0.5, use a value of at least {sfreq / n_times}" ) # Compute DPSS windows @@ -315,7 +312,7 @@ def _compute_mt_params(n_times, sfreq, bandwidth, low_bias, adaptive, verbose=No if adaptive and len(eigvals) < 3: warn( "Not adaptively combining the spectral estimators due to a " - "low number of tapers (%s < 3)." % (len(eigvals),) + f"low number of tapers ({len(eigvals)} < 3)." ) adaptive = False @@ -465,7 +462,7 @@ def psd_array_multitaper( @verbose def tfr_array_multitaper( - epoch_data, + data, sfreq, freqs, n_cycles=7.0, @@ -477,6 +474,7 @@ def tfr_array_multitaper( n_jobs=None, *, verbose=None, + epoch_data=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -486,11 +484,11 @@ def tfr_array_multitaper( Parameters ---------- - epoch_data : array of shape (n_epochs, n_channels, n_times) + data : array of shape (n_epochs, n_channels, n_times) The epochs. sfreq : float Sampling frequency of the data in Hz. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s zero_mean : bool If True, make sure the wavelets have a mean of zero. Defaults to True. @@ -508,12 +506,17 @@ def tfr_array_multitaper( * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. %(n_jobs)s + The parallelization is implemented across channels. %(verbose)s + epoch_data : None + Deprecated parameter for providing epoched data as of 1.7, will be replaced with + the ``data`` parameter in 1.8. New code should use the ``data`` parameter. If + ``epoch_data`` is not ``None``, a warning will be raised. Returns ------- out : array - Time frequency transform of ``epoch_data``. + Time frequency transform of ``data``. - if ``output in ('complex',' 'phase')``, array of shape ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)`` @@ -543,8 +546,15 @@ def tfr_array_multitaper( """ from .tfr import _compute_tfr + if epoch_data is not None: + warn( + "The parameter for providing data will be switched from `epoch_data` to " + "`data` in 1.8. Use the `data` parameter to avoid this warning.", + FutureWarning, + ) + return _compute_tfr( - epoch_data, + data, freqs, sfreq=sfreq, method="multitaper", diff --git a/mne/time_frequency/psd.py b/mne/time_frequency/psd.py index 33bcd16df8c..b2083c22229 100644 --- a/mne/time_frequency/psd.py +++ b/mne/time_frequency/psd.py @@ -4,6 +4,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import warnings from functools import partial import numpy as np @@ -11,6 +12,7 @@ from ..parallel import parallel_func from ..utils import _check_option, _ensure_int, logger, verbose +from ..utils.numerics import _mask_to_onsets_offsets # adapted from SciPy @@ -214,7 +216,7 @@ def psd_array_welch( ) parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs) - func = partial( + _func = partial( spectrogram, detrend=detrend, noverlap=n_overlap, @@ -224,12 +226,54 @@ def psd_array_welch( window=window, mode=mode, ) - x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0] + if np.any(np.isnan(x)): + good_mask = ~np.isnan(x) + # NaNs originate from annot, so must match for all channels. Note that we CANNOT + # use np.testing.assert_allclose() here; it is strict about shapes/broadcasting + assert np.allclose(good_mask, good_mask[[0]], equal_nan=True) + t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0]) + x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)] + # weights reflect the number of samples used from each span. For spans longer + # than `n_per_seg`, trailing samples may be discarded. For spans shorter than + # `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically + # reduces `n_per_seg` to match the span length (with a warning). + step = n_per_seg - n_overlap + span_lengths = [span.shape[-1] for span in x_splits] + weights = [ + w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths + ] + agg_func = partial(np.average, weights=weights) + if n_jobs > 1: + logger.info( + f"Data split into {len(x_splits)} (probably unequal) chunks due to " + '"bad_*" annotations. Parallelization may be sub-optimal.' + ) + if (np.array(span_lengths) < n_per_seg).any(): + logger.info( + "At least one good data span is shorter than n_per_seg, and will be " + "analyzed with a shorter window than the rest of the file." + ) + + def func(*args, **kwargs): + # swallow SciPy warnings caused by short good data spans + with warnings.catch_warnings(): + warnings.filterwarnings( + action="ignore", + module="scipy", + category=UserWarning, + message=r"nperseg = \d+ is greater than input length", + ) + return _func(*args, **kwargs) + + else: + x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0] + agg_func = np.concatenate + func = _func f_spect = parallel( my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output) for d in x_splits ) - psds = np.concatenate(f_spect, axis=0) + psds = agg_func(f_spect, axis=0) shape = dshape + (len(freqs),) if average is None: shape = shape + (-1,) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index aa459124347..a9006ac443f 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1,7 +1,4 @@ """Container classes for spectral data.""" - -# Authors: Dan McCloy -# # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -25,6 +22,7 @@ from ..utils import ( GetEpochsMixin, _build_data_frame, + _check_method_kwargs, _check_pandas_index_arguments, _check_pandas_installed, _check_sphere, @@ -45,13 +43,14 @@ _is_numeric, check_fname, ) -from ..utils.misc import _pl -from ..utils.spectrum import _split_psd_kwargs +from ..utils.misc import _identity_function, _pl +from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap from ..viz.utils import ( _format_units_psd, _get_plot_ch_type, + _make_combine_callable, _plot_psd, _prepare_sensor_names, plt_show, @@ -60,10 +59,6 @@ from .psd import _check_nfft, psd_array_welch -def _identity_function(x): - return x - - class SpectrumMixin: """Mixin providing spectral plotting methods to sensor-space containers.""" @@ -318,29 +313,18 @@ def __init__( ) # method self._inst_type = type(inst) - method = _validate_method(method, self._get_instance_type_string()) + method = _validate_method(method, _get_instance_type_string(self)) # don't allow complex output psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) if method_kw.get("output", "") == "complex": - warn( - f"Complex output support in {type(self).__name__} objects is " - "deprecated and will be removed in version 1.7. If you need complex " - f"output please use mne.time_frequency.{psd_funcs[method].__name__}() " - "instead.", - FutureWarning, + raise ValueError( + f"Complex output is not supported in {type(self).__name__} objects. " + f"Please use mne.time_frequency.{psd_funcs[method].__name__}() instead." ) # triage method and kwargs. partial() doesn't check validity of kwargs, # so we do it manually to save compute time if any are invalid. - invalid_ix = np.isin( - list(method_kw), list(signature(psd_funcs[method]).parameters), invert=True - ) - if invalid_ix.any(): - invalid_kw = np.array(list(method_kw))[invalid_ix].tolist() - s = _pl(invalid_kw) - raise TypeError( - f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} ' - f'for PSD method "{method}".' - ) + psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) + _check_method_kwargs(psd_funcs[method], method_kw, msg=f'PSD method "{method}"') self._psd_func = partial(psd_funcs[method], remove_dc=remove_dc, **method_kw) # apply proj if desired @@ -359,7 +343,7 @@ def __init__( self.info = pick_info(inst.info, sel=self._picks, copy=True) # assign some attributes - self.preload = True # needed for __getitem__, doesn't mean anything + self.preload = True # needed for __getitem__, never False self._method = method # self._dims may also get updated by child classes self._dims = ( @@ -368,12 +352,12 @@ def __init__( ) if method_kw.get("average", "") in (None, False): self._dims += ("segment",) - if self._returns_complex_tapers(**method_kw): - self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] # record data type (for repr and html_repr) self._data_type = ( "Fourier Coefficients" if "taper" in self._dims else "Power Spectrum" ) + # set nave (child constructor overrides this for Evoked input) + self._nave = None def __eq__(self, other): """Test equivalence of two Spectrum instances.""" @@ -381,7 +365,7 @@ def __eq__(self, other): def __getstate__(self): """Prepare object for serialization.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) out = dict( method=self.method, data=self._data, @@ -391,6 +375,7 @@ def __getstate__(self): inst_type_str=inst_type_str, data_type=self._data_type, info=self.info, + nave=self.nave, ) return out @@ -407,16 +392,15 @@ def __setstate__(self, state): self._sfreq = state["sfreq"] self.info = Info(**state["info"]) self._data_type = state["data_type"] + self._nave = state.get("nave") # objs saved before #11282 won't have `nave` self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) self._inst_type = inst_types[state["inst_type_str"]] - if "weights" in state and state["weights"] is not None: - self._mt_weights = state["weights"] def __repr__(self): """Build string representation of the Spectrum object.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) # shape & dimension names dims = " × ".join( [f"{dim[0]} {dim[1]}s" for dim in zip(self.shape, self._dims)] @@ -430,7 +414,7 @@ def __repr__(self): @repr_html def _repr_html_(self, caption=None): """Build HTML representation of the Spectrum object.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] t = _get_html_template("repr", "spectrum.html.jinja") t = t.render(spectrum=self, inst_type=inst_type_str, units=units) @@ -456,23 +440,14 @@ def _check_values(self): s = _pl(bad_value.sum()) warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) - def _returns_complex_tapers(self, **method_kw): - return method_kw.get("output", "") == "complex" and self.method == "multitaper" - def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # make the spectra result = self._psd_func( data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose ) - # assign ._data (handling unaggregated multitaper output) - if self._returns_complex_tapers(**method_kw): - fourier_coefs, freqs, weights = result - self._data = fourier_coefs - self._mt_weights = weights - else: - psds, freqs = result - self._data = psds - # assign properties (._data already assigned above) + # assign ._data ._freqs, ._shape + psds, freqs = result + self._data = psds self._freqs = freqs # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) @@ -481,33 +456,11 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): if method_kw.get("average", "") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) - # insert n_tapers - if self._returns_complex_tapers(**method_kw): - self._shape = self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:] # we don't need these anymore, and they make save/load harder del self._picks del self._psd_func del self._time_mask - def _get_instance_type_string(self): - """Get string representation of the originating instance type.""" - from ..epochs import BaseEpochs - from ..evoked import Evoked, EvokedArray - from ..io import BaseRaw - - parent_classes = self._inst_type.__bases__ - if BaseRaw in parent_classes: - inst_type_str = "Raw" - elif BaseEpochs in parent_classes: - inst_type_str = "Epochs" - elif self._inst_type in (Evoked, EvokedArray): - inst_type_str = "Evoked" - elif self._inst_type is np.ndarray: - inst_type_str = "Array" - else: - raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum") - return inst_type_str - @property def _detrend_picks(self): """Provide compatibility with __iter__.""" @@ -517,6 +470,10 @@ def _detrend_picks(self): def ch_names(self): return self.info["ch_names"] + @property + def data(self): + return self._data + @property def freqs(self): return self._freqs @@ -525,6 +482,10 @@ def freqs(self): def method(self): return self._method + @property + def nave(self): + return self._nave + @property def sfreq(self): return self._sfreq @@ -592,7 +553,7 @@ def plot( picks=None, average=False, dB=True, - amplitude="auto", + amplitude=None, xscale="linear", ci="sd", ci_alpha=0.3, @@ -612,57 +573,56 @@ def plot( .. versionchanged:: 1.5 In version 1.5, the default behavior changed so that all - :term:`data channels` (not just "good" data channels) are shown - by default. + :term:`data channels` (not just "good" data channels) are shown by + default. average : bool - Whether to average across channels before plotting. If ``True``, - interactive plotting of scalp topography is disabled, and - parameters ``ci`` and ``ci_alpha`` control the style of the - confidence band around the mean. Default is ``False``. + Whether to average across channels before plotting. If ``True``, interactive + plotting of scalp topography is disabled, and parameters ``ci`` and + ``ci_alpha`` control the style of the confidence band around the mean. + Default is ``False``. %(dB_spectrum_plot)s amplitude : bool | 'auto' Whether to plot an amplitude spectrum (``True``) or power spectrum - (``False``). If ``'auto'``, will plot a power spectrum when - ``dB=True`` and an amplitude spectrum otherwise. Default is - ``'auto'``. + (``False``). If ``'auto'``, will plot a power spectrum when ``dB=True`` and + an amplitude spectrum otherwise. Default is ``'auto'``. + + .. versionchanged:: 1.8 + In version 1.8, the value ``amplitude="auto"`` will be removed. The + default value will change to ``amplitude=False``. %(xscale_plot_psd)s ci : float | 'sd' | 'range' | None - Type of confidence band drawn around the mean when - ``average=True``. If ``'sd'`` the band spans ±1 standard deviation - across channels. If ``'range'`` the band spans the range across - channels at each frequency. If a :class:`float`, it indicates the - (bootstrapped) confidence interval to display, and must satisfy - ``0 < ci <= 100``. If ``None``, no band is drawn. Default is - ``sd``. + Type of confidence band drawn around the mean when ``average=True``. If + ``'sd'`` the band spans ±1 standard deviation across channels. If + ``'range'`` the band spans the range across channels at each frequency. If a + :class:`float`, it indicates the (bootstrapped) confidence interval to + display, and must satisfy ``0 < ci <= 100``. If ``None``, no band is drawn. + Default is ``sd``. ci_alpha : float - Opacity of the confidence band. Must satisfy - ``0 <= ci_alpha <= 1``. Default is 0.3. + Opacity of the confidence band. Must satisfy ``0 <= ci_alpha <= 1``. Default + is 0.3. %(color_plot_psd)s alpha : float | None Opacity of the spectrum line(s). If :class:`float`, must satisfy ``0 <= alpha <= 1``. If ``None``, opacity will be ``1`` when - ``average=True`` and ``0.1`` when ``average=False``. Default is - ``None``. + ``average=True`` and ``0.1`` when ``average=False``. Default is ``None``. %(spatial_colors_psd)s %(sphere_topomap_auto)s %(exclude_spectrum_plot)s .. versionchanged:: 1.5 - In version 1.5, the default behavior changed from - ``exclude='bads'`` to ``exclude=()``. + In version 1.5, the default behavior changed from ``exclude='bads'`` to + ``exclude=()``. %(axes_spectrum_plot_topomap)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure - Figure with spectra plotted in separate subplots for each channel - type. + Figure with spectra plotted in separate subplots for each channel type. """ # Must nest this _mpl_figure import because of the BACKEND global # stuff from ..viz._mpl_figure import _line_figure, _split_picks_by_type - from .multitaper import _psd_from_mt # arg checking ci = _check_ci(ci) @@ -672,10 +632,19 @@ def plot( scalings = _handle_default("scalings", None) titles = _handle_default("titles", None) units = _handle_default("units", None) - if amplitude == "auto": + + depr_message = ( + "The value of `amplitude='auto'` will be removed in MNE 1.8.0, and the new " + "default will be `amplitude=False`." + ) + if amplitude is None or amplitude == "auto": + warn(depr_message, FutureWarning) estimate = "power" if dB else "amplitude" - else: # amplitude is boolean + else: estimate = "amplitude" if amplitude else "power" + + logger.info(f"Plotting {estimate} spectral density ({dB=}).") + # split picks by channel type picks = _picks_to_idx( self.info, picks, "data", exclude=exclude, with_ref_meg=False @@ -683,12 +652,8 @@ def plot( (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) - # handle unaggregated multitaper - if hasattr(self, "_mt_weights"): - logger.info("Aggregating multitaper estimates before plotting...") - _f = partial(_psd_from_mt, weights=self._mt_weights) # handle unaggregated Welch - elif "segment" in self._dims: + if "segment" in self._dims: logger.info("Aggregating Welch estimates (median) before plotting...") seg_axis = self._dims.index("segment") _f = partial(np.nanmedian, axis=seg_axis) @@ -996,7 +961,7 @@ def to_data_frame( # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # triage for Epoch-derived or unaggregated spectra - from_epo = self._dims[0] == "epoch" + from_epo = _get_instance_type_string(self) == "Epochs" unagg_welch = "segment" in self._dims unagg_mt = "taper" in self._dims # arg checking @@ -1059,9 +1024,8 @@ def units(self, latex=False): for that channel type. """ units = _handle_default("si_units", None) - power = not hasattr(self, "_mt_weights") return { - ch_type: _format_units_psd(units[ch_type], power=power, latex=latex) + ch_type: _format_units_psd(units[ch_type], power=True, latex=latex) for ch_type in sorted(self.get_channel_types(unique=True)) } @@ -1103,8 +1067,10 @@ class Spectrum(BaseSpectrum): have been computed. %(info_not_none)s method : str - The method used to compute the spectrum (``'welch'`` or - ``'multitaper'``). + The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). + nave : int | None + The number of trials averaged together when generating the spectrum. ``None`` + indicates no averaging is known to have occurred. See Also -------- @@ -1166,13 +1132,21 @@ def __init__( data = self.inst.get_data( self._picks, start, stop + 1, reject_by_annotation=rba ) + if np.any(np.isnan(data)) and method == "multitaper": + raise NotImplementedError( + 'Cannot use method="multitaper" when reject_by_annotation=True. ' + 'Please use method="welch" instead.' + ) + else: # Evoked data = self.inst.data[self._picks][:, self._time_mask] + # set nave + self._nave = getattr(inst, "nave", None) # compute the spectra self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) # check for correct shape and bad values self._check_values() - del self._shape + del self._shape # calculated from self._data henceforth # save memory del self.inst @@ -1205,7 +1179,8 @@ def __getitem__(self, item): requested data values and the corresponding times), accessing :class:`~mne.time_frequency.Spectrum` values via subscript does **not** return the corresponding frequency bin values. If you need - them, use ``spectrum.freqs[freq_indices]``. + them, use ``spectrum.freqs[freq_indices]`` or + ``spectrum.get_data(..., return_freqs=True)``. """ from ..io import BaseRaw @@ -1240,7 +1215,7 @@ class SpectrumArray(Spectrum): data : array, shape (n_channels, n_freqs) The power spectral density for each channel. %(info_not_none)s - %(freqs_tfr)s + %(freqs_tfr_array)s %(verbose)s See Also @@ -1438,21 +1413,17 @@ def average(self, method="mean"): spectrum : instance of Spectrum The aggregated spectrum object. """ - if isinstance(method, str): - method = getattr(np, method) # mean, median, std, etc - method = partial(method, axis=0) + _validate_type(method, ("str", "callable")) + method = _make_combine_callable( + method, axis=0, valid=("mean", "median"), keepdims=False + ) if not callable(method): raise ValueError( '"method" must be a valid string or callable, ' f"got a {type(method).__name__} ({method})." ) # averaging unaggregated spectral estimates are not supported - if hasattr(self, "_mt_weights"): - raise NotImplementedError( - "Averaging complex spectra is not supported. Consider " - "averaging the signals before computing the complex spectrum." - ) - elif "segment" in self._dims: + if "segment" in self._dims: raise NotImplementedError( "Averaging individual Welch segments across epochs is not " "supported. Consider averaging the signals before computing " @@ -1460,6 +1431,7 @@ def average(self, method="mean"): ) # serialize the object and update data, dims, and data type state = super().__getstate__() + state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] state["data_type"] = f'Averaged {state["data_type"]}' @@ -1489,7 +1461,7 @@ class EpochsSpectrumArray(EpochsSpectrum): data : array, shape (n_epochs, n_channels, n_freqs) The power spectral density for each channel in each epoch. %(info_not_none)s - %(freqs_tfr)s + %(freqs_tfr_array)s %(events_epochs)s %(event_id)s %(verbose)s diff --git a/mne/time_frequency/tests/test_ar.py b/mne/time_frequency/tests/test_ar.py index bef37e7dd18..f0ea9db2a1e 100644 --- a/mne/time_frequency/tests/test_ar.py +++ b/mne/time_frequency/tests/test_ar.py @@ -10,9 +10,7 @@ from mne import io from mne.time_frequency.ar import _yule_walker, fit_iir_model_raw -raw_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) +raw_fname = Path(__file__).parents[2] / "io" / "tests" / "data" / "test_raw.fif" def test_yule_walker(): diff --git a/mne/time_frequency/tests/test_psd.py b/mne/time_frequency/tests/test_psd.py index e02e561384f..363a4207ce9 100644 --- a/mne/time_frequency/tests/test_psd.py +++ b/mne/time_frequency/tests/test_psd.py @@ -36,6 +36,26 @@ def test_psd_nan(): assert "hamming window" in log +def test_bad_annot_handling(): + """Make sure results equivalent with/without Annotations.""" + n_per_seg = 256 + n_chan = 3 + n_times = 5 * n_per_seg + x = np.random.default_rng(seed=42).standard_normal(size=(n_chan, n_times)) + want = psd_array_welch(x, sfreq=100) + # now simulate an annotation that breaks up the array into unequal spans. Using + # `n_per_seg` as the cut point is unrealistic/idealized, but it allows us to test + # whether we get results ~identical to `want` (which we should in this edge case) + x2 = np.concatenate( + (x[..., :n_per_seg], np.full((n_chan, 1), np.nan), x[..., n_per_seg:]), axis=-1 + ) + got = psd_array_welch(x2, sfreq=100) + # freqs should be identical + np.testing.assert_array_equal(got[1], want[1]) + # powers should be very very close + np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0) + + def _make_psd_data(): """Make noise data with sinusoids in 2 out of 7 channels.""" rng = np.random.default_rng(0) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 58e3309bcc8..a6ea0be9739 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -1,17 +1,17 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. + from functools import partial import numpy as np import pytest from matplotlib.colors import same_color -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_array_equal -from mne import Annotations, create_info, make_fixed_length_epochs -from mne.io import RawArray +from mne import Annotations from mne.time_frequency import read_spectrum -from mne.time_frequency.multitaper import _psd_from_mt from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray +from mne.utils import _record_warnings def test_compute_psd_errors(raw): @@ -22,8 +22,11 @@ def test_compute_psd_errors(raw): raw.compute_psd(foo=None) with pytest.raises(TypeError, match="keyword arguments foo, bar for"): raw.compute_psd(foo=None, bar=None) - with pytest.warns(FutureWarning, match="Complex output support in.*deprecated"): + with pytest.raises(ValueError, match="Complex output is not supported in "): raw.compute_psd(output="complex") + raw.set_annotations(Annotations(onset=0.01, duration=0.01, description="bad_foo")) + with pytest.raises(NotImplementedError, match='Cannot use method="multitaper"'): + raw.compute_psd(method="multitaper", reject_by_annotation=True) @pytest.mark.parametrize("method", ("welch", "multitaper")) @@ -125,11 +128,15 @@ def test_n_welch_windows(raw): ) -def _get_inst(inst, request, evoked): +def _get_inst(inst, request, *, evoked=None, average_tfr=None): # ↓ XXX workaround: # ↓ parametrized fixtures are not accessible via request.getfixturevalue # ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913 - return evoked if inst == "evoked" else request.getfixturevalue(inst) + if inst == "evoked": + return evoked + elif inst == "average_tfr": + return average_tfr + return request.getfixturevalue(inst) @pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) @@ -137,7 +144,7 @@ def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" pytest.importorskip("h5io") fname = tmp_path / f"{inst}-spectrum.h5" - inst = _get_inst(inst, request, evoked) + inst = _get_inst(inst, request, evoked=evoked) orig = inst.compute_psd() orig.save(fname) loaded = read_spectrum(fname) @@ -159,12 +166,13 @@ def test_spectrum_reject_by_annot(raw): Cannot use raw_spectrum fixture here because we're testing reject_by_annotation in .compute_psd() method. """ - spect_no_annot = raw.compute_psd() + kw = dict(n_per_seg=512) # smaller than shortest good span, to avoid warning + spect_no_annot = raw.compute_psd(**kw) raw.set_annotations(Annotations([1, 5], [3, 3], ["test", "test"])) - spect_benign_annot = raw.compute_psd() + spect_benign_annot = raw.compute_psd(**kw) raw.annotations.description = np.array(["bad_test", "bad_test"]) - spect_reject_annot = raw.compute_psd() - spect_ignored_annot = raw.compute_psd(reject_by_annotation=False) + spect_reject_annot = raw.compute_psd(**kw) + spect_ignored_annot = raw.compute_psd(**kw, reject_by_annotation=False) # the only one that should be different is `spect_reject_annot` assert spect_no_annot == spect_benign_annot assert spect_no_annot == spect_ignored_annot @@ -205,78 +213,6 @@ def test_epochs_spectrum_average(epochs_spectrum, method): assert avg_spect._dims == ("channel", "freq") # no 'epoch' -def _agg_helper(df, weights, group_cols): - """Aggregate complex multitaper spectrum after conversion to DataFrame.""" - from pandas import Series - - unagged_columns = df[group_cols].iloc[0].values.tolist() - x_mt = df.drop(columns=group_cols).values[np.newaxis].T - psd = _psd_from_mt(x_mt, weights) - psd = np.atleast_1d(np.squeeze(psd)).tolist() - _df = dict(zip(df.columns, unagged_columns + psd)) - return Series(_df) - - -@pytest.mark.filterwarnings("ignore:Complex output support.*:FutureWarning") -@pytest.mark.parametrize("long_format", (False, True)) -@pytest.mark.parametrize( - "method, output", - [ - ("welch", "complex"), - ("welch", "power"), - ("multitaper", "complex"), - ], -) -def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): - """Test converting complex multitaper spectra to data frame.""" - pytest.importorskip("pandas") - from pandas.testing import assert_frame_equal - - from mne.utils.dataframe import _inplace - - # aggregated spectrum → dataframe - orig_df = raw.compute_psd(method=method).to_data_frame(long_format=long_format) - # unaggregated welch or complex multitaper → - # aggregate w/ pandas (to make sure we did reshaping right) - kwargs = dict() - if method == "welch": - kwargs.update(average=False, verbose="error") - spectrum = raw.compute_psd(method=method, output=output, **kwargs) - df = spectrum.to_data_frame(long_format=long_format) - grouping_cols = ["freq"] - drop_cols = ["segment"] if method == "welch" else ["taper"] - if long_format: - grouping_cols.append("channel") - drop_cols.append("ch_type") - orig_df.drop(columns="ch_type", inplace=True) - # only do a couple freq bins, otherwise test takes forever for multitaper - subset = partial(np.isin, test_elements=spectrum.freqs[:2]) - df = df.loc[subset(df["freq"])] - orig_df = orig_df.loc[subset(orig_df["freq"])] - # sort orig_df, because at present we can't actually prevent pandas from - # sorting at the agg step *sigh* - _inplace(orig_df, "sort_values", by=grouping_cols, ignore_index=True) - # aggregate - df = df.drop(columns=drop_cols) - gb = df.groupby(grouping_cols, as_index=False, observed=False) - if method == "welch": - if output == "complex": - - def _fun(x): - return np.nanmean(np.abs(x)) - - agg_df = gb.agg(_fun) - else: - agg_df = gb.mean() # excludes missing values itself - else: - gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 - agg_df = gb.apply(_agg_helper, spectrum._mt_weights, grouping_cols) - # even with check_categorical=False, we know that the *data* matches; - # what may differ is the order of the "levels" in the *metadata* for the - # channel name column - assert_frame_equal(agg_df, orig_df, check_categorical=False) - - @pytest.mark.parametrize("inst", ("raw_spectrum", "epochs_spectrum", "evoked")) def test_spectrum_to_data_frame(inst, request, evoked): """Test the to_data_frame method for Spectrum.""" @@ -286,7 +222,7 @@ def test_spectrum_to_data_frame(inst, request, evoked): # setup is_already_psd = inst in ("raw_spectrum", "epochs_spectrum") is_epochs = inst == "epochs_spectrum" - inst = _get_inst(inst, request, evoked) + inst = _get_inst(inst, request, evoked=evoked) extra_dim = () if is_epochs else (1,) extra_cols = ["freq", "condition", "epoch"] if is_epochs else ["freq"] # compute PSD @@ -339,74 +275,18 @@ def test_spectrum_proj(inst, request): assert has_proj == no_proj -@pytest.mark.filterwarnings("ignore:Complex output support.*:FutureWarning") -@pytest.mark.parametrize( - "method, average", - [ - ("welch", False), - ("welch", "mean"), - ("multitaper", False), - ], -) -def test_spectrum_complex(method, average): - """Test output='complex' support.""" - sfreq = 100 - n = 10 * sfreq - freq = 3.0 - phase = np.pi / 4 # should be recoverable - data = np.cos(2 * np.pi * freq * np.arange(n) / sfreq + phase)[np.newaxis] - raw = RawArray(data, create_info(1, sfreq, "eeg")) - epochs = make_fixed_length_epochs(raw, duration=2.0, preload=True) - assert len(epochs) == 5 - assert len(epochs.times) == 2 * sfreq - kwargs = dict(output="complex", method=method) - if method == "welch": - kwargs["n_fft"] = sfreq - want_dims = ("epoch", "channel", "freq") - want_shape = (5, 1, sfreq // 2 + 1) - if not average: - want_dims = want_dims + ("segment",) - want_shape = want_shape + (2,) - kwargs["average"] = average - else: - assert method == "multitaper" - assert not average - want_dims = ("epoch", "channel", "taper", "freq") - want_shape = (5, 1, 7, sfreq + 1) - spectrum = epochs.compute_psd(**kwargs) - idx = np.argmin(np.abs(spectrum.freqs - freq)) - assert spectrum.freqs[idx] == freq - assert spectrum._dims == want_dims - assert spectrum.shape == want_shape - data = spectrum.get_data() - assert data.dtype == np.complex128 - coef = spectrum.get_data(fmin=freq, fmax=freq).mean(0) - if method == "multitaper": - coef = coef[..., 0, :] # first taper - elif not average: - coef = coef.mean(-1) # over segments - coef = coef.item() - assert_allclose(np.angle(coef), phase, rtol=1e-4) - # Now test that it warns appropriately - epochs._data[0, 0, :] = 0 # actually zero for one epoch and ch - with pytest.warns(UserWarning, match="Zero value.*channel 0"): - epochs.compute_psd(**kwargs) - # But not if we mark that channel as bad - epochs.info["bads"] = epochs.ch_names[:1] - epochs.compute_psd(**kwargs) - - def test_spectrum_kwarg_triaging(raw): """Test kwarg triaging in legacy plot_psd() method.""" import matplotlib.pyplot as plt regex = r"legacy plot_psd\(\) method.*unexpected keyword.*'axes'.*Try rewriting" - fig, axes = plt.subplots(1, 2) + _, axes = plt.subplots(1, 2) # `axes` is the new param name: technically only valid for Spectrum.plot() - with pytest.warns(RuntimeWarning, match=regex): + with _record_warnings(), pytest.warns(RuntimeWarning, match=regex): raw.plot_psd(axes=axes) # `ax` is the correct legacy param name - raw.plot_psd(ax=axes) + with pytest.warns(FutureWarning, match="amplitude='auto'"): + raw.plot_psd(ax=axes) def _check_spectrum_equivalent(spect1, spect2, tmp_path): diff --git a/mne/time_frequency/tests/test_stockwell.py b/mne/time_frequency/tests/test_stockwell.py index 96b2c064801..54d71b907ed 100644 --- a/mne/time_frequency/tests/test_stockwell.py +++ b/mne/time_frequency/tests/test_stockwell.py @@ -29,7 +29,7 @@ ) from mne.utils import _record_warnings -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" raw_ctf_fname = base_dir / "test_ctf_raw.fif" @@ -87,7 +87,7 @@ def test_stockwell_core(): width = 0.5 freqs = fftpack.fftfreq(len(pulse), 1.0 / sfreq) fmin, fmax = 1.0, 100.0 - start_f, stop_f = [np.abs(freqs - f).argmin() for f in (fmin, fmax)] + start_f, stop_f = (np.abs(freqs - f).argmin() for f in (fmin, fmax)) W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width) st_pulse = _st(pulse, start_f, W) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 33e82d5a126..cedc13a479b 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -8,25 +8,36 @@ import matplotlib.pyplot as plt import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_equal -from scipy.signal import morlet2 +from matplotlib.collections import PathCollection +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) import mne from mne import ( Epochs, EpochsArray, - Info, - Transform, create_info, pick_types, read_events, ) +from mne.epochs import equalize_epoch_counts from mne.io import read_raw_fif -from mne.tests.test_epochs import assert_metadata_equal -from mne.time_frequency import tfr_array_morlet, tfr_array_multitaper -from mne.time_frequency.tfr import ( +from mne.time_frequency import ( AverageTFR, + AverageTFRArray, + EpochsSpectrum, EpochsTFR, + EpochsTFRArray, + RawTFR, + RawTFRArray, + tfr_array_morlet, + tfr_array_multitaper, +) +from mne.time_frequency.tfr import ( _compute_tfr, _make_dpss, combine_tfr, @@ -39,13 +50,42 @@ write_tfrs, ) from mne.utils import catch_logging, grand_average -from mne.viz.utils import _fake_click, _fake_keypress, _fake_scroll +from mne.utils._testing import _get_suptitle +from mne.viz.utils import ( + _channel_type_prettyprint, + _fake_click, + _fake_keypress, + _fake_scroll, +) + +from .test_spectrum import _get_inst -data_path = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" raw_ctf_fname = data_path / "test_ctf_raw.fif" +freqs_linspace = np.linspace(20, 40, num=5) +freqs_unsorted_list = [26, 33, 41, 20] +mag_names = [f"MEG 01{n}1" for n in (1, 2, 3)] + +parametrize_morlet_multitaper = pytest.mark.parametrize( + "method", ("morlet", "multitaper") +) +parametrize_power_phase_complex = pytest.mark.parametrize( + "output", ("power", "phase", "complex") +) +parametrize_inst_and_ch_type = pytest.mark.parametrize( + "inst,ch_type", + ( + pytest.param("raw_tfr", "mag"), + pytest.param("raw_tfr", "grad"), + pytest.param("epochs_tfr", "mag"), # no grad pairs in epochs fixture + pytest.param("average_tfr", "mag"), + pytest.param("average_tfr", "grad"), + ), +) + def test_tfr_ctf(): """Test that TFRs can be calculated on CTF data.""" @@ -57,6 +97,15 @@ def test_tfr_ctf(): method(epochs, [10], 1) # smoke test +# Copied from SciPy before it was removed +def _morlet2(M, s, w=5): + x = np.arange(0, M) - (M - 1.0) / 2 + x = x / s + wavelet = np.exp(1j * w * x) * np.exp(-0.5 * x**2) * np.pi ** (-0.25) + output = np.sqrt(1 / s) * wavelet + return output + + @pytest.mark.parametrize("sfreq", [1000.0, 100 + np.pi]) @pytest.mark.parametrize("freq", [10.0, np.pi]) @pytest.mark.parametrize("n_cycles", [7, 2]) @@ -77,7 +126,7 @@ def test_morlet(sfreq, freq, n_cycles): M = len(W) w = n_cycles s = w * sfreq / (2 * freq * np.pi) # from SciPy docs - Ws = morlet2(M, s, w) * np.sqrt(2) + Ws = _morlet2(M, s, w) * np.sqrt(2) assert_allclose(W, Ws) # Check FWHM @@ -88,7 +137,7 @@ def test_morlet(sfreq, freq, n_cycles): assert_allclose(fwhm_formula, fwhm_empirical, atol=3 / sfreq) -def test_time_frequency(): +def test_tfr_morlet(): """Test time-frequency transform (PSD and ITC).""" # Set parameters event_id = 1 @@ -125,7 +174,8 @@ def test_time_frequency(): # Now compute evoked evoked = epochs.average() - pytest.raises(ValueError, tfr_morlet, evoked, freqs, 1.0, return_itc=True) + with pytest.raises(ValueError, match="Inter-trial coherence is not supported with"): + tfr_morlet(evoked, freqs, n_cycles=1.0, return_itc=True) power, itc = tfr_morlet( epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True, return_itc=True ) @@ -519,208 +569,205 @@ def test_tfr_multitaper(): tfr_multitaper(epochs, freqs=np.arange(-4, -1), n_cycles=7) -def test_crop(): - """Test TFR cropping.""" - data = np.zeros((3, 4, 5)) - times = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) - freqs = np.array([0.10, 0.20, 0.30, 0.40]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - tfr = AverageTFR( - info, - data=data, - times=times, - freqs=freqs, - nave=20, - comment="test", - method="crazy-tfr", - ) +@pytest.mark.parametrize( + "method,freqs", + ( + pytest.param("morlet", freqs_linspace, id="morlet"), + pytest.param("multitaper", freqs_linspace, id="multitaper"), + pytest.param("stockwell", freqs_linspace[[0, -1]], id="stockwell"), + ), +) +@pytest.mark.parametrize("decim", (4, slice(0, 200), slice(1, 200, 3))) +def test_tfr_decim_and_shift_time(epochs, method, freqs, decim): + """Test TFR decimation; slices must be long-ish to be longer than the wavelets.""" + tfr = epochs.compute_tfr(method, freqs=freqs, decim=decim) + if not isinstance(decim, slice): + decim = slice(None, None, decim) + # check n_times + want = len(range(*decim.indices(len(epochs.times)))) + assert tfr.shape[-1] == want + # Check that decim changes sfreq + assert tfr.sfreq == epochs.info["sfreq"] / (decim.step or 1) + # check after-the-fact decimation. The mixin .decimate method doesn't allow slices + if isinstance(decim, int): + tfr2 = epochs.compute_tfr(method, freqs=freqs, decim=1) + tfr2.decimate(decim) + assert tfr == tfr2 + # test .shift_time() too + shift = -0.137 + data, times, freqs = tfr.get_data(return_times=True, return_freqs=True) + tfr.shift_time(shift, relative=True) + assert_allclose(times + shift, tfr.times, rtol=0, atol=0.5 / tfr.sfreq) + # shift time should only affect times: + assert_array_equal(data, tfr.get_data()) + assert_array_equal(freqs, tfr.freqs) + + +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_io(inst, average_tfr, request, tmp_path): + """Test TFR I/O.""" + pytest.importorskip("h5io") + pd = pytest.importorskip("pandas") - tfr.crop(tmin=0.2) - assert_array_equal(tfr.times, [0.2, 0.3, 0.4, 0.5]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-1] == 4 - - tfr.crop(fmax=0.3) - assert_array_equal(tfr.freqs, [0.1, 0.2, 0.3]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-2] == 3 - - tfr.crop(tmin=0.3, tmax=0.4, fmin=0.1, fmax=0.2) - assert_array_equal(tfr.times, [0.3, 0.4]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-1] == 2 - assert_array_equal(tfr.freqs, [0.1, 0.2]) - assert tfr.data.shape[-2] == 2 - - -def test_decim_shift_time(): - """Test TFR decimation and shift_time.""" - data = np.zeros((3, 3, 3, 1000)) - times = np.linspace(0, 1, 1000) - freqs = np.array([0.10, 0.20, 0.30]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fname = tmp_path / "temp_tfr.hdf5" + # test .save() method + tfr.save(fname, overwrite=True) + assert read_tfrs(fname) == tfr + # test save single TFR with write_tfrs() + write_tfrs(fname, tfr, overwrite=True) + assert read_tfrs(fname) == tfr + # test save multiple TFRs with write_tfrs() + tfr2 = tfr.copy() + tfr2._data = np.zeros_like(tfr._data) + write_tfrs(fname, [tfr, tfr2], overwrite=True) + tfr_list = read_tfrs(fname) + assert tfr_list[0] == tfr + assert tfr_list[1] == tfr2 + # test condition-related errors + if isinstance(tfr, AverageTFR): + # auto-generated keys: first TFR has comment, so `0` not assigned + tfr2.comment = None + write_tfrs(fname, [tfr, tfr2], overwrite=True) + with pytest.raises(ValueError, match='Cannot find condition "0" in this'): + read_tfrs(fname, condition=0) + # second TFR had no comment, so should get auto-comment `1` assigned + read_tfrs(fname, condition=1) + return + else: + with pytest.raises(NotImplementedError, match="condition is only supported"): + read_tfrs(fname, condition="foo") + # the rest we only do for EpochsTFR (no need to parametrize) + if isinstance(tfr, RawTFR): + return + # make sure everything still works if there's metadata + tfr.metadata = pd.DataFrame(dict(foo=range(tfr.shape[0])), index=tfr.selection) + # test old-style meas date + sec_microsec_tuple = (1, 2) + with tfr.info._unlock(): + tfr.info["meas_date"] = sec_microsec_tuple + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + want = datetime.datetime( + year=1970, + month=1, + day=1, + hour=0, + minute=0, + second=sec_microsec_tuple[0], + microsecond=sec_microsec_tuple[1], + tzinfo=datetime.timezone.utc, ) - with info._unlock(): - info["lowpass"] = 100 - tfr = EpochsTFR(info, data=data, times=times, freqs=freqs) - tfr_ave = tfr.average() - assert_allclose(tfr.times, tfr_ave.times) - assert not hasattr(tfr_ave, "first") - tfr_ave.decimate(3) - assert not hasattr(tfr_ave, "first") - tfr.decimate(3) - assert tfr.times.size == 1000 // 3 + 1 - assert tfr.data.shape == ((3, 3, 3, 1000 // 3 + 1)) - tfr_ave_2 = tfr.average() - assert not hasattr(tfr_ave_2, "first") - assert_allclose(tfr.times, tfr_ave.times) - assert_allclose(tfr.times, tfr_ave_2.times) - assert_allclose(tfr_ave_2.data, tfr_ave.data) - tfr.shift_time(-0.1, relative=True) - tfr_ave.shift_time(-0.1, relative=True) - tfr_ave_3 = tfr.average() - assert_allclose(tfr_ave_3.times, tfr_ave.times) - assert_allclose(tfr_ave_3.data, tfr_ave.data) - assert_allclose(tfr_ave_2.data, tfr_ave_3.data) # data unchanged - - -def test_io(tmp_path): - """Test TFR IO capacities.""" - pd = pytest.importorskip("pandas") - pytest.importorskip("h5io") + assert tfr_loaded.info["meas_date"] == want + with tfr.info._unlock(): + tfr.info["meas_date"] = want + assert tfr_loaded == tfr + # test overwrite + with pytest.raises(OSError, match="Destination file exists."): + tfr.save(fname, overwrite=False) - fname = tmp_path / "test-tfr.h5" - data = np.zeros((3, 2, 3)) - times = np.array([0.1, 0.2, 0.3]) - freqs = np.array([0.10, 0.20]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - with info._unlock(check_after=True): - info["meas_date"] = datetime.datetime( - year=2020, month=2, day=5, tzinfo=datetime.timezone.utc - ) - tfr = AverageTFR( - info, - data=data, - times=times, - freqs=freqs, - nave=20, - comment="test", - method="crazy-tfr", +def test_roundtrip_from_legacy_func(epochs, tmp_path): + """Test save/load with TFRs generated by legacy method (gh-12512).""" + pytest.importorskip("h5io") + + fname = tmp_path / "temp_tfr.hdf5" + tfr = tfr_morlet( + epochs, freqs=freqs_linspace, n_cycles=7, average=True, return_itc=False ) - tfr.save(fname) - tfr2 = read_tfrs(fname, condition="test") - assert isinstance(tfr2.info, Info) - assert isinstance(tfr2.info["dev_head_t"], Transform) - - assert_array_equal(tfr.data, tfr2.data) - assert_array_equal(tfr.times, tfr2.times) - assert_array_equal(tfr.freqs, tfr2.freqs) - assert_equal(tfr.comment, tfr2.comment) - assert_equal(tfr.nave, tfr2.nave) - - pytest.raises(OSError, tfr.save, fname) - - tfr.comment = None - # test old meas_date - with info._unlock(): - info["meas_date"] = (1, 2) tfr.save(fname, overwrite=True) - assert_equal(read_tfrs(fname, condition=0).comment, tfr.comment) - tfr.comment = "test-A" - tfr2.comment = "test-B" - - fname = tmp_path / "test2-tfr.h5" - write_tfrs(fname, [tfr, tfr2]) - tfr3 = read_tfrs(fname, condition="test-A") - assert_equal(tfr.comment, tfr3.comment) - - assert isinstance(tfr.info, mne.Info) - - tfrs = read_tfrs(fname, condition=None) - assert_equal(len(tfrs), 2) - tfr4 = tfrs[1] - assert_equal(tfr2.comment, tfr4.comment) - - pytest.raises(ValueError, read_tfrs, fname, condition="nonono") - # Test save of EpochsTFR. - n_events = 5 - data = np.zeros((n_events, 3, 2, 3)) - - # create fake metadata - rng = np.random.RandomState(42) - rt = np.round(rng.uniform(size=(n_events,)), 3) - trialtypes = np.array(["face", "place"]) - trial = trialtypes[(rng.uniform(size=(n_events,)) > 0.5).astype(int)] - meta = pd.DataFrame(dict(RT=rt, Trial=trial)) - # fake events and event_id - events = np.zeros([n_events, 3]) - events[:, 0] = np.arange(n_events) - events[:, 2] = np.ones(n_events) - event_id = {"a/b": 1} - # fake selection - n_dropped_epochs = 3 - selection = np.arange(n_events + n_dropped_epochs)[n_dropped_epochs:] - drop_log = tuple( - [("IGNORED",) for i in range(n_dropped_epochs)] + [() for i in range(n_events)] + assert read_tfrs(fname) == tfr + + +def test_raw_tfr_init(raw): + """Test the RawTFR and RawTFRArray constructors.""" + one = RawTFR(inst=raw, method="morlet", freqs=freqs_linspace) + two = RawTFRArray(one.info, one.data, one.times, one.freqs, method="morlet") + # some attributes we know won't match: + for attr in ("_data_type", "_inst_type"): + assert getattr(one, attr) != getattr(two, attr) + delattr(one, attr) + delattr(two, attr) + assert one == two + # test RawTFR.__getitem__ + data = one[:5] + assert data.shape == (5,) + one.shape[1:] + # test missing method/freqs + with pytest.raises(ValueError, match="RawTFR got unsupported parameter value"): + RawTFR(inst=raw) + + +def test_average_tfr_init(full_evoked): + """Test the AverageTFR and AverageTFRArray constructors.""" + one = AverageTFR(inst=full_evoked, method="morlet", freqs=freqs_linspace) + two = AverageTFRArray( + one.info, + one.data, + one.times, + one.freqs, + method="morlet", + comment=one.comment, + nave=one.nave, ) - - tfr = EpochsTFR( - info, - data=data, - times=times, - freqs=freqs, - comment="test", - method="crazy-tfr", - events=events, - event_id=event_id, - selection=selection, - drop_log=drop_log, - metadata=meta, - ) - fname_save = fname - tfr.save(fname_save, True) - fname_write = tmp_path / "test3-tfr.h5" - write_tfrs(fname_write, tfr, overwrite=True) - for fname in [fname_save, fname_write]: - read_tfr = read_tfrs(fname)[0] - assert_array_equal(tfr.data, read_tfr.data) - assert_metadata_equal(tfr.metadata, read_tfr.metadata) - assert_array_equal(tfr.events, read_tfr.events) - assert tfr.event_id == read_tfr.event_id - assert_array_equal(tfr.selection, read_tfr.selection) - assert tfr.drop_log == read_tfr.drop_log - with pytest.raises(NotImplementedError, match="condition not supported"): - tfr = read_tfrs(fname, condition="a") - - -def test_init_EpochsTFR(): + # some attributes we know won't match, otherwise should be identical + assert one._data_type != two._data_type + one._data_type = two._data_type + assert one == two + # test missing method, bad freqs + with pytest.raises(ValueError, match="AverageTFR got unsupported parameter value"): + AverageTFR(inst=full_evoked) + with pytest.raises(ValueError, match='must be a length-2 iterable or "auto"'): + AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) + + +def test_epochstfr_init_errors(epochs_tfr): """Test __init__ for EpochsTFR.""" - # Create fake data: - data = np.zeros((3, 3, 3, 3)) - times = np.array([0.1, 0.2, 0.3]) - freqs = np.array([0.10, 0.20, 0.30]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - data_x = data[:, :, :, 0] - with pytest.raises(ValueError, match="data should be 4d. Got 3"): - tfr = EpochsTFR(info, data=data_x, times=times, freqs=freqs) - data_x = data[:, :-1, :, :] - with pytest.raises(ValueError, match="channels and data size don't"): - tfr = EpochsTFR(info, data=data_x, times=times, freqs=freqs) - times_x = times[:-1] - with pytest.raises(ValueError, match="times and data size don't match"): - tfr = EpochsTFR(info, data=data, times=times_x, freqs=freqs) - freqs_x = freqs[:-1] - with pytest.raises(ValueError, match="frequencies and data size don't"): - tfr = EpochsTFR(info, data=data, times=times_x, freqs=freqs_x) - del tfr + state = epochs_tfr.__getstate__() + with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): + EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) + with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): + EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): + EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): + EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + + +@pytest.mark.parametrize("inst", ("epochs_tfr", "average_tfr")) +def test_tfr_init_deprecation(inst, average_tfr, request): + """Check for the deprecation warning message (not needed for RawTFR, it's new).""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + kwargs = dict(info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs) + Klass = tfr.__class__ + with pytest.warns(FutureWarning, match='"info", "data", "times" are deprecat'): + Klass(**kwargs) + with pytest.raises(ValueError, match="Do not pass `inst` alongside deprecated"): + with pytest.warns(FutureWarning, match='"info", "data", "times" are deprecat'): + Klass(**kwargs, inst="foo") + + +@pytest.mark.parametrize( + "method,freqs,match", + ( + ("morlet", None, "EpochsTFR got unsupported parameter value freqs=None."), + (None, freqs_linspace, "got unsupported parameter value method=None."), + (None, None, "got unsupported parameter values method=None and freqs=None."), + ), +) +def test_compute_tfr_init_errors(epochs, method, freqs, match): + """Test that method and freqs are always passed (if not using __setstate__).""" + with pytest.raises(ValueError, match=match): + epochs.compute_tfr(method=method, freqs=freqs) + + +def test_equalize_epochs_tfr_counts(epochs_tfr): + """Test equalize_epoch_counts for EpochsTFR.""" + # make the fixture have 3 epochs instead of 1 + epochs_tfr._data = np.vstack((epochs_tfr._data, epochs_tfr._data, epochs_tfr._data)) + tfr2 = epochs_tfr.copy() + tfr2 = tfr2[:-1] + equalize_epoch_counts([epochs_tfr, tfr2]) + assert epochs_tfr.shape == tfr2.shape def test_dB_computation(): @@ -734,9 +781,9 @@ def test_dB_computation(): ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] ) kwargs = dict(times=times, freqs=freqs, nave=20, comment="test", method="crazy-tfr") - tfr = AverageTFR(info, data=data, **kwargs) - complex_tfr = AverageTFR(info, data=complex_data, **kwargs) - plot_kwargs = dict(dB=True, combine="mean", vmin=0, vmax=7) + tfr = AverageTFRArray(info=info, data=data, **kwargs) + complex_tfr = AverageTFRArray(info=info, data=complex_data, **kwargs) + plot_kwargs = dict(dB=True, combine="mean", vlim=(0, 7)) fig1 = tfr.plot(**plot_kwargs)[0] fig2 = complex_tfr.plot(**plot_kwargs)[0] # since we're fixing vmin/vmax, equal colors should mean ~equal input data @@ -754,8 +801,8 @@ def test_plot(): info = mne.create_info( ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] ) - tfr = AverageTFR( - info, + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, @@ -764,88 +811,6 @@ def test_plot(): method="crazy-tfr", ) - # test title=auto, combine=None, and correct length of figure list - picks = [1, 2] - figs = tfr.plot( - picks, title="auto", colorbar=False, mask=np.ones(tfr.data.shape[1:], bool) - ) - assert len(figs) == len(picks) - assert "MEG" in figs[0].texts[0].get_text() - plt.close("all") - - # test combine and title keyword - figs = tfr.plot( - picks, - title="title", - colorbar=False, - combine="rms", - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - assert figs[0].texts[0].get_text() == "title" - figs = tfr.plot( - picks, - title="auto", - colorbar=False, - combine="mean", - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - assert figs[0].texts[0].get_text() == "Mean of 2 sensors" - figs = tfr.plot( - picks, - title="auto", - colorbar=False, - combine=lambda x: x.mean(axis=0), - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - - with pytest.raises(ValueError, match="Invalid value for the 'combine'"): - tfr.plot( - picks, - colorbar=False, - combine="something", - mask=np.ones(tfr.data.shape[1:], bool), - ) - with pytest.raises(RuntimeError, match="must operate on a single"): - tfr.plot(picks, combine=lambda x, y: x.mean(axis=0)) - with pytest.raises(RuntimeError, match=re.escape("of shape (n_freqs, n_times).")): - tfr.plot(picks, combine=lambda x: x.mean(axis=0, keepdims=True)) - with pytest.raises( - RuntimeError, - match=re.escape("return a numpy array of shape (n_freqs, n_times)."), - ): - tfr.plot(picks, combine=lambda x: 101) - - plt.close("all") - - # test axes argument - first with list of axes - ax = plt.subplot2grid((2, 2), (0, 0)) - ax2 = plt.subplot2grid((2, 2), (0, 1)) - ax3 = plt.subplot2grid((2, 2), (1, 0)) - figs = tfr.plot(picks=[0, 1, 2], axes=[ax, ax2, ax3]) - assert len(figs) == len([ax, ax2, ax3]) - # and as a single axes - figs = tfr.plot(picks=[0], axes=ax) - assert len(figs) == 1 - plt.close("all") - # and invalid inputs - with pytest.raises(ValueError, match="axes must be None"): - tfr.plot(picks, colorbar=False, axes={}, mask=np.ones(tfr.data.shape[1:], bool)) - - # different number of axes and picks should throw a RuntimeError - with pytest.raises(RuntimeError, match="There must be an axes"): - tfr.plot( - picks=[0], - colorbar=False, - axes=[ax, ax2], - mask=np.ones(tfr.data.shape[1:], bool), - ) - - tfr.plot_topo(picks=[1, 2]) - plt.close("all") - # interactive mode on by default fig = tfr.plot(picks=[1], cmap="RdBu_r")[0] _fake_keypress(fig, "up") @@ -876,65 +841,76 @@ def test_plot(): plt.close("all") -def test_plot_joint(): - """Test TFR joint plotting.""" - raw = read_raw_fif(raw_fname) - times = np.linspace(-0.1, 0.1, 200) - n_freqs = 3 - nave = 1 - rng = np.random.RandomState(42) - data = rng.randn(len(raw.ch_names), n_freqs, len(times)) - tfr = AverageTFR(raw.info, data, times, np.arange(n_freqs), nave) - - topomap_args = {"res": 8, "contours": 0, "sensors": False} - - for combine in ("mean", "rms", lambda x: x.mean(axis=0)): - with catch_logging() as log: - tfr.plot_joint( - title="auto", - colorbar=True, - combine=combine, - topomap_args=topomap_args, - verbose="debug", - ) - plt.close("all") - log = log.getvalue() - assert "Plotting topomap for grad data" in log - - # check various timefreqs - for timefreqs in ( - { - (tfr.times[0], tfr.freqs[1]): (0.1, 0.5), - (tfr.times[-1], tfr.freqs[-1]): (0.2, 0.6), - }, - [(tfr.times[1], tfr.freqs[1])], - ): - tfr.plot_joint(timefreqs=timefreqs, topomap_args=topomap_args) - plt.close("all") - - # test bad timefreqs - timefreqs = ( - [(-100, 1)], - tfr.times[1], - [1], - [(tfr.times[1], tfr.freqs[1], tfr.freqs[1])], +@pytest.mark.parametrize( + "timefreqs,title,combine", + ( + pytest.param( + {(0.33, 23): (0, 0), (0.25, 30): (0.1, 2)}, + "0.25 ± 0.05 s,\n30.0 ± 1.0 Hz", + "mean", + id="dict,mean", + ), + pytest.param([(0.25, 30)], "0.25 s,\n30.0 Hz", "rms", id="list,rms"), + pytest.param(None, None, lambda x: x.mean(axis=0), id="none,lambda"), + ), +) +@parametrize_inst_and_ch_type +def test_tfr_plot_joint( + inst, ch_type, combine, timefreqs, title, full_average_tfr, request +): + """Test {Raw,Epochs,Average}TFR.plot_joint().""" + tfr = _get_inst(inst, request, average_tfr=full_average_tfr) + with catch_logging() as log: + fig = tfr.plot_joint( + picks=ch_type, + timefreqs=timefreqs, + combine=combine, + topomap_args=dict(res=8, contours=0, sensors=False), # for speed + verbose="debug", + ) + assert f"Plotting topomap for {ch_type} data" in log.getvalue() + # check for correct number of axes + n_topomaps = 1 if timefreqs is None else len(timefreqs) + assert len(fig.axes) == n_topomaps + 2 # n_topomaps + 1 image + 1 colorbar + # title varies by `ch_type` when `timefreqs=None`, so we don't test that here + if title is not None: + assert fig.axes[0].get_title() == title + # test interactivity + ax = [ax for ax in fig.axes if ax.get_xlabel() == "Time (s)"][0] + kw = dict(fig=fig, ax=ax, xform="ax") + _fake_click(**kw, kind="press", point=(0.4, 0.4)) + _fake_click(**kw, kind="motion", point=(0.5, 0.5)) + _fake_click(**kw, kind="release", point=(0.6, 0.6)) + # make sure we actually got a pop-up figure, and it has a plausible title + fignums = plt.get_fignums() + assert len(fignums) == 2 + popup_fig = plt.figure(fignums[-1]) + assert re.match( + r"-?\d{1,2}\.\d{3} - -?\d{1,2}\.\d{3} s,\n\d{1,2}\.\d{2} - \d{1,2}\.\d{2} Hz", + _get_suptitle(popup_fig), ) - for these_timefreqs in timefreqs: - pytest.raises(ValueError, tfr.plot_joint, these_timefreqs) - # test that the object is not internally modified - tfr_orig = tfr.copy() - tfr.plot_joint( - baseline=(0, None), exclude=[tfr.ch_names[0]], topomap_args=topomap_args - ) - plt.close("all") - assert_array_equal(tfr.data, tfr_orig.data) - assert set(tfr.ch_names) == set(tfr_orig.ch_names) - assert set(tfr.times) == set(tfr_orig.times) - # test tfr with picked channels - tfr.pick(tfr.ch_names[:-1]) - tfr.plot_joint(title="auto", colorbar=True, topomap_args=topomap_args) +@pytest.mark.parametrize( + "match,timefreqs,topomap_args", + ( + (r"Requested time point \(-88.000 s\) exceeds the range of", [(-88, 1)], None), + (r"Requested frequency \(99.0 Hz\) exceeds the range of", [(0.0, 99)], None), + ("list of tuple pairs, or a dict of such tuple pairs, not 0", [0.0], None), + ("does not match the channel type present in", None, dict(ch_type="eeg")), + ), +) +def test_tfr_plot_joint_errors(full_average_tfr, match, timefreqs, topomap_args): + """Test AverageTFR.plot_joint() error messages.""" + with pytest.raises(ValueError, match=match): + full_average_tfr.plot_joint(timefreqs=timefreqs, topomap_args=topomap_args) + + +def test_tfr_plot_joint_doesnt_modify(full_average_tfr): + """Test that the object is unchanged after plot_joint().""" + tfr = full_average_tfr.copy() + full_average_tfr.plot_joint() + assert tfr == full_average_tfr def test_add_channels(): @@ -947,8 +923,8 @@ def test_add_channels(): 1000.0, ["mag", "mag", "mag", "eeg", "eeg", "stim"], ) - tfr = AverageTFR( - info, + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, @@ -1168,13 +1144,12 @@ def test_averaging_epochsTFR(): avgpower = power.average(method=method) assert_array_equal(func(power.data, axis=0), avgpower.data) with pytest.raises( - RuntimeError, match="You passed a function that " "resulted in data" + RuntimeError, match=r"EpochsTFR.average\(\) got .* shape \(\), but it should be" ): power.average(method=np.mean) -@pytest.mark.parametrize("copy", [True, False]) -def test_averaging_freqsandtimes_epochsTFR(copy): +def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" # Setup for reading the raw data event_id = 1 @@ -1209,138 +1184,60 @@ def test_averaging_freqsandtimes_epochsTFR(copy): return_itc=False, ) - # Test average methods for freqs and times - for idx, (func, method) in enumerate( - zip( - [np.mean, np.median, np.mean, np.mean], - [ - "mean", - "median", - lambda x: np.mean(x, axis=2), - lambda x: np.mean(x, axis=3), - ], - ) + # Test averaging over freqs + kwargs = dict(dim="freqs", copy=True) + for method, func in zip( + ("mean", "median", lambda x: np.mean(x, axis=2)), (np.mean, np.median, np.mean) ): - if idx == 3: - with pytest.raises(RuntimeError, match="You passed a function"): - avgpower = power.copy().average(method=method, dim="freqs", copy=copy) - continue - avgpower = power.copy().average(method=method, dim="freqs", copy=copy) - assert_array_equal(func(power.data, axis=2, keepdims=True), avgpower.data) - assert avgpower.freqs == np.mean(power.freqs) + avgpower = power.average(method=method, **kwargs) + assert_array_equal(avgpower.data, func(power.data, axis=2, keepdims=True)) + assert_array_equal(avgpower.freqs, func(power.freqs, keepdims=True)) assert isinstance(avgpower, EpochsTFR) - - # average over epochs - avgpower = avgpower.average() + avgpower = avgpower.average() # average over epochs assert isinstance(avgpower, AverageTFR) - - # Test average methods for freqs and times - for idx, (func, method) in enumerate( - zip( - [np.mean, np.median, np.mean, np.mean], - [ - "mean", - "median", - lambda x: np.mean(x, axis=3), - lambda x: np.mean(x, axis=2), - ], - ) + with pytest.raises(RuntimeError, match=r"shape \(1, 2, 3\), but it should"): + # collapsing wrong axis (time instead of freq) + avgpower = power.average(method=lambda x: np.mean(x, axis=3), **kwargs) + + # Test averaging over times + kwargs = dict(dim="times", copy=False) + for method, func in zip( + ("mean", "median", lambda x: np.mean(x, axis=3)), (np.mean, np.median, np.mean) ): - if idx == 3: - with pytest.raises(RuntimeError, match="You passed a function"): - avgpower = power.copy().average(method=method, dim="times", copy=copy) - continue - avgpower = power.copy().average(method=method, dim="times", copy=copy) - assert_array_equal(func(power.data, axis=-1, keepdims=True), avgpower.data) - assert avgpower.times == np.mean(power.times) - assert isinstance(avgpower, EpochsTFR) - - # average over epochs - avgpower = avgpower.average() - assert isinstance(avgpower, AverageTFR) + avgpower = power.average(method=method, **kwargs) + assert_array_equal(avgpower.data, func(power.data, axis=-1, keepdims=False)) + assert isinstance(avgpower, EpochsSpectrum) + with pytest.raises(RuntimeError, match=r"shape \(1, 2, 420\), but it should"): + # collapsing wrong axis (freq instead of time) + avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs) -def test_getitem_epochsTFR(): - """Test GetEpochsMixin in the context of EpochsTFR.""" +@pytest.mark.parametrize("n_drop", (0, 2)) +def test_epochstfr_getitem(epochs_full, n_drop): + """Test EpochsTFR.__getitem__().""" pd = pytest.importorskip("pandas") - - # Setup for reading the raw data and select a few trials - raw = read_raw_fif(raw_fname) - events = read_events(event_fname) - # create fake data, test with and without dropping epochs - for n_drop_epochs in [0, 2]: - n_events = 12 - # create fake metadata - rng = np.random.RandomState(42) - rt = rng.uniform(size=(n_events,)) - trialtypes = np.array(["face", "place"]) - trial = trialtypes[(rng.uniform(size=(n_events,)) > 0.5).astype(int)] - meta = pd.DataFrame(dict(RT=rt, Trial=trial)) - event_id = dict(a=1, b=2, c=3, d=4) - epochs = Epochs( - raw, events[:n_events], event_id=event_id, metadata=meta, decim=1 - ) - epochs.drop(np.arange(n_drop_epochs)) - n_events -= n_drop_epochs - - freqs = np.arange(12.0, 17.0, 2.0) # define frequencies of interest - n_cycles = freqs / 2.0 # 0.5 second time windows for all frequencies - - # Choose time x (full) bandwidth product - time_bandwidth = 4.0 - # With 0.5 s time windows, this gives 8 Hz smoothing - kwargs = dict( - freqs=freqs, - n_cycles=n_cycles, - use_fft=True, - time_bandwidth=time_bandwidth, - return_itc=False, - average=False, - n_jobs=None, - ) - power = tfr_multitaper(epochs, **kwargs) - - # Check that power and epochs metadata is the same - assert_metadata_equal(epochs.metadata, power.metadata) - assert_metadata_equal(epochs[::2].metadata, power[::2].metadata) - assert_metadata_equal(epochs["RT < .5"].metadata, power["RT < .5"].metadata) - assert_array_equal(epochs.selection, power.selection) - assert epochs.drop_log == power.drop_log - - # Check that get power is functioning - assert_array_equal(power[3:6].data, power.data[3:6]) - assert_array_equal(power[3:6].events, power.events[3:6]) - assert_array_equal(epochs.selection[3:6], power.selection[3:6]) - - indx_check = power.metadata["Trial"] == "face" - try: - indx_check = indx_check.to_numpy() - except Exception: - pass # older Pandas - indx_check = indx_check.nonzero() - assert_array_equal(power['Trial == "face"'].events, power.events[indx_check]) - assert_array_equal(power['Trial == "face"'].data, power.data[indx_check]) - - # Check that the wrong Key generates a Key Error for Metadata search - with pytest.raises(KeyError): - power['Trialz == "place"'] - - # Test length function - assert len(power) == n_events - assert len(power[3:6]) == 3 - - # Test iteration function - for ind, power_ep in enumerate(power): - assert_array_equal(power_ep, power.data[ind]) - if ind == 5: - break - - # Test that current state is maintained - assert_array_equal(power.next(), power.data[ind + 1]) - - # Check decim affects sfreq - power_decim = tfr_multitaper(epochs, decim=2, **kwargs) - assert power.info["sfreq"] / 2.0 == power_decim.info["sfreq"] + from pandas.testing import assert_frame_equal + + epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7))) + epochs_full.drop(np.arange(n_drop)) + tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace) + # check that various attributes are preserved + assert_frame_equal(tfr.metadata, epochs_full.metadata) + assert epochs_full.drop_log == tfr.drop_log + for attr in ("events", "selection", "times"): + assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) + # test pandas query + foo_a = tfr["foo == 'a'"] + bar_3 = tfr["bar <= 3"] + assert foo_a == bar_3 + assert foo_a.shape[0] == 4 - n_drop + # test integer and slice + subset_ints = tfr[[0, 1, 2]] + subset_slice = tfr[:3] + assert subset_ints == subset_slice + # test iteration + for ix, epo in enumerate(tfr): + assert_array_equal(tfr[ix].data, epo.data.obj[np.newaxis]) def test_to_data_frame(): @@ -1362,8 +1259,13 @@ def test_to_data_frame(): events[:, 2] = np.arange(5, 5 + n_epos) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, srate, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1446,14 +1348,19 @@ def test_to_data_frame_index(index): events[:, 2] = np.arange(5, 8) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, 1000.0, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation if not isinstance(index, list): index = [index] - assert df.index.names == index + assert list(df.index.names) == index # test that non-indexed data were present as columns non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) if len(non_index): @@ -1471,17 +1378,339 @@ def test_to_data_frame_time_format(time_format): n_freqs = 5 n_times = 6 data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + times = np.arange(6, dtype=float) freqs = np.arange(5) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, 1000.0, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) # test time_format df = tfr.to_data_frame(time_format=time_format) dtypes = {None: np.float64, "ms": np.int64, "timedelta": pd.Timedelta} assert isinstance(df["time"].iloc[0], dtypes[time_format]) + + +@parametrize_morlet_multitaper +@parametrize_power_phase_complex +@pytest.mark.parametrize("picks", ("mag", mag_names, [2, 5, 8])) # all 3 equivalent +def test_raw_compute_tfr(raw, method, output, picks, tmp_path): + """Test Raw.compute_tfr() and picks handling.""" + full_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace) + pick_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace, picks=picks) + assert isinstance(pick_tfr, RawTFR), type(pick_tfr) + # ↓↓↓ can't use [2,5,8] because ch0 is IAS, so indices change between raw and TFR + want = full_tfr.get_data(picks=mag_names) + got = pick_tfr.get_data() + assert_array_equal(want, got) + # make sure save/load works for phase/complex data + if output in ("phase", "complex"): + pytest.importorskip("h5io") + fname = tmp_path / "temp_tfr.hdf5" + full_tfr.save(fname, overwrite=True) + assert read_tfrs(fname) == full_tfr + + +@parametrize_morlet_multitaper +@parametrize_power_phase_complex +@pytest.mark.parametrize("freqs", (freqs_linspace, freqs_unsorted_list)) +def test_evoked_compute_tfr(full_evoked, method, output, freqs): + """Test Evoked.compute_tfr(), with a few different ways of specifying freqs.""" + tfr = full_evoked.compute_tfr(method, freqs, output=output) + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == full_evoked.nave + assert tfr.comment == full_evoked.comment + + +@parametrize_morlet_multitaper +@pytest.mark.parametrize( + "average,return_itc,dim,want_class", + ( + pytest.param(True, False, None, None, id="average,no_itc"), + pytest.param(True, True, None, None, id="average,itc"), + pytest.param(False, False, "freqs", EpochsTFR, id="no_average,agg_freqs"), + pytest.param(False, False, "epochs", AverageTFR, id="no_average,agg_epochs"), + pytest.param(False, False, "times", EpochsSpectrum, id="no_average,agg_times"), + ), +) +def test_epochs_compute_tfr_average_itc( + epochs, method, average, return_itc, dim, want_class +): + """Test Epochs.compute_tfr(), averaging (at call time and afterward), and ITC.""" + tfr = epochs.compute_tfr( + method, freqs=freqs_linspace, average=average, return_itc=return_itc + ) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR), type(itc) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + # if not averaging initially, make sure the post-facto .average() works too + if average: + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == 1 + assert tfr.comment == "1" + else: + assert isinstance(tfr, EpochsTFR), type(tfr) + avg = tfr.average(dim=dim) + assert isinstance(avg, want_class), type(avg) + if dim == "epochs": + assert avg.nave == len(epochs) + assert avg.comment.startswith(f"mean of {len(epochs)} EpochsTFR") + + +def test_epochs_vs_evoked_compute_tfr(epochs): + """Compare result of averaging before or after the TFR computation. + + This is mostly a test of object structure / attribute preservation. In normal cases, + the data should not match: + - epochs.compute_tfr().average() is average of squared magnitudes + - epochs.average().compute_tfr() is squared magnitude of average + But the `epochs` fixture has only one epoch, so here data should be identical too. + + The three things that will always end up different are `._comment`, `._inst_type`, + and `._data_type`, so we ignore those here. + """ + avg_first = epochs.average().compute_tfr(method="morlet", freqs=freqs_linspace) + avg_second = epochs.compute_tfr(method="morlet", freqs=freqs_linspace).average() + for attr in ("_comment", "_inst_type", "_data_type"): + assert getattr(avg_first, attr) != getattr(avg_second, attr) + delattr(avg_first, attr) + delattr(avg_second, attr) + assert avg_first == avg_second + + +morlet_kw = dict(n_cycles=freqs_linspace / 4, use_fft=False, zero_mean=True) +mt_kw = morlet_kw | dict(zero_mean=False, time_bandwidth=6) +stockwell_kw = dict(n_fft=1024, width=2) + + +@pytest.mark.parametrize( + "method,freqs,method_kw", + ( + pytest.param("morlet", freqs_linspace, morlet_kw, id="morlet-nondefaults"), + pytest.param("multitaper", freqs_linspace, mt_kw, id="multitaper-nondefaults"), + pytest.param("stockwell", "auto", stockwell_kw, id="stockwell-nondefaults"), + ), +) +def test_epochs_compute_tfr_method_kw(epochs, method, freqs, method_kw): + """Test Epochs.compute_tfr(**method_kw).""" + tfr = epochs.compute_tfr(method, freqs=freqs, average=True, **method_kw) + assert isinstance(tfr, AverageTFR), type(tfr) + + +@pytest.mark.parametrize( + "freqs", + (pytest.param("auto", id="freqauto"), pytest.param([20, 41], id="fminfmax")), +) +@pytest.mark.parametrize("return_itc", (False, True)) +def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): + """Test Epochs.compute_tfr(method="stockwell").""" + tfr = epochs.compute_tfr("stockwell", freqs, return_itc=return_itc) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + assert isinstance(tfr, AverageTFR) + assert tfr.comment == "1" + + +@pytest.mark.parametrize("copy", (False, True)) +def test_epochstfr_iter_evoked(epochs_tfr, copy): + """Test EpochsTFR.iter_evoked().""" + avgs = list(epochs_tfr.iter_evoked(copy=copy)) + assert len(avgs) == len(epochs_tfr) + assert all(avg.nave == 1 for avg in avgs) + assert avgs[0].comment == str(epochs_tfr.events[0, -1]) + + +def test_tfr_proj(epochs): + """Test `compute_tfr(proj=True)`.""" + epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) + + +def test_tfr_copy(average_tfr): + """Test BaseTFR.copy() method.""" + tfr_copy = average_tfr.copy() + # check that info is independent + tfr_copy.info["bads"] = tfr_copy.ch_names + assert average_tfr.info["bads"] == [] + # check that data is independent + tfr_copy.data = np.inf + assert np.isfinite(average_tfr.get_data()).all() + + +@pytest.mark.parametrize( + "mode", ("mean", "ratio", "logratio", "percent", "zscore", "zlogratio") +) +def test_tfr_apply_baseline(average_tfr, mode): + """Test TFR baselining.""" + average_tfr.apply_baseline((-0.1, -0.05), mode=mode) + + +def test_tfr_arithmetic(epochs): + """Test TFR arithmetic operations.""" + tfr, itc = epochs.compute_tfr( + "morlet", freqs=freqs_linspace, average=True, return_itc=True + ) + itc_copy = itc.copy() + # addition / subtraction of objects + double = tfr + tfr + double -= tfr + assert tfr == double + itc_copy += tfr + assert itc == itc_copy - tfr + # multiplication / division with scalars + bigger_itc = itc * 23 + assert_array_almost_equal(itc.data, (bigger_itc / 23).data, decimal=15) + # multiplication / division with arrays + arr = np.full_like(itc.data, 23) + assert_array_equal(bigger_itc.data, (itc * arr).data) + # in-place multiplication/division + bigger_itc *= 2 + bigger_itc /= 46 + assert_array_almost_equal(itc.data, bigger_itc.data, decimal=15) + # check errors + with pytest.raises(RuntimeError, match="types do not match"): + tfr + epochs + with pytest.raises(RuntimeError, match="times do not match"): + tfr + tfr.copy().crop(tmax=0.2) + with pytest.raises(RuntimeError, match="freqs do not match"): + tfr + tfr.copy().crop(fmax=33) + + +def test_tfr_repr_html(epochs_tfr): + """Test TFR._repr_html_().""" + result = epochs_tfr._repr_html_(caption="Foo") + for heading in ("Data type", "Data source", "Estimation method"): + assert f"{heading}" in result + for data in ("Power Estimates", "Epochs", "morlet"): + assert f"{data}" in result + + +@pytest.mark.parametrize( + "picks,combine", + ( + pytest.param("mag", "mean", id="mean_of_mags"), + pytest.param("grad", "rms", id="rms_of_grads"), + pytest.param([1], "mean", id="single_channel"), + pytest.param([1, 2], None, id="two_separate_channels"), + ), +) +def test_tfr_plot_combine(epochs_tfr, picks, combine): + """Test TFR.plot() picks, combine, and title="auto". + + No need to parametrize over {Raw,Epochs,Evoked}TFR, the code path is shared. + """ + fig = epochs_tfr.plot(picks=picks, combine=combine, title="auto") + assert len(fig) == 1 if isinstance(picks, str) else len(picks) + # test `title="auto"` + for ix, _fig in enumerate(fig): + if isinstance(picks, str): + ch_type = _channel_type_prettyprint[picks] + want = rf"{'RMS' if combine == 'rms' else 'Mean'} of \d{{1,3}} {ch_type}s" + else: + want = epochs_tfr.ch_names[picks[ix]] + assert re.search(want, _get_suptitle(_fig)) + + +def test_tfr_plot_extras(epochs_tfr): + """Test other options of TFR.plot().""" + # test mask and custom title + picks = [1] + mask = np.ones(epochs_tfr.data.shape[2:], bool) + fig = epochs_tfr.plot(picks=picks, mask=mask, title="Foo") + assert _get_suptitle(fig[0]) == "Foo" + mask = np.ones(epochs_tfr.data.shape[1:], bool) + with pytest.raises(ValueError, match="mask must have the same shape as the data"): + epochs_tfr.plot(picks=picks, mask=mask) + # test combine-related errors + with pytest.raises(ValueError, match='"combine" must be None, a callable, or one'): + epochs_tfr.plot(picks=picks, combine="foo") + with pytest.raises(RuntimeError, match="Wrong type yielded by callable"): + epochs_tfr.plot(picks=picks, combine=lambda x: 777) + with pytest.raises(RuntimeError, match="Wrong shape yielded by callable"): + epochs_tfr.plot(picks=picks, combine=lambda x: np.array([777])) + with pytest.raises(ValueError, match="wrong with the callable passed to 'combine'"): + epochs_tfr.plot(picks=picks, combine=lambda x, y: x.mean(axis=0)) + # test custom Axes + fig, axs = plt.subplots(1, 5) + fig2 = epochs_tfr.plot(picks=[1, 2], combine=lambda x: x.mean(axis=0), axes=axs[0]) + fig3 = epochs_tfr.plot(picks=[1, 2, 3], axes=axs[1:-1]) + fig4 = epochs_tfr.plot(picks=[1], axes=axs[-1:].tolist()) + for _fig in fig2 + fig3 + fig4: + assert fig == _fig + with pytest.raises(ValueError, match="axes must be None"): + epochs_tfr.plot(picks=picks, axes={}) + with pytest.raises(RuntimeError, match="must be one axes for each picked channel"): + epochs_tfr.plot(picks=[1, 2], axes=axs[-1:]) + # test singleton check by faking having 2 epochs + epochs_tfr._data = np.vstack((epochs_tfr._data, epochs_tfr._data)) + with pytest.raises(NotImplementedError, match=r"Cannot call plot\(\) from"): + epochs_tfr.plot() + + +def test_tfr_plot_interactivity(epochs_tfr): + """Test interactivity of TFR.plot().""" + fig = epochs_tfr.plot(picks="mag", combine="mean")[0] + assert len(plt.get_fignums()) == 1 + # press and release in same spot (should do nothing) + kw = dict(fig=fig, ax=fig.axes[0], xform="ax") + _fake_click(**kw, point=(0.5, 0.5), kind="press") + _fake_click(**kw, point=(0.5, 0.5), kind="motion") + _fake_click(**kw, point=(0.5, 0.5), kind="release") + assert len(plt.get_fignums()) == 1 + # click and drag (should create popup topomap) + _fake_click(**kw, point=(0.4, 0.4), kind="press") + _fake_click(**kw, point=(0.5, 0.5), kind="motion") + _fake_click(**kw, point=(0.6, 0.6), kind="release") + assert len(plt.get_fignums()) == 2 + + +@parametrize_inst_and_ch_type +def test_tfr_plot_topo(inst, ch_type, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topo().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot_topo(picks=ch_type) + assert fig is not None + + +@parametrize_inst_and_ch_type +def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topomap().""" + tfr = _get_inst(inst, request, average_tfr=full_average_tfr) + fig = tfr.plot_topomap(ch_type=ch_type) + # fake a click-drag-release to select all sensors & generate a pop-up TFR image + ax = fig.axes[0] + pts = [ + coll.get_offsets() + for coll in ax.collections + if isinstance(coll, PathCollection) + ][0] + # sometimes sensors are outside axes; make sure our click starts inside axes + lims = np.vstack((ax.get_xlim(), ax.get_ylim())) + pad = np.diff(lims, axis=1).ravel() / 100 + start = np.clip(pts.min(axis=0) - pad, *(lims.min(axis=1) + pad)) + stop = np.clip(pts.max(axis=0) + pad, *(lims.max(axis=1) - pad)) + kw = dict(fig=fig, ax=ax, xform="data") + _fake_click(**kw, kind="press", point=tuple(start)) + # ↓↓↓ possible bug? using (start+stop)/2 for the motion event causes the motion + # ↓↓↓ event (not release event) coords to propagate → fails to select sensors + _fake_click(**kw, kind="motion", point=tuple(stop)) + _fake_click(**kw, kind="release", point=tuple(stop)) + # make sure we actually got a pop-up figure, and it has a plausible title + fignums = plt.get_fignums() + assert len(fignums) == 2 + popup_fig = plt.figure(fignums[-1]) + assert re.match( + rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() + ) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index ce547568232..8f8599f757c 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -11,19 +11,17 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import inspect from copy import deepcopy from functools import partial +import matplotlib.pyplot as plt import numpy as np from scipy.fft import fft, ifft from scipy.signal import argrelmax from .._fiff.meas_info import ContainsMixin, Info -from .._fiff.pick import ( - _picks_to_idx, - channel_type, - pick_info, -) +from .._fiff.pick import _picks_to_idx, pick_info from ..baseline import _check_baseline, rescale from ..channels.channels import UpdateChannelsMixin from ..channels.layout import _find_topomap_coords, _merge_ch_data, _pair_grad_sensors @@ -37,27 +35,35 @@ _build_data_frame, _check_combine, _check_event_id, + _check_fname, + _check_method_kwargs, _check_option, _check_pandas_index_arguments, _check_pandas_installed, _check_time_format, _convert_times, + _ensure_events, _freq_mask, - _gen_events, _import_h5io_funcs, _is_numeric, + _pl, _prepare_read_metadata, _prepare_write_metadata, _time_mask, _validate_type, check_fname, + copy_doc, copy_function_doc_to_method_doc, fill_doc, + legacy, logger, + object_diff, + repr_html, sizeof_fmt, verbose, warn, ) +from ..utils.spectrum import _get_instance_type_string from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo from ..viz.topomap import ( _add_colorbar, @@ -67,6 +73,7 @@ plot_topomap, ) from ..viz.utils import ( + _make_combine_callable, _prepare_joint_axes, _set_title_multiple_electrodes, _setup_cmap, @@ -75,7 +82,8 @@ figure_nobar, plt_show, ) -from .multitaper import dpss_windows +from .multitaper import dpss_windows, tfr_array_multitaper +from .spectrum import EpochsSpectrum @fill_doc @@ -124,13 +132,11 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): Examples -------- Let's show a simple example of the relationship between ``n_cycles`` and - the FWHM using :func:`mne.time_frequency.fwhm`, as well as the equivalent - call using :func:`scipy.signal.morlet2`: + the FWHM using :func:`mne.time_frequency.fwhm`: .. plot:: import numpy as np - from scipy.signal import morlet2 as sp_morlet import matplotlib.pyplot as plt from mne.time_frequency import morlet, fwhm @@ -139,24 +145,15 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): wavelet = morlet(sfreq=sfreq, freqs=freq, n_cycles=n_cycles) M, w = len(wavelet), n_cycles # convert to SciPy convention s = w * sfreq / (2 * freq * np.pi) # from SciPy docs - wavelet_sp = sp_morlet(M, s, w) * np.sqrt(2) # match our normalization _, ax = plt.subplots(layout="constrained") - colors = { - ('MNE', 'real'): '#66CCEE', - ('SciPy', 'real'): '#4477AA', - ('MNE', 'imag'): '#EE6677', - ('SciPy', 'imag'): '#AA3377', - } - lw = dict(MNE=2, SciPy=4) - zorder = dict(MNE=5, SciPy=4) + colors = dict(real="#66CCEE", imag="#EE6677") t = np.arange(-M // 2 + 1, M // 2 + 1) / sfreq - for name, w in (('MNE', wavelet), ('SciPy', wavelet_sp)): - for kind in ('real', 'imag'): - ax.plot(t, getattr(w, kind), label=f'{name} {kind}', - lw=lw[name], color=colors[(name, kind)], - zorder=zorder[name]) - ax.plot(t, np.abs(wavelet), label=f'MNE abs', color='k', lw=1., zorder=6) + for kind in ('real', 'imag'): + ax.plot( + t, getattr(wavelet, kind), label=kind, color=colors[kind], + ) + ax.plot(t, np.abs(wavelet), label=f'abs', color='k', lw=1., zorder=6) half_max = np.max(np.abs(wavelet)) / 2. ax.plot([-this_fwhm / 2., this_fwhm / 2.], [half_max, half_max], color='k', linestyle='-', label='FWHM', zorder=6) @@ -239,7 +236,14 @@ def fwhm(freq, n_cycles): return n_cycles * np.sqrt(2 * np.log(2)) / (np.pi * freq) -def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): +def _make_dpss( + sfreq, + freqs, + n_cycles=7.0, + time_bandwidth=4.0, + zero_mean=False, + return_weights=False, +): """Compute DPSS tapers for the given frequency range. Parameters @@ -257,6 +261,8 @@ def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): Default is 4.0, giving 3 good tapers. zero_mean : bool | None, , default False Make sure the wavelet has a mean of zero. + return_weights : bool + Whether to return the concentration weights. Returns ------- @@ -304,7 +310,8 @@ def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): Wm.append(Wk) Ws.append(Wm) - + if return_weights: + return Ws, conc return Ws @@ -360,7 +367,7 @@ def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): The time-frequency transform of the signals. """ _check_option("mode", mode, ["same", "valid", "full"]) - decim = _check_decim(decim) + decim = _ensure_slice(decim) X = np.asarray(X) # Precompute wavelets for given frequency range to save time @@ -426,6 +433,7 @@ def _compute_tfr( decim=1, output="complex", n_jobs=None, + *, verbose=None, ): """Compute time-frequency transforms. @@ -490,15 +498,14 @@ def _compute_tfr( ``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary - values contain the inter-trial coherence: - ``out = avg_power + i * ITC``. + values contain the ITC: ``out = avg_power + i * itc``. """ # Check data epoch_data = np.asarray(epoch_data) if epoch_data.ndim != 3: raise ValueError( "epoch_data must be of shape (n_epochs, n_chans, " - "n_times), got %s" % (epoch_data.shape,) + f"n_times), got {epoch_data.shape}" ) # Check params @@ -514,11 +521,11 @@ def _compute_tfr( output, ) - decim = _check_decim(decim) + decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): raise ValueError( "Cannot compute freq above Nyquist freq of the data " - "(%0.1f Hz), got %0.1f Hz" % (sfreq / 2.0, freqs.max()) + f"({sfreq / 2.0:0.1f} Hz), got {freqs.max():0.1f} Hz" ) # We decimate *after* decomposition, so we need to create our kernels @@ -698,7 +705,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): dtype = np.complex128 # Init outputs - decim = _check_decim(decim) + decim = _ensure_slice(decim) n_tapers = len(Ws) n_epochs, n_times = X[:, decim].shape n_freqs = len(Ws[0]) @@ -790,7 +797,7 @@ def cwt(X, Ws, use_fft=True, mode="same", decim=1): def _cwt_array(X, Ws, nfft, mode, decim, use_fft): - decim = _check_decim(decim) + decim = _ensure_slice(decim) coefs = _cwt_gen(X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) n_signals, n_times = X[:, decim].shape @@ -802,85 +809,31 @@ def _cwt_array(X, Ws, nfft, mode, decim, use_fft): def _tfr_aux( - method, inst, freqs, decim, return_itc, picks, average, output=None, **tfr_params + method, inst, freqs, decim, return_itc, picks, average, output, **tfr_params ): from ..epochs import BaseEpochs - """Help reduce redundancy between tfr_morlet and tfr_multitaper.""" - decim = _check_decim(decim) - data = _get_data(inst, return_itc) - info = inst.info.copy() # make a copy as sfreq can be altered - - info, data = _prepare_picks(info, data, picks, axis=1) - del picks - - if average: - if output == "complex": - raise ValueError('output must be "power" if average=True') - if return_itc: - output = "avg_power_itc" - else: - output = "avg_power" - else: - output = "power" if output is None else output - if return_itc: - raise ValueError( - "Inter-trial coherence is not supported" " with average=False" - ) - - out = _compute_tfr( - data, - freqs, - info["sfreq"], + kwargs = dict( method=method, - output=output, + freqs=freqs, + picks=picks, decim=decim, + output=output, **tfr_params, ) - times = inst.times[decim].copy() - with info._unlock(): - info["sfreq"] /= decim.step - - if average: - if return_itc: - power, itc = out.real, out.imag - else: - power = out - nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method="%s-power" % method) - if return_itc: - out = ( - out, - AverageTFR(info, itc, times, freqs, nave, method="%s-itc" % method), - ) - else: - power = out - if isinstance(inst, BaseEpochs): - meta = deepcopy(inst._metadata) - evs = deepcopy(inst.events) - ev_id = deepcopy(inst.event_id) - selection = deepcopy(inst.selection) - drop_log = deepcopy(inst.drop_log) - else: - # if the input is of class Evoked - meta = evs = ev_id = selection = drop_log = None - - out = EpochsTFR( - info, - power, - times, - freqs, - method="%s-power" % method, - events=evs, - event_id=ev_id, - selection=selection, - drop_log=drop_log, - metadata=meta, - ) - - return out - - + if isinstance(inst, BaseEpochs): + kwargs.update(average=average, return_itc=return_itc) + elif average: + logger.info("inst is Evoked, setting `average=False`") + average = False + if average and output == "complex": + raise ValueError('output must be "power" if average=True') + if not average and return_itc: + raise ValueError("Inter-trial coherence is not supported with average=False") + return inst.compute_tfr(**kwargs) + + +@legacy(alt='.compute_tfr(method="morlet")') @verbose def tfr_morlet( inst, @@ -906,7 +859,7 @@ def tfr_morlet( ---------- inst : Epochs | Evoked The epochs or evoked object. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s use_fft : bool, default False The fft based convolution or not. @@ -973,16 +926,17 @@ def tfr_morlet( @verbose def tfr_array_morlet( - epoch_data, + data, sfreq, freqs, n_cycles=7.0, - zero_mean=False, + zero_mean=None, use_fft=True, decim=1, output="complex", n_jobs=None, verbose=None, + epoch_data=None, ): """Compute Time-Frequency Representation (TFR) using Morlet wavelets. @@ -991,14 +945,19 @@ def tfr_array_morlet( Parameters ---------- - epoch_data : array of shape (n_epochs, n_channels, n_times) + data : array of shape (n_epochs, n_channels, n_times) The epochs. sfreq : float | int Sampling frequency of the data. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s - zero_mean : bool + zero_mean : bool | None If True, make sure the wavelets have a mean of zero. default False. + + .. versionchanged:: 1.8 + The default will change from ``zero_mean=False`` in 1.6 to ``True`` in + 1.8, and (if not set explicitly) will raise a ``FutureWarning`` in 1.7. + use_fft : bool Use the FFT for convolutions or not. default True. %(decim_tfr)s @@ -1015,11 +974,15 @@ def tfr_array_morlet( The number of epochs to process at the same time. The parallelization is implemented across channels. Default 1. %(verbose)s + epoch_data : None + Deprecated parameter for providing epoched data as of 1.7, will be replaced with + the ``data`` parameter in 1.8. New code should use the ``data`` parameter. If + ``epoch_data`` is not ``None``, a warning will be raised. Returns ------- out : array - Time frequency transform of epoch_data. + Time frequency transform of ``data``. - if ``output in ('complex', 'phase', 'power')``, array of shape ``(n_epochs, n_chans, n_freqs, n_times)`` @@ -1049,8 +1012,22 @@ def tfr_array_morlet( ---------- .. footbibliography:: """ + if zero_mean is None: + warn( + "The default value of `zero_mean` will change from `False` to `True` " + "in version 1.8. Set the value explicitly to avoid this warning.", + FutureWarning, + ) + zero_mean = False + if epoch_data is not None: + warn( + "The parameter for providing data will be switched from `epoch_data` to " + "`data` in 1.8. Use the `data` parameter to avoid this warning.", + FutureWarning, + ) + return _compute_tfr( - epoch_data=epoch_data, + epoch_data=data, freqs=freqs, sfreq=sfreq, method="morlet", @@ -1065,6 +1042,7 @@ def tfr_array_morlet( ) +@legacy(alt='.compute_tfr(method="multitaper")') @verbose def tfr_multitaper( inst, @@ -1082,15 +1060,15 @@ def tfr_multitaper( ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. - Same computation as `~mne.time_frequency.tfr_array_multitaper`, but - operates on `~mne.Epochs` or `~mne.Evoked` objects instead of + Same computation as :func:`~mne.time_frequency.tfr_array_multitaper`, but + operates on :class:`~mne.Epochs` or :class:`~mne.Evoked` objects instead of :class:`NumPy arrays `. Parameters ---------- inst : Epochs | Evoked The epochs or evoked object. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s %(time_bandwidth_tfr)s use_fft : bool, default True @@ -1128,6 +1106,9 @@ def tfr_multitaper( .. versionadded:: 0.9.0 """ + from ..epochs import EpochsArray + from ..evoked import Evoked + tfr_params = dict( n_cycles=n_cycles, n_jobs=n_jobs, @@ -1135,23 +1116,580 @@ def tfr_multitaper( zero_mean=True, time_bandwidth=time_bandwidth, ) + if isinstance(inst, Evoked) and not average: + # convert AverageTFR to EpochsTFR for backwards compatibility + inst = EpochsArray(inst.data[np.newaxis], inst.info, tmin=inst.tmin, proj=False) return _tfr_aux( - "multitaper", inst, freqs, decim, return_itc, picks, average, **tfr_params + method="multitaper", + inst=inst, + freqs=freqs, + decim=decim, + return_itc=return_itc, + picks=picks, + average=average, + output="power", + **tfr_params, ) # TFR(s) class -class _BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): - """Base TFR class.""" +@fill_doc +class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): + """Base class for RawTFR, EpochsTFR, and AverageTFR (for type checking only). + + .. note:: + This class should not be instantiated directly; it is provided in the public API + only for type-checking purposes (e.g., ``isinstance(my_obj, BaseTFR)``). To + create TFR objects, use the ``.compute_tfr()`` methods on :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, or :class:`~mne.Evoked`, or use the constructors listed + below under "See Also". + + Parameters + ---------- + inst : instance of Raw, Epochs, or Evoked + The data from which to compute the time-frequency representation. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(n_jobs)s + %(reject_by_annotation_tfr)s + %(verbose)s + %(method_kw_tfr)s + + See Also + -------- + mne.time_frequency.RawTFR + mne.time_frequency.RawTFRArray + mne.time_frequency.EpochsTFR + mne.time_frequency.EpochsTFRArray + mne.time_frequency.AverageTFR + mne.time_frequency.AverageTFRArray + """ + + def __init__( + self, + inst, + method, + freqs, + tmin, + tmax, + picks, + proj, + *, + decim, + n_jobs, + reject_by_annotation=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + from ._stockwell import tfr_array_stockwell - def __init__(self): - self.baseline = None + # triage reading from file + if isinstance(inst, dict): + self.__setstate__(inst) + return + if method is None or freqs is None: + problem = [ + f"{k}=None" + for k, v in dict(method=method, freqs=freqs).items() + if v is None + ] + # TODO when py3.11 is min version, replace if/elif/else block with + # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] + _varnames = inspect.currentframe().f_back.f_code.co_varnames + if "BaseRaw" in _varnames: + classname = "RawTFR" + elif "Evoked" in _varnames: + classname = "AverageTFR" + else: + assert "BaseEpochs" in _varnames and "Evoked" not in _varnames + classname = "EpochsTFR" + # end TODO + raise ValueError( + f'{classname} got unsupported parameter value{_pl(problem)} ' + f'{" and ".join(problem)}.' + ) + # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) + if method == "morlet": + method_kw.setdefault("zero_mean", True) + # check method + valid_methods = ["morlet", "multitaper"] + if isinstance(inst, BaseEpochs): + valid_methods.append("stockwell") + method = _check_option("method", method, valid_methods) + # for stockwell, `tmin, tmax` already added to `method_kw` by calling method, + # and `freqs` vector has been pre-computed + if method != "stockwell": + method_kw.update(freqs=freqs) + # ↓↓↓ if constructor called directly, prevents key error + method_kw.setdefault("output", "power") + self._freqs = np.asarray(freqs, dtype=np.float64) + del freqs + # check validity of kwargs manually to save compute time if any are invalid + tfr_funcs = dict( + morlet=tfr_array_morlet, + multitaper=tfr_array_multitaper, + stockwell=tfr_array_stockwell, + ) + _check_method_kwargs(tfr_funcs[method], method_kw, msg=f'TFR method "{method}"') + self._tfr_func = partial(tfr_funcs[method], **method_kw) + # apply proj if desired + if proj: + inst = inst.copy().apply_proj() + self.inst = inst + + # prep picks and add the info object. bads and non-data channels are dropped by + # _picks_to_idx() so we update the info accordingly: + self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False) + self.info = pick_info(inst.info, sel=self._picks, copy=True) + # assign some attributes + self._method = method + self._inst_type = type(inst) + self._baseline = None + self.preload = True # needed for __getitem__, never False for TFRs + # self._dims may also get updated by child classes + self._dims = ["channel", "freq", "time"] + self._needs_taper_dim = method == "multitaper" and method_kw["output"] in ( + "complex", + "phase", + ) + if self._needs_taper_dim: + self._dims.insert(1, "taper") + self._dims = tuple(self._dims) + # get the instance data. + time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) + get_instance_data_kw = dict(time_mask=time_mask) + if reject_by_annotation is not None: + get_instance_data_kw.update(reject_by_annotation=reject_by_annotation) + data = self._get_instance_data(**get_instance_data_kw) + # compute the TFR + self._decim = _ensure_slice(decim) + self._raw_times = inst.times[time_mask] + self._compute_tfr(data, n_jobs, verbose) + self._update_epoch_attributes() + # "apply" decim to the rest of the object (data is decimated in _compute_tfr) + with self.info._unlock(): + self.info["sfreq"] /= self._decim.step + _decim_times = inst.times[self._decim] + _decim_time_mask = _time_mask(_decim_times, tmin, tmax, sfreq=self.sfreq) + self._raw_times = _decim_times[_decim_time_mask].copy() + self._set_times(self._raw_times) self._decim = 1 + # record data type (for repr and html_repr). ITC handled in the calling method. + if method == "stockwell": + self._data_type = "Power Estimates" + else: + data_types = dict( + power="Power Estimates", + avg_power="Average Power Estimates", + avg_power_itc="Average Power Estimates", + phase="Phase", + complex="Complex Amplitude", + ) + self._data_type = data_types[method_kw["output"]] + # check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw + # `output` so it may be missing here, so use `.get()` + negative_ok = method_kw.get("output", "") in ("complex", "phase") + # if method_kw.get("output", None) in ("phase", "complex"): + # raise RuntimeError + self._check_values(negative_ok=negative_ok) + # we don't need these anymore, and they make save/load harder + del self._picks + del self._tfr_func + del self._needs_taper_dim + del self._shape # calculated from self._data henceforth + del self.inst # save memory + + def __abs__(self): + """Return the absolute value.""" + tfr = self.copy() + tfr.data = np.abs(tfr.data) + return tfr + + @fill_doc + def __add__(self, other): + """Add two TFR instances. + + %(__add__tfr)s + """ + self._check_compatibility(other) + out = self.copy() + out.data += other.data + return out + + @fill_doc + def __iadd__(self, other): + """Add a TFR instance to another, in-place. + + %(__iadd__tfr)s + """ + self._check_compatibility(other) + self.data += other.data + return self + + @fill_doc + def __sub__(self, other): + """Subtract two TFR instances. + + %(__sub__tfr)s + """ + self._check_compatibility(other) + out = self.copy() + out.data -= other.data + return out + + @fill_doc + def __isub__(self, other): + """Subtract a TFR instance from another, in-place. + + %(__isub__tfr)s + """ + self._check_compatibility(other) + self.data -= other.data + return self + + @fill_doc + def __mul__(self, num): + """Multiply a TFR instance by a scalar. + + %(__mul__tfr)s + """ + out = self.copy() + out.data *= num + return out + + @fill_doc + def __imul__(self, num): + """Multiply a TFR instance by a scalar, in-place. + + %(__imul__tfr)s + """ + self.data *= num + return self + + @fill_doc + def __truediv__(self, num): + """Divide a TFR instance by a scalar. + + %(__truediv__tfr)s + """ + out = self.copy() + out.data /= num + return out + + @fill_doc + def __itruediv__(self, num): + """Divide a TFR instance by a scalar, in-place. + + %(__itruediv__tfr)s + """ + self.data /= num + return self + + def __eq__(self, other): + """Test equivalence of two TFR instances.""" + return object_diff(vars(self), vars(other)) == "" + + def __getstate__(self): + """Prepare object for serialization.""" + return dict( + method=self.method, + data=self._data, + sfreq=self.sfreq, + dims=self._dims, + freqs=self.freqs, + times=self.times, + inst_type_str=_get_instance_type_string(self), + data_type=self._data_type, + info=self.info, + baseline=self._baseline, + decim=self._decim, + ) + + def __setstate__(self, state): + """Unpack from serialized format.""" + from ..epochs import Epochs + from ..evoked import Evoked + from ..io import Raw + + defaults = dict( + method="unknown", + dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], + baseline=None, + decim=1, + data_type="TFR", + inst_type_str="Unknown", + ) + defaults.update(**state) + self._method = defaults["method"] + self._data = defaults["data"] + self._freqs = np.asarray(defaults["freqs"], dtype=np.float64) + self._dims = defaults["dims"] + self._raw_times = np.asarray(defaults["times"], dtype=np.float64) + self._baseline = defaults["baseline"] + self.info = Info(**defaults["info"]) + self._data_type = defaults["data_type"] + self._decim = defaults["decim"] + self.preload = True + self._set_times(self._raw_times) + # Handle instance type. Prior to gh-11282, Raw was not a possibility so if + # `inst_type_str` is missing it must be Epochs or Evoked + unknown_class = Epochs if "epoch" in self._dims else Evoked + inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) + self._inst_type = inst_types[defaults["inst_type_str"]] + # sanity check data/freqs/times/info agreement + self._check_state() + + def __repr__(self): + """Build string representation of the TFR object.""" + inst_type_str = _get_instance_type_string(self) + nave = f" (nave={self.nave})" if hasattr(self, "nave") else "" + # shape & dimension names + dims = " × ".join( + [f"{size} {dim}s" for size, dim in zip(self.shape, self._dims)] + ) + freq_range = f"{self.freqs[0]:0.1f} - {self.freqs[-1]:0.1f} Hz" + time_range = f"{self.times[0]:0.2f} - {self.times[-1]:0.2f} s" + return ( + f"<{self._data_type} from {inst_type_str}{nave}, " + f"{self.method} method | {dims}, {freq_range}, {time_range}, " + f"{sizeof_fmt(self._size)}>" + ) + + @repr_html + def _repr_html_(self, caption=None): + """Build HTML representation of the TFR object.""" + from ..html_templates import _get_html_template + + inst_type_str = _get_instance_type_string(self) + nave = getattr(self, "nave", 0) + t = _get_html_template("repr", "tfr.html.jinja") + t = t.render(tfr=self, inst_type=inst_type_str, nave=nave, caption=caption) + return t + + def _check_compatibility(self, other): + """Check compatibility of two TFR instances, in preparation for arithmetic.""" + operation = inspect.currentframe().f_back.f_code.co_name.strip("_") + if operation.startswith("i"): + operation = operation[1:] + msg = f"Cannot {operation} the two TFR instances: {{}} do not match{{}}." + extra = "" + if not isinstance(other, type(self)): + problem = "types" + extra = f" (self is {type(self)}, other is {type(other)})" + elif not self.times.shape == other.times.shape or np.any( + self.times != other.times + ): + problem = "times" + elif not self.freqs.shape == other.freqs.shape or np.any( + self.freqs != other.freqs + ): + problem = "freqs" + else: # should be OK + return + raise RuntimeError(msg.format(problem, extra)) + + def _check_state(self): + """Check data/freqs/times/info agreement during __setstate__.""" + msg = "{} axis of data ({}) doesn't match {} attribute ({})" + n_chan_info = len(self.info["chs"]) + n_chan = self._data.shape[self._dims.index("channel")] + n_freq = self._data.shape[self._dims.index("freq")] + n_time = self._data.shape[self._dims.index("time")] + if n_chan_info != n_chan: + msg = msg.format("Channel", n_chan, "info", n_chan_info) + elif n_freq != len(self.freqs): + msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) + elif n_time != len(self.times): + msg = msg.format("Time", n_time, "times", self.times.size) + else: + return + raise ValueError(msg) + + def _check_values(self, negative_ok=False): + """Check TFR results for correct shape and bad values.""" + assert len(self._dims) == self._data.ndim + assert self._data.shape == self._shape + # Check for implausible power values: take min() across all but the channel axis + # TODO: should this be more fine-grained (report "chan X in epoch Y")? + ch_dim = self._dims.index("channel") + dims = np.arange(self._data.ndim).tolist() + dims.pop(ch_dim) + negative_values = self._data.min(axis=tuple(dims)) < 0 + if negative_values.any() and not negative_ok: + chs = np.array(self.ch_names)[negative_values].tolist() + s = _pl(negative_values.sum()) + warn( + f"Negative value in time-frequency decomposition for channel{s} " + f'{", ".join(chs)}', + UserWarning, + ) + + def _compute_tfr(self, data, n_jobs, verbose): + result = self._tfr_func( + data, + self.sfreq, + decim=self._decim, + n_jobs=n_jobs, + verbose=verbose, + ) + # assign ._data and maybe ._itc + # tfr_array_stockwell always returns ITC (sometimes it's None) + if self.method == "stockwell": + self._data, self._itc, freqs = result + assert np.array_equal(self._freqs, freqs) + elif self._tfr_func.keywords.get("output", "").endswith("_itc"): + self._data, self._itc = result.real, result.imag + else: + self._data = result + # remove fake "epoch" dimension + if self.method != "stockwell" and _get_instance_type_string(self) != "Epochs": + self._data = np.squeeze(self._data, axis=0) + + # this is *expected* shape, it gets asserted later in _check_values() + # (and then deleted afterwards) + expected_shape = [ + len(self.ch_names), + len(self.freqs), + len(self._raw_times[self._decim]), # don't use self.times, not set yet + ] + # deal with the "taper" dimension + if self._needs_taper_dim: + expected_shape.insert(1, self._data.shape[1]) + self._shape = tuple(expected_shape) + + @verbose + def _onselect( + self, + eclick, + erelease, + picks=None, + exclude="bads", + combine="mean", + baseline=None, + mode=None, + cmap=None, + source_plot_joint=False, + topomap_args=None, + verbose=None, + ): + """Respond to rectangle selector in TFR image plots with a topomap plot.""" + if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: + return + t_range = (min(eclick.xdata, erelease.xdata), max(eclick.xdata, erelease.xdata)) + f_range = (min(eclick.ydata, erelease.ydata), max(eclick.ydata, erelease.ydata)) + # snap to nearest measurement point + t_idx = np.abs(self.times - np.atleast_2d(t_range).T).argmin(axis=1) + f_idx = np.abs(self.freqs - np.atleast_2d(f_range).T).argmin(axis=1) + tmin, tmax = self.times[t_idx] + fmin, fmax = self.freqs[f_idx] + # immutable → mutable default + if topomap_args is None: + topomap_args = dict() + topomap_args.setdefault("cmap", cmap) + topomap_args.setdefault("vlim", (None, None)) + # figure out which channel types we're dealing with + types = list() + if "eeg" in self: + types.append("eeg") + if "mag" in self: + types.append("mag") + if "grad" in self: + grad_picks = _pair_grad_sensors( + self.info, topomap_coords=False, raise_error=False + ) + if len(grad_picks) > 1: + types.append("grad") + elif len(types) == 0: + logger.info( + "Need at least 2 gradiometer pairs to plot a gradiometer topomap." + ) + return # Don't draw a figure for nothing. + + fig = figure_nobar() + t_range = f"{tmin:.3f}" if tmin == tmax else f"{tmin:.3f} - {tmax:.3f}" + f_range = f"{fmin:.2f}" if fmin == fmax else f"{fmin:.2f} - {fmax:.2f}" + fig.suptitle(f"{t_range} s,\n{f_range} Hz") + + if source_plot_joint: + ax = fig.add_subplot() + data, times, freqs = self.get_data( + picks=picks, exclude=exclude, return_times=True, return_freqs=True + ) + # merge grads before baselining (makes ERDs visible) + ch_types = np.array(self.get_channel_types(unique=True)) + ch_type = ch_types.item() # will error if there are more than one + data, pos = _merge_if_grads( + data=data, + info=self.info, + ch_type=ch_type, + sphere=topomap_args.get("sphere"), + combine=combine, + ) + # baseline and crop + data, *_ = _prep_data_for_plot( + data, + times, + freqs, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + verbose=verbose, + ) + # average over times and freqs + data = data.mean((-2, -1)) + + im, _ = plot_topomap(data, pos, axes=ax, show=False, **topomap_args) + _add_colorbar(ax, im, topomap_args["cmap"], title="AU") + plt_show(fig=fig) + else: + for idx, ch_type in enumerate(types): + ax = fig.add_subplot(1, len(types), idx + 1) + plot_tfr_topomap( + self, + ch_type=ch_type, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + axes=ax, + **topomap_args, + ) + ax.set_title(ch_type) + + def _update_epoch_attributes(self): + # overwritten in EpochsTFR; adds things needed for to_data_frame and __getitem__ + pass + + @property + def _detrend_picks(self): + """Provide compatibility with __iter__.""" + return list() + + @property + def baseline(self): + """Start and end of the baseline period (in seconds).""" + return self._baseline + + @property + def ch_names(self): + """The channel names.""" + return self.info["ch_names"] @property def data(self): + """The time-frequency-resolved power estimates.""" return self._data @data.setter @@ -1159,9 +1697,29 @@ def data(self, data): self._data = data @property - def ch_names(self): - """Channel names.""" - return self.info["ch_names"] + def freqs(self): + """The frequencies at which power estimates were computed.""" + return self._freqs + + @property + def method(self): + """The method used to compute the time-frequency power estimates.""" + return self._method + + @property + def sfreq(self): + """Sampling frequency of the data.""" + return self.info["sfreq"] + + @property + def shape(self): + """Data shape.""" + return self._data.shape + + @property + def times(self): + """The time points present in the data (in seconds).""" + return self._times_readonly @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): @@ -1169,10 +1727,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): Parameters ---------- - tmin : float | None - Start time of selection in seconds. - tmax : float | None - End time of selection in seconds. + %(tmin_tmax_psd)s fmin : float | None Lowest frequency of selection in Hz. @@ -1185,7 +1740,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): Returns ------- - inst : instance of AverageTFR + %(inst_tfr)s The modified instance. """ super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) @@ -1197,7 +1752,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): else: freq_mask = slice(None) - self.freqs = self.freqs[freq_mask] + self._freqs = self.freqs[freq_mask] # Deal with broadcasting (boolean arrays do not broadcast, but indices # do, so we need to convert freq_mask to make use of broadcasting) if isinstance(freq_mask, np.ndarray): @@ -1206,12 +1761,12 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): return self def copy(self): - """Return a copy of the instance. + """Return copy of the TFR instance. Returns ------- - copy : instance of EpochsTFR | instance of AverageTFR - A copy of the instance. + %(inst_tfr)s + A copy of the object. """ return deepcopy(self) @@ -1221,14 +1776,9 @@ def apply_baseline(self, baseline, mode="mean", verbose=None): Parameters ---------- - baseline : array-like, shape (2,) - The time interval to apply rescaling / baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. + %(baseline_rescale)s + + How baseline is computed is determined by the ``mode`` parameter. mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' Perform baseline correction by @@ -1247,522 +1797,313 @@ def apply_baseline(self, baseline, mode="mean", verbose=None): Returns ------- - inst : instance of AverageTFR + %(inst_tfr)s The modified instance. - """ # noqa: E501 - self.baseline = _check_baseline( - baseline, times=self.times, sfreq=self.info["sfreq"] - ) - rescale(self.data, self.times, self.baseline, mode, copy=False) + """ + self._baseline = _check_baseline(baseline, times=self.times, sfreq=self.sfreq) + rescale(self.data, self.times, self.baseline, mode, copy=False, verbose=verbose) return self - @verbose - def save(self, fname, overwrite=False, *, verbose=None): - """Save TFR object to hdf5 file. + @fill_doc + def get_data( + self, + picks=None, + exclude="bads", + fmin=None, + fmax=None, + tmin=None, + tmax=None, + return_times=False, + return_freqs=False, + ): + """Get time-frequency data in NumPy array format. Parameters ---------- - fname : path-like - The file name, which should end with ``-tfr.h5``. - %(overwrite)s - %(verbose)s + %(picks_good_data_noref)s + %(exclude_spectrum_get_data)s + %(fmin_fmax_tfr)s + %(tmin_tmax_psd)s + return_times : bool + Whether to return the time values for the requested time range. + Default is ``False``. + return_freqs : bool + Whether to return the frequency bin values for the requested + frequency range. Default is ``False``. - See Also - -------- - read_tfrs, write_tfrs - """ - write_tfrs(fname, self, overwrite=overwrite) + Returns + ------- + data : array + The requested data in a NumPy array. + times : array + The time values for the requested data range. Only returned if + ``return_times`` is ``True``. + freqs : array + The frequency values for the requested data range. Only returned if + ``return_freqs`` is ``True``. - @verbose - def to_data_frame( - self, - picks=None, - index=None, - long_format=False, - time_format=None, - *, - verbose=None, - ): - """Export data in tabular structure as a pandas DataFrame. - - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. - - Parameters - ---------- - %(picks_all)s - %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. - %(long_format_df_epo)s - %(time_format_df)s - - .. versionadded:: 0.23 - %(verbose)s - - Returns - ------- - %(df_return)s + Notes + ----- + Returns a copy of the underlying data (not a view). """ - # check pandas once here, instead of in each private utils function - pd = _check_pandas_installed() # noqa - # arg checking - valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): - valid_index_args.extend(["epoch", "condition"]) - valid_time_formats = ["ms", "timedelta"] - index = _check_pandas_index_arguments(index, valid_index_args) - time_format = _check_time_format(time_format, valid_time_formats) - # get data - times = self.times - picks = _picks_to_idx(self.info, picks, "all", exclude=()) - if isinstance(self, EpochsTFR): - data = self.data[:, picks, :, :] - else: - data = self.data[np.newaxis, picks] # add singleton "epochs" axis - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, 1, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) - # prepare extra columns / multiindex - mindex = list() - times = np.tile(times, n_epochs * n_freqs) - times = _convert_times(self, times, time_format) - mindex.append(("time", times)) - freqs = self.freqs - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) - mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) - rev_event_id = {v: k for k, v in self.event_id.items()} - conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) - assert all(len(mdx) == len(mindex[0]) for mdx in mindex) - # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] - df = _build_data_frame( - self, data, picks, long_format, mindex, index, default_index=default_index + tmin = self.times[0] if tmin is None else tmin + tmax = self.times[-1] if tmax is None else tmax + fmin = 0 if fmin is None else fmin + fmax = np.inf if fmax is None else fmax + picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) - return df - - -@fill_doc -class AverageTFR(_BaseTFR): - """Container for Time-Frequency data. - - Can for example store induced power at sensor level or inter-trial - coherence. - - Parameters - ---------- - %(info_not_none)s - data : ndarray, shape (n_channels, n_freqs, n_times) - The data. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - nave : int - The number of averaged TFRs. - comment : str | None, default None - Comment on the data, e.g., the experimental condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - %(verbose)s - - Attributes - ---------- - %(info_not_none)s - ch_names : list - The names of the channels. - nave : int - Number of averaged epochs. - data : ndarray, shape (n_channels, n_freqs, n_times) - The data array. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : str - Comment on dataset. Can be the condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - """ - - @verbose - def __init__( - self, info, data, times, freqs, nave, comment=None, method=None, verbose=None - ): # noqa: D102 - super().__init__() - self.info = info - if data.ndim != 3: - raise ValueError("data should be 3d. Got %d." % data.ndim) - n_channels, n_freqs, n_times = data.shape - if n_channels != len(info["chs"]): - raise ValueError( - "Number of channels and data size don't match" - " (%d != %d)." % (n_channels, len(info["chs"])) - ) - if n_freqs != len(freqs): - raise ValueError( - "Number of frequencies and data size don't match" - " (%d != %d)." % (n_freqs, len(freqs)) - ) - if n_times != len(times): - raise ValueError( - "Number of times and data size don't match" - " (%d != %d)." % (n_times, len(times)) - ) - self.data = data - self._set_times(np.array(times, dtype=float)) - self._raw_times = self.times.copy() - self.freqs = np.array(freqs, dtype=float) - self.nave = nave - self.comment = comment - self.method = method - self.preload = True + fmin_idx = np.searchsorted(self.freqs, fmin) + fmax_idx = np.searchsorted(self.freqs, fmax, side="right") + tmin_idx = np.searchsorted(self.times, tmin) + tmax_idx = np.searchsorted(self.times, tmax, side="right") + freq_picks = np.arange(fmin_idx, fmax_idx) + time_picks = np.arange(tmin_idx, tmax_idx) + freq_axis = self._dims.index("freq") + time_axis = self._dims.index("time") + chan_axis = self._dims.index("channel") + # normally there's a risk of np.take reducing array dimension if there + # were only one channel or frequency selected, but `_picks_to_idx` + # and np.arange both always return arrays, so we're safe; the result + # will always have the same `ndim` as it started with. + data = ( + self._data.take(picks, chan_axis) + .take(freq_picks, freq_axis) + .take(time_picks, time_axis) + ) + out = [data] + if return_times: + times = self._raw_times[tmin_idx:tmax_idx] + out.append(times) + if return_freqs: + freqs = self._freqs[fmin_idx:fmax_idx] + out.append(freqs) + if not return_times and not return_freqs: + return out[0] + return tuple(out) @verbose def plot( self, picks=None, - baseline=None, - mode="mean", + *, + exclude=(), tmin=None, tmax=None, - fmin=None, - fmax=None, + fmin=0.0, + fmax=np.inf, + baseline=None, + mode="mean", + dB=False, + combine=None, + layout=None, # TODO deprecate? not used in orig implementation either + yscale="auto", vmin=None, vmax=None, - cmap="RdBu_r", - dB=False, + vlim=(None, None), + cnorm=None, + cmap=None, colorbar=True, - show=True, - title=None, - axes=None, - layout=None, - yscale="auto", + title=None, # don't deprecate this one; has (useful) option title="auto" mask=None, mask_style=None, mask_cmap="Greys", mask_alpha=0.1, - combine=None, - exclude=[], - cnorm=None, + axes=None, + show=True, verbose=None, ): - """Plot TFRs as a two-dimensional image(s). + """Plot TFRs as two-dimensional time-frequency images. Parameters ---------- %(picks_good_data)s - baseline : None (default) or tuple, shape (2,) - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. - mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' - Perform baseline correction by + %(exclude_spectrum_plot)s + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(baseline_rescale)s - - subtracting the mean of baseline values ('mean') (default) - - dividing by the mean of baseline values ('ratio') - - dividing by the mean of baseline values and taking the log - ('logratio') - - subtracting the mean of baseline values followed by dividing by - the mean of baseline values ('percent') - - subtracting the mean of baseline values and dividing by the - standard deviation of baseline values ('zscore') - - dividing by the mean of baseline values, taking the log, and - dividing by the standard deviation of log baseline values - ('zlogratio') + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(dB_spectrum_plot)s + %(combine_tfr_plot)s - tmin : None | float - The first time instant to display. If None the first time point - available is used. Defaults to None. - tmax : None | float - The last time instant to display. If None the last time point - available is used. Defaults to None. - fmin : None | float - The first frequency to display. If None the first frequency - available is used. Defaults to None. - fmax : None | float - The last frequency to display. If None the last frequency - available is used. Defaults to None. - vmin : float | None - The minimum value an the color scale. If vmin is None, the data - minimum value is used. Defaults to None. - vmax : float | None - The maximum value an the color scale. If vmax is None, the data - maximum value is used. Defaults to None. - cmap : matplotlib colormap | 'interactive' | (colormap, bool) - The colormap to use. If tuple, the first value indicates the - colormap to use and the second value is a boolean defining - interactivity. In interactive mode the colors are adjustable by - clicking and dragging the colorbar with left and right mouse - button. Left mouse button moves the scale up and down and right - mouse button adjusts the range. Hitting space bar resets the range. - Up and down arrows can be used to change the colormap. If - 'interactive', translates to ('RdBu_r', True). Defaults to - 'RdBu_r'. - - .. warning:: Interactive mode works smoothly only for a small - amount of images. - - dB : bool - If True, 10*log10 is applied to the data to get dB. - Defaults to False. - colorbar : bool - If true, colorbar will be added to the plot. Defaults to True. - show : bool - Call pyplot.show() at the end. Defaults to True. - title : str | 'auto' | None - String for ``title``. Defaults to None (blank/no title). If - 'auto', and ``combine`` is None, the title for each figure - will be the channel name. If 'auto' and ``combine`` is not None, - ``title`` states how many channels were combined into that figure - and the method that was used for ``combine``. If str, that String - will be the title for each figure. - axes : instance of Axes | list | None - The axes to plot to. If list, the list must be a list of Axes of - the same length as ``picks``. If instance of Axes, there must be - only one channel plotted. If ``combine`` is not None, ``axes`` - must either be an instance of Axes, or a list of length 1. - layout : Layout | None - Layout instance specifying sensor positions. Used for interactive - plotting of topographies on rectangle selection. If possible, the - correct layout is inferred from the data. - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. + .. versionchanged:: 1.3 + Added support for ``callable``. + %(layout_spectrum_plot_topo)s + %(yscale_tfr_plot)s .. versionadded:: 0.14.0 - mask : ndarray | None - An array of booleans of the same shape as the data. Entries of the - data that correspond to False in the mask are plotted - transparently. Useful for, e.g., masking for statistical - significance. + %(vmin_vmax_tfr_plot)s + %(vlim_tfr_plot)s + %(cnorm)s + + .. versionadded:: 0.24 + %(cmap_topomap)s + %(colorbar)s + %(title_tfr_plot)s + %(mask_tfr_plot)s .. versionadded:: 0.16.0 - mask_style : None | 'both' | 'contour' | 'mask' - If ``mask`` is not None: if ``'contour'``, a contour line is drawn - around the masked areas (``True`` in ``mask``). If ``'mask'``, - entries not ``True`` in ``mask`` are shown transparently. If - ``'both'``, both a contour and transparency are used. - If ``None``, defaults to ``'both'`` if ``mask`` is not None, and is - ignored otherwise. + %(mask_style_tfr_plot)s .. versionadded:: 0.17 - mask_cmap : matplotlib colormap | (colormap, bool) | 'interactive' - The colormap chosen for masked parts of the image (see below), if - ``mask`` is not ``None``. If None, ``cmap`` is reused. Defaults to - ``'Greys'``. Not interactive. Otherwise, as ``cmap``. + %(mask_cmap_tfr_plot)s .. versionadded:: 0.17 - mask_alpha : float - A float between 0 and 1. If ``mask`` is not None, this sets the - alpha level (degree of transparency) for the masked-out segments. - I.e., if 0, masked-out segments are not visible at all. - Defaults to 0.1. + %(mask_alpha_tfr_plot)s .. versionadded:: 0.16.0 - combine : 'mean' | 'rms' | callable | None - Type of aggregation to perform across selected channels. If - None, plot one figure per selected channel. If a function, it must - operate on an array of shape ``(n_channels, n_freqs, n_times)`` and - return an array of shape ``(n_freqs, n_times)``. - - .. versionchanged:: 1.3 - Added support for ``callable``. - exclude : list of str | 'bads' - Channels names to exclude from being shown. If 'bads', the - bad channels are excluded. Defaults to an empty list. - %(cnorm)s - - .. versionadded:: 0.24 + %(axes_tfr_plot)s + %(show)s %(verbose)s Returns ------- figs : list of instances of matplotlib.figure.Figure A list of figures containing the time-frequency power. - """ # noqa: E501 - return self._plot( - picks=picks, - baseline=baseline, - mode=mode, + """ + # deprecations + vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) + # the rectangle selector plots topomaps, which needs all channels uncombined, + # so we keep a reference to that state here, and (because the topomap plotting + # function wants an AverageTFR) update it with `comment` and `nave` values in + # case we started out with a singleton EpochsTFR or RawTFR + initial_state = self.__getstate__() + initial_state.setdefault("comment", "") + initial_state.setdefault("nave", 1) + # `_picks_to_idx` also gets done inside `get_data()`` below, but we do it here + # because we need the indices later + idx_picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False + ) + pick_names = np.array(self.ch_names)[idx_picks].tolist() # for titles + ch_types = self.get_channel_types(idx_picks) + # get data arrays + data, times, freqs = self.get_data( + picks=idx_picks, exclude=(), return_times=True, return_freqs=True + ) + # pass tmin/tmax here ↓↓↓, not here ↑↑↑; we want to crop *after* baselining + data, times, freqs = _prep_data_for_plot( + data, + times, + freqs, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, - cmap=cmap, + baseline=baseline, + mode=mode, dB=dB, - colorbar=colorbar, - show=show, - title=title, - axes=axes, - layout=layout, - yscale=yscale, - mask=mask, - mask_style=mask_style, - mask_cmap=mask_cmap, - mask_alpha=mask_alpha, - combine=combine, - exclude=exclude, - cnorm=cnorm, verbose=verbose, ) - - @verbose - def _plot( - self, - picks=None, - baseline=None, - mode="mean", - tmin=None, - tmax=None, - fmin=None, - fmax=None, - vmin=None, - vmax=None, - cmap="RdBu_r", - dB=False, - colorbar=True, - show=True, - title=None, - axes=None, - layout=None, - yscale="auto", - mask=None, - mask_style=None, - mask_cmap="Greys", - mask_alpha=0.25, - combine=None, - exclude=None, - copy=True, - source_plot_joint=False, - topomap_args=dict(), - ch_type=None, - cnorm=None, - verbose=None, - ): - """Plot TFRs as a two-dimensional image(s). - - See self.plot() for parameters description. - """ - import matplotlib.pyplot as plt - - # channel selection - # simply create a new tfr object(s) with the desired channel selection - tfr = _preproc_tfr_instance( - self, - picks, - tmin, - tmax, - fmin, - fmax, - vmin, - vmax, - dB, - mode, - baseline, - exclude, - copy, + # shape + ch_axis = self._dims.index("channel") + freq_axis = self._dims.index("freq") + time_axis = self._dims.index("time") + want_shape = list(self.shape) + want_shape[ch_axis] = len(idx_picks) if combine is None else 1 + want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping + want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = tuple(want_shape) + # combine + combine_was_none = combine is None + combine = _make_combine_callable( + combine, axis=ch_axis, valid=("mean", "rms"), keepdims=True ) - del picks - - data = tfr.data - n_picks = len(tfr.ch_names) if combine is None else 1 - - # combine picks - _validate_type(combine, (None, str, "callable")) - if isinstance(combine, str): - _check_option("combine", combine, ("mean", "rms")) - if combine == "mean": - data = data.mean(axis=0, keepdims=True) - elif combine == "rms": - data = np.sqrt((data**2).mean(axis=0, keepdims=True)) - elif combine is not None: # callable - # It must operate on (n_channels, n_freqs, n_times) and return - # (n_freqs, n_times). Operates on a copy in-case 'combine' does - # some in-place operations. - try: - data = combine(data.copy()) - except TypeError: - raise RuntimeError( - "A callable 'combine' must operate on a single argument, " - "a numpy array of shape (n_channels, n_freqs, n_times)." - ) - if not isinstance(data, np.ndarray) or data.shape != tfr.data.shape[1:]: - raise RuntimeError( - "A callable 'combine' must return a numpy array of shape " - "(n_freqs, n_times)." - ) - # keep initial dimensions + try: + data = combine(data) # no need to copy; get_data() never returns a view + except Exception as e: + msg = ( + "Something went wrong with the callable passed to 'combine'; see " + "traceback." + ) + raise ValueError(msg) from e + # call succeeded, check type and shape + mismatch = False + if not isinstance(data, np.ndarray): + mismatch = "type" + extra = "" + elif data.shape not in (want_shape, want_shape[1:]): + mismatch = "shape" + extra = f" of shape {data.shape}" + if mismatch: + raise RuntimeError( + f"Wrong {mismatch} yielded by callable passed to 'combine'. Make sure " + "your function takes a single argument (an array of shape " + "(n_channels, n_freqs, n_times)) and returns an array of shape " + f"(n_freqs, n_times); yours yielded: {type(data)}{extra}." + ) + # restore singleton collapsed axis (removed by user-provided callable): + # (n_freqs, n_times) → (1, n_freqs, n_times) + if data.shape == (len(freqs), len(times)): data = data[np.newaxis] - # figure overhead - # set plot dimension - tmin, tmax = tfr.times[[0, -1]] - if vmax is None: - vmax = np.abs(data).max() - if vmin is None: - vmin = -np.abs(data).max() - - # set colorbar - cmap = _setup_cmap(cmap) - - # make sure there are as many axes as there will be channels to plot - if isinstance(axes, list) or isinstance(axes, np.ndarray): - figs_and_axes = [(ax.get_figure(), ax) for ax in axes] + assert data.shape == want_shape + # cmap handling. power may be negative depending on baseline strategy so set + # `norm` empirically — but only if user didn't set limits explicitly. + norm = False if vlim == (None, None) else data.min() >= 0.0 + vmin, vmax = _setup_vmin_vmax(data, *vlim, norm=norm) + cmap = _setup_cmap(cmap, norm=norm) + # prepare figure(s) + if axes is None: + figs = [plt.figure(layout="constrained") for _ in range(data.shape[0])] + axes = [fig.add_subplot() for fig in figs] elif isinstance(axes, plt.Axes): - figs_and_axes = [(ax.get_figure(), ax) for ax in [axes]] - elif axes is None: - figs = [plt.figure(layout="constrained") for i in range(n_picks)] - figs_and_axes = [(fig, fig.add_subplot(111)) for fig in figs] + figs = [axes.get_figure()] + axes = [axes] + elif isinstance(axes, np.ndarray): # allow plotting into a grid of axes + figs = [ax.get_figure() for ax in axes.flat] + elif hasattr(axes, "__iter__") and len(axes): + figs = [ax.get_figure() for ax in axes] else: - raise ValueError("axes must be None, plt.Axes, or list " "of plt.Axes.") - if len(figs_and_axes) != n_picks: - raise RuntimeError("There must be an axes for each picked " "channel.") - - for idx in range(n_picks): - fig = figs_and_axes[idx][0] - ax = figs_and_axes[idx][1] - onselect_callback = partial( - tfr._onselect, + raise ValueError( + f"axes must be None, Axes, or list/array of Axes, got {type(axes)}" + ) + if len(axes) != data.shape[0]: + raise RuntimeError( + f"Mismatch between picked channels ({data.shape[0]}) and axes " + f"({len(axes)}); there must be one axes for each picked channel." + ) + # check if we're being called from within plot_joint(). If so, get the + # `topomap_args` from the calling context and pass it to the onselect handler. + # (we need 2 `f_back` here because of the verbose decorator) + calling_frame = inspect.currentframe().f_back.f_back + source_plot_joint = calling_frame.f_code.co_name == "plot_joint" + topomap_args = ( + dict() + if not source_plot_joint + else calling_frame.f_locals.get("topomap_args", dict()) + ) + # plot + for ix, _fig in enumerate(figs): + # restrict the onselect instance to the channel type of the picks used in + # the image plot + uniq_types = np.unique(ch_types) + ch_type = None if len(uniq_types) > 1 else uniq_types.item() + this_tfr = AverageTFR(inst=initial_state).pick(ch_type, verbose=verbose) + _onselect_callback = partial( + this_tfr._onselect, + picks=None, # already restricted the picks in `this_tfr` + exclude=(), + baseline=baseline, + mode=mode, cmap=cmap, source_plot_joint=source_plot_joint, - topomap_args={ - k: v - for k, v in topomap_args.items() - if k not in {"vmin", "vmax", "cmap", "axes"} - }, + topomap_args=topomap_args, ) + # draw the image plot _imshow_tfr( - ax, - 0, - tmin, - tmax, - vmin, - vmax, - onselect_callback, + ax=axes[ix], + tfr=data[[ix]], + ch_idx=0, + tmin=times[0], + tmax=times[-1], + vmin=vmin, + vmax=vmax, + onselect=_onselect_callback, ylim=None, - tfr=data[idx : idx + 1], - freq=tfr.freqs, + freq=freqs, x_label="Time (s)", y_label="Frequency (Hz)", colorbar=colorbar, @@ -1774,123 +2115,83 @@ def _plot( mask_alpha=mask_alpha, cnorm=cnorm, ) - + # handle title. automatic title is: + # f"{Baselined} {power} ({ch_name})" or + # f"{Baselined} {power} ({combination} of {N} {ch_type}s)" if title == "auto": - if len(tfr.info["ch_names"]) == 1 or combine is None: - subtitle = tfr.info["ch_names"][idx] - else: - subtitle = _set_title_multiple_electrodes( - None, combine, tfr.info["ch_names"], all_=True, ch_type=ch_type + if combine_was_none: # one plot per channel + which_chs = pick_names[ix] + elif len(pick_names) == 1: # there was only one pick anyway + which_chs = pick_names[0] + else: # one plot for all chs combined + which_chs = _set_title_multiple_electrodes( + None, combine, pick_names, all_=True, ch_type=ch_type ) + _prefix = "Power" if baseline is None else "Baselined power" + _title = f"{_prefix} ({which_chs})" else: - subtitle = title - fig.suptitle(subtitle) - + _title = title + _fig.suptitle(_title) plt_show(show) - return [fig for (fig, ax) in figs_and_axes] + return figs @verbose def plot_joint( self, + *, timefreqs=None, picks=None, - baseline=None, - mode="mean", + exclude=(), + combine="mean", tmin=None, tmax=None, fmin=None, fmax=None, + baseline=None, + mode="mean", + dB=False, + yscale="auto", vmin=None, vmax=None, - cmap="RdBu_r", - dB=False, + vlim=(None, None), + cnorm=None, + cmap=None, colorbar=True, + title=None, # TODO consider deprecating this one, or adding an "auto" option show=True, - title=None, - yscale="auto", - combine="mean", - exclude=[], topomap_args=None, image_args=None, verbose=None, ): - """Plot TFRs as a two-dimensional image with topomaps. + """Plot TFRs as a two-dimensional image with topomap highlights. Parameters ---------- - timefreqs : None | list of tuple | dict of tuple - The time-frequency point(s) for which topomaps will be plotted. - See Notes. + %(timefreqs)s %(picks_good_data)s - baseline : None (default) or tuple of length 2 - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None, the beginning of the data is used. - If b is None, then b is set to the end of the interval. - If baseline is equal to (None, None), the entire time - interval is used. - mode : None | str - If str, must be one of 'ratio', 'zscore', 'mean', 'percent', - 'logratio' and 'zlogratio'. - Do baseline correction with ratio (power is divided by mean - power during baseline) or zscore (power is divided by standard - deviation of power during baseline after subtracting the mean, - power = [power - mean(power_baseline)] / std(power_baseline)), - mean simply subtracts the mean power, percent is the same as - applying ratio then mean, logratio is the same as mean but then - rendered in log-scale, zlogratio is the same as zscore but data - is rendered in log-scale first. - If None no baseline correction is applied. - %(tmin_tmax_psd)s - %(fmin_fmax_psd)s - vmin : float | None - The minimum value of the color scale for the image (for - topomaps, see ``topomap_args``). If vmin is None, the data - absolute minimum value is used. - vmax : float | None - The maximum value of the color scale for the image (for - topomaps, see ``topomap_args``). If vmax is None, the data - absolute maximum value is used. - cmap : matplotlib colormap - The colormap to use. - dB : bool - If True, 10*log10 is applied to the data to get dB. - colorbar : bool - If true, colorbar will be added to the plot (relating to the - topomaps). For user defined axes, the colorbar cannot be drawn. - Defaults to True. - show : bool - Call pyplot.show() at the end. - title : str | None - String for title. Defaults to None (blank/no title). - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. - combine : 'mean' | 'rms' | callable - Type of aggregation to perform across selected channels. If a - function, it must operate on an array of shape - ``(n_channels, n_freqs, n_times)`` and return an array of shape - ``(n_freqs, n_times)``. + %(exclude_psd)s + Default is an empty :class:`tuple` which includes all channels. + %(combine_tfr_plot_joint)s .. versionchanged:: 1.3 - Added support for ``callable``. - exclude : list of str | 'bads' - Channels names to exclude from being shown. If 'bads', the - bad channels are excluded. Defaults to an empty list, i.e., ``[]``. - topomap_args : None | dict - A dict of ``kwargs`` that are forwarded to - :func:`mne.viz.plot_topomap` to style the topomaps. ``axes`` and - ``show`` are ignored. If ``times`` is not in this dict, automatic - peak detection is used. Beyond that, if ``None``, no customizable - arguments will be passed. - Defaults to ``None``. - image_args : None | dict - A dict of ``kwargs`` that are forwarded to :meth:`AverageTFR.plot` - to style the image. ``axes`` and ``show`` are ignored. Beyond that, - if ``None``, no customizable arguments will be passed. - Defaults to ``None``. + Added support for ``callable``. + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(baseline_rescale)s + + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(dB_tfr_plot_topo)s + %(yscale_tfr_plot)s + %(vmin_vmax_tfr_plot)s + %(vlim_tfr_plot_joint)s + %(cnorm)s + %(cmap_tfr_plot_topo)s + %(colorbar_tfr_plot_joint)s + %(title_none)s + %(show)s + %(topomap_args)s + %(image_args)s %(verbose)s Returns @@ -1900,68 +2201,37 @@ def plot_joint( Notes ----- - ``timefreqs`` has three different modes: tuples, dicts, and auto. - For (list of) tuple(s) mode, each tuple defines a pair - (time, frequency) in s and Hz on the TFR plot. For example, to - look at 10 Hz activity 1 second into the epoch and 3 Hz activity - 300 msec into the epoch, :: - - timefreqs=((1, 10), (.3, 3)) - - If provided as a dictionary, (time, frequency) tuples are keys and - (time_window, frequency_window) tuples are the values - indicating the - width of the windows (centered on the time and frequency indicated by - the key) to be averaged over. For example, :: - - timefreqs={(1, 10): (0.1, 2)} - - would translate into a window that spans 0.95 to 1.05 seconds, as - well as 9 to 11 Hz. If None, a single topomap will be plotted at the - absolute peak across the time-frequency representation. + %(notes_timefreqs_tfr_plot_joint)s .. versionadded:: 0.16.0 - """ # noqa: E501 + """ + from matplotlib import ticker from matplotlib.patches import ConnectionPatch - ##################################### - # Handle channels (picks and types) # - ##################################### - - # it would be nicer to let this happen in self._plot, - # but we need it here to do the loop over the remaining channel - # types in case a user supplies `picks` that pre-select only one - # channel type. - # Nonetheless, it should be refactored for code reuse. - copy = any(var is not None for var in (exclude, picks, baseline)) - tfr = self - if copy: - tfr = tfr.copy() - picks = "data" if picks is None else picks - tfr.pick(picks, exclude=() if exclude is None else exclude) - del picks - ch_types = tfr.info.get_channel_types(unique=True) - - # if multiple sensor types: one plot per channel type, recursive call - if len(ch_types) > 1: - logger.info( - "Multiple channel types selected, returning one " "figure per type." - ) + # deprecations + vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) + # handle recursion + picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False + ) + all_ch_types = np.array(self.get_channel_types()) + uniq_ch_types = sorted(set(all_ch_types[picks])) + if len(uniq_ch_types) > 1: + msg = "Multiple channel types selected, returning one figure per type." + logger.info(msg) figs = list() - for this_type in ch_types: # pick corresponding channel type - type_picks = [ - idx - for idx in range(tfr.info["nchan"]) - if channel_type(tfr.info, idx) == this_type - ] - tf_ = tfr.copy().pick(type_picks) - if len(tf_.info.get_channel_types(unique=True)) > 1: - raise RuntimeError( - "Possibly infinite loop due to channel selection " - "problem. This should never happen! Please check " - "your channel types." - ) + for this_type in uniq_ch_types: + this_picks = np.intersect1d( + picks, + np.nonzero(np.isin(all_ch_types, this_type))[0], + assume_unique=True, + ) + # TODO might be nice to not "copy first, then pick"; alternative might + # be to subset the data with `this_picks` and then construct the "copy" + # using __getstate__ and __setstate__ + _tfr = self.copy().pick(this_picks) figs.append( - tf_.plot_joint( + _tfr.plot_joint( timefreqs=timefreqs, picks=None, baseline=baseline, @@ -1970,8 +2240,7 @@ def plot_joint( tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, + vlim=vlim, cmap=cmap, dB=dB, colorbar=colorbar, @@ -1979,207 +2248,181 @@ def plot_joint( title=title, yscale=yscale, combine=combine, - exclude=None, + exclude=(), topomap_args=topomap_args, verbose=verbose, ) ) return figs else: - ch_type = ch_types.pop() - - # Handle timefreqs - timefreqs = _get_timefreqs(tfr, timefreqs) - n_timefreqs = len(timefreqs) - - if topomap_args is None: - topomap_args = dict() - topomap_args_pass = { - k: v - for k, v in topomap_args.items() - if k not in ("axes", "show", "colorbar") - } - topomap_args_pass["outlines"] = topomap_args.get("outlines", "head") - topomap_args_pass["contours"] = topomap_args.get("contours", 6) - topomap_args_pass["ch_type"] = ch_type - - ############## - # Image plot # - ############## - - fig, tf_ax, map_ax = _prepare_joint_axes(n_timefreqs) - - cmap = _setup_cmap(cmap) - - # image plot - # we also use this to baseline and truncate (times and freqs) - # (a copy of) the instance - if image_args is None: - image_args = dict() - fig = tfr._plot( - picks=None, - baseline=baseline, - mode=mode, + ch_type = uniq_ch_types[0] + + # handle defaults + _validate_type(combine, ("str", "callable"), item_name="combine") # no `None` + image_args = dict() if image_args is None else image_args + topomap_args = dict() if topomap_args is None else topomap_args.copy() + # make sure if topomap_args["ch_type"] is set, it matches what is in `self.info` + topomap_args.setdefault("ch_type", ch_type) + if topomap_args["ch_type"] != ch_type: + raise ValueError( + f"topomap_args['ch_type'] is {topomap_args['ch_type']} which does not " + f"match the channel type present in the object ({ch_type})." + ) + # some necessary defaults + topomap_args.setdefault("outlines", "head") + topomap_args.setdefault("contours", 6) + # don't pass these: + topomap_args.pop("axes", None) + topomap_args.pop("show", None) + topomap_args.pop("colorbar", None) + + # get the time/freq limits of the image plot, to make sure requested annotation + # times/freqs are in range + _, times, freqs = self.get_data( + picks=picks, + exclude=(), tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, - cmap=cmap, + return_times=True, + return_freqs=True, + ) + # validate requested annotation times and freqs + timefreqs = _get_timefreqs(self, timefreqs) + valid_timefreqs = dict() + while timefreqs: + (_time, _freq), (t_win, f_win) = timefreqs.popitem() + # convert to half-windows + t_win /= 2 + f_win /= 2 + # make sure the times / freqs are in-bounds + msg = ( + "Requested {} exceeds the range of the data ({}). Choose different " + "`timefreqs`." + ) + if (times > _time).all() or (times < _time).all(): + _var = f"time point ({_time:0.3f} s)" + _range = f"{times[0]:0.3f} - {times[-1]:0.3f} s" + raise ValueError(msg.format(_var, _range)) + elif (freqs > _freq).all() or (freqs < _freq).all(): + _var = f"frequency ({_freq:0.1f} Hz)" + _range = f"{freqs[0]:0.1f} - {freqs[-1]:0.1f} Hz" + raise ValueError(msg.format(_var, _range)) + # snap the times/freqs to the nearest point we have an estimate for, and + # store the validated points + if t_win == 0: + _time = times[np.argmin(np.abs(times - _time))] + if f_win == 0: + _freq = freqs[np.argmin(np.abs(freqs - _freq))] + valid_timefreqs[(_time, _freq)] = (t_win, f_win) + + # prep data for topomaps (unlike image plot, must include all channels of the + # current ch_type). Don't pass tmin/tmax here (crop later after baselining) + topomap_picks = _picks_to_idx(self.info, ch_type) + data, times, freqs = self.get_data( + picks=topomap_picks, exclude=(), return_times=True, return_freqs=True + ) + # merge grads before baselining (makes ERDS visible) + info = pick_info(self.info, sel=topomap_picks, copy=True) + data, pos = _merge_if_grads( + data=data, + info=info, + ch_type=ch_type, + sphere=topomap_args.get("sphere"), + combine=combine, + ) + # loop over intended topomap locations, to find one vlim that works for all. + tf_array = np.array(list(valid_timefreqs)) # each row is [time, freq] + tf_array = tf_array[tf_array[:, 0].argsort()] # sort by time + _vmin, _vmax = (np.inf, -np.inf) + topomap_arrays = list() + topomap_titles = list() + for _time, _freq in tf_array: + # reduce data to the range of interest in the TF plane (i.e., finally crop) + t_win, f_win = valid_timefreqs[(_time, _freq)] + _tmin, _tmax = np.array([-1, 1]) * t_win + _time + _fmin, _fmax = np.array([-1, 1]) * f_win + _freq + _data, *_ = _prep_data_for_plot( + data, + times, + freqs, + tmin=_tmin, + tmax=_tmax, + fmin=_fmin, + fmax=_fmax, + baseline=baseline, + mode=mode, + verbose=verbose, + ) + _data = _data.mean(axis=(-1, -2)) # avg over times and freqs + topomap_arrays.append(_data) + _vmin = min(_data.min(), _vmin) + _vmax = max(_data.max(), _vmax) + # construct topopmap subplot title + t_pm = "" if t_win == 0 else f" ± {t_win:0.2f}" + f_pm = "" if f_win == 0 else f" ± {f_win:0.1f}" + _title = f"{_time:0.2f}{t_pm} s,\n{_freq:0.1f}{f_pm} Hz" + topomap_titles.append(_title) + # handle cmap. Power may be negative depending on baseline strategy so set + # `norm` empirically. vmin/vmax will be handled separately within the `plot()` + # call for the image plot. + norm = np.min(topomap_arrays) >= 0.0 + cmap = _setup_cmap(cmap, norm=norm) + topomap_args.setdefault("cmap", cmap[0]) # prevent interactive cbar + # finalize topomap vlims and compute contour locations. + # By passing `data=None` here ↓↓↓↓ we effectively assert vmin & vmax aren't None + _vlim = _setup_vmin_vmax(data=None, vmin=_vmin, vmax=_vmax, norm=norm) + topomap_args.setdefault("vlim", _vlim) + locator, topomap_args["contours"] = _set_contour_locator( + *topomap_args["vlim"], topomap_args["contours"] + ) + # initialize figure and do the image plot. `self.plot()` needed to wait to be + # called until after `topomap_args` was fully populated --- we don't pass the + # dict through to `self.plot()` explicitly here, but we do "reach back" and get + # it if it's needed by the interactive rectangle selector. + fig, image_ax, topomap_axes = _prepare_joint_axes(len(valid_timefreqs)) + fig = self.plot( + picks=picks, + exclude=(), + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, dB=dB, + combine=combine, + yscale=yscale, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, colorbar=False, - show=False, title=title, - axes=tf_ax, - yscale=yscale, - combine=combine, - exclude=None, - copy=False, - source_plot_joint=True, - topomap_args=topomap_args_pass, - ch_type=ch_type, + # mask, mask_style, mask_cmap, mask_alpha + axes=image_ax, + show=False, + verbose=verbose, **image_args, - )[0] - - # set and check time and freq limits ... - # can only do this after the tfr plot because it may change these - # parameters - tmax, tmin = tfr.times.max(), tfr.times.min() - fmax, fmin = tfr.freqs.max(), tfr.freqs.min() - for time, freq in timefreqs.keys(): - if not (tmin <= time <= tmax): - error_value = "time point (" + str(time) + " s)" - elif not (fmin <= freq <= fmax): - error_value = "frequency (" + str(freq) + " Hz)" - else: - continue - raise ValueError( - "Requested " + error_value + " exceeds the range" - "of the data. Choose different `timefreqs`." - ) - - ############ - # Topomaps # - ############ - - titles, all_data, all_pos, vlims = [], [], [], [] - - # the structure here is a bit complicated to allow aggregating vlims - # over all topomaps. First, one loop over all timefreqs to collect - # vlims. Then, find the max vlims and in a second loop over timefreqs, - # do the actual plotting. - timefreqs_array = np.array([np.array(keys) for keys in timefreqs]) - order = timefreqs_array[:, 0].argsort() # sort by time - - for ii, (time, freq) in enumerate(timefreqs_array[order]): - avg = timefreqs[(time, freq)] - # set up symmetric windows - time_half_range, freq_half_range = avg / 2.0 - - if time_half_range == 0: - time = tfr.times[np.argmin(np.abs(tfr.times - time))] - if freq_half_range == 0: - freq = tfr.freqs[np.argmin(np.abs(tfr.freqs - freq))] - - if (time_half_range == 0) and (freq_half_range == 0): - sub_map_title = "(%.2f s,\n%.1f Hz)" % (time, freq) - else: - sub_map_title = "(%.1f \u00B1 %.1f s,\n%.1f \u00B1 %.1f Hz)" % ( - time, - time_half_range, - freq, - freq_half_range, - ) - - tmin = time - time_half_range - tmax = time + time_half_range - fmin = freq - freq_half_range - fmax = freq + freq_half_range - - data = tfr.data - - # merging grads here before rescaling makes ERDs visible - - sphere = topomap_args.get("sphere") - if ch_type == "grad": - picks = _pair_grad_sensors(tfr.info, topomap_coords=False) - pos = _find_topomap_coords(tfr.info, picks=picks[::2], sphere=sphere) - method = combine if isinstance(combine, str) else "rms" - data, _ = _merge_ch_data(data[picks], ch_type, [], method=method) - del picks, method - else: - pos, _ = _get_pos_outlines(tfr.info, None, sphere) - del sphere - - all_pos.append(pos) - - data, times, freqs, _, _ = _preproc_tfr( - data, - tfr.times, - tfr.freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - None, - tfr.info["sfreq"], - ) - - vlims.append(np.abs(data).max()) - titles.append(sub_map_title) - all_data.append(data) - new_t = tfr.times[np.abs(tfr.times - np.median([times])).argmin()] - new_f = tfr.freqs[np.abs(tfr.freqs - np.median([freqs])).argmin()] - timefreqs_array[ii] = (new_t, new_f) - - # passing args to the topomap calls - max_lim = max(vlims) - _vlim = list(topomap_args.get("vlim", (None, None))) - # fall back on ± max_lim - for sign, index in zip((-1, 1), (0, 1)): - if _vlim[index] is None: - _vlim[index] = sign * max_lim - topomap_args_pass["vlim"] = tuple(_vlim) - locator, contours = _set_contour_locator(*_vlim, topomap_args_pass["contours"]) - topomap_args_pass["contours"] = contours - - for ax, title, data, pos in zip(map_ax, titles, all_data, all_pos): + )[0] # [0] because `.plot()` always returns a list + # now, actually plot the topomaps + for ax, title, _data in zip(topomap_axes, topomap_titles, topomap_arrays): ax.set_title(title) - plot_topomap( - data.mean(axis=(-1, -2)), - pos, - cmap=cmap[0], - axes=ax, - show=False, - **topomap_args_pass, - ) - - ############# - # Finish up # - ############# + plot_topomap(_data, pos, axes=ax, show=False, **topomap_args) + # draw colorbar if colorbar: - from matplotlib import ticker - cbar = fig.colorbar(ax.images[0]) - if locator is None: - locator = ticker.MaxNLocator(nbins=5) - cbar.locator = locator + cbar.locator = ticker.MaxNLocator(nbins=5) if locator is None else locator cbar.update_ticks() - - # draw the connection lines between time series and topoplots - for (time_, freq_), map_ax_ in zip(timefreqs_array, map_ax): + # draw the connection lines between time-frequency image and topoplots + for (time_, freq_), topo_ax in zip(tf_array, topomap_axes): con = ConnectionPatch( xyA=[time_, freq_], xyB=[0.5, 0], coordsA="data", coordsB="axes fraction", - axesA=tf_ax, - axesB=map_ax_, + axesA=image_ax, + axesB=topo_ax, color="grey", linestyle="-", linewidth=1.5, @@ -2192,108 +2435,6 @@ def plot_joint( plt_show(show) return fig - @verbose - def _onselect( - self, - eclick, - erelease, - baseline=None, - mode=None, - cmap=None, - source_plot_joint=False, - topomap_args=None, - verbose=None, - ): - """Handle rubber band selector in channel tfr.""" - if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: - return - tmin = round(min(eclick.xdata, erelease.xdata), 5) # s - tmax = round(max(eclick.xdata, erelease.xdata), 5) - fmin = round(min(eclick.ydata, erelease.ydata), 5) # Hz - fmax = round(max(eclick.ydata, erelease.ydata), 5) - tmin = min(self.times, key=lambda x: abs(x - tmin)) # find closest - tmax = min(self.times, key=lambda x: abs(x - tmax)) - fmin = min(self.freqs, key=lambda x: abs(x - fmin)) - fmax = min(self.freqs, key=lambda x: abs(x - fmax)) - if tmin == tmax or fmin == fmax: - logger.info( - "The selected area is too small. " - "Select a larger time-frequency window." - ) - return - - types = list() - if "eeg" in self: - types.append("eeg") - if "mag" in self: - types.append("mag") - if "grad" in self: - if ( - len( - _pair_grad_sensors( - self.info, topomap_coords=False, raise_error=False - ) - ) - >= 2 - ): - types.append("grad") - elif len(types) == 0: - return # Don't draw a figure for nothing. - - fig = figure_nobar() - fig.suptitle( - "{:.2f} s - {:.2f} s, {:.2f} Hz - {:.2f} Hz".format(tmin, tmax, fmin, fmax), - y=0.04, - ) - - if source_plot_joint: - ax = fig.add_subplot(111) - data = _preproc_tfr( - self.data, - self.times, - self.freqs, - tmin, - tmax, - fmin, - fmax, - None, - None, - None, - None, - None, - self.info["sfreq"], - )[0] - data = data.mean(-1).mean(-1) - vmax = np.abs(data).max() - im, _ = plot_topomap( - data, - self.info, - vlim=(-vmax, vmax), - cmap=cmap[0], - axes=ax, - show=False, - **topomap_args, - ) - _add_colorbar(ax, im, cmap, title="AU", pad=0.1) - fig.show() - else: - for idx, ch_type in enumerate(types): - ax = fig.add_subplot(1, len(types), idx + 1) - plot_tfr_topomap( - self, - ch_type=ch_type, - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - baseline=baseline, - mode=mode, - cmap=None, - vlim=(None, None), - axes=ax, - ) - ax.set_title(ch_type) - @verbose def plot_topo( self, @@ -2304,11 +2445,11 @@ def plot_topo( tmax=None, fmin=None, fmax=None, - vmin=None, + vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) vmax=None, layout=None, cmap="RdBu_r", - title=None, + title=None, # don't deprecate; topo titles aren't standard (color, size, just.) dB=False, colorbar=True, layout_scale=0.945, @@ -2320,88 +2461,38 @@ def plot_topo( yscale="auto", verbose=None, ): - """Plot TFRs in a topography with images. + """Plot a TFR image for each channel in a sensor layout arrangement. Parameters ---------- %(picks_good_data)s - baseline : None (default) or tuple of length 2 - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. - mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' - Perform baseline correction by - - - subtracting the mean of baseline values ('mean') - - dividing by the mean of baseline values ('ratio') - - dividing by the mean of baseline values and taking the log - ('logratio') - - subtracting the mean of baseline values followed by dividing by - the mean of baseline values ('percent') - - subtracting the mean of baseline values and dividing by the - standard deviation of baseline values ('zscore') - - dividing by the mean of baseline values, taking the log, and - dividing by the standard deviation of log baseline values - ('zlogratio') + %(baseline_rescale)s - tmin : None | float - The first time instant to display. If None the first time point - available is used. - tmax : None | float - The last time instant to display. If None the last time point - available is used. - fmin : None | float - The first frequency to display. If None the first frequency - available is used. - fmax : None | float - The last frequency to display. If None the last frequency - available is used. - vmin : float | None - The minimum value of the color scale. If vmin is None, the data - minimum value is used. - vmax : float | None - The maximum value of the color scale. If vmax is None, the data - maximum value is used. - layout : Layout | None - Layout instance specifying sensor positions. If possible, the - correct layout is inferred from the data. - cmap : matplotlib colormap | str - The colormap to use. Defaults to 'RdBu_r'. - title : str - Title of the figure. - dB : bool - If True, 10*log10 is applied to the data to get dB. - colorbar : bool - If true, colorbar will be added to the plot. - layout_scale : float - Scaling factor for adjusting the relative size of the layout - on the canvas. - show : bool - Call pyplot.show() at the end. - border : str - Matplotlib borders style to be used for each sensor plot. - fig_facecolor : color - The figure face color. Defaults to black. - fig_background : None | array - A background image for the figure. This must be a valid input to - `matplotlib.pyplot.imshow`. Defaults to None. - font_color : color - The color of tick labels in the colorbar. Defaults to white. - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(vmin_vmax_tfr_plot_topo)s + %(layout_spectrum_plot_topo)s + %(cmap_tfr_plot_topo)s + %(title_none)s + %(dB_tfr_plot_topo)s + %(colorbar)s + %(layout_scale)s + %(show)s + %(border_topo)s + %(fig_facecolor)s + %(fig_background)s + %(font_color)s + %(yscale_tfr_plot)s %(verbose)s Returns ------- fig : matplotlib.figure.Figure The figure containing the topography. - """ # noqa: E501 + """ + # convenience vars times = self.times.copy() freqs = self.freqs data = self.data @@ -2410,6 +2501,8 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks + # TODO this is the only remaining call to _preproc_tfr; should be refactored + # (to use _prep_data_for_plot?) data, times, freqs, vmin, vmax = _preproc_tfr( data, times, @@ -2458,22 +2551,1106 @@ def plot_topo( vmin=vmin, vmax=vmax, cmap=cmap, - layout_scale=layout_scale, + layout_scale=layout_scale, + title=title, + border=border, + x_label="Time (s)", + y_label="Frequency (Hz)", + fig_facecolor=fig_facecolor, + font_color=font_color, + unified=True, + img=True, + ) + + add_background_image(fig, fig_background) + plt_show(show) + return fig + + @copy_function_doc_to_method_doc(plot_tfr_topomap) + def plot_topomap( + self, + tmin=None, + tmax=None, + fmin=0.0, + fmax=np.inf, + *, + ch_type=None, + baseline=None, + mode="mean", + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=2, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%1.1e", + units=None, + axes=None, + show=True, + ): + return plot_tfr_topomap( + self, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + ch_type=ch_type, + baseline=baseline, + mode=mode, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + show=show, + ) + + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save time-frequency data to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + Path of file to save to. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.time_frequency.read_spectrum + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "time-frequency object", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + out = self.__getstate__() + if "metadata" in out: + out["metadata"] = _prepare_write_metadata(out["metadata"]) + write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") + + @verbose + def to_data_frame( + self, + picks=None, + index=None, + long_format=False, + time_format=None, + *, + verbose=None, + ): + """Export data in tabular structure as a pandas DataFrame. + + Channels are converted to columns in the DataFrame. By default, + additional columns ``'time'``, ``'freq'``, ``'epoch'``, and + ``'condition'`` (epoch event description) are added, unless ``index`` + is not ``None`` (in which case the columns specified in ``index`` will + be used to form the DataFrame's index instead). ``'epoch'``, and + ``'condition'`` are not supported for ``AverageTFR``. + + Parameters + ---------- + %(picks_all)s + %(index_df_epo)s + Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and + ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` + for ``AverageTFR``. + Defaults to ``None``. + %(long_format_df_epo)s + %(time_format_df)s + + .. versionadded:: 0.23 + %(verbose)s + + Returns + ------- + %(df_return)s + """ + # check pandas once here, instead of in each private utils function + pd = _check_pandas_installed() # noqa + # arg checking + valid_index_args = ["time", "freq"] + if isinstance(self, EpochsTFR): + valid_index_args.extend(["epoch", "condition"]) + valid_time_formats = ["ms", "timedelta"] + index = _check_pandas_index_arguments(index, valid_index_args) + time_format = _check_time_format(time_format, valid_time_formats) + # get data + picks = _picks_to_idx(self.info, picks, "all", exclude=()) + data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) + axis = self._dims.index("channel") + if not isinstance(self, EpochsTFR): + data = data[np.newaxis] # add singleton "epochs" axis + axis += 1 + n_epochs, n_picks, n_freqs, n_times = data.shape + # reshape to (epochs*freqs*times) x signals + data = np.moveaxis(data, axis, -1) + data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + # prepare extra columns / multiindex + mindex = list() + times = _convert_times(times, time_format, self.info["meas_date"]) + times = np.tile(times, n_epochs * n_freqs) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + mindex.append(("time", times)) + mindex.append(("freq", freqs)) + if isinstance(self, EpochsTFR): + mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + rev_event_id = {v: k for k, v in self.event_id.items()} + conditions = [rev_event_id[k] for k in self.events[:, 2]] + mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) + # build DataFrame + if isinstance(self, EpochsTFR): + default_index = ["condition", "epoch", "freq", "time"] + else: + default_index = ["freq", "time"] + df = _build_data_frame( + self, data, picks, long_format, mindex, index, default_index=default_index + ) + return df + + +@fill_doc +class AverageTFR(BaseTFR): + """Data object for spectrotemporal representations of averaged data. + + .. warning:: The preferred means of creating AverageTFR objects is via the + instance methods :meth:`mne.Epochs.compute_tfr` and + :meth:`mne.Evoked.compute_tfr`, or via + :meth:`mne.time_frequency.EpochsTFR.average`. Direct class + instantiation is discouraged. + + Parameters + ---------- + %(info_not_none)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or + use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + data : ndarray, shape (n_channels, n_freqs, n_times) + The data. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or + use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + times : ndarray, shape (n_times,) + The time values in seconds. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead and + (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use + :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. + nave : int + The number of averaged TFRs. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead; + ``nave`` will be inferred automatically. Or, use + :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + inst : instance of Evoked | instance of Epochs | dict + The data from which to compute the time-frequency representation. Passing a + :class:`dict` will create the AverageTFR using the ``__setstate__`` interface + and is not recommended for typical use cases. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(comment_averagetfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_averagetfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(nave_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + RawTFR + EpochsTFR + AverageTFRArray + mne.Evoked.compute_tfr + mne.time_frequency.EpochsTFR.average + + Notes + ----- + The old API (prior to version 1.7) was:: + + AverageTFR(info, data, times, freqs, nave, comment=None, method=None) + + That API is still available via :class:`~mne.time_frequency.AverageTFRArray` for + cases where the data are precomputed or do not originate from MNE-Python objects. + The preferred new API uses instance methods:: + + evoked.compute_tfr(method, freqs, ...) + epochs.compute_tfr(method, freqs, average=True, ...) + + The new API also supports AverageTFR instantiation from a :class:`dict`, but this + is primarily for save/load and internal purposes, and wraps ``__setstate__``. + During the transition from the old to the new API, it may be expedient to use + :class:`~mne.time_frequency.AverageTFRArray` as a "quick-fix" approach to updating + scripts under active development. + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info=None, + data=None, + times=None, + freqs=None, + nave=None, + *, + inst=None, + method=None, + tmin=None, + tmax=None, + picks=None, + proj=False, + decim=1, + comment=None, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + from ..evoked import Evoked + from ._stockwell import _check_input_st, _compute_freqs_st + + # deprecations. TODO remove after 1.7 release + depr_params = dict(info=info, data=data, times=times, nave=nave) + bad_params = list() + for name, param in depr_params.items(): + if param is not None: + bad_params.append(name) + if len(bad_params): + _s = _pl(bad_params) + is_are = _pl(bad_params, "is", "are") + bad_params_list = '", "'.join(bad_params) + warn( + f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be ' + "removed in version 1.8. For a quick fix, use ``AverageTFRArray`` with " + "the same parameters. For a long-term fix, see the docstring notes.", + FutureWarning, + ) + if inst is not None: + raise ValueError( + "Do not pass `inst` alongside deprecated params " + f'"{bad_params_list}"; see docstring of AverageTFR for guidance.' + ) + inst = depr_params | dict(freqs=freqs, method=method, comment=comment) + # end TODO ↑↑↑↑↑↑ + + # dict is allowed for __setstate__ compatibility, and Epochs.compute_tfr() can + # return an AverageTFR depending on its parameters, so Epochs input is allowed + _validate_type( + inst, (BaseEpochs, Evoked, dict), "object passed to AverageTFR constructor" + ) + # stockwell API is very different from multitaper/morlet + if method == "stockwell" and not isinstance(inst, dict): + if isinstance(freqs, str) and freqs == "auto": + fmin, fmax = None, None + elif len(freqs) == 2: + fmin, fmax = freqs + else: + raise ValueError( + "for Stockwell method, freqs must be a length-2 iterable " + f'or "auto", got {freqs}.' + ) + method_kw.update(fmin=fmin, fmax=fmax) + # Compute freqs. We need a couple lines of code dupe here (also in + # BaseTFR.__init__) to get the subset of times to pass to _check_input_st() + _mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info["sfreq"]) + _times = inst.times[_mask].copy() + _, default_nfft, _ = _check_input_st(_times, None) + n_fft = method_kw.get("n_fft", default_nfft) + *_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"]) + + # use Evoked.comment or str(Epochs.event_id) as the default comment... + if comment is None: + comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", ""))) + # ...but don't overwrite if it's coming in with a comment already set + if isinstance(inst, dict): + inst.setdefault("comment", comment) + else: + self._comment = getattr(self, "_comment", comment) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + + def __getstate__(self): + """Prepare AverageTFR object for serialization.""" + out = super().__getstate__() + out.update(nave=self.nave, comment=self.comment) + # NOTE: self._itc should never exist in the instance returned to the user; it + # is temporarily present in the output from the tfr_array_* function, and is + # split out into a separate AverageTFR object (and deleted from the object + # holding power estimates) before those objects are passed back to the user. + # The following lines are there because we make use of __getstate__ to achieve + # that splitting of objects. + if hasattr(self, "_itc"): + out.update(itc=self._itc) + return out + + def __setstate__(self, state): + """Unpack AverageTFR from serialized format.""" + super().__setstate__(state) + self._comment = state.get("comment", "") + self._nave = state.get("nave", 1) + + @property + def comment(self): + return self._comment + + @comment.setter + def comment(self, comment): + self._comment = comment + + @property + def nave(self): + return self._nave + + @nave.setter + def nave(self, nave): + self._nave = nave + + def _get_instance_data(self, time_mask): + # AverageTFRs can be constructed from Epochs data, so we triage shape here. + # Evoked data get a fake singleton "epoch" axis prepended + dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis + data = self.inst.get_data(picks=self._picks)[dim, :, time_mask] + self._nave = getattr(self.inst, "nave", data.shape[0]) + return data + + +@fill_doc +class AverageTFRArray(AverageTFR): + """Data object for *precomputed* spectrotemporal representations of averaged data. + + Parameters + ---------- + %(info_not_none)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + nave : int + The number of averaged TFRs. + %(comment_averagetfr_attr)s + %(method_tfr_array)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_averagetfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(nave_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + AverageTFR + EpochsTFRArray + mne.Epochs.compute_tfr + mne.Evoked.compute_tfr + """ + + def __init__( + self, info, data, times, freqs, *, nave=None, comment=None, method=None + ): + state = dict(info=info, data=data, times=times, freqs=freqs) + for name, optional in dict(nave=nave, comment=comment, method=method).items(): + if optional is not None: + state[name] = optional + self.__setstate__(state) + + +@fill_doc +class EpochsTFR(BaseTFR, GetEpochsMixin): + """Data object for spectrotemporal representations of epoched data. + + .. important:: + The preferred means of creating EpochsTFR objects from :class:`~mne.Epochs` + objects is via the instance method :meth:`~mne.Epochs.compute_tfr`. + To create an EpochsTFR object from pre-computed data (i.e., a NumPy array) use + :class:`~mne.time_frequency.EpochsTFRArray`. + + Parameters + ---------- + %(info_not_none)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + data : ndarray, shape (n_channels, n_freqs, n_times) + The data. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + times : ndarray, shape (n_times,) + The time values in seconds. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead and + (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(freqs_tfr_epochs)s + inst : instance of Epochs + The data from which to compute the time-frequency representation. + %(method_tfr_epochs)s + %(comment_tfr_attr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(events_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(event_id_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + selection : array + List of indices of selected events (not dropped or ignored etc.). For + example, if the original event array had 4 events and the second event + has been dropped, this attribute would be np.array([0, 2, 3]). + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + drop_log : tuple of tuple + A tuple of the same length as the event array used to initialize the + ``EpochsTFR`` object. If the i-th original event is still part of the + selection, drop_log[i] will be an empty tuple; otherwise it will be + a tuple of the reasons the event is not longer in the selection, e.g.: + + - ``'IGNORED'`` + If it isn't part of the current subset defined by the user + - ``'NO_DATA'`` or ``'TOO_SHORT'`` + If epoch didn't contain enough data names of channels that + exceeded the amplitude threshold + - ``'EQUALIZED_COUNTS'`` + See :meth:`~mne.Epochs.equalize_event_counts` + - ``'USER'`` + For user-defined reasons (see :meth:`~mne.Epochs.drop`). + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(metadata_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_tfr_attr)s + %(drop_log)s + %(event_id_attr)s + %(events_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(metadata_attr)s + %(method_tfr_attr)s + %(selection_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + mne.Epochs.compute_tfr + RawTFR + AverageTFR + EpochsTFRArray + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info=None, + data=None, + times=None, + freqs=None, + *, + inst=None, + method=None, + comment=None, + tmin=None, + tmax=None, + picks=None, + proj=False, + decim=1, + events=None, + event_id=None, + selection=None, + drop_log=None, + metadata=None, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + + # deprecations. TODO remove after 1.7 release + depr_params = dict(info=info, data=data, times=times, comment=comment) + bad_params = list() + for name, param in depr_params.items(): + if param is not None: + bad_params.append(name) + if len(bad_params): + _s = _pl(bad_params) + is_are = _pl(bad_params, "is", "are") + bad_params_list = '", "'.join(bad_params) + warn( + f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be ' + "removed in version 1.8. For a quick fix, use ``EpochsTFRArray`` with " + "the same parameters. For a long-term fix, see the docstring notes.", + FutureWarning, + ) + if inst is not None: + raise ValueError( + "Do not pass `inst` alongside deprecated params " + f'"{bad_params_list}"; see docstring of AverageTFR for guidance.' + ) + # sensible defaults are created in __setstate__ so only pass these through + # if they're user-specified + optional = dict( + freqs=freqs, + method=method, + events=events, + event_id=event_id, + selection=selection, + drop_log=drop_log, + metadata=metadata, + ) + optional_params = { + key: val for key, val in optional.items() if val is not None + } + inst = depr_params | optional_params + # end TODO ↑↑↑↑↑↑ + + # dict is allowed for __setstate__ compatibility + _validate_type( + inst, (BaseEpochs, dict), "object passed to EpochsTFR constructor", "Epochs" + ) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + + @fill_doc + def __getitem__(self, item): + """Subselect epochs from an EpochsTFR. + + Parameters + ---------- + %(item)s + Access options are the same as for :class:`~mne.Epochs` objects, see the + docstring Notes section of :meth:`mne.Epochs.__getitem__` for explanation. + + Returns + ------- + %(getitem_epochstfr_return)s + """ + return super().__getitem__(item) + + def __getstate__(self): + """Prepare EpochsTFR object for serialization.""" + out = super().__getstate__() + out.update( + metadata=self._metadata, + drop_log=self.drop_log, + event_id=self.event_id, + events=self.events, + selection=self.selection, + raw_times=self._raw_times, + ) + return out + + def __setstate__(self, state): + """Unpack EpochsTFR from serialized format.""" + if state["data"].ndim != 4: + raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + super().__setstate__(state) + self._metadata = state.get("metadata", None) + n_epochs = self.shape[0] + n_times = self.shape[-1] + fake_samps = np.linspace( + n_times, n_times * (n_epochs + 1), n_epochs, dtype=int, endpoint=False + ) + fake_events = np.dstack( + (fake_samps, np.zeros_like(fake_samps), np.ones_like(fake_samps)) + ).squeeze(axis=0) + self.events = state.get("events", _ensure_events(fake_events)) + self.event_id = state.get("event_id", _check_event_id(None, self.events)) + self.drop_log = state.get("drop_log", tuple()) + self.selection = state.get("selection", np.arange(n_epochs)) + self._bad_dropped = True # always true, need for `equalize_event_counts()` + + def __next__(self, return_event_id=False): + """Iterate over EpochsTFR objects. + + NOTE: __iter__() and _stop_iter() are defined by the GetEpochs mixin. + + Parameters + ---------- + return_event_id : bool + If ``True``, return both the EpochsTFR data and its associated ``event_id``. + + Returns + ------- + epoch : array of shape (n_channels, n_freqs, n_times) + The single-epoch time-frequency data. + event_id : int + The integer event id associated with the epoch. Only returned if + ``return_event_id`` is ``True``. + """ + if self._current >= len(self._data): + self._stop_iter() + epoch = self._data[self._current] + event_id = self.events[self._current][-1] + self._current += 1 + if return_event_id: + return epoch, event_id + return epoch + + def _check_singleton(self): + """Check if self contains only one Epoch, and return it as an AverageTFR.""" + if self.shape[0] > 1: + calling_func = inspect.currentframe().f_back.f_code.co_name + raise NotImplementedError( + f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; " + "please subselect a single epoch before plotting." + ) + return list(self.iter_evoked())[0] + + def _get_instance_data(self, time_mask): + return self.inst.get_data(picks=self._picks)[:, :, time_mask] + + def _update_epoch_attributes(self): + # adjust dims and shape + if self.method != "stockwell": # stockwell consumes epochs dimension + self._dims = ("epoch",) + self._dims + self._shape = (len(self.inst),) + self._shape + # we need these for to_data_frame() + self.event_id = self.inst.event_id.copy() + self.events = self.inst.events.copy() + self.selection = self.inst.selection.copy() + # we need these for __getitem__() + self.drop_log = deepcopy(self.inst.drop_log) + self._metadata = self.inst.metadata + # we need this for compatibility with equalize_event_counts() + self._bad_dropped = True + + def average(self, method="mean", *, dim="epochs", copy=False): + """Aggregate the EpochsTFR across epochs, frequencies, or times. + + Parameters + ---------- + method : "mean" | "median" | callable + How to aggregate the data across the given ``dim``. If callable, + must take a :class:`NumPy array` of shape + ``(n_epochs, n_channels, n_freqs, n_times)`` and return an array + with one fewer dimensions (which dimension is collapsed depends on + the value of ``dim``). Default is ``"mean"``. + dim : "epochs" | "freqs" | "times" + The dimension along which to combine the data. + copy : bool + Whether to return a copy of the modified instance, or modify in place. + Ignored when ``dim="epochs"`` or ``"times"`` because those options return + different types (:class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.EpochsSpectrum`, respectively). + + Returns + ------- + tfr : instance of EpochsTFR | AverageTFR | EpochsSpectrum + The aggregated TFR object. + + Notes + ----- + Passing in ``np.median`` is considered unsafe for complex data; pass + the string ``"median"`` instead to compute the *marginal* median + (i.e. the median of the real and imaginary components separately). + See discussion here: + + https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + """ + _check_option("dim", dim, ("epochs", "freqs", "times")) + axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural + + func = _check_combine(mode=method, axis=axis) + data = func(self.data) + + n_epochs, n_channels, n_freqs, n_times = self.data.shape + freqs, times = self.freqs, self.times + if dim == "epochs": + expected_shape = self._data.shape[1:] + elif dim == "freqs": + expected_shape = (n_epochs, n_channels, n_times) + freqs = np.mean(self.freqs, keepdims=True) + elif dim == "times": + expected_shape = (n_epochs, n_channels, n_freqs) + times = np.mean(self.times, keepdims=True) + + if data.shape != expected_shape: + raise RuntimeError( + "EpochsTFR.average() got a method that resulted in data of shape " + f"{data.shape}, but it should be {expected_shape}." + ) + state = self.__getstate__() + # restore singleton freqs axis (not necessary for epochs/times: class changes) + if dim == "freqs": + data = np.expand_dims(data, axis=axis) + else: + state["dims"] = (*state["dims"][:axis], *state["dims"][axis + 1 :]) + state["data"] = data + state["info"] = deepcopy(self.info) + state["freqs"] = freqs + state["times"] = times + if dim == "epochs": + state["inst_type_str"] = "Evoked" + state["nave"] = n_epochs + state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}" + out = AverageTFR(inst=state) + out._data_type = "Average Power" + return out + + elif dim == "times": + return EpochsSpectrum( + state, + method=None, + fmin=None, + fmax=None, + tmin=None, + tmax=None, + picks=None, + exclude=None, + proj=None, + remove_dc=None, + n_jobs=None, + ) + # ↓↓↓ these two are for dim == "freqs" + elif copy: + return EpochsTFR(inst=state, method=None, freqs=None) + else: + self._data = np.expand_dims(data, axis=axis) + self._freqs = freqs + return self + + @verbose + def drop(self, indices, reason="USER", verbose=None): + """Drop epochs based on indices or boolean mask. + + .. note:: The indices refer to the current set of undropped epochs + rather than the complete set of dropped and undropped epochs. + They are therefore not necessarily consistent with any + external indices (e.g., behavioral logs). To drop epochs + based on external criteria, do not use the ``preload=True`` + flag when constructing an Epochs object, and call this + method before calling the :meth:`mne.Epochs.drop_bad` or + :meth:`mne.Epochs.load_data` methods. + + Parameters + ---------- + indices : array of int or bool + Set epochs to remove by specifying indices to remove or a boolean + mask to apply (where True values get removed). Events are + correspondingly modified. + reason : str + Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Default: 'USER'. + %(verbose)s + + Returns + ------- + epochs : instance of Epochs or EpochsTFR + The epochs with indices dropped. Operates in-place. + """ + from ..epochs import BaseEpochs + + BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose) + + return self + + def iter_evoked(self, copy=False): + """Iterate over EpochsTFR to yield a sequence of AverageTFR objects. + + The AverageTFR objects will each contain a single epoch (i.e., no averaging is + performed). This method resets the EpochTFR instance's iteration state to the + first epoch. + + Parameters + ---------- + copy : bool + Whether to yield copies of the data and measurement info, or views/pointers. + """ + self.__iter__() + state = self.__getstate__() + state["inst_type_str"] = "Evoked" + state["dims"] = state["dims"][1:] # drop "epochs" + + while True: + try: + data, event_id = self.__next__(return_event_id=True) + except StopIteration: + break + if copy: + state["info"] = deepcopy(self.info) + state["data"] = data.copy() + else: + state["data"] = data + state["nave"] = 1 + yield AverageTFR(inst=state, method=None, freqs=None, comment=str(event_id)) + + @verbose + @copy_doc(BaseTFR.plot) + def plot( + self, + picks=None, + *, + exclude=(), + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode="mean", + dB=False, + combine=None, + layout=None, # TODO deprecate; not used in orig implementation + yscale="auto", + vmin=None, + vmax=None, + vlim=(None, None), + cnorm=None, + cmap=None, + colorbar=True, + title=None, # don't deprecate this one; has (useful) option title="auto" + mask=None, + mask_style=None, + mask_cmap="Greys", + mask_alpha=0.1, + axes=None, + show=True, + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot( + picks=picks, + exclude=exclude, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + combine=combine, + layout=layout, + yscale=yscale, + vmin=vmin, + vmax=vmax, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, + colorbar=colorbar, + title=title, + mask=mask, + mask_style=mask_style, + mask_cmap=mask_cmap, + mask_alpha=mask_alpha, + axes=axes, + show=show, + verbose=verbose, + ) + + @verbose + @copy_doc(BaseTFR.plot_topo) + def plot_topo( + self, + picks=None, + baseline=None, + mode="mean", + tmin=None, + tmax=None, + fmin=None, + fmax=None, + vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) + vmax=None, + layout=None, + cmap=None, + title=None, # don't deprecate; topo titles aren't standard (color, size, just.) + dB=False, + colorbar=True, + layout_scale=0.945, + show=True, + border="none", + fig_facecolor="k", + fig_background=None, + font_color="w", + yscale="auto", + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_topo( + picks=picks, + baseline=baseline, + mode=mode, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + layout=layout, + cmap=cmap, + title=title, + dB=dB, + colorbar=colorbar, + layout_scale=layout_scale, + show=show, + border=border, + fig_facecolor=fig_facecolor, + fig_background=fig_background, + font_color=font_color, + yscale=yscale, + verbose=verbose, + ) + + @verbose + @copy_doc(BaseTFR.plot_joint) + def plot_joint( + self, + *, + timefreqs=None, + picks=None, + exclude=(), + combine="mean", + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode="mean", + dB=False, + yscale="auto", + vmin=None, + vmax=None, + vlim=(None, None), + cnorm=None, + cmap=None, + colorbar=True, + title=None, + show=True, + topomap_args=None, + image_args=None, + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_joint( + timefreqs=timefreqs, + picks=picks, + exclude=exclude, + combine=combine, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + yscale=yscale, + vmin=vmin, + vmax=vmax, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, + colorbar=colorbar, title=title, - border=border, - x_label="Time (s)", - y_label="Frequency (Hz)", - fig_facecolor=fig_facecolor, - font_color=font_color, - unified=True, - img=True, + show=show, + topomap_args=topomap_args, + image_args=image_args, + verbose=verbose, ) - add_background_image(fig, fig_background) - plt_show(show) - return fig - - @copy_function_doc_to_method_doc(plot_tfr_topomap) + @copy_doc(BaseTFR.plot_topomap) def plot_topomap( self, tmin=None, @@ -2505,8 +3682,8 @@ def plot_topomap( axes=None, show=True, ): - return plot_tfr_topomap( - self, + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_topomap( tmin=tmin, tmax=tmax, fmin=fmin, @@ -2536,160 +3713,55 @@ def plot_topomap( show=show, ) - def _check_compat(self, tfr): - """Check that self and tfr have the same time-frequency ranges.""" - assert np.all(tfr.times == self.times) - assert np.all(tfr.freqs == self.freqs) - - def __add__(self, tfr): # noqa: D105 - """Add instances.""" - self._check_compat(tfr) - out = self.copy() - out.data += tfr.data - return out - - def __iadd__(self, tfr): # noqa: D105 - self._check_compat(tfr) - self.data += tfr.data - return self - - def __sub__(self, tfr): # noqa: D105 - """Subtract instances.""" - self._check_compat(tfr) - out = self.copy() - out.data -= tfr.data - return out - - def __isub__(self, tfr): # noqa: D105 - self._check_compat(tfr) - self.data -= tfr.data - return self - - def __truediv__(self, a): # noqa: D105 - """Divide instances.""" - out = self.copy() - out /= a - return out - - def __itruediv__(self, a): # noqa: D105 - self.data /= a - return self - - def __mul__(self, a): - """Multiply source instances.""" - out = self.copy() - out *= a - return out - - def __imul__(self, a): # noqa: D105 - self.data *= a - return self - - def __repr__(self): # noqa: D105 - s = "time : [%f, %f]" % (self.times[0], self.times[-1]) - s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1]) - s += ", nave : %d" % self.nave - s += ", channels : %d" % self.data.shape[0] - s += ", ~%s" % (sizeof_fmt(self._size),) - return "" % s - @fill_doc -class EpochsTFR(_BaseTFR, GetEpochsMixin): - """Container for Time-Frequency data on epochs. - - Can for example store induced power at sensor level. +class EpochsTFRArray(EpochsTFR): + """Data object for *precomputed* spectrotemporal representations of epoched data. Parameters ---------- %(info_not_none)s - data : ndarray, shape (n_epochs, n_channels, n_freqs, n_times) - The data. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : str | None, default None - Comment on the data, e.g., the experimental condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - events : ndarray, shape (n_events, 3) | None - The events as stored in the Epochs class. If None (default), all event - values are set to 1 and event time-samples are set to range(n_epochs). - event_id : dict | None - Example: dict(auditory=1, visual=3). They keys can be used to access - associated events. If None, all events will be used and a dict is - created with string integer names corresponding to the event id - integers. - selection : iterable | None - Iterable of indices of selected epochs. If ``None``, will be - automatically generated, corresponding to all non-zero events. - - .. versionadded:: 0.23 - drop_log : tuple | None - Tuple of tuple of strings indicating which epochs have been marked to - be ignored. - - .. versionadded:: 0.23 - metadata : instance of pandas.DataFrame | None - A :class:`pandas.DataFrame` containing pertinent information for each - trial. See :class:`mne.Epochs` for further details. - %(verbose)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + %(comment_tfr_attr)s + %(method_tfr_array)s + %(events_epochstfr)s + %(event_id_epochstfr)s + %(selection)s + %(drop_log)s + %(metadata_epochstfr)s Attributes ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_tfr_attr)s + %(drop_log)s + %(event_id_attr)s + %(events_attr)s + %(freqs_tfr_attr)s %(info_not_none)s - ch_names : list - The names of the channels. - data : ndarray, shape (n_epochs, n_channels, n_freqs, n_times) - The data array. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : string - Comment on dataset. Can be the condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - events : ndarray, shape (n_events, 3) | None - Array containing sample information as event_id - event_id : dict | None - Names of conditions correspond to event_ids - selection : array - List of indices of selected events (not dropped or ignored etc.). For - example, if the original event array had 4 events and the second event - has been dropped, this attribute would be np.array([0, 2, 3]). - drop_log : tuple of tuple - A tuple of the same length as the event array used to initialize the - ``EpochsTFR`` object. If the i-th original event is still part of the - selection, drop_log[i] will be an empty tuple; otherwise it will be - a tuple of the reasons the event is not longer in the selection, e.g.: - - - ``'IGNORED'`` - If it isn't part of the current subset defined by the user - - ``'NO_DATA'`` or ``'TOO_SHORT'`` - If epoch didn't contain enough data names of channels that - exceeded the amplitude threshold - - ``'EQUALIZED_COUNTS'`` - See :meth:`~mne.Epochs.equalize_event_counts` - - ``'USER'`` - For user-defined reasons (see :meth:`~mne.Epochs.drop`). + %(metadata_attr)s + %(method_tfr_attr)s + %(selection_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s - metadata : pandas.DataFrame, shape (n_events, n_cols) | None - DataFrame containing pertinent information for each trial - - Notes - ----- - .. versionadded:: 0.13.0 + See Also + -------- + AverageTFR + mne.Epochs.compute_tfr + mne.Evoked.compute_tfr """ - @verbose def __init__( self, info, data, times, freqs, + *, comment=None, method=None, events=None, @@ -2697,170 +3769,204 @@ def __init__( selection=None, drop_log=None, metadata=None, - verbose=None, ): - # noqa: D102 - super().__init__() - self.info = info - if data.ndim != 4: - raise ValueError("data should be 4d. Got %d." % data.ndim) - n_epochs, n_channels, n_freqs, n_times = data.shape - if n_channels != len(info["chs"]): - raise ValueError( - "Number of channels and data size don't match" - " (%d != %d)." % (n_channels, len(info["chs"])) - ) - if n_freqs != len(freqs): - raise ValueError( - "Number of frequencies and data size don't match" - " (%d != %d)." % (n_freqs, len(freqs)) - ) - if n_times != len(times): - raise ValueError( - "Number of times and data size don't match" - " (%d != %d)." % (n_times, len(times)) - ) - if events is None: - n_epochs = len(data) - events = _gen_events(n_epochs) - if selection is None: - n_epochs = len(data) - selection = np.arange(n_epochs) - if drop_log is None: - n_epochs_prerejection = max(len(events), max(selection) + 1) - drop_log = tuple( - () if k in selection else ("IGNORED",) - for k in range(n_epochs_prerejection) - ) - else: - drop_log = drop_log - # check consistency: - assert len(selection) == len(events) - assert len(drop_log) >= len(events) - assert len(selection) == sum((len(dl) == 0 for dl in drop_log)) - event_id = _check_event_id(event_id, events) - self.data = data - self._set_times(np.array(times, dtype=float)) - self._raw_times = self.times.copy() # needed for decimate - self.freqs = np.array(freqs, dtype=float) - self.events = events - self.event_id = event_id - self.selection = selection - self.drop_log = drop_log - self.comment = comment - self.method = method - self.preload = True - self.metadata = metadata + state = dict(info=info, data=data, times=times, freqs=freqs) + optional = dict( + comment=comment, + method=method, + events=events, + event_id=event_id, + selection=selection, + drop_log=drop_log, + metadata=metadata, + ) + for name, value in optional.items(): + if value is not None: + state[name] = value + self.__setstate__(state) - @property - def _detrend_picks(self): - return list() - def __repr__(self): # noqa: D105 - s = "time : [%f, %f]" % (self.times[0], self.times[-1]) - s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1]) - s += ", epochs : %d" % self.data.shape[0] - s += ", channels : %d" % self.data.shape[1] - s += ", ~%s" % (sizeof_fmt(self._size),) - return "" % s +@fill_doc +class RawTFR(BaseTFR): + """Data object for spectrotemporal representations of continuous data. + + .. warning:: The preferred means of creating RawTFR objects from + :class:`~mne.io.Raw` objects is via the instance method + :meth:`~mne.io.Raw.compute_tfr`. Direct class instantiation + is not supported. - def __abs__(self): - """Take the absolute value.""" - epochs = self.copy() - epochs.data = np.abs(self.data) - return epochs + Parameters + ---------- + inst : instance of Raw + The data from which to compute the time-frequency representation. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(reject_by_annotation_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + ch_names : list + The channel names. + freqs : array + Frequencies at which the amplitude, power, or fourier coefficients + have been computed. + %(info_not_none)s + method : str + The method used to compute the spectra (``'morlet'``, ``'multitaper'`` + or ``'stockwell'``). + + See Also + -------- + mne.io.Raw.compute_tfr + EpochsTFR + AverageTFR + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + inst, + method=None, + freqs=None, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + reject_by_annotation=False, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..io import BaseRaw + + # dict is allowed for __setstate__ compatibility + _validate_type( + inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw" + ) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=reject_by_annotation, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) - def average(self, method="mean", dim="epochs", copy=False): - """Average the data across epochs. + def __getitem__(self, item): + """Get RawTFR data. Parameters ---------- - method : str | callable - How to combine the data. If "mean"/"median", the mean/median - are returned. Otherwise, must be a callable which, when passed - an array of shape (n_epochs, n_channels, n_freqs, n_time) - returns an array of shape (n_channels, n_freqs, n_time). - Note that due to file type limitations, the kind for all - these will be "average". - dim : 'epochs' | 'freqs' | 'times' - The dimension along which to combine the data. - copy : bool - Whether to return a copy of the modified instance, - or modify in place. Ignored when ``dim='epochs'`` - because a new instance must be returned. + item : int | slice | array-like + Indexing is similar to a :class:`NumPy array`; see + Notes. Returns ------- - ave : instance of AverageTFR | EpochsTFR - The averaged data. + %(getitem_tfr_return)s Notes ----- - Passing in ``np.median`` is considered unsafe when there is complex - data because NumPy doesn't compute the marginal median. Numpy currently - sorts the complex values by real part and return whatever value is - computed. Use with caution. We use the marginal median in the - complex case (i.e. the median of each component separately) if - one passes in ``median``. See a discussion in scipy: + The last axis is always time, the next-to-last axis is always + frequency, and the first axis is always channel. If + ``method='multitaper'`` and ``output='complex'`` then the second axis + will be taper index. - https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + Integer-, list-, and slice-based indexing is possible: + + - ``raw_tfr[[0, 2]]`` gives the whole time-frequency plane for the + first and third channels. + - ``raw_tfr[..., :3, :]`` gives the first 3 frequency bins and all + times for all channels (and tapers, if present). + - ``raw_tfr[..., :100]`` gives the first 100 time samples in all + frequency bins for all channels (and tapers). + - ``raw_tfr[(4, 7)]`` is the same as ``raw_tfr[4, 7]``. + + .. note:: + + Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the + requested data values and the corresponding times), accessing + :class:`~mne.time_frequency.RawTFR` values via subscript does + **not** return the corresponding frequency bin values. If you need + them, use ``RawTFR.freqs[freq_indices]`` or + ``RawTFR.get_data(..., return_freqs=True)``. """ - _check_option("dim", dim, ("epochs", "freqs", "times")) - axis = dict(epochs=0, freqs=2, times=self.data.ndim - 1)[dim] + from ..io import BaseRaw - # return a lambda function for computing a combination metric - # over epochs - func = _check_combine(mode=method, axis=axis) - data = func(self.data) + self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) + return BaseRaw._getitem(self, item, return_times=False) - n_epochs, n_channels, n_freqs, n_times = self.data.shape - freqs, times = self.freqs, self.times + def _get_instance_data(self, time_mask, reject_by_annotation): + start, stop = np.where(time_mask)[0][[0, -1]] + rba = "NaN" if reject_by_annotation else None + data = self.inst.get_data( + self._picks, start, stop + 1, reject_by_annotation=rba + ) + # prepend a singleton "epochs" axis + return data[np.newaxis] - if dim == "freqs": - freqs = np.mean(self.freqs, keepdims=True) - n_freqs = 1 - elif dim == "times": - times = np.mean(self.times, keepdims=True) - n_times = 1 - if dim == "epochs": - expected_shape = self._data.shape[1:] - else: - expected_shape = (n_epochs, n_channels, n_freqs, n_times) - data = np.expand_dims(data, axis=axis) - if data.shape != expected_shape: - raise RuntimeError( - f"You passed a function that resulted in data of shape " - f"{data.shape}, but it should be {expected_shape}." - ) +@fill_doc +class RawTFRArray(RawTFR): + """Data object for *precomputed* spectrotemporal representations of continuous data. - if dim == "epochs": - return AverageTFR( - info=self.info.copy(), - data=data, - times=times, - freqs=freqs, - nave=self.data.shape[0], - method=self.method, - comment=self.comment, - ) - elif copy: - return EpochsTFR( - info=self.info.copy(), - data=data, - times=times, - freqs=freqs, - method=self.method, - comment=self.comment, - metadata=self.metadata, - events=self.events, - event_id=self.event_id, - ) - else: - self.data = data - self._set_times(times) - self.freqs = freqs - return self + Parameters + ---------- + %(info_not_none)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + %(method_tfr_array)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + RawTFR + mne.io.Raw.compute_tfr + EpochsTFRArray + AverageTFRArray + """ + + def __init__( + self, + info, + data, + times, + freqs, + *, + method=None, + ): + state = dict(info=info, data=data, times=times, freqs=freqs) + if method is not None: + state["method"] = method + self.__setstate__(state) def combine_tfr(all_tfr, weights="nave"): @@ -2905,10 +4011,10 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: assert t_.ch_names == ch_names, ValueError( - "%s and %s do not contain " "the same channels" % (tfr, t_) + f"{tfr} and {t_} do not contain the same channels" ) assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( - "%s and %s do not contain the same time instants" % (tfr, t_) + f"{tfr} and {t_} do not contain the same time instants" ) # use union of bad channels @@ -2924,6 +4030,7 @@ def combine_tfr(all_tfr, weights="nave"): # Utils +# ↓↓↓↓↓↓↓↓↓↓↓ this is still used in _stockwell.py def _get_data(inst, return_itc): """Get data from Epochs or Evoked instance as epochs x ch x time.""" from ..epochs import BaseEpochs @@ -3017,8 +4124,7 @@ def _preproc_tfr( return data, times, freqs, vmin, vmax -# TODO: Name duplication with mne/utils/mixin.py -def _check_decim(decim): +def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") if not isinstance(decim, slice): @@ -3040,10 +4146,11 @@ def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): ---------- fname : path-like The file name, which should end with ``-tfr.h5``. - tfr : AverageTFR | list of AverageTFR | EpochsTFR - The TFR dataset, or list of TFR datasets, to save in one file. - Note. If .comment is not None, a name will be generated on the fly, - based on the order in which the TFR objects are passed. + tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR + The (list of) TFR object(s) to save in one file. If ``tfr.comment`` is ``None``, + a sequential numeric string name will be generated on the fly, based on the + order in which the TFR objects are passed. This can be used to selectively load + single TFR objects from the file later. %(overwrite)s %(verbose)s @@ -3054,92 +4161,116 @@ def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): Notes ----- .. versionadded:: 0.9.0 - """ + """ # noqa E501 _, write_hdf5 = _import_h5io_funcs() out = [] if not isinstance(tfr, (list, tuple)): tfr = [tfr] for ii, tfr_ in enumerate(tfr): - comment = ii if tfr_.comment is None else tfr_.comment - out.append(_prepare_write_tfr(tfr_, condition=comment)) + comment = ii if getattr(tfr_, "comment", None) is None else tfr_.comment + state = tfr_.__getstate__() + if "metadata" in state: + state["metadata"] = _prepare_write_metadata(state["metadata"]) + out.append((comment, state)) write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") -def _prepare_write_tfr(tfr, condition): - """Aux function.""" - attributes = dict( - times=tfr.times, - freqs=tfr.freqs, - data=tfr.data, - info=tfr.info, - comment=tfr.comment, - method=tfr.method, - ) - if hasattr(tfr, "nave"): # if AverageTFR - attributes["nave"] = tfr.nave - elif hasattr(tfr, "events"): # if EpochsTFR - attributes["events"] = tfr.events - attributes["event_id"] = tfr.event_id - attributes["selection"] = tfr.selection - attributes["drop_log"] = tfr.drop_log - attributes["metadata"] = _prepare_write_metadata(tfr.metadata) - return condition, attributes - - @verbose def read_tfrs(fname, condition=None, *, verbose=None): - """Read TFR datasets from hdf5 file. + """Load a TFR object from disk. Parameters ---------- fname : path-like - The file name, which should end with -tfr.h5 . + Path to a TFR file in HDF5 format. condition : int or str | list of int or str | None - The condition to load. If None, all conditions will be returned. - Defaults to None. + The condition to load. If ``None``, all conditions will be returned. + Defaults to ``None``. %(verbose)s Returns ------- - tfr : AverageTFR | list of AverageTFR | EpochsTFR - Depending on ``condition`` either the TFR object or a list of multiple - TFR objects. + tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR + The loaded time-frequency object. See Also -------- + mne.time_frequency.RawTFR.save + mne.time_frequency.EpochsTFR.save + mne.time_frequency.AverageTFR.save write_tfrs Notes ----- .. versionadded:: 0.9.0 - """ - check_fname(fname, "tfr", ("-tfr.h5", "_tfr.h5")) + """ # noqa E501 read_hdf5, _ = _import_h5io_funcs() + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + valid_fnames = tuple( + f"{sep}tfr.{ext}" for sep in ("-", "_") for ext in ("h5", "hdf5") + ) + check_fname(fname, "tfr", valid_fnames) + logger.info(f"Reading {fname} ...") + hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") + # single TFR from TFR.save() + if "inst_type_str" in hdf5_dict: + if "epoch" in hdf5_dict["dims"]: + Klass = EpochsTFR + elif "nave" in hdf5_dict: + Klass = AverageTFR + else: + Klass = RawTFR + out = Klass(inst=hdf5_dict) + if getattr(out, "metadata", None) is not None: + out.metadata = _prepare_read_metadata(out.metadata) + return out + # maybe multiple TFRs from write_tfrs() + return _read_multiple_tfrs(hdf5_dict, condition=condition, verbose=verbose) - logger.info("Reading %s ..." % fname) - tfr_data = read_hdf5(fname, title="mnepython", slash="replace") - for k, tfr in tfr_data: + +@verbose +def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): + """Read (possibly multiple) TFR datasets from an h5 file written by write_tfrs().""" + out = list() + keys = list() + # tfr_data is a list of (comment, tfr_dict) tuples + for key, tfr in tfr_data: + keys.append(str(key)) # auto-assigned keys are ints + is_epochs = tfr["data"].ndim == 4 + is_average = "nave" in tfr + if condition is not None: + if not is_average: + raise NotImplementedError( + "condition is only supported when reading AverageTFRs." + ) + if key != condition: + continue + tfr = dict(tfr) tfr["info"] = Info(tfr["info"]) tfr["info"]._check_consistency() if "metadata" in tfr: tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) - is_average = "nave" in tfr - if condition is not None: - if not is_average: - raise NotImplementedError( - "condition not supported when reading " "EpochsTFR." - ) - tfr_dict = dict(tfr_data) - if condition not in tfr_dict: - keys = ["%s" % k for k in tfr_dict] - raise ValueError( - 'Cannot find condition ("{}") in this file. ' - 'The file contains "{}""'.format(condition, " or ".join(keys)) + # additional keys needed for TFR __setstate__ + defaults = dict(baseline=None, data_type="Power Estimates") + if is_epochs: + Klass = EpochsTFR + defaults.update( + inst_type_str="Epochs", dims=("epoch", "channel", "freq", "time") ) - out = AverageTFR(**tfr_dict[condition]) - else: - inst = AverageTFR if is_average else EpochsTFR - out = [inst(**d) for d in list(zip(*tfr_data))[1]] + elif is_average: + Klass = AverageTFR + defaults.update(inst_type_str="Evoked", dims=("channel", "freq", "time")) + else: + Klass = RawTFR + defaults.update(inst_type_str="Raw", dims=("channel", "freq", "time")) + out.append(Klass(inst=defaults | tfr)) + if len(out) == 0: + raise ValueError( + f'Cannot find condition "{condition}" in this file. ' + f'The file contains conditions {", ".join(keys)}' + ) + if len(out) == 1: + out = out[0] return out @@ -3148,28 +4279,28 @@ def _get_timefreqs(tfr, timefreqs): # Input check timefreq_error_msg = ( "Supplied `timefreqs` are somehow malformed. Please supply None, " - "a list of tuple pairs, or a dict of such tuple pairs, not: " + "a list of tuple pairs, or a dict of such tuple pairs, not {}" ) if isinstance(timefreqs, dict): for k, v in timefreqs.items(): for item in (k, v): - if len(item) != 2 or any((not _is_numeric(n) for n in item)): + if len(item) != 2 or any(not _is_numeric(n) for n in item): raise ValueError(timefreq_error_msg, item) elif timefreqs is not None: if not hasattr(timefreqs, "__len__"): - raise ValueError(timefreq_error_msg, timefreqs) - if len(timefreqs) == 2 and all((_is_numeric(v) for v in timefreqs)): + raise ValueError(timefreq_error_msg.format(timefreqs)) + if len(timefreqs) == 2 and all(_is_numeric(v) for v in timefreqs): timefreqs = [tuple(timefreqs)] # stick a pair of numbers in a list else: for item in timefreqs: if ( hasattr(item, "__len__") and len(item) == 2 - and all((_is_numeric(n) for n in item)) + and all(_is_numeric(n) for n in item) ): pass else: - raise ValueError(timefreq_error_msg, item) + raise ValueError(timefreq_error_msg.format(item)) # If None, automatic identification of max peak else: @@ -3196,59 +4327,66 @@ def _get_timefreqs(tfr, timefreqs): return timefreqs -def _preproc_tfr_instance( - tfr, - picks, - tmin, - tmax, - fmin, - fmax, - vmin, - vmax, - dB, - mode, - baseline, - exclude, - copy=True, -): - """Baseline and truncate (times and freqs) a TFR instance.""" - tfr = tfr.copy() if copy else tfr - - exclude = None if picks is None else exclude - picks = _picks_to_idx(tfr.info, picks, exclude="bads") - pick_names = [tfr.info["ch_names"][pick] for pick in picks] - tfr.pick(pick_names) - - if exclude == "bads": - exclude = [ch for ch in tfr.info["bads"] if ch in tfr.info["ch_names"]] - if exclude is not None: - tfr.drop_channels(exclude) - - data, times, freqs, _, _ = _preproc_tfr( - tfr.data, - tfr.times, - tfr.freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - tfr.info["sfreq"], - copy=False, - ) - - tfr._set_times(times) - tfr.freqs = freqs - tfr.data = data - - return tfr - - def _check_tfr_complex(tfr, reason="source space estimation"): """Check that time-frequency epochs or average data is complex.""" if not np.iscomplexobj(tfr.data): raise RuntimeError(f"Time-frequency data must be complex for {reason}") + + +def _merge_if_grads(data, info, ch_type, sphere, combine=None): + if ch_type == "grad": + grad_picks = _pair_grad_sensors(info, topomap_coords=False) + pos = _find_topomap_coords(info, picks=grad_picks[::2], sphere=sphere) + grad_method = combine if isinstance(combine, str) else "rms" + data, _ = _merge_ch_data(data[grad_picks], ch_type, [], method=grad_method) + else: + pos, _ = _get_pos_outlines(info, picks=ch_type, sphere=sphere) + return data, pos + + +@verbose +def _prep_data_for_plot( + data, + times, + freqs, + *, + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode=None, + dB=False, + verbose=None, +): + # baseline + copy = baseline is not None + data = rescale(data, times, baseline, mode, copy=copy, verbose=verbose) + # crop times + time_mask = np.nonzero(_time_mask(times, tmin, tmax))[0] + times = times[time_mask] + # crop freqs + freq_mask = np.nonzero(_time_mask(freqs, fmin, fmax))[0] + freqs = freqs[freq_mask] + # crop data + data = data[..., freq_mask, :][..., time_mask] + # complex amplitude → real power; real-valued data is already power (or ITC) + if np.iscomplexobj(data): + data = (data * data.conj()).real + if dB: + data = 10 * np.log10(data) + return data, times, freqs + + +def _warn_deprecated_vmin_vmax(vlim, vmin, vmax): + if vmin is not None or vmax is not None: + warning = "Parameters `vmin` and `vmax` are deprecated, use `vlim` instead." + if vlim[0] is None and vlim[1] is None: + vlim = (vmin, vmax) + else: + warning += ( + " You've also provided a (non-default) value for `vlim`, " + "so `vmin` and `vmax` will be ignored." + ) + warn(warning, FutureWarning) + return vlim diff --git a/mne/transforms.py b/mne/transforms.py index f0efd287f40..7a3875ef56c 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -111,8 +111,8 @@ class Transform(dict): ``'ctf_meg'``, ``'unknown'``. """ - def __init__(self, fro, to, trans=None): # noqa: D102 - super(Transform, self).__init__() + def __init__(self, fro, to, trans=None): + super().__init__() # we could add some better sanity checks here fro = _to_const(fro) to = _to_const(to) @@ -223,8 +223,7 @@ def _print_coord_trans( scale = 1000.0 if (ti != 3 and units != "mm") else 1.0 text = " mm" if ti != 3 else "" log_func( - " % 8.6f % 8.6f % 8.6f %7.2f%s" - % (tt[0], tt[1], tt[2], scale * tt[3], text) + f" {tt[0]:8.6f} {tt[1]:8.6f} {tt[2]:8.6f} {scale * tt[3]:7.2f}{text}" ) @@ -237,13 +236,9 @@ def _find_trans(subject, subjects_dir=None): trans_fnames = glob.glob(str(subjects_dir / subject / "*-trans.fif")) if len(trans_fnames) < 1: - raise RuntimeError( - "Could not find the transformation for " "{subject}".format(subject=subject) - ) + raise RuntimeError(f"Could not find the transformation for {subject}") elif len(trans_fnames) > 1: - raise RuntimeError( - "Found multiple transformations for " "{subject}".format(subject=subject) - ) + raise RuntimeError(f"Found multiple transformations for {subject}") return Path(trans_fnames[0]) @@ -666,7 +661,7 @@ def transform_surface_to(surf, dest, trans, copy=False): if isinstance(dest, str): if dest not in _str_to_frame: raise KeyError( - 'dest must be one of %s, not "%s"' % (list(_str_to_frame.keys()), dest) + f'dest must be one of {list(_str_to_frame.keys())}, not "{dest}"' ) dest = _str_to_frame[dest] # convert to integer if surf["coord_frame"] == dest: @@ -1022,7 +1017,7 @@ def transform(self, pts, verbose=None): dest : shape (n_transform, 3) The transformed points. """ - logger.info("Transforming %s points" % (len(pts),)) + logger.info(f"Transforming {len(pts)} points") assert pts.shape[1] == 3 # for memory reasons, we should do this in ~100 MB chunks out = np.zeros_like(pts) @@ -1143,11 +1138,8 @@ def fit( dest_center = _fit_sphere(hsp, disp=False)[1] destination = destination - dest_center logger.info( - " Using centers %s -> %s" - % ( - np.array_str(src_center, None, 3), - np.array_str(dest_center, None, 3), - ) + " Using centers {np.array_str(src_center, None, 3)} -> " + "{np.array_str(dest_center, None, 3)}" ) self._fit_params = dict( n_src=len(source), @@ -1355,6 +1347,28 @@ def _quat_to_affine(quat): return affine +def _affine_to_quat(affine): + assert affine.shape[-2:] == (4, 4) + return np.concatenate( + [rot_to_quat(affine[..., :3, :3]), affine[..., :3, 3]], + axis=-1, + ) + + +def _angle_dist_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"): + a = _affine_to_quat(a) + b = np.zeros(6) if b is None else _affine_to_quat(b) + ang = _angle_between_quats(a[..., :3], b[..., :3]) + dist = np.linalg.norm(a[..., 3:] - b[..., 3:], axis=-1) + assert isinstance(angle_units, str) and angle_units in ("rad", "deg") + if angle_units == "deg": + ang = np.rad2deg(ang) + assert isinstance(distance_units, str) and distance_units in ("m", "mm") + if distance_units == "mm": + dist *= 1e3 + return ang, dist + + def _angle_between_quats(x, y=None): """Compute the ang between two quaternions w/3-element representations.""" # z = conj(x) * y @@ -1554,7 +1568,7 @@ def read_ras_mni_t(subject, subjects_dir=None): def _read_fs_xfm(fname): """Read a Freesurfer transform from a .xfm file.""" assert fname.endswith(".xfm") - with open(fname, "r") as fid: + with open(fname) as fid: logger.debug("Reading FreeSurfer talairach.xfm file:\n%s" % fname) # read lines until we get the string 'Linear_Transform', which precedes @@ -1563,7 +1577,7 @@ def _read_fs_xfm(fname): for li, line in enumerate(fid): if li == 0: kind = line.strip() - logger.debug("Found: %r" % (kind,)) + logger.debug(f"Found: {repr(kind)}") if line[: len(comp)] == comp: # we have the right line, so don't read any more break @@ -1843,10 +1857,7 @@ def _compute_volume_registration( # report some useful information if step in ("translation", "rigid"): - dist = np.linalg.norm(reg_affine[:3, 3]) - angle = np.rad2deg( - _angle_between_quats(np.zeros(3), rot_to_quat(reg_affine[:3, :3])) - ) + angle, dist = _angle_dist_between_rigid(reg_affine, angle_units="deg") logger.info(f" Translation: {dist:6.1f} mm") if step == "rigid": logger.info(f" Rotation: {angle:6.1f}°") diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index 42694921f00..e22d8f6166c 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -32,7 +32,7 @@ __all__ = [ "_check_depth", "_check_dict_keys", "_check_dt", - "_check_edflib_installed", + "_check_edfio_installed", "_check_eeglabio_installed", "_check_event_id", "_check_fname", @@ -41,6 +41,7 @@ __all__ = [ "_check_if_nan", "_check_info_inv", "_check_integer_or_list", + "_check_method_kwargs", "_check_on_missing", "_check_one_ch_type", "_check_option", @@ -230,7 +231,7 @@ from .check import ( _check_compensation_grade, _check_depth, _check_dict_keys, - _check_edflib_installed, + _check_edfio_installed, _check_eeglabio_installed, _check_event_id, _check_fname, @@ -239,6 +240,7 @@ from .check import ( _check_if_nan, _check_info_inv, _check_integer_or_list, + _check_method_kwargs, _check_on_missing, _check_one_ch_type, _check_option, diff --git a/mne/utils/_bunch.py b/mne/utils/_bunch.py index 0fdac59139f..26cc4e6b17a 100644 --- a/mne/utils/_bunch.py +++ b/mne/utils/_bunch.py @@ -15,7 +15,7 @@ class Bunch(dict): """Dictionary-like object that exposes its keys as attributes.""" - def __init__(self, **kwargs): # noqa: D102 + def __init__(self, **kwargs): dict.__init__(self, kwargs) self.__dict__ = self @@ -63,7 +63,7 @@ def __new__(cls, name, val): # noqa: D102,D105 return out def __str__(self): # noqa: D105 - return "%s (%s)" % (str(self.__class__.mro()[-2](self)), self._name) + return f"{str(self.__class__.mro()[-2](self))} ({self._name})" __repr__ = __str__ diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 1dcb1a5e8a6..f4546e5e7d8 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -159,7 +159,7 @@ class use_log_level: This message will be printed! """ - def __init__(self, verbose=None, *, add_frames=None): # noqa: D102 + def __init__(self, verbose=None, *, add_frames=None): self._level = verbose self._add_frames = add_frames self._old_frames = _filter.add_frames diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 999d6242695..f0e76c70e8a 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -50,7 +50,7 @@ def __new__(self): # noqa: D105 new = str.__new__(self, tempfile.mkdtemp(prefix="tmp_mne_tempdir_")) return new - def __init__(self): # noqa: D102 + def __init__(self): self._path = self.__str__() def __del__(self): # noqa: D105 @@ -121,7 +121,7 @@ def run_command_if_main(): class ArgvSetter: """Temporarily set sys.argv.""" - def __init__(self, args=(), disable_stdout=True, disable_stderr=True): # noqa: D102 + def __init__(self, args=(), disable_stdout=True, disable_stderr=True): self.argv = list(("python",) + args) self.stdout = ClosingStringIO() if disable_stdout else sys.stdout self.stderr = ClosingStringIO() if disable_stderr else sys.stderr @@ -243,16 +243,13 @@ def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG"): snr = snrs.min() bad_count = (snrs < min_tol).sum() msg = " (%s)" % msg if msg != "" else msg - assert bad_count == 0, "SNR (worst %0.2f) < %0.2f for %s/%s " "channels%s" % ( - snr, - min_tol, - bad_count, - len(picks), - msg, + assert bad_count == 0, ( + f"SNR (worst {snr:0.2f}) < {min_tol:0.2f} " + f"for {bad_count}/{len(picks)} channels{msg}" ) # median tol snr = np.median(snrs) - assert snr >= med_tol, "%s SNR median %0.2f < %0.2f%s" % (kind, snr, med_tol, msg) + assert snr >= med_tol, f"{kind} SNR median {snr:0.2f} < {med_tol:0.2f}{msg}" def assert_meg_snr( @@ -296,7 +293,7 @@ def assert_snr(actual, desired, tol): """Assert actual and desired arrays are within some SNR tolerance.""" with np.errstate(divide="ignore"): # allow infinite snr = linalg.norm(desired, ord="fro") / linalg.norm(desired - actual, ord="fro") - assert snr >= tol, "%f < %f" % (snr, tol) + assert snr >= tol, f"{snr} < {tol}" def assert_stcs_equal(stc1, stc2): @@ -344,7 +341,7 @@ def assert_dig_allclose(info_py, info_bin, limit=None): d_bin["r"], rtol=1e-5, atol=1e-5, - err_msg="Failure on %s:\n%s\n%s" % (ii, d_py["r"], d_bin["r"]), + err_msg=f"Failure on {ii}:\n{d_py['r']}\n{d_bin['r']}", ) if any(d["kind"] == FIFF.FIFFV_POINT_EXTRA for d in dig_py) and info_py is not None: r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape( @@ -368,3 +365,13 @@ def _click_ch_name(fig, ch_index=0, button=1): x = bbox.intervalx.mean() y = bbox.intervaly.mean() _fake_click(fig, fig.mne.ax_main, (x, y), xform="pix", button=button) + + +def _get_suptitle(fig): + """Get fig suptitle (shim for matplotlib < 3.8.0).""" + # TODO: obsolete when minimum MPL version is 3.8 + if check_version("matplotlib", "3.8"): + return fig.get_suptitle() + else: + # unreliable hack; should work in most tests as we rarely use `sup_{x,y}label` + return fig.texts[0].get_text() diff --git a/mne/utils/check.py b/mne/utils/check.py index eb8e14de256..80d87cafd2b 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -8,14 +8,13 @@ import operator import os import re -from builtins import input # no-op here but facilitates testing +from builtins import input # noqa: UP029 from difflib import get_close_matches from importlib import import_module -from importlib.metadata import version +from inspect import signature from pathlib import Path import numpy as np -from packaging.version import parse from ..defaults import HEAD_SIZE_DEFAULT, _handle_default from ..fixes import _compare_version, _median_complex @@ -67,14 +66,14 @@ def check_fname(fname, filetype, endings, endings_err=()): if len(endings_err) > 0 and not fname.endswith(endings_err): print_endings = " or ".join([", ".join(endings_err[:-1]), endings_err[-1]]) raise OSError( - "The filename (%s) for file type %s must end with %s" - % (fname, filetype, print_endings) + f"The filename ({fname}) for file type {filetype} must end " + f"with {print_endings}" ) print_endings = " or ".join([", ".join(endings[:-1]), endings[-1]]) if not fname.endswith(endings): warn( - "This filename (%s) does not conform to MNE naming conventions. " - "All %s files should end with %s" % (fname, filetype, print_endings) + f"This filename ({fname}) does not conform to MNE naming conventions. " + f"All {filetype} files should end with {print_endings}" ) @@ -232,11 +231,29 @@ def _check_fname( name="File", need_dir=False, *, + check_bids_split=False, verbose=None, ): """Check for file existence, and return its absolute path.""" _validate_type(fname, "path-like", name) - fname = Path(fname).expanduser().absolute() + # special case for MNE-BIDS, check split + fname_path = Path(fname) + if check_bids_split: + try: + from mne_bids import BIDSPath + except Exception: + pass + else: + if isinstance(fname, BIDSPath) and fname.split is not None: + raise ValueError( + f"Passing a BIDSPath {name} with `{fname.split=}` is unsafe as it " + "can unexpectedly lead to invalid BIDS split naming. Explicitly " + f"set `{name}.split = None` to avoid ambiguity. If you want the " + f"old misleading split naming, you can pass `str({name})`." + ) + + fname = fname_path.expanduser().absolute() + del fname_path if fname.exists(): if not overwrite: @@ -296,19 +313,20 @@ def _check_preload(inst, msg): """Ensure data are preloaded.""" from ..epochs import BaseEpochs from ..evoked import Evoked - from ..time_frequency import _BaseTFR + from ..source_estimate import _BaseSourceEstimate + from ..time_frequency import BaseTFR from ..time_frequency.spectrum import BaseSpectrum - if isinstance(inst, (_BaseTFR, Evoked, BaseSpectrum)): + if isinstance(inst, (BaseTFR, Evoked, BaseSpectrum, _BaseSourceEstimate)): pass else: name = "epochs" if isinstance(inst, BaseEpochs) else "raw" if not inst.preload: raise RuntimeError( "By default, MNE does not load data into main memory to " - "conserve resources. " + msg + " requires %s data to be " + "conserve resources. " + msg + f" requires {name} data to be " "loaded. Use preload=True (or string) in the constructor or " - "%s.load_data()." % (name, name) + f"{name}.load_data()." ) if name == "epochs": inst._handle_empty("raise", msg) @@ -342,8 +360,8 @@ def _check_compensation_grade(info1, info2, name1, name2="data", ch_names=None): # perform check if grade1 != grade2: raise RuntimeError( - "Compensation grade of %s (%s) and %s (%s) do not match" - % (name1, grade1, name2, grade2) + f"Compensation grade of {name1} ({grade1}) and {name2} ({grade2}) " + "do not match" ) @@ -368,7 +386,6 @@ def indent(x): # Mapping import namespaces to their pypi package name pip_name = dict( sklearn="scikit-learn", - EDFlib="EDFlib-Python", mne_bids="mne-bids", mne_nirs="mne-nirs", mne_features="mne-features", @@ -411,21 +428,9 @@ def _check_eeglabio_installed(strict=True): return _soft_import("eeglabio", "exporting to EEGLab", strict=strict) -def _check_edflib_installed(strict=True): +def _check_edfio_installed(strict=True): """Aux function.""" - out = _soft_import("EDFlib", "exporting to EDF", strict=strict) - if out: - # EDFlib-Python 1.0.7 is not compatible with NumPy 2.0 - # https://gitlab.com/Teuniz/EDFlib-Python/-/issues/10 - ver = version("EDFlib-Python") - if parse(ver) <= parse("1.0.7") and parse(np.__version__).major >= 2: - if strict: # pragma: no cover - raise RuntimeError( - f"EDFlib version={ver} is not compatible with NumPy 2.0, consider " - "upgrading EDFlib-Python" - ) - out = False - return out + return _soft_import("edfio", "exporting to EDF", strict=strict) def _check_pybv_installed(strict=True): @@ -446,8 +451,8 @@ def _check_pandas_index_arguments(index, valid): index = [index] if not isinstance(index, list): raise TypeError( - "index must be `None` or a string or list of strings," - " got type {}.".format(type(index)) + "index must be `None` or a string or list of strings, got type " + f"{type(index)}." ) invalid = set(index) - set(valid) if invalid: @@ -467,8 +472,8 @@ def _check_time_format(time_format, valid, meas_date=None): if time_format not in valid and time_format is not None: valid_str = '", "'.join(valid) raise ValueError( - '"{}" is not a valid time format. Valid options are ' - '"{}" and None.'.format(time_format, valid_str) + f'"{time_format}" is not a valid time format. Valid options are ' + f'"{valid_str}" and None.' ) # allow datetime only if meas_date available if time_format == "datetime" and meas_date is None: @@ -664,10 +669,11 @@ def _path_like(item): def _check_if_nan(data, msg=" to be plotted"): """Raise if any of the values are NaN.""" if not np.isfinite(data).all(): - raise ValueError("Some of the values {} are NaN.".format(msg)) + raise ValueError(f"Some of the values {msg} are NaN.") -def _check_info_inv(info, forward, data_cov=None, noise_cov=None): +@verbose +def _check_info_inv(info, forward, data_cov=None, noise_cov=None, verbose=None): """Return good channels common to forward model and covariance matrices.""" from .._fiff.pick import pick_types @@ -711,6 +717,19 @@ def _check_info_inv(info, forward, data_cov=None, noise_cov=None): if noise_cov is not None: ch_names = _compare_ch_names(ch_names, noise_cov.ch_names, noise_cov["bads"]) + # inform about excluding any channels apart from bads and reference + all_bads = info["bads"] + ref_chs + if data_cov is not None: + all_bads += data_cov["bads"] + if noise_cov is not None: + all_bads += noise_cov["bads"] + dropped_nonbads = set(info["ch_names"]) - set(ch_names) - set(all_bads) + if dropped_nonbads: + logger.info( + f"Excluding {len(dropped_nonbads)} channel(s) missing from the " + "provided forward operator and/or covariance matrices" + ) + picks = [info["ch_names"].index(k) for k in ch_names if k in info["ch_names"]] return picks @@ -746,9 +765,7 @@ def _check_rank(rank): _validate_type(rank, (None, dict, str), "rank") if isinstance(rank, str): if rank not in ["full", "info"]: - raise ValueError( - 'rank, if str, must be "full" or "info", ' "got %s" % (rank,) - ) + raise ValueError(f'rank, if str, must be "full" or "info", got {rank}') return rank @@ -765,7 +782,13 @@ def _check_one_ch_type(method, info, forward, data_cov=None, noise_cov=None): info_pick = info else: _validate_type(noise_cov, [None, Covariance], "noise_cov") - picks = _check_info_inv(info, forward, data_cov=data_cov, noise_cov=noise_cov) + picks = _check_info_inv( + info, + forward, + data_cov=data_cov, + noise_cov=noise_cov, + verbose=_verbose_safe_false(), + ) info_pick = pick_info(info, picks) ch_types = [_contains_ch_type(info_pick, tt) for tt in ("mag", "grad", "eeg")] if sum(ch_types) > 1: @@ -892,6 +915,7 @@ def _check_all_same_channel_names(instances): def _check_combine(mode, valid=("mean", "median", "std"), axis=0): + # XXX TODO Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py if mode == "mean": def fun(data): @@ -913,7 +937,7 @@ def fun(data): raise ValueError( "Combine option must be " + ", ".join(valid) - + " or callable, got %s (type %s)." % (mode, type(mode)) + + f" or callable, got {mode} (type {type(mode)})." ) return fun @@ -926,7 +950,7 @@ def _check_src_normal(pick_ori, src): raise RuntimeError( "Normal source orientation is supported only for " "surface or discrete SourceSpaces, got type " - "%s" % (src.kind,) + f"{src.kind}" ) @@ -1068,7 +1092,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): raise ValueError( "sphere, if a ConductorModel, must be spherical " "with multiple layers, not a BEM or single-layer " - "sphere (got %s)" % (sphere,) + f"sphere (got {sphere})" ) sphere = tuple(sphere["r0"]) + (sphere["layers"][0]["rad"],) sphere_units = "m" @@ -1078,7 +1102,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): if sphere.shape != (4,): raise ValueError( "sphere must be float or 1D array of shape (4,), got " - "array-like of shape %s" % (sphere.shape,) + f"array-like of shape {sphere.shape}" ) _check_option("sphere_units", sphere_units, ("m", "mm")) if sphere_units == "mm": @@ -1147,9 +1171,9 @@ def _suggest(val, options, cutoff=0.66): if len(options) == 0: return "" elif len(options) == 1: - return " Did you mean %r?" % (options[0],) + return f" Did you mean {repr(options[0])}?" else: - return " Did you mean one of %r?" % (options,) + return f" Did you mean one of {repr(options)}?" def _check_on_missing(on_missing, name="on_missing", *, extras=()): @@ -1220,7 +1244,21 @@ def _import_nibabel(why="use MRI files"): try: import nibabel as nib except ImportError as exp: - raise exp.__class__( - "nibabel is required to %s, got:\n%s" % (why, exp) - ) from None + raise exp.__class__(f"nibabel is required to {why}, got:\n{exp}") from None return nib + + +def _check_method_kwargs(func, kwargs, msg=None): + """Ensure **kwargs are compatible with the function they're passed to.""" + from .misc import _pl + + valid = list(signature(func).parameters) + is_invalid = np.isin(list(kwargs), valid, invert=True) + if is_invalid.any(): + invalid_kw = np.array(list(kwargs))[is_invalid].tolist() + s = _pl(invalid_kw) + if msg is None: + msg = f'function "{func}"' + raise TypeError( + f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} ' f"for {msg}." + ) diff --git a/mne/utils/config.py b/mne/utils/config.py index fe4bc7079a4..9fab1015040 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -10,7 +10,6 @@ import os import os.path as op import platform -import re import shutil import subprocess import sys @@ -218,7 +217,7 @@ def set_memmap_min_size(memmap_min_size): def _load_config(config_path, raise_error=False): """Safely load a config file.""" - with open(config_path, "r") as fid: + with open(config_path) as fid: try: config = json.load(fid) except ValueError: @@ -329,9 +328,9 @@ def get_config(key=None, default=None, raise_error=False, home_dir=None, use_env "for a permanent one" % key ) raise KeyError( - 'Key "%s" not found in %s' - "the mne-python config file (%s). " - "Try %s%s.%s" % (key, loc_env, config_path, meth_env, meth_file, extra_env) + f'Key "{key}" not found in {loc_env}' + f"the mne-python config file ({config_path}). " + f"Try {meth_env}{meth_file}.{extra_env}" ) else: return config.get(key, default) @@ -626,16 +625,6 @@ def sys_info( _validate_type(check_version, (bool, "numeric"), "check_version") ljust = 24 if dependencies == "developer" else 21 platform_str = platform.platform() - if platform.system() == "Darwin" and sys.version_info[:2] < (3, 8): - # platform.platform() in Python < 3.8 doesn't call - # platform.mac_ver() if we're on Darwin, so we don't get a nice macOS - # version number. Therefore, let's do this manually here. - macos_ver = platform.mac_ver()[0] - macos_architecture = re.findall("Darwin-.*?-(.*)", platform_str) - if macos_architecture: - macos_architecture = macos_architecture[0] - platform_str = f"macOS-{macos_ver}-{macos_architecture}" - del macos_ver, macos_architecture out = partial(print, end="", file=fid) out("Platform".ljust(ljust) + platform_str + "\n") @@ -660,8 +649,6 @@ def sys_info( "numpy", "scipy", "matplotlib", - "pooch", - "jinja2", "", "# Numerical (optional)", "sklearn", @@ -672,6 +659,8 @@ def sys_info( "openmeeg", "cupy", "pandas", + "h5io", + "h5py", "", "# Visualization (optional)", "pyvista", @@ -695,6 +684,11 @@ def sys_info( "mne-connectivity", "mne-icalabel", "mne-bids-pipeline", + "neo", + "eeglabio", + "edfio", + "mffpy", + "pybv", "", ) if dependencies == "developer": @@ -702,15 +696,28 @@ def sys_info( "# Testing", "pytest", "nbclient", + "statsmodels", "numpydoc", "flake8", "pydocstyle", + "nitime", + "imageio", + "imageio-ffmpeg", + "snirf", "", "# Documentation", "sphinx", "sphinx-gallery", "pydata-sphinx-theme", "", + "# Infrastructure", + "decorator", + "jinja2", + # "lazy-loader", + "packaging", + "pooch", + "tqdm", + "", ) try: unicode = unicode and (sys.stdout.encoding.lower().startswith("utf")) diff --git a/mne/utils/dataframe.py b/mne/utils/dataframe.py index 599a2f88165..95618c614fa 100644 --- a/mne/utils/dataframe.py +++ b/mne/utils/dataframe.py @@ -10,6 +10,7 @@ from ..defaults import _handle_default from ._logging import logger, verbose +from .check import check_version @verbose @@ -17,7 +18,7 @@ def _set_pandas_dtype(df, columns, dtype, verbose=None): """Try to set the right columns to dtype.""" for column in columns: df[column] = df[column].astype(dtype) - logger.info('Converting "%s" to "%s"...' % (column, dtype)) + logger.info(f'Converting "{column}" to "{dtype}"...') def _scale_dataframe_data(inst, data, picks, scalings): @@ -35,7 +36,7 @@ def _scale_dataframe_data(inst, data, picks, scalings): return data -def _convert_times(inst, times, time_format): +def _convert_times(times, time_format, meas_date=None, first_time=0): """Convert vector of time in seconds to ms, datetime, or timedelta.""" # private function; pandas already checked in calling function from pandas import to_timedelta @@ -45,14 +46,22 @@ def _convert_times(inst, times, time_format): elif time_format == "timedelta": times = to_timedelta(times, unit="s") elif time_format == "datetime": - times = to_timedelta(times + inst.first_time, unit="s") + inst.info["meas_date"] + times = to_timedelta(times + first_time, unit="s") + meas_date return times def _inplace(df, method, **kwargs): - """Handle transition: inplace=True (pandas <1.5) → copy=False (>=1.5).""" + # Handle transition: inplace=True (pandas <1.5) → copy=False (>=1.5) + # and 3.0 warning: + # E DeprecationWarning: The copy keyword is deprecated and will be removed in a + # future version. Copy-on-Write is active in pandas since 3.0 which utilizes a + # lazy copy mechanism that defers copies until necessary. Use .copy() to make + # an eager copy if necessary. _meth = getattr(df, method) # used for set_index() and rename() - if "copy" in signature(_meth).parameters: + + if check_version("pandas", "3.0"): + return _meth(**kwargs) + elif "copy" in signature(_meth).parameters: return _meth(**kwargs, copy=False) else: _meth(**kwargs, inplace=True) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index d8eb668ae04..f29ff9508a5 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -64,42 +64,87 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # A -docdict[ - "accept" -] = """ +tfr_arithmetics_return_template = """ +Returns +------- +tfr : instance of RawTFR | instance of EpochsTFR | instance of AverageTFR + {} +""" + +tfr_add_sub_template = """ +Parameters +---------- +other : instance of RawTFR | instance of EpochsTFR | instance of AverageTFR + The TFR instance to {}. Must have the same type as ``self``, and matching + ``.times`` and ``.freqs`` attributes. + +{} +""" + +tfr_mul_truediv_template = """ +Parameters +---------- +num : int | float + The number to {} by. + +{} +""" + +tfr_arithmetics_return = tfr_arithmetics_return_template.format( + "A new TFR instance, of the same type as ``self``." +) +tfr_inplace_arithmetics_return = tfr_arithmetics_return_template.format( + "The modified TFR instance." +) + +docdict["__add__tfr"] = tfr_add_sub_template.format("add", tfr_arithmetics_return) +docdict["__iadd__tfr"] = tfr_add_sub_template.format( + "add", tfr_inplace_arithmetics_return +) +docdict["__imul__tfr"] = tfr_mul_truediv_template.format( + "multiply", tfr_inplace_arithmetics_return +) +docdict["__isub__tfr"] = tfr_add_sub_template.format( + "subtract", tfr_inplace_arithmetics_return +) +docdict["__itruediv__tfr"] = tfr_mul_truediv_template.format( + "divide", tfr_inplace_arithmetics_return +) +docdict["__mul__tfr"] = tfr_mul_truediv_template.format( + "multiply", tfr_arithmetics_return +) +docdict["__sub__tfr"] = tfr_add_sub_template.format("subtract", tfr_arithmetics_return) +docdict["__truediv__tfr"] = tfr_mul_truediv_template.format( + "divide", tfr_arithmetics_return +) + + +docdict["accept"] = """ accept : bool If True (default False), accept the license terms of this dataset. """ -docdict[ - "add_ch_type_export_params" -] = """ +docdict["add_ch_type_export_params"] = """ add_ch_type : bool Whether to incorporate the channel type into the signal label (e.g. whether to store channel "Fz" as "EEG Fz"). Only used for EDF format. Default is ``False``. """ -docdict[ - "add_data_kwargs" -] = """ +docdict["add_data_kwargs"] = """ add_data_kwargs : dict | None Additional arguments to brain.add_data (e.g., ``dict(time_label_size=10)``). """ -docdict[ - "add_frames" -] = """ +docdict["add_frames"] = """ add_frames : int | None If int, enable (>=1) or disable (0) the printing of stack frame information using formatting. Default (None) does not change the formatting. This can add overhead so is meant only for debugging. """ -docdict[ - "adjacency_clust" -] = """ +docdict["adjacency_clust"] = """ adjacency : scipy.sparse.spmatrix | None | False Defines adjacency between locations in the data, where "locations" can be spatial vertices, frequency bins, time points, etc. For spatial vertices @@ -155,25 +200,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["adjacency_clust"].format(**st).format(**groups) ) -docdict[ - "adjust_dig_chpi" -] = """ +docdict["adjust_dig_chpi"] = """ adjust_dig : bool If True, adjust the digitization locations used for fitting based on the positions localized at the start of the file. """ -docdict[ - "agg_fun_psd_topo" -] = """ +docdict["agg_fun_psd_topo"] = """ agg_fun : callable The function used to aggregate over frequencies. Defaults to :func:`numpy.sum` if ``normalize=True``, else :func:`numpy.mean`. """ -docdict[ - "align_view" -] = """ +docdict["align_view"] = """ align : bool If True, consider view arguments relative to canonical MRI directions (closest to MNI for the subject) rather than native MRI @@ -181,16 +220,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): have large rotations). """ -docdict[ - "allow_2d" -] = """ +docdict["allow_2d"] = """ allow_2d : bool If True, allow 2D data as input (i.e. n_samples, n_features). """ -docdict[ - "allow_empty_eltc" -] = """ +docdict["allow_empty_eltc"] = """ allow_empty : bool | str ``False`` (default) will emit an error if there are labels that have no vertices in the source estimate. ``True`` and ``'ignore'`` will return @@ -202,16 +237,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for "ignore". """ -docdict[ - "alpha" -] = """ +docdict["alpha"] = """ alpha : float in [0, 1] Alpha level to control opacity. """ -docdict[ - "anonymize_info_notes" -] = """ +docdict["anonymize_info_notes"] = """ Removes potentially identifying information if it exists in ``info``. Specifically for each of the following we use: @@ -240,11 +271,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # raw/epochs/evoked apply_function method # apply_function method summary applyfun_summary = """\ -The function ``fun`` is applied to the channels defined in ``picks``. +The function ``fun`` is applied to the channels or vertices defined in ``picks``. The {} object's data is modified in-place. If the function returns a different data type (e.g. :py:obj:`numpy.complex128`) it must be specified using the ``dtype`` parameter, which causes the data type of **all** the data -to change (even if the function is only applied to channels in ``picks``).{} +to change (even if the function is only applied to channels/vertices in ``picks``).{} .. note:: If ``n_jobs`` > 1, more memory is required as ``len(picks) * n_times`` additional time points need to @@ -260,17 +291,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["applyfun_summary_epochs"] = applyfun_summary.format("epochs", applyfun_preload) docdict["applyfun_summary_evoked"] = applyfun_summary.format("evoked", "") docdict["applyfun_summary_raw"] = applyfun_summary.format("raw", applyfun_preload) +docdict["applyfun_summary_stc"] = applyfun_summary.format("source estimate", "") -docdict[ - "area_alpha_plot_psd" -] = """\ +docdict["area_alpha_plot_psd"] = """\ area_alpha : float Alpha for the area. """ -docdict[ - "area_mode_plot_psd" -] = """\ +docdict["area_mode_plot_psd"] = """\ area_mode : str | None Mode for plotting area. If 'std', the mean +/- 1 STD (across channels) will be plotted. If 'range', the min and max (across channels) will be @@ -278,18 +306,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If None, no area will be plotted. If average=False, no area is plotted. """ -docdict[ - "aseg" -] = """ +docdict["aseg"] = """ aseg : str The anatomical segmentation file. Default ``aparc+aseg``. This may be any anatomical segmentation file in the mri subdirectory of the Freesurfer subject directory. """ -docdict[ - "average_plot_evoked_topomap" -] = """ +docdict["average_plot_evoked_topomap"] = """ average : float | array-like of float, shape (n_times,) | None The time window (in seconds) around a given time point to be used for averaging. For example, 0.2 would translate into a time window that @@ -303,9 +327,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for ``array-like`` input. """ -docdict[ - "average_plot_psd" -] = """\ +docdict["average_plot_psd"] = """\ average : bool If False, the PSDs of all channels is displayed. No averaging is done and parameters area_mode and area_alpha are ignored. When @@ -313,9 +335,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): drag) to plot a topomap. """ -docdict[ - "average_psd" -] = """\ +docdict["average_psd"] = """\ average : str | None How to average the segments. If ``mean`` (default), calculate the arithmetic mean. If ``median``, calculate the median, corrected for @@ -323,9 +343,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): segments. """ -docdict[ - "average_tfr" -] = """ +docdict["average_tfr"] = """ average : bool, default True If ``False`` return an `EpochsTFR` containing separate TFRs for each epoch. If ``True`` return an `AverageTFR` containing the average of all @@ -340,57 +358,76 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _axes_base = """\ -{} : instance of Axes | {}None - The axes to plot to. If ``None``, a new :class:`~matplotlib.figure.Figure` - will be created{}. {}Default is ``None``. -""" -_axes_num = ( - "If :class:`~matplotlib.axes.Axes` are provided (either as a " - "single instance or a :class:`list` of axes), the number of axes " - "provided must {}." -) +{param} : instance of Axes | {allowed}None + The axes to plot into. If ``None``, a new :class:`~matplotlib.figure.Figure` + will be created{created}. {list_extra}{extra}Default is ``None``. +""" _axes_list = _axes_base.format( - "{}", "list of Axes | ", " with the correct number of axes", _axes_num + param="{param}", + allowed="list of Axes | ", + created=" with the correct number of axes", + list_extra="""If :class:`~matplotlib.axes.Axes` + are provided (either as a single instance or a :class:`list` of axes), + the number of axes provided must {must}. """, + extra="{extra}", +) +_match_chtypes_present_in = "match the number of channel types present in the {}object." +docdict["ax_plot_psd"] = _axes_list.format( + param="ax", must=_match_chtypes_present_in.format(""), extra="" +) +docdict["axes_cov_plot_topomap"] = _axes_list.format( + param="axes", must="be length 1", extra="" ) -_ch_types_present = "match the number of channel types present in the {}" "object." -docdict["ax_plot_psd"] = _axes_list.format("ax", _ch_types_present.format("")) -docdict["axes_cov_plot_topomap"] = _axes_list.format("axes", "be length 1") docdict["axes_evoked_plot_topomap"] = _axes_list.format( - "axes", "match the number of ``times`` provided (unless ``times`` is ``None``)" + param="axes", + must="match the number of ``times`` provided (unless ``times`` is ``None``)", + extra="", ) -docdict[ - "axes_montage" -] = """ +docdict["axes_montage"] = """ axes : instance of Axes | instance of Axes3D | None Axes to draw the sensors to. If ``kind='3d'``, axes must be an instance - of Axes3D. If None (default), a new axes will be created.""" + of Axes3D. If None (default), a new axes will be created. +""" docdict["axes_plot_projs_topomap"] = _axes_list.format( - "axes", "match the number of projectors" + param="axes", + must="match the number of projectors", + extra="", +) +docdict["axes_plot_topomap"] = _axes_base.format( + param="axes", + allowed="", + created="", + list_extra="", + extra="", ) -docdict["axes_plot_topomap"] = _axes_base.format("axes", "", "", "") docdict["axes_spectrum_plot"] = _axes_list.format( - "axes", _ch_types_present.format(":class:`~mne.time_frequency.Spectrum`") + param="axes", + must=_match_chtypes_present_in.format(":class:`~mne.time_frequency.Spectrum` "), + extra="", ) docdict["axes_spectrum_plot_topo"] = _axes_list.format( - "axes", - "be length 1 (for efficiency, subplots for each channel are simulated " + param="axes", + must="be length 1 (for efficiency, subplots for each channel are simulated " "within a single :class:`~matplotlib.axes.Axes` object)", + extra="", ) docdict["axes_spectrum_plot_topomap"] = _axes_list.format( - "axes", "match the length of ``bands``" + param="axes", must="match the length of ``bands``", extra="" +) +docdict["axes_tfr_plot"] = _axes_list.format( + param="axes", + must="match the number of picks", + extra="""If ``combine`` is not None, + ``axes`` must either be an instance of Axes, or a list of length 1. """, ) -docdict[ - "axis_facecolor" -] = """\ +docdict["axis_facecolor"] = """\ axis_facecolor : str | tuple A matplotlib-compatible color to use for the axis background. Defaults to black. """ -docdict[ - "azimuth" -] = """ +docdict["azimuth"] = """ azimuth : float The azimuthal angle of the camera rendering the view in degrees. """ @@ -398,17 +435,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # B -docdict[ - "bad_condition_maxwell_cond" -] = """ +docdict["bad_condition_maxwell_cond"] = """ bad_condition : str How to deal with ill-conditioned SSS matrices. Can be ``"error"`` (default), ``"warning"``, ``"info"``, or ``"ignore"``. """ -docdict[ - "bands_psd_topo" -] = """ +docdict["bands_psd_topo"] = """ bands : None | dict | list of tuple The frequencies or frequency ranges to plot. If a :class:`dict`, keys will be used as subplot titles and values should be either a single frequency @@ -431,9 +464,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Allow passing a dict and discourage passing tuples. """ -docdict[ - "base_estimator" -] = """ +docdict["base_estimator"] = """ base_estimator : object The base estimator to iteratively fit on a subset of the dataset. """ @@ -445,16 +476,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If a tuple ``(a, b)``, the interval is between ``a`` and ``b`` (in seconds), including the endpoints. If ``a`` is ``None``, the **beginning** of the data is used; and if ``b`` - is ``None``, it is set to the **end** of the interval. + is ``None``, it is set to the **end** of the data. If ``(None, None)``, the entire time interval is used. - .. note:: The baseline ``(a, b)`` includes both endpoints, i.e. all - timepoints ``t`` such that ``a <= t <= b``. + .. note:: + The baseline ``(a, b)`` includes both endpoints, i.e. all timepoints ``t`` + such that ``a <= t <= b``. """ -docdict[ - "baseline_epochs" -] = f"""{_baseline_rescale_base} +docdict["baseline_epochs"] = f"""{_baseline_rescale_base} Correction is applied **to each epoch and channel individually** in the following way: @@ -463,9 +493,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ -docdict[ - "baseline_evoked" -] = f"""{_baseline_rescale_base} +docdict["baseline_evoked"] = f"""{_baseline_rescale_base} Correction is applied **to each channel individually** in the following way: @@ -474,9 +502,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ -docdict[ - "baseline_report" -] = f"""{_baseline_rescale_base} +docdict["baseline_report"] = f"""{_baseline_rescale_base} Correction is applied in the following way **to each channel:** 1. Calculate the mean signal of the baseline period. @@ -487,9 +513,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["baseline_rescale"] = _baseline_rescale_base -docdict[ - "baseline_stc" -] = f"""{_baseline_rescale_base} +docdict["baseline_stc"] = f"""{_baseline_rescale_base} Correction is applied **to each source individually** in the following way: @@ -505,47 +529,44 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ -docdict[ - "block" -] = """\ +docdict["baseline_tfr_attr"] = """ +baseline : array-like, shape (2,) + The start and end times of the baseline period, in seconds.""" + + +docdict["block"] = """\ block : bool Whether to halt program execution until the figure is closed. May not work on all systems / platforms. Defaults to ``False``. """ -docdict[ - "border_topomap" -] = """ +docdict["border_topo"] = """ +border : str + Matplotlib border style to be used for each sensor plot. +""" +docdict["border_topomap"] = """ border : float | 'mean' Value to extrapolate to on the topomap borders. If ``'mean'`` (default), then each extrapolated point has the average value of its neighbours. """ -docdict[ - "brain_kwargs" -] = """ +docdict["brain_kwargs"] = """ brain_kwargs : dict | None Additional arguments to the :class:`mne.viz.Brain` constructor (e.g., ``dict(silhouette=True)``). """ -docdict[ - "brain_update" -] = """ +docdict["brain_update"] = """ update : bool Force an update of the plot. Defaults to True. """ -docdict[ - "browser" -] = """ +docdict["browser"] = """ fig : matplotlib.figure.Figure | mne_qt_browser.figure.MNEQtBrowser Browser instance. """ -docdict[ - "buffer_size_clust" -] = """ +docdict["buffer_size_clust"] = """ buffer_size : int | None Block size to use when computing test statistics. This can significantly reduce memory usage when ``n_jobs > 1`` and memory sharing between @@ -554,9 +575,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): a small block of locations at a time. """ -docdict[ - "by_event_type" -] = """ +docdict["by_event_type"] = """ by_event_type : bool When ``False`` (the default) all epochs are processed together and a single :class:`~mne.Evoked` object is returned. When ``True``, epochs are first @@ -571,18 +590,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # C -docdict[ - "calibration_maxwell_cal" -] = """ +docdict["calibration_maxwell_cal"] = """ calibration : str | None Path to the ``'.dat'`` file with fine calibration coefficients. File can have 1D or 3D gradiometer imbalance correction. This file is machine/site-specific. """ -docdict[ - "cbar_fmt_topomap" -] = """\ +docdict["cbar_fmt_topomap"] = """\ cbar_fmt : str Formatting string for colorbar tick labels. See :ref:`formatspec` for details. @@ -595,17 +610,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ ) -docdict[ - "center" -] = """ +docdict["center"] = """ center : float or None If not None, center of a divergent colormap, changes the meaning of fmin, fmax and fmid. """ -docdict[ - "ch_name_ecg" -] = """ +docdict["ch_name_ecg"] = """ ch_name : None | str The name of the channel to use for ECG peak detection. If ``None`` (default), ECG channel is used if present. If ``None`` and @@ -614,9 +625,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): MEG channels. """ -docdict[ - "ch_name_eog" -] = """ +docdict["ch_name_eog"] = """ ch_name : str | list of str | None The name of the channel(s) to use for EOG peak detection. If a string, can be an arbitrary channel. This doesn't have to be a channel of @@ -628,9 +637,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If ``None`` (default), use the channel(s) in ``raw`` with type ``eog``. """ -docdict[ - "ch_names_annot" -] = """ +docdict["ch_names_annot"] = """ ch_names : list | None List of lists of channel names associated with the annotations. Empty entries are assumed to be associated with no specific channel, @@ -643,10 +650,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): description=['Start', 'BAD_flux', 'BAD_noise'], ch_names=[[], ['MEG0111', 'MEG2563'], ['MEG1443']]) """ +docdict["ch_names_tfr_attr"] = """ +ch_names : list + The channel names.""" -docdict[ - "ch_type_set_eeg_reference" -] = """ +docdict["ch_type_set_eeg_reference"] = """ ch_type : list of str | str The name of the channel type to apply the reference to. Valid channel types are ``'auto'``, ``'eeg'``, ``'ecog'``, ``'seeg'``, @@ -685,34 +693,26 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["channel_wise_applyfun_epo"] = chwise.format("in each epoch ", "epochs and ") -docdict[ - "check_disjoint_clust" -] = """ +docdict["check_disjoint_clust"] = """ check_disjoint : bool Whether to check if the connectivity matrix can be separated into disjoint sets before clustering. This may lead to faster clustering, especially if the second dimension of ``X`` (usually the "time" dimension) is large. """ -docdict[ - "chpi_amplitudes" -] = """ +docdict["chpi_amplitudes"] = """ chpi_amplitudes : dict The time-varying cHPI coil amplitudes, with entries "times", "proj", and "slopes". """ -docdict[ - "chpi_locs" -] = """ +docdict["chpi_locs"] = """ chpi_locs : dict The time-varying cHPI coils locations, with entries "times", "rrs", "moments", and "gofs". """ -docdict[ - "clim" -] = """ +docdict["clim"] = """ clim : str | dict Colorbar properties specification. If 'auto', set clim automatically based on data percentiles. If dict, should contain: @@ -731,9 +731,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): only divergent colormaps should be used with ``pos_lims``. """ -docdict[ - "clim_onesided" -] = """ +docdict["clim_onesided"] = """ clim : str | dict Colorbar properties specification. If 'auto', set clim automatically based on data percentiles. If dict, should contain: @@ -747,17 +745,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``pos_lims``, as the surface plot must show the magnitude. """ -docdict[ - "cmap" -] = """ -cmap : matplotlib colormap | str | None - The :class:`~matplotlib.colors.Colormap` to use. Defaults to ``None``, which - will use the matplotlib default colormap. +_cmap_template = """ +cmap : matplotlib colormap | str{allowed} + The :class:`~matplotlib.colors.Colormap` to use. If a :class:`str`, must be a + valid Matplotlib colormap name. Default is {default}. """ - -docdict[ - "cmap_topomap" -] = """ +docdict["cmap"] = _cmap_template.format( + allowed=" | None", + default="``None``, which will use the Matplotlib default colormap", +) +docdict["cmap_tfr_plot_topo"] = _cmap_template.format( + allowed="", default='``"RdBu_r"``' +) +docdict["cmap_topomap"] = """\ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None Colormap to use. If :class:`tuple`, the first value indicates the colormap to use and the second value is a boolean defining interactivity. In @@ -774,17 +774,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): 2 topomaps. """ -docdict[ - "cmap_topomap_simple" -] = """ +docdict["cmap_topomap_simple"] = """ cmap : matplotlib colormap | None Colormap to use. If None, 'Reds' is used for all positive data, otherwise defaults to 'RdBu_r'. """ -docdict[ - "cnorm" -] = """ +docdict["cnorm"] = """ cnorm : matplotlib.colors.Normalize | None How to normalize the colormap. If ``None``, standard linear normalization is performed. If not ``None``, ``vmin`` and ``vmax`` will be ignored. @@ -793,57 +789,126 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :ref:`the ERDs example` for an example of its use. """ -docdict[ - "color_matplotlib" -] = """ +docdict["color_matplotlib"] = """ color : color A list of anything matplotlib accepts: string, RGB, hex, etc. """ -docdict[ - "color_plot_psd" -] = """\ +docdict["color_plot_psd"] = """\ color : str | tuple A matplotlib-compatible color to use. Has no effect when spatial_colors=True. """ -docdict[ - "color_spectrum_plot_topo" -] = """\ +docdict["color_spectrum_plot_topo"] = """\ color : str | tuple A matplotlib-compatible color to use for the curves. Defaults to white. """ -docdict[ - "colorbar_topomap" -] = """ +docdict["colorbar"] = """\ +colorbar : bool + Whether to add a colorbar to the plot. Default is ``True``. +""" +docdict["colorbar_tfr_plot_joint"] = """ +colorbar : bool + Whether to add a colorbar to the plot (for the topomap annotations). Not compatible + with user-defined ``axes``. Default is ``True``. +""" +docdict["colorbar_topomap"] = """ colorbar : bool Plot a colorbar in the rightmost column of the figure. """ -docdict[ - "colormap" -] = """ +docdict["colormap"] = """ colormap : str | np.ndarray of float, shape(n_colors, 3 | 4) Name of colormap to use or a custom look up table. If array, must be (n x 3) or (n x 4) array for with RGB or RGBA values between 0 and 255. """ -docdict[ - "combine" -] = """ -combine : None | str | callable - How to combine information across channels. If a :class:`str`, must be - one of 'mean', 'median', 'std' (standard deviation) or 'gfp' (global - field power). -""" +_combine_template = """ +combine : 'mean' | {literals} | callable{none} + How to aggregate across channels. {none_sentence}If a string, + ``"mean"`` uses :func:`numpy.mean`, {other_string}. + If :func:`callable`, it must operate on an :class:`array ` + of shape ``({shape})`` and return an array of shape + ``({return_shape})``. {example}{notes}Defaults to {default}. +""" +_example = """For example:: + + combine = lambda data: np.median(data, axis=1) + + """ # ← the 4 trailing spaces are intentional here! +_median_std_gfp = """``"median"`` computes the `marginal median + `__, ``"std"`` + uses :func:`numpy.std`, and ``"gfp"`` computes global field power + for EEG channels and RMS amplitude for MEG channels""" +_none_default = dict(none=" | None", default="``None``") +docdict["combine_plot_compare_evokeds"] = _combine_template.format( + literals="'median' | 'std' | 'gfp'", + **_none_default, + none_sentence="""If ``None``, channels are combined by + computing GFP/RMS, unless ``picks`` is a single channel (not channel type) + or ``axes="topo"``, in which cases no combining is performed. """, + other_string=_median_std_gfp, + shape="n_evokeds, n_channels, n_times", + return_shape="n_evokeds, n_times", + example=_example, + notes="", +) +docdict["combine_plot_epochs_image"] = _combine_template.format( + literals="'median' | 'std' | 'gfp'", + **_none_default, + none_sentence="""If ``None``, channels are combined by + computing GFP/RMS, unless ``group_by`` is also ``None`` and ``picks`` is a + list of specific channels (not channel types), in which case no combining + is performed and each channel gets its own figure. """, + other_string=_median_std_gfp, + shape="n_epochs, n_channels, n_times", + return_shape="n_epochs, n_times", + example=_example, + notes="See Notes for further details. ", +) +docdict["combine_tfr_plot"] = _combine_template.format( + literals="'rms'", + **_none_default, + none_sentence="If ``None``, plot one figure per selected channel. ", + shape="n_channels, n_freqs, n_times", + return_shape="n_freqs, n_times", + other_string='``"rms"`` computes the root-mean-square', + example="", + notes="", +) +docdict["combine_tfr_plot_joint"] = _combine_template.format( + literals="'rms'", + none="", + none_sentence="", + shape="n_channels, n_freqs, n_times", + return_shape="n_freqs, n_times", + other_string='``"rms"`` computes the root-mean-square', + example="", + notes="", + default='``"mean"``', +) -docdict[ - "compute_proj_ecg" -] = """This function will: +_comment_template = """ +comment : str{or_none} + Comment on the data, e.g., the experimental condition(s){avgd}.{extra}""" +docdict["comment_averagetfr"] = _comment_template.format( + or_none=" | None", + avgd="averaged", + extra="""Default is ``None`` + which is replaced with ``inst.comment`` (for :class:`~mne.Evoked` instances) + or a comma-separated string representation of the keys in ``inst.event_id`` + (for :class:`~mne.Epochs` instances).""", +) +docdict["comment_averagetfr_attr"] = _comment_template.format( + or_none="", avgd=" averaged", extra="" +) +docdict["comment_tfr_attr"] = _comment_template.format(or_none="", avgd="", extra="") + +docdict["compute_proj_ecg"] = """This function will: #. Filter the ECG data channel. @@ -858,9 +923,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): #. Calculate SSP projection vectors on that data to capture the artifacts.""" -docdict[ - "compute_proj_eog" -] = """This function will: +docdict["compute_proj_eog"] = """This function will: #. Filter the EOG data channel. @@ -876,18 +939,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): #. Calculate SSP projection vectors on that data to capture the artifacts.""" -docdict[ - "compute_ssp" -] = """This function aims to find those SSP vectors that +docdict["compute_ssp"] = """This function aims to find those SSP vectors that will project out the ``n`` most prominent signals from the data for each specified sensor type. Consequently, if the provided input data contains high levels of noise, the produced SSP vectors can then be used to eliminate that noise from the data. """ -docdict[ - "contours_topomap" -] = """ +docdict["contours_topomap"] = """ contours : int | array-like The number of contour lines to draw. If ``0``, no contours will be drawn. If a positive integer, that number of contour levels are chosen using the @@ -898,9 +957,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): corresponding to the contour levels. Default is ``6``. """ -docdict[ - "coord_frame_maxwell" -] = """ +docdict["coord_frame_maxwell"] = """ coord_frame : str The coordinate frame that the ``origin`` is specified in, either ``'meg'`` or ``'head'``. For empty-room recordings that do not have @@ -908,17 +965,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): frame should be used. """ -docdict[ - "copy_df" -] = """ +docdict["copy_df"] = """ copy : bool If ``True``, data will be copied. Otherwise data may be modified in place. Defaults to ``True``. """ -docdict[ - "create_ecg_epochs" -] = """This function will: +docdict["create_ecg_epochs"] = """This function will: #. Filter the ECG data channel. @@ -927,9 +980,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): #. Create `~mne.Epochs` around the R wave peaks, capturing the heartbeats. """ -docdict[ - "create_eog_epochs" -] = """This function will: +docdict["create_eog_epochs"] = """This function will: #. Filter the EOG data channel. @@ -939,9 +990,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): #. Create `~mne.Epochs` around the eyeblinks. """ -docdict[ - "cross_talk_maxwell" -] = """ +docdict["cross_talk_maxwell"] = """ cross_talk : str | None Path to the FIF file with cross-talk correction information. """ @@ -949,15 +998,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # D -_dB = """\ +_dB = """ dB : bool Whether to plot on a decibel-like scale. If ``True``, plots - 10 × log₁₀(spectral power){}.{} + 10 × log₁₀({quantity}){caveat}.{extra} """ +_ignored_if_normalize = " Ignored if ``normalize=True``." +_psd = "spectral power" -docdict[ - "dB_plot_psd" -] = """\ +docdict["dB_plot_psd"] = """\ dB : bool Plot Power Spectral Density (PSD), in units (amplitude**2/Hz (dB)) if ``dB=True``, and ``estimate='power'`` or ``estimate='auto'``. Plot PSD @@ -968,14 +1017,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``dB=True`` and ``estimate='amplitude'``. """ docdict["dB_plot_topomap"] = _dB.format( - " following the application of ``agg_fun``", " Ignored if ``normalize=True``." + quantity=_psd, + caveat=" following the application of ``agg_fun``", + extra=_ignored_if_normalize, +) +docdict["dB_spectrum_plot"] = _dB.format(quantity=_psd, caveat="", extra="") +docdict["dB_spectrum_plot_topo"] = _dB.format( + quantity=_psd, caveat="", extra=_ignored_if_normalize ) -docdict["dB_spectrum_plot"] = _dB.format("", "") -docdict["dB_spectrum_plot_topo"] = _dB.format("", " Ignored if ``normalize=True``.") +docdict["dB_tfr_plot_topo"] = _dB.format(quantity="data", caveat="", extra="") + +_data_template = """ +data : ndarray, shape ({}) + The data. +""" +docdict["data_tfr"] = _data_template.format("n_channels, n_freqs, n_times") -docdict[ - "daysback_anonymize_info" -] = """ +docdict["daysback_anonymize_info"] = """ daysback : int | None Number of days to subtract from all dates. If ``None`` (default), the acquisition date, ``info['meas_date']``, @@ -983,15 +1041,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``info['meas_date']`` is ``None`` (i.e., no acquisition date has been set). """ -docdict[ - "dbs" -] = """ +docdict["dbs"] = """ dbs : bool If True (default), show DBS (deep brain stimulation) electrodes. """ -docdict[ - "decim" -] = """ +docdict["decim"] = """ decim : int Factor by which to subsample the data. @@ -1002,9 +1056,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): may occur. """ -docdict[ - "decim_notes" -] = """ +docdict["decim_notes"] = """ For historical reasons, ``decim`` / "decimation" refers to simply subselecting samples from a given signal. This contrasts with the broader signal processing literature, where decimation is defined as (quoting @@ -1024,24 +1076,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``inst.decimate(4)``. """ -docdict[ - "decim_tfr" -] = """ -decim : int | slice, default 1 - To reduce memory usage, decimation factor after time-frequency - decomposition. +docdict["decim_tfr"] = """ +decim : int | slice + Decimation factor, applied *after* time-frequency decomposition. - - if `int`, returns ``tfr[..., ::decim]``. - - if `slice`, returns ``tfr[..., decim]``. + - if :class:`int`, returns ``tfr[..., ::decim]`` (keep only every Nth + sample along the time axis). + - if :class:`slice`, returns ``tfr[..., decim]`` (keep only the specified + slice along the time axis). .. note:: Decimation is done after convolutions and may create aliasing artifacts. """ -docdict[ - "depth" -] = """ +docdict["depth"] = """ depth : None | float | dict How to weight (or normalize) the forward using a depth prior. If float (default 0.8), it acts as the depth weighting exponent (``exp``) @@ -1054,9 +1103,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Depth bias ignored for ``method='eLORETA'``. """ -docdict[ - "destination_maxwell_dest" -] = """ +docdict["destination_maxwell_dest"] = """ destination : path-like | array-like, shape (3,) | None The destination location for the head. Can be ``None``, which will not change the head position, or a path to a FIF file @@ -1067,9 +1114,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): head location). """ -docdict[ - "detrend_epochs" -] = """ +docdict["detrend_epochs"] = """ detrend : int | None If 0 or 1, the data channels (MEG and EEG) will be detrended when loaded. 0 is a constant (DC) detrend, 1 is a linear detrend. None @@ -1080,17 +1125,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (will yield equivalent results but be slower). """ -docdict[ - "df_return" -] = """ +docdict["df_return"] = """ df : instance of pandas.DataFrame A dataframe suitable for usage with other statistical/plotting/analysis packages. """ -docdict[ - "dig_kinds" -] = """ +docdict["dig_kinds"] = """ dig_kinds : list of str | str Kind of digitization points to use in the fitting. These can be any combination of ('cardinal', 'hpi', 'eeg', 'extra'). Can also @@ -1099,9 +1140,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): 'eeg' points. """ -docdict[ - "dipole" -] = """ +docdict["dipole"] = """ dipole : instance of Dipole | list of Dipole Dipole object containing position, orientation and amplitude of one or more dipoles. Multiple simultaneous dipoles may be defined by @@ -1112,9 +1151,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for a list of :class:`mne.Dipole` instances. """ -docdict[ - "distance" -] = """ +docdict["distance"] = """ distance : float | "auto" | None The distance from the camera rendering the view to the focalpoint in plot units (either m or mm). If "auto", the bounds of visible objects will be @@ -1124,17 +1161,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``None`` will no longer change the distance, use ``"auto"`` instead. """ -docdict[ - "drop_log" -] = """ +docdict["drop_log"] = """ drop_log : tuple | None Tuple of tuple of strings indicating which epochs have been marked to - be ignored. -""" + be ignored.""" -docdict[ - "dtype_applyfun" -] = """ +docdict["dtype_applyfun"] = """ dtype : numpy.dtype Data type to use after applying the function. If None (default) the data type is not modified. @@ -1143,16 +1175,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # E -docdict[ - "ecog" -] = """ +docdict["ecog"] = """ ecog : bool If True (default), show ECoG sensors. """ -docdict[ - "edf_resamp_note" -] = """ +docdict["edf_resamp_note"] = """ :class:`mne.io.Raw` only stores signals with matching sampling frequencies. Therefore, if mixed sampling frequency signals are requested, all signals are upsampled to the highest loaded sampling frequency. In this case, using @@ -1160,9 +1188,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): slices of the signal are requested. """ -docdict[ - "eeg" -] = """ +docdict["eeg"] = """ eeg : bool | str | list | dict String options are: @@ -1180,20 +1206,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for specifying alpha values as a dict. """ -docdict[ - "elevation" -] = """ +docdict["elevation"] = """ elevation : float The The zenith angle of the camera rendering the view in degrees. """ -docdict[ - "eltc_mode_notes" -] = """ +docdict["eltc_mode_notes"] = """ Valid values for ``mode`` are: - ``'max'`` - Maximum value across vertices at each time point within each label. + Maximum absolute value across vertices at each time point within each label. - ``'mean'`` Average across vertices at each time point within each label. Ignores orientation of sources for standard source estimates, which varies @@ -1203,7 +1225,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): - ``'mean_flip'`` Finds the dominant direction of source space normal vector orientations within each label, applies a sign-flip to time series at vertices whose - orientation is more than 180° different from the dominant direction, and + orientation is more than 90° different from the dominant direction, and then averages across vertices at each time point within each label. - ``'pca_flip'`` Applies singular value decomposition to the time courses within each label, @@ -1228,32 +1250,24 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``'max'``, and ``'auto'``. """ -docdict[ - "emit_warning" -] = """ +docdict["emit_warning"] = """ emit_warning : bool Whether to emit warnings when cropping or omitting annotations. """ -docdict[ - "encoding_edf" -] = """ +docdict["encoding_edf"] = """ encoding : str Encoding of annotations channel(s). Default is "utf8" (the only correct encoding according to the EDF+ standard). """ -docdict[ - "epochs_preload" -] = """ +docdict["epochs_preload"] = """ Load all epochs from disk when creating the object or wait before accessing each epoch (more memory efficient but can be slower). """ -docdict[ - "epochs_reject_tmin_tmax" -] = """ +docdict["epochs_reject_tmin_tmax"] = """ reject_tmin, reject_tmax : float | None Start and end of the time window used to reject epochs based on peak-to-peak (PTP) amplitudes as specified via ``reject`` and ``flat``. @@ -1264,27 +1278,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): both, ``reject`` and ``flat``. """ -docdict[ - "epochs_tmin_tmax" -] = """ +docdict["epochs_tmin_tmax"] = """ tmin, tmax : float Start and end time of the epochs in seconds, relative to the time-locked event. The closest or matching samples corresponding to the start and end time are included. Defaults to ``-0.2`` and ``0.5``, respectively. """ -docdict[ - "estimate_plot_psd" -] = """\ +docdict["estimate_plot_psd"] = """\ estimate : str, {'auto', 'power', 'amplitude'} Can be "power" for power spectral density (PSD), "amplitude" for amplitude spectrum density (ASD), or "auto" (default), which uses "power" when dB is True and "amplitude" otherwise. """ -docdict[ - "event_color" -] = """ +docdict["event_color"] = """ event_color : color object | dict | None Color(s) to use for :term:`events`. To show all :term:`events` in the same color, pass any matplotlib-compatible color. To color events differently, @@ -1294,27 +1302,33 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): color cycle. """ -docdict[ - "event_id" -] = """ -event_id : int | list of int | dict | None +docdict["event_id"] = """ +event_id : int | list of int | dict | str | list of str | None The id of the :term:`events` to consider. If dict, the keys can later be used to access associated :term:`events`. Example: dict(auditory=1, visual=3). If int, a dict will be created with the id as - string. If a list, all :term:`events` with the IDs specified in the list - are used. If None, all :term:`events` will be used and a dict is created + string. If a list of int, all :term:`events` with the IDs specified in the list + are used. If a str or list of str, ``events`` must be ``None`` to use annotations + and then the IDs must be the name(s) of the annotations to use. + If None, all :term:`events` will be used and a dict is created with string integer names corresponding to the event id integers.""" - -docdict[ - "event_id_ecg" -] = """ +_event_id_template = """ +event_id : dict{or_none} + Mapping from condition descriptions (strings) to integer event codes.{extra}""" +docdict["event_id_attr"] = _event_id_template.format(or_none="", extra="") +docdict["event_id_ecg"] = """ event_id : int The index to assign to found ECG events. """ +docdict["event_id_epochstfr"] = _event_id_template.format( + or_none=" | None", + extra="""If ``None``, + all events in ``events`` will be included, and the ``event_id`` attribute + will be a :class:`dict` mapping a string version of each integer event ID + to the corresponding integer.""", +) -docdict[ - "event_repeated_epochs" -] = """ +docdict["event_repeated_epochs"] = """ event_repeated : str How to handle duplicates in ``events[:, 0]``. Can be ``'error'`` (default), to raise an error, 'drop' to only retain the row occurring @@ -1324,27 +1338,30 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.19 """ -docdict[ - "events" -] = """ -events : array of int, shape (n_events, 3) - The array of :term:`events`. The first column contains the event time in - samples, with :term:`first_samp` included. The third column contains the - event id.""" - -docdict[ - "events_epochs" -] = """ -events : array of int, shape (n_events, 3) - The array of :term:`events`. The first column contains the event time in - samples, with :term:`first_samp` included. The third column contains the - event id. - If some events don't match the events of interest as specified by - ``event_id``, they will be marked as ``IGNORED`` in the drop log.""" - -docdict[ - "evoked_by_event_type_returns" -] = """ +_events_template = """ +events : ndarray of int, shape (n_events, 3){or_none} + The identity and timing of experimental events, around which the epochs were + created. See :term:`events` for more information.{extra} +""" +docdict["events"] = _events_template.format(or_none="", extra="") +docdict["events_attr"] = """ +events : ndarray of int, shape (n_events, 3) + The events array.""" +docdict["events_epochs"] = _events_template.format( + or_none="", + extra="""Events that don't match + the events of interest as specified by ``event_id`` will be marked as + ``IGNORED`` in the drop log.""", +) +docdict["events_epochstfr"] = _events_template.format( + or_none=" | None", + extra="""If ``None``, all integer + event codes are set to ``1`` (i.e., all epochs are assumed to be of the same + type) and their corresponding sample numbers are set as arbitrary, equally + spaced sample numbers with a step size of ``len(times)``.""", +) + +docdict["evoked_by_event_type_returns"] = """ evoked : instance of Evoked | list of Evoked The averaged epochs. When ``by_event_type=True`` was specified, a list is returned containing a @@ -1353,18 +1370,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): dictionary. """ -docdict[ - "exclude_clust" -] = """ +docdict["exclude_after_unique"] = """ +exclude_after_unique : bool + If True, exclude channels are searched for after they have been made + unique. This is useful to choose channels that have been made unique + by adding a suffix. If False, the original names are checked. + + .. versionchanged:: 1.7 +""" + +docdict["exclude_clust"] = """ exclude : bool array or None Mask to apply to the data to exclude certain points from clustering (e.g., medial wall vertices). Should be the same shape as ``X``. If ``None``, no points are excluded. """ -docdict[ - "exclude_frontal" -] = """ +docdict["exclude_frontal"] = """ exclude_frontal : bool If True, exclude points that have both negative Z values (below the nasion) and positive Y values (in front of the LPA/RPA). @@ -1383,30 +1405,27 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): " from being drawn", "spectrum." ) -docdict[ - "export_edf_note" -] = """ -For EDF exports, only channels measured in Volts are allowed; in MNE-Python -this means channel types 'eeg', 'ecog', 'seeg', 'emg', 'eog', 'ecg', 'dbs', -'bio', and 'misc'. 'stim' channels are dropped. Although this function -supports storing channel types in the signal label (e.g. ``EEG Fz`` or -``MISC E``), other software may not support this (optional) feature of -the EDF standard. - -If ``add_ch_type`` is True, then channel types are written based on what -they are currently set in MNE-Python. One should double check that all -their channels are set correctly. You can call -:attr:`raw.set_channel_types ` to set -channel types. - -In addition, EDF does not support storing a montage. You will need -to store the montage separately and call :attr:`raw.set_montage() -`. -""" - -docdict[ - "export_eeglab_note" -] = """ +docdict["export_edf_note"] = """ +Although this function supports storing channel types in the signal label (e.g. +``EEG Fz`` or ``MISC E``), other software may not support this (optional) feature of the +EDF standard. + +If ``add_ch_type`` is True, then channel types are written based on what they are +currently set in MNE-Python. One should double check that all their channels are set +correctly. You can call :meth:`mne.io.Raw.set_channel_types` to set channel types. + +In addition, EDF does not support storing a montage. You will need to store the montage +separately and call :meth:`mne.io.Raw.set_montage`. + +The physical range of the signals is determined by signal type by default +(``physical_range="auto"``). However, if individual channel ranges vary significantly +due to the presence of e.g. drifts/offsets/biases, setting +``physical_range="channelwise"`` might be more appropriate. This will ensure a maximum +resolution for each individual channel, but some tools might not be able to handle this +appropriately (even though channel-wise ranges are covered by the EDF standard). +""" + +docdict["export_eeglab_note"] = """ For EEGLAB exports, channel locations are expanded to full EEGLAB format. For more details see :func:`eeglabio.utils.cart_to_eeglab`. """ @@ -1416,59 +1435,39 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): from the filename extension. See supported formats above for more information.""" -docdict[ - "export_fmt_params_epochs" -] = """ +docdict["export_fmt_params_epochs"] = f""" fmt : 'auto' | 'eeglab' - {} -""".format( - _export_fmt_params_base -) + {_export_fmt_params_base} +""" -docdict[ - "export_fmt_params_evoked" -] = """ +docdict["export_fmt_params_evoked"] = f""" fmt : 'auto' | 'mff' - {} -""".format( - _export_fmt_params_base -) + {_export_fmt_params_base} +""" -docdict[ - "export_fmt_params_raw" -] = """ +docdict["export_fmt_params_raw"] = f""" fmt : 'auto' | 'brainvision' | 'edf' | 'eeglab' - {} -""".format( - _export_fmt_params_base -) + {_export_fmt_params_base} +""" -docdict[ - "export_fmt_support_epochs" -] = """\ +docdict["export_fmt_support_epochs"] = """\ Supported formats: - EEGLAB (``.set``, uses :mod:`eeglabio`) """ -docdict[ - "export_fmt_support_evoked" -] = """\ +docdict["export_fmt_support_evoked"] = """\ Supported formats: - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ -docdict[ - "export_fmt_support_raw" -] = """\ +docdict["export_fmt_support_raw"] = """\ Supported formats: - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `EDFlib-Python `_) + - EDF (``.edf``, uses `edfio `_) """ # noqa: E501 -docdict[ - "export_warning" -] = """\ +docdict["export_warning"] = """\ .. warning:: Since we are exporting to external formats, there's no guarantee that all the info will be preserved in the external format. See Notes for details. @@ -1488,9 +1487,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_warning_note_raw"] = _export_warning_note_base.format("io.Raw") -docdict[ - "ext_order_chpi" -] = """ +docdict["ext_order_chpi"] = """ ext_order : int The external order for SSS-like interfence suppression. The SSS bases are used as projection vectors during fitting. @@ -1500,16 +1497,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): detection of true HPI signals. """ -docdict[ - "ext_order_maxwell" -] = """ +docdict["ext_order_maxwell"] = """ ext_order : int Order of external component of spherical expansion. """ -docdict[ - "extended_proj_maxwell" -] = """ +docdict["extended_proj_maxwell"] = """ extended_proj : list The empty-room projection vectors used to extend the external SSS basis (i.e., use eSSS). @@ -1517,9 +1510,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.21 """ -docdict[ - "extrapolate_topomap" -] = """ +docdict["extrapolate_topomap"] = """ extrapolate : str Options: @@ -1538,18 +1529,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the head circle. """ -docdict[ - "eyelink_apply_offsets" -] = """ +docdict["eyelink_apply_offsets"] = """ apply_offsets : bool (default False) Adjusts the onset time of the :class:`~mne.Annotations` created from Eyelink experiment messages, if offset values exist in the ASCII file. If False, any offset-like values will be prepended to the annotation description. """ -docdict[ - "eyelink_create_annotations" -] = """ +docdict["eyelink_create_annotations"] = """ create_annotations : bool | list (default True) Whether to create :class:`~mne.Annotations` from occular events (blinks, fixations, saccades) and experiment messages. If a list, must @@ -1558,24 +1545,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): experiment messages. """ -docdict[ - "eyelink_find_overlaps" -] = """ +docdict["eyelink_find_overlaps"] = """ find_overlaps : bool (default False) Combine left and right eye :class:`mne.Annotations` (blinks, fixations, saccades) if their start times and their stop times are both not separated by more than overlap_threshold. """ -docdict[ - "eyelink_fname" -] = """ +docdict["eyelink_fname"] = """ fname : path-like Path to the eyelink file (``.asc``).""" -docdict[ - "eyelink_overlap_threshold" -] = """ +docdict["eyelink_overlap_threshold"] = """ overlap_threshold : float (default 0.05) Time in seconds. Threshold of allowable time-gap between both the start and stop times of the left and right eyes. If the gap is larger than the threshold, @@ -1591,9 +1572,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # F -docdict[ - "f_power_clust" -] = """ +docdict["f_power_clust"] = """ t_power : float Power to raise the statistical values (usually F-values) by before summing (sign will be retained). Note that ``t_power=0`` will give a @@ -1601,9 +1580,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): by its statistical score. """ -docdict[ - "fiducials" -] = """ +docdict["fiducials"] = """ fiducials : list | dict | str The fiducials given in the MRI (surface RAS) coordinate system. If a dictionary is provided, it must contain the **keys** @@ -1618,17 +1595,17 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and if absent, falls back to ``'estimated'``. """ -docdict[ - "fig_facecolor" -] = """\ +docdict["fig_background"] = """ +fig_background : None | array + A background image for the figure. This must be a valid input to + :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. +""" +docdict["fig_facecolor"] = """ fig_facecolor : str | tuple - A matplotlib-compatible color to use for the figure background. - Defaults to black. + A matplotlib-compatible color to use for the figure background. Defaults to black. """ -docdict[ - "filter_length" -] = """ +docdict["filter_length"] = """ filter_length : str | int Length of the FIR filter to use (if applicable): @@ -1645,16 +1622,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): this should not be used. """ -docdict[ - "filter_length_ecg" -] = """ +docdict["filter_length_ecg"] = """ filter_length : str | int | None Number of taps to use for filtering. """ -docdict[ - "filter_length_notch" -] = """ +docdict["filter_length_notch"] = """ filter_length : str | int Length of the FIR filter to use (if applicable): @@ -1679,9 +1652,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The default in 0.21 is None, but this will change to ``'10s'`` in 0.22. """ -docdict[ - "fir_design" -] = """ +docdict["fir_design"] = """ fir_design : str Can be "firwin" (default) to use :func:`scipy.signal.firwin`, or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses @@ -1691,9 +1662,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.15 """ -docdict[ - "fir_window" -] = """ +docdict["fir_window"] = """ fir_window : str The window to use in FIR design, can be "hamming" (default), "hann" (default in 0.13), or "blackman". @@ -1708,9 +1677,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): is smaller than this threshold, the epoch will be dropped. If ``None`` then no rejection is performed based on flatness of the signal.""" -docdict[ - "flat" -] = f""" +docdict["flat"] = f""" flat : dict | None {_flat_common} @@ -1718,11 +1685,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): quality, pass the ``reject_tmin`` and ``reject_tmax`` parameters. """ -docdict[ - "flat_drop_bad" -] = f""" +docdict["flat_drop_bad"] = """ flat : dict | str | None -{_flat_common} + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) + or a custom function. Valid **keys** can be any channel type present + in the object. If using PTP, **values** are floats that set the minimum + acceptable PTP. If the PTP is smaller than this threshold, the epoch + will be dropped. If ``None`` then no rejection is performed based on + flatness of the signal. If a custom function is used than ``flat`` can be + used to reject epochs based on any criteria (including maxima and + minima). If ``'existing'``, then the flat parameters set during epoch creation are used. """ @@ -1736,10 +1708,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) docdict["fmin_fmax_psd_topo"] = _fmin_fmax.format("``fmin=0, fmax=100``.") +docdict["fmin_fmax_tfr"] = _fmin_fmax.format( + """``None`` + which is equivalent to ``fmin=0, fmax=np.inf`` (spans all frequencies + present in the data).""" +) -docdict[ - "fmin_fmid_fmax" -] = """ +docdict["fmin_fmid_fmax"] = """ fmin : float Minimum value in colormap (uses real fmin if None). fmid : float @@ -1749,33 +1724,25 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Maximum value in colormap (uses real max if None). """ -docdict[ - "fname_epochs" -] = """ +docdict["fname_epochs"] = """ fname : path-like | file-like The epochs to load. If a filename, should end with ``-epo.fif`` or ``-epo.fif.gz``. If a file-like object, preloading must be used. """ -docdict[ - "fname_export_params" -] = """ +docdict["fname_export_params"] = """ fname : str Name of the output file. """ -docdict[ - "fname_fwd" -] = """ +docdict["fname_fwd"] = """ fname : path-like File name to save the forward solution to. It should end with ``-fwd.fif`` or ``-fwd.fif.gz`` to save to FIF, or ``-fwd.h5`` to save to HDF5. """ -docdict[ - "fnirs" -] = """ +docdict["fnirs"] = """ fnirs : str | list | dict | bool | None Can be "channels", "pairs", "detectors", and/or "sources" to show the fNIRS channel locations, optode locations, or line between @@ -1788,34 +1755,46 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for specifying alpha values as a dict. """ -docdict[ - "focalpoint" -] = """ +docdict["focalpoint"] = """ focalpoint : tuple, shape (3,) | str | None The focal point of the camera rendering the view: (x, y, z) in plot units (either m or mm). When ``"auto"``, it is set to the center of mass of the visible bounds. """ -docdict[ - "forward_set_eeg_reference" -] = """ +docdict["font_color"] = """ +font_color : color + The color of tick labels in the colorbar. Defaults to white. +""" + +docdict["forward_set_eeg_reference"] = """ forward : instance of Forward | None Forward solution to use. Only used with ``ref_channels='REST'``. .. versionadded:: 0.21 """ +_freqs_tfr_template = """ +freqs : array-like |{auto} None + The frequencies at which to compute the power estimates. + {stockwell} be an array of shape (n_freqs,). ``None`` (the + default) only works when using ``__setstate__`` and will raise an error otherwise. +""" +docdict["freqs_tfr"] = _freqs_tfr_template.format(auto="", stockwell="Must") +docdict["freqs_tfr_array"] = """ +freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. +""" +docdict["freqs_tfr_attr"] = """ +freqs : array + Frequencies at which power has been computed.""" +docdict["freqs_tfr_epochs"] = _freqs_tfr_template.format( + auto=" 'auto' | ", + stockwell="""If ``method='stockwell'`` this must be a length 2 iterable specifying lowest + and highest frequencies, or ``'auto'`` (to use all available frequencies). + For other methods, must""", # noqa E501 +) -docdict[ - "freqs_tfr" -] = """ -freqs : array of float, shape (n_freqs,) - The frequencies of interest in Hz. -""" - -docdict[ - "fullscreen" -] = """ +docdict["fullscreen"] = """ fullscreen : bool Whether to start in fullscreen (``True``) or windowed mode (``False``). @@ -1827,6 +1806,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): fun has to be a timeseries (:class:`numpy.ndarray`). The function must operate on an array of shape ``(n_times,)`` {}. The function must return an :class:`~numpy.ndarray` shaped like its input. + + .. note:: + If ``channel_wise=True``, one can optionally access the index and/or the + name of the currently processed channel within the applied function. + This can enable tailored computations for different channels. + To use this feature, add ``ch_idx`` and/or ``ch_name`` as + additional argument(s) to your function definition. """ docdict["fun_applyfun"] = applyfun_fun_base.format( " if ``channel_wise=True`` and ``(len(picks), n_times)`` otherwise" @@ -1834,18 +1820,17 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["fun_applyfun_evoked"] = applyfun_fun_base.format( " because it will apply channel-wise" ) +docdict["fun_applyfun_stc"] = applyfun_fun_base.format( + " because it will apply vertex-wise" +) -docdict[ - "fwd" -] = """ +docdict["fwd"] = """ fwd : instance of Forward The forward solution. If present, the orientations of the dipoles present in the forward solution are displayed. """ -docdict[ - "fwhm_morlet_notes" -] = r""" +docdict["fwhm_morlet_notes"] = r""" Convolution of a signal with a Morlet wavelet will impose temporal smoothing that is determined by the duration of the wavelet. In MNE-Python, the duration of the wavelet is determined by the ``sigma`` parameter, which gives the @@ -1879,9 +1864,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # G -docdict[ - "get_peak_parameters" -] = """ +docdict["get_peak_parameters"] = """ tmin : float | None The minimum point in time to be considered for peak getting. tmax : float | None @@ -1899,21 +1882,30 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (False, default). """ -_getitem_base = """\ +_getitem_spectrum_base = """ data : ndarray The selected spectral data. Shape will be - ``({}n_channels, n_freqs)`` for normal power spectra, - ``({}n_channels, n_freqs, n_segments)`` for unaggregated - Welch estimates, or ``({}n_channels, n_tapers, n_freqs)`` + ``({n_epo}n_channels, n_freqs)`` for normal power spectra, + ``({n_epo}n_channels, n_freqs, n_segments)`` for unaggregated + Welch estimates, or ``({n_epo}n_channels, n_tapers, n_freqs)`` for unaggregated multitaper estimates. """ -_fill_epochs = ["n_epochs, "] * 3 -docdict["getitem_epochspectrum_return"] = _getitem_base.format(*_fill_epochs) -docdict["getitem_spectrum_return"] = _getitem_base.format("", "", "") +_getitem_tfr_base = """ +data : ndarray + The selected time-frequency data. Shape will be + ``({n_epo}n_channels, n_freqs, n_times)`` for Morlet, Stockwell, and aggregated + (``output='power'``) multitaper methods, or + ``({n_epo}n_channels, n_tapers, n_freqs, n_times)`` for unaggregated + (``output='complex'``) multitaper method. +""" +n_epo = "n_epochs, " +docdict["getitem_epochspectrum_return"] = _getitem_spectrum_base.format(n_epo=n_epo) +docdict["getitem_epochstfr_return"] = _getitem_tfr_base.format(n_epo=n_epo) +docdict["getitem_spectrum_return"] = _getitem_spectrum_base.format(n_epo="") +docdict["getitem_tfr_return"] = _getitem_tfr_base.format(n_epo="") + -docdict[ - "group_by_browse" -] = """ +docdict["group_by_browse"] = """ group_by : str How to group channels. ``'type'`` groups by channel type, ``'original'`` plots in the order of ch_names, ``'selection'`` uses @@ -1929,17 +1921,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # H -docdict[ - "h_freq" -] = """ +docdict["h_freq"] = """ h_freq : float | None For FIR filters, the upper pass-band edge; for IIR filters, the upper cutoff frequency. If None the data are only high-passed. """ -docdict[ - "h_trans_bandwidth" -] = """ +docdict["h_trans_bandwidth"] = """ h_trans_bandwidth : float | str Width of the transition band at the high cut-off frequency in Hz (low pass or cutoff 2 in bandpass). Can be "auto" @@ -1950,9 +1938,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Only used for ``method='fir'``. """ -docdict[ - "head_pos" -] = """ +docdict["head_pos"] = """ head_pos : None | path-like | dict | tuple | array Path to the position estimates file. Should be in the format of the files produced by MaxFilter. If dict, keys should @@ -1964,26 +1950,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`mne.chpi.read_head_pos`. """ -docdict[ - "head_pos_maxwell" -] = """ +docdict["head_pos_maxwell"] = """ head_pos : array | None If array, movement compensation will be performed. The array should be of shape (N, 10), holding the position parameters as returned by e.g. ``read_head_pos``. """ -docdict[ - "head_source" -] = """ +docdict["head_source"] = """ head_source : str | list of str Head source(s) to use. See the ``source`` option of :func:`mne.get_head_surf` for more information. """ -docdict[ - "hitachi_fname" -] = """ +docdict["hitachi_fname"] = """ fname : list | str Path(s) to the Hitachi CSV file(s). This should only be a list for multiple probes that were acquired simultaneously. @@ -1992,9 +1972,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for list-of-str. """ -docdict[ - "hitachi_notes" -] = """ +docdict["hitachi_notes"] = """ Hitachi does not encode their channel positions, so you will need to create a suitable mapping using :func:`mne.channels.make_standard_montage` or :func:`mne.channels.make_dig_montage` like (for a 3x5/ETG-7000 example): @@ -2049,9 +2027,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # I -docdict[ - "idx_pctf" -] = """ +docdict["idx_pctf"] = """ idx : list of int | list of Label Source for indices for which to compute PSFs or CTFs. If mode is None, PSFs/CTFs will be returned for all indices. If mode is not None, the @@ -2065,27 +2041,27 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): specified labels. """ -docdict[ - "ignore_ref_maxwell" -] = """ +docdict["ignore_ref_maxwell"] = """ ignore_ref : bool If True, do not include reference channels in compensation. This option should be True for KIT files, since Maxwell filtering with reference channels is not currently supported. """ -docdict[ - "iir_params" -] = """ +docdict["iir_params"] = """ iir_params : dict | None Dictionary of parameters to use for IIR filtering. If ``iir_params=None`` and ``method="iir"``, 4th order Butterworth will be used. For more information, see :func:`mne.filter.construct_iir_filter`. """ -docdict[ - "image_format_report" -] = """ +docdict["image_args"] = """ +image_args : dict | None + Keyword arguments to pass to :meth:`mne.time_frequency.AverageTFR.plot`. ``axes`` + and ``show`` are ignored. Defaults to ``None`` (i.e., and empty :class:`dict`). +""" + +docdict["image_format_report"] = """ image_format : 'png' | 'svg' | 'gif' | None The image format to be used for the report, can be ``'png'``, ``'svg'``, or ``'gif'``. @@ -2093,9 +2069,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): instantiation. """ -docdict[ - "image_interp_topomap" -] = """ +docdict["image_interp_topomap"] = """ image_interp : str The image interpolation to be used. Options are ``'cubic'`` (default) to use :class:`scipy.interpolate.CloughTocher2DInterpolator`, @@ -2103,9 +2077,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``'linear'`` to use :class:`scipy.interpolate.LinearNDInterpolator`. """ -docdict[ - "include_tmax" -] = """ +docdict["include_tmax"] = """ include_tmax : bool If True (default), include tmax. If False, exclude tmax (similar to how Python indexing typically works). @@ -2139,39 +2111,33 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): "sensors and methods of measurement." ) -docdict[ - "info" -] = f""" +docdict["info"] = f""" info : mne.Info | None {_info_base} """ -docdict[ - "info_not_none" -] = f""" +docdict["info_not_none"] = f""" info : mne.Info {_info_base} """ -docdict[ - "info_str" -] = f""" +docdict["info_str"] = f""" info : mne.Info | path-like {_info_base} If ``path-like``, it should be a :class:`str` or :class:`pathlib.Path` to a file with measurement information (e.g. :class:`mne.io.Raw`). """ -docdict[ - "int_order_maxwell" -] = """ +docdict["inst_tfr"] = """ +inst : instance of RawTFR, EpochsTFR, or AverageTFR +""" + +docdict["int_order_maxwell"] = """ int_order : int Order of internal component of spherical expansion. """ -docdict[ - "interaction_scene" -] = """ +docdict["interaction_scene"] = """ interaction : 'trackball' | 'terrain' How interactions with the scene via an input device (e.g., mouse or trackpad) modify the camera position. If ``'terrain'``, one axis is @@ -2181,9 +2147,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): some axes. """ -docdict[ - "interaction_scene_none" -] = """ +docdict["interaction_scene_none"] = """ interaction : 'trackball' | 'terrain' | None How interactions with the scene via an input device (e.g., mouse or trackpad) modify the camera position. If ``'terrain'``, one axis is @@ -2195,27 +2159,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): used. """ -docdict[ - "interp" -] = """ +docdict["interp"] = """ interp : str Either ``'hann'``, ``'cos2'`` (default), ``'linear'``, or ``'zero'``, the type of forward-solution interpolation to use between forward solutions at different head positions. """ -docdict[ - "interpolation_brain_time" -] = """ +docdict["interpolation_brain_time"] = """ interpolation : str | None Interpolation method (:class:`scipy.interpolate.interp1d` parameter). Must be one of ``'linear'``, ``'nearest'``, ``'zero'``, ``'slinear'``, ``'quadratic'`` or ``'cubic'``. """ -docdict[ - "inversion_bf" -] = """ +docdict["inversion_bf"] = """ inversion : 'single' | 'matrix' This determines how the beamformer deals with source spaces in "free" orientation. Such source spaces define three orthogonal dipoles at each @@ -2229,12 +2187,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Defaults to ``'matrix'``. """ +docdict["item"] = """ +item : int | slice | array-like | str +""" + # %% # J -docdict[ - "joint_set_eeg_reference" -] = """ +docdict["joint_set_eeg_reference"] = """ joint : bool How to handle list-of-str ``ch_type``. If False (default), one projector is created per channel type. If True, one projector is created across @@ -2246,9 +2206,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # K -docdict[ - "keep_his_anonymize_info" -] = """ +docdict["keep_his_anonymize_info"] = """ keep_his : bool If ``True``, ``his_id`` of ``subject_info`` will **not** be overwritten. Defaults to ``False``. @@ -2257,35 +2215,33 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): anonymized. Use with caution. """ -docdict[ - "kit_elp" -] = """ +docdict["kit_badcoils"] = """ +bad_coils : array-like of int | None + Indices of (up to two) bad marker coils to be removed. + These marker coils must be present in the elp and mrk files. +""" + +docdict["kit_elp"] = """ elp : path-like | array of shape (8, 3) | None Digitizer points representing the location of the fiducials and the marker coils with respect to the digitized head shape, or path to a file containing these points. """ -docdict[ - "kit_hsp" -] = """ +docdict["kit_hsp"] = """ hsp : path-like | array of shape (n_points, 3) | None Digitizer head shape points, or path to head shape file. If more than 10,000 points are in the head shape, they are automatically decimated. """ -docdict[ - "kit_mrk" -] = """ +docdict["kit_mrk"] = """ mrk : path-like | array of shape (5, 3) | list | None Marker points representing the location of the marker coils with respect to the MEG sensors, or path to a marker file. If list, all of the markers will be averaged together. """ -docdict[ - "kit_slope" -] = r""" +docdict["kit_slope"] = r""" slope : ``'+'`` | ``'-'`` How to interpret values on KIT trigger channels when synthesizing a Neuromag-style stim channel. With ``'+'``\, a positive slope (low-to-high) @@ -2293,9 +2249,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): is interpreted as an event. """ -docdict[ - "kit_stim" -] = r""" +docdict["kit_stim"] = r""" stim : list of int | ``'<'`` | ``'>'`` | None Channel-value correspondence when converting KIT trigger channels to a Neuromag-style stim channel. For ``'<'``\, the largest values are @@ -2305,25 +2259,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): generated. """ -docdict[ - "kit_stimcode" -] = """ +docdict["kit_stimcode"] = """ stim_code : ``'binary'`` | ``'channel'`` How to decode trigger values from stim channels. ``'binary'`` read stim channel events as binary code, 'channel' encodes channel number. """ -docdict[ - "kit_stimthresh" -] = """ +docdict["kit_stimthresh"] = """ stimthresh : float | None The threshold level for accepting voltage changes in KIT trigger channels as a trigger event. If None, stim must also be set to None. """ -docdict[ - "kwargs_fun" -] = """ +docdict["kwargs_fun"] = """ **kwargs : dict Additional keyword arguments to pass to ``fun``. """ @@ -2331,26 +2279,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # L -docdict[ - "l_freq" -] = """ +docdict["l_freq"] = """ l_freq : float | None For FIR filters, the lower pass-band edge; for IIR filters, the lower cutoff frequency. If None the data are only low-passed. """ -docdict[ - "l_freq_ecg_filter" -] = """ +docdict["l_freq_ecg_filter"] = """ l_freq : float Low pass frequency to apply to the ECG channel while finding events. h_freq : float High pass frequency to apply to the ECG channel while finding events. """ -docdict[ - "l_trans_bandwidth" -] = """ +docdict["l_trans_bandwidth"] = """ l_trans_bandwidth : float | str Width of the transition band at the low cut-off frequency in Hz (high pass or cutoff 1 in bandpass). Can be "auto" @@ -2361,16 +2303,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Only used for ``method='fir'``. """ -docdict[ - "label_tc_el_returns" -] = """ +docdict["label_tc_el_returns"] = """ label_tc : array | list (or generator) of array, shape (n_labels[, n_orient], n_times) Extracted time course for each label and source estimate. """ -docdict[ - "labels_eltc" -] = """ +docdict["labels_eltc"] = """ labels : Label | BiHemiLabel | list | tuple | str If using a surface or mixed source space, this should be the :class:`~mne.Label`'s for which to extract the time course. @@ -2386,19 +2324,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionchanged:: 0.21.0 Support for volume source estimates. """ - -docdict[ - "layout_spectrum_plot_topo" -] = """\ +docdict["layout_scale"] = """ +layout_scale : float + Scaling factor for adjusting the relative size of the layout on the canvas. +""" +docdict["layout_spectrum_plot_topo"] = """\ layout : instance of Layout | None Layout instance specifying sensor positions (does not need to be specified for Neuromag data). If ``None`` (default), the layout is - inferred from the data. + inferred from the data (if possible). """ -docdict[ - "line_alpha_plot_psd" -] = """\ +docdict["line_alpha_plot_psd"] = """\ line_alpha : float | None Alpha for the PSD line. Can be None (default) to use 1.0 when ``average=True`` and 0.1 when ``average=False``. @@ -2425,9 +2362,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["long_format_df_spe"] = _long_format_df_base.format(*spe) docdict["long_format_df_stc"] = _long_format_df_base.format(*stc) -docdict[ - "loose" -] = """ +docdict["loose"] = """ loose : float | 'auto' | dict Value that weights the source variances of the dipole components that are parallel (tangential) to the cortical surface. Can be: @@ -2446,9 +2381,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # M -docdict[ - "mag_scale_maxwell" -] = """ +docdict["mag_scale_maxwell"] = """ mag_scale : float | str The magenetometer scale-factor used to bring the magnetometers to approximately the same order of magnitude as the gradiometers @@ -2458,9 +2391,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): 59.5 for VectorView). """ -docdict[ - "mapping_rename_channels_duplicates" -] = """ +docdict["mapping_rename_channels_duplicates"] = """ mapping : dict | callable A dictionary mapping the old channel to a new channel name e.g. ``{'EEG061' : 'EEG161'}``. Can also be a callable function @@ -2482,17 +2413,25 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): with the parameters given in ``mask_params``. Defaults to ``None``, equivalent to an array of all ``False`` elements. """ - +docdict["mask_alpha_tfr_plot"] = """ +mask_alpha : float + Relative opacity of the masked region versus the unmasked region, given as a + :class:`float` between 0 and 1 (i.e., 0 means masked areas are not visible at all). + Defaults to ``0.1``. +""" +docdict["mask_cmap_tfr_plot"] = """ +mask_cmap : matplotlib colormap | str | None + Colormap to use for masked areas of the plot. If a :class:`str`, must be a valid + Matplotlib colormap name. If None, ``cmap`` is used for both masked and unmasked + areas. Ignored if ``mask`` is ``None``. Default is ``'Greys'``. +""" docdict["mask_evoked_topomap"] = _mask_base.format( shape="(n_channels, n_times)", shape_appendix="-time combinations", example=" (useful for, e.g. marking which channels at which times a " "statistical test of the data reaches significance)", ) - -docdict[ - "mask_params_topomap" -] = """ +docdict["mask_params_topomap"] = """ mask_params : dict | None Additional plotting parameters for plotting significant sensors. Default (None) equals:: @@ -2500,18 +2439,30 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): dict(marker='o', markerfacecolor='w', markeredgecolor='k', linewidth=0, markersize=4) """ - docdict["mask_patterns_topomap"] = _mask_base.format( shape="(n_channels, n_patterns)", shape_appendix="-pattern combinations", example="" ) - +docdict["mask_style_tfr_plot"] = """ +mask_style : None | 'both' | 'contour' | 'mask' + How to distinguish the masked/unmasked regions of the plot. If ``"contour"``, a + line is drawn around the areas where the mask is ``True``. If ``"mask"``, areas + where the mask is ``False`` will be (partially) transparent, as determined by + ``mask_alpha``. If ``"both"``, both a contour and transparency are used. Default is + ``None``, which is silently ignored if ``mask`` is ``None`` and is interpreted like + ``"both"`` otherwise. +""" +docdict["mask_tfr_plot"] = """ +mask : ndarray | None + An :class:`array ` of :class:`boolean ` values, of the same + shape as the data. Data that corresponds to ``False`` entries in the mask are + plotted differently, as determined by ``mask_style``, ``mask_alpha``, and + ``mask_cmap``. Useful for, e.g., highlighting areas of statistical significance. +""" docdict["mask_topomap"] = _mask_base.format( shape="(n_channels,)", shape_appendix="(s)", example="" ) -docdict[ - "match_alias" -] = """ +docdict["match_alias"] = """ match_alias : bool | dict Whether to use a lookup table to match unrecognized channel location names to their known aliases. If True, uses the mapping in @@ -2522,18 +2473,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.23 """ -docdict[ - "match_case" -] = """ +docdict["match_case"] = """ match_case : bool If True (default), channel name matching will be case sensitive. .. versionadded:: 0.20 """ -docdict[ - "max_dist_ieeg" -] = """ +docdict["max_dist_ieeg"] = """ max_dist : float The maximum distance to project a sensor to the pial surface in meters. Sensors that are greater than this distance from the pial surface will @@ -2541,17 +2488,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): flat brain. """ -docdict[ - "max_iter_multitaper" -] = """ +docdict["max_iter_multitaper"] = """ max_iter : int Maximum number of iterations to reach convergence when combining the tapered spectra with adaptive weights (see argument ``adaptive``). This argument has not effect if ``adaptive`` is set to ``False``.""" -docdict[ - "max_step_clust" -] = """ +docdict["max_step_clust"] = """ max_step : int Maximum distance between samples along the second axis of ``X`` to be considered adjacent (typically the second axis is the "time" dimension). @@ -2562,9 +2505,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`mne.stats.combine_adjacency`). """ -docdict[ - "measure" -] = """ +docdict["measure"] = """ measure : 'zscore' | 'correlation' Which method to use for finding outliers among the components: @@ -2577,9 +2518,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.21""" -docdict[ - "meg" -] = """ +docdict["meg"] = """ meg : str | list | dict | bool | None Can be "helmet", "sensors" or "ref" to show the MEG helmet, sensors or reference sensors respectively, or a combination like @@ -2591,33 +2530,48 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for specifying alpha values as a dict. """ -docdict[ - "metadata_epochs" -] = """ +_metadata_attr_template = """ metadata : instance of pandas.DataFrame | None - A :class:`pandas.DataFrame` specifying metadata about each epoch. - If given, ``len(metadata)`` must equal ``len(events)``. The DataFrame - may only contain values of type (str | int | float | bool). - If metadata is given, then pandas-style queries may be used to select - subsets of data, see :meth:`mne.Epochs.__getitem__`. - When a subset of the epochs is created in this (or any other - supported) manner, the metadata object is subsetted accordingly, and - the row indices will be modified to match ``epochs.selection``. - - .. versionadded:: 0.16 -""" + A :class:`pandas.DataFrame` specifying metadata about each epoch{or_none}.{extra} +""" +_metadata_template = _metadata_attr_template.format( + or_none="", + extra=""" + If not ``None``, ``len(metadata)`` must equal ``len(events)``. For + save/load compatibility, the :class:`~pandas.DataFrame` may only contain + :class:`str`, :class:`int`, :class:`float`, and :class:`bool` values. + If not ``None``, then pandas-style queries may be used to select + subsets of data, see :meth:`mne.Epochs.__getitem__`. When the {obj} object + is subsetted, the metadata is subsetted accordingly, and the row indices + will be modified to match ``{obj}.selection``.""", +) +docdict["metadata_attr"] = _metadata_attr_template.format( + or_none=" (or ``None``)", extra="" +) +docdict["metadata_epochs"] = _metadata_template.format(obj="Epochs") +docdict["metadata_epochstfr"] = _metadata_template.format(obj="EpochsTFR") -docdict[ - "method_fir" -] = """ +docdict["method_fir"] = """ method : str ``'fir'`` will use overlap-add FIR filtering, ``'iir'`` will use IIR forward-backward filtering (via :func:`~scipy.signal.filtfilt`). """ -docdict[ - "method_kw_psd" -] = """\ +_method_kw_tfr_template = """ +**method_kw + Additional keyword arguments passed to the spectrotemporal estimation function + (e.g., ``n_cycles, use_fft, zero_mean`` for Morlet method{stockwell} + or ``n_cycles, use_fft, zero_mean, time_bandwidth`` for multitaper method). + See :func:`~mne.time_frequency.tfr_array_morlet`{stockwell_crossref} + and :func:`~mne.time_frequency.tfr_array_multitaper` for additional details. +""" + +docdict["method_kw_epochs_tfr"] = _method_kw_tfr_template.format( + stockwell=", ``n_fft, width`` for Stockwell method,", + stockwell_crossref=", :func:`~mne.time_frequency.tfr_array_stockwell`,", +) + +docdict["method_kw_psd"] = """\ **method_kw Additional keyword arguments passed to the spectral estimation function (e.g., ``n_fft, n_overlap, n_per_seg, average, window`` @@ -2627,7 +2581,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`~mne.time_frequency.psd_array_multitaper` for details. """ -_method_psd = r""" +docdict["method_kw_tfr"] = _method_kw_tfr_template.format( + stockwell="", stockwell_crossref="" +) + +_method_psd = """ method : ``'welch'`` | ``'multitaper'``{} Spectral estimation method. ``'welch'`` uses Welch's method :footcite:p:`Welch1967`, ``'multitaper'`` uses DPSS @@ -2643,16 +2601,42 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["method_psd"] = _method_psd.format("", "") docdict["method_psd_auto"] = _method_psd.format(" | ``'auto'``", "") -docdict[ - "mode_eltc" -] = """ +docdict["method_resample"] = """ +method : str + Resampling method to use. Can be ``"fft"`` (default) or ``"polyphase"`` + to use FFT-based on polyphase FIR resampling, respectively. These wrap to + :func:`scipy.signal.resample` and :func:`scipy.signal.resample_poly`, respectively. +""" + +_method_tfr_template = """ +method : ``'morlet'`` | ``'multitaper'``{literals} | None + Spectrotemporal power estimation method. ``'morlet'`` uses Morlet wavelets, + ``'multitaper'`` uses DPSS tapers :footcite:p:`Slepian1978`{cites}. ``None`` (the + default) only works when using ``__setstate__`` and will raise an error otherwise. +""" +docdict["method_tfr"] = _method_tfr_template.format(literals="", cites="") +docdict["method_tfr_array"] = """ +method : str | None + Comment on the method used to compute the data, e.g., ``"hilbert"``. + Default is ``None``. +""" +docdict["method_tfr_attr"] = """ +method : str + The method used to compute the spectra (e.g., ``"morlet"``, ``"multitaper"`` + or ``"stockwell"``). +""" +docdict["method_tfr_epochs"] = _method_tfr_template.format( + literals=" | ``'stockwell'``", + cites=", and ``'stockwell'`` uses the S-transform " + ":footcite:p:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`", +) + +docdict["mode_eltc"] = """ mode : str Extraction mode, see Notes. """ -docdict[ - "mode_pctf" -] = """ +docdict["mode_pctf"] = """ mode : None | 'mean' | 'max' | 'svd' Compute summary of PSFs/CTFs across all indices specified in 'idx'. Can be: @@ -2666,9 +2650,24 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): n_comp first SVD components. """ -docdict[ - "montage" -] = """ +docdict["mode_tfr_plot"] = """ +mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' + Perform baseline correction by + + - subtracting the mean of baseline values ('mean') (default) + - dividing by the mean of baseline values ('ratio') + - dividing by the mean of baseline values and taking the log + ('logratio') + - subtracting the mean of baseline values followed by dividing by + the mean of baseline values ('percent') + - subtracting the mean of baseline values and dividing by the + standard deviation of baseline values ('zscore') + - dividing by the mean of baseline values, taking the log, and + dividing by the standard deviation of log baseline values + ('zlogratio') +""" + +docdict["montage"] = """ montage : None | str | DigMontage A montage containing channel positions. If a string or :class:`~mne.channels.DigMontage` is @@ -2682,9 +2681,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["montage_types"] = """EEG/sEEG/ECoG/DBS/fNIRS""" -docdict[ - "montage_units" -] = """ +docdict["montage_units"] = """ montage_units : str Units that channel positions are represented in. Defaults to "mm" (millimeters), but can be any prefix + "m" combination (including just @@ -2693,22 +2690,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 1.3 """ -docdict[ - "morlet_reference" -] = """ +docdict["morlet_reference"] = """ The Morlet wavelets follow the formulation in :footcite:t:`Tallon-BaudryEtAl1997`. """ -docdict[ - "moving" -] = """ +docdict["moving"] = """ moving : instance of SpatialImage The image to morph ("from" volume). """ -docdict[ - "mri_resolution_eltc" -] = """ +docdict["mri_resolution_eltc"] = """ mri_resolution : bool If True (default), the volume source space will be upsampled to the original MRI resolution via trilinear interpolation before the atlas values @@ -2722,17 +2713,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # N -docdict[ - "n_comp_pctf_n" -] = """ +docdict["n_comp_pctf_n"] = """ n_comp : int Number of PSF/CTF components to return for mode='max' or mode='svd'. Default n_comp=1. """ -docdict[ - "n_cycles_tfr" -] = """ +docdict["n_cycles_tfr"] = """ n_cycles : int | array of int, shape (n_freqs,) Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of @@ -2741,9 +2728,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and about time and frequency smoothing. """ -docdict[ - "n_jobs" -] = """\ +docdict["n_jobs"] = """\ n_jobs : int | None The number of jobs to run in parallel. If ``-1``, it is set to the number of CPU cores. Requires the :mod:`joblib` package. @@ -2753,25 +2738,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): value for ``n_jobs``. """ -docdict[ - "n_jobs_cuda" -] = """ +docdict["n_jobs_cuda"] = """ n_jobs : int | str Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy`` is installed properly. """ -docdict[ - "n_jobs_fir" -] = """ +docdict["n_jobs_fir"] = """ n_jobs : int | str Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy`` is installed properly and ``method='fir'``. """ -docdict[ - "n_pca_components_apply" -] = """ +docdict["n_pca_components_apply"] = """ n_pca_components : int | float | None The number of PCA components to be kept, either absolute (int) or fraction of the explained variance (float). If None (default), @@ -2779,24 +2758,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): in 0.23 all components will be used. """ -docdict[ - "n_permutations_clust_all" -] = """ +docdict["n_permutations_clust_all"] = """ n_permutations : int | 'all' The number of permutations to compute. Can be 'all' to perform an exact test. """ -docdict[ - "n_permutations_clust_int" -] = """ +docdict["n_permutations_clust_int"] = """ n_permutations : int The number of permutations to compute. """ -docdict[ - "n_proj_vectors" -] = """ +docdict["n_proj_vectors"] = """ n_grad : int | float between ``0`` and ``1`` Number of vectors for gradiometers. Either an integer or a float between 0 and 1 to select the number of vectors to explain the cumulative variance greater than @@ -2811,18 +2784,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``n_eeg``. """ -docdict[ - "names_topomap" -] = """\ +docdict["names_topomap"] = """\ names : None | list Labels for the sensors. If a :class:`list`, labels should correspond to the order of channels in ``data``. If ``None`` (default), no channel names are plotted. """ -docdict[ - "nirx_notes" -] = """ +docdict["nave_tfr_attr"] = """ +nave : int + The number of epochs that were averaged to yield the result. This may reflect + epochs averaged *before* time-frequency analysis (as in + ``epochs.average(...).compute_tfr(...)``) or *after* time-frequency analysis (as + in ``epochs.compute_tfr(...).average(...)``). +""" +docdict["nirx_notes"] = """ This function has only been tested with NIRScout and NIRSport devices, and with the NIRStar software version 15 and above and Aurora software 2021 and above. @@ -2836,9 +2812,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): saturated data. """ -docdict[ - "niter" -] = """ +docdict["niter"] = """ niter : dict | tuple | None For each phase of the volume registration, ``niter`` is the number of iterations per successive stage of optimization. If a tuple is @@ -2857,9 +2831,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): sdr=(5, 5, 3)) """ -docdict[ - "norm_pctf" -] = """ +docdict["norm_pctf"] = """ norm : None | 'max' | 'norm' Whether and how to normalise the PSFs and CTFs. This will be applied before computing summaries as specified in 'mode'. @@ -2870,24 +2842,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): * 'norm' : Normalize to maximum norm across all PSFs/CTFs. """ -docdict[ - "normalization" -] = """normalization : 'full' | 'length' +docdict["normalization"] = """normalization : 'full' | 'length' Normalization strategy. If "full", the PSD will be normalized by the sampling rate as well as the length of the signal (as in :ref:`Nitime `). Default is ``'length'``.""" -docdict[ - "normalize_psd_topo" -] = """ +docdict["normalize_psd_topo"] = """ normalize : bool If True, each band will be divided by the total power. Defaults to False. """ -docdict[ - "notes_2d_backend" -] = """\ +docdict["notes_2d_backend"] = """\ MNE-Python provides two different backends for browsing plots (i.e., :meth:`raw.plot()`, :meth:`epochs.plot()`, and :meth:`ica.plot_sources()`). One is @@ -2915,9 +2881,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["notes_plot_*_psd_func"] = _notes_plot_psd.format("function") docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method") -docdict[ - "notes_spectrum_array" -] = """ +docdict["notes_spectrum_array"] = """ It is assumed that the data passed in represent spectral *power* (not amplitude, phase, model coefficients, etc) and downstream methods (such as :meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in @@ -2925,27 +2889,45 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): other things may also not work or be incorrect). """ -docdict[ - "notes_tmax_included_by_default" -] = """ +docdict["notes_timefreqs_tfr_plot_joint"] = """ +``timefreqs`` has three different modes: tuples, dicts, and auto. For (list of) tuple(s) +mode, each tuple defines a pair (time, frequency) in s and Hz on the TFR plot. +For example, to look at 10 Hz activity 1 second into the epoch and 3 Hz activity 300 ms +into the epoch, :: + + timefreqs=((1, 10), (.3, 3)) + +If provided as a dictionary, (time, frequency) tuples are keys and (time_window, +frequency_window) tuples are the values — indicating the width of the windows (centered +on the time and frequency indicated by the key) to be averaged over. For example, :: + + timefreqs={(1, 10): (0.1, 2)} + +would translate into a window that spans 0.95 to 1.05 seconds and 9 to 11 Hz. If +``None``, a single topomap will be plotted at the absolute peak across the +time-frequency representation. +""" + +docdict["notes_tmax_included_by_default"] = """ Unlike Python slices, MNE time intervals by default include **both** their end points; ``crop(tmin, tmax)`` returns the interval ``tmin <= t <= tmax``. Pass ``include_tmax=False`` to specify the half-open interval ``tmin <= t < tmax`` instead. """ -docdict[ - "npad" -] = """ +docdict["npad"] = """ npad : int | str - Amount to pad the start and end of the data. - Can also be ``"auto"`` to use a padding that will result in - a power-of-two size (can be much faster). + Amount to pad the start and end of the data. Can also be ``"auto"`` to use a padding + that will result in a power-of-two size (can be much faster). """ -docdict[ - "nrows_ncols_ica_components" -] = """ +docdict["npad_resample"] = ( + docdict["npad"] + + """ + Only used when ``method="fft"``. +""" +) +docdict["nrows_ncols_ica_components"] = """ nrows, ncols : int | 'auto' The number of rows and columns of topographies to plot. If both ``nrows`` and ``ncols`` are ``'auto'``, will plot up to 20 components in a 5×4 grid, @@ -2956,9 +2938,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``nrows='auto', ncols='auto'``. """ -docdict[ - "nrows_ncols_topomap" -] = """ +docdict["nrows_ncols_topomap"] = """ nrows, ncols : int | 'auto' The number of rows and columns of topographies to plot. If either ``nrows`` or ``ncols`` is ``'auto'``, the necessary number will be inferred. Defaults @@ -2968,9 +2948,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # O -docdict[ - "offset_decim" -] = """ +docdict["offset_decim"] = """ offset : int Apply an offset to where the decimation starts relative to the sample corresponding to t=0. The offset is in samples at the @@ -2979,9 +2957,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.12 """ -docdict[ - "on_baseline_ica" -] = """ +docdict["on_baseline_ica"] = """ on_baseline : str How to handle baseline-corrected epochs or evoked data. Can be ``'raise'`` to raise an error, ``'warn'`` (default) to emit a @@ -2990,9 +2966,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 1.2 """ -docdict[ - "on_defects" -] = """ +docdict["on_defects"] = """ on_defects : 'raise' | 'warn' | 'ignore' What to do if the surface is found to have topological defects. Can be ``'raise'`` (default) to raise an error, ``'warn'`` to emit a @@ -3003,9 +2977,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): fail irrespective of this parameter. """ -docdict[ - "on_header_missing" -] = """ +docdict["on_header_missing"] = """ on_header_missing : str Can be ``'raise'`` (default) to raise an error, ``'warn'`` to emit a warning, or ``'ignore'`` to ignore when the FastSCAN header is missing. @@ -3018,9 +2990,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): warning, or ``'ignore'`` to ignore when""" -docdict[ - "on_mismatch_info" -] = f""" +docdict["on_mismatch_info"] = f""" on_mismatch : 'raise' | 'warn' | 'ignore' {_on_missing_base} the device-to-head transformation differs between instances. @@ -3028,27 +2998,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.24 """ -docdict[ - "on_missing_ch_names" -] = f""" +docdict["on_missing_ch_names"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} entries in ch_names are not present in the raw instance. .. versionadded:: 0.23.0 """ -docdict[ - "on_missing_chpi" -] = f""" +docdict["on_missing_chpi"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} no cHPI information can be found. If ``'ignore'`` or ``'warn'``, all return values will be empty arrays or ``None``. If ``'raise'``, an exception will be raised. """ -docdict[ - "on_missing_epochs" -] = """ +docdict["on_missing_epochs"] = """ on_missing : 'raise' | 'warn' | 'ignore' What to do if one or several event ids are not found in the recording. Valid keys are 'raise' | 'warn' | 'ignore' @@ -3060,9 +3024,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): automatically generated irrespective of this parameter. """ -docdict[ - "on_missing_events" -] = f""" +docdict["on_missing_events"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} event numbers from ``event_id`` are missing from :term:`events`. When numbers from :term:`events` are missing from @@ -3072,32 +3034,24 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.21 """ -docdict[ - "on_missing_fiducials" -] = f""" +docdict["on_missing_fiducials"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} some necessary fiducial points are missing. """ -docdict[ - "on_missing_fwd" -] = f""" +docdict["on_missing_fwd"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} ``stc`` has vertices that are not in ``fwd``. """ -docdict[ - "on_missing_montage" -] = f""" +docdict["on_missing_montage"] = f""" on_missing : 'raise' | 'warn' | 'ignore' {_on_missing_base} channels have missing coordinates. .. versionadded:: 0.20.1 """ -docdict[ - "on_rank_mismatch" -] = """ +docdict["on_rank_mismatch"] = """ on_rank_mismatch : str If an explicit MEG value is passed, what to do when it does not match an empirically computed rank (only used for covariances). @@ -3107,30 +3061,24 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.23 """ -docdict[ - "on_split_missing" -] = f""" +docdict["on_split_missing"] = f""" on_split_missing : str {_on_missing_base} split file is missing. .. versionadded:: 0.22 """ -docdict[ - "ordered" -] = """ +docdict["ordered"] = """ ordered : bool - If True (default False), ensure that the order of the channels in + If True (default), ensure that the order of the channels in the modified instance matches the order of ``ch_names``. .. versionadded:: 0.20.0 - .. versionchanged:: 1.5 - The default changed from False in 1.4 to True in 1.5. + .. versionchanged:: 1.7 + The default changed from False in 1.6 to True in 1.7. """ -docdict[ - "origin_maxwell" -] = """ +docdict["origin_maxwell"] = """ origin : array-like, shape (3,) | str Origin of internal and external multipolar moment space in meters. The default is ``'auto'``, which means ``(0., 0., 0.)`` when @@ -3142,9 +3090,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): options or specifying the origin manually. """ -docdict[ - "out_type_clust" -] = """ +docdict["out_type_clust"] = """ out_type : 'mask' | 'indices' Output format of clusters within a list. If ``'mask'``, returns a list of boolean arrays, @@ -3157,9 +3103,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Default is ``'indices'``. """ -docdict[ - "outlines_topomap" -] = """ +docdict["outlines_topomap"] = """ outlines : 'head' | dict | None The outlines to be drawn. If 'head', the default head scheme will be drawn. If dict, each key refers to a tuple of x and y positions, the values @@ -3170,9 +3114,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Defaults to 'head'. """ -docdict[ - "overview_mode" -] = """ +docdict["output_compute_tfr"] = """ +output : str + What kind of estimate to return. Allowed values are ``"complex"``, ``"phase"``, + and ``"power"``. Default is ``"power"``. +""" + +docdict["overview_mode"] = """ overview_mode : str | None Can be "channels", "empty", or "hidden" to set the overview bar mode for the ``'qt'`` backend. If None (default), the config option @@ -3180,9 +3128,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): if it's not found. """ -docdict[ - "overwrite" -] = """ +docdict["overwrite"] = """ overwrite : bool If True (default False), overwrite the destination file if it exists. @@ -3192,70 +3138,92 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # P _pad_base = """ -pad : str - The type of padding to use. Supports all :func:`numpy.pad` ``mode`` - options. Can also be ``"reflect_limited"``, which pads with a - reflected version of each vector mirrored on the first and last values + all :func:`numpy.pad` ``mode`` options. Can also be ``"reflect_limited"``, which + pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. """ -docdict["pad"] = _pad_base - docdict["pad_fir"] = ( - _pad_base - + """ + """ +pad : str + The type of padding to use. Supports """ + + _pad_base + + """\ Only used for ``method='fir'``. """ ) -docdict[ - "pca_vars_pctf" -] = """ +docdict["pad_resample"] = ( # used when default is not "auto" + """ +pad : str + The type of padding to use. When ``method="fft"``, supports """ + + _pad_base + + """\ + When ``method="polyphase"``, supports all modes of :func:`scipy.signal.upfirdn`. +""" +) + +docdict["pad_resample_auto"] = ( # used when default is "auto" + docdict["pad_resample"] + + """\ + The default ("auto") means ``'reflect_limited'`` for ``method='fft'`` and + ``'reflect'`` for ``method='polyphase'``. +""" +) +docdict["pca_vars_pctf"] = """ pca_vars : array, shape (n_comp,) | list of array The explained variances of the first n_comp SVD components across the PSFs/CTFs for the specified vertices. Arrays for multiple labels are returned as list. Only returned if ``mode='svd'`` and ``return_pca_vars=True``. """ -docdict[ - "per_sample_metric" -] = """ +docdict["per_sample_metric"] = """ per_sample : bool If True the metric is computed for each sample separately. If False, the metric is spatio-temporal. """ -docdict[ - "phase" -] = """ +docdict["phase"] = """ phase : str Phase of the filter. - When ``method='fir'``, symmetric linear-phase FIR filters are constructed, - and if ``phase='zero'`` (default), the delay of this filter is compensated - for, making it non-causal. If ``phase='zero-double'``, - then this filter is applied twice, once forward, and once backward - (also making it non-causal). If ``'minimum'``, then a minimum-phase filter - will be constructed and applied, which is causal but has weaker stop-band - suppression. - When ``method='iir'``, ``phase='zero'`` (default) or - ``phase='zero-double'`` constructs and applies IIR filter twice, once - forward, and once backward (making it non-causal) using - :func:`~scipy.signal.filtfilt`. - If ``phase='forward'``, it constructs and applies forward IIR filter using + When ``method='fir'``, symmetric linear-phase FIR filters are constructed + with the following behaviors when ``method="fir"``: + + ``"zero"`` (default) + The delay of this filter is compensated for, making it non-causal. + ``"minimum"`` + A minimum-phase filter will be constructed by decomposing the zero-phase filter + into a minimum-phase and all-pass systems, and then retaining only the + minimum-phase system (of the same length as the original zero-phase filter) + via :func:`scipy.signal.minimum_phase`. + ``"zero-double"`` + *This is a legacy option for compatibility with MNE <= 0.13.* + The filter is applied twice, once forward, and once backward + (also making it non-causal). + ``"minimum-half"`` + *This is a legacy option for compatibility with MNE <= 1.6.* + A minimum-phase filter will be reconstructed from the zero-phase filter with + half the length of the original filter. + + When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'`` + constructs and applies IIR filter twice, once forward, and once backward (making it + non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply + the filter once in the forward (causal) direction using :func:`~scipy.signal.lfilter`. .. versionadded:: 0.13 + .. versionchanged:: 1.7 + + The behavior for ``phase="minimum"`` was fixed to use a filter of the requested + length and improved suppression. """ -docdict[ - "physical_range_export_params" -] = """ +docdict["physical_range_export_params"] = """ physical_range : str | tuple - The physical range of the data. If 'auto' (default), then - it will infer the physical min and max from the data itself, - taking the minimum and maximum values per channel type. - If it is a 2-tuple of minimum and maximum limit, then those - physical ranges will be used. Only used for exporting EDF files. + The physical range of the data. If 'auto' (default), the physical range is inferred + from the data, taking the minimum and maximum values per channel type. If + 'channelwise', the range will be defined per channel. If a tuple of minimum and + maximum, this manual physical range will be used. Only used for exporting EDF files. """ _pick_ori_novec = """ @@ -3283,9 +3251,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ ) -docdict[ - "pick_ori_bf" -] = """ +docdict["pick_ori_bf"] = """ pick_ori : None | str For forward solutions with fixed orientation, None (default) must be used and a scalar beamformer is computed. For free-orientation forward @@ -3307,9 +3273,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): + _pick_ori_novec ) -docdict[ - "pick_types_params" -] = """ +docdict["pick_types_params"] = """ meg : bool | str If True include MEG channels. If string it can be 'mag', 'grad', 'planar1' or 'planar2' to select only magnetometers, all @@ -3423,9 +3387,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): f"{picks_base} good data channels {noref}" ) docdict["picks_header"] = _picks_header -docdict[ - "picks_ica" -] = """ +docdict["picks_ica"] = """ picks : int | list of int | slice | None Indices of the independent components (ICs) to visualize. If an integer, represents the index of the IC to pick. @@ -3434,24 +3396,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): IC: ``ICA001``. ``None`` will pick all independent components in the order fitted. """ -docdict[ - "picks_nostr" -] = f"""picks : list | slice | None +docdict["picks_nostr"] = f"""picks : list | slice | None {_picks_desc} {_picks_int} None (default) will pick all channels. {reminder_nostr}""" -docdict[ - "picks_plot_projs_joint_trace" -] = f"""\ +docdict["picks_plot_projs_joint_trace"] = f"""\ picks_trace : {_picks_types} Channels to show alongside the projected time courses. Typically these are the ground-truth channels for an artifact (e.g., ``'eog'`` or ``'ecg'``). {_picks_int} {_picks_str} no channels. """ -docdict[ - "pipeline" -] = """ +docdict["pipeline"] = """ pipeline : str | tuple The volume registration steps to perform (a ``str`` for a single step, or ``tuple`` for a set of sequential steps). The following steps can be @@ -3490,9 +3446,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the SDR step. """ -docdict[ - "plot_psd_doc" -] = """\ +docdict["plot_psd_doc"] = """\ Plot power or amplitude spectra. Separate plots are drawn for each channel type. When the data have been @@ -3512,16 +3466,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["pos_topomap"] = _pos_topomap.format(" | instance of Info") docdict["pos_topomap_psd"] = _pos_topomap.format("") -docdict[ - "position" -] = """ +docdict["position"] = """ position : int The position for the progress bar. """ -docdict[ - "precompute" -] = """ +docdict["precompute"] = """ precompute : bool | str Whether to load all data (not just the visible portion) into RAM and apply preprocessing (e.g., projectors) to the full data array in a separate @@ -3536,9 +3486,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for the MNE_BROWSER_PRECOMPUTE config variable. """ -docdict[ - "preload" -] = """ +docdict["preload"] = """ preload : bool or str (default False) Preload data into memory for data manipulation and faster indexing. If True, the data will be preloaded into memory (fast, requires @@ -3546,9 +3494,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): file name of a memory-mapped file which is used to store the data on the hard drive (slower, requires less memory).""" -docdict[ - "preload_concatenate" -] = """ +docdict["preload_concatenate"] = """ preload : bool, str, or None (default None) Preload data into memory for data manipulation and faster indexing. If True, the data will be preloaded into memory (fast, requires @@ -3559,9 +3505,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): of the instances passed in. """ -docdict[ - "proj_epochs" -] = """ +docdict["proj_epochs"] = """ proj : bool | 'delayed' Apply SSP projection vectors. If proj is 'delayed' and reject is not None the single epochs will be projected before the rejection @@ -3575,9 +3519,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): recommended value if SSPs are not used for cleaning the data. """ -docdict[ - "proj_plot" -] = """ +docdict["proj_plot"] = """ proj : bool | 'interactive' | 'reconstruct' If true SSP projections are applied before display. If 'interactive', a check box for reversible selection of SSP projection vectors will @@ -3589,17 +3531,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Support for 'reconstruct' was added. """ -docdict[ - "proj_psd" -] = """\ +docdict["proj_psd"] = """\ proj : bool Whether to apply SSP projection vectors before spectral estimation. Default is ``False``. """ -docdict[ - "projection_set_eeg_reference" -] = """ +docdict["projection_set_eeg_reference"] = """ projection : bool If ``ref_channels='average'`` this argument specifies if the average reference should be computed as a projection (True) or not @@ -3611,16 +3549,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): must be set to ``False`` (the default in this case). """ -docdict[ - "projs" -] = """ +docdict["projs"] = """ projs : list of Projection List of computed projection vectors. """ -docdict[ - "projs_report" -] = """ +docdict["projs_report"] = """ projs : bool | None Whether to add SSP projector plots if projectors are present in the data. If ``None``, use ``projs`` from `~mne.Report` creation. @@ -3629,9 +3563,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # R -docdict[ - "random_state" -] = """ +docdict["random_state"] = """ random_state : None | int | instance of ~numpy.random.RandomState A seed for the NumPy random number generator (RNG). If ``None`` (default), the seed will be obtained from the operating system @@ -3690,24 +3622,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["rank_info"] = _rank_base + "\n The default is ``'info'``." docdict["rank_none"] = _rank_base + "\n The default is ``None``." -docdict[ - "raw_epochs" -] = """ +docdict["raw_epochs"] = """ raw : Raw object An instance of `~mne.io.Raw`. """ -docdict[ - "raw_sfreq" -] = """ +docdict["raw_sfreq"] = """ raw_sfreq : float The original Raw object sampling rate. If None, then it is set to ``info['sfreq']``. """ -docdict[ - "reduce_rank" -] = """ +docdict["reduce_rank"] = """ reduce_rank : bool If True, the rank of the denominator of the beamformer formula (i.e., during pseudo-inversion) will be reduced by one for each spatial location. @@ -3719,18 +3645,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``pick='max_power'`` with weight normalization). """ -docdict[ - "ref_channels" -] = """ +docdict["ref_channels"] = """ ref_channels : str | list of str Name of the electrode(s) which served as the reference in the recording. If a name is provided, a corresponding channel is added and its data is set to 0. This is useful for later re-referencing. """ -docdict[ - "ref_channels_set_eeg_reference" -] = """ +docdict["ref_channels_set_eeg_reference"] = """ ref_channels : list of str | str Can be: @@ -3742,16 +3664,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the data """ -docdict[ - "reg_affine" -] = """ +docdict["reg_affine"] = """ reg_affine : ndarray of float, shape (4, 4) The affine that registers one volume to another. """ -docdict[ - "regularize_maxwell_reg" -] = """ +docdict["regularize_maxwell_reg"] = """ regularize : str | None Basis regularization type, must be ``"in"`` or None. ``"in"`` is the same algorithm as the ``-regularize in`` option in @@ -3768,18 +3686,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["reject_by_annotation_all"] = _reject_by_annotation_base -docdict[ - "reject_by_annotation_epochs" -] = """ +docdict["reject_by_annotation_epochs"] = """ reject_by_annotation : bool Whether to reject based on annotations. If ``True`` (default), epochs overlapping with segments whose description begins with ``'bad'`` are rejected. If ``False``, no rejection based on annotations is performed. """ -docdict[ - "reject_by_annotation_psd" -] = """\ +docdict["reject_by_annotation_psd"] = """\ reject_by_annotation : bool Whether to omit bad spans of data before spectral estimation. If ``True``, spans with annotations whose description begins with @@ -3793,6 +3707,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ ) +docdict["reject_by_annotation_tfr"] = """ +reject_by_annotation : bool + Whether to omit bad spans of data before spectrotemporal power + estimation. If ``True``, spans with annotations whose description + begins with ``bad`` will be represented with ``np.nan`` in the + time-frequency representation. +""" + _reject_common = """\ Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP), i.e. the absolute difference between the lowest and the highest signal @@ -3817,18 +3739,53 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): difference will be preserved. """ -docdict[ - "reject_drop_bad" -] = f""" +docdict["reject_drop_bad"] = """\ reject : dict | str | None -{_reject_common} + Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP) + or custom functions. Peak-to-peak signal amplitude is defined as + the absolute difference between the lowest and the highest signal + value. In each individual epoch, the PTP is calculated for every channel. + If the PTP of any one channel exceeds the rejection threshold, the + respective epoch will be dropped. + + The dictionary keys correspond to the different channel types; valid + **keys** can be any channel type present in the object. + + Example:: + + reject = dict(grad=4000e-13, # unit: T / m (gradiometers) + mag=4e-12, # unit: T (magnetometers) + eeg=40e-6, # unit: V (EEG channels) + eog=250e-6 # unit: V (EOG channels) + ) + + Custom rejection criteria can be also be used by passing a callable, + e.g., to check for 99th percentile of absolute values of any channel + across time being bigger than :unit:`1 mV`. The callable must return a + ``(good, reason)`` tuple: ``good`` must be :class:`bool` and ``reason`` + must be :class:`str`, :class:`list`, or :class:`tuple` where each entry + is a :class:`str`:: + + reject = dict( + eeg=lambda x: ( + (np.percentile(np.abs(x), 99, axis=1) > 1e-3).any(), + "signal > 1 mV somewhere", + ) + ) + + .. note:: If rejection is based on a signal **difference** + calculated for each channel separately, applying baseline + correction does not affect the rejection procedure, as the + difference will be preserved. + + .. note:: If ``reject`` is a callable, than **any** criteria can be + used to reject epochs (including maxima and minima). + If ``reject`` is ``None``, no rejection is performed. If ``'existing'`` (default), then the rejection parameters set at instantiation are used. -""" +""" # noqa: E501 -docdict[ - "reject_epochs" -] = f""" +docdict["reject_epochs"] = f""" reject : dict | None {_reject_common} .. note:: To constrain the time period used for estimation of signal @@ -3837,17 +3794,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If ``reject`` is ``None`` (default), no rejection is performed. """ -docdict[ - "remove_dc" -] = """ +docdict["remove_dc"] = """ remove_dc : bool If ``True``, the mean is subtracted from each segment before computing its spectrum. """ -docdict[ - "replace_report" -] = """ +docdict["replace_report"] = """ replace : bool If ``True``, content already present that has the same ``title`` and ``section`` will be replaced. Defaults to ``False``, which will cause @@ -3855,16 +3808,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): already exists. """ -docdict[ - "res_topomap" -] = """ +docdict["res_topomap"] = """ res : int The resolution of the topomap image (number of pixels along each side). """ -docdict[ - "return_pca_vars_pctf" -] = """ +docdict["return_pca_vars_pctf"] = """ return_pca_vars : bool Whether or not to return the explained variances across the specified vertices for individual SVD components. This is only valid if @@ -3872,9 +3821,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Default return_pca_vars=False. """ -docdict[ - "roll" -] = """ +docdict["roll"] = """ roll : float | None The roll of the camera rendering the view in degrees. """ @@ -3882,9 +3829,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # S -docdict[ - "saturated" -] = """saturated : str +docdict["saturated"] = """saturated : str Replace saturated segments of data with NaNs, can be: ``"ignore"`` @@ -3903,9 +3848,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.24 """ -docdict[ - "scalings" -] = """ +docdict["scalings"] = """ scalings : 'auto' | dict | None Scaling factors for the traces. If a dictionary where any value is ``'auto'``, the scaling factor is set to match the 99.5th @@ -3926,26 +3869,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): positive direction and 20 µV in the negative direction). """ -docdict[ - "scalings_df" -] = """ +docdict["scalings_df"] = """ scalings : dict | None Scaling factor applied to the channels picked. If ``None``, defaults to ``dict(eeg=1e6, mag=1e15, grad=1e13)`` — i.e., converts EEG to µV, magnetometers to fT, and gradiometers to fT/cm. """ -docdict[ - "scalings_topomap" -] = """ +docdict["scalings_topomap"] = """ scalings : dict | float | None The scalings of the channel types to be applied for plotting. If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``. """ -docdict[ - "scoring" -] = """ +docdict["scoring"] = """ scoring : callable | str | None Score function (or loss function) with signature ``score_func(y, y_pred, **kwargs)``. @@ -3955,17 +3892,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``scoring=sklearn.metrics.roc_auc_score``). """ -docdict[ - "sdr_morph" -] = """ +docdict["sdr_morph"] = """ sdr_morph : instance of dipy.align.DiffeomorphicMap The class that applies the the symmetric diffeomorphic registration (SDR) morph. """ -docdict[ - "section_report" -] = """ +docdict["section_report"] = """ section : str | None The name of the section (or content block) to add the content to. This feature is useful for grouping multiple related content elements @@ -3977,9 +3910,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 1.1 """ -docdict[ - "seed" -] = """ +docdict["seed"] = """ seed : None | int | instance of ~numpy.random.RandomState A seed for the NumPy random number generator (RNG). If ``None`` (default), the seed will be obtained from the operating system @@ -3989,24 +3920,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the RNG with a defined state. """ -docdict[ - "seeg" -] = """ +docdict["seeg"] = """ seeg : bool If True (default), show sEEG electrodes. """ -docdict[ - "selection" -] = """ +docdict["selection"] = """ selection : iterable | None Iterable of indices of selected epochs. If ``None``, will be automatically generated, corresponding to all non-zero events. """ +docdict["selection_attr"] = """ +selection : ndarray + Array of indices of *selected* epochs (i.e., epochs that were not rejected, dropped, + or ignored).""" -docdict[ - "sensor_colors" -] = """ +docdict["sensor_colors"] = """ sensor_colors : array-like of color | dict | None Colors to use for the sensor glyphs. Can be None (default) to use default colors. A dict should provide the colors (values) for each channel type (keys), e.g.:: @@ -4020,9 +3949,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``. """ -docdict[ - "sensors_topomap" -] = """ +docdict["sensors_topomap"] = """ sensors : bool | str Whether to add markers for sensor locations. If :class:`str`, should be a valid matplotlib format string (e.g., ``'r+'`` for red plusses, see the @@ -4030,9 +3957,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): default), black circles will be used. """ -docdict[ - "set_eeg_reference_see_also_notes" -] = """ +docdict["set_eeg_reference_see_also_notes"] = """ See Also -------- mne.set_bipolar_reference : Convenience function for creating bipolar @@ -4082,16 +4007,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. footbibliography:: """ -docdict[ - "show" -] = """\ +docdict["sfreq_tfr_attr"] = """ +sfreq : int | float + The sampling frequency (read from ``info``).""" +docdict["shape_tfr_attr"] = """ +shape : tuple of int + The shape of the data.""" + +docdict["show"] = """\ show : bool Show the figure if ``True``. """ -docdict[ - "show_names_topomap" -] = """ +docdict["show_names_topomap"] = """ show_names : bool | callable If ``True``, show channel names next to each sensor marker. If callable, channel names will be formatted using the callable; e.g., to @@ -4100,18 +4028,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): non-masked sensor names will be shown. """ -docdict[ - "show_scalebars" -] = """ +docdict["show_scalebars"] = """ show_scalebars : bool Whether to show scale bars when the plot is initialized. Can be toggled after initialization by pressing :kbd:`s` while the plot window is focused. Default is ``True``. """ -docdict[ - "show_scrollbars" -] = """ +docdict["show_scrollbars"] = """ show_scrollbars : bool Whether to show scrollbars when the plot is initialized. Can be toggled after initialization by pressing :kbd:`z` ("zen mode") while the plot @@ -4120,9 +4044,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.19.0 """ -docdict[ - "show_traces" -] = """ +docdict["show_traces"] = """ show_traces : bool | str | float If True, enable interactive picking of a point on the surface of the brain and plot its time course. @@ -4136,16 +4058,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.20.0 """ -docdict[ - "size_topomap" -] = """ +docdict["size_topomap"] = """ size : float Side length of each subplot in inches. """ -docdict[ - "skip_by_annotation" -] = """ +docdict["skip_by_annotation"] = """ skip_by_annotation : str | list of str If a string (or list of str), any annotation segment that begins with the given string will not be included in filtering, and @@ -4157,9 +4075,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): To disable, provide an empty list. Only used if ``inst`` is raw. """ -docdict[ - "skip_by_annotation_maxwell" -] = """ +docdict["skip_by_annotation_maxwell"] = """ skip_by_annotation : str | list of str If a string (or list of str), any annotation segment that begins with the given string will not be included in filtering, and @@ -4171,24 +4087,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): To disable, provide an empty list. """ -docdict[ - "smooth" -] = """ +docdict["smooth"] = """ smooth : float in [0, 1) The smoothing factor to be applied. Default 0 is no smoothing. """ -docdict[ - "spatial_colors_psd" -] = """\ +docdict["spatial_colors_psd"] = """\ spatial_colors : bool Whether to color spectrum lines by channel location. Ignored if ``average=True``. """ -docdict[ - "sphere_topomap_auto" -] = f"""\ +docdict["sphere_topomap_auto"] = f"""\ sphere : float | array-like | instance of ConductorModel | None | 'auto' | 'eeglab' The sphere parameters to use for the head outline. Can be array-like of shape (4,) to give the X/Y/Z origin and radius in meters, or a single float @@ -4205,17 +4115,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionchanged:: 1.1 Added ``'eeglab'`` option. """ -docdict[ - "splash" -] = """ +docdict["splash"] = """ splash : bool If True (default), a splash screen is shown during the application startup. Only applicable to the ``qt`` backend. """ -docdict[ - "split_naming" -] = """ +docdict["split_naming"] = """ split_naming : 'neuromag' | 'bids' When splitting files, append a filename partition with the appropriate naming schema: for ``'neuromag'``, a split file ``fname.fif`` will be named @@ -4223,16 +4129,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): it will be named ``fname_split-01.fif``, ``fname_split-02.fif``, etc. """ -docdict[ - "src_eltc" -] = """ +docdict["src_eltc"] = """ src : instance of SourceSpaces The source spaces for the source time courses. """ -docdict[ - "src_volume_options" -] = """ +docdict["src_volume_options"] = """ src : instance of SourceSpaces | None The source space corresponding to the source estimate. Only necessary if the STC is a volume or mixed source estimate. @@ -4263,9 +4165,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): entry. """ -docdict[ - "st_fixed_maxwell_only" -] = """ +docdict["st_fixed_maxwell_only"] = """ st_fixed : bool If True (default), do tSSS using the median head position during the ``st_duration`` window. This is the default behavior of MaxFilter @@ -4287,9 +4187,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.12 """ -docdict[ - "standardize_names" -] = """ +docdict["standardize_names"] = """ standardize_names : bool If True, standardize MEG and EEG channel names to be ``'MEG ###'`` and ``'EEG ###'``. If False (default), native @@ -4307,48 +4205,36 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["stat_fun_clust_t"] = _stat_fun_clust_base.format("ttest_1samp_no_p") -docdict[ - "static" -] = """ +docdict["static"] = """ static : instance of SpatialImage The image to align with ("to" volume). """ -docdict[ - "stc_est_metric" -] = """ +docdict["stc_est_metric"] = """ stc_est : instance of (Vol|Mixed)SourceEstimate The source estimates containing estimated values e.g. obtained with a source imaging method. """ -docdict[ - "stc_metric" -] = """ +docdict["stc_metric"] = """ metric : float | array, shape (n_times,) The metric. float if per_sample is False, else array with the values computed for each time point. """ -docdict[ - "stc_plot_kwargs_report" -] = """ +docdict["stc_plot_kwargs_report"] = """ stc_plot_kwargs : dict Dictionary of keyword arguments to pass to :class:`mne.SourceEstimate.plot`. Only used when plotting in 3D mode. """ -docdict[ - "stc_true_metric" -] = """ +docdict["stc_true_metric"] = """ stc_true : instance of (Vol|Mixed)SourceEstimate The source estimates containing correct values. """ -docdict[ - "stcs_pctf" -] = """ +docdict["stcs_pctf"] = """ stcs : instance of SourceEstimate | list of instances of SourceEstimate The PSFs or CTFs as STC objects. All PSFs/CTFs will be returned as successive samples in STC objects, in the order they are specified @@ -4364,9 +4250,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): a VectorSourceEstimate object. """ -docdict[ - "std_err_by_event_type_returns" -] = """ +docdict["std_err_by_event_type_returns"] = """ std_err : instance of Evoked | list of Evoked The standard error over epochs. When ``by_event_type=True`` was specified, a list is returned containing a @@ -4375,9 +4259,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): dictionary. """ -docdict[ - "step_down_p_clust" -] = """ +docdict["step_down_p_clust"] = """ step_down_p : float To perform a step-down-in-jumps test, pass a p-value for clusters to exclude from each successive iteration. Default is zero, perform no @@ -4386,48 +4268,36 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): but costs computation time. """ -docdict[ - "subject" -] = """ +docdict["subject"] = """ subject : str The FreeSurfer subject name. """ -docdict[ - "subject_label" -] = """ +docdict["subject_label"] = """ subject : str | None Subject which this label belongs to. Should only be specified if it is not specified in the label. """ -docdict[ - "subject_none" -] = """ +docdict["subject_none"] = """ subject : str | None The FreeSurfer subject name. """ -docdict[ - "subject_optional" -] = """ +docdict["subject_optional"] = """ subject : str The FreeSurfer subject name. While not necessary, it is safer to set the subject parameter to avoid analysis errors. """ -docdict[ - "subjects_dir" -] = """ +docdict["subjects_dir"] = """ subjects_dir : path-like | None The path to the directory containing the FreeSurfer subjects reconstructions. If ``None``, defaults to the ``SUBJECTS_DIR`` environment variable. """ -docdict[ - "surface" -] = """surface : str +docdict["surface"] = """surface : str The surface along which to do the computations, defaults to ``'white'`` (the gray-white matter boundary). """ @@ -4435,9 +4305,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # T -docdict[ - "t_power_clust" -] = """ +docdict["t_power_clust"] = """ t_power : float Power to raise the statistical values (usually t-values) by before summing (sign will be retained). Note that ``t_power=0`` will give a @@ -4445,24 +4313,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): by its statistical score. """ -docdict[ - "t_window_chpi_t" -] = """ +docdict["t_window_chpi_t"] = """ t_window : float Time window to use to estimate the amplitudes, default is 0.2 (200 ms). """ -docdict[ - "tags_report" -] = """ +docdict["tags_report"] = """ tags : array-like of str | str Tags to add for later interactive filtering. Must not contain spaces. """ -docdict[ - "tail_clust" -] = """ +docdict["tail_clust"] = """ tail : int If tail is 1, the statistic is thresholded above threshold. If tail is -1, the statistic is thresholded below threshold. @@ -4470,9 +4332,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): the distribution. """ -docdict[ - "temporal_window_tfr_intro" -] = """ +docdict["temporal_window_tfr_intro"] = """ In spectrotemporal analysis (as with traditional fourier methods), the temporal and spectral resolution are interrelated: longer temporal windows allow more precise frequency estimates; shorter temporal windows "smear" @@ -4492,9 +4352,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): multitapers and wavelets `_. """ # noqa: E501 -docdict[ - "temporal_window_tfr_morlet_notes" -] = r""" +docdict["temporal_window_tfr_morlet_notes"] = r""" In MNE-Python, the length of the Morlet wavelet is affected by the arguments ``freqs`` and ``n_cycles``, which define the frequencies of interest and the number of cycles, respectively. For the time-frequency representation, @@ -4512,9 +4370,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): For more information on the Morlet wavelet, see :func:`mne.time_frequency.morlet`. """ -docdict[ - "temporal_window_tfr_multitaper_notes" -] = r""" +docdict["temporal_window_tfr_multitaper_notes"] = r""" In MNE-Python, the multitaper temporal window length is defined by the arguments ``freqs`` and ``n_cycles``, respectively defining the frequencies of interest and the number of cycles: :math:`T = \frac{\mathtt{n\_cycles}}{\mathtt{freqs}}` @@ -4532,33 +4388,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): theme : str | path-like Can be "auto", "light", or "dark" or a path-like to a custom stylesheet. For Dark-Mode and automatic Dark-Mode-Detection, - :mod:`qdarkstyle` and + `qdarkstyle `__ and `darkdetect `__, respectively, are required.\ If None (default), the config option {config_option} will be used, defaulting to "auto" if it's not found.\ """ -docdict[ - "theme_3d" -] = """ +docdict["theme_3d"] = """ {theme} -""".format( - theme=_theme.format(config_option="MNE_3D_OPTION_THEME") -) +""".format(theme=_theme.format(config_option="MNE_3D_OPTION_THEME")) -docdict[ - "theme_pg" -] = """ +docdict["theme_pg"] = """ {theme} Only supported by the ``'qt'`` backend. -""".format( - theme=_theme.format(config_option="MNE_BROWSER_THEME") -) +""".format(theme=_theme.format(config_option="MNE_BROWSER_THEME")) -docdict[ - "thresh" -] = """ +docdict["thresh"] = """ thresh : None or float Not supported yet. If not None, values below thresh will not be visible. @@ -4582,9 +4428,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): f_test = ("an F-threshold", "an F-statistic") docdict["threshold_clust_f"] = _threshold_clust_base.format(*f_test) -docdict[ - "threshold_clust_f_notes" -] = """ +docdict["threshold_clust_f_notes"] = """ For computing a ``threshold`` based on a p-value, use the conversion from :meth:`scipy.stats.rv_continuous.ppf`:: @@ -4597,9 +4441,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): t_test = ("a t-threshold", "a t-statistic") docdict["threshold_clust_t"] = _threshold_clust_base.format(*t_test) -docdict[ - "threshold_clust_t_notes" -] = """ +docdict["threshold_clust_t_notes"] = """ For computing a ``threshold`` based on a p-value, use the conversion from :meth:`scipy.stats.rv_continuous.ppf`:: @@ -4611,9 +4453,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): For testing the lower tail (``tail=-1``), don't subtract ``pval`` from 1. """ -docdict[ - "time_bandwidth_tfr" -] = """ +docdict["time_bandwidth_tfr"] = """ time_bandwidth : float ``≥ 2.0`` Product between the temporal window length (in seconds) and the *full* frequency bandwidth (in Hz). This product can be seen as the surface of the @@ -4621,9 +4461,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (thus the frequency resolution) and the number of good tapers. See notes for additional information.""" -docdict[ - "time_bandwidth_tfr_notes" -] = r""" +docdict["time_bandwidth_tfr_notes"] = r""" In MNE-Python's multitaper functions, the frequency bandwidth is additionally affected by the parameter ``time_bandwidth``. The ``n_cycles`` parameter determines the temporal window length based on the @@ -4659,9 +4497,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): example above, the half-frequency bandwidth is 2 Hz. """ -docdict[ - "time_format" -] = """ +docdict["time_format"] = """ time_format : 'float' | 'clock' Style of time labels on the horizontal axis. If ``'float'``, labels will be number of seconds from the start of the recording. If ``'clock'``, @@ -4689,9 +4525,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) docdict["time_format_df_raw"] = _time_format_df_base.format(_raw_tf) -docdict[ - "time_label" -] = """ +docdict["time_label"] = """ time_label : str | callable | None Format of the time label (a format string, a function that maps floating point time values to strings, or None for no label). The @@ -4699,69 +4533,66 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): is more than one time point. """ -docdict[ - "time_unit" -] = """\ +docdict["time_unit"] = """\ time_unit : str The units for the time axis, can be "s" (default) or "ms". """ -docdict[ - "time_viewer_brain_screenshot" -] = """ +docdict["time_viewer_brain_screenshot"] = """ time_viewer : bool If True, include time viewer traces. Only used if ``time_viewer=True`` and ``separate_canvas=False``. """ -docdict[ - "title_none" -] = """ +docdict["timefreqs"] = """ +timefreqs : None | list of tuple | dict of tuple + The time-frequency point(s) for which topomaps will be plotted. See Notes. +""" + +docdict["times"] = """ +times : ndarray, shape (n_times,) + The time values in seconds. +""" + +docdict["title_none"] = """ title : str | None The title of the generated figure. If ``None`` (default), no title is displayed. """ - -docdict[ - "tmax_raw" -] = """ +docdict["title_tfr_plot"] = """ +title : str | 'auto' | None + Title for the plot. If ``"auto"``, will use the channel name (if ``combine`` is + ``None``) or state the number and method of combined channels used to generate the + plot. If ``None``, no title is shown. Default is ``None``. +""" +docdict["tmax_raw"] = """ tmax : float End time of the raw data to use in seconds (cannot exceed data duration). """ -docdict[ - "tmin" -] = """ +docdict["tmin"] = """ tmin : scalar Time point of the first sample in data. """ -docdict[ - "tmin_epochs" -] = """ +docdict["tmin_epochs"] = """ tmin : float Start time before event. If nothing provided, defaults to 0. """ -docdict[ - "tmin_raw" -] = """ +docdict["tmin_raw"] = """ tmin : float Start time of the raw data to use in seconds (must be >= 0). """ -docdict[ - "tmin_tmax_psd" -] = """\ +docdict["tmin_tmax_psd"] = """\ tmin, tmax : float | None First and last times to include, in seconds. ``None`` uses the first or last time present in the data. Default is ``tmin=None, tmax=None`` (all times). """ -docdict[ - "tol_kind_rank" -] = """ +docdict["tol_kind_rank"] = """ tol_kind : str Can be: "absolute" (default) or "relative". Only used if ``tol`` is a float, because when ``tol`` is a string the mode is implicitly relative. @@ -4780,9 +4611,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.21.0 """ -docdict[ - "tol_rank" -] = """ +docdict["tol_rank"] = """ tol : float | 'auto' Tolerance for singular values to consider non-zero in calculating the rank. The singular values are calculated @@ -4791,38 +4620,38 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): same thresholding as :func:`scipy.linalg.orth`. """ -docdict[ - "topomap_kwargs" -] = """ -topomap_kwargs : dict | None - Keyword arguments to pass to the topomap-generating functions. +_topomap_args_template = """ +{param} : dict | None + Keyword arguments to pass to {func}.{extra} """ +docdict["topomap_args"] = _topomap_args_template.format( + param="topomap_args", + func=":func:`mne.viz.plot_topomap`", + extra=" ``axes`` and ``show`` are ignored. If ``times`` is not in this dict, " + "automatic peak detection is used. Beyond that, if ``None``, no customizable " + "arguments will be passed. Defaults to ``None`` (i.e., an empty :class:`dict`).", +) +docdict["topomap_kwargs"] = _topomap_args_template.format( + param="topomap_kwargs", func="the topomap-generating functions", extra="" +) _trans_base = """\ If str, the path to the head<->MRI transform ``*-trans.fif`` file produced during coregistration. Can also be ``'fsaverage'`` to use the built-in fsaverage transformation.""" -docdict[ - "trans" -] = f""" +docdict["trans"] = f""" trans : path-like | dict | instance of Transform | ``"fsaverage"`` | None {_trans_base} If trans is None, an identity matrix is assumed. """ -docdict[ - "trans_not_none" -] = """ +docdict["trans_not_none"] = f""" trans : str | dict | instance of Transform - %s -""" % ( - _trans_base, -) + {_trans_base} +""" -docdict[ - "transparent" -] = """ +docdict["transparent"] = """ transparent : bool | None If True: use a linear transparency between fmin and fmid and make values below fmin fully transparent (symmetrically for @@ -4830,17 +4659,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): type. """ -docdict[ - "tstart_ecg" -] = """ +docdict["tstart_ecg"] = """ tstart : float Start ECG detection after ``tstart`` seconds. Useful when the beginning of the run is noisy. """ -docdict[ - "tstep" -] = """ +docdict["tstep"] = """ tstep : scalar Time step between successive samples in data. """ @@ -4848,18 +4673,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # U -docdict[ - "ui_event_name_source" -] = """ +docdict["ui_event_name_source"] = """ name : str The name of the event (same as its class name but in snake_case). source : matplotlib.figure.Figure | Figure3D The figure that published the event. """ -docdict[ - "uint16_codec" -] = """ +docdict["uint16_codec"] = """ uint16_codec : str | None If your set file contains non-ascii characters, sometimes reading it may fail and give rise to error message stating that "buffer is @@ -4868,9 +4689,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): can therefore help you solve this problem. """ -docdict[ - "units" -] = """ +docdict["units"] = """ units : str | dict | None Specify the unit(s) that the data should be returned in. If ``None`` (default), the data is returned in the @@ -4889,9 +4708,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): channel-type-specific default unit. """ -docdict[ - "units_edf_bdf_io" -] = """ +docdict["units_edf_bdf_io"] = """ units : dict | str The units of the channels as stored in the file. This argument is useful only if the units are missing from the original file. @@ -4910,17 +4727,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): "dict | ", "and ``scalings=None`` the unit is automatically determined, otherwise " ) -docdict[ - "use_cps" -] = """ +docdict["use_cps"] = """ use_cps : bool Whether to use cortical patch statistics to define normal orientations for surfaces (default True). """ -docdict[ - "use_cps_restricted" -] = """ +docdict["use_cps_restricted"] = """ use_cps : bool Whether to use cortical patch statistics to define normal orientations for surfaces (default True). @@ -4929,9 +4742,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): not in surface orientation, and ``pick_ori='normal'``. """ -docdict[ - "use_opengl" -] = """ +docdict["use_opengl"] = """ use_opengl : bool | None Whether to use OpenGL when rendering the plot (requires ``pyopengl``). May increase performance, but effect is dependent on system CPU and @@ -4946,9 +4757,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # V -docdict[ - "vector_pctf" -] = """ +docdict["vector_pctf"] = """ vector : bool Whether to return PSF/CTF as vector source estimate (3 values per location) or source estimate object (1 intensity value per location). @@ -4958,9 +4767,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 1.2 """ -docdict[ - "verbose" -] = """ +docdict["verbose"] = """ verbose : bool | str | int | None Control verbosity of the logging output. If ``None``, use the default verbosity level. See the :ref:`logging documentation ` and @@ -4968,17 +4775,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): argument. """ -docdict[ - "vertices_volume" -] = """ +docdict["vertices_volume"] = """ vertices : list of array of int The indices of the dipoles in the source space. Should be a single array of shape (n_dipoles,) unless there are subvolumes. """ -docdict[ - "view" -] = """ +docdict["view"] = """ view : str | None The name of the view to show (e.g. "lateral"). Other arguments take precedence and modify the camera starting from the ``view``. @@ -4986,71 +4789,104 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): string shortcut options. """ -docdict[ - "view_layout" -] = """ +docdict["view_layout"] = """ view_layout : str Can be "vertical" (default) or "horizontal". When using "horizontal" mode, the PyVista backend must be used and hemi cannot be "split". """ -docdict[ - "views" -] = """ +docdict["views"] = """ views : str | list View to use. Using multiple views (list) is not supported for mpl backend. See :meth:`Brain.show_view ` for valid string options. """ -_vlim = """ -vlim : tuple of length 2{} - Colormap limits to use. If a :class:`tuple` of floats, specifies the - lower and upper bounds of the colormap (in that order); providing - ``None`` for either entry will set the corresponding boundary at the - min/max of the data{}. {}{}{}Defaults to ``(None, None)``. -""" -_vlim_joint = _vlim.format( - " | 'joint'", - " (separately for each {0})", - "{1}", - "If ``vlim='joint'``, will compute the colormap limits jointly across " - "all {0}s of the same channel type, using the min/max of the data for " - "that channel type. ", - "{2}", +_vlim = """\ +vlim : tuple of length 2{joint_param} + Lower and upper bounds of the colormap, typically a numeric value in the same + units as the data. {callable} + If both entries are ``None``, the bounds are set at {bounds}. + Providing ``None`` for just one entry will set the corresponding boundary at the + min/max of the data. {extra}Defaults to ``(None, None)``. +""" +_joint_param = ' | "joint"' +_callable_sentence = """Elements of the :class:`tuple` may also be callable functions + which take in a :class:`NumPy array ` and return a scalar. +""" +_bounds_symmetric = """± the maximum absolute value + of the data (yielding a colormap with midpoint at 0)""" +_bounds_minmax = "``(min(data), max(data))``" +_bounds_norm = "``(0, max(abs(data)))``" +_bounds_contingent = f"""{_bounds_symmetric}, or {_bounds_norm} + if the (possibly baselined) data are all-positive""" +_joint_sentence = """If ``vlim="joint"``, will compute the colormap limits + jointly across all {what}s of the same channel type (instead of separately + for each {what}), using the min/max of the data for that channel type. + {joint_extra}""" + +docdict["vlim_plot_topomap"] = _vlim.format( + joint_param="", callable="", bounds=_bounds_minmax, extra="" +) +docdict["vlim_plot_topomap_proj"] = _vlim.format( + joint_param=_joint_param, + callable=_callable_sentence, + bounds=_bounds_contingent, + extra=_joint_sentence.format( + what="projector", + joint_extra='If vlim is ``"joint"``, ``info`` must not be ``None``. ', + ), ) -_vlim_callable = ( - "Elements of the :class:`tuple` may also be callable functions which " - "take in a :class:`NumPy array ` and return a scalar. " +docdict["vlim_plot_topomap_psd"] = _vlim.format( + joint_param=_joint_param, + callable=_callable_sentence, + bounds=_bounds_contingent, + extra=_joint_sentence.format(what="topomap", joint_extra=""), ) - -docdict["vlim_plot_topomap"] = _vlim.format("", "", "", "", "") -docdict["vlim_plot_topomap_proj"] = _vlim_joint.format( - "projector", - _vlim_callable, - "If vlim is ``'joint'``, ``info`` must not be ``None``. ", +docdict["vlim_tfr_plot"] = _vlim.format( + joint_param="", callable="", bounds=_bounds_contingent, extra="" +) +docdict["vlim_tfr_plot_joint"] = _vlim.format( + joint_param="", + callable="", + bounds=_bounds_contingent, + extra="""To specify the colormap separately for the topomap annotations, + see ``topomap_args``. """, ) -docdict["vlim_plot_topomap_psd"] = _vlim_joint.format("topomap", _vlim_callable, "") -docdict[ - "vmin_vmax_topomap" -] = """ -vmin, vmax : float | callable | None +_vmin_vmax_template = """ +vmin, vmax : float | {allowed}None Lower and upper bounds of the colormap, in the same units as the data. - If ``vmin`` and ``vmax`` are both ``None``, they are set at ± the - maximum absolute value of the data (yielding a colormap with midpoint - at 0). If only one of ``vmin``, ``vmax`` is ``None``, will use - ``min(data)`` or ``max(data)``, respectively. If callable, should - accept a :class:`NumPy array ` of data and return a - float. + If ``vmin`` and ``vmax`` are both ``None``, the bounds are set at + {bounds}. If only one of ``vmin``, ``vmax`` is ``None``, will use + ``min(data)`` or ``max(data)``, respectively.{extra} """ +docdict["vmin_vmax_tfr_plot"] = """ +vmin, vmax : float | None + Lower and upper bounds of the colormap. See ``vlim``. + + .. deprecated:: 1.7 + ``vmin`` and ``vmax`` will be removed in version 1.8. + Use ``vlim`` parameter instead. +""" +# ↓↓↓ this one still used, needs helper func refactor before we can migrate to `vlim` +docdict["vmin_vmax_tfr_plot_topo"] = _vmin_vmax_template.format( + allowed="", bounds=_bounds_symmetric, extra="" +) +# ↓↓↓ this one still used in Evoked.animate_topomap(), should migrate to `vlim` +docdict["vmin_vmax_topomap"] = _vmin_vmax_template.format( + allowed="callable | ", + bounds=_bounds_symmetric, + extra=""" If callable, should accept + a :class:`NumPy array ` of data and return a :class:`float`.""", +) + + # %% # W -docdict[ - "weight_norm" -] = """ +docdict["weight_norm"] = """ weight_norm : str | None Can be: @@ -5082,27 +4918,25 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ -docdict[ - "window_psd" -] = """\ +docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. """ -docdict[ - "window_resample" -] = """ +docdict["window_resample"] = """ window : str | tuple - Frequency-domain window to use in resampling. - See :func:`scipy.signal.resample`. + When ``method="fft"``, this is the *frequency-domain* window to use in resampling, + and should be the same length as the signal; see :func:`scipy.signal.resample` + for details. When ``method="polyphase"``, this is the *time-domain* linear-phase + window to use after upsampling the signal; see :func:`scipy.signal.resample_poly` + for details. The default ``"auto"`` will use ``"boxcar"`` for ``method="fft"`` and + ``("kaiser", 5.0)`` for ``method="polyphase"``. """ # %% # X -docdict[ - "xscale_plot_psd" -] = """\ +docdict["xscale_plot_psd"] = """\ xscale : 'linear' | 'log' Scale of the frequency axis. Default is ``'linear'``. """ @@ -5110,6 +4944,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # Y +docdict["yscale_tfr_plot"] = """ +yscale : 'auto' | 'linear' | 'log' + The scale of the y (frequency) axis. 'linear' gives linear y axis, 'log' gives + log-spaced y axis and 'auto' detects if frequencies are log-spaced and if so sets + the y axis to 'log'. Default is 'auto'. +""" + # %% # Z @@ -5161,7 +5002,7 @@ def fill_doc(f): except (TypeError, ValueError, KeyError) as exp: funcname = f.__name__ funcname = docstring.split("\n")[0] if funcname is None else funcname - raise RuntimeError("Error documenting %s:\n%s" % (funcname, str(exp))) + raise RuntimeError(f"Error documenting {funcname}:\n{str(exp)}") return f @@ -5181,12 +5022,12 @@ def copy_doc(source): Parameters ---------- source : function - Function to copy the docstring from + Function to copy the docstring from. Returns ------- wrapper : function - The decorated function + The decorated function. Examples -------- @@ -5463,11 +5304,7 @@ def linkcode_resolve(domain, info): kind = "main" else: kind = "maint/%s" % (".".join(mne.__version__.split(".")[:2])) - return "http://github.com/mne-tools/mne-python/blob/%s/mne/%s%s" % ( - kind, - fn, - linespec, - ) + return f"http://github.com/mne-tools/mne-python/blob/{kind}/mne/{fn}{linespec}" def open_docs(kind=None, version=None): @@ -5499,13 +5336,13 @@ def open_docs(kind=None, version=None): if version is None: version = get_config("MNE_DOCS_VERSION", "stable") _check_option("version", version, ["stable", "dev"]) - webbrowser.open_new_tab("https://mne.tools/%s/%s" % (version, kind)) + webbrowser.open_new_tab(f"https://mne.tools/{version}/{kind}") class _decorator: """Inject code or modify the docstring of a class, method, or function.""" - def __init__(self, extra): # noqa: D102 + def __init__(self, extra): self.kind = self.__class__.__name__ self.extra = extra self.msg = f"NOTE: {{}}() is a {self.kind} {{}}. {self.extra}." @@ -5625,7 +5462,7 @@ class legacy(_decorator): and in a sphinx warning box in the docstring. """ - def __init__(self, alt, extra=""): # noqa: D102 + def __init__(self, alt, extra=""): period = ". " if len(extra) else "" extra = f"New code should use {alt}{period}{extra}" super().__init__(extra=extra) @@ -5693,7 +5530,7 @@ def _docformat(docstring, docdict=None, funcname=None): try: return docstring % indented except (TypeError, ValueError, KeyError) as exp: - raise RuntimeError("Error documenting %s:\n%s" % (funcname, str(exp))) + raise RuntimeError(f"Error documenting {funcname}:\n{str(exp)}") def _indentcount_lines(lines): diff --git a/mne/utils/misc.py b/mne/utils/misc.py index 05d856c0226..a86688ca2a7 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -14,6 +14,7 @@ import traceback import weakref from contextlib import ExitStack, contextmanager +from importlib.resources import files from math import log from queue import Empty, Queue from string import Formatter @@ -26,11 +27,9 @@ from ._logging import logger, verbose, warn from .check import _check_option, _validate_type -# TODO: remove try/except when our min version is py 3.9 -try: - from importlib.resources import files -except ImportError: - from importlib_resources import files + +def _identity_function(x): + return x # TODO: no longer needed when py3.9 is minimum supported version @@ -161,20 +160,7 @@ def run_subprocess(command, return_code=False, verbose=None, *args, **kwargs): break else: out = out.decode("utf-8") - # Strip newline at end of the string, otherwise we'll end - # up with two subsequent newlines (as the logger adds one) - # - # XXX Once we drop support for Python <3.9, uncomment the - # following line and remove the if/else block below. - # - # log_out = out.removesuffix('\n') - if sys.version_info[:2] >= (3, 9): - log_out = out.removesuffix("\n") - elif out.endswith("\n"): - log_out = out[:-1] - else: - log_out = out - + log_out = out.removesuffix("\n") logger.info(log_out) all_out += out @@ -185,19 +171,7 @@ def run_subprocess(command, return_code=False, verbose=None, *args, **kwargs): break else: err = err.decode("utf-8") - # Strip newline at end of the string, otherwise we'll end - # up with two subsequent newlines (as the logger adds one) - # - # XXX Once we drop support for Python <3.9, uncomment the - # following line and remove the if/else block below. - # - # err_out = err.removesuffix('\n') - if sys.version_info[:2] >= (3, 9): - err_out = err.removesuffix("\n") - elif err.endswith("\n"): - err_out = err[:-1] - else: - err_out = err + err_out = err.removesuffix("\n") # Leave this as logger.warning rather than warn(...) to # mirror the logger.info above for stdout. This function @@ -295,9 +269,8 @@ def running_subprocess(command, after="wait", verbose=None, *args, **kwargs): def _clean_names(names, remove_whitespace=False, before_dash=True): """Remove white-space on topo matching. - This function handles different naming - conventions for old VS new VectorView systems (`remove_whitespace`). - Also it allows to remove system specific parts in CTF channel names + This function handles different naming conventions for old VS new VectorView systems + (`remove_whitespace`) and removes system specific parts in CTF channel names (`before_dash`). Usage @@ -307,7 +280,6 @@ def _clean_names(names, remove_whitespace=False, before_dash=True): # for CTF ch_names = _clean_names(epochs.ch_names, before_dash=True) - """ cleaned = [] for name in names: @@ -318,7 +290,10 @@ def _clean_names(names, remove_whitespace=False, before_dash=True): if name.endswith("_v"): name = name[:-2] cleaned.append(name) - + if len(set(cleaned)) != len(names): + # this was probably not a VectorView or CTF dataset, and we now broke the + # dataset by creating duplicates, so let's use the original channel names. + return names return cleaned diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index c90121fdfbb..793e399a69f 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -80,13 +80,13 @@ def __getitem__(self, item): Parameters ---------- - item : slice, array-like, str, or list - See below for use cases. + item : int | slice | array-like | str + See Notes for use cases. Returns ------- epochs : instance of Epochs - See below for use cases. + The subset of epochs. Notes ----- @@ -178,7 +178,7 @@ def _getitem( ---------- item: slice, array-like, str, or list see `__getitem__` for details. - reason: str + reason: str, list/tuple of str entry in `drop_log` for unselected epochs copy: bool return a copy of the current object @@ -197,10 +197,9 @@ def _getitem( `Epochs` or tuple(Epochs, np.ndarray) if `return_indices` is True subset of epochs (and optionally array with kept epoch indices) """ - data = self._data - self._data = None inst = self.copy() if copy else self - self._data = inst._data = data + if self._data is not None: + np.copyto(inst._data, self._data, casting="no") del self select = inst._item_to_select(item) @@ -209,8 +208,15 @@ def _getitem( key_selection = inst.selection[select] drop_log = list(inst.drop_log) if reason is not None: - for k in np.setdiff1d(inst.selection, key_selection): - drop_log[k] = (reason,) + _validate_type(reason, (list, tuple, str), "reason") + if isinstance(reason, (list, tuple)): + for r in reason: + _validate_type(r, str, r) + if isinstance(reason, str): + reason = (reason,) + reason = tuple(reason) + for idx in np.setdiff1d(inst.selection, key_selection): + drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log @@ -281,7 +287,7 @@ def _keys_to_idx(self, keys): except Exception as exp: msg += ( " The epochs.metadata Pandas query did not " - "yield any results: %s" % (exp.args[0],) + f"yield any results: {exp.args[0]}" ) else: return vals @@ -446,7 +452,7 @@ def metadata(self, metadata, verbose=None): action += " existing" else: action = "Not setting" if metadata is None else "Adding" - logger.info("%s metadata%s" % (action, n_col)) + logger.info(f"{action} metadata{n_col}") self._metadata = metadata @@ -674,10 +680,10 @@ def decimate(self, decim, offset=0, *, verbose=None): # appropriately filtered to avoid aliasing from ..epochs import BaseEpochs from ..evoked import Evoked - from ..time_frequency import AverageTFR, EpochsTFR + from ..time_frequency import BaseTFR # This should be the list of classes that inherit - _validate_type(self, (BaseEpochs, Evoked, EpochsTFR, AverageTFR), "inst") + _validate_type(self, (BaseEpochs, Evoked, BaseTFR), "inst") decim, offset, new_sfreq = _check_decim( self.info, decim, offset, check_filter=not hasattr(self, "freqs") ) @@ -748,7 +754,7 @@ def _prepare_write_metadata(metadata): """Convert metadata to JSON for saving.""" if metadata is not None: if not isinstance(metadata, list): - metadata = metadata.to_json(orient="records") + metadata = metadata.reset_index().to_json(orient="records") else: # Pandas DataFrame metadata = json.dumps(metadata) assert isinstance(metadata, str) @@ -765,5 +771,7 @@ def _prepare_read_metadata(metadata): assert isinstance(metadata, list) if pd: metadata = pd.DataFrame.from_records(metadata) + if "index" in metadata.columns: + metadata.set_index("index", inplace=True) assert isinstance(metadata, pd.DataFrame) return metadata diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 64bc4515f93..2f09689917b 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -29,7 +29,12 @@ svd_flip, ) from ._logging import logger, verbose, warn -from .check import _ensure_int, _validate_type, check_random_state +from .check import ( + _check_pandas_installed, + _ensure_int, + _validate_type, + check_random_state, +) from .docs import fill_doc from .misc import _empty_hash @@ -255,9 +260,9 @@ def _get_inst_data(inst): from ..epochs import BaseEpochs from ..evoked import Evoked from ..io import BaseRaw - from ..time_frequency.tfr import _BaseTFR + from ..time_frequency.tfr import BaseTFR - _validate_type(inst, (BaseRaw, BaseEpochs, Evoked, _BaseTFR), "Instance") + _validate_type(inst, (BaseRaw, BaseEpochs, Evoked, BaseTFR), "Instance") if not inst.preload: inst.load_data() return inst._data @@ -321,7 +326,7 @@ def _apply_scaling_array(data, picks_list, scalings, verbose=None): """Scale data type-dependently for estimation.""" scalings = _check_scaling_inputs(data, picks_list, scalings) if isinstance(scalings, dict): - logger.debug(" Scaling using mapping %s." % (scalings,)) + logger.debug(f" Scaling using mapping {scalings}.") picks_dict = dict(picks_list) scalings = [(picks_dict[k], v) for k, v in scalings.items() if k in picks_dict] for idx, scaling in scalings: @@ -493,16 +498,15 @@ def _time_mask( assert include_tmax # can only be used when sfreq is known if raise_error and tmin > tmax: raise ValueError( - "tmin (%s) must be less than or equal to tmax (%s)" % (orig_tmin, orig_tmax) + f"tmin ({orig_tmin}) must be less than or equal to tmax ({orig_tmax})" ) mask = times >= tmin mask &= times <= tmax if raise_error and not mask.any(): extra = "" if include_tmax else "when include_tmax=False " raise ValueError( - "No samples remain when using tmin=%s and tmax=%s %s" - "(original time bounds are [%s, %s])" - % (orig_tmin, orig_tmax, extra, times[0], times[-1]) + f"No samples remain when using tmin={orig_tmin} and tmax={orig_tmax} " + f"{extra}(original time bounds are [{times[0]}, {times[-1]}])" ) return mask @@ -525,15 +529,14 @@ def _freq_mask(freqs, sfreq, fmin=None, fmax=None, raise_error=True): fmax = int(round(fmax * sfreq)) / sfreq + 0.5 / sfreq if raise_error and fmin > fmax: raise ValueError( - "fmin (%s) must be less than or equal to fmax (%s)" % (orig_fmin, orig_fmax) + f"fmin ({orig_fmin}) must be less than or equal to fmax ({orig_fmax})" ) mask = freqs >= fmin mask &= freqs <= fmax if raise_error and not mask.any(): raise ValueError( - "No frequencies remain when using fmin=%s and " - "fmax=%s (original frequency bounds are [%s, %s])" - % (orig_fmin, orig_fmax, freqs[0], freqs[-1]) + f"No frequencies remain when using fmin={orig_fmin} and fmax={orig_fmax} " + f"(original frequency bounds are [{freqs[0]}, {freqs[-1]}])" ) return mask @@ -683,7 +686,7 @@ def object_hash(x, h=None): for xx in x: object_hash(xx, h) else: - raise RuntimeError("unsupported type: %s (%s)" % (type(x), x)) + raise RuntimeError(f"unsupported type: {type(x)} ({x})") return int(h.hexdigest(), 16) @@ -733,7 +736,7 @@ def object_size(x, memo=None): elif sparse.isspmatrix_csc(x) or sparse.isspmatrix_csr(x): size = sum(sys.getsizeof(xx) for xx in [x, x.data, x.indices, x.indptr]) else: - raise RuntimeError("unsupported type: %s (%s)" % (type(x), x)) + raise RuntimeError(f"unsupported type: {type(x)} ({x})") memo[id_] = size return size @@ -778,6 +781,7 @@ def object_diff(a, b, pre="", *, allclose=False): diffs : str A string representation of the differences. """ + pd = _check_pandas_installed(strict=False) out = "" if type(a) != type(b): # Deal with NamedInt and NamedFloat @@ -804,16 +808,16 @@ def object_diff(a, b, pre="", *, allclose=False): ) elif isinstance(a, (list, tuple)): if len(a) != len(b): - out += pre + " length mismatch (%s, %s)\n" % (len(a), len(b)) + out += pre + f" length mismatch ({len(a)}, {len(b)})\n" else: for ii, (xx1, xx2) in enumerate(zip(a, b)): out += object_diff(xx1, xx2, pre + "[%s]" % ii, allclose=allclose) elif isinstance(a, float): if not _array_equal_nan(a, b, allclose): - out += pre + " value mismatch (%s, %s)\n" % (a, b) + out += pre + f" value mismatch ({a}, {b})\n" elif isinstance(a, (str, int, bytes, np.generic)): if a != b: - out += pre + " value mismatch (%s, %s)\n" % (a, b) + out += pre + f" value mismatch ({a}, {b})\n" elif a is None: if b is not None: out += pre + " left is None, right is not (%s)\n" % (b) @@ -830,18 +834,22 @@ def object_diff(a, b, pre="", *, allclose=False): # sparsity and sparse type of b vs a already checked above by type() if b.shape != a.shape: out += pre + ( - " sparse matrix a and b shape mismatch" - "(%s vs %s)" % (a.shape, b.shape) + " sparse matrix a and b shape mismatch" f"({a.shape} vs {b.shape})" ) else: c = a - b c.eliminate_zeros() if c.nnz > 0: out += pre + (" sparse matrix a and b differ on %s " "elements" % c.nnz) + elif pd and isinstance(a, pd.DataFrame): + try: + pd.testing.assert_frame_equal(a, b) + except AssertionError: + out += pre + " DataFrame mismatch\n" elif hasattr(a, "__getstate__") and a.__getstate__() is not None: out += object_diff(a.__getstate__(), b.__getstate__(), pre, allclose=allclose) else: - raise RuntimeError(pre + ": unsupported type %s (%s)" % (type(a), a)) + raise RuntimeError(pre + f": unsupported type {type(a)} ({a})") return out @@ -883,16 +891,16 @@ def _fit(self, X): ) elif not 0 <= n_components <= min(n_samples, n_features): raise ValueError( - "n_components=%r must be between 0 and " - "min(n_samples, n_features)=%r with " - "svd_solver='full'" % (n_components, min(n_samples, n_features)) + f"n_components={repr(n_components)} must be between 0 and " + f"min(n_samples, n_features)={repr(min(n_samples, n_features))} with " + "svd_solver='full'" ) elif n_components >= 1: if not isinstance(n_components, (numbers.Integral, np.integer)): raise ValueError( - "n_components=%r must be of type int " - "when greater than or equal to 1, " - "was of type=%r" % (n_components, type(n_components)) + f"n_components={repr(n_components)} must be of type int " + f"when greater than or equal to 1, " + f"was of type={repr(type(n_components))}" ) self.mean_ = np.mean(X, axis=0) @@ -938,7 +946,7 @@ def _fit(self, X): def _mask_to_onsets_offsets(mask): """Group boolean mask into contiguous onset:offset pairs.""" - assert mask.dtype == bool and mask.ndim == 1 + assert mask.dtype == np.dtype(bool) and mask.ndim == 1 mask = mask.astype(int) diff = np.diff(mask) onsets = np.where(diff > 0)[0] + 1 @@ -1051,7 +1059,7 @@ def _check_dt(dt): or dt.tzinfo is None or dt.tzinfo is not timezone.utc ): - raise ValueError("Date must be datetime object in UTC: %r" % (dt,)) + raise ValueError(f"Date must be datetime object in UTC: {repr(dt)}") def _dt_to_stamp(inp_date): @@ -1102,7 +1110,7 @@ def restore(self, val): try: idx = self.popped.pop(val) except KeyError: - warn("Could not find value: %s" % (val,)) + warn(f"Could not find value: {val}") else: loc = np.searchsorted(self.indices, idx) self.indices.insert(loc, idx) diff --git a/mne/utils/progressbar.py b/mne/utils/progressbar.py index 94f595dd441..b1938c2fac3 100644 --- a/mne/utils/progressbar.py +++ b/mne/utils/progressbar.py @@ -55,7 +55,7 @@ def __init__( *, which_tqdm=None, **kwargs, - ): # noqa: D102 + ): # The following mimics this, but with configurable module to use # from ..externals.tqdm import auto import tqdm @@ -137,8 +137,7 @@ def update_with_increment_value(self, increment_value): def __iter__(self): """Iterate to auto-increment the pbar with 1.""" - for x in self._tqdm: - yield x + yield from self._tqdm def subset(self, idx): """Make a joblib-friendly index subset updater. @@ -188,7 +187,7 @@ def __del__(self): class _UpdateThread(Thread): def __init__(self, pb): - super(_UpdateThread, self).__init__(daemon=True) + super().__init__(daemon=True) self._mne_run = True self._mne_pb = pb diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 5abcb7e3378..67a68b344a7 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -1,3 +1,5 @@ +"""Utility functions for spectral and spectrotemporal analysis.""" + # License: BSD-3-Clause # Copyright the MNE-Python contributors. from inspect import currentframe, getargvalues, signature @@ -5,6 +7,26 @@ from ..utils import warn +def _get_instance_type_string(inst): + """Get string representation of the originating instance type.""" + from ..epochs import BaseEpochs + from ..evoked import Evoked, EvokedArray + from ..io import BaseRaw + + parent_classes = inst._inst_type.__bases__ + if BaseRaw in parent_classes: + inst_type_str = "Raw" + elif BaseEpochs in parent_classes: + inst_type_str = "Epochs" + elif inst._inst_type in (Evoked, EvokedArray): + inst_type_str = "Evoked" + else: + raise RuntimeError( + f"Unknown instance type {inst._inst_type} in {type(inst).__name__}" + ) + return inst_type_str + + def _pop_with_fallback(mapping, key, fallback_fun): """Pop from a dict and fallback to a function parameter's default value.""" fallback = signature(fallback_fun).parameters[key].default diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 4f5f6d5416b..4ec7450df99 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -27,9 +27,11 @@ _check_subject, _on_missing, _path_like, + _record_warnings, _safe_input, _suggest, _validate_type, + catch_logging, check_fname, check_random_state, check_version, @@ -141,12 +143,12 @@ def test_check_info_inv(): assert [1, 2] not in picks # covariance matrix data_cov_bads = data_cov.copy() - data_cov_bads["bads"] = data_cov_bads.ch_names[0] + data_cov_bads["bads"] = [data_cov_bads.ch_names[0]] picks = _check_info_inv(epochs.info, forward, data_cov=data_cov_bads) assert 0 not in picks # noise covariance matrix noise_cov_bads = noise_cov.copy() - noise_cov_bads["bads"] = noise_cov_bads.ch_names[1] + noise_cov_bads["bads"] = [noise_cov_bads.ch_names[1]] picks = _check_info_inv(epochs.info, forward, noise_cov=noise_cov_bads) assert 1 not in picks @@ -164,10 +166,16 @@ def test_check_info_inv(): noise_cov = pick_channels_cov( noise_cov, include=[noise_cov.ch_names[ii] for ii in range(7, 12)] ) - picks = _check_info_inv( - epochs.info, forward, noise_cov=noise_cov, data_cov=data_cov - ) - assert list(range(7, 10)) == picks + with catch_logging() as log: + picks = _check_info_inv( + epochs.info, forward, noise_cov=noise_cov, data_cov=data_cov, verbose=True + ) + assert list(range(7, 10)) == picks + + # make sure to inform the user that 7 channels were dropped + # (there are 10 channels in epochs but only 3 were picked) + log = log.getvalue() + assert "Excluding 7 channel(s) missing" in log def test_check_option(): @@ -361,7 +369,7 @@ def test_check_sphere_verbose(): info = mne.io.read_info(fname_raw) with info._unlock(): info["dig"] = info["dig"][:20] - with pytest.warns(RuntimeWarning, match="may be inaccurate"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="may be inaccurate"): _check_sphere("auto", info) with mne.use_log_level("error"): _check_sphere("auto", info) diff --git a/mne/utils/tests/test_config.py b/mne/utils/tests/test_config.py index ffae55ad08a..e0155638b0d 100644 --- a/mne/utils/tests/test_config.py +++ b/mne/utils/tests/test_config.py @@ -100,17 +100,35 @@ def test_config(tmp_path): pytest.raises(TypeError, _get_stim_channel, [1], None) -def test_sys_info(): +def test_sys_info_basic(): """Test info-showing utility.""" out = ClosingStringIO() sys_info(fid=out, check_version=False) out = out.getvalue() assert "numpy" in out + # replace all in-line whitespace with single space + out = "\n".join(" ".join(o.split()) for o in out.splitlines()) if platform.system() == "Darwin": - assert "Platform macOS-" in out + assert "Platform macOS-" in out elif platform.system() == "Linux": - assert "Platform Linux" in out + assert "Platform Linux" in out + + +def test_sys_info_complete(): + """Test that sys_info is sufficiently complete.""" + tomllib = pytest.importorskip("tomllib") # python 3.11+ + pyproject = Path(__file__).parents[3] / "pyproject.toml" + if not pyproject.is_file(): + pytest.skip("Does not appear to be a dev installation") + out = ClosingStringIO() + sys_info(fid=out, check_version=False, dependencies="developer") + out = out.getvalue() + pyproject = tomllib.loads(pyproject.read_text("utf-8")) + deps = pyproject["project"]["optional-dependencies"]["test_extra"] + for dep in deps: + dep = dep.split("[")[0].split(">")[0] + assert f" {dep}" in out, f"Missing in dev config: {dep}" def test_sys_info_qt_browser(): diff --git a/mne/utils/tests/test_docs.py b/mne/utils/tests/test_docs.py index 0fd13aa25a5..c5ab49d3167 100644 --- a/mne/utils/tests/test_docs.py +++ b/mne/utils/tests/test_docs.py @@ -122,12 +122,12 @@ def m1(): def test_copy_function_doc_to_method_doc(): """Test decorator for re-using function docstring as method docstrings.""" - def f1(object, a, b, c): + def f1(obj, a, b, c): """Docstring for f1. Parameters ---------- - object : object + obj : object Some object. This description also has blank lines in it. @@ -138,7 +138,7 @@ def f1(object, a, b, c): """ pass - def f2(object): + def f2(obj): """Docstring for f2. Parameters @@ -152,7 +152,7 @@ def f2(object): """ pass - def f3(object): + def f3(obj): """Docstring for f3. Parameters @@ -162,11 +162,11 @@ def f3(object): """ pass - def f4(object): + def f4(obj): """Docstring for f4.""" pass - def f5(object): # noqa: D410, D411, D414 + def f5(obj): # noqa: D410, D411, D414 """Docstring for f5. Parameters diff --git a/mne/utils/tests/test_logging.py b/mne/utils/tests/test_logging.py index 81613749aaf..25668a1de37 100644 --- a/mne/utils/tests/test_logging.py +++ b/mne/utils/tests/test_logging.py @@ -26,7 +26,7 @@ ) from mne.utils._logging import _frame_info -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" fname_raw = base_dir / "test_raw.fif" fname_evoked = base_dir / "test-ave.fif" fname_log = base_dir / "test-ave.log" @@ -63,7 +63,7 @@ def test_frame_info(capsys, monkeypatch): def test_how_to_deal_with_warnings(): """Test filter some messages out of warning records.""" - with pytest.warns(UserWarning, match="bb") as w: + with pytest.warns(Warning, match="(bb|aa) warning") as w: warnings.warn("aa warning", UserWarning) warnings.warn("bb warning", UserWarning) warnings.warn("bb warning", RuntimeWarning) @@ -73,7 +73,7 @@ def test_how_to_deal_with_warnings(): assert len(w) == 1 -def clean_lines(lines=[]): +def clean_lines(lines=()): """Scrub filenames for checking logging output (in test_logging).""" return [line if "Reading " not in line else "Reading test file" for line in lines] @@ -84,11 +84,11 @@ def test_logging_options(tmp_path): with pytest.raises(ValueError, match="Invalid value for the 'verbose"): set_log_level("foo") test_name = tmp_path / "test.log" - with open(fname_log, "r") as old_log_file: + with open(fname_log) as old_log_file: # [:-1] used to strip an extra "No baseline correction applied" old_lines = clean_lines(old_log_file.readlines()) old_lines.pop(-1) - with open(fname_log_2, "r") as old_log_file_2: + with open(fname_log_2) as old_log_file_2: old_lines_2 = clean_lines(old_log_file_2.readlines()) old_lines_2.pop(14) old_lines_2.pop(-1) @@ -112,7 +112,7 @@ def test_logging_options(tmp_path): assert fid.readlines() == [] # SHOULD print evoked = read_evokeds(fname_evoked, condition=1, verbose=True) - with open(test_name, "r") as new_log_file: + with open(test_name) as new_log_file: new_lines = clean_lines(new_log_file.readlines()) assert new_lines == old_lines set_log_file(None) # Need to do this to close the old file @@ -131,7 +131,7 @@ def test_logging_options(tmp_path): assert fid.readlines() == [] # SHOULD print evoked = read_evokeds(fname_evoked, condition=1) - with open(test_name, "r") as new_log_file: + with open(test_name) as new_log_file: new_lines = clean_lines(new_log_file.readlines()) assert new_lines == old_lines # check to make sure appending works (and as default, raises a warning) @@ -139,7 +139,7 @@ def test_logging_options(tmp_path): with pytest.warns(RuntimeWarning, match="appended to the file"): set_log_file(test_name) evoked = read_evokeds(fname_evoked, condition=1) - with open(test_name, "r") as new_log_file: + with open(test_name) as new_log_file: new_lines = clean_lines(new_log_file.readlines()) assert new_lines == old_lines_2 @@ -148,7 +148,7 @@ def test_logging_options(tmp_path): # this line needs to be called to actually do some logging evoked = read_evokeds(fname_evoked, condition=1) del evoked - with open(test_name, "r") as new_log_file: + with open(test_name) as new_log_file: new_lines = clean_lines(new_log_file.readlines()) assert new_lines == old_lines with catch_logging() as log: diff --git a/mne/utils/tests/test_misc.py b/mne/utils/tests/test_misc.py index 06b29964dd1..4168101fab3 100644 --- a/mne/utils/tests/test_misc.py +++ b/mne/utils/tests/test_misc.py @@ -8,7 +8,7 @@ import pytest import mne -from mne.utils import catch_logging, run_subprocess, sizeof_fmt +from mne.utils import _clean_names, catch_logging, run_subprocess, sizeof_fmt def test_sizeof_fmt(): @@ -144,3 +144,16 @@ def remove_traceback(log): other = stdout assert std == want assert other == "" + + +def test_clean_names(): + """Test cleaning names on OPM dataset. + + This channel name list is a subset from a user OPM dataset reported on the forum + https://mne.discourse.group/t/error-when-trying-to-plot-projectors-ssp/8456 + where the function _clean_names ended up creating a duplicate channel name L108_bz. + """ + ch_names = ["R305_bz-s2", "L108_bz-s77", "R112_bz-s109", "L108_bz-s110"] + ch_names_clean = _clean_names(ch_names, before_dash=True) + assert ch_names == ch_names_clean + assert len(set(ch_names_clean)) == len(ch_names_clean) diff --git a/mne/utils/tests/test_numerics.py b/mne/utils/tests/test_numerics.py index 12f366776ea..40560d42cc1 100644 --- a/mne/utils/tests/test_numerics.py +++ b/mne/utils/tests/test_numerics.py @@ -45,7 +45,7 @@ ) from mne.utils.numerics import _LRU_CACHE_MAXSIZES, _LRU_CACHES -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" fname_raw = base_dir / "test_raw.fif" ave_fname = base_dir / "test-ave.fif" cov_fname = base_dir / "test-cov.fif" @@ -318,7 +318,7 @@ def test_object_size(): (200, 900, sparse.eye(20, format="csr")), ): size = object_size(obj) - assert lower < size < upper, "%s < %s < %s:\n%s" % (lower, size, upper, obj) + assert lower < size < upper, f"{lower} < {size} < {upper}:\n{obj}" # views work properly x = dict(a=1) assert object_size(x) < 1000 @@ -450,7 +450,7 @@ def test_pca(n_components, whiten): assert_array_equal(X, X_orig) X_mne = pca_mne.fit_transform(X) assert_array_equal(X, X_orig) - assert_allclose(X_skl, X_mne) + assert_allclose(X_skl, X_mne * np.sign(np.sum(X_skl * X_mne, axis=0))) assert pca_mne.n_components_ == pca_skl.n_components_ for key in ( "mean_", @@ -459,6 +459,10 @@ def test_pca(n_components, whiten): "explained_variance_ratio_", ): val_skl, val_mne = getattr(pca_skl, key), getattr(pca_mne, key) + if key == "components_": + val_mne = val_mne * np.sign( + np.sum(val_skl * val_mne, axis=1, keepdims=True) + ) assert_allclose(val_skl, val_mne) if isinstance(n_components, float): assert pca_mne.n_components_ == n_dim - 1 diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 5f68fd0a46e..3bee0c4fd29 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -198,7 +198,7 @@ def plot_head_positions( if p.ndim != 2 or p.shape[1] != 10: raise ValueError( "pos (or each entry in pos if a list) must be " - "dimension (N, 10), got %s" % (p.shape,) + f"dimension (N, 10), got {p.shape}" ) if ii > 0: # concatenation p[:, 0] += pos[ii - 1][-1, 0] - p[0, 0] @@ -233,7 +233,7 @@ def plot_head_positions( else: axes = np.array(axes) if axes.shape != (3, 2): - raise ValueError("axes must have shape (3, 2), got %s" % (axes.shape,)) + raise ValueError(f"axes must have shape (3, 2), got {axes.shape}") fig = axes[0, 0].figure labels = ["xyz", ("$q_1$", "$q_2$", "$q_3$")] @@ -1793,12 +1793,11 @@ def _process_clim(clim, colormap, transparent, data=0.0, allow_pos_lims=True): key = "lims" clim = {"kind": "percent", key: [96, 97.5, 99.95]} if not isinstance(clim, dict): - raise ValueError('"clim" must be "auto" or dict, got %s' % (clim,)) + raise ValueError(f'"clim" must be "auto" or dict, got {clim}') if ("lims" in clim) + ("pos_lims" in clim) != 1: raise ValueError( - "Exactly one of lims and pos_lims must be specified " - "in clim, got %s" % (clim,) + "Exactly one of lims and pos_lims must be specified " f"in clim, got {clim}" ) if "pos_lims" in clim and not allow_pos_lims: raise ValueError('Cannot use "pos_lims" for clim, use "lims" ' "instead") @@ -1806,17 +1805,17 @@ def _process_clim(clim, colormap, transparent, data=0.0, allow_pos_lims=True): ctrl_pts = np.array(clim["pos_lims" if diverging else "lims"], float) ctrl_pts = np.array(ctrl_pts, float) if ctrl_pts.shape != (3,): - raise ValueError("clim has shape %s, it must be (3,)" % (ctrl_pts.shape,)) + raise ValueError(f"clim has shape {ctrl_pts.shape}, it must be (3,)") if (np.diff(ctrl_pts) < 0).any(): raise ValueError( - "colormap limits must be monotonically " "increasing, got %s" % (ctrl_pts,) + f"colormap limits must be monotonically increasing, got {ctrl_pts}" ) clim_kind = clim.get("kind", "percent") _check_option("clim['kind']", clim_kind, ["value", "values", "percent"]) if clim_kind == "percent": perc_data = np.abs(data) if diverging else data ctrl_pts = np.percentile(perc_data, ctrl_pts) - logger.info("Using control points %s" % (ctrl_pts,)) + logger.info(f"Using control points {ctrl_pts}") assert len(ctrl_pts) == 3 clim = dict(kind="value") clim["pos_lims" if diverging else "lims"] = ctrl_pts @@ -2187,8 +2186,7 @@ def link_brains(brains, time=True, camera=False, colorbar=True, picking=False): if _get_3d_backend() != "pyvistaqt": raise NotImplementedError( - "Expected 3d backend is pyvistaqt but" - " {} was given.".format(_get_3d_backend()) + f"Expected 3d backend is pyvistaqt but {_get_3d_backend()} was given." ) from ._brain import Brain, _LinkViewer @@ -2198,9 +2196,7 @@ def link_brains(brains, time=True, camera=False, colorbar=True, picking=False): raise ValueError("The collection of brains is empty.") for brain in brains: if not isinstance(brain, Brain): - raise TypeError( - "Expected type is Brain but" " {} was given.".format(type(brain)) - ) + raise TypeError("Expected type is Brain but" f" {type(brain)} was given.") # enable time viewer if necessary brain.setup_time_viewer() subjects = [brain._subject for brain in brains] @@ -2520,7 +2516,7 @@ def _plot_stc( if overlay_alpha == 0: smoothing_steps = 1 # Disable smoothing to save time. - title = subject if len(hemis) > 1 else "%s - %s" % (subject, hemis[0]) + title = subject if len(hemis) > 1 else f"{subject} - {hemis[0]}" kwargs = { "subject": subject, "hemi": hemi, @@ -2936,7 +2932,7 @@ def _onclick(event, params, verbose=None): time_sl = slice(0, None) else: initial_time = float(initial_time) - logger.info("Fixing initial time: %s s" % (initial_time,)) + logger.info(f"Fixing initial time: {initial_time} s") initial_time = np.argmin(np.abs(stc.times - initial_time)) time_sl = slice(initial_time, initial_time + 1) if initial_pos is None: # find max pos and (maybe) time @@ -2949,10 +2945,10 @@ def _onclick(event, params, verbose=None): if initial_pos.shape != (3,): raise ValueError( "initial_pos must be float ndarray with shape " - "(3,), got shape %s" % (initial_pos.shape,) + f"(3,), got shape {initial_pos.shape}" ) initial_pos *= 1000 - logger.info("Fixing initial position: %s mm" % (initial_pos.tolist(),)) + logger.info(f"Fixing initial position: {initial_pos.tolist()} mm") loc_idx = _cut_coords_to_idx(initial_pos, img) if initial_time is not None: # time also specified time_idx = time_sl.start @@ -4000,14 +3996,10 @@ def _plot_dipole( coord_frame_name = "Head" if coord_frame == "head" else "MRI" if title is None: - title = "Dipole #%s / %s @ %.3fs, GOF: %.1f%%, %.1fnAm\n%s: " % ( - idx + 1, - len(dipole.times), - dipole.times[idx], - dipole.gof[idx], - dipole.amplitude[idx] * 1e9, - coord_frame_name, - ) + "(%0.1f, %0.1f, %0.1f) mm" % tuple(xyz[idx]) + title = f"Dipole #{idx + 1} / {len(dipole.times)} @ {dipole.times[idx]:.3f}s, " + f"GOF: {dipole.gof[idx]:.1f}%, {dipole.amplitude[idx] * 1e9:.1f}nAm\n" + f"{coord_frame_name}: " + f"({xyz[idx][0]:0.1f}, {xyz[idx][1]:0.1f}, " + f"{xyz[idx][2]:0.1f}) mm" ax.get_figure().suptitle(title) diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 48baff23d1e..203cb686360 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -150,11 +150,7 @@ def remove_overlay(self, names): def _apply(self): if self._current_colors is None or self._renderer is None: return - self._renderer._set_mesh_scalars( - mesh=self._polydata, - scalars=self._current_colors, - name=self._default_scalars_name, - ) + self._polydata[self._default_scalars_name] = self._current_colors def update(self, colors=None): if colors is not None and self._cached_colors is not None: diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index dfebec1f5dc..c58ad7d0e54 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -18,6 +18,7 @@ __all__ = [ "compare_fiff", "concatenate_images", "create_3d_figure", + "eyetracking", "get_3d_backend", "get_brain_class", "get_browser_backend", @@ -86,7 +87,7 @@ __all__ = [ "use_3d_backend", "use_browser_backend", ] -from . import _scraper, backends, ui_events +from . import _scraper, backends, eyetracking, ui_events from ._3d import ( link_brains, plot_alignment, diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 4725e8664aa..da5ca5c3cd1 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -1573,7 +1573,7 @@ def plot_time_course(self, hemi, vertex_id, color, update=True): mni = " MNI: " + ", ".join("%5.1f" % m for m in mni) else: mni = "" - label = "{}:{}{}".format(hemi_str, str(vertex_id).ljust(6), mni) + label = f"{hemi_str}:{str(vertex_id).ljust(6)}{mni}" act_data, smooth = self.act_data_smooth[hemi] if smooth is not None: act_data = smooth[vertex_id].dot(act_data)[0] @@ -1880,8 +1880,8 @@ def add_data( time = np.asarray(time) if time.shape != (array.shape[-1],): raise ValueError( - "time has shape %s, but need shape %s " - "(array.shape[-1])" % (time.shape, (array.shape[-1],)) + f"time has shape {time.shape}, but need shape " + f"{(array.shape[-1],)} (array.shape[-1])" ) self._data["time"] = time @@ -1907,8 +1907,8 @@ def add_data( if array.ndim == 3: if array.shape[1] != 3: raise ValueError( - "If array has 3 dimensions, array.shape[1] " - "must equal 3, got %s" % (array.shape[1],) + "If array has 3 dimensions, array.shape[1] must equal 3, got " + f"{array.shape[1]}" ) fmin, fmid, fmax = _update_limits(fmin, fmid, fmax, center, array) if colormap == "auto": @@ -1921,13 +1921,13 @@ def add_data( elif isinstance(smoothing_steps, int): if smoothing_steps < 0: raise ValueError( - "Expected value of `smoothing_steps` is" - " positive but {} was given.".format(smoothing_steps) + "Expected value of `smoothing_steps` is positive but " + f"{smoothing_steps} was given." ) else: raise TypeError( - "Expected type of `smoothing_steps` is int or" - " NoneType but {} was given.".format(type(smoothing_steps)) + "Expected type of `smoothing_steps` is int or NoneType but " + f"{type(smoothing_steps)} was given." ) self._data["stc"] = stc @@ -2183,8 +2183,6 @@ def add_label( borders=False, hemi=None, subdir=None, - *, - reset_camera=None, ): """Add an ROI label to the image. @@ -2216,8 +2214,6 @@ def add_label( label directory rather than in the label directory itself (e.g. for ``$SUBJECTS_DIR/$SUBJECT/label/aparc/lh.cuneus.label`` ``brain.add_label('cuneus', subdir='aparc')``). - reset_camera : bool - Deprecated. Use :meth:`show_view` instead. Notes ----- @@ -2324,12 +2320,6 @@ def add_label( keep_idx = np.unique(keep_idx) show[keep_idx] = 1 scalars *= show - if reset_camera is not None: - warn( - "reset_camera is deprecated and will be removed in 1.7, " - "use show_view instead", - FutureWarning, - ) for _, _, v in self._iter_views(hemi): mesh = self._layered_meshes[hemi] mesh.add_overlay( @@ -2927,6 +2917,8 @@ def add_text( name = text if name is None else name if "text" in self._actors and name in self._actors["text"]: raise ValueError(f"Text with the name {name} already exists") + if color is None: + color = self._fg_color for ri, ci, _ in self._iter_views("vol"): if (row is None or row == ri) and (col is None or col == ci): actor = self._renderer.text2d( @@ -3470,9 +3462,9 @@ def set_data_smoothing(self, n_steps): vertices = hemi_data["vertices"] if vertices is None: raise ValueError( - "len(data) < nvtx (%s < %s): the vertices " + f"len(data) < nvtx ({len(hemi_data)} < " + f"{self.geo[hemi].x.shape[0]}): the vertices " "parameter must not be None" - % (len(hemi_data), self.geo[hemi].x.shape[0]) ) morph_n_steps = "nearest" if n_steps == -1 else n_steps with use_log_level(False): @@ -3942,8 +3934,8 @@ def _make_movie_frames( tmin = self._times[0] elif tmin < self._times[0]: raise ValueError( - "tmin=%r is smaller than the first time point " - "(%r)" % (tmin, self._times[0]) + f"tmin={repr(tmin)} is smaller than the first time point " + f"({repr(self._times[0])})" ) # find indexes at which to create frames @@ -3951,8 +3943,8 @@ def _make_movie_frames( tmax = self._times[-1] elif tmax > self._times[-1]: raise ValueError( - "tmax=%r is greater than the latest time point " - "(%r)" % (tmax, self._times[-1]) + f"tmax={repr(tmax)} is greater than the latest time point " + f"({repr(self._times[-1])})" ) n_frames = floor((tmax - tmin) * time_dilation * framerate) times = np.arange(n_frames, dtype=float) @@ -3964,7 +3956,7 @@ def _make_movie_frames( if n_times == 0: raise ValueError("No time points selected") - logger.debug("Save movie for time points/samples\n%s\n%s" % (times, time_idx)) + logger.debug(f"Save movie for time points/samples\n{times}\n{time_idx}") # Sometimes the first screenshot is rendered with a different # resolution on OS X self.screenshot(time_viewer=time_viewer) @@ -4135,9 +4127,9 @@ def _update_limits(fmin, fmid, fmax, center, array): fmid = (fmin + fmax) / 2.0 if fmin >= fmid: - raise RuntimeError("min must be < mid, got %0.4g >= %0.4g" % (fmin, fmid)) + raise RuntimeError(f"min must be < mid, got {fmin:0.4g} >= {fmid:0.4g}") if fmid >= fmax: - raise RuntimeError("mid must be < max, got %0.4g >= %0.4g" % (fmid, fmax)) + raise RuntimeError(f"mid must be < max, got {fmid:0.4g} >= {fmax:0.4g}") return fmin, fmid, fmax diff --git a/mne/viz/_brain/colormap.py b/mne/viz/_brain/colormap.py index 0567e352252..31c42456995 100644 --- a/mne/viz/_brain/colormap.py +++ b/mne/viz/_brain/colormap.py @@ -117,9 +117,7 @@ def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=T Color map with transparency channel. """ if not fmin <= fmid <= fmax: - raise ValueError( - "Must have fmin (%s) <= fmid (%s) <= fmax (%s)" % (fmin, fmid, fmax) - ) + raise ValueError(f"Must have fmin ({fmin}) <= fmid ({fmid}) <= fmax ({fmax})") lut_table = create_lut(lut_table) assert lut_table.dtype.kind == "i" divergent = center is not None diff --git a/mne/viz/_brain/surface.py b/mne/viz/_brain/surface.py index f4625c5f019..7f17cebf718 100644 --- a/mne/viz/_brain/surface.py +++ b/mne/viz/_brain/surface.py @@ -120,9 +120,7 @@ def load_geometry(self): None """ if self.surf == "flat": # special case - fname = path.join( - self.data_path, "surf", "%s.%s" % (self.hemi, "cortex.patch.flat") - ) + fname = path.join(self.data_path, "surf", f"{self.hemi}.cortex.patch.flat") _check_fname( fname, overwrite="read", must_exist=True, name="flatmap surface file" ) @@ -184,7 +182,7 @@ def load_curvature(self): else: self.curv = None self.bin_curv = None - color = np.ones((self.coords.shape[0])) + color = np.ones(self.coords.shape[0]) # morphometry (curvature) normalization in order to get gray cortex # TODO: delete self.grey_curv after cortex parameter # will be fully supported diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index f233f83389d..ea470937300 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -1362,7 +1362,7 @@ def _create_testing_brain( rng = np.random.RandomState(0) vertices = [s["vertno"] for s in sample_src] n_verts = sum(len(v) for v in vertices) - stc_data = np.zeros((n_verts * n_time)) + stc_data = np.zeros(n_verts * n_time) stc_size = stc_data.size stc_data[(rng.rand(stc_size // 20) * stc_size).astype(int)] = rng.rand( stc_data.size // 20 diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 9835afa4e2b..da19372d8bc 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1847,7 +1847,7 @@ def _draw_one_scalebar(self, x, y, ch_type): color = "#AA3377" # purple kwargs = dict(color=color, zorder=self.mne.zorder["scalebar"]) if ch_type == "time": - label = f"{self.mne.boundary_times[1]/2:.2f} s" + label = f"{self.mne.boundary_times[1] / 2:.2f} s" text = self.mne.ax_main.text( x[0] + 0.015, y[1] - 0.05, diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index 23cb65c6c44..c31023401ed 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -166,11 +166,11 @@ def mesh( The scalar valued associated to the vertices. vmin : float | None vmin is used to scale the colormap. - If None, the min of the data will be used + If None, the min of the data will be used. vmax : float | None vmax is used to scale the colormap. - If None, the max of the data will be used - colormap : + If None, the max of the data will be used. + colormap : str | np.ndarray | matplotlib.colors.Colormap | None The colormap to use. interpolate_before_map : Enabling makes for a smoother scalars display. Default is True. @@ -225,17 +225,17 @@ def contour( The opacity of the contour. vmin : float | None vmin is used to scale the colormap. - If None, the min of the data will be used + If None, the min of the data will be used. vmax : float | None vmax is used to scale the colormap. - If None, the max of the data will be used - colormap : + If None, the max of the data will be used. + colormap : str | np.ndarray | matplotlib.colors.Colormap | None The colormap to use. normalized_colormap : bool Specify if the values of the colormap are between 0 and 1. kind : 'line' | 'tube' The type of the primitives to use to display the contours. - color : + color : tuple | str The color of the mesh as a tuple (red, green, blue) of float values between 0 and 1 or a valid color name (i.e. 'white' or 'w'). @@ -270,11 +270,11 @@ def surface( The opacity of the surface. vmin : float | None vmin is used to scale the colormap. - If None, the min of the data will be used + If None, the min of the data will be used. vmax : float | None vmax is used to scale the colormap. - If None, the max of the data will be used - colormap : + If None, the max of the data will be used. + colormap : str | np.ndarray | matplotlib.colors.Colormap | None The colormap to use. scalars : ndarray, shape (n_vertices,) The scalar valued associated to the vertices. @@ -354,11 +354,11 @@ def tube( The optional scalar data to use. vmin : float | None vmin is used to scale the colormap. - If None, the min of the data will be used + If None, the min of the data will be used. vmax : float | None vmax is used to scale the colormap. - If None, the max of the data will be used - colormap : + If None, the max of the data will be used. + colormap : str | np.ndarray | matplotlib.colors.Colormap | None The colormap to use. opacity : float The opacity of the tube(s). @@ -446,7 +446,7 @@ def quiver3d( The optional scalar data to use. backface_culling : bool If True, enable backface culling on the quiver. - colormap : + colormap : str | np.ndarray | matplotlib.colors.Colormap | None The colormap to use. vmin : float | None vmin is used to scale the colormap. @@ -518,15 +518,15 @@ def scalarbar(self, source, color="white", title=None, n_labels=4, bgcolor=None) Parameters ---------- - source : + source The object of the scene used for the colormap. - color : + color : tuple | str The color of the label text. title : str | None The title of the scalar bar. n_labels : int | None The number of labels to display on the scalar bar. - bgcolor : + bgcolor : tuple | str The color of the background when there is transparency. """ pass @@ -549,8 +549,6 @@ def set_camera( distance=None, focalpoint=None, roll=None, - *, - reset_camera=None, ): """Configure the camera of the scene. @@ -566,16 +564,9 @@ def set_camera( The focal point of the camera: (x, y, z). roll : float The rotation of the camera along its axis. - reset_camera : bool - Deprecated, used ``distance="auto"`` instead. """ pass - @abstractclassmethod - def reset_camera(self): - """Reset the camera properties.""" - pass - @abstractclassmethod def screenshot(self, mode="rgb", filename=None): """Take a screenshot of the scene. @@ -1145,7 +1136,7 @@ def _dock_add_file_button( desc, func, *, - filter=None, + filter_=None, initial_directory=None, save=False, is_directory=False, @@ -1218,7 +1209,7 @@ def _dialog_create( callback, *, icon="Warning", - buttons=[], + buttons=(), modal=True, window=None, ): diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index 6a9e5a6cf8f..4601ef1fc6a 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -976,7 +976,7 @@ def _dialog_create( callback, *, icon="Warning", - buttons=[], + buttons=(), modal=True, window=None, ): @@ -1202,7 +1202,7 @@ def _dock_add_file_button( desc, func, *, - filter=None, + filter_=None, initial_directory=None, save=False, is_directory=False, diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index b8fca4c995a..b94163b2ec8 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -26,8 +26,6 @@ _check_option, _require_version, _validate_type, - copy_base_doc_to_subclass_doc, - deprecated, warn, ) from ._abstract import Figure3D, _AbstractRenderer @@ -110,7 +108,6 @@ def _init( off_screen=False, notebook=False, splash=False, - multi_samples=None, ): self._plotter = plotter self.display = None @@ -125,7 +122,6 @@ def _init( self.store["shape"] = shape self.store["off_screen"] = off_screen self.store["border"] = False - self.store["multi_samples"] = multi_samples self.store["line_smoothing"] = True self.store["polygon_smoothing"] = True self.store["point_smoothing"] = True @@ -195,7 +191,6 @@ def visible(self, state): self.plotter.render() -@copy_base_doc_to_subclass_doc class _PyVistaRenderer(_AbstractRenderer): """Class managing rendering scene. @@ -237,12 +232,12 @@ def __init__( notebook=notebook, smooth_shading=smooth_shading, splash=splash, - multi_samples=multi_samples, ) self.font_family = "arial" self.tube_n_sides = 20 self.antialias = _get_3d_option("antialias") self.depth_peeling = _get_3d_option("depth_peeling") + self.multi_samples = multi_samples self.smooth_shading = smooth_shading if isinstance(fig, int): saved_fig = _FIGURES.get(fig) @@ -261,15 +256,13 @@ def __init__( if pyvista.OFF_SCREEN: self.figure.store["off_screen"] = True - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - # pyvista theme may enable depth peeling by default so - # we disable it initially to better control the value afterwards - with _disabled_depth_peeling(): - self.plotter = self.figure._build() - self._hide_axes() - self._toggle_antialias() - self._enable_depth_peeling() + # pyvista theme may enable depth peeling by default so + # we disable it initially to better control the value afterwards + with _disabled_depth_peeling(): + self.plotter = self.figure._build() + self._hide_axes() + self._toggle_antialias() + self._enable_depth_peeling() # FIX: https://github.com/pyvista/pyvistaqt/pull/68 if not hasattr(self.plotter, "iren"): @@ -312,9 +305,7 @@ def _loc_to_index(self, loc): def subplot(self, x, y): x = np.max([0, np.min([x, self.figure._nrows - 1])]) y = np.max([0, np.min([y, self.figure._ncols - 1])]) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - self.plotter.subplot(x, y) + self.plotter.subplot(x, y) def scene(self): return self.figure @@ -375,58 +366,56 @@ def polydata( ): from matplotlib.colors import to_rgba_array - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - rgba = False - if color is not None: - # See if we need to convert or not - check_color = to_rgba_array(color) - if len(check_color) == mesh.n_points: - scalars = (check_color * 255).astype("ubyte") - color = None - rgba = True - if isinstance(colormap, np.ndarray): - if colormap.dtype == np.uint8: - colormap = colormap.astype(np.float64) / 255.0 - from matplotlib.colors import ListedColormap - - colormap = ListedColormap(colormap) - if normals is not None: - mesh.point_data["Normals"] = normals - mesh.GetPointData().SetActiveNormals("Normals") - else: - _compute_normals(mesh) - smooth_shading = self.smooth_shading - if representation == "wireframe": - smooth_shading = False # never use smooth shading for wf - rgba = kwargs.pop("rgba", rgba) - actor = _add_mesh( - plotter=self.plotter, - mesh=mesh, - color=color, - scalars=scalars, - edge_color=color, - opacity=opacity, - cmap=colormap, - backface_culling=backface_culling, - rng=[vmin, vmax], - show_scalar_bar=False, - rgba=rgba, - smooth_shading=smooth_shading, - interpolate_before_map=interpolate_before_map, - style=representation, - line_width=line_width, - **kwargs, - ) + rgba = False + if color is not None: + # See if we need to convert or not + check_color = to_rgba_array(color) + if len(check_color) == mesh.n_points: + scalars = (check_color * 255).astype("ubyte") + color = None + rgba = True + if isinstance(colormap, np.ndarray): + if colormap.dtype == np.uint8: + colormap = colormap.astype(np.float64) / 255.0 + from matplotlib.colors import ListedColormap + + colormap = ListedColormap(colormap) + if normals is not None: + mesh.point_data["Normals"] = normals + mesh.GetPointData().SetActiveNormals("Normals") + else: + _compute_normals(mesh) + smooth_shading = self.smooth_shading + if representation == "wireframe": + smooth_shading = False # never use smooth shading for wf + rgba = kwargs.pop("rgba", rgba) + actor = _add_mesh( + plotter=self.plotter, + mesh=mesh, + color=color, + scalars=scalars, + edge_color=color, + opacity=opacity, + cmap=colormap, + backface_culling=backface_culling, + rng=[vmin, vmax], + show_scalar_bar=False, + rgba=rgba, + smooth_shading=smooth_shading, + interpolate_before_map=interpolate_before_map, + style=representation, + line_width=line_width, + **kwargs, + ) - if polygon_offset is not None: - mapper = actor.GetMapper() - mapper.SetResolveCoincidentTopologyToPolygonOffset() - mapper.SetRelativeCoincidentTopologyPolygonOffsetParameters( - polygon_offset, polygon_offset - ) + if polygon_offset is not None: + mapper = actor.GetMapper() + mapper.SetResolveCoincidentTopologyToPolygonOffset() + mapper.SetRelativeCoincidentTopologyPolygonOffsetParameters( + polygon_offset, polygon_offset + ) - return actor, mesh + return actor, mesh def mesh( self, @@ -449,11 +438,9 @@ def mesh( polygon_offset=None, **kwargs, ): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - vertices = np.c_[x, y, z].astype(float) - triangles = np.c_[np.full(len(triangles), 3), triangles] - mesh = PolyData(vertices, triangles) + vertices = np.c_[x, y, z].astype(float) + triangles = np.c_[np.full(len(triangles), 3), triangles] + mesh = PolyData(vertices, triangles) return self.polydata( mesh=mesh, color=color, @@ -485,33 +472,31 @@ def contour( kind="line", color=None, ): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - if colormap is not None: - colormap = _get_colormap_from_array(colormap, normalized_colormap) - vertices = np.array(surface["rr"]) - triangles = np.array(surface["tris"]) - n_triangles = len(triangles) - triangles = np.c_[np.full(n_triangles, 3), triangles] - mesh = PolyData(vertices, triangles) - mesh.point_data["scalars"] = scalars - contour = mesh.contour(isosurfaces=contours) - line_width = width - if kind == "tube": - contour = contour.tube(radius=width, n_sides=self.tube_n_sides) - line_width = 1.0 - actor = _add_mesh( - plotter=self.plotter, - mesh=contour, - show_scalar_bar=False, - line_width=line_width, - color=color, - rng=[vmin, vmax], - cmap=colormap, - opacity=opacity, - smooth_shading=self.smooth_shading, - ) - return actor, contour + if colormap is not None: + colormap = _get_colormap_from_array(colormap, normalized_colormap) + vertices = np.array(surface["rr"]) + triangles = np.array(surface["tris"]) + n_triangles = len(triangles) + triangles = np.c_[np.full(n_triangles, 3), triangles] + mesh = PolyData(vertices, triangles) + mesh.point_data["scalars"] = scalars + contour = mesh.contour(isosurfaces=contours) + line_width = width + if kind == "tube": + contour = contour.tube(radius=width, n_sides=self.tube_n_sides) + line_width = 1.0 + actor = _add_mesh( + plotter=self.plotter, + mesh=contour, + show_scalar_bar=False, + line_width=line_width, + color=color, + rng=[vmin, vmax], + cmap=colormap, + opacity=opacity, + smooth_shading=self.smooth_shading, + ) + return actor, contour def surface( self, @@ -526,13 +511,11 @@ def surface( backface_culling=False, polygon_offset=None, ): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - normals = surface.get("nn", None) - vertices = np.array(surface["rr"]) - triangles = np.array(surface["tris"]) - triangles = np.c_[np.full(len(triangles), 3), triangles] - mesh = PolyData(vertices, triangles) + normals = surface.get("nn", None) + vertices = np.array(surface["rr"]) + triangles = np.array(surface["tris"]) + triangles = np.c_[np.full(len(triangles), 3), triangles] + mesh = PolyData(vertices, triangles) colormap = _get_colormap_from_array(colormap, normalized_colormap) if scalars is not None: mesh.point_data["scalars"] = scalars @@ -567,26 +550,24 @@ def sphere( return None, None _check_option("center.ndim", center.ndim, (1, 2)) _check_option("center.shape[-1]", center.shape[-1], (3,)) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - sphere = vtkSphereSource() - sphere.SetThetaResolution(resolution) - sphere.SetPhiResolution(resolution) - if radius is not None: - sphere.SetRadius(radius) - sphere.Update() - geom = sphere.GetOutput() - mesh = PolyData(center) - glyph = mesh.glyph(orient=False, scale=False, factor=factor, geom=geom) - actor = _add_mesh( - self.plotter, - mesh=glyph, - color=color, - opacity=opacity, - backface_culling=backface_culling, - smooth_shading=self.smooth_shading, - ) - return actor, glyph + sphere = vtkSphereSource() + sphere.SetThetaResolution(resolution) + sphere.SetPhiResolution(resolution) + if radius is not None: + sphere.SetRadius(radius) + sphere.Update() + geom = sphere.GetOutput() + mesh = PolyData(center) + glyph = mesh.glyph(orient=False, scale=False, factor=factor, geom=geom) + actor = _add_mesh( + self.plotter, + mesh=glyph, + color=color, + opacity=opacity, + backface_culling=backface_culling, + smooth_shading=self.smooth_shading, + ) + return actor, glyph def tube( self, @@ -602,30 +583,28 @@ def tube( reverse_lut=False, opacity=None, ): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - cmap = _get_colormap_from_array(colormap, normalized_colormap) - for pointa, pointb in zip(origin, destination): - line = Line(pointa, pointb) - if scalars is not None: - line.point_data["scalars"] = scalars[0, :] - scalars = "scalars" - color = None - else: - scalars = None - tube = line.tube(radius, n_sides=self.tube_n_sides) - actor = _add_mesh( - plotter=self.plotter, - mesh=tube, - scalars=scalars, - flip_scalars=reverse_lut, - rng=[vmin, vmax], - color=color, - show_scalar_bar=False, - cmap=cmap, - smooth_shading=self.smooth_shading, - opacity=opacity, - ) + cmap = _get_colormap_from_array(colormap, normalized_colormap) + for pointa, pointb in zip(origin, destination): + line = Line(pointa, pointb) + if scalars is not None: + line.point_data["scalars"] = scalars[0, :] + scalars = "scalars" + color = None + else: + scalars = None + tube = line.tube(radius, n_sides=self.tube_n_sides) + actor = _add_mesh( + plotter=self.plotter, + mesh=tube, + scalars=scalars, + flip_scalars=reverse_lut, + rng=[vmin, vmax], + color=color, + show_scalar_bar=False, + cmap=cmap, + smooth_shading=self.smooth_shading, + opacity=opacity, + ) return actor, tube def quiver3d( @@ -661,85 +640,83 @@ def quiver3d( _validate_type(scale_mode, str, "scale_mode") scale_map = dict(none=False, scalar="scalars", vector="vec") _check_option("scale_mode", scale_mode, list(scale_map)) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - factor = scale - vectors = np.c_[u, v, w] - points = np.vstack(np.c_[x, y, z]) - n_points = len(points) - cell_type = np.full(n_points, VTK_VERTEX) - cells = np.c_[np.full(n_points, 1), range(n_points)] - args = (cells, cell_type, points) - grid = UnstructuredGrid(*args) - if scalars is None: - scalars = np.ones((n_points,)) - mesh_scalars = None - else: - mesh_scalars = "scalars" - grid.point_data["scalars"] = np.array(scalars, float) - grid.point_data["vec"] = vectors - if mode == "2darrow": - return _arrow_glyph(grid, factor), grid - elif mode == "arrow": - alg = _glyph(grid, orient="vec", scalars="scalars", factor=factor) - mesh = pyvista.wrap(alg.GetOutput()) + factor = scale + vectors = np.c_[u, v, w] + points = np.vstack(np.c_[x, y, z]) + n_points = len(points) + cell_type = np.full(n_points, VTK_VERTEX) + cells = np.c_[np.full(n_points, 1), range(n_points)] + args = (cells, cell_type, points) + grid = UnstructuredGrid(*args) + if scalars is None: + scalars = np.ones((n_points,)) + mesh_scalars = None + else: + mesh_scalars = "scalars" + grid.point_data["scalars"] = np.array(scalars, float) + grid.point_data["vec"] = vectors + if mode == "2darrow": + return _arrow_glyph(grid, factor), grid + elif mode == "arrow": + alg = _glyph(grid, orient="vec", scalars="scalars", factor=factor) + mesh = pyvista.wrap(alg.GetOutput()) + else: + tr = None + if mode == "cone": + glyph = vtkConeSource() + glyph.SetCenter(0.5, 0, 0) + if glyph_radius is not None: + glyph.SetRadius(glyph_radius) + elif mode == "cylinder": + glyph = vtkCylinderSource() + if glyph_radius is not None: + glyph.SetRadius(glyph_radius) + elif mode == "oct": + glyph = vtkPlatonicSolidSource() + glyph.SetSolidTypeToOctahedron() else: - tr = None - if mode == "cone": - glyph = vtkConeSource() - glyph.SetCenter(0.5, 0, 0) - if glyph_radius is not None: - glyph.SetRadius(glyph_radius) - elif mode == "cylinder": - glyph = vtkCylinderSource() - if glyph_radius is not None: - glyph.SetRadius(glyph_radius) - elif mode == "oct": - glyph = vtkPlatonicSolidSource() - glyph.SetSolidTypeToOctahedron() - else: - assert mode == "sphere", mode # guaranteed above - glyph = vtkSphereSource() - if mode == "cylinder": - if glyph_height is not None: - glyph.SetHeight(glyph_height) - if glyph_center is not None: - glyph.SetCenter(glyph_center) - if glyph_resolution is not None: - glyph.SetResolution(glyph_resolution) + assert mode == "sphere", mode # guaranteed above + glyph = vtkSphereSource() + if mode == "cylinder": + if glyph_height is not None: + glyph.SetHeight(glyph_height) + if glyph_center is not None: + glyph.SetCenter(glyph_center) + if glyph_resolution is not None: + glyph.SetResolution(glyph_resolution) + tr = vtkTransform() + tr.RotateWXYZ(90, 0, 0, 1) + elif mode == "oct": + if solid_transform is not None: + assert solid_transform.shape == (4, 4) tr = vtkTransform() - tr.RotateWXYZ(90, 0, 0, 1) - elif mode == "oct": - if solid_transform is not None: - assert solid_transform.shape == (4, 4) - tr = vtkTransform() - tr.SetMatrix(solid_transform.astype(np.float64).ravel()) - if tr is not None: - # fix orientation - glyph.Update() - trp = vtkTransformPolyDataFilter() - trp.SetInputData(glyph.GetOutput()) - trp.SetTransform(tr) - glyph = trp + tr.SetMatrix(solid_transform.astype(np.float64).ravel()) + if tr is not None: + # fix orientation glyph.Update() - geom = glyph.GetOutput() - mesh = grid.glyph( - orient="vec", - scale=scale_map[scale_mode], - factor=factor, - geom=geom, - ) - actor = _add_mesh( - self.plotter, - mesh=mesh, - color=color, - opacity=opacity, - scalars=mesh_scalars if colormap is not None else None, - colormap=colormap, - show_scalar_bar=False, - backface_culling=backface_culling, - clim=clim, + trp = vtkTransformPolyDataFilter() + trp.SetInputData(glyph.GetOutput()) + trp.SetTransform(tr) + glyph = trp + glyph.Update() + geom = glyph.GetOutput() + mesh = grid.glyph( + orient="vec", + scale=scale_map[scale_mode], + factor=factor, + geom=geom, ) + actor = _add_mesh( + self.plotter, + mesh=mesh, + color=color, + opacity=opacity, + scalars=mesh_scalars if colormap is not None else None, + colormap=colormap, + show_scalar_bar=False, + backface_culling=backface_culling, + clim=clim, + ) return actor, mesh def text2d( @@ -747,42 +724,37 @@ def text2d( ): size = 14 if size is None else size position = (x_window, y_window) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - actor = self.plotter.add_text( - text, position=position, font_size=size, color=color, viewport=True - ) - if isinstance(justification, str): - if justification == "left": - actor.GetTextProperty().SetJustificationToLeft() - elif justification == "center": - actor.GetTextProperty().SetJustificationToCentered() - elif justification == "right": - actor.GetTextProperty().SetJustificationToRight() - else: - raise ValueError( - "Expected values for `justification`" - "are `left`, `center` or `right` but " - "got {} instead.".format(justification) - ) + actor = self.plotter.add_text( + text, position=position, font_size=size, color=color, viewport=True + ) + if isinstance(justification, str): + if justification == "left": + actor.GetTextProperty().SetJustificationToLeft() + elif justification == "center": + actor.GetTextProperty().SetJustificationToCentered() + elif justification == "right": + actor.GetTextProperty().SetJustificationToRight() + else: + raise ValueError( + "Expected values for `justification` are `left`, `center` or " + f"`right` but got {justification} instead." + ) _hide_testing_actor(actor) return actor def text3d(self, x, y, z, text, scale, color="white"): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - kwargs = dict( - points=np.array([x, y, z]).astype(float), - labels=[text], - point_size=scale, - text_color=color, - font_family=self.font_family, - name=text, - shape_opacity=0, - ) - if "always_visible" in signature(self.plotter.add_point_labels).parameters: - kwargs["always_visible"] = True - actor = self.plotter.add_point_labels(**kwargs) + kwargs = dict( + points=np.array([x, y, z]).astype(float), + labels=[text], + point_size=scale, + text_color=color, + font_family=self.font_family, + name=text, + shape_opacity=0, + ) + if "always_visible" in signature(self.plotter.add_point_labels).parameters: + kwargs["always_visible"] = True + actor = self.plotter.add_point_labels(**kwargs) _hide_testing_actor(actor) return actor @@ -801,26 +773,24 @@ def scalarbar( mapper = source.GetMapper() else: mapper = None - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - kwargs = dict( - color=color, - title=title, - n_labels=n_labels, - use_opacity=False, - n_colors=256, - position_x=0.15, - position_y=0.05, - width=0.7, - shadow=False, - bold=True, - label_font_size=22, - font_family=self.font_family, - background_color=bgcolor, - mapper=mapper, - ) - kwargs.update(extra_kwargs) - actor = self.plotter.add_scalar_bar(**kwargs) + kwargs = dict( + color=color, + title=title, + n_labels=n_labels, + use_opacity=False, + n_colors=256, + position_x=0.15, + position_y=0.05, + width=0.7, + shadow=False, + bold=True, + label_font_size=22, + font_family=self.font_family, + background_color=bgcolor, + mapper=mapper, + ) + kwargs.update(extra_kwargs) + actor = self.plotter.add_scalar_bar(**kwargs) _hide_testing_actor(actor) return actor @@ -843,7 +813,6 @@ def set_camera( *, rigid=None, update=True, - reset_camera=None, ): _set_3d_view( self.figure, @@ -852,18 +821,10 @@ def set_camera( distance=distance, focalpoint=focalpoint, roll=roll, - reset_camera=reset_camera, rigid=rigid, update=update, ) - @deprecated( - "reset_camera is deprecated and will be removed in 1.7, use " - "set_camera(distance='auto') instead" - ) - def reset_camera(self): - self.plotter.reset_camera() - def screenshot(self, mode="rgb", filename=None): return _take_3d_screenshot(figure=self.figure, mode=mode, filename=filename) @@ -892,7 +853,10 @@ def _toggle_antialias(self): plotter.disable_anti_aliasing() else: if not bad_system: - plotter.enable_anti_aliasing(aa_type="msaa") + plotter.enable_anti_aliasing( + aa_type="msaa", + multi_samples=self.multi_samples, + ) def remove_mesh(self, mesh_data): actor, _ = mesh_data @@ -931,13 +895,6 @@ def _update_picking_callback( self.plotter.picker.AddObserver(vtkCommand.EndPickEvent, on_pick) self.plotter.picker.SetVolumeOpacityIsovalue(0.0) - def _set_mesh_scalars(self, mesh, scalars, name): - # Catch: FutureWarning: Conversion of the second argument of - # issubdtype from `complex` to `np.complexfloating` is deprecated. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - mesh.point_data[name] = scalars - def _set_colormap_range( self, actor, ctable, scalar_bar, rng=None, background_color=None ): @@ -1190,7 +1147,6 @@ def _set_3d_view( focalpoint=None, distance=None, roll=None, - reset_camera=None, rigid=None, update=True, ): @@ -1201,14 +1157,6 @@ def _set_3d_view( # camera slides along the vector defined from camera position to focal point until # all of the actors can be seen (quoting PyVista's docs) - if reset_camera is not None: - reset_camera = False - warn( - "reset_camera is deprecated and will be removed in 1.7, use " - "distance='auto' instead", - FutureWarning, - ) - # Figure out our current parameters in the transformed space _, phi, theta = _get_user_camera_direction(figure.plotter, rigid) @@ -1258,9 +1206,7 @@ def _set_3d_view( def _set_3d_title(figure, title, size=16): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - figure.plotter.add_text(title, font_size=size, color="white", name="title") + figure.plotter.add_text(title, font_size=size, color="white", name="title") figure.plotter.update() _process_events(figure.plotter) @@ -1270,26 +1216,22 @@ def _check_3d_figure(figure): def _close_3d_figure(figure): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - # copy the plotter locally because figure.plotter is modified - plotter = figure.plotter - # close the window - plotter.close() # additional cleaning following signal_close - _process_events(plotter) - # free memory and deregister from the scraper - plotter.deep_clean() # remove internal references - _ALL_PLOTTERS.pop(plotter._id_name, None) - _process_events(plotter) + # copy the plotter locally because figure.plotter is modified + plotter = figure.plotter + # close the window + plotter.close() # additional cleaning following signal_close + _process_events(plotter) + # free memory and deregister from the scraper + plotter.deep_clean() # remove internal references + _ALL_PLOTTERS.pop(plotter._id_name, None) + _process_events(plotter) def _take_3d_screenshot(figure, mode="rgb", filename=None): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - _process_events(figure.plotter) - return figure.plotter.screenshot( - transparent_background=(mode == "rgba"), filename=filename - ) + _process_events(figure.plotter) + return figure.plotter.screenshot( + transparent_background=(mode == "rgba"), filename=filename + ) def _process_events(plotter): diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index 3f7f28abc1b..6e59c2b6c20 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -112,6 +112,7 @@ _take_3d_screenshot, # noqa: F401 ) from ._utils import ( + _ICONS_PATH, _init_mne_qtapp, _qt_app_exec, _qt_detect_theme, @@ -276,13 +277,13 @@ def __init__(self, value, callback, icon=None): self.setText(value) self.released.connect(callback) if icon: - self.setIcon(QIcon.fromTheme(icon)) + self.setIcon(_qicon(icon)) def _click(self): self.click() def _set_icon(self, icon): - self.setIcon(QIcon.fromTheme(icon)) + self.setIcon(_qicon(icon)) class _Slider(QSlider, _AbstractSlider, _Widget, metaclass=_BaseWidget): @@ -474,16 +475,16 @@ def __init__(self, value, rng, callback): self._slider.valueChanged.connect(callback) self._nav_hbox = QHBoxLayout() self._play_button = QPushButton() - self._play_button.setIcon(QIcon.fromTheme("play")) + self._play_button.setIcon(_qicon("play")) self._nav_hbox.addWidget(self._play_button) self._pause_button = QPushButton() - self._pause_button.setIcon(QIcon.fromTheme("pause")) + self._pause_button.setIcon(_qicon("pause")) self._nav_hbox.addWidget(self._pause_button) self._reset_button = QPushButton() - self._reset_button.setIcon(QIcon.fromTheme("reset")) + self._reset_button.setIcon(_qicon("reset")) self._nav_hbox.addWidget(self._reset_button) self._loop_button = QPushButton() - self._loop_button.setIcon(QIcon.fromTheme("restore")) + self._loop_button.setIcon(_qicon("restore")) self._loop_button.setStyleSheet("background-color : lightgray;") self._loop_button._checked = True @@ -930,7 +931,7 @@ def _dialog_create( callback, *, icon="Warning", - buttons=[], + buttons=(), modal=True, window=None, ): @@ -1205,7 +1206,7 @@ def _dock_add_file_button( desc, func, *, - filter=None, + filter_=None, initial_directory=None, save=False, is_directory=False, @@ -1226,11 +1227,11 @@ def callback(): ) elif save: name = QFileDialog.getSaveFileName( - parent=self._window, directory=initial_directory, filter=filter + parent=self._window, directory=initial_directory, filter=filter_ ) else: name = QFileDialog.getOpenFileName( - parent=self._window, directory=initial_directory, filter=filter + parent=self._window, directory=initial_directory, filter=filter_ ) name = name[0] if isinstance(name, tuple) else name # handle the cancel button @@ -1494,18 +1495,18 @@ def closeEvent(event): self._window.closeEvent = closeEvent def _window_load_icons(self): - self._icons["help"] = QIcon.fromTheme("help") - self._icons["play"] = QIcon.fromTheme("play") - self._icons["pause"] = QIcon.fromTheme("pause") - self._icons["reset"] = QIcon.fromTheme("reset") - self._icons["scale"] = QIcon.fromTheme("scale") - self._icons["clear"] = QIcon.fromTheme("clear") - self._icons["movie"] = QIcon.fromTheme("movie") - self._icons["restore"] = QIcon.fromTheme("restore") - self._icons["screenshot"] = QIcon.fromTheme("screenshot") - self._icons["visibility_on"] = QIcon.fromTheme("visibility_on") - self._icons["visibility_off"] = QIcon.fromTheme("visibility_off") - self._icons["folder"] = QIcon.fromTheme("folder") + self._icons["help"] = _qicon("help") + self._icons["play"] = _qicon("play") + self._icons["pause"] = _qicon("pause") + self._icons["reset"] = _qicon("reset") + self._icons["scale"] = _qicon("scale") + self._icons["clear"] = _qicon("clear") + self._icons["movie"] = _qicon("movie") + self._icons["restore"] = _qicon("restore") + self._icons["screenshot"] = _qicon("screenshot") + self._icons["visibility_on"] = _qicon("visibility_on") + self._icons["visibility_off"] = _qicon("visibility_off") + self._icons["folder"] = _qicon("folder") def _window_clean(self): self.figure._plotter = None @@ -1844,3 +1845,10 @@ def _testing_context(interactive): finally: pyvista.OFF_SCREEN = orig_offscreen renderer.MNE_3D_BACKEND_TESTING = orig_testing + + +def _qicon(name): + # Get icon from theme with a file fallback + return QIcon.fromTheme( + name, QIcon(str(_ICONS_PATH / "light" / "actions" / f"{name}.svg")) + ) diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index d613a909f67..25e87fcff22 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -33,6 +33,7 @@ "notebook", ) ALLOWED_QUIVER_MODES = ("2darrow", "arrow", "cone", "cylinder", "sphere", "oct") +_ICONS_PATH = Path(__file__).parents[2] / "icons" def _get_colormap_from_array( @@ -68,13 +69,12 @@ def _check_color(color): raise ValueError("Values out of range [0.0, 1.0].") else: raise TypeError( - "Expected data type is `np.int64`, `np.int32`, or " - "`np.float64` but {} was given.".format(np_color.dtype) + "Expected data type is `np.int64`, `np.int32`, or `np.float64` but " + f"{np_color.dtype} was given." ) else: raise TypeError( - "Expected type is `str` or iterable but " - "{} was given.".format(type(color)) + f"Expected type is `str` or iterable but {type(color)} was given." ) return color @@ -90,9 +90,9 @@ def _alpha_blend_background(ctable, background_color): def _qt_init_icons(): from qtpy.QtGui import QIcon - icons_path = str(Path(__file__).parents[2] / "icons") - QIcon.setThemeSearchPaths([icons_path]) - return icons_path + QIcon.setThemeSearchPaths([str(_ICONS_PATH)] + QIcon.themeSearchPaths()) + QIcon.setFallbackThemeName("light") + return str(_ICONS_PATH) @contextmanager @@ -150,7 +150,8 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): bundle = NSBundle.mainBundle() info = bundle.localizedInfoDictionary() or bundle.infoDictionary() - info["CFBundleName"] = app_name + if "CFBundleName" not in info: + info["CFBundleName"] = app_name except ModuleNotFoundError: pass @@ -181,7 +182,11 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False): if enable_icon or splash: icons_path = _qt_init_icons() - if enable_icon and app.windowIcon().cacheKey() != _QT_ICON_KEYS["app"]: + if ( + enable_icon + and app.windowIcon().cacheKey() != _QT_ICON_KEYS["app"] + and app.windowIcon().isNull() # don't overwrite existing icon (e.g. MNELAB) + ): # Set icon kind = "bigsur_" if platform.mac_ver()[0] >= "10.16" else "default_" icon = QIcon(f"{icons_path}/mne_{kind}icon.png") @@ -274,89 +279,45 @@ def _qt_detect_theme(): def _qt_get_stylesheet(theme): _validate_type(theme, ("path-like",), "theme") theme = str(theme) - orig_theme = theme - system_theme = None - stylesheet = "" - extra_msg = "" - if theme == "auto": - theme = system_theme = _qt_detect_theme() - if theme in ("dark", "light"): - if system_theme is None: - system_theme = _qt_detect_theme() - qt_version, api = _check_qt_version(return_api=True) - # On macOS, we shouldn't need to set anything when the requested theme - # matches that of the current OS state - if sys.platform == "darwin": - extra_msg = f"when in {system_theme} mode on macOS" - # But before 5.13, we need to patch some mistakes - if sys.platform == "darwin" and theme == system_theme: - if theme == "dark" and _compare_version(qt_version, "<", "5.13"): - # Taken using "Digital Color Meter" on macOS 12.2.1 looking at - # Meld, and also adapting (MIT-licensed) - # https://github.com/ColinDuquesnoy/QDarkStyleSheet/blob/master/qdarkstyle/dark/style.qss # noqa: E501 - # Something around rgb(51, 51, 51) worked as the bgcolor here, - # but it's easy enough just to set it transparent and inherit - # the bgcolor of the window (which is the same). We also take - # the separator images from QDarkStyle (MIT). - icons_path = _qt_init_icons() - stylesheet = """\ -QStatusBar { - border: 1px solid rgb(76, 76, 75); - background: transparent; -} -QStatusBar QLabel { - background: transparent; -} -QToolBar { - background-color: transparent; - border-bottom: 1px solid rgb(99, 99, 99); -} -QToolBar::separator:horizontal { - width: 16px; - image: url("%(icons_path)s/toolbar_separator_horizontal@2x.png"); -} -QToolBar::separator:vertical { - height: 16px; - image: url("%(icons_path)s/toolbar_separator_vertical@2x.png"); -} -QToolBar::handle:horizontal { - width: 16px; - image: url("%(icons_path)s/toolbar_move_horizontal@2x.png"); -} -QToolBar::handle:vertical { - height: 16px; - image: url("%(icons_path)s/toolbar_move_vertical@2x.png"); -} -""" % dict( - icons_path=icons_path - ) + stylesheet = "" # no stylesheet + if theme in ("auto", "dark", "light"): + if theme == "auto": + return stylesheet + assert theme in ("dark", "light") + system_theme = _qt_detect_theme() + if theme == system_theme: + return stylesheet + _, api = _check_qt_version(return_api=True) + # On macOS or Qt 6, we shouldn't need to set anything when the requested + # theme matches that of the current OS state + try: + import qdarkstyle + except ModuleNotFoundError: + logger.info( + f'To use {theme} mode when in {system_theme} mode, "qdarkstyle" has' + "to be installed! You can install it with:\n" + "pip install qdarkstyle\n" + ) else: - # Here we are on non-macOS (or on macOS but our sys theme does not - # match the requested theme) - if api in ("PySide6", "PyQt6"): - if orig_theme != "auto" and not (theme == system_theme == "light"): - warn( - f"Setting theme={repr(theme)} is not yet supported " - f"for {api} in qdarkstyle, it will be ignored" - ) + if api in ("PySide6", "PyQt6") and _compare_version( + qdarkstyle.__version__, "<", "3.2.3" + ): + warn( + f"Setting theme={repr(theme)} is not supported for {api} in " + f"qdarkstyle {qdarkstyle.__version__}, it will be ignored. " + "Consider upgrading qdarkstyle to >=3.2.3." + ) else: - try: - import qdarkstyle - except ModuleNotFoundError: - logger.info( - f'To use {theme} mode{extra_msg}, "qdarkstyle" has to ' - "be installed! You can install it with:\n" - "pip install qdarkstyle\n" - ) - else: - klass = getattr( + stylesheet = qdarkstyle.load_stylesheet( + getattr( getattr(qdarkstyle, theme).palette, f"{theme.capitalize()}Palette", ) - stylesheet = qdarkstyle.load_stylesheet(klass) + ) + return stylesheet else: try: - file = open(theme, "r") + file = open(theme) except OSError: warn( "Requested theme file not found, will use light instead: " @@ -365,8 +326,7 @@ def _qt_get_stylesheet(theme): else: with file as fid: stylesheet = fid.read() - - return stylesheet + return stylesheet def _should_raise_window(): diff --git a/mne/viz/backends/renderer.py b/mne/viz/backends/renderer.py index 1bb396d165c..faa209454e1 100644 --- a/mne/viz/backends/renderer.py +++ b/mne/viz/backends/renderer.py @@ -264,8 +264,6 @@ def set_3d_view( focalpoint=None, distance=None, roll=None, - *, - reset_camera=None, ): """Configure the view of the given scene. @@ -278,8 +276,6 @@ def set_3d_view( %(focalpoint)s %(distance)s %(roll)s - reset_camera : bool - Deprecated, use ``distance="auto"`` instead. """ backend._set_3d_view( figure=figure, @@ -288,7 +284,6 @@ def set_3d_view( focalpoint=focalpoint, distance=distance, roll=roll, - reset_camera=reset_camera, ) @@ -396,7 +391,7 @@ def _enable_time_interaction( current_time_func, times, init_playback_speed=0.01, - playback_speed_range=[0.01, 0.1], + playback_speed_range=(0.01, 0.1), ): from ..ui_events import ( PlaybackSpeed, diff --git a/mne/viz/backends/tests/test_renderer.py b/mne/viz/backends/tests/test_renderer.py index ec62a1b6748..b20bb6e4865 100644 --- a/mne/viz/backends/tests/test_renderer.py +++ b/mne/viz/backends/tests/test_renderer.py @@ -194,7 +194,7 @@ def test_renderer(renderer, monkeypatch): "-uc", "import mne; mne.viz.create_3d_figure((800, 600), show=True); " "backend = mne.viz.get_3d_backend(); " - "assert backend == %r, backend" % (backend,), + f"assert backend == {repr(backend)}, backend", ] monkeypatch.setenv("MNE_3D_BACKEND", backend) run_subprocess(cmd) diff --git a/mne/viz/backends/tests/test_utils.py b/mne/viz/backends/tests/test_utils.py index 196eb030cea..26636004026 100644 --- a/mne/viz/backends/tests/test_utils.py +++ b/mne/viz/backends/tests/test_utils.py @@ -8,7 +8,6 @@ import platform from colorsys import rgb_to_hls -from contextlib import nullcontext import numpy as np import pytest @@ -66,19 +65,7 @@ def test_theme_colors(pg_backend, theme, monkeypatch, tmp_path): monkeypatch.setattr(darkdetect, "theme", lambda: "light") raw = RawArray(np.zeros((1, 1000)), create_info(1, 1000.0, "eeg")) _, api = _check_qt_version(return_api=True) - if api in ("PyQt6", "PySide6"): - if theme == "dark": # we force darkdetect to say the sys is light - ctx = pytest.warns(RuntimeWarning, match="not yet supported") - else: - ctx = nullcontext() - return_early = True - else: - ctx = nullcontext() - return_early = False - with ctx: - fig = raw.plot(theme=theme) - if return_early: - return # we could add a ton of conditionals below, but KISS + fig = raw.plot(theme=theme) is_dark = _qt_is_dark(fig) # on Darwin these checks get complicated, so don't bother for now if platform.system() == "Darwin": diff --git a/mne/viz/circle.py b/mne/viz/circle.py index 7f9d16ecc54..b19130b3bff 100644 --- a/mne/viz/circle.py +++ b/mne/viz/circle.py @@ -7,7 +7,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - from functools import partial from itertools import cycle @@ -97,7 +96,13 @@ def circular_layout( def _plot_connectivity_circle_onpick( - event, fig=None, ax=None, indices=None, n_nodes=0, node_angles=None, ylim=[9, 10] + event, + fig=None, + ax=None, + indices=None, + n_nodes=0, + node_angles=None, + ylim=(9, 10), ): """Isolate connections around a single node when user left clicks a node. diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 20dbeed142c..9871a0c2647 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -145,19 +145,7 @@ def plot_epochs_image( ``overlay_times`` should be ordered to correspond with the :class:`~mne.Epochs` object (i.e., ``overlay_times[0]`` corresponds to ``epochs[0]``, etc). - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_epochs, n_channels, n_times)``) and return an - :class:`array ` of shape ``(n_epochs, n_times)``. For - example:: - - combine = lambda data: np.median(data, axis=1) - - If ``combine`` is ``None``, channels are combined by computing GFP, - unless ``group_by`` is also ``None`` and ``picks`` is a list of - specific channels (not channel types), in which case no combining is - performed and each channel gets its own figure. See Notes for further - details. Defaults to ``None``. + %(combine_plot_epochs_image)s group_by : None | dict Specifies which channels are aggregated into a single figure, with aggregation method determined by the ``combine`` parameter. If not @@ -286,8 +274,8 @@ def plot_epochs_image( if len(set(this_ch_type)) > 1: types = ", ".join(set(this_ch_type)) raise ValueError( - 'Cannot combine sensors of different types; "{}" ' - "contains types {}.".format(this_group, types) + f'Cannot combine sensors of different types; "{this_group}" contains ' + f"types {types}." ) # now we know they're all the same type... group_by[this_group] = dict( @@ -297,8 +285,8 @@ def plot_epochs_image( # are they trying to combine a single channel? if len(these_picks) < 2 and combine_given: warn( - 'Only one channel in group "{}"; cannot combine by method ' - '"{}".'.format(this_group, combine) + f'Only one channel in group "{this_group}"; cannot combine by method ' + f'"{combine}".' ) # check for compatible `fig` / `axes`; instantiate figs if needed; add @@ -437,13 +425,12 @@ def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False): n_axes = 1 + int(evoked) + int(colorbar) ax_names = ("image", "evoked", "colorbar") ax_names = np.array(ax_names)[np.where([True, evoked, colorbar])] - prefix = "Since evoked={} and colorbar={}, ".format(evoked, colorbar) + prefix = f"Since evoked={evoked} and colorbar={colorbar}, " # got both fig and axes if fig is not None and axes is not None: raise ValueError( - 'At least one of "fig" or "axes" must be None; got ' - "fig={}, axes={}.".format(fig, axes) + f'At least one of "fig" or "axes" must be None; got fig={fig}, axes={axes}.' ) # got fig=None and axes=None: make fig(s) and axes @@ -468,8 +455,7 @@ def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False): # `plot_image`, be forgiving of presence/absence of sensor inset axis. if len(fig.axes) not in (n_axes, n_axes + 1): raise ValueError( - '{}"fig" must contain {} axes, got {}.' - "".format(prefix, n_axes, len(fig.axes)) + f'{prefix}"fig" must contain {n_axes} axes, got {len(fig.axes)}.' ) if len(list(group_by)) != 1: raise ValueError( @@ -498,8 +484,7 @@ def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False): if isinstance(axes, list): if len(axes) != n_axes: raise ValueError( - '{}"axes" must be length {}, got {}.' - "".format(prefix, n_axes, len(axes)) + f'{prefix}"axes" must be length {n_axes}, got {len(axes)}.' ) # for list of axes to work, must be only one group if len(list(group_by)) != 1: @@ -518,14 +503,14 @@ def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False): # group_by dict and the user won't have known what keys we chose. if set(axes) != set(group_by): raise ValueError( - 'If "axes" is a dict its keys ({}) must match ' - 'the keys in "group_by" ({}).'.format(list(axes), list(group_by)) + f'If "axes" is a dict its keys ({list(axes)}) must match the keys in ' + f'"group_by" ({list(group_by)}).' ) for this_group, this_axes_list in axes.items(): if len(this_axes_list) != n_axes: raise ValueError( - '{}each value in "axes" must be a list of {} ' - "axes, got {}.".format(prefix, n_axes, len(this_axes_list)) + f'{prefix}each value in "axes" must be a list of {n_axes} axes, got' + f" {len(this_axes_list)}." ) # NB: next line assumes all axes in each list are in same figure group_by[this_group]["fig"] = this_axes_list[0].get_figure() @@ -656,10 +641,9 @@ def _plot_epochs_image( # draw the colorbar if colorbar: - from matplotlib.pyplot import colorbar as cbar - if "colorbar" in ax: # axes supplied by user - this_colorbar = cbar(im, cax=ax["colorbar"]) + cax = ax["colorbar"] + this_colorbar = cax.figure.colorbar(im, cax=cax) this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12) else: # we created them this_colorbar = fig.colorbar(im, ax=ax_im) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 11a229d80d1..5883dfaf5f5 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -193,7 +193,7 @@ def _line_plot_onselect( method = "mean" if psd else "rms" this_data, _ = _merge_ch_data(this_data, ch_type, [], method=method) - title = "%s %s" % (ch_type, method.upper()) + title = f"{ch_type} {method.upper()}" else: title = ch_type this_data = np.average(this_data, axis=1) @@ -213,7 +213,7 @@ def _line_plot_onselect( ) unit = "Hz" if psd else time_unit - fig.suptitle("Average over %.2f%s - %.2f%s" % (xmin, unit, xmax, unit), y=0.1) + fig.suptitle(f"Average over {xmin:.2f}{unit} - {xmax:.2f}{unit}", y=0.1) plt_show() if text is not None: text.set_visible(False) @@ -390,9 +390,9 @@ def _plot_evoked( ax.set_xlabel("") ims = [ax.images[0] for ax in axes.values()] clims = np.array([im.get_clim() for im in ims]) - min, max = clims.min(), clims.max() + min_, max_ = clims.min(), clims.max() for im in ims: - im.set_clim(min, max) + im.set_clim(min_, max_) figs = [ax.get_figure() for ax in axes.values()] if len(set(figs)) == 1: return figs[0] @@ -628,7 +628,7 @@ def _plot_lines( if this_type in _DATA_CH_TYPES_SPLIT: logger.info( "Need more than one channel to make " - "topography for %s. Disabling interactivity." % (this_type,) + f"topography for {this_type}. Disabling interactivity." ) selectables[type_idx] = False @@ -1171,7 +1171,7 @@ def plot_evoked_topo( scalings=None, title=None, proj=False, - vline=[0.0], + vline=(0.0,), fig_background=None, merge_grads=False, legend=True, @@ -1218,7 +1218,7 @@ def plot_evoked_topo( If true SSP projections are applied before display. If 'interactive', a check box for reversible selection of SSP projection vectors will be shown. - vline : list of float | None + vline : list of float | float| None The values at which to show a vertical line. fig_background : None | ndarray A background image for the figure. This must work with a call to @@ -1481,7 +1481,7 @@ def plot_evoked_image( def _plot_update_evoked(params, bools): """Update the plot evoked lines.""" - picks, evoked = [params[k] for k in ("picks", "evoked")] + picks, evoked = (params[k] for k in ("picks", "evoked")) projs = [ proj for ii, proj in enumerate(params["projs"]) if ii in np.where(bools)[0] ] @@ -1695,10 +1695,10 @@ def whitened_gfp(x, rank=None): for ch, sub_picks in picks_list: this_rank = rank_[ch] - title = "{0} ({2}{1})".format( + title = "{} ({}{})".format( titles_[ch] if n_columns > 1 else ch, - this_rank, "rank " if n_columns > 1 else "", + this_rank, ) label = noise_cov.get("method", "empirical") @@ -1868,7 +1868,7 @@ def plot_evoked_joint( from matplotlib.patches import ConnectionPatch if ts_args is not None and not isinstance(ts_args, dict): - raise TypeError("ts_args must be dict or None, got type %s" % (type(ts_args),)) + raise TypeError(f"ts_args must be dict or None, got type {type(ts_args)}") ts_args = dict() if ts_args is None else ts_args.copy() ts_args["time_unit"], _ = _check_time_unit( ts_args.get("time_unit", "s"), evoked.times @@ -1878,7 +1878,7 @@ def plot_evoked_joint( got_axes = False illegal_args = {"show", "times", "exclude"} for args in (ts_args, topomap_args): - if any((x in args for x in illegal_args)): + if any(x in args for x in illegal_args): raise ValueError( "Don't pass any of {} as *_args.".format(", ".join(list(illegal_args))) ) @@ -2106,8 +2106,8 @@ def _validate_style_keys_pce(styles, conditions, tags): styles = deepcopy(styles) if not set(styles).issubset(tags.union(conditions)): raise ValueError( - 'The keys in "styles" ({}) must match the keys in ' - '"evokeds" ({}).'.format(list(styles), conditions) + f'The keys in "styles" ({list(styles)}) must match the keys in ' + f'"evokeds" ({conditions}).' ) # make sure all the keys are in there for cond in conditions: @@ -2145,26 +2145,20 @@ def _validate_colors_pce(colors, cmap, conditions, tags): if isinstance(colors, (list, tuple, np.ndarray)): if len(conditions) > len(colors): raise ValueError( - "Trying to plot {} conditions, but there are only" - " {} colors{}. Please specify colors manually.".format( - len(conditions), len(colors), err_suffix - ) + f"Trying to plot {len(conditions)} conditions, but there are only " + f"{len(colors)} colors{err_suffix}. Please specify colors manually." ) colors = dict(zip(conditions, colors)) # should be a dict by now... if not isinstance(colors, dict): raise TypeError( - '"colors" must be a dict, list, or None; got {}.'.format( - type(colors).__name__ - ) + f'"colors" must be a dict, list, or None; got {type(colors).__name__}.' ) # validate color dict keys if not set(colors).issubset(tags.union(conditions)): raise ValueError( - 'If "colors" is a dict its keys ({}) must ' - 'match the keys/conditions in "evokeds" ({}).'.format( - list(colors), conditions - ) + f'If "colors" is a dict its keys ({list(colors)}) must match the ' + f'keys/conditions in "evokeds" ({conditions}).' ) # validate color dict values color_vals = list(colors.values()) @@ -2218,25 +2212,21 @@ def _validate_linestyles_pce(linestyles, conditions, tags): if isinstance(linestyles, (list, tuple, np.ndarray)): if len(conditions) > len(linestyles): raise ValueError( - "Trying to plot {} conditions, but there are " - "only {} linestyles. Please specify linestyles " - "manually.".format(len(conditions), len(linestyles)) + f"Trying to plot {len(conditions)} conditions, but there are only " + f"{len(linestyles)} linestyles. Please specify linestyles manually." ) linestyles = dict(zip(conditions, linestyles)) # should be a dict by now... if not isinstance(linestyles, dict): raise TypeError( - '"linestyles" must be a dict, list, or None; got {}.'.format( - type(linestyles).__name__ - ) + '"linestyles" must be a dict, list, or None; got ' + f"{type(linestyles).__name__}." ) # validate linestyle dict keys if not set(linestyles).issubset(tags.union(conditions)): raise ValueError( - 'If "linestyles" is a dict its keys ({}) must ' - 'match the keys/conditions in "evokeds" ({}).'.format( - list(linestyles), conditions - ) + f'If "linestyles" is a dict its keys ({list(linestyles)}) must match the ' + f'keys/conditions in "evokeds" ({conditions}).' ) # normalize linestyle values (so we can accurately count unique linestyles # later). See https://github.com/matplotlib/matplotlib/blob/master/matplotlibrc.template#L131-L133 # noqa @@ -2475,8 +2465,7 @@ def _draw_axes_pce( ybounds = _trim_ticks(ax.get_yticks(), ymin, ymax)[[0, -1]] else: raise ValueError( - '"truncate_yaxis" must be bool or ' - '"auto", got {}'.format(truncate_yaxis) + f'"truncate_yaxis" must be bool or "auto", got {truncate_yaxis}' ) _setup_ax_spines( ax, @@ -2494,14 +2483,22 @@ def _draw_axes_pce( ) -def _get_data_and_ci(evoked, combine, combine_func, picks, scaling=1, ci_fun=None): +def _get_data_and_ci( + evoked, combine, combine_func, ch_type, picks, scaling=1, ci_fun=None +): """Compute (sensor-aggregated, scaled) time series and possibly CI.""" picks = np.array(picks).flatten() # apply scalings data = np.array([evk.data[picks] * scaling for evk in evoked]) # combine across sensors if combine is not None: - logger.info('combining channels using "{}"'.format(combine)) + if combine == "gfp" and ch_type == "eeg": + msg = f"GFP ({ch_type} channels)" + elif combine == "gfp" and ch_type in ("mag", "grad"): + msg = f"RMS ({ch_type} channels)" + else: + msg = f'"{combine}"' + logger.info(f"combining channels using {msg}") data = combine_func(data) # get confidence band if ci_fun is not None: @@ -2529,9 +2526,7 @@ def _get_ci_function_pce(ci, do_topo=False): return partial(_ci, ci=ci, method=method) else: raise TypeError( - '"ci" must be None, bool, float or callable, got {}'.format( - type(ci).__name__ - ) + f'"ci" must be None, bool, float or callable, got {type(ci).__name__}' ) @@ -2563,7 +2558,7 @@ def _plot_compare_evokeds( ax.set_title(title) -def _title_helper_pce(title, picked_types, picks, ch_names, combine): +def _title_helper_pce(title, picked_types, picks, ch_names, ch_type, combine): """Format title for plot_compare_evokeds.""" if title is None: title = ( @@ -2574,9 +2569,13 @@ def _title_helper_pce(title, picked_types, picks, ch_names, combine): # add the `combine` modifier do_combine = picked_types or len(ch_names) > 1 if title is not None and len(title) and isinstance(combine, str) and do_combine: - _comb = combine.upper() if combine == "gfp" else combine - _comb = "std. dev." if _comb == "std" else _comb - title += " ({})".format(_comb) + if combine == "gfp": + _comb = "RMS" if ch_type in ("mag", "grad") else "GFP" + elif combine == "std": + _comb = "std. dev." + else: + _comb = combine + title += f" ({_comb})" return title @@ -2756,18 +2755,7 @@ def plot_compare_evokeds( value of the ``combine`` parameter. Defaults to ``None``. show : bool Whether to show the figure. Defaults to ``True``. - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_evokeds, n_channels, n_times)``) and return an - :class:`array ` of shape ``(n_epochs, n_times)``. For - example:: - - combine = lambda data: np.median(data, axis=1) - - If ``combine`` is ``None``, channels are combined by computing GFP, - unless ``picks`` is a single channel (not channel type) or - ``axes='topo'``, in which cases no combining is performed. Defaults to - ``None``. + %(combine_plot_compare_evokeds)s %(sphere_topomap_auto)s %(time_unit)s @@ -2862,7 +2850,7 @@ def plot_compare_evokeds( if not isinstance(evokeds, dict): raise TypeError( '"evokeds" must be a dict, list, or instance of ' - "mne.Evoked; got {}".format(type(evokeds).__name__) + f"mne.Evoked; got {type(evokeds).__name__}" ) evokeds = deepcopy(evokeds) # avoid modifying dict outside function scope for cond, evoked in evokeds.items(): @@ -2900,6 +2888,8 @@ def plot_compare_evokeds( "misc", # from ICA "emg", "ref_meg", + "eyegaze", + "pupil", ) ch_types = [ t for t in info.get_channel_types(picks=picks, unique=True) if t in all_types @@ -2915,20 +2905,27 @@ def plot_compare_evokeds( # cannot combine a single channel if (len(picks) < 2) and combine is not None: warn( - 'Only {} channel in "picks"; cannot combine by method "{}".'.format( - len(picks), combine - ) + f'Only {len(picks)} channel in "picks"; cannot combine by method ' + f'"{combine}".' ) # `combine` defaults to GFP unless picked a single channel or axes='topo' do_topo = isinstance(axes, str) and axes == "topo" if combine is None and len(picks) > 1 and not do_topo: combine = "gfp" # convert `combine` into callable (if None or str) - combine_func = _make_combine_callable(combine) + combine_funcs = { + ch_type: _make_combine_callable(combine, ch_type=ch_type) + for ch_type in ch_types + } # title title = _title_helper_pce( - title, picked_types, picks=orig_picks, ch_names=ch_names, combine=combine + title, + picked_types, + picks=orig_picks, + ch_names=ch_names, + ch_type=ch_types[0] if len(ch_types) == 1 else None, + combine=combine, ) topo_disp_title = False # setup axes @@ -2953,9 +2950,7 @@ def plot_compare_evokeds( _validate_if_list_of_axes(axes, obligatory_len=len(ch_types)) if len(ch_types) > 1: - logger.info( - "Multiple channel types selected, returning one figure " "per type." - ) + logger.info("Multiple channel types selected, returning one figure per type.") figs = list() for ch_type, ax in zip(ch_types, axes): _picks = picks_by_type[ch_type] @@ -2964,7 +2959,12 @@ def plot_compare_evokeds( # don't pass `combine` here; title will run through this helper # function a second time & it will get added then _title = _title_helper_pce( - title, picked_types, picks=_picks, ch_names=_ch_names, combine=None + title, + picked_types, + picks=_picks, + ch_names=_ch_names, + ch_type=ch_type, + combine=None, ) figs.extend( plot_compare_evokeds( @@ -3002,12 +3002,18 @@ def plot_compare_evokeds( colorbar_ticks, ) = _handle_styles_pce(styles, linestyles, colors, cmap, conditions) # From now on there is only 1 channel type - assert len(ch_types) == 1 + if not len(ch_types): + got_idx = _picks_to_idx(info, picks=orig_picks) + got = np.unique(np.array(info.get_channel_types())[got_idx]).tolist() + raise RuntimeError( + f"No valid channel type(s) provided. Got {got}. Valid channel types are:" + f"\n{all_types}." + ) ch_type = ch_types[0] # some things that depend on ch_type: units = _handle_default("units")[ch_type] scalings = _handle_default("scalings")[ch_type] - + combine_func = combine_funcs[ch_type] # prep for topo pos_picks = picks # need this version of picks for sensor location inset info = pick_info(info, sel=picks, copy=True) @@ -3140,6 +3146,7 @@ def click_func( this_evokeds, combine, c_func, + ch_type=ch_type, picks=_picks, scaling=scalings, ci_fun=ci_fun, diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index dd691fccf3c..3ce9c6756e2 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -2,6 +2,7 @@ author: Marijn van Vliet """ + # License: BSD-3-Clause # Copyright the MNE-Python contributors. from functools import partial @@ -380,6 +381,10 @@ def _configure_dock(self): if self._show_density: r._dock_add_label(value="max value", align=True, layout=layout) + @_auto_weakref + def _callback(vmax, kind, scaling): + self.set_vmax(vmax / scaling, kind=kind) + for surf_map in self._surf_maps: if surf_map["map_kind"] == "meg": scaling = DEFAULTS["scalings"]["grad"] @@ -388,32 +393,28 @@ def _configure_dock(self): rng = [0, np.max(np.abs(surf_map["data"])) * scaling] hlayout = r._dock_add_layout(vertical=False) - @_auto_weakref - def _callback(vmax, type, scaling): - self.set_vmax(vmax / scaling, type=type) - - self._widgets[ - f"vmax_slider_{surf_map['map_kind']}" - ] = r._dock_add_slider( - name=surf_map["map_kind"].upper(), - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, type=surf_map["map_kind"], scaling=scaling - ), - double=True, - layout=hlayout, + self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = ( + r._dock_add_slider( + name=surf_map["map_kind"].upper(), + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial( + _callback, kind=surf_map["map_kind"], scaling=scaling + ), + double=True, + layout=hlayout, + ) ) - self._widgets[ - f"vmax_spin_{surf_map['map_kind']}" - ] = r._dock_add_spin_box( - name="", - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, type=surf_map["map_kind"], scaling=scaling - ), - layout=hlayout, + self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = ( + r._dock_add_spin_box( + name="", + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial( + _callback, kind=surf_map["map_kind"], scaling=scaling + ), + layout=hlayout, + ) ) r._layout_add_widget(layout, hlayout) @@ -473,15 +474,15 @@ def _on_colormap_range(self, event): if self._show_density: surf_map["mesh"].update_overlay(name="field", rng=[vmin, vmax]) # Update the GUI widgets - if type == "meg": + if kind == "meg": scaling = DEFAULTS["scalings"]["grad"] else: scaling = DEFAULTS["scalings"]["eeg"] with disable_ui_events(self): - widget = self._widgets.get(f"vmax_slider_{type}", None) + widget = self._widgets.get(f"vmax_slider_{kind}", None) if widget is not None: widget.set_value(vmax * scaling) - widget = self._widgets.get(f"vmax_spin_{type}", None) + widget = self._widgets.get(f"vmax_spin_{kind}", None) if widget is not None: widget.set_value(vmax * scaling) @@ -541,28 +542,28 @@ def set_contours(self, n_contours): ), ) - def set_vmax(self, vmax, type="meg"): + def set_vmax(self, vmax, kind="meg"): """Change the color range of the density maps. Parameters ---------- vmax : float The new maximum value of the color range. - type : 'meg' | 'eeg' + kind : 'meg' | 'eeg' Which field map to apply the new color range to. """ - _check_option("type", type, ["eeg", "meg"]) + _check_option("type", kind, ["eeg", "meg"]) for surf_map in self._surf_maps: - if surf_map["map_kind"] == type: + if surf_map["map_kind"] == kind: publish( self, ColormapRange( - kind=f"field_strength_{type}", + kind=f"field_strength_{kind}", fmin=-vmax, fmax=vmax, ), ) - break + break else: raise ValueError(f"No {type.upper()} field map currently shown.") @@ -571,4 +572,4 @@ def _rescale(self): for surf_map in self._surf_maps: current_data = surf_map["data_interp"](self._current_time) vmax = float(np.max(current_data)) - self.set_vmax(vmax, type=surf_map["map_kind"]) + self.set_vmax(vmax, kind=surf_map["map_kind"]) diff --git a/mne/viz/eyetracking/heatmap.py b/mne/viz/eyetracking/heatmap.py index 8cb44ac4931..e6e6832084e 100644 --- a/mne/viz/eyetracking/heatmap.py +++ b/mne/viz/eyetracking/heatmap.py @@ -6,16 +6,18 @@ import numpy as np from scipy.ndimage import gaussian_filter -from ...utils import _ensure_int, _validate_type, fill_doc, logger +from ..._fiff.constants import FIFF +from ...utils import _validate_type, fill_doc, logger from ..utils import plt_show @fill_doc def plot_gaze( epochs, - width, - height, *, + calibration=None, + width=None, + height=None, sigma=25, cmap=None, alpha=1.0, @@ -29,14 +31,17 @@ def plot_gaze( ---------- epochs : instance of Epochs The :class:`~mne.Epochs` object containing eyegaze channels. + calibration : instance of Calibration | None + An instance of Calibration with information about the screen size, distance, + and resolution. If ``None``, you must provide a width and height. width : int - The width dimension of the plot canvas. For example, if the eyegaze data units - are pixels, and the participant screen resolution was 1920x1080, then the width - should be 1920. + The width dimension of the plot canvas, only valid if eyegaze data are in + pixels. For example, if the participant screen resolution was 1920x1080, then + the width should be 1920. height : int - The height dimension of the plot canvas. For example, if the eyegaze data units - are pixels, and the participant screen resolution was 1920x1080, then the height - should be 1080. + The height dimension of the plot canvas, only valid if eyegaze data are in + pixels. For example, if the participant screen resolution was 1920x1080, then + the height should be 1080. sigma : float | None The amount of Gaussian smoothing applied to the heatmap data (standard deviation in pixels). If ``None``, no smoothing is applied. Default is 25. @@ -59,17 +64,22 @@ def plot_gaze( from mne import BaseEpochs from mne._fiff.pick import _picks_to_idx + from ...preprocessing.eyetracking.utils import ( + _check_calibration, + get_screen_visual_angle, + ) + _validate_type(epochs, BaseEpochs, "epochs") _validate_type(alpha, "numeric", "alpha") _validate_type(sigma, ("numeric", None), "sigma") - width = _ensure_int(width, "width") - height = _ensure_int(height, "height") + # Get the gaze data pos_picks = _picks_to_idx(epochs.info, "eyegaze") gaze_data = epochs.get_data(picks=pos_picks) gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks]) x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :] y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :] + unit = epochs.info["chs"][pos_picks[0]]["unit"] # assumes all units are the same if x_data.shape[1] > 1: # binocular recording. Average across eyes logger.info("Detected binocular recording. Averaging positions across eyes.") @@ -77,13 +87,53 @@ def plot_gaze( y_data = np.nanmean(y_data, axis=1) canvas = np.vstack((x_data.flatten(), y_data.flatten())) # shape (2, n_samples) + # Check that we have the right inputs + if calibration is not None: + if width is not None or height is not None: + raise ValueError( + "If a calibration is provided, you cannot provide a width or height" + " to plot heatmaps. Please provide only the calibration object." + ) + _check_calibration(calibration) + if unit == FIFF.FIFF_UNIT_PX: + width, height = calibration["screen_resolution"] + elif unit == FIFF.FIFF_UNIT_RAD: + width, height = calibration["screen_size"] + else: + raise ValueError( + f"Invalid unit type: {unit}. gaze data Must be pixels or radians." + ) + else: + if width is None or height is None: + raise ValueError( + "If no calibration is provided, you must provide a width and height" + " to plot heatmaps." + ) + # Create 2D histogram - # Bin into image-like format + # We need to set the histogram bins & bounds, and imshow extent, based on the units + if unit == FIFF.FIFF_UNIT_PX: # pixel on screen + _range = [[0, height], [0, width]] + bins_x, bins_y = width, height + extent = [0, width, height, 0] + elif unit == FIFF.FIFF_UNIT_RAD: # radians of visual angle + if not calibration: + raise ValueError( + "If gaze data are in Radians, you must provide a" + " calibration instance to plot heatmaps." + ) + width, height = get_screen_visual_angle(calibration) + x_range = [-width / 2, width / 2] + y_range = [-height / 2, height / 2] + _range = [y_range, x_range] + extent = (x_range[0], x_range[1], y_range[0], y_range[1]) + bins_x, bins_y = calibration["screen_resolution"] + hist, _, _ = np.histogram2d( canvas[1, :], canvas[0, :], - bins=(height, width), - range=[[0, height], [0, width]], + bins=(bins_y, bins_x), + range=_range, ) # Convert density from samples to seconds hist /= epochs.info["sfreq"] @@ -99,6 +149,7 @@ def plot_gaze( alpha=alpha, vmin=vlim[0], vmax=vlim[1], + extent=extent, axes=axes, show=show, ) @@ -108,10 +159,12 @@ def _plot_heatmap_array( data, width, height, + *, cmap=None, alpha=None, vmin=None, vmax=None, + extent=None, axes=None, show=True, ): @@ -136,7 +189,8 @@ def _plot_heatmap_array( alphas = 1 if alpha is None else alpha vmin = np.nanmin(data) if vmin is None else vmin vmax = np.nanmax(data) if vmax is None else vmax - extent = [0, width, height, 0] # origin is the top left of the screen + if extent is None: + extent = [0, width, height, 0] # Plot heatmap im = ax.imshow( diff --git a/mne/viz/eyetracking/tests/test_heatmap.py b/mne/viz/eyetracking/tests/test_heatmap.py index a088c1dc7fe..0f0b0bfc4d5 100644 --- a/mne/viz/eyetracking/tests/test_heatmap.py +++ b/mne/viz/eyetracking/tests/test_heatmap.py @@ -4,33 +4,57 @@ # Copyright the MNE-Python contributors. import matplotlib.pyplot as plt -import numpy as np import pytest import mne +from mne._fiff.constants import FIFF -@pytest.mark.parametrize("axes", [None, True]) -def test_plot_heatmap(axes): +@pytest.mark.parametrize("axes, unit", [(None, "px"), (True, "rad")]) +def test_plot_heatmap(eyetrack_raw, eyetrack_cal, axes, unit): """Test plot_gaze.""" - # Create a toy epochs instance - info = info = mne.create_info( - ch_names=["xpos", "ypos"], sfreq=100, ch_types="eyegaze" - ) - # simulate a steady fixation at the center of the screen - width, height = (1920, 1080) - shape = (1, 100) # x or y, time - data = np.vstack([np.full(shape, width / 2), np.full(shape, height / 2)]) - epochs = mne.EpochsArray(data[None, ...], info) - epochs.info["chs"][0]["loc"][4] = -1 - epochs.info["chs"][1]["loc"][4] = 1 + epochs = mne.make_fixed_length_epochs(eyetrack_raw, duration=1.0) + epochs.load_data() + width, height = eyetrack_cal["screen_resolution"] # 1920, 1080 + if unit == "rad": + mne.preprocessing.eyetracking.convert_units(epochs, eyetrack_cal, to="radians") if axes: axes = plt.subplot() - fig = mne.viz.eyetracking.plot_gaze( - epochs, width=width, height=height, axes=axes, cmap="Greys", sigma=None - ) + + # First check that we raise errors when we should + with pytest.raises(ValueError, match="If no calibration is provided"): + mne.viz.eyetracking.plot_gaze(epochs) + + with pytest.raises(ValueError, match="If a calibration is provided"): + mne.viz.eyetracking.plot_gaze( + epochs, width=width, height=height, calibration=eyetrack_cal + ) + + with pytest.raises(ValueError, match="Invalid unit"): + ep_bad = epochs.copy() + ep_bad.info["chs"][0]["unit"] = FIFF.FIFF_UNIT_NONE + mne.viz.eyetracking.plot_gaze(ep_bad, calibration=eyetrack_cal) + + # raise an error if no calibration object is provided for radian data + if unit == "rad": + with pytest.raises(ValueError, match="If gaze data are in Radians"): + mne.viz.eyetracking.plot_gaze(epochs, axes=axes, width=1, height=1) + + # Now check that we get the expected output + if unit == "px": + fig = mne.viz.eyetracking.plot_gaze( + epochs, width=width, height=height, axes=axes, cmap="Greys", sigma=None + ) + elif unit == "rad": + fig = mne.viz.eyetracking.plot_gaze( + epochs, + calibration=eyetrack_cal, + axes=axes, + cmap="Greys", + sigma=None, + ) img = fig.axes[0].images[0].get_array() # We simulated a 2D histogram where only the central pixel (960, 540) was active - assert img.T[width // 2, height // 2] == 1 # central pixel is active - assert np.sum(img) == 1 # only the central pixel should be active + # so regardless of the unit, we should have a heatmap with the central bin active + assert img.T[width // 2, height // 2] == 1 diff --git a/mne/viz/ica.py b/mne/viz/ica.py index dcd585c37fe..1ec18fde1da 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -855,8 +855,18 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=No lines[-1].set_pickradius(3.0) ax.set(title=title, xlim=times[[0, -1]], xlabel="Time (ms)", ylabel="(NA)") - if len(exclude) > 0: - plt.legend(loc="best") + leg_lines_labels = list( + zip( + *[ + (line, label) + for line, label in zip(lines, exclude_labels) + if label is not None + ] + ) + ) + if len(leg_lines_labels): + leg_lines, leg_labels = leg_lines_labels + ax.legend(leg_lines, leg_labels, loc="best") texts.append( ax.text( diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 3d8a9469620..49b01ed6b16 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -130,7 +130,7 @@ def plot_cov( fig_cov : instance of matplotlib.figure.Figure The covariance plot. fig_svd : instance of matplotlib.figure.Figure | None - The SVD spectra plot of the covariance. + The SVD plot of the covariance (i.e., the eigenvalues or "matrix spectrum"). See Also -------- @@ -869,7 +869,7 @@ def plot_events( continue y = np.full(count, idx + 1 if equal_spacing else events[ev_mask, 2][0]) if event_id is not None: - event_label = "%s (%s)" % (event_id_rev[ev], count) + event_label = f"{event_id_rev[ev]} ({count})" else: event_label = "N=%d" % (count,) labels.append(event_label) @@ -1025,7 +1025,7 @@ def _get_flim(flim, fscale, freq, sfreq=None): def _check_fscale(fscale): """Check for valid fscale.""" if not isinstance(fscale, str) or fscale not in ("log", "linear"): - raise ValueError('fscale must be "log" or "linear", got %s' % (fscale,)) + raise ValueError(f'fscale must be "log" or "linear", got {fscale}') _DEFAULT_ALIM = (-80, 10) @@ -1340,7 +1340,7 @@ def plot_ideal_filter( if freq[0] != 0: raise ValueError( "freq should start with DC (zero) and end with " - "Nyquist, but got %s for DC" % (freq[0],) + f"Nyquist, but got {freq[0]} for DC" ) freq = np.array(freq) # deal with semilogx problems @ x=0 @@ -1411,8 +1411,8 @@ def _handle_event_colors(color_dict, unique_events, event_id): if len(unassigned): unassigned_str = ", ".join(str(e) for e in unassigned) warn( - "Color was not assigned for event%s %s. Default colors will " - "be used." % (_pl(unassigned), unassigned_str) + f"Color was not assigned for event{_pl(unassigned)} {unassigned_str}. " + "Default colors will be used." ) default_colors.update(custom_colors) return default_colors @@ -1535,7 +1535,7 @@ def plot_csd( ax.set_xticks([]) ax.set_yticks([]) if csd._is_sum: - ax.set_title("%.1f-%.1f Hz." % (np.min(freq), np.max(freq))) + ax.set_title(f"{np.min(freq):.1f}-{np.max(freq):.1f} Hz.") else: ax.set_title("%.1f Hz." % freq) diff --git a/mne/viz/montage.py b/mne/viz/montage.py index 18ff3e1c2d7..935a306e0d9 100644 --- a/mne/viz/montage.py +++ b/mne/viz/montage.py @@ -1,6 +1,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. """Functions to plot EEG sensor montages or digitizer montages.""" + from copy import deepcopy import numpy as np @@ -72,9 +73,9 @@ def plot_montage( n_chans = pos.shape[0] n_dupes = dupes.shape[0] idx = np.setdiff1d(np.arange(len(pos)), dupes[:, 1]).tolist() - logger.info("{} duplicate electrode labels found:".format(n_dupes)) + logger.info(f"{n_dupes} duplicate electrode labels found:") logger.info(", ".join([ch_names[d[0]] + "/" + ch_names[d[1]] for d in dupes])) - logger.info("Plotting {} unique labels.".format(n_chans - n_dupes)) + logger.info(f"Plotting {n_chans - n_dupes} unique labels.") ch_names = [ch_names[i] for i in idx] ch_pos = dict(zip(ch_names, pos[idx, :])) # XXX: this might cause trouble if montage was originally in head diff --git a/mne/viz/raw.py b/mne/viz/raw.py index 65bfb08604e..dd90352d0cc 100644 --- a/mne/viz/raw.py +++ b/mne/viz/raw.py @@ -11,7 +11,7 @@ import numpy as np -from .._fiff.pick import pick_channels, pick_types +from .._fiff.pick import _picks_to_idx, pick_channels, pick_types from ..defaults import _handle_default from ..filter import create_filter from ..utils import _check_option, _get_stim_channel, _validate_type, legacy, verbose @@ -63,6 +63,7 @@ def plot_raw( time_format="float", precompute=None, use_opengl=None, + picks=None, *, theme=None, overview_mode=None, @@ -192,6 +193,7 @@ def plot_raw( %(time_format)s %(precompute)s %(use_opengl)s + %(picks_all)s %(theme_pg)s .. versionadded:: 1.0 @@ -310,7 +312,9 @@ def plot_raw( # determine trace order ch_names = np.array(raw.ch_names) ch_types = np.array(raw.get_channel_types()) - order = _get_channel_plotting_order(order, ch_types) + + picks = _picks_to_idx(info, picks, none="all", exclude=()) + order = _get_channel_plotting_order(order, ch_types, picks=picks) n_channels = min(info["nchan"], n_channels, len(order)) # adjust order based on channel selection, if needed selections = None diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 54ebf0fcd83..5109becb645 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -36,6 +36,7 @@ from mne._fiff.constants import FIFF from mne.bem import read_bem_solution, read_bem_surfaces from mne.datasets import testing +from mne.defaults import DEFAULTS from mne.io import read_info, read_raw_bti, read_raw_ctf, read_raw_kit, read_raw_nirx from mne.minimum_norm import apply_inverse from mne.source_estimate import _BaseVolSourceEstimate @@ -66,7 +67,7 @@ ctf_fname = data_dir / "CTF" / "testdata_ctf.ds" nirx_fname = data_dir / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" -io_dir = Path(__file__).parent.parent.parent / "io" +io_dir = Path(__file__).parents[2] / "io" base_dir = io_dir / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -125,11 +126,11 @@ def test_plot_sparse_source_estimates(renderer_interactive, brain_gc): vertices = [s["vertno"] for s in sample_src] n_time = 5 n_verts = sum(len(v) for v in vertices) - stc_data = np.zeros((n_verts * n_time)) + stc_data = np.zeros(n_verts * n_time) stc_size = stc_data.size - stc_data[ - (np.random.rand(stc_size // 20) * stc_size).astype(int) - ] = np.random.RandomState(0).rand(stc_data.size // 20) + stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = ( + np.random.RandomState(0).rand(stc_data.size // 20) + ) stc_data.shape = (n_verts, n_time) stc = SourceEstimate(stc_data, vertices, 1, 1) @@ -196,8 +197,16 @@ def test_plot_evoked_field(renderer): assert isinstance(fig, EvokedField) fig._rescale() fig.set_time(0.05) + assert fig._current_time == 0.05 fig.set_contours(10) - fig.set_vmax(2) + assert fig._n_contours == 10 + assert fig._widgets["contours"].get_value() == 10 + fig.set_vmax(2e-12, kind="meg") + assert fig._surf_maps[1]["contours"][-1] == 2e-12 + assert ( + fig._widgets["vmax_slider_meg"].get_value() + == DEFAULTS["scalings"]["grad"] * 2e-12 + ) fig = evoked.plot_field(maps, time_viewer=False) assert isinstance(fig, Figure3D) @@ -748,7 +757,7 @@ def test_process_clim_plot(renderer_interactive, brain_gc): vertices = [s["vertno"] for s in sample_src] n_time = 5 n_verts = sum(len(v) for v in vertices) - stc_data = np.random.RandomState(0).rand((n_verts * n_time)) + stc_data = np.random.RandomState(0).rand(n_verts * n_time) stc_data.shape = (n_verts, n_time) stc = SourceEstimate(stc_data, vertices, 1, 1, "sample") @@ -870,7 +879,7 @@ def test_stc_mpl(): vertices = [s["vertno"] for s in sample_src] n_time = 5 n_verts = sum(len(v) for v in vertices) - stc_data = np.ones((n_verts * n_time)) + stc_data = np.ones(n_verts * n_time) stc_data.shape = (n_verts, n_time) stc = SourceEstimate(stc_data, vertices, 1, 1, "sample") stc.plot( @@ -1198,11 +1207,11 @@ def test_link_brains(renderer_interactive): vertices = [s["vertno"] for s in sample_src] n_time = 5 n_verts = sum(len(v) for v in vertices) - stc_data = np.zeros((n_verts * n_time)) + stc_data = np.zeros(n_verts * n_time) stc_size = stc_data.size - stc_data[ - (np.random.rand(stc_size // 20) * stc_size).astype(int) - ] = np.random.RandomState(0).rand(stc_data.size // 20) + stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = ( + np.random.RandomState(0).rand(stc_data.size // 20) + ) stc_data.shape = (n_verts, n_time) stc = SourceEstimate(stc_data, vertices, 1, 1) diff --git a/mne/viz/tests/test_3d_mpl.py b/mne/viz/tests/test_3d_mpl.py index 2b46a688a13..b006a421494 100644 --- a/mne/viz/tests/test_3d_mpl.py +++ b/mne/viz/tests/test_3d_mpl.py @@ -89,7 +89,7 @@ def test_plot_volume_source_estimates( log = log.getvalue() want_str = "t = %0.3f s" % want_t assert want_str in log, (want_str, init_t) - want_str = "(%0.1f, %0.1f, %0.1f) mm" % want_p + want_str = f"({want_p[0]:0.1f}, {want_p[1]:0.1f}, {want_p[2]:0.1f}) mm" assert want_str in log, (want_str, init_p) for ax_idx in [0, 2, 3, 4]: _fake_click(fig, fig.axes[ax_idx], (0.3, 0.5)) diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index cf1c07b7d85..9679a787277 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -17,6 +17,7 @@ from mne import Epochs, EpochsArray, create_info from mne.datasets import testing from mne.event import make_fixed_length_events +from mne.utils import _record_warnings from mne.viz import plot_drop_log @@ -52,13 +53,13 @@ def test_plot_epochs_basic(epochs, epochs_full, noise_cov_io, capsys, browser_ba browser_backend._close_all() # add a channel to cov['bads'] noise_cov_io["bads"] = [epochs.ch_names[1]] - with pytest.warns(RuntimeWarning, match="projection"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"): epochs.plot(noise_cov=noise_cov_io) browser_backend._close_all() # have a data channel missing from the covariance noise_cov_io["names"] = noise_cov_io["names"][:306] noise_cov_io["data"] = noise_cov_io["data"][:306][:306] - with pytest.warns(RuntimeWarning, match="projection"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"): epochs.plot(noise_cov=noise_cov_io) browser_backend._close_all() # other options @@ -300,7 +301,10 @@ def test_plot_epochs_image(epochs): picks=[0, 1], order=lambda times, data: np.arange(len(data))[::-1] ) # test warning - with pytest.warns(RuntimeWarning, match="Only one channel in group"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="Only one channel in group"), + ): epochs.plot_image(picks=[1], combine="mean") # group_by should be a dict with pytest.raises(TypeError, match="dict or None"): @@ -396,9 +400,9 @@ def test_plot_psd_epochs(epochs): """Test plotting epochs psd (+topomap).""" spectrum = epochs.compute_psd() old_defaults = dict(picks="data", exclude="bads") - spectrum.plot(average=True, spatial_colors=False, **old_defaults) - spectrum.plot(average=False, spatial_colors=True, **old_defaults) - spectrum.plot(average=False, spatial_colors=False, **old_defaults) + spectrum.plot(average=True, amplitude=False, spatial_colors=False, **old_defaults) + spectrum.plot(average=False, amplitude=False, spatial_colors=True, **old_defaults) + spectrum.plot(average=False, amplitude=False, spatial_colors=False, **old_defaults) # test plot_psd_topomap errors with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=dict(foo=(0, 0.01))) @@ -418,7 +422,7 @@ def test_plot_psd_epochs(epochs): err_str = "for channel %s" % epochs.ch_names[2] epochs.get_data(copy=False)[0, 2, :] = 0 for dB in [True, False]: - with pytest.warns(UserWarning, match=err_str): + with _record_warnings(), pytest.warns(UserWarning, match=err_str): epochs.compute_psd().plot(dB=dB) @@ -492,12 +496,12 @@ def test_plot_psd_epochs_ctf(raw_ctf): epochs = Epochs(raw_ctf, evts, preload=True) old_defaults = dict(picks="data", exclude="bads") # EEG060 is flat in this dataset - with pytest.warns(UserWarning, match="for channel EEG060"): + with _record_warnings(), pytest.warns(UserWarning, match="for channel EEG060"): spectrum = epochs.compute_psd() for dB in [True, False]: spectrum.plot(dB=dB) spectrum.drop_channels(["EEG060"]) - spectrum.plot(spatial_colors=False, average=False, **old_defaults) + spectrum.plot(spatial_colors=False, average=False, amplitude=False, **old_defaults) with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=[(0, 0.01, "foo")]) spectrum.plot_topomap() diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index b44a33385b2..e177df6a9b8 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -34,11 +34,11 @@ from mne.datasets import testing from mne.io import read_raw_fif from mne.stats.parametric import _parametric_ci -from mne.utils import catch_logging +from mne.utils import _record_warnings, catch_logging from mne.viz import plot_compare_evokeds, plot_evoked_white from mne.viz.utils import _fake_click, _get_cmap -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" raw_fname = base_dir / "test_raw.fif" raw_sss_fname = base_dir / "test_chpi_raw_sss.fif" @@ -119,7 +119,7 @@ def test_plot_evoked_cov(): epochs = Epochs(raw, events, picks=default_picks) cov = compute_covariance(epochs) evoked_sss = epochs.average() - with pytest.warns(RuntimeWarning, match="relative scaling"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="relative scaling"): evoked_sss.plot(noise_cov=cov, time_unit="s") plt.close("all") @@ -333,7 +333,7 @@ def test_plot_evoked_image(): mask=np.ones(evoked.data.shape).astype(bool), time_unit="s", ) - with pytest.warns(RuntimeWarning, match="not adding contour"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="not adding contour"): evoked.plot_image(picks=[1, 2], mask=None, mask_style="both", time_unit="s") with pytest.raises(ValueError, match="must have the same shape"): evoked.plot_image(mask=evoked.data[1:, 1:] > 0, time_unit="s") @@ -402,28 +402,41 @@ def test_plot_white(): evoked_sss.plot_white(cov, time_unit="s") +@pytest.mark.parametrize( + "combine,vlines,title,picks", + ( + pytest.param(None, [0.1, 0.2], "MEG 0113", "MEG 0113", id="singlepick"), + pytest.param("mean", [], "(mean)", "mag", id="mag-mean"), + pytest.param("gfp", "auto", "(GFP)", "eeg", id="eeg-gfp"), + pytest.param(None, "auto", "(RMS)", ["MEG 0113", "MEG 0112"], id="meg-rms"), + pytest.param( + "std", "auto", "(std. dev.)", ["MEG 0113", "MEG 0112"], id="meg-std" + ), + pytest.param( + lambda x: np.min(x, axis=1), "auto", "MEG 0112", [0, 1], id="intpicks" + ), + ), +) +def test_plot_compare_evokeds_title(evoked, picks, vlines, combine, title): + """Test title generation by plot_compare_evokeds().""" + # test picks, combine, and vlines (1-channel pick also shows sensor inset) + fig = plot_compare_evokeds(evoked, picks=picks, vlines=vlines, combine=combine) + assert fig[0].axes[0].get_title().endswith(title) + + @pytest.mark.slowtest # slow on Azure -def test_plot_compare_evokeds(): +def test_plot_compare_evokeds(evoked): """Test plot_compare_evokeds.""" - evoked = _get_epochs().average() # test defaults figs = plot_compare_evokeds(evoked) assert len(figs) == 3 - # test picks, combine, and vlines (1-channel pick also shows sensor inset) - picks = ["MEG 0113", "mag"] + 2 * [["MEG 0113", "MEG 0112"]] + [[0, 1]] - vlines = [[0.1, 0.2], []] + 3 * ["auto"] - combine = [None, "mean", "std", None, lambda x: np.min(x, axis=1)] - title = ["MEG 0113", "(mean)", "(std. dev.)", "(GFP)", "MEG 0112"] - for _p, _v, _c, _t in zip(picks, vlines, combine, title): - fig = plot_compare_evokeds(evoked, picks=_p, vlines=_v, combine=_c) - assert fig[0].axes[0].get_title().endswith(_t) # test passing more than one evoked red, blue = evoked.copy(), evoked.copy() red.comment = red.comment + "*" * 100 red.data *= 1.5 blue.data /= 1.5 evoked_dict = {"aud/l": blue, "aud/r": red, "vis": evoked} - huge_dict = {"cond{}".format(i): ev for i, ev in enumerate([evoked] * 11)} + huge_dict = {f"cond{i}": ev for i, ev in enumerate([evoked] * 11)} plot_compare_evokeds(evoked_dict) # dict plot_compare_evokeds([[red, evoked], [blue, evoked]]) # list of lists figs = plot_compare_evokeds({"cond": [blue, red, evoked]}) # dict of list @@ -438,6 +451,17 @@ def test_plot_compare_evokeds(): yvals = line.get_ydata() assert (yvals < ylim[1]).all() assert (yvals > ylim[0]).all() + # test plotting eyetracking data + plt.close("all") # close the previous figures as to avoid a too many figs warning + info_tmp = mne.create_info(["pupil_left"], evoked.info["sfreq"], ["pupil"]) + evoked_et = mne.EvokedArray(np.ones_like(evoked.times).reshape(1, -1), info_tmp) + figs = plot_compare_evokeds(evoked_et, show_sensors=False) + assert len(figs) == 1 + # test plotting only invalid channel types + info_tmp = mne.create_info(["ias"], evoked.info["sfreq"], ["ias"]) + ev_invalid = mne.EvokedArray(np.ones_like(evoked.times).reshape(1, -1), info_tmp) + with pytest.raises(RuntimeError, match="No valid"): + plot_compare_evokeds(ev_invalid, picks="all") plt.close("all") # test other CI args diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 421d844e127..39d4b616431 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -26,7 +26,7 @@ from mne.viz.ica import _create_properties_layout, plot_ica_properties from mne.viz.utils import _fake_click, _fake_keypress -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" raw_fname = base_dir / "test_raw.fif" cov_fname = base_dir / "test-cov.fif" @@ -157,7 +157,7 @@ def test_plot_ica_properties(): ) ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_iter=1, random_state=0) - with pytest.warns(RuntimeWarning, match="projection"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"): ica.fit(raw) # test _create_properties_layout @@ -240,7 +240,7 @@ def test_plot_ica_properties(): # Test handling of zeros ica = ICA(random_state=0, max_iter=1) epochs.pick(pick_names) - with pytest.warns(UserWarning, match="did not converge"): + with _record_warnings(), pytest.warns(UserWarning, match="did not converge"): ica.fit(epochs) epochs._data[0] = 0 # Usually UserWarning: Infinite value .* for epo @@ -254,7 +254,7 @@ def test_plot_ica_properties(): raw_annot.pick(np.arange(10)) raw_annot.del_proj() - with pytest.warns(UserWarning, match="did not converge"): + with _record_warnings(), pytest.warns(UserWarning, match="did not converge"): ica.fit(raw_annot) # drop bad data segments fig = ica.plot_properties(raw_annot, picks=[0, 1], **topoargs) @@ -362,12 +362,15 @@ def test_plot_ica_sources(raw_orig, browser_backend, monkeypatch): ica.plot_sources(epochs) ica.plot_sources(epochs.average()) evoked = epochs.average() + ica.exclude = [0] fig = ica.plot_sources(evoked) # Test a click ax = fig.get_axes()[0] line = ax.lines[0] _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], "data") _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], "data") + leg = ax.get_legend() + assert len(leg.get_texts()) == len(ica.exclude) == 1 # plot with bad channels excluded ica.exclude = [0] diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py index 49d3e7219bb..aa0fa0f1959 100644 --- a/mne/viz/tests/test_misc.py +++ b/mne/viz/tests/test_misc.py @@ -30,6 +30,7 @@ from mne.io import read_raw_fif from mne.minimum_norm import read_inverse_operator from mne.time_frequency import CrossSpectralDensity +from mne.utils import _record_warnings from mne.viz import ( plot_bem, plot_chpi_snr, @@ -51,7 +52,7 @@ evoked_fname = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" dip_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc_set1.dip" chpi_fif_fname = data_path / "SSS" / "test_move_anon_raw.fif" -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" cov_fname = base_dir / "test-cov.fif" event_fname = base_dir / "test-eve.fif" @@ -214,7 +215,10 @@ def test_plot_events(): assert fig.axes[0].get_legend() is not None with pytest.warns(RuntimeWarning, match="Color was not assigned"): plot_events(events, raw.info["sfreq"], raw.first_samp, color=color) - with pytest.warns(RuntimeWarning, match=r"vent \d+ missing from event_id"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match=r"vent \d+ missing from event_id"), + ): plot_events( events, raw.info["sfreq"], @@ -223,7 +227,7 @@ def test_plot_events(): color=color, ) multimatch = r"event \d+ missing from event_id|in the color dict but is" - with pytest.warns(RuntimeWarning, match=multimatch): + with _record_warnings(), pytest.warns(RuntimeWarning, match=multimatch): plot_events( events, raw.info["sfreq"], @@ -243,7 +247,10 @@ def test_plot_events(): on_missing="ignore", ) extra_id = {"aud_l": 1, "missing": 111} - with pytest.warns(RuntimeWarning, match="from event_id is not present in"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="from event_id is not present in"), + ): plot_events( events, raw.info["sfreq"], @@ -251,7 +258,7 @@ def test_plot_events(): event_id=extra_id, on_missing="warn", ) - with pytest.warns(RuntimeWarning, match="event 2 missing"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="event 2 missing"): plot_events( events, raw.info["sfreq"], diff --git a/mne/viz/tests/test_montage.py b/mne/viz/tests/test_montage.py index 0a95cdbbb55..332ca82a6a4 100644 --- a/mne/viz/tests/test_montage.py +++ b/mne/viz/tests/test_montage.py @@ -15,12 +15,12 @@ from mne.channels import make_dig_montage, make_standard_montage, read_dig_fif -p_dir = Path(__file__).parent.parent.parent / "io" / "kit" / "tests" / "data" +p_dir = Path(__file__).parents[2] / "io" / "kit" / "tests" / "data" elp = p_dir / "test_elp.txt" hsp = p_dir / "test_hsp.txt" hpi = p_dir / "test_mrk.sqd" point_names = ["nasion", "lpa", "rpa", "1", "2", "3", "4", "5"] -io_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +io_dir = Path(__file__).parents[2] / "io" / "tests" / "data" fif_fname = io_dir / "test_raw.fif" diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 89619d36e2f..031f3d34392 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -11,7 +11,7 @@ import numpy as np import pytest from matplotlib import backend_bases -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_array_equal from mne import Annotations, create_info, pick_types from mne._fiff.pick import _DATA_CH_TYPES_ORDER_DEFAULT, _PICK_TYPES_DATA_DICT @@ -541,6 +541,7 @@ def test_plot_raw_traces(raw, events, browser_backend): ismpl = browser_backend.name == "matplotlib" with raw.info._unlock(): raw.info["lowpass"] = 10.0 # allow heavy decim during plotting + assert raw.info["bads"] == [] fig = raw.plot( events=events, order=[1, 7, 5, 2, 3], n_channels=3, group_by="original" ) @@ -623,6 +624,30 @@ def test_plot_raw_traces(raw, events, browser_backend): raw.plot(event_color={"foo": "r"}) plot_raw(raw, events=events, event_color={-1: "r", 998: "b"}) + # gh-12547 + raw.info["bads"] = raw.ch_names[1:2] + picks = [1, 7, 5, 2, 3] + fig = raw.plot(events=events, order=picks, group_by="original") + assert_array_equal(fig.mne.picks, picks) + + +def test_plot_raw_picks(raw, browser_backend): + """Test functionality of picks and order arguments.""" + with raw.info._unlock(): + raw.info["lowpass"] = 10.0 # allow heavy decim during plotting + + fig = raw.plot(picks=["MEG 0112"]) + assert len(fig.mne.traces) == 1 + + fig = raw.plot(picks=["meg"]) + assert len(fig.mne.traces) == len(raw.get_channel_types(picks="meg")) + + fig = raw.plot(order=[4, 3]) + assert_array_equal(fig.mne.ch_order, np.array([4, 3])) + + fig = raw.plot(picks=[4, 3]) + assert_array_equal(fig.mne.ch_order, np.array([3, 4])) + @pytest.mark.parametrize("group_by", ("position", "selection")) def test_plot_raw_groupby(raw, browser_backend, group_by): @@ -938,29 +963,33 @@ def test_plot_raw_psd(raw, raw_orig): spectrum = raw.compute_psd() # deprecation change handler old_defaults = dict(picks="data", exclude="bads") - fig = spectrum.plot(average=False) + fig = spectrum.plot(average=False, amplitude=False) # normal mode - fig = spectrum.plot(average=False, **old_defaults) + fig = spectrum.plot(average=False, amplitude=False, **old_defaults) fig.canvas.callbacks.process( "resize_event", backend_bases.ResizeEvent("resize_event", fig.canvas) ) # specific mode picks = pick_types(spectrum.info, meg="mag", eeg=False)[:4] - spectrum.plot(picks=picks, ci="range", spatial_colors=True, exclude="bads") - raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4, **old_defaults) + spectrum.plot( + picks=picks, ci="range", spatial_colors=True, exclude="bads", amplitude=False + ) + raw.compute_psd(tmax=20.0).plot( + color="yellow", dB=False, alpha=0.4, amplitude=True, **old_defaults + ) plt.close("all") # one axes supplied ax = plt.axes() - spectrum.plot(picks=picks, axes=ax, average=True, exclude="bads") + spectrum.plot(picks=picks, axes=ax, average=True, exclude="bads", amplitude=False) plt.close("all") # two axes supplied _, axs = plt.subplots(2) - spectrum.plot(axes=axs, average=True, **old_defaults) + spectrum.plot(axes=axs, average=True, amplitude=False, **old_defaults) plt.close("all") # need 2, got 1 ax = plt.axes() with pytest.raises(ValueError, match="of length 2.*the length is 1"): - spectrum.plot(axes=ax, average=True, **old_defaults) + spectrum.plot(axes=ax, average=True, amplitude=False, **old_defaults) plt.close("all") # topo psd ax = plt.subplot() @@ -969,7 +998,10 @@ def test_plot_raw_psd(raw, raw_orig): # with channel information not available for idx in range(len(raw.info["chs"])): raw.info["chs"][idx]["loc"] = np.zeros(12) - with pytest.warns(RuntimeWarning, match="locations not available"): + with ( + _record_warnings(), + pytest.warns(RuntimeWarning, match="locations not available"), + ): raw.compute_psd().plot(spatial_colors=True, average=False) # with a flat channel raw[5, :] = 0 @@ -981,14 +1013,13 @@ def test_plot_raw_psd(raw, raw_orig): # check grad axes title = fig.axes[0].get_title() ylabel = fig.axes[0].get_ylabel() - ends_dB = ylabel.endswith("mathrm{(dB)}$") unit = r"fT/cm/\sqrt{Hz}" if amplitude else "(fT/cm)²/Hz" assert title == "Gradiometers", title assert unit in ylabel, ylabel if dB: - assert ends_dB, ylabel + assert "dB" in ylabel else: - assert not ends_dB, ylabel + assert "dB" not in ylabel # check mag axes title = fig.axes[1].get_title() ylabel = fig.axes[1].get_ylabel() @@ -1006,8 +1037,8 @@ def test_plot_raw_psd(raw, raw_orig): raw = raw_orig.crop(0, 1) picks = pick_types(raw.info, meg=True) spectrum = raw.compute_psd(picks=picks) - spectrum.plot(average=False, **old_defaults) - spectrum.plot(average=True, **old_defaults) + spectrum.plot(average=False, amplitude=False, **old_defaults) + spectrum.plot(average=True, amplitude=False, **old_defaults) plt.close("all") raw.set_channel_types( { @@ -1018,7 +1049,7 @@ def test_plot_raw_psd(raw, raw_orig): }, verbose="error", ) - fig = raw.compute_psd().plot(**old_defaults) + fig = raw.compute_psd().plot(amplitude=False, **old_defaults) assert len(fig.axes) == 10 plt.close("all") @@ -1029,7 +1060,7 @@ def test_plot_raw_psd(raw, raw_orig): raw = RawArray(data, info) picks = pick_types(raw.info, misc=True) spectrum = raw.compute_psd(picks=picks, n_fft=n_fft) - spectrum.plot(spatial_colors=False, picks=picks, exclude="bads") + spectrum.plot(spatial_colors=False, picks=picks, exclude="bads", amplitude=False) plt.close("all") diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index fc421136c94..344572dcfc9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -18,7 +18,7 @@ from mne import Epochs, compute_proj_evoked, read_cov, read_events from mne.channels import read_layout from mne.io import read_raw_fif -from mne.time_frequency.tfr import AverageTFR +from mne.time_frequency.tfr import AverageTFRArray from mne.utils import _record_warnings from mne.viz import ( _get_presser, @@ -30,7 +30,7 @@ from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography from mne.viz.utils import _fake_click -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" @@ -309,23 +309,25 @@ def test_plot_tfr_topo(): data = np.random.RandomState(0).randn( len(epochs.ch_names), n_freqs, len(epochs.times) ) - tfr = AverageTFR(epochs.info, data, epochs.times, np.arange(n_freqs), nave) - plt.close("all") - fig = tfr.plot_topo( - baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0 + tfr = AverageTFRArray( + info=epochs.info, + data=data, + times=epochs.times, + freqs=np.arange(n_freqs), + nave=nave, ) + plt.close("all") + fig = tfr.plot_topo(baseline=(None, 0), mode="ratio", vmin=0.0, vmax=14.0) # test complex tfr.data = tfr.data * (1 + 1j) plt.close("all") - fig = tfr.plot_topo( - baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0 - ) + fig = tfr.plot_topo(baseline=(None, 0), mode="ratio", vmin=0.0, vmax=14.0) # test opening tfr by clicking num_figures_before = len(plt.get_fignums()) # could use np.reshape(fig.axes[-1].images[0].get_extent(), (2, 2)).mean(1) - with pytest.warns(RuntimeWarning, match="not masking"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="not masking"): _fake_click(fig, fig.axes[0], (0.08, 0.65)) assert num_figures_before + 1 == len(plt.get_fignums()) plt.close("all") @@ -335,21 +337,30 @@ def test_plot_tfr_topo(): # nonuniform freqs freqs = np.logspace(*np.log10([3, 10]), num=3) - tfr = AverageTFR(epochs.info, data, epochs.times, freqs, nave) - fig = tfr.plot([4], baseline=(None, 0), mode="mean", vmax=14.0, show=False) + tfr = AverageTFRArray( + info=epochs.info, data=data, times=epochs.times, freqs=freqs, nave=nave + ) + fig = tfr.plot([4], baseline=(None, 0), mode="mean", vlim=(None, 14.0), show=False) assert fig[0].axes[0].get_yaxis().get_scale() == "log" # one timesample - tfr = AverageTFR(epochs.info, data[:, :, [0]], epochs.times[[1]], freqs, nave) + tfr = AverageTFRArray( + info=epochs.info, + data=data[:, :, [0]], + times=epochs.times[[1]], + freqs=freqs, + nave=nave, + ) + with _record_warnings(): # matplotlib equal left/right - tfr.plot([4], baseline=None, vmax=14.0, show=False, yscale="linear") + tfr.plot([4], baseline=None, vlim=(None, 14.0), show=False, yscale="linear") # one frequency bin, log scale required: as it doesn't make sense # to plot log scale for one value, we test whether yscale is set to linear vmin, vmax = 0.0, 2.0 fig, ax = plt.subplots() tmin, tmax = epochs.times[0], epochs.times[-1] - with pytest.warns(RuntimeWarning, match="not masking"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="not masking"): _imshow_tfr( ax, 3, @@ -372,7 +383,7 @@ def test_plot_tfr_topo(): # ValueError when freq[0] == 0 and yscale == 'log' these_freqs = freqs[:3].copy() these_freqs[0] = 0 - with pytest.warns(RuntimeWarning, match="not masking"): + with _record_warnings(), pytest.warns(RuntimeWarning, match="not masking"): pytest.raises( ValueError, _imshow_tfr, diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 33ae3cc645c..3ac6bb108a2 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ from mne.datasets import testing from mne.io import RawArray, read_info, read_raw_fif from mne.preprocessing import compute_bridged_electrodes -from mne.time_frequency.tfr import AverageTFR +from mne.time_frequency.tfr import AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -63,7 +63,7 @@ ecg_fname = data_dir / "MEG" / "sample" / "sample_audvis_ecg-proj.fif" triux_fname = data_dir / "SSS" / "TRIUX" / "triux_bmlhus_erm_raw.fif" -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" @@ -578,13 +578,21 @@ def test_plot_tfr_topomap(): data = rng.randn(len(picks), n_freqs, len(times)) # test complex numbers - tfr = AverageTFR(info, data * (1 + 1j), times, np.arange(n_freqs), nave) + tfr = AverageTFRArray( + info=info, + data=data * (1 + 1j), + times=times, + freqs=np.arange(n_freqs), + nave=nave, + ) tfr.plot_topomap( ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) # test real numbers - tfr = AverageTFR(info, data, times, np.arange(n_freqs), nave) + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave + ) tfr.plot_topomap( ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index f0679563da3..cb9e40b583c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -30,7 +30,7 @@ concatenate_images, ) -base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" +base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" cov_fname = base_dir / "test-cov.fif" ev_fname = base_dir / "test_raw-eve.fif" diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 9c3f7c5bd75..11f6695e834 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -16,7 +16,7 @@ from .._fiff.pick import channel_type, pick_types from ..defaults import _handle_default -from ..utils import Bunch, _check_option, _clean_names, _to_rgb, fill_doc +from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc from .utils import ( DraggableColorbar, _check_cov, @@ -428,7 +428,6 @@ def _imshow_tfr( cnorm=None, ): """Show time-frequency map as two-dimensional image.""" - from matplotlib import pyplot as plt from matplotlib.widgets import RectangleSelector _check_option("yscale", yscale, ["auto", "linear", "log"]) @@ -460,7 +459,7 @@ def _imshow_tfr( if isinstance(colorbar, DraggableColorbar): cbar = colorbar.cbar # this happens with multiaxes case else: - cbar = plt.colorbar(mappable=img, ax=ax) + cbar = ax.get_figure().colorbar(mappable=img, ax=ax) if interactive_cmap: ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None) ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept @@ -555,7 +554,7 @@ def _format_coord(x, y, labels, ax): if "(" in xlabel and ")" in xlabel else "s" ) - timestr = "%6.3f %s: " % (x, xunit) + timestr = f"{x:6.3f} {xunit}: " if not nearby: return "%s Nothing here" % timestr labels = [""] * len(nearby) if labels is None else labels @@ -574,11 +573,9 @@ def _format_coord(x, y, labels, ax): s = timestr for data_, label, tvec in nearby_data: idx = np.abs(tvec - x).argmin() - s += "%7.2f %s" % (data_[ch_idx, idx], yunit) + s += f"{data_[ch_idx, idx]:7.2f} {yunit}" if trunc_labels: - label = ( - label if len(label) <= 10 else "%s..%s" % (label[:6], label[-2:]) - ) + label = label if len(label) <= 10 else f"{label[:6]}..{label[-2:]}" s += " [%s] " % label if label else " " return s @@ -631,10 +628,14 @@ def _rm_cursor(event): else: ax.set_ylabel(y_label) - if vline: - plt.axvline(vline, color=hvline_color, linewidth=1.0, linestyle="--") - if hline: - plt.axhline(hline, color=hvline_color, linewidth=1.0, zorder=10) + if vline is not None: + vline = [vline] if _is_numeric(vline) else vline + for vline_ in vline: + plt.axvline(vline_, color=hvline_color, linewidth=1.0, linestyle="--") + if hline is not None: + hline = [hline] if _is_numeric(hline) else hline + for hline_ in hline: + plt.axhline(hline_, color=hvline_color, linewidth=1.0, zorder=10) if colorbar: plt.colorbar() diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 0c2f6f273b0..5a6eac4f1ab 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -183,9 +183,9 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None): # Modify the nirs channel names to indicate they are to be merged # New names will have the form S1_D1xS2_D2 # More than two channels can overlap and be merged - for set in overlapping_channels: - idx = ch_names.index(set[0][:-4]) - new_name = "x".join(s[:-4] for s in set) + for set_ in overlapping_channels: + idx = ch_names.index(set_[0][:-4]) + new_name = "x".join(s[:-4] for s in set_) ch_names[idx] = new_name pos = np.array(pos)[:, :2] # 2D plot, otherwise interpolation bugs @@ -306,12 +306,12 @@ def _add_colorbar( cmap, *, title=None, - format=None, + format_=None, kind=None, ch_type=None, ): """Add a colorbar to an axis.""" - cbar = ax.figure.colorbar(im, format=format, shrink=0.6) + cbar = ax.figure.colorbar(im, format=format_, shrink=0.6) if cmap is not None and cmap[1]: ax.CB = DraggableColorbar(cbar, im, kind, ch_type) cax = cbar.ax @@ -376,8 +376,7 @@ def plot_projs_topomap( %(info_not_none)s Must be associated with the channels in the projectors. .. versionchanged:: 0.20 - The positional argument ``layout`` was deprecated and replaced - by ``info``. + The positional argument ``layout`` was replaced by ``info``. %(sensors_topomap)s %(show_names_topomap)s @@ -598,7 +597,7 @@ def _plot_projs_topomap( im, cmap, title=units, - format=cbar_fmt, + format_=cbar_fmt, kind="projs_topomap", ch_type=_ch_type, ) @@ -913,6 +912,7 @@ def _topomap_plot_sensors(pos_x, pos_y, sensors, ax): def _get_pos_outlines(info, picks, sphere, to_sphere=True): from ..channels.layout import _find_topomap_coords + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) ch_type = _get_plot_ch_type(pick_info(_simplify_info(info), picks), None) orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) @@ -1219,9 +1219,8 @@ def _plot_topomap( raise ValueError("Multiple channel types in Info structure. " + info_help) elif len(pos["chs"]) != data.shape[0]: raise ValueError( - "Number of channels in the Info object (%s) and " - "the data array (%s) do not match. " % (len(pos["chs"]), data.shape[0]) - + info_help + f"Number of channels in the Info object ({len(pos['chs'])}) and the " + f"data array ({data.shape[0]}) do not match." + info_help ) else: ch_type = ch_type.pop() @@ -1252,9 +1251,9 @@ def _plot_topomap( ) if pos.ndim != 2: error = ( - "{ndim}D array supplied as electrode positions, where a 2D " - "array was expected" - ).format(ndim=pos.ndim) + f"{pos.ndim}D array supplied as electrode positions, where a 2D array was " + "expected" + ) raise ValueError(error + " " + pos_help) elif pos.shape[1] == 3: error = ( @@ -1273,7 +1272,7 @@ def _plot_topomap( if len(data) != len(pos): raise ValueError( "Data and pos need to be of same length. Got data of " - "length %s, pos of length %s" % (len(data), len(pos)) + f"length {len(data)}, pos of length { len(pos)}" ) norm = min(data) >= 0 @@ -1472,7 +1471,7 @@ def _plot_ica_topomap( im, cmap, title="AU", - format="%3.2f", + format_="%3.2f", kind="ica_topomap", ch_type=ch_type, ) @@ -1711,7 +1710,7 @@ def plot_ica_components( im, cmap, title="AU", - format=cbar_fmt, + format_=cbar_fmt, kind="ica_comp_topomap", ch_type=ch_type, ) @@ -1893,7 +1892,6 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] # merging grads before rescaling makes ERDs visible @@ -1912,7 +1910,6 @@ def plot_tfr_topomap( itmin = idx[0] if tmax is not None: itmax = idx[-1] + 1 - # crop freqs ifmin, ifmax = None, None idx = np.where(_time_mask(tfr.freqs, fmin, fmax))[0] @@ -1920,8 +1917,7 @@ def plot_tfr_topomap( ifmax = idx[-1] + 1 data = data[:, ifmin:ifmax, itmin:itmax] - data = np.mean(np.mean(data, axis=2), axis=1)[:, np.newaxis] - + data = data.mean(axis=(1, 2))[:, np.newaxis] norm = False if np.min(data) < 0 else True vlim = _setup_vmin_vmax(data, *vlim, norm) cmap = _setup_cmap(cmap, norm=norm) @@ -1991,7 +1987,7 @@ def plot_tfr_topomap( im, cmap, title=units, - format=cbar_fmt, + format_=cbar_fmt, kind="tfr_topomap", ch_type=ch_type, ) @@ -2563,7 +2559,7 @@ def _plot_topomap_multi_cbar( ) if colorbar: - cbar, cax = _add_colorbar(ax, im, cmap, title=None, format=cbar_fmt) + cbar, cax = _add_colorbar(ax, im, cmap, title=None, format_=cbar_fmt) cbar.set_ticks(_vlim) if unit is not None: cbar.ax.set_ylabel(unit, fontsize=8) @@ -3158,9 +3154,9 @@ def _animate(frame, ax, ax_line, params): time_idx = params["frames"][frame] if params["time_unit"] == "ms": - title = "%6.0f ms" % (params["times"][frame] * 1e3,) + title = f"{params['times'][frame] * 1e3:6.0f} ms" else: - title = "%6.3f s" % (params["times"][frame],) + title = f"{params['times'][frame]:6.3f} s" if params["blit"]: text = params["text"] else: @@ -3452,10 +3448,10 @@ def _plot_corrmap( for ii, data_, ax, subject, idx in zip(picks, data, axes, subjs, indices): if template: - ttl = "Subj. {}, {}".format(subject, ica._ica_names[idx]) + ttl = f"Subj. {subject}, {ica._ica_names[idx]}" ax.set_title(ttl, fontsize=12) else: - ax.set_title("Subj. {}".format(subject)) + ax.set_title(f"Subj. {subject}") if merge_channels: data_, _ = _merge_ch_data(data_, ch_type, []) _vlim = _setup_vmin_vmax(data_, None, None) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 231776c9165..adad59c4be0 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -9,13 +9,14 @@ Authors: Marijn van Vliet """ + # License: BSD-3-Clause # Copyright the MNE-Python contributors. import contextlib import re import weakref from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Optional, Union from matplotlib.colors import Colormap @@ -205,7 +206,7 @@ class Contours(UIEvent): """ kind: str - contours: List[str] + contours: list[str] def _get_event_channel(fig): diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 4223bafad6c..5d2f2d95617 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -46,6 +46,7 @@ ) from .._fiff.proj import Projection, setup_proj from ..defaults import _handle_default +from ..fixes import _median_complex from ..rank import compute_rank from ..transforms import apply_trans from ..utils import ( @@ -65,6 +66,7 @@ verbose, warn, ) +from ..utils.misc import _identity_function from .ui_events import ColormapRange, publish, subscribe _channel_type_prettyprint = { @@ -243,12 +245,12 @@ def _validate_if_list_of_axes(axes, obligatory_len=None, name="axes"): ) -def mne_analyze_colormap(limits=[5, 10, 15], format="vtk"): +def mne_analyze_colormap(limits=(5, 10, 15), format="vtk"): # noqa: A002 """Return a colormap similar to that used by mne_analyze. Parameters ---------- - limits : list (or array) of length 3 or 6 + limits : array-like of length 3 or 6 Bounds for the colormap, which will be mirrored across zero if length 3, or completely specified (and potentially asymmetric) if length 6. format : str @@ -457,8 +459,7 @@ def _prepare_trellis( naxes = ncols * nrows if naxes < n_cells: raise ValueError( - "Cannot plot {} axes in a {} by {} " - "figure.".format(n_cells, nrows, ncols) + f"Cannot plot {n_cells} axes in a {nrows} by {ncols} figure." ) width = size * ncols @@ -1416,7 +1417,7 @@ def _compute_scalings(scalings, inst, remove_dc=False, duration=10): time_middle = np.mean(inst.times) tmin = np.clip(time_middle - n_secs / 2.0, inst.times.min(), None) tmax = np.clip(time_middle + n_secs / 2.0, None, inst.times.max()) - smin, smax = [int(round(x * inst.info["sfreq"])) for x in (tmin, tmax)] + smin, smax = (int(round(x * inst.info["sfreq"])) for x in (tmin, tmax)) data = inst._read_segment(smin, smax) elif isinstance(inst, BaseEpochs): # Load a random subset of epochs up to 100mb in size @@ -1784,15 +1785,7 @@ def _get_color_list(annotations=False): from matplotlib import rcParams color_cycle = rcParams.get("axes.prop_cycle") - - if not color_cycle: - # Use deprecated color_cycle to avoid KeyErrors in environments - # with Python 2.7 and Matplotlib < 1.5 - # this will already be a list - colors = rcParams.get("axes.color_cycle") - else: - # we were able to use the prop_cycle. Now just convert to list - colors = color_cycle.by_key()["color"] + colors = color_cycle.by_key()["color"] # If we want annotations, red is reserved ... remove if present. This # checks for the reddish color in MPL dark background style, normal style, @@ -1987,9 +1980,7 @@ def _handle_decim(info, decim, lowpass): decim = max(int(info["sfreq"] / (lp * 3) + 1e-6), 1) decim = _ensure_int(decim, "decim", must_be='an int or "auto"') if decim <= 0: - raise ValueError( - 'decim must be "auto" or a positive integer, got %s' % (decim,) - ) + raise ValueError(f'decim must be "auto" or a positive integer, got {decim}') decim = _check_decim(info, decim, 0)[0] data_picks = _pick_data_channels(info, exclude=()) return decim, data_picks @@ -2147,20 +2138,27 @@ def _set_title_multiple_electrodes( ch_type = _channel_type_prettyprint.get(ch_type, ch_type) if ch_type is None: ch_type = "sensor" - if len(ch_names) > 1: - ch_type += "s" - combine = combine.capitalize() if isinstance(combine, str) else "Combination" + ch_type = f"{ch_type}{_pl(ch_names)}" + if hasattr(combine, "func"): # functools.partial + combine = combine.func + if callable(combine): + combine = getattr(combine, "__name__", str(combine)) + if not isinstance(combine, str): + combine = "Combination" + # mean → Mean, but avoid RMS → Rms and GFP → Gfp + if combine[0].islower(): + combine = combine.capitalize() if all_: title = f"{combine} of {len(ch_names)} {ch_type}" elif len(ch_names) > max_chans and combine != "gfp": - logger.info("More than %i channels, truncating title ...", max_chans) + logger.info(f"More than {max_chans} channels, truncating title ...") title += f", ...\n({combine} of {len(ch_names)} {ch_type})" return title def _check_time_unit(time_unit, times): if not isinstance(time_unit, str): - raise TypeError("time_unit must be str, got %s" % (type(time_unit),)) + raise TypeError(f"time_unit must be str, got {type(time_unit)}") if time_unit == "s": pass elif time_unit == "ms": @@ -2222,7 +2220,7 @@ def _plot_masked_image( if mask.shape != data.shape: raise ValueError( "The mask must have the same shape as the data, " - "i.e., %s, not %s" % (data.shape, mask.shape) + f"i.e., {data.shape}, not {mask.shape}" ) if draw_contour and yscale == "log": warn("Cannot draw contours with linear yscale yet ...") @@ -2329,7 +2327,7 @@ def _plot_masked_image( t_end = ", all points masked)" else: fraction = 1 - (np.float64(mask.sum()) / np.float64(mask.size)) - t_end = ", %0.3g%% of points masked)" % (fraction * 100,) + t_end = f", {fraction * 100:0.3g}% of points masked)" else: t_end = ")" @@ -2337,31 +2335,69 @@ def _plot_masked_image( @fill_doc -def _make_combine_callable(combine): +def _make_combine_callable( + combine, + *, + axis=1, + valid=("mean", "median", "std", "gfp"), + ch_type=None, + keepdims=False, +): """Convert None or string values of ``combine`` into callables. Params ------ - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_epochs, n_channels, n_times)`` or ``(n_evokeds, n_channels, - n_times)``) and return an :class:`array ` of shape - ``(n_epochs, n_times)`` or ``(n_evokeds, n_times)``. + combine : None | str | callable + If callable, the callable must accept one positional input (a numpy array) and + return an array with one fewer dimensions (the missing dimension's position is + given by ``axis``). + axis : int + Axis of data array across which to combine. May vary depending on data + context; e.g., if data are time-domain sensor traces or TFRs, continuous + or epoched, etc. + valid : tuple + Valid string values for built-in combine methods + (may vary for, e.g., combining TFRs versus time-domain signals). + ch_type : str + Channel type. Affects whether "gfp" is allowed as a synonym for "rms". + keepdims : bool + Whether to retain the singleton dimension after collapsing across it. """ + kwargs = dict(axis=axis, keepdims=keepdims) if combine is None: - combine = partial(np.squeeze, axis=1) + combine = _identity_function if keepdims else partial(np.squeeze, axis=axis) elif isinstance(combine, str): combine_dict = { - key: partial(getattr(np, key), axis=1) for key in ("mean", "median", "std") + key: partial(getattr(np, key), **kwargs) + for key in valid + if getattr(np, key, None) is not None } - combine_dict["gfp"] = lambda data: np.sqrt((data**2).mean(axis=1)) + # marginal median that is safe for complex values: + if "median" in valid: + combine_dict["median"] = partial(_median_complex, axis=axis) + + # RMS and GFP; if GFP requested for MEG channels, will use RMS anyway + def _rms(data): + return np.sqrt((data**2).mean(**kwargs)) + + def _gfp(data): + return data.std(axis=axis, ddof=0) + + # make them play nice with _set_title_multiple_electrodes() + _rms.__name__ = "RMS" + _gfp.__name__ = "GFP" + if "rms" in valid: + combine_dict["rms"] = _rms + if "gfp" in valid and ch_type == "eeg": + combine_dict["gfp"] = _gfp + elif "gfp" in valid: + combine_dict["gfp"] = _rms try: combine = combine_dict[combine] except KeyError: raise ValueError( - '"combine" must be None, a callable, or one of ' - '"mean", "median", "std", or "gfp"; got {}' - "".format(combine) + f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; ' + f'got {combine}' ) return combine @@ -2369,29 +2405,12 @@ def _make_combine_callable(combine): def _convert_psds( psds, dB, estimate, scaling, unit, ch_names=None, first_dim="channel" ): - """Convert PSDs to dB (if necessary) and appropriate units. - - The following table summarizes the relationship between the value of - parameters ``dB`` and ``estimate``, and the type of plot and corresponding - units. - - | dB | estimate | plot | units | - |-------+-------------+------+-------------------| - | True | 'power' | PSD | amp**2/Hz (dB) | - | True | 'amplitude' | ASD | amp/sqrt(Hz) (dB) | - | True | 'auto' | PSD | amp**2/Hz (dB) | - | False | 'power' | PSD | amp**2/Hz | - | False | 'amplitude' | ASD | amp/sqrt(Hz) | - | False | 'auto' | ASD | amp/sqrt(Hz) | - - where amp are the units corresponding to the variable, as specified by - ``unit``. - """ + """Convert PSDs to dB (if necessary) and appropriate units.""" _check_option("first_dim", first_dim, ["channel", "epoch"]) where = np.where(psds.min(1) <= 0)[0] if len(where) > 0: - # Construct a helpful error message, depending on whether the first - # dimension of `psds` are channels or epochs. + # Construct a helpful error message, depending on whether the first dimension of + # `psds` corresponds to channels or epochs. if dB: bad_value = "Infinite" else: @@ -2413,16 +2432,18 @@ def _convert_psds( if estimate == "amplitude": np.sqrt(psds, out=psds) psds *= scaling - ylabel = r"$\mathrm{%s/\sqrt{Hz}}$" % unit + ylabel = rf"$\mathrm{{{unit}/\sqrt{{Hz}}}}$" else: psds *= scaling * scaling if "/" in unit: - unit = "(%s)" % unit - ylabel = r"$\mathrm{%s²/Hz}$" % unit + unit = f"({unit})" + ylabel = rf"$\mathrm{{{unit}²/Hz}}$" if dB: np.log10(np.maximum(psds, np.finfo(float).tiny), out=psds) psds *= 10 - ylabel += r"$\ \mathrm{(dB)}$" + ylabel = r"$\mathrm{dB}\ $" + ylabel + ylabel = "Power (" + ylabel if estimate == "power" else "Amplitude (" + ylabel + ylabel += ")" return ylabel diff --git a/pyproject.toml b/pyproject.toml index 8c172fcac70..5a2dbce91a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + [project] name = "mne" description = "MNE-Python project for MEG and EEG data analysis." @@ -8,7 +12,7 @@ authors = [ maintainers = [{ name = "Dan McCloy", email = "dan@mccloy.info" }] license = { text = "BSD-3-Clause" } readme = { file = "README.rst", content-type = "text/x-rst" } -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = [ "neuroscience", "neuroimaging", @@ -41,9 +45,7 @@ dependencies = [ "decorator", "packaging", "jinja2", - "importlib_resources>=5.10.2; python_version<'3.9'", "lazy_loader>=0.3", - "defusedxml", ] [project.optional-dependencies] @@ -60,7 +62,8 @@ hdf5 = ["h5io", "pymatreader"] full = [ "mne[hdf5]", "qtpy", - "PyQt6", + "PyQt6!=6.6.1", + "PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3", "pyobjc-framework-Cocoa>=5.2.0; platform_system=='Darwin'", "sip", "scikit-learn", @@ -69,10 +72,9 @@ full = [ "numba", "h5py", "pandas", - "numexpr", + "pyarrow", # only needed to avoid a deprecation warning in pandas "jupyter", "python-picard", - "statsmodels", "joblib", "psutil", "dipy", @@ -93,26 +95,32 @@ full = [ "trame-vuetify", "mne-qt-browser", "darkdetect", - "qdarkstyle", + "qdarkstyle!=3.2.2", "threadpoolctl", + # duplicated in test_extra: + "statsmodels", + "eeglabio", + "edfio>=0.2.1", + "pybv", + "snirf", + "defusedxml", + "neo", ] # Dependencies for running the test infrastructure test = [ - "pytest", + "pytest>=8.0.0rc2", "pytest-cov", "pytest-timeout", - "pytest-harvest", "pytest-qt", "ruff", "numpydoc", "codespell", - "check-manifest", "tomli; python_version<'3.11'", "twine", "wheel", "pre-commit", - "black", + "mypy", ] # Dependencies for being able to run additional tests (rare/CIs/advanced devs) @@ -121,21 +129,25 @@ test_extra = [ "nitime", "nbclient", "sphinx-gallery", + "statsmodels", "eeglabio", - "EDFlib-Python", + "edfio>=0.2.1", "pybv", "imageio>=2.6.1", "imageio-ffmpeg>=0.4.1", "snirf", + "neo", + "mne-bids", ] -# Dependencies for building the docuemntation +# Dependencies for building the documentation doc = [ - "sphinx>=6", + "sphinx>=6,<7.3", "numpydoc", - "pydata_sphinx_theme==0.13.3", + "pydata_sphinx_theme==0.15.2", "sphinx-gallery", "sphinxcontrib-bibtex>=2.5", + "sphinxcontrib-towncrier", "memory_profiler", "neo", "seaborn!=0.11.2", @@ -153,63 +165,39 @@ doc = [ "ipython!=8.7.0", "selenium", ] -dev = ["mne[test,doc]"] +dev = ["mne[test,doc]", "rcssmin"] [project.urls] Homepage = "https://mne.tools/" -Download = "https://pypi.org/project/scikit-learn/#files" +Download = "https://pypi.org/project/mne/#files" "Bug Tracker" = "https://github.com/mne-tools/mne-python/issues/" Documentation = "https://mne.tools/" Forum = "https://mne.discourse.group/" "Source Code" = "https://github.com/mne-tools/mne-python/" -[build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -where = ["."] -include = ["mne*"] -namespaces = false - -[tool.setuptools_scm] -write_to = "mne/_version.py" -version_scheme = "release-branch-semver" +[tool.hatch.build] +exclude = [ + "/.*", + "/*.yml", + "/*.yaml", + "/*.toml", + "/*.txt", + "/mne/**/tests", + "/logo", + "/doc", + "/tools", + "/tutorials", + "/examples", + "/codemeta.json", + "/ignore_words.txt", + "/Makefile", + "/CITATION.cff", + "/CONTRIBUTING.md", +] # tracked by git, but we don't want to ship those files -[options] -zip_safe = false # the package can run out of an .egg file -include_package_data = true - -[tool.setuptools.package-data] -"mne" = [ - "data/eegbci_checksums.txt", - "data/*.sel", - "data/icos.fif.gz", - "data/coil_def*.dat", - "data/helmets/*.fif.gz", - "data/FreeSurferColorLUT.txt", - "data/image/*gif", - "data/image/*lout", - "data/fsaverage/*.fif", - "channels/data/layouts/*.lout", - "channels/data/layouts/*.lay", - "channels/data/montages/*.sfp", - "channels/data/montages/*.txt", - "channels/data/montages/*.elc", - "channels/data/neighbors/*.mat", - "datasets/sleep_physionet/SHA1SUMS", - "datasets/_fsaverage/*.txt", - "datasets/_infant/*.txt", - "datasets/_phantom/*.txt", - "html/*.js", - "html/*.css", - "html_templates/repr/*.jinja", - "html_templates/report/*.jinja", - "icons/*.svg", - "icons/*.png", - "io/artemis123/resources/*.csv", - "io/edf/gdf_encodes.txt", -] +[tool.hatch.version] +source = "vcs" +raw-options = { version_scheme = "release-branch-semver" } [tool.codespell] ignore-words = "ignore_words.txt" @@ -217,15 +205,17 @@ builtin = "clear,rare,informal,names,usage" skip = "doc/references.bib" [tool.ruff] -select = ["E", "F", "W", "D", "I"] exclude = ["__init__.py", "constants.py", "resources.py"] + +[tool.ruff.lint] +select = ["A", "B006", "D", "E", "F", "I", "W", "UP", "UP031"] ignore = [ "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D413", # Missing blank line after last section ] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" ignore-decorators = [ "property", @@ -235,7 +225,7 @@ ignore-decorators = [ "mne.utils.deprecated", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tutorials/time-freq/10_spectrum_class.py" = [ "E501", # line too long ] @@ -251,9 +241,15 @@ ignore-decorators = [ "examples/*/*.py" = [ "D205", # 1 blank line required between summary line and description ] +"examples/preprocessing/eeg_bridging.py" = [ + "E501", # line too long +] [tool.pytest.ini_options] -addopts = """--durations=20 --doctest-modules -ra --cov-report= --tb=short \ +# -r f (failed), E (error), s (skipped), x (xfail), X (xpassed), w (warnings) +# don't put in xfail for pytest 8.0+ because then it prints the tracebacks, +# which look like real errors +addopts = """--durations=20 --doctest-modules -rfEXs --cov-report= --tb=short \ --cov-branch --doctest-ignore-import-errors --junit-xml=junit-results.xml \ --ignore=doc --ignore=logo --ignore=examples --ignore=tutorials \ --ignore=mne/gui/_*.py --ignore=mne/icons --ignore=tools \ @@ -261,9 +257,6 @@ addopts = """--durations=20 --doctest-modules -ra --cov-report= --tb=short \ --color=yes --capture=sys""" junit_family = "xunit2" -[tool.black] -exclude = "(dist/)|(build/)|(.*\\.ipynb)" - [tool.bandit.assert_used] skips = ["*/test_*.py"] # assert statements are good practice with pytest @@ -308,5 +301,79 @@ ignore_directives = [ "toctree", "rst-class", "tab-set", + "towncrier-draft-entries", ] ignore_messages = "^.*(Unknown target name|Undefined substitution referenced)[^`]*$" + +[tool.mypy] +ignore_errors = true +scripts_are_modules = true +strict = false + +[[tool.mypy.overrides]] +module = ['mne.annotations', 'mne.epochs', 'mne.evoked', 'mne.io'] +ignore_errors = false +# Ignore "attr-defined" until we fix stuff like: +# - BunchConstNamed: '"BunchConstNamed" has no attribute "FIFFB_EVOKED"' +# - Missing __all__: 'Module "mne.io.snirf" does not explicitly export attribute "read_raw_snirf"' +# Ignore "no-untyped-call" until we fix stuff like: +# - 'Call to untyped function "end_block" in typed context' +# Ignore "no-untyped-def" until we fix stuff like: +# - 'Function is missing a type annotation' +# Ignore "misc" until we fix stuff like: +# - 'Cannot determine type of "_projector" in base class "ProjMixin"' +# Ignore "assignment" until we fix stuff like: +# - 'Incompatible types in assignment (expression has type "tuple[str, ...]", variable has type "str")' +# Ignore "operator" until we fix stuff like: +# - Unsupported operand types for - ("None" and "int") +disable_error_code = [ + 'attr-defined', + 'no-untyped-call', + 'no-untyped-def', + 'misc', + 'assignment', + 'operator', +] + +[tool.towncrier] +package = "mne" +directory = "doc/changes/devel/" +filename = "doc/changes/devel.rst" +title_format = "{version} ({project_date})" +issue_format = "`#{issue} `__" + +[[tool.towncrier.type]] +directory = "notable" +name = "Notable changes" +showcontent = true + +[[tool.towncrier.type]] +directory = "dependency" +name = "Dependencies" +showcontent = true + +[[tool.towncrier.type]] +directory = "bugfix" +name = "Bugfixes" +showcontent = true + +[[tool.towncrier.type]] +directory = "apichange" +name = "API changes by deprecation" +showcontent = true + +[[tool.towncrier.type]] +directory = "newfeature" +name = "New features" +showcontent = true + +[[tool.towncrier.type]] +directory = "other" +name = "Other changes" +showcontent = true + +[tool.changelog-bot] +[tool.changelog-bot.towncrier_changelog] +enabled = true +verify_pr_number = true +changelog_skip_label = "no-changelog-entry-needed" diff --git a/tools/azure_dependencies.sh b/tools/azure_dependencies.sh index d27c10d8224..cf4dd4726b9 100755 --- a/tools/azure_dependencies.sh +++ b/tools/azure_dependencies.sh @@ -1,38 +1,14 @@ -#!/bin/bash -ef +#!/bin/bash +set -eo pipefail +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) STD_ARGS="--progress-bar off --upgrade" -python -m pip install $STD_ARGS pip setuptools wheel packaging setuptools_scm +python -m pip install $STD_ARGS pip setuptools wheel if [ "${TEST_MODE}" == "pip" ]; then - python -m pip install --only-binary="numba,llvmlite,numpy,scipy,vtk" -e .[test,full] + python -m pip install $STD_ARGS --only-binary="numba,llvmlite,numpy,scipy,vtk" -e .[test,full] elif [ "${TEST_MODE}" == "pip-pre" ]; then - STD_ARGS="$STD_ARGS --pre" - python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://www.riverbankcomputing.com/pypi/simple" PyQt6 PyQt6-sip PyQt6-Qt6 - echo "Numpy etc." - # See github_actions_dependencies.sh for comments - python -m pip install $STD_ARGS --only-binary "numpy" numpy - python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" "scipy>=1.12.0.dev0" scikit-learn matplotlib pandas statsmodels - echo "dipy" - python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" dipy - echo "h5py" - python -m pip install $STD_ARGS --only-binary ":all:" -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com" h5py - echo "vtk" - python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk - echo "nilearn and openmeeg" - python -m pip install $STD_ARGS git+https://github.com/nilearn/nilearn - python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://test.pypi.org/simple" openmeeg - echo "pyvista/pyvistaqt" - python -m pip install --progress-bar off git+https://github.com/pyvista/pyvista - python -m pip install --progress-bar off git+https://github.com/pyvista/pyvistaqt - echo "misc" - python -m pip install $STD_ARGS imageio-ffmpeg xlrd mffpy python-picard pillow traitlets pybv eeglabio - echo "nibabel with workaround" - python -m pip install --progress-bar off git+https://github.com/nipy/nibabel.git - echo "joblib" - python -m pip install --progress-bar off git+https://github.com/joblib/joblib@master - echo "EDFlib-Python" - python -m pip install $STD_ARGS git+https://gitlab.com/Teuniz/EDFlib-Python@master - ./tools/check_qt_import.sh PyQt6 - python -m pip install $STD_ARGS -e .[hdf5,test] + ${SCRIPT_DIR}/install_pre_requirements.sh + python -m pip install $STD_ARGS --pre -e .[test] else echo "Unknown run type ${TEST_MODE}" exit 1 diff --git a/tools/circleci_bash_env.sh b/tools/circleci_bash_env.sh index fb5e471c9fd..55cdb2e157c 100755 --- a/tools/circleci_bash_env.sh +++ b/tools/circleci_bash_env.sh @@ -17,6 +17,7 @@ source tools/get_minimal_commands.sh echo "export MNE_3D_BACKEND=pyvistaqt" >> $BASH_ENV echo "export MNE_BROWSER_BACKEND=qt" >> $BASH_ENV echo "export MNE_BROWSER_PRECOMPUTE=false" >> $BASH_ENV +echo "export MNE_ADD_CONTRIBUTOR_IMAGE=true" >> $BASH_ENV echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV echo "export DISPLAY=:99" >> $BASH_ENV echo "source ~/python_env/bin/activate" >> $BASH_ENV diff --git a/tools/generate_codemeta.py b/tools/generate_codemeta.py index 9e697cecc55..a1c1fac77b4 100644 --- a/tools/generate_codemeta.py +++ b/tools/generate_codemeta.py @@ -44,6 +44,7 @@ "De Santis", "Dupré la Tour", "de la Torre", + "de Jong", "de Montalivet", "van den Bosch", "Van den Bossche", diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index 2a4b90bb910..36a0e16d4fb 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -2,12 +2,17 @@ set -o pipefail +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) STD_ARGS="--progress-bar off --upgrade" +INSTALL_ARGS="-e" INSTALL_KIND="test_extra,hdf5" if [ ! -z "$CONDA_ENV" ]; then echo "Uninstalling MNE for CONDA_ENV=${CONDA_ENV}" - conda remove -c conda-forge --force -yq mne + conda remove -c conda-forge --force -yq mne-base python -m pip uninstall -y mne + if [[ "${RUNNER_OS}" != "Windows" ]]; then + INSTALL_ARGS="" + fi elif [ ! -z "$CONDA_DEPENDENCIES" ]; then echo "Using Mamba to install CONDA_DEPENDENCIES=${CONDA_DEPENDENCIES}" mamba install -y $CONDA_DEPENDENCIES @@ -15,50 +20,12 @@ elif [ ! -z "$CONDA_DEPENDENCIES" ]; then STD_ARGS="--progress-bar off" INSTALL_KIND="test" else - echo "Install pip-pre dependencies" test "${MNE_CI_KIND}" == "pip-pre" STD_ARGS="$STD_ARGS --pre" - python -m pip install $STD_ARGS pip - echo "Numpy" - pip uninstall -yq numpy - echo "PyQt6" - pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 --extra-index-url https://www.riverbankcomputing.com/pypi/simple PyQt6 - echo "NumPy/SciPy/pandas etc." - # As of 2023/11/20 no NumPy 2.0 because it requires everything using its ABI to - # compile against 2.0, and h5py isn't (and probably not VTK either) - pip install $STD_ARGS --only-binary "numpy" --default-timeout=60 numpy - pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" scipy scikit-learn matplotlib pillow pandas statsmodels - echo "dipy" - pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 --extra-index-url "https://pypi.anaconda.org/scipy-wheels-nightly/simple" dipy - echo "H5py" - pip install $STD_ARGS --only-binary ":all:" -f "https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com" h5py - echo "OpenMEEG" - pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://test.pypi.org/simple" openmeeg - # No Numba because it forces an old NumPy version - echo "nilearn and openmeeg" - pip install $STD_ARGS git+https://github.com/nilearn/nilearn - pip install $STD_ARGS openmeeg - echo "VTK" - pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk - python -c "import vtk" - echo "PyVista" - pip install $STD_ARGS git+https://github.com/pyvista/pyvista - echo "pyvistaqt" - pip install $STD_ARGS git+https://github.com/pyvista/pyvistaqt - echo "imageio-ffmpeg, xlrd, mffpy, python-picard" - pip install $STD_ARGS imageio-ffmpeg xlrd mffpy python-picard patsy traitlets pybv eeglabio - echo "mne-qt-browser" - pip install $STD_ARGS git+https://github.com/mne-tools/mne-qt-browser - echo "nibabel with workaround" - pip install $STD_ARGS git+https://github.com/nipy/nibabel.git - echo "joblib" - pip install $STD_ARGS git+https://github.com/joblib/joblib@master - echo "EDFlib-Python" - pip install $STD_ARGS git+https://gitlab.com/Teuniz/EDFlib-Python@master - # Until Pandas is fixed, make sure we didn't install it - ! python -c "import pandas" + ${SCRIPT_DIR}/install_pre_requirements.sh + INSTALL_KIND="test_extra" fi echo "" echo "Installing test dependencies using pip" -python -m pip install $STD_ARGS -e .[$INSTALL_KIND] +python -m pip install $STD_ARGS $INSTALL_ARGS .[$INSTALL_KIND] diff --git a/tools/github_actions_install.sh b/tools/github_actions_install.sh deleted file mode 100755 index f52c193d773..00000000000 --- a/tools/github_actions_install.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -set -eo pipefail - -pip install -ve . diff --git a/tools/github_actions_test.sh b/tools/github_actions_test.sh index f218197bda6..78cc063d016 100755 --- a/tools/github_actions_test.sh +++ b/tools/github_actions_test.sh @@ -12,6 +12,17 @@ if [ "${MNE_CI_KIND}" == "notebook" ]; then else USE_DIRS="mne/" fi +JUNIT_PATH="junit-results.xml" +if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]]; then + JUNIT_PATH="$(pwd)/${JUNIT_PATH}" + # Use the installed version after adding all (excluded) test files + cd .. + INSTALL_PATH=$(python -c "import mne, pathlib; print(str(pathlib.Path(mne.__file__).parents[1]))") + echo "Copying tests from $(pwd)/mne-python/mne/ to ${INSTALL_PATH}/mne/" + rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ./mne-python/mne/ ${INSTALL_PATH}/mne/ + cd $INSTALL_PATH + echo "Executing from $(pwd)" +fi set -x -pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml -vv ${USE_DIRS} +pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml --color=yes --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} set +x diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh new file mode 100755 index 00000000000..47c7087ac8d --- /dev/null +++ b/tools/install_pre_requirements.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +set -eo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PLATFORM=$(python -c 'import platform; print(platform.system())') + +echo "Installing pip-pre dependencies on ${PLATFORM}" +STD_ARGS="--progress-bar off --upgrade --pre" + +# Dependencies of scientific-python-nightly-wheels are installed here so that +# we can use strict --index-url (instead of --extra-index-url) below +python -m pip install $STD_ARGS pip setuptools packaging \ + threadpoolctl cycler fonttools kiwisolver pyparsing pillow python-dateutil \ + patsy pytz tzdata nibabel tqdm trx-python joblib numexpr +echo "PyQt6" +# Now broken in latest release and in the pre release: +# pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 --extra-index-url https://www.riverbankcomputing.com/pypi/simple -r $SCRIPT_DIR/pyqt6_requirements.txt +python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 -r $SCRIPT_DIR/pyqt6_requirements.txt +echo "NumPy/SciPy/pandas etc." +python -m pip uninstall -yq numpy +# No pyarrow yet https://github.com/apache/arrow/issues/40216 +# No h5py (and thus dipy) yet until they improve/refactor thier wheel building infrastructure for Windows +OTHERS="" +if [[ "${PLATFORM}" == "Linux" ]]; then + OTHERS="h5py dipy" +fi +python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \ + --index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ + "numpy>=2.1.0.dev0" "scipy>=1.14.0.dev0" "scikit-learn>=1.5.dev0" \ + matplotlib pandas statsmodels \ + $OTHERS + +# No Numba because it forces an old NumPy version + +echo "OpenMEEG" +python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://test.pypi.org/simple" "openmeeg>=2.6.0.dev4" + +echo "nilearn" +python -m pip install $STD_ARGS git+https://github.com/nilearn/nilearn + +echo "VTK" +python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk +python -c "import vtk" + +echo "PyVista" +python -m pip install $STD_ARGS git+https://github.com/pyvista/pyvista + +echo "picard" +python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard + +echo "pyvistaqt" +pip install $STD_ARGS git+https://github.com/pyvista/pyvistaqt + +echo "imageio-ffmpeg, xlrd, mffpy" +pip install $STD_ARGS imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio + +echo "mne-qt-browser" +pip install $STD_ARGS git+https://github.com/mne-tools/mne-qt-browser + +echo "nibabel" +pip install $STD_ARGS git+https://github.com/nipy/nibabel + +echo "joblib" +pip install $STD_ARGS git+https://github.com/joblib/joblib + +echo "edfio" +pip install $STD_ARGS git+https://github.com/the-siesta-group/edfio + +if [[ "${PLATFORM}" == "Linux" ]]; then + echo "h5io" + pip install $STD_ARGS git+https://github.com/h5io/h5io + + echo "pysnirf2" + pip install $STD_ARGS git+https://github.com/BUNPC/pysnirf2 +fi + +# Make sure we're on a NumPy 2.0 variant +echo "Checking NumPy version" +python -c "import numpy as np; assert np.__version__[0] == '2', np.__version__" + +# And that Qt works +echo "Checking Qt" +${SCRIPT_DIR}/check_qt_import.sh PyQt6 diff --git a/tools/pyqt6_requirements.txt b/tools/pyqt6_requirements.txt new file mode 100644 index 00000000000..26ec8315141 --- /dev/null +++ b/tools/pyqt6_requirements.txt @@ -0,0 +1,2 @@ +PyQt6!=6.6.1 +PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3 diff --git a/tools/setup_xvfb.sh b/tools/setup_xvfb.sh index a5c55d0819b..d22f8e2b7ac 100755 --- a/tools/setup_xvfb.sh +++ b/tools/setup_xvfb.sh @@ -11,5 +11,5 @@ done # This also includes the libraries necessary for PyQt5/PyQt6 sudo apt update -sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0 +sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0 libxml2 /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset diff --git a/tutorials/clinical/20_seeg.py b/tutorials/clinical/20_seeg.py index cce5f4a089a..6166001c075 100644 --- a/tutorials/clinical/20_seeg.py +++ b/tutorials/clinical/20_seeg.py @@ -58,8 +58,7 @@ raw = mne.io.read_raw(misc_path / "seeg" / "sample_seeg_ieeg.fif") -events, event_id = mne.events_from_annotations(raw) -epochs = mne.Epochs(raw, events, event_id, detrend=1, baseline=None) +epochs = mne.Epochs(raw, detrend=1, baseline=None) epochs = epochs["Response"][0] # just process one epoch of data for speed # %% @@ -213,8 +212,14 @@ evoked = epochs.average() stc = mne.stc_near_sensors( - evoked, trans, "fsaverage", subjects_dir=subjects_dir, src=vol_src, verbose="error" -) # ignore missing electrode warnings + evoked, + trans, + "fsaverage", + subjects_dir=subjects_dir, + src=vol_src, + surface=None, + verbose="error", +) stc = abs(stc) # just look at magnitude clim = dict(kind="value", lims=np.percentile(abs(evoked.data), [10, 50, 75])) diff --git a/tutorials/clinical/30_ecog.py b/tutorials/clinical/30_ecog.py index 2ccc2d6cb91..d568d3b1bb4 100644 --- a/tutorials/clinical/30_ecog.py +++ b/tutorials/clinical/30_ecog.py @@ -100,15 +100,11 @@ # at the posterior commissure) raw.set_montage(montage) -# Find the annotated events -events, event_id = mne.events_from_annotations(raw) - # Make a 25 second epoch that spans before and after the seizure onset epoch_length = 25 # seconds epochs = mne.Epochs( raw, - events, - event_id=event_id["onset"], + event_id="onset", tmin=13, tmax=13 + epoch_length, baseline=None, diff --git a/tutorials/clinical/60_sleep.py b/tutorials/clinical/60_sleep.py index 020d00bab7e..b25776d7435 100644 --- a/tutorials/clinical/60_sleep.py +++ b/tutorials/clinical/60_sleep.py @@ -75,7 +75,11 @@ [alice_files, bob_files] = fetch_data(subjects=[ALICE, BOB], recording=[1]) raw_train = mne.io.read_raw_edf( - alice_files[0], stim_channel="Event marker", infer_types=True, preload=True + alice_files[0], + stim_channel="Event marker", + infer_types=True, + preload=True, + verbose="error", # ignore issues with stored filter settings ) annot_train = mne.read_annotations(alice_files[1]) @@ -172,7 +176,11 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ raw_test = mne.io.read_raw_edf( - bob_files[0], stim_channel="Event marker", infer_types=True, preload=True + bob_files[0], + stim_channel="Event marker", + infer_types=True, + preload=True, + verbose="error", ) annot_test = mne.read_annotations(bob_files[1]) annot_test.crop(annot_test[1]["onset"] - 30 * 60, annot_test[-2]["onset"] + 30 * 60) @@ -219,6 +227,7 @@ axes=ax, show=False, average=True, + amplitude=False, spatial_colors=False, picks="data", exclude="bads", @@ -306,7 +315,7 @@ def eeg_power_band(epochs): y_test = epochs_test.events[:, 2] acc = accuracy_score(y_test, y_pred) -print("Accuracy score: {}".format(acc)) +print(f"Accuracy score: {acc}") ############################################################################## # In short, yes. We can predict Bob's sleeping stages based on Alice's data. diff --git a/tutorials/epochs/10_epochs_overview.py b/tutorials/epochs/10_epochs_overview.py index 54a99f9f149..7726bc754a2 100644 --- a/tutorials/epochs/10_epochs_overview.py +++ b/tutorials/epochs/10_epochs_overview.py @@ -314,9 +314,7 @@ shorter_epochs = epochs.copy().crop(tmin=-0.1, tmax=0.1, include_tmax=True) for name, obj in dict(Original=epochs, Cropped=shorter_epochs).items(): - print( - "{} epochs has {} time samples".format(name, obj.get_data(copy=False).shape[-1]) - ) + print(f"{name} epochs has {obj.get_data(copy=False).shape[-1]} time samples") # %% # Cropping removed part of the baseline. When printing the @@ -370,7 +368,7 @@ channel_4_6_8 = epochs.get_data(picks=slice(4, 9, 2)) for name, arr in dict(EOG=eog_data, MEG=meg_data, Slice=channel_4_6_8).items(): - print("{} contains {} channels".format(name, arr.shape[1])) + print(f"{name} contains {arr.shape[1]} channels") # %% # Note that if your analysis requires repeatedly extracting single epochs from diff --git a/tutorials/epochs/20_visualize_epochs.py b/tutorials/epochs/20_visualize_epochs.py index 5fc1a454700..e311b324ee8 100644 --- a/tutorials/epochs/20_visualize_epochs.py +++ b/tutorials/epochs/20_visualize_epochs.py @@ -144,7 +144,7 @@ # :class:`~mne.time_frequency.EpochsSpectrum`'s # :meth:`~mne.time_frequency.EpochsSpectrum.plot` method. -epochs["auditory"].compute_psd().plot(picks="eeg", exclude="bads") +epochs["auditory"].compute_psd().plot(picks="eeg", exclude="bads", amplitude=False) # %% # It is also possible to plot spectral power estimates across sensors as a @@ -245,7 +245,9 @@ # therefore mask smaller signal fluctuations of interest. reject_criteria = dict( - mag=3000e-15, grad=3000e-13, eeg=150e-6 # 3000 fT # 3000 fT/cm + mag=3000e-15, # 3000 fT + grad=3000e-13, # 3000 fT/cm + eeg=150e-6, ) # 150 µV epochs.drop_bad(reject=reject_criteria) diff --git a/tutorials/epochs/30_epochs_metadata.py b/tutorials/epochs/30_epochs_metadata.py index 7d5c06871ad..42fa1219b52 100644 --- a/tutorials/epochs/30_epochs_metadata.py +++ b/tutorials/epochs/30_epochs_metadata.py @@ -116,14 +116,14 @@ # MNE-Python will try the traditional method first before falling back on rich # metadata querying. -epochs["solenoid"].compute_psd().plot(picks="data", exclude="bads") +epochs["solenoid"].compute_psd().plot(picks="data", exclude="bads", amplitude=False) # %% # One use of the Pandas query string approach is to select specific words for # plotting: words = ["typhoon", "bungalow", "colossus", "drudgery", "linguist", "solenoid"] -epochs["WORD in {}".format(words)].plot(n_channels=29, events=True) +epochs[f"WORD in {words}"].plot(n_channels=29, events=True) # %% # Notice that in this dataset, each "condition" (A.K.A., each word) occurs only diff --git a/tutorials/epochs/40_autogenerate_metadata.py b/tutorials/epochs/40_autogenerate_metadata.py index 8f7f3f5a90e..9e769a5ff5e 100644 --- a/tutorials/epochs/40_autogenerate_metadata.py +++ b/tutorials/epochs/40_autogenerate_metadata.py @@ -46,13 +46,11 @@ # Copyright the MNE-Python contributors. # %% -from pathlib import Path - import matplotlib.pyplot as plt import mne -data_dir = Path(mne.datasets.erp_core.data_path()) +data_dir = mne.datasets.erp_core.data_path() infile = data_dir / "ERP-CORE_Subject-001_Task-Flankers_eeg.fif" raw = mne.io.read_raw(infile, preload=True) @@ -88,7 +86,7 @@ # i.e. starting with stimulus onset and expanding beyond the end of the epoch metadata_tmin, metadata_tmax = 0.0, 1.5 -# auto-create metadata +# auto-create metadata: # this also returns a new events array and an event_id dictionary. we'll see # later why this is important metadata, events, event_id = mne.epochs.make_metadata( diff --git a/tutorials/evoked/10_evoked_overview.py b/tutorials/evoked/10_evoked_overview.py index a2513ea1e27..9116bb19ea6 100644 --- a/tutorials/evoked/10_evoked_overview.py +++ b/tutorials/evoked/10_evoked_overview.py @@ -336,11 +336,7 @@ channel, latency, value = trial.get_peak(ch_type="eeg", return_amplitude=True) latency = int(round(latency * 1e3)) # convert to milliseconds value = int(round(value * 1e6)) # convert to µV - print( - "Trial {}: peak of {} µV at {} ms in channel {}".format( - ix, value, latency, channel - ) - ) + print(f"Trial {ix}: peak of {value} µV at {latency} ms in channel {channel}") # %% # .. REFERENCES diff --git a/tutorials/forward/20_source_alignment.py b/tutorials/forward/20_source_alignment.py index c1ff697f9ce..f14b556f165 100644 --- a/tutorials/forward/20_source_alignment.py +++ b/tutorials/forward/20_source_alignment.py @@ -121,8 +121,8 @@ ) dists = mne.dig_mri_distances(raw.info, trans, "sample", subjects_dir=subjects_dir) print( - "Distance from %s digitized points to head surface: %0.1f mm" - % (len(dists), 1000 * np.mean(dists)) + f"Distance from {len(dists)} digitized points to head surface: " + f"{1000 * np.mean(dists):0.1f} mm" ) # %% diff --git a/tutorials/intro/10_overview.py b/tutorials/intro/10_overview.py index 20dc532f65a..2c9a68a1baf 100644 --- a/tutorials/intro/10_overview.py +++ b/tutorials/intro/10_overview.py @@ -5,12 +5,12 @@ Overview of MEG/EEG analysis with MNE-Python ============================================ -This tutorial covers the basic EEG/MEG pipeline for event-related analysis: -loading data, epoching, averaging, plotting, and estimating cortical activity -from sensor data. It introduces the core MNE-Python data structures -`~mne.io.Raw`, `~mne.Epochs`, `~mne.Evoked`, and `~mne.SourceEstimate`, and -covers a lot of ground fairly quickly (at the expense of depth). Subsequent -tutorials address each of these topics in greater detail. +This tutorial covers the basic EEG/MEG pipeline for event-related analysis: loading +data, epoching, averaging, plotting, and estimating cortical activity from sensor data. +It introduces the core MNE-Python data structures `~mne.io.Raw`, `~mne.Epochs`, +`~mne.Evoked`, and `~mne.SourceEstimate`, and covers a lot of ground fairly quickly (at +the expense of depth). Subsequent tutorials address each of these topics in greater +detail. We begin by importing the necessary Python modules: """ @@ -79,7 +79,7 @@ # sessions, `~mne.io.Raw.plot` is interactive and allows scrolling, scaling, # bad channel marking, annotations, projector toggling, etc. -raw.compute_psd(fmax=50).plot(picks="data", exclude="bads") +raw.compute_psd(fmax=50).plot(picks="data", exclude="bads", amplitude=False) raw.plot(duration=5, n_channels=30) # %% @@ -309,8 +309,8 @@ # frequency content. frequencies = np.arange(7, 30, 3) -power = mne.time_frequency.tfr_morlet( - aud_epochs, n_cycles=2, return_itc=False, freqs=frequencies, decim=3 +power = aud_epochs.compute_tfr( + "morlet", n_cycles=2, return_itc=False, freqs=frequencies, decim=3, average=True ) power.plot(["MEG 1332"]) diff --git a/tutorials/inverse/10_stc_class.py b/tutorials/inverse/10_stc_class.py index 4330daa41ed..8638b4eaf2a 100644 --- a/tutorials/inverse/10_stc_class.py +++ b/tutorials/inverse/10_stc_class.py @@ -118,7 +118,7 @@ shape = stc.data.shape -print("The data has %s vertex locations with %s sample points each." % shape) +print(f"The data has {shape} vertex locations with {shape} sample points each.") # %% # We see that stc carries 7498 time series of 25 samples length. Those time @@ -140,7 +140,8 @@ shape_lh = stc.lh_data.shape print( - "The left hemisphere has %s vertex locations with %s sample points each." % shape_lh + f"The left hemisphere has {shape_lh} vertex locations with {shape_lh} sample points" + " each." ) # %% diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index f12e5968546..958ff809ede 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -92,8 +92,8 @@ best_idx = np.argmax(dip.gof) best_time = dip.times[best_idx] print( - "Highest GOF %0.1f%% at t=%0.1f ms with confidence volume %0.1f cm^3" - % (dip.gof[best_idx], best_time * 1000, dip.conf["vol"][best_idx] * 100**3) + f"Highest GOF {dip.gof[best_idx]:0.1f}% at t={best_time * 1000:0.1f} ms with " + f"confidence volume {dip.conf['vol'][best_idx] * 100**3:0.1f} cm^3" ) # remember to create a subplot for the colorbar fig, axes = plt.subplots( @@ -117,8 +117,7 @@ plot_params["colorbar"] = True diff.plot_topomap(time_format="Difference", axes=axes[2:], **plot_params) fig.suptitle( - "Comparison of measured and predicted fields " - "at {:.0f} ms".format(best_time * 1000.0), + f"Comparison of measured and predicted fields at {best_time * 1000:.0f} ms", fontsize=16, ) diff --git a/tutorials/inverse/80_brainstorm_phantom_elekta.py b/tutorials/inverse/80_brainstorm_phantom_elekta.py index 8184badeda3..ed9a14fc56f 100644 --- a/tutorials/inverse/80_brainstorm_phantom_elekta.py +++ b/tutorials/inverse/80_brainstorm_phantom_elekta.py @@ -53,7 +53,9 @@ # noise (five peaks around 300 Hz). Here, we use only the first 30 seconds # to save memory: -raw.compute_psd(tmax=30).plot(average=False, picks="data", exclude="bads") +raw.compute_psd(tmax=30).plot( + average=False, amplitude=False, picks="data", exclude="bads" +) # %% # Our phantom produces sinusoidal bursts at 20 Hz: @@ -149,19 +151,19 @@ ) diffs = 1000 * np.sqrt(np.sum((dip.pos - actual_pos) ** 2, axis=-1)) -print("mean(position error) = %0.1f mm" % (np.mean(diffs),)) +print(f"mean(position error) = {np.mean(diffs):0.1f} mm") ax1.bar(event_id, diffs) ax1.set_xlabel("Dipole index") ax1.set_ylabel("Loc. error (mm)") angles = np.rad2deg(np.arccos(np.abs(np.sum(dip.ori * actual_ori, axis=1)))) -print("mean(angle error) = %0.1f°" % (np.mean(angles),)) +print(f"mean(angle error) = {np.mean(angles):0.1f}°") ax2.bar(event_id, angles) ax2.set_xlabel("Dipole index") ax2.set_ylabel("Angle error (°)") amps = actual_amp - dip.amplitude / 1e-9 -print("mean(abs amplitude error) = %0.1f nAm" % (np.mean(np.abs(amps)),)) +print(f"mean(abs amplitude error) = {np.mean(np.abs(amps)):0.1f} nAm") ax3.bar(event_id, amps) ax3.set_xlabel("Dipole index") ax3.set_ylabel("Amplitude error (nAm)") diff --git a/tutorials/inverse/95_phantom_KIT.py b/tutorials/inverse/95_phantom_KIT.py index 444ae4635fd..75e0025a9c2 100644 --- a/tutorials/inverse/95_phantom_KIT.py +++ b/tutorials/inverse/95_phantom_KIT.py @@ -15,9 +15,8 @@ # Copyright the MNE-Python contributors. # %% -import matplotlib.pyplot as plt +import mne_bids import numpy as np -from scipy.signal import find_peaks import mne @@ -25,14 +24,33 @@ actual_pos, actual_ori = mne.dipole.get_phantom_dipoles("oyama") actual_pos, actual_ori = actual_pos[:49], actual_ori[:49] # only 49 of 50 dipoles -raw = mne.io.read_raw_kit(data_path / "002_phantom_11Hz_100uA.con") -# cut from ~800 to ~300s for speed, and also at convenient dip stim boundaries -# chosen by examining MISC 017 by eye. -raw.crop(11.5, 302.9).load_data() -raw.filter(None, 40) # 11 Hz stimulation, no need to keep higher freqs +bids_path = mne_bids.BIDSPath( + root=data_path, + subject="01", + task="phantom", + run="01", + datatype="meg", +) +# ignore warning about misc units +raw = mne_bids.read_raw_bids(bids_path).load_data() + +# Let's apply a little bit of preprocessing (temporal filtering and reference +# regression) +picks_artifact = ["MISC 001", "MISC 002", "MISC 003"] +picks = np.r_[ + mne.pick_types(raw.info, meg=True), + mne.pick_channels(raw.info["ch_names"], picks_artifact), +] +raw.filter(None, 40, picks=picks) +mne.preprocessing.regress_artifact( + raw, picks="meg", picks_artifact=picks_artifact, copy=False, proj=False +) plot_scalings = dict(mag=5e-12) # large-amplitude sinusoids raw_plot_kwargs = dict(duration=15, n_channels=50, scalings=plot_scalings) -raw.plot(**raw_plot_kwargs) +events, event_id = mne.events_from_annotations(raw) +raw.plot(events=events, **raw_plot_kwargs) +n_dip = len(event_id) +assert n_dip == 49 # sanity check # %% # We can also look at the power spectral density to see the phantom oscillations at @@ -40,87 +58,16 @@ # boxcar windowing of the 11 Hz sinusoid. spectrum = raw.copy().crop(0, 60).compute_psd(n_fft=10000) -fig = spectrum.plot() +fig = spectrum.plot(amplitude=False) fig.axes[0].set_xlim(0, 50) dip_freq = 11.0 fig.axes[0].axvline(dip_freq, color="r", ls="--", lw=2, zorder=4) # %% -# To find the events, we can look at the MISC channel that recorded the activations. -# Here we use a very simple thresholding approach to find the events. -# The MISC 017 channel holds the dipole activations, which are 2-cycle 11 Hz sinusoidal -# bursts with the initial sinusoidal deflection downward, so we do a little bit of -# signal manipulation to help :func:`~scipy.signal.find_peaks`. - -# Figure out events -dip_act, dip_t = raw["MISC 017"] -dip_act = dip_act[0] # 2D to 1D array -dip_act -= dip_act.mean() # remove DC offset -dip_act *= -1 # invert so first deflection is positive -thresh = np.percentile(dip_act, 90) -min_dist = raw.info["sfreq"] / dip_freq * 0.9 # 90% of period, to be safe -peaks = find_peaks(dip_act, height=thresh, distance=min_dist)[0] -assert len(peaks) % 2 == 0 # 2-cycle modulations -peaks = peaks[::2] # take only first peaks of each 2-cycle burst - -fig, ax = plt.subplots(layout="constrained", figsize=(12, 4)) -stop = int(15 * raw.info["sfreq"]) # 15 sec -ax.plot(dip_t[:stop], dip_act[:stop], color="k", lw=1) -ax.axhline(thresh, color="r", ls="--", lw=1) -peak_idx = peaks[peaks < stop] -ax.plot(dip_t[peak_idx], dip_act[peak_idx], "ro", zorder=5, ms=5) -ax.set(xlabel="Time (s)", ylabel="Dipole activation (AU)\n(MISC 017 adjusted)") -ax.set(xlim=dip_t[[0, stop - 1]]) - -# We know that there are 32 dipoles, so mark the first ones as well -n_dip = 49 -assert len(peaks) % n_dip == 0 # we found them all (hopefully) -ax.plot(dip_t[peak_idx[::n_dip]], dip_act[peak_idx[::n_dip]], "bo", zorder=4, ms=10) - -# Knowing we've caught the top of the first cycle of a 11 Hz sinusoid, plot onsets -# with red X's. -onsets = peaks - np.round(raw.info["sfreq"] / dip_freq / 4.0).astype( - int -) # shift to start -onset_idx = onsets[onsets < stop] -ax.plot(dip_t[onset_idx], dip_act[onset_idx], "rx", zorder=5, ms=5) - -# %% -# Given the onsets are now stored in ``peaks``, we can create our events array and plot -# on our raw data. +# Now we can figure out our epoching parameters and epoch the data and plot it. -n_rep = len(peaks) // n_dip -events = np.zeros((len(peaks), 3), int) -events[:, 0] = onsets + raw.first_samp -events[:, 2] = np.tile(np.arange(1, n_dip + 1), n_rep) -raw.plot(events=events, **raw_plot_kwargs) - -# %% -# Now we can figure out our epoching parameters and epoch the data, sanity checking -# some values along the way knowing how the stimulation was done. - -# Sanity check and determine epoching params -deltas = np.diff(events[:, 0], axis=0) -group_deltas = deltas[n_dip - 1 :: n_dip] / raw.info["sfreq"] # gap between 49 and 1 -assert (group_deltas > 0.8).all() -assert (group_deltas < 0.9).all() -others = np.delete(deltas, np.arange(n_dip - 1, len(deltas), n_dip)) # remove 49->1 -others = others / raw.info["sfreq"] -assert (others > 0.25).all() -assert (others < 0.3).all() -tmax = 1 / dip_freq * 2.0 # 2 cycles -tmin = tmax - others.min() -assert tmin < 0 -epochs = mne.Epochs( - raw, - events, - tmin=tmin, - tmax=tmax, - baseline=(None, 0), - decim=10, - picks="data", - preload=True, -) +tmin, tmax = -0.08, 0.18 +epochs = mne.Epochs(raw, tmin=tmin, tmax=tmax, decim=10, picks="data", preload=True) del raw epochs.plot(scalings=plot_scalings) @@ -131,7 +78,7 @@ t_peak = 1.0 / dip_freq / 4.0 data = np.zeros((len(epochs.ch_names), n_dip)) for di in range(n_dip): - data[:, [di]] = epochs[str(di + 1)].average().crop(t_peak, t_peak).data + data[:, [di]] = epochs[f"dip{di + 1:02d}"].average().crop(t_peak, t_peak).data evoked = mne.EvokedArray(data, epochs.info, tmin=0, comment="KIT phantom activations") evoked.plot_joint() @@ -141,22 +88,12 @@ trans = mne.transforms.Transform("head", "mri", np.eye(4)) sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.08) cov = mne.compute_covariance(epochs, tmax=0, method="empirical") -# We need to correct the ``dev_head_t`` because it's incorrect for these data! -# relative to the helmet: hleft, forward, up -translation = mne.transforms.translation(x=0.01, y=-0.015, z=-0.088) -# pitch down (rot about x/R), roll left (rot about y/A), yaw left (rot about z/S) -rotation = mne.transforms.rotation( - x=np.deg2rad(5), - y=np.deg2rad(-1), - z=np.deg2rad(-3), -) -evoked.info["dev_head_t"]["trans"][:] = translation @ rotation dip, residual = mne.fit_dipole(evoked, cov, sphere, n_jobs=None) # %% # Finally let's look at the results. -# sphinx_gallery_thumbnail_number = 7 +# sphinx_gallery_thumbnail_number = 5 print(f"Average amplitude: {np.mean(dip.amplitude) * 1e9:0.1f} nAm") print(f"Average GOF: {np.mean(dip.gof):0.1f}%") diff --git a/tutorials/io/60_ctf_bst_auditory.py b/tutorials/io/60_ctf_bst_auditory.py index dd8d9abadf5..a9d86594669 100644 --- a/tutorials/io/60_ctf_bst_auditory.py +++ b/tutorials/io/60_ctf_bst_auditory.py @@ -105,7 +105,7 @@ for idx in [1, 2]: csv_fname = data_path / "MEG" / "bst_auditory" / f"events_bad_0{idx}.csv" df = pd.read_csv(csv_fname, header=None, names=["onset", "duration", "id", "label"]) - print("Events from run {0}:".format(idx)) + print(f"Events from run {idx}:") print(df) df["onset"] += offset * (idx - 1) @@ -165,10 +165,14 @@ # saving mode we do the filtering at evoked stage, which is not something you # usually would do. if not use_precomputed: - raw.compute_psd(tmax=np.inf, picks="meg").plot(picks="data", exclude="bads") + raw.compute_psd(tmax=np.inf, picks="meg").plot( + picks="data", exclude="bads", amplitude=False + ) notches = np.arange(60, 181, 60) raw.notch_filter(notches, phase="zero-double", fir_design="firwin2") - raw.compute_psd(tmax=np.inf, picks="meg").plot(picks="data", exclude="bads") + raw.compute_psd(tmax=np.inf, picks="meg").plot( + picks="data", exclude="bads", amplitude=False + ) # %% # We also lowpass filter the data at 100 Hz to remove the hf components. @@ -204,9 +208,7 @@ onsets = onsets[diffs > min_diff] assert len(onsets) == len(events) diffs = 1000.0 * (events[:, 0] - onsets) / raw.info["sfreq"] -print( - "Trigger delay removed (μ ± σ): %0.1f ± %0.1f ms" % (np.mean(diffs), np.std(diffs)) -) +print(f"Trigger delay removed (μ ± σ): {np.mean(diffs):0.1f} ± {np.std(diffs):0.1f} ms") events[:, 0] = onsets del sound_data, diffs diff --git a/tutorials/machine-learning/30_strf.py b/tutorials/machine-learning/30_strf.py index a838ae0018c..4d8acad03c2 100644 --- a/tutorials/machine-learning/30_strf.py +++ b/tutorials/machine-learning/30_strf.py @@ -170,9 +170,9 @@ # Create training and testing data train, test = np.arange(n_epochs - 1), n_epochs - 1 X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test] -X_train, X_test, y_train, y_test = [ +X_train, X_test, y_train, y_test = ( np.rollaxis(ii, -1, 0) for ii in (X_train, X_test, y_train, y_test) -] +) # Model the simulated data as a function of the spectrogram input alphas = np.logspace(-3, 3, 7) scores = np.zeros_like(alphas) diff --git a/tutorials/machine-learning/50_decoding.py b/tutorials/machine-learning/50_decoding.py index 06d34bd49c8..10fa044281b 100644 --- a/tutorials/machine-learning/50_decoding.py +++ b/tutorials/machine-learning/50_decoding.py @@ -145,7 +145,7 @@ # Mean scores across cross-validation splits score = np.mean(scores, axis=0) -print("Spatio-temporal: %0.1f%%" % (100 * score,)) +print(f"Spatio-temporal: {100 * score:0.1f}%") # %% # PSDEstimator @@ -224,7 +224,7 @@ csp = CSP(n_components=3, norm_trace=False) clf_csp = make_pipeline(csp, LinearModel(LogisticRegression(solver="liblinear"))) scores = cross_val_multiscore(clf_csp, X, y, cv=5, n_jobs=None) -print("CSP: %0.1f%%" % (100 * scores.mean(),)) +print(f"CSP: {100 * scores.mean():0.1f}%") # %% # Source power comodulation (SPoC) diff --git a/tutorials/preprocessing/10_preprocessing_overview.py b/tutorials/preprocessing/10_preprocessing_overview.py index 483ac653767..d70fa4b4811 100644 --- a/tutorials/preprocessing/10_preprocessing_overview.py +++ b/tutorials/preprocessing/10_preprocessing_overview.py @@ -141,7 +141,7 @@ # use :meth:`~mne.io.Raw.compute_psd` to illustrate. fig = raw.compute_psd(tmax=np.inf, fmax=250).plot( - average=True, picks="data", exclude="bads" + average=True, amplitude=False, picks="data", exclude="bads" ) # add some arrows at 60 Hz and its harmonics: for ax in fig.axes[1:]: diff --git a/tutorials/preprocessing/15_handling_bad_channels.py b/tutorials/preprocessing/15_handling_bad_channels.py index daac97976a5..7ddc36af026 100644 --- a/tutorials/preprocessing/15_handling_bad_channels.py +++ b/tutorials/preprocessing/15_handling_bad_channels.py @@ -238,8 +238,9 @@ fig.suptitle(title, size="xx-large", weight="bold") # %% -# Note that we used the ``exclude=[]`` trick in the call to -# :meth:`~mne.io.Raw.pick_types` to make sure the bad channels were not +# Note that the method :meth:`~mne.io.Raw.pick` default +# arguments includes ``exclude=()`` which ensures that bad +# channels are not # automatically dropped from the selection. Here is the corresponding example # with the interpolated gradiometer channel; since there are more channels # we'll use a more transparent gray color this time: diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index d478255b048..a04005f3532 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -23,6 +23,8 @@ import os +import numpy as np + import mne sample_data_folder = mne.datasets.sample.data_path() @@ -205,8 +207,8 @@ # %% # .. _`tut-reject-epochs-section`: # -# Rejecting Epochs based on channel amplitude -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Rejecting Epochs based on peak-to-peak channel amplitude +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Besides "bad" annotations, the :class:`mne.Epochs` class constructor has # another means of rejecting epochs, based on signal amplitude thresholds for @@ -328,6 +330,108 @@ epochs.drop_bad(reject=stronger_reject_criteria) print(epochs.drop_log) +# %% +# .. _`tut-reject-epochs-func-section`: +# +# Rejecting Epochs using callables (functions) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Sometimes it is useful to reject epochs based criteria other than +# peak-to-peak amplitudes. For example, we might want to reject epochs +# based on the maximum or minimum amplitude of a channel. +# In this case, the `mne.Epochs.drop_bad` function also accepts +# callables (functions) in the ``reject`` and ``flat`` parameters. This +# allows us to define functions to reject epochs based on our desired criteria. +# +# Let's begin by generating Epoch data with large artifacts in one eeg channel +# in order to demonstrate the versatility of this approach. + +raw.crop(0, 5) +raw.del_proj() +chans = raw.info["ch_names"][-5:-1] +raw.pick(chans) +data = raw.get_data() + +new_data = data +new_data[0, 180:200] *= 1e3 +new_data[0, 460:580] += 1e-3 +edit_raw = mne.io.RawArray(new_data, raw.info) + +# Create fixed length epochs of 1 second +events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) +epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# As you can see, we have two large artifacts in the first channel. One large +# spike in amplitude and one large increase in amplitude. + +# Let's try to reject the epoch containing the spike in amplitude based on the +# maximum amplitude of the first channel. Please note that the callable in +# ``reject`` must return a (good, reason) tuple. Where the good must be bool +# and reason must be a str, list, or tuple where each entry is a str. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")) +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Here, the epoch containing the spike in amplitude was rejected for having a +# maximum amplitude greater than 1e-2 Volts. Notice the use of the ``any()`` +# function to check if any of the channels exceeded the threshold. We could +# have also used the ``all()`` function to check if all channels exceeded the +# threshold. + +# Next, let's try to reject the epoch containing the increase in amplitude +# using the median. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")) +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Finally, let's try to reject both epochs using a combination of the maximum +# and median. We'll define a custom function and use boolean operators to +# combine the two criteria. + + +def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return ((max_condition.any() or median_condition.any()), ["max amp", "median amp"]) + + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad(reject=dict(eeg=reject_criteria)) +epochs.plot(events=True) + # %% # Note that a complementary Python module, the `autoreject package`_, uses # machine learning to find optimal rejection criteria, and is designed to diff --git a/tutorials/preprocessing/25_background_filtering.py b/tutorials/preprocessing/25_background_filtering.py index 948ab43d76f..c0f56098bad 100644 --- a/tutorials/preprocessing/25_background_filtering.py +++ b/tutorials/preprocessing/25_background_filtering.py @@ -148,6 +148,7 @@ from scipy import signal import mne +from mne.fixes import minimum_phase from mne.time_frequency.tfr import morlet from mne.viz import plot_filter, plot_ideal_filter @@ -168,7 +169,7 @@ gain = [1, 1, 0, 0] third_height = np.array(plt.rcParams["figure.figsize"]) * [1, 1.0 / 3.0] -ax = plt.subplots(1, figsize=third_height)[1] +ax = plt.subplots(1, figsize=third_height, layout="constrained")[1] plot_ideal_filter(freq, gain, ax, title="Ideal %s Hz lowpass" % f_p, flim=flim) # %% @@ -249,8 +250,8 @@ freq = [0.0, f_p, f_s, nyq] gain = [1.0, 1.0, 0.0, 0.0] -ax = plt.subplots(1, figsize=third_height)[1] -title = "%s Hz lowpass with a %s Hz transition" % (f_p, trans_bandwidth) +ax = plt.subplots(1, figsize=third_height, layout="constrained")[1] +title = f"{f_p} Hz lowpass with a {trans_bandwidth} Hz transition" plot_ideal_filter(freq, gain, ax, title=title, flim=flim) # %% @@ -316,15 +317,15 @@ # is constant) but small in the pass-band. Unlike zero-phase filters, which # require time-shifting backward the output of a linear-phase filtering stage # (and thus becoming non-causal), minimum-phase filters do not require any -# compensation to achieve small delays in the pass-band. Note that as an -# artifact of the minimum phase filter construction step, the filter does -# not end up being as steep as the linear/zero-phase version. +# compensation to achieve small delays in the pass-band. # # We can construct a minimum-phase filter from our existing linear-phase -# filter with the :func:`scipy.signal.minimum_phase` function, and note -# that the falloff is not as steep: +# filter, and note that the falloff is not as steep. Here we do this with function +# ``mne.fixes.minimum_phase()`` to avoid a SciPy bug; once SciPy 1.14.0 is released you +# could directly use +# :func:`scipy.signal.minimum_phase(..., half=False) `. -h_min = signal.minimum_phase(h) +h_min = minimum_phase(h, half=False) plot_filter(h_min, sfreq, freq, gain, "Minimum-phase", **kwargs) # %% @@ -683,7 +684,6 @@ def plot_signal(x, offset): for text in axes[0].get_yticklabels(): text.set(rotation=45, size=8) axes[1].set(xlim=flim, ylim=(-60, 10), xlabel="Frequency (Hz)", ylabel="Magnitude (dB)") -mne.viz.adjust_axes(axes) plt.show() # %% @@ -779,7 +779,7 @@ def plot_signal(x, offset): xlabel = "Time (s)" ylabel = r"Amplitude ($\mu$V)" tticks = [0, 0.5, 1.3, t[-1]] -axes = plt.subplots(2, 2)[1].ravel() +axes = plt.subplots(2, 2, layout="constrained")[1].ravel() for ax, x_f, title in zip( axes, [x_lp_2, x_lp_30, x_hp_2, x_hp_p1], @@ -791,7 +791,6 @@ def plot_signal(x, offset): ylim=ylim, xlim=xlim, xticks=tticks, title=title, xlabel=xlabel, ylabel=ylabel ) -mne.viz.adjust_axes(axes) plt.show() # %% @@ -830,7 +829,7 @@ def plot_signal(x, offset): def baseline_plot(x): - all_axes = plt.subplots(3, 2, layout="constrained")[1] + fig, all_axes = plt.subplots(3, 2, layout="constrained") for ri, (axes, freq) in enumerate(zip(all_axes, [0.1, 0.3, 0.5])): for ci, ax in enumerate(axes): if ci == 0: @@ -846,8 +845,7 @@ def baseline_plot(x): ax.set(title=("No " if ci == 0 else "") + "Baseline Correction") ax.set(xticks=tticks, ylim=ylim, xlim=xlim, xlabel=xlabel) ax.set_ylabel("%0.1f Hz" % freq, rotation=0, horizontalalignment="right") - mne.viz.adjust_axes(axes) - plt.suptitle(title) + fig.suptitle(title) plt.show() diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index 530b92741f6..cf9b3335949 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -78,9 +78,7 @@ duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False ) fig.subplots_adjust(top=0.9) - fig.suptitle( - "High-pass filtered at {} Hz".format(cutoff), size="xx-large", weight="bold" - ) + fig.suptitle(f"High-pass filtered at {cutoff} Hz", size="xx-large", weight="bold") # %% # Looks like 0.1 Hz was not quite high enough to fully remove the slow drifts. @@ -123,7 +121,7 @@ def add_arrows(axes): - # add some arrows at 60 Hz and its harmonics + """Add some arrows at 60 Hz and its harmonics.""" for ax in axes: freqs = ax.lines[-1].get_xdata() psds = ax.lines[-1].get_ydata() @@ -143,7 +141,9 @@ def add_arrows(axes): ) -fig = raw.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads") +fig = raw.compute_psd(fmax=250).plot( + average=True, amplitude=False, picks="data", exclude="bads" +) add_arrows(fig.axes[:2]) # %% @@ -159,8 +159,10 @@ def add_arrows(axes): freqs = (60, 120, 180, 240) raw_notch = raw.copy().notch_filter(freqs=freqs, picks=meg_picks) for title, data in zip(["Un", "Notch "], [raw, raw_notch]): - fig = data.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads") - fig.suptitle("{}filtered".format(title), size="xx-large", weight="bold") + fig = data.compute_psd(fmax=250).plot( + average=True, amplitude=False, picks="data", exclude="bads" + ) + fig.suptitle(f"{title}filtered", size="xx-large", weight="bold") add_arrows(fig.axes[:2]) # %% @@ -178,8 +180,10 @@ def add_arrows(axes): freqs=freqs, picks=meg_picks, method="spectrum_fit", filter_length="10s" ) for title, data in zip(["Un", "spectrum_fit "], [raw, raw_notch_fit]): - fig = data.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads") - fig.suptitle("{}filtered".format(title), size="xx-large", weight="bold") + fig = data.compute_psd(fmax=250).plot( + average=True, amplitude=False, picks="data", exclude="bads" + ) + fig.suptitle(f"{title}filtered", size="xx-large", weight="bold") add_arrows(fig.axes[:2]) # %% @@ -206,16 +210,59 @@ def add_arrows(axes): # frequency`_ of the desired new sampling rate. This can be clearly seen in the # PSD plot, where a dashed vertical line indicates the filter cutoff; the # original data had an existing lowpass at around 172 Hz (see -# ``raw.info['lowpass']``), and the data resampled from 600 Hz to 200 Hz gets +# ``raw.info['lowpass']``), and the data resampled from ~600 Hz to 200 Hz gets # automatically lowpass filtered at 100 Hz (the `Nyquist frequency`_ for a # target rate of 200 Hz): raw_downsampled = raw.copy().resample(sfreq=200) +# choose n_fft for Welch PSD to make frequency axes similar resolution +n_ffts = [4096, int(round(4096 * 200 / raw.info["sfreq"]))] +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled], ["Original", "Downsampled"], n_ffts +): + fig = data.compute_psd(n_fft=n_fft).plot( + average=True, amplitude=False, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) -for data, title in zip([raw, raw_downsampled], ["Original", "Downsampled"]): - fig = data.compute_psd().plot(average=True, picks="data", exclude="bads") - fig.suptitle(title) - plt.setp(fig.axes, xlim=(0, 300)) +# %% +# By default, MNE-Python resamples using ``method="fft"``, which performs FFT-based +# resampling via :func:`scipy.signal.resample`. While efficient and good for most +# biological signals, it has two main potential drawbacks: +# +# 1. It assumes periodicity of the signal. We try to overcome this with appropriate +# signal padding, but some signal leakage may still occur. +# 2. It treats the entire signal as a single block. This means that in general effects +# are not guaranteed to be localized in time, though in practice they often are. +# +# Alternatively, resampling can be performed using ``method="polyphase"`` instead. +# This uses :func:`scipy.signal.resample_poly` under the hood, which in turn utilizes +# a three-step process to resample signals (see :func:`scipy.signal.upfirdn` for +# details). This process guarantees that each resampled output value is only affected by +# input values within a limited range. In other words, output values are guaranteed to +# be a result of a specific set of input values. +# +# In general, using ``method="polyphase"`` can also be faster than ``method="fft"`` in +# cases where the desired sampling rate is an integer factor different from the input +# sampling rate. For example: + +# sphinx_gallery_thumbnail_number = 11 + +n_ffts = [4096, 2048] # factor of 2 smaller n_fft +raw_downsampled_poly = raw.copy().resample( + sfreq=raw.info["sfreq"] / 2.0, + method="polyphase", + verbose=True, +) +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled_poly], ["Original", "Downsampled (polyphase)"], n_ffts +): + data.compute_psd(n_fft=n_fft).plot( + average=True, amplitude=False, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) # %% # Because resampling involves filtering, there are some pitfalls to resampling diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 6f21840fa30..7c7c872ff70 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -416,11 +416,10 @@ ica.plot_sources(eog_evoked) # %% -# Note that above we used `~mne.preprocessing.ICA.plot_sources` on both -# the original `~mne.io.Raw` instance and also on an -# `~mne.Evoked` instance of the extracted EOG artifacts. This can be -# another way to confirm that `~mne.preprocessing.ICA.find_bads_eog` has -# identified the correct components. +# Note that above we used :meth:`~mne.preprocessing.ICA.plot_sources` on both the +# original :class:`~mne.io.Raw` instance and also on an `~mne.Evoked` instance of the +# extracted EOG artifacts. This can be another way to confirm that +# :meth:`~mne.preprocessing.ICA.find_bads_eog` has identified the correct components. # # # Using a simulated channel to select ICA components @@ -567,7 +566,7 @@ with mne.viz.use_browser_backend("matplotlib"): fig = ica.plot_sources(raw, show_scrollbars=False) fig.subplots_adjust(top=0.9) # make space for title - fig.suptitle("Subject {}".format(index)) + fig.suptitle(f"Subject {index}") # %% # Notice that subjects 2 and 3 each seem to have *two* ICs that reflect ocular diff --git a/tutorials/preprocessing/45_projectors_background.py b/tutorials/preprocessing/45_projectors_background.py index 0b11d168db4..00de570229b 100644 --- a/tutorials/preprocessing/45_projectors_background.py +++ b/tutorials/preprocessing/45_projectors_background.py @@ -372,7 +372,7 @@ def setup_3d_axes(): with mne.viz.use_browser_backend("matplotlib"): fig = mags.plot(butterfly=True, proj=proj) fig.subplots_adjust(top=0.9) - fig.suptitle("proj={}".format(proj), size="xx-large", weight="bold") + fig.suptitle(f"proj={proj}", size="xx-large", weight="bold") # %% # Additional ways of visualizing projectors are covered in the tutorial @@ -443,7 +443,7 @@ def setup_3d_axes(): with mne.viz.use_browser_backend("matplotlib"): fig = data.plot(butterfly=True, proj=True) fig.subplots_adjust(top=0.9) - fig.suptitle("{} ECG projector".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} ECG projector", size="xx-large", weight="bold") # %% # When are projectors "applied"? diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 2f5af536a3d..bc0b9081f64 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -116,7 +116,14 @@ raw.info["bads"] = ["MEG 2443"] spectrum = empty_room_raw.compute_psd() for average in (False, True): - spectrum.plot(average=average, dB=False, xscale="log", picks="data", exclude="bads") + spectrum.plot( + average=average, + dB=False, + amplitude=True, + xscale="log", + picks="data", + exclude="bads", + ) # %% # Creating the empty-room projectors @@ -173,7 +180,7 @@ with mne.viz.use_browser_backend("matplotlib"): fig = raw.plot(proj=True, order=mags, duration=1, n_channels=2) fig.subplots_adjust(top=0.9) # make room for title - fig.suptitle("{} projectors".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} projectors", size="xx-large", weight="bold") # %% # The effect is sometimes easier to see on averaged data. Here we use an @@ -340,7 +347,7 @@ with mne.viz.use_browser_backend("matplotlib"): fig = raw.plot(order=artifact_picks, n_channels=len(artifact_picks)) fig.subplots_adjust(top=0.9) # make room for title - fig.suptitle("{} ECG projectors".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} ECG projectors", size="xx-large", weight="bold") # %% # Finally, note that above we passed ``reject=None`` to the @@ -452,7 +459,7 @@ with mne.viz.use_browser_backend("matplotlib"): fig = raw.plot(order=artifact_picks, n_channels=len(artifact_picks)) fig.subplots_adjust(top=0.9) # make room for title - fig.suptitle("{} EOG projectors".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} EOG projectors", size="xx-large", weight="bold") # %% # Notice that the small peaks in the first to magnetometer channels (``MEG diff --git a/tutorials/preprocessing/55_setting_eeg_reference.py b/tutorials/preprocessing/55_setting_eeg_reference.py index 049e8f31a8b..22e247469ee 100644 --- a/tutorials/preprocessing/55_setting_eeg_reference.py +++ b/tutorials/preprocessing/55_setting_eeg_reference.py @@ -27,7 +27,7 @@ ) raw = mne.io.read_raw_fif(sample_data_raw_file, verbose=False) raw.crop(tmax=60).load_data() -raw.pick(["EEG 0{:02}".format(n) for n in range(41, 60)]) +raw.pick([f"EEG 0{n:02}" for n in range(41, 60)]) # %% # Background @@ -131,7 +131,7 @@ # :meth:`~mne.io.Raw.set_eeg_reference` with ``ref_channels='average'``. Just # as above, this will not affect any channels marked as "bad", nor will it # include bad channels when computing the average. However, it does modify the -# :class:`~mne.io.Raw` object in-place, so we'll make a copy first so we can +# :class:`~mne.io.Raw` object in-place, so we'll make a copy first, so we can # still go back to the unmodified :class:`~mne.io.Raw` object later: # sphinx_gallery_thumbnail_number = 4 @@ -176,7 +176,7 @@ fig = raw.plot(proj=proj, n_channels=len(raw)) # make room for title fig.subplots_adjust(top=0.9) - fig.suptitle("{} reference".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} reference", size="xx-large", weight="bold") # %% # Using an infinite reference (REST) @@ -199,7 +199,7 @@ fig = _raw.plot(n_channels=len(raw), scalings=dict(eeg=5e-5)) # make room for title fig.subplots_adjust(top=0.9) - fig.suptitle("{} reference".format(title), size="xx-large", weight="bold") + fig.suptitle(f"{title} reference", size="xx-large", weight="bold") # %% # Using a bipolar reference @@ -241,9 +241,13 @@ # the source modeling is performed. In contrast, applying an average reference # by the traditional subtraction method offers no such guarantee. # -# For these reasons, when performing inverse imaging, *MNE-Python will raise -# a ``ValueError`` if there are EEG channels present and something other than -# an average reference strategy has been specified*. +# .. important:: For these reasons, when performing inverse imaging, MNE-Python +# will raise a ``ValueError`` if there are EEG channels present +# and something other than an average reference projector strategy +# has been specified. To ensure correct functioning consider +# calling :meth:`set_eeg_reference(projection=True) +# ` to add an average +# reference as a projector. # # .. LINKS # diff --git a/tutorials/preprocessing/59_head_positions.py b/tutorials/preprocessing/59_head_positions.py index cd1a454fd7b..37ed574132b 100644 --- a/tutorials/preprocessing/59_head_positions.py +++ b/tutorials/preprocessing/59_head_positions.py @@ -37,7 +37,7 @@ data_path = op.join(mne.datasets.testing.data_path(verbose=True), "SSS") fname_raw = op.join(data_path, "test_move_anon_raw.fif") raw = mne.io.read_raw_fif(fname_raw, allow_maxshield="yes").load_data() -raw.compute_psd().plot(picks="data", exclude="bads") +raw.compute_psd().plot(picks="data", exclude="bads", amplitude=False) # %% # We can use `mne.chpi.get_chpi_info` to retrieve the coil frequencies, diff --git a/tutorials/preprocessing/70_fnirs_processing.py b/tutorials/preprocessing/70_fnirs_processing.py index 8b59c6a31ff..4c211c9a770 100644 --- a/tutorials/preprocessing/70_fnirs_processing.py +++ b/tutorials/preprocessing/70_fnirs_processing.py @@ -157,7 +157,9 @@ raw_haemo_unfiltered = raw_haemo.copy() raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02) for when, _raw in dict(Before=raw_haemo_unfiltered, After=raw_haemo).items(): - fig = _raw.compute_psd().plot(average=True, picks="data", exclude="bads") + fig = _raw.compute_psd().plot( + average=True, amplitude=False, picks="data", exclude="bads" + ) fig.suptitle(f"{when} filtering", weight="bold", size="x-large") # %% @@ -244,7 +246,7 @@ epochs["Tapping"].average().plot_image(axes=axes[:, 1], clim=clims) for column, condition in enumerate(["Control", "Tapping"]): for ax in axes[:, column]: - ax.set_title("{}: {}".format(condition, ax.get_title())) + ax.set_title(f"{condition}: {ax.get_title()}") # %% @@ -344,7 +346,7 @@ for column, condition in enumerate(["Tapping Left", "Tapping Right", "Left-Right"]): for row, chroma in enumerate(["HbO", "HbR"]): - axes[row, column].set_title("{}: {}".format(chroma, condition)) + axes[row, column].set_title(f"{chroma}: {condition}") # %% # Lastly, we can also look at the individual waveforms to see what is diff --git a/tutorials/preprocessing/80_opm_processing.py b/tutorials/preprocessing/80_opm_processing.py index 49a8159d748..8d1642d88b8 100644 --- a/tutorials/preprocessing/80_opm_processing.py +++ b/tutorials/preprocessing/80_opm_processing.py @@ -26,20 +26,20 @@ # %% import matplotlib.pyplot as plt +import nibabel as nib import numpy as np import mne -opm_data_folder = mne.datasets.ucl_opm_auditory.data_path() +subject = "sub-002" +data_path = mne.datasets.ucl_opm_auditory.data_path() opm_file = ( - opm_data_folder - / "sub-002" - / "ses-001" - / "meg" - / "sub-002_ses-001_task-aef_run-001_meg.bin" + data_path / subject / "ses-001" / "meg" / "sub-002_ses-001_task-aef_run-001_meg.bin" ) +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" + # For now we are going to assume the device and head coordinate frames are -# identical (even though this is incorrect), so we pass verbose='error' for now +# identical (even though this is incorrect), so we pass verbose='error' raw = mne.io.read_raw_fil(opm_file, verbose="error") raw.crop(120, 210).load_data() # crop for speed @@ -240,7 +240,59 @@ raw, events, tmin=-0.1, tmax=0.4, baseline=(-0.1, 0.0), verbose="error" ) evoked = epochs.average() -evoked.plot() +t_peak = evoked.times[np.argmax(np.std(evoked.copy().pick("meg").data, axis=0))] +fig = evoked.plot() +fig.axes[0].axvline(t_peak, color="red", ls="--", lw=1) + +# %% +# Visualizing coregistration +# -------------------------- +# By design, the sensors in this dataset are already in the scanner RAS coordinate +# frame. We can thus visualize them in the FreeSurfer MRI coordinate frame by computing +# the transformation between the FreeSurfer MRI coordinate frame and scanner RAS: + +mri = nib.load(subjects_dir / "sub-002" / "mri" / "T1.mgz") +trans = mri.header.get_vox2ras_tkr() @ np.linalg.inv(mri.affine) +trans[:3, 3] /= 1000.0 # nibabel uses mm, MNE uses m +trans = mne.transforms.Transform("head", "mri", trans) + +bem = subjects_dir / subject / "bem" / f"{subject}-5120-bem-sol.fif" +src = subjects_dir / subject / "bem" / f"{subject}-oct-6-src.fif" +mne.viz.plot_alignment( + evoked.info, + subjects_dir=subjects_dir, + subject=subject, + trans=trans, + surfaces={"head": 0.1, "inner_skull": 0.2, "white": 1.0}, + meg=["helmet", "sensors"], + verbose="error", + bem=bem, + src=src, +) + +# %% +# Plotting the inverse +# -------------------- +# Now we can compute a forward and inverse: + +fwd = mne.make_forward_solution( + evoked.info, + trans=trans, + bem=bem, + src=src, + verbose=True, +) +noise_cov = mne.compute_covariance(epochs, tmax=0) +inv = mne.minimum_norm.make_inverse_operator(evoked.info, fwd, noise_cov, verbose=True) +stc = mne.minimum_norm.apply_inverse( + evoked, inv, 1.0 / 9.0, method="dSPM", verbose=True +) +brain = stc.plot( + hemi="split", + size=(800, 400), + initial_time=t_peak, + subjects_dir=subjects_dir, +) # %% # References diff --git a/tutorials/raw/10_raw_overview.py b/tutorials/raw/10_raw_overview.py index 31dfbf12325..bf8fe20effd 100644 --- a/tutorials/raw/10_raw_overview.py +++ b/tutorials/raw/10_raw_overview.py @@ -142,10 +142,10 @@ ch_names = raw.ch_names n_chan = len(ch_names) # note: there is no raw.n_channels attribute print( - "the (cropped) sample data object has {} time samples and {} channels." - "".format(n_time_samps, n_chan) + f"the (cropped) sample data object has {n_time_samps} time samples and " + f"{n_chan} channels." ) -print("The last time sample is at {} seconds.".format(time_secs[-1])) +print(f"The last time sample is at {time_secs[-1]} seconds.") print("The first few channel names are {}.".format(", ".join(ch_names[:3]))) print() # insert a blank line in the output @@ -291,9 +291,12 @@ # inaccurate, you can change the type of any channel with the # :meth:`~mne.io.Raw.set_channel_types` method. The method takes a # :class:`dictionary ` mapping channel names to types; allowed types are -# ``ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, dbs, stim, syst, ecog, -# hbo, hbr``. A common use case for changing channel type is when using frontal -# EEG electrodes as makeshift EOG channels: +# ``bio, chpi, csd, dbs, dipole, ecg, ecog, eeg, emg, eog, exci, eyegaze, +# fnirs_cw_amplitude, fnirs_fd_ac_amplitude, fnirs_fd_phase, fnirs_od, gof, +# gsr, hbo, hbr, ias, misc, pupil, ref_meg, resp, seeg, stim, syst, +# temperature`` (see :term:`sensor types` for more information about them). +# A common use case for changing channel type is when using frontal EEG +# electrodes as makeshift EOG channels: raw.set_channel_types({"EEG_001": "eog"}) print(raw.copy().pick(picks="eog").ch_names) diff --git a/tutorials/raw/30_annotate_raw.py b/tutorials/raw/30_annotate_raw.py index 8a2a43d4188..99c40506b66 100644 --- a/tutorials/raw/30_annotate_raw.py +++ b/tutorials/raw/30_annotate_raw.py @@ -230,7 +230,7 @@ descr = ann["description"] start = ann["onset"] end = ann["onset"] + ann["duration"] - print("'{}' goes from {} to {}".format(descr, start, end)) + print(f"'{descr}' goes from {start} to {end}") # %% # Note that iterating, indexing and slicing `~mne.Annotations` all diff --git a/tutorials/raw/40_visualize_raw.py b/tutorials/raw/40_visualize_raw.py index 0056d90e413..091f44a1493 100644 --- a/tutorials/raw/40_visualize_raw.py +++ b/tutorials/raw/40_visualize_raw.py @@ -5,13 +5,13 @@ Built-in plotting methods for Raw objects ========================================= -This tutorial shows how to plot continuous data as a time series, how to plot -the spectral density of continuous data, and how to plot the sensor locations -and projectors stored in `~mne.io.Raw` objects. +This tutorial shows how to plot continuous data as a time series, how to plot the +spectral density of continuous data, and how to plot the sensor locations and projectors +stored in `~mne.io.Raw` objects. As usual we'll start by importing the modules we need, loading some -:ref:`example data `, and cropping the `~mne.io.Raw` -object to just 60 seconds before loading it into RAM to save memory: +:ref:`example data `, and cropping the `~mne.io.Raw` object to just 60 +seconds before loading it into RAM to save memory: """ # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -120,7 +120,7 @@ # object has a :meth:`~mne.time_frequency.Spectrum.plot` method: spectrum = raw.compute_psd() -spectrum.plot(average=True, picks="data", exclude="bads") +spectrum.plot(average=True, picks="data", exclude="bads", amplitude=False) # %% # If the data have been filtered, vertical dashed lines will automatically @@ -134,7 +134,7 @@ # documentation of `~mne.time_frequency.Spectrum.plot` for full details): midline = ["EEG 002", "EEG 012", "EEG 030", "EEG 048", "EEG 058", "EEG 060"] -spectrum.plot(picks=midline, exclude="bads") +spectrum.plot(picks=midline, exclude="bads", amplitude=False) # %% # It is also possible to plot spectral power estimates across sensors as a diff --git a/tutorials/simulation/10_array_objs.py b/tutorials/simulation/10_array_objs.py index a2e94ab1c7a..4367d880207 100644 --- a/tutorials/simulation/10_array_objs.py +++ b/tutorials/simulation/10_array_objs.py @@ -232,4 +232,4 @@ info=info, ) -spectrum.plot(spatial_colors=False) +spectrum.plot(spatial_colors=False, amplitude=False) diff --git a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py index c32af4bcd97..0e7242e96d5 100644 --- a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py +++ b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py @@ -40,7 +40,6 @@ import mne from mne.datasets import sample from mne.stats import permutation_cluster_1samp_test -from mne.time_frequency import tfr_morlet # %% # Set parameters @@ -92,8 +91,8 @@ freqs = np.arange(8, 40, 2) # run the TFR decomposition -tfr_epochs = tfr_morlet( - epochs, +tfr_epochs = epochs.compute_tfr( + "morlet", freqs, n_cycles=4.0, decim=decim, diff --git a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py index 3ced6a82463..0b4078ec883 100644 --- a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py +++ b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py @@ -32,7 +32,6 @@ import mne from mne.datasets import sample from mne.stats import permutation_cluster_test -from mne.time_frequency import tfr_morlet print(__doc__) @@ -104,24 +103,17 @@ decim = 2 freqs = np.arange(7, 30, 3) # define frequencies of interest n_cycles = 1.5 - -tfr_epochs_1 = tfr_morlet( - epochs_condition_1, - freqs, +tfr_kwargs = dict( + method="morlet", + freqs=freqs, n_cycles=n_cycles, decim=decim, return_itc=False, average=False, ) -tfr_epochs_2 = tfr_morlet( - epochs_condition_2, - freqs, - n_cycles=n_cycles, - decim=decim, - return_itc=False, - average=False, -) +tfr_epochs_1 = epochs_condition_1.compute_tfr(**tfr_kwargs) +tfr_epochs_2 = epochs_condition_2.compute_tfr(**tfr_kwargs) tfr_epochs_1.apply_baseline(mode="ratio", baseline=(None, 0)) tfr_epochs_2.apply_baseline(mode="ratio", baseline=(None, 0)) diff --git a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py index 202c660575a..19a90decea8 100644 --- a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py +++ b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py @@ -36,7 +36,6 @@ import mne from mne.datasets import sample from mne.stats import f_mway_rm, f_threshold_mway_rm, fdr_correction -from mne.time_frequency import tfr_morlet print(__doc__) @@ -105,8 +104,8 @@ # --------------------------------------------- epochs_power = list() for condition in [epochs[k] for k in event_id]: - this_tfr = tfr_morlet( - condition, + this_tfr = condition.compute_tfr( + "morlet", freqs, n_cycles=n_cycles, decim=decim, diff --git a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py index c7fdbdb1fc2..2ba8c55bf3d 100644 --- a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py +++ b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py @@ -41,7 +41,6 @@ from mne.channels import find_ch_adjacency from mne.datasets import sample from mne.stats import combine_adjacency, spatio_temporal_cluster_test -from mne.time_frequency import tfr_morlet from mne.viz import plot_compare_evokeds # %% @@ -231,7 +230,7 @@ # add new axis for time courses and plot time courses ax_signals = divider.append_axes("right", size="300%", pad=1.2) - title = "Cluster #{0}, {1} sensor".format(i_clu + 1, len(ch_inds)) + title = f"Cluster #{i_clu + 1}, {len(ch_inds)} sensor" if len(ch_inds) > 1: title += "s (mean)" plot_compare_evokeds( @@ -269,9 +268,9 @@ epochs_power = list() for condition in [epochs[k] for k in ("Aud/L", "Vis/L")]: - this_tfr = tfr_morlet( - condition, - freqs, + this_tfr = condition.compute_tfr( + method="morlet", + freqs=freqs, n_cycles=n_cycles, decim=decim, average=False, @@ -385,7 +384,7 @@ # add new axis for spectrogram ax_spec = divider.append_axes("right", size="300%", pad=1.2) - title = "Cluster #{0}, {1} spectrogram".format(i_clu + 1, len(ch_inds)) + title = f"Cluster #{i_clu + 1}, {len(ch_inds)} spectrogram" if len(ch_inds) > 1: title += " (max over channels)" F_obs_plot = F_obs[..., ch_inds].max(axis=-1) diff --git a/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py b/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py index 0951280e6d6..24c0adc9d35 100644 --- a/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py +++ b/tutorials/stats-source-space/60_cluster_rmANOVA_spatiotemporal.py @@ -285,9 +285,7 @@ def stat_fun(*args): inds_t, inds_v = [ (clusters[cluster_ind]) for ii, cluster_ind in enumerate(good_cluster_inds) -][ - 0 -] # first cluster +][0] # first cluster times = np.arange(X[0].shape[1]) * tstep * 1e3 diff --git a/tutorials/time-freq/10_spectrum_class.py b/tutorials/time-freq/10_spectrum_class.py index c5f8f4fd639..9d7eb9fae5d 100644 --- a/tutorials/time-freq/10_spectrum_class.py +++ b/tutorials/time-freq/10_spectrum_class.py @@ -8,9 +8,9 @@ The Spectrum and EpochsSpectrum classes: frequency-domain data ============================================================== -This tutorial shows how to create and visualize frequency-domain -representations of your data, starting from continuous :class:`~mne.io.Raw`, -discontinuous :class:`~mne.Epochs`, or averaged :class:`~mne.Evoked` data. +This tutorial shows how to create and visualize frequency-domain representations of your +data, starting from continuous :class:`~mne.io.Raw`, discontinuous :class:`~mne.Epochs`, +or averaged :class:`~mne.Evoked` data. As usual we'll start by importing the modules we need, and loading our :ref:`sample dataset `: @@ -122,7 +122,7 @@ # (interpolated scalp topography of power, in specific frequency bands). A few # plot options are demonstrated below; see the docstrings for full details. -evk_spectrum.plot(picks="data", exclude="bads") +evk_spectrum.plot(picks="data", exclude="bads", amplitude=False) evk_spectrum.plot_topo(color="k", fig_facecolor="w", axis_facecolor="w") # %% diff --git a/tutorials/time-freq/20_sensors_time_frequency.py b/tutorials/time-freq/20_sensors_time_frequency.py index c4981b2b1e0..9175e700041 100644 --- a/tutorials/time-freq/20_sensors_time_frequency.py +++ b/tutorials/time-freq/20_sensors_time_frequency.py @@ -10,7 +10,7 @@ We will use this dataset: :ref:`somato-dataset`. It contains so-called event related synchronizations (ERS) / desynchronizations (ERD) in the beta band. -""" +""" # noqa D400 # Authors: Alexandre Gramfort # Stefan Appelhoff # Richard Höchenberger @@ -24,7 +24,6 @@ import mne from mne.datasets import somato -from mne.time_frequency import tfr_morlet # %% # Set parameters @@ -66,7 +65,9 @@ # %% # Let's first check out all channel types by averaging across epochs. -epochs.compute_psd(fmin=2.0, fmax=40.0).plot(average=True, picks="data", exclude="bads") +epochs.compute_psd(fmin=2.0, fmax=40.0).plot( + average=True, amplitude=False, picks="data", exclude="bads" +) # %% # Now, let's take a look at the spatial distributions of the PSD, averaged @@ -188,14 +189,13 @@ # define frequencies of interest (log-spaced) freqs = np.logspace(*np.log10([6, 35]), num=8) n_cycles = freqs / 2.0 # different number of cycle per frequency -power, itc = tfr_morlet( - epochs, +power, itc = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, - use_fft=True, + average=True, return_itc=True, decim=3, - n_jobs=None, ) # %% @@ -208,7 +208,7 @@ # You can also select a portion in the time-frequency plane to # obtain a topomap for a certain time-frequency region. power.plot_topo(baseline=(-0.5, 0), mode="logratio", title="Average power") -power.plot([82], baseline=(-0.5, 0), mode="logratio", title=power.ch_names[82]) +power.plot(picks=[82], baseline=(-0.5, 0), mode="logratio", title=power.ch_names[82]) fig, axes = plt.subplots(1, 2, figsize=(7, 4), layout="constrained") topomap_kw = dict( diff --git a/tutorials/time-freq/50_ssvep.py b/tutorials/time-freq/50_ssvep.py index 323e8a4fe54..706841fefac 100644 --- a/tutorials/time-freq/50_ssvep.py +++ b/tutorials/time-freq/50_ssvep.py @@ -84,14 +84,12 @@ raw.filter(l_freq=0.1, h_freq=None, fir_design="firwin", verbose=False) # Construct epochs -event_id = {"12hz": 255, "15hz": 155} -events, _ = mne.events_from_annotations(raw, verbose=False) +raw.annotations.rename({"Stimulus/S255": "12hz", "Stimulus/S155": "15hz"}) tmin, tmax = -1.0, 20.0 # in s baseline = None epochs = mne.Epochs( raw, - events=events, - event_id=[event_id["12hz"], event_id["15hz"]], + event_id=["12hz", "15hz"], tmin=tmin, tmax=tmax, baseline=baseline, @@ -356,8 +354,8 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): # Get indices for the different trial types # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -i_trial_12hz = np.where(epochs.events[:, 2] == event_id["12hz"])[0] -i_trial_15hz = np.where(epochs.events[:, 2] == event_id["15hz"])[0] +i_trial_12hz = np.where(epochs.annotations.description == "12hz")[0] +i_trial_15hz = np.where(epochs.annotations.description == "15hz")[0] # %% # Get indices of EEG channels forming the ROI @@ -424,13 +422,13 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): mne.viz.plot_topomap(snrs_12hz_chaverage, epochs.info, vlim=(1, None), axes=ax) print("sub 2, 12 Hz trials, SNR at 12 Hz") -print("average SNR (all channels): %f" % snrs_12hz_chaverage.mean()) -print("average SNR (occipital ROI): %f" % snrs_target.mean()) +print(f"average SNR (all channels): {snrs_12hz_chaverage.mean()}") +print(f"average SNR (occipital ROI): {snrs_target.mean()}") tstat_roi_vs_scalp = ttest_rel(snrs_target.mean(axis=1), snrs_12hz.mean(axis=1)) print( - "12 Hz SNR in occipital ROI is significantly larger than 12 Hz SNR over " - "all channels: t = %.3f, p = %f" % tstat_roi_vs_scalp + "12 Hz SNR in occipital ROI is significantly larger than 12 Hz SNR over all " + f"channels: t = {tstat_roi_vs_scalp[0]:.3f}, p = {tstat_roi_vs_scalp[1]}" ) ############################################################################## @@ -522,24 +520,24 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): res["stim_12hz_snrs_12hz"], res["stim_12hz_snrs_15hz"] ) print( - "12 Hz Trials: 12 Hz SNR is significantly higher than 15 Hz SNR" - ": t = %.3f, p = %f" % tstat_12hz_trial_stim + "12 Hz Trials: 12 Hz SNR is significantly higher than 15 Hz SNR: t = " + f"{tstat_12hz_trial_stim[0]:.3f}, p = {tstat_12hz_trial_stim[1]}" ) tstat_12hz_trial_1st_harmonic = ttest_rel( res["stim_12hz_snrs_24hz"], res["stim_12hz_snrs_30hz"] ) print( - "12 Hz Trials: 24 Hz SNR is significantly higher than 30 Hz SNR" - ": t = %.3f, p = %f" % tstat_12hz_trial_1st_harmonic + "12 Hz Trials: 24 Hz SNR is significantly higher than 30 Hz SNR: t = " + f"{tstat_12hz_trial_1st_harmonic[0]:.3f}, p = {tstat_12hz_trial_1st_harmonic[1]}" ) tstat_12hz_trial_2nd_harmonic = ttest_rel( res["stim_12hz_snrs_36hz"], res["stim_12hz_snrs_45hz"] ) print( - "12 Hz Trials: 36 Hz SNR is significantly higher than 45 Hz SNR" - ": t = %.3f, p = %f" % tstat_12hz_trial_2nd_harmonic + "12 Hz Trials: 36 Hz SNR is significantly higher than 45 Hz SNR: t = " + f"{tstat_12hz_trial_2nd_harmonic[0]:.3f}, p = {tstat_12hz_trial_2nd_harmonic[1]}" ) print() @@ -547,24 +545,24 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): res["stim_15hz_snrs_12hz"], res["stim_15hz_snrs_15hz"] ) print( - "15 Hz trials: 12 Hz SNR is significantly lower than 15 Hz SNR" - ": t = %.3f, p = %f" % tstat_15hz_trial_stim + "15 Hz trials: 12 Hz SNR is significantly lower than 15 Hz SNR: t = " + f"{tstat_15hz_trial_stim[0]:.3f}, p = {tstat_15hz_trial_stim[1]}" ) tstat_15hz_trial_1st_harmonic = ttest_rel( res["stim_15hz_snrs_24hz"], res["stim_15hz_snrs_30hz"] ) print( - "15 Hz trials: 24 Hz SNR is significantly lower than 30 Hz SNR" - ": t = %.3f, p = %f" % tstat_15hz_trial_1st_harmonic + "15 Hz trials: 24 Hz SNR is significantly lower than 30 Hz SNR: t = " + f"{tstat_15hz_trial_1st_harmonic[0]:.3f}, p = {tstat_15hz_trial_1st_harmonic[1]}" ) tstat_15hz_trial_2nd_harmonic = ttest_rel( res["stim_15hz_snrs_36hz"], res["stim_15hz_snrs_45hz"] ) print( - "15 Hz trials: 36 Hz SNR is significantly lower than 45 Hz SNR" - ": t = %.3f, p = %f" % tstat_15hz_trial_2nd_harmonic + "15 Hz trials: 36 Hz SNR is significantly lower than 45 Hz SNR: t = " + f"{tstat_15hz_trial_2nd_harmonic[0]:.3f}, p = {tstat_15hz_trial_2nd_harmonic[1]}" ) ############################################################################## @@ -604,7 +602,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): window_snrs = [[]] * len(window_lengths) for i_win, win in enumerate(window_lengths): # compute spectrogram - this_spectrum = epochs[str(event_id["12hz"])].compute_psd( + this_spectrum = epochs["12hz"].compute_psd( "welch", n_fft=int(sfreq * win), n_overlap=0, @@ -688,7 +686,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1): for i_win, win in enumerate(window_starts): # compute spectrogram - this_spectrum = epochs[str(event_id["12hz"])].compute_psd( + this_spectrum = epochs["12hz"].compute_psd( "welch", n_fft=int(sfreq * window_length) - 1, n_overlap=0, diff --git a/tutorials/visualization/10_publication_figure.py b/tutorials/visualization/10_publication_figure.py index 69edf301eb5..138f9165db1 100644 --- a/tutorials/visualization/10_publication_figure.py +++ b/tutorials/visualization/10_publication_figure.py @@ -108,7 +108,7 @@ axes, [screenshot, cropped_screenshot], ["Before", "After"] ): ax.imshow(image) - ax.set_title("{} cropping".format(title)) + ax.set_title(f"{title} cropping") # %% # A lot of figure settings can be adjusted after the figure is created, but diff --git a/tutorials/visualization/20_ui_events.py b/tutorials/visualization/20_ui_events.py index ce268e1d8a5..e119b5032c1 100644 --- a/tutorials/visualization/20_ui_events.py +++ b/tutorials/visualization/20_ui_events.py @@ -16,6 +16,7 @@ Since the figures on our website don't have any interaction capabilities, this example will only work properly when run in an interactive environment. """ + # Author: Marijn van Vliet # # License: BSD-3-Clause