8000 Add support for new __sklearn_tags__ by stes · Pull Request #205 · AdaptiveMotorControlLab/CEBRA · GitHub
[go: up one dir, main page]

Skip to content

Add support for new __sklearn_tags__ #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 16, 2024
Merged

Add support for new __sklearn_tags__ #205

merged 4 commits into from
Dec 16, 2024

Conversation

stes
Copy link
Member
@stes stes commented Dec 16, 2024

Fix #204

sklearn 1.6.0 was released on Dec 9, 24 and introduced a new mechanism for specifying estimator tags (https://scikit-learn.org/dev/developers/develop.html). This PR adopts CEBRA to comply with this new notation. Older sklearn variants will fall back to the more_tags() functions as recommended in this comment.

Indepedently, I spotted a bug in the inheritance order in the CEBRA class, which was fixed now, as described here.

Finally, since the code is now version dependent and there might be users rolling older sklearn version, I extended the test suite by one case checking with a legacy sklearn version (version 1.4.2 which is roughly one year old) -- this will hopefully cover the most important cases. The majority of tests are with sklearn latest (1.6.0 as of Dec 16, 24).

@stes stes requested a review from MMathisLab December 16, 2024 17:57
Copy link
Member
@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm; but did not directly test

@stes stes merged commit 5f46c32 into main Dec 16, 2024
13 checks passed
@stes stes deleted the stes/fix-sklearn-tags branch December 16, 2024 19:32
@stes stes mentioned this pull request Dec 20, 2024
2 tasks
@Gunnar-Stunnar
Copy link

Was this ever fixed, I am using Cebra v0.4.0 and now receiving this error:

[<ipython-input-13-2e6de4decd48>](https://localhost:8080/#) in train(self, neural_session, continous_sessions)
     44 
     45         # fit decoder
---> 46         emb = self.cebra_posOnly_model.transform(nStack_train)
     47         fullKin = trainAllKin(emb, cStack_train)
     48 

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_set_output.py](https://localhost:8080/#) in wrapped(self, X, *args, **kwargs)
    317     @wraps(f)
    318     def wrapped(self, X, *args, **kwargs):
--> 319         data_to_wrap = f(self, X, *args, **kwargs)
    320         if isinstance(data_to_wrap, tuple):
    321             # only wrap the first output for cross decomposition

[/usr/local/lib/python3.10/dist-packages/cebra/integrations/sklearn/cebra.py](https://localhost:8080/#) in transform(self, X, session_id)
   1224         """
   1225 
-> 1226         sklearn_utils_validation.check_is_fitted(self, "n_features_")
   1227         model, offset = self._select_model(X, session_id)
   1228 

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_is_fitted(estimator, attributes, msg, all_or_any)
   1749         raise TypeError("%s is not an estimator instance." % (estimator))
   1750 
-> 1751     tags = get_tags(estimator)
   1752 
   1753     if not tags.requires_fit and attributes is None:

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_tags.py](https://localhost:8080/#) in get_tags(estimator)
    403         for klass in reversed(type(estimator).mro()):
    404             if "__sklearn_tags__" in vars(klass):
--> 405                 sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator)  # type: ignore[attr-defined]
    406                 class_order.append(klass)
    407             elif "_more_tags" in vars(klass):

[/usr/local/lib/python3.10/dist-packages/sklearn/base.py](https://localhost:8080/#) in __sklearn_tags__(self)
    857 
    858     def __sklearn_tags__(self):
--> 859         tags = super().__sklearn_tags__()
    860         tags.transformer_tags = TransformerTags()
    861         return tags

AttributeError: 'super' object has no attribute '__sklearn_tags__'

This is after installing the latest Cebra package with the scikit-learn v1.6.0.

@Gunnar-Stunnar
Copy link

rolling back scikit-learn back to v1.5.2 worked

@stes
Copy link
Member Author
stes commented Dec 26, 2024

Hi @Gunnar-Stunnar , this was merged after the cebra 0.4.0 release. If you install the latest version from git,

pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git

the error should disappear even with sklearn > 1.6.0. In case you give that a try, please let me know if it works!

@Gunnar-Stunnar
Copy link

Looks like it worked!

My logs are now being filed with this error:

/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
pos: -9.4474 neg:  15.6202 total:  6.1728 temperature:  0.1000: 100%|██████████| 1000/1000 [00:16<00:00, 61.28it/s]
/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(

@juliagorman
Copy link

I recently installed CEBRA and am still gettitng the first error. I rolled back to scikit-learn back to v1.5.2 and now am getting this log error described above

@MMathisLab
Copy link
Member

You would need to pull from git, but a new version is coming soon!

@juliagorman
Copy link

if i re-install from git, will it change the PyTorch version I currently have installed in my conda environment? Should I just make a new environment to install from git?

@stes
Copy link
Member Author
stes commented Jan 29, 2025

I recently installed CEBRA and am still gettitng the first error. I rolled back to scikit-learn back to v1.5.2 and now am getting this log error described above

are you referring to this log output? This is just a warning message which is safe to ignore, there is no effect with respect to model fitting.

You can configure the warnings package if you want to get rid of the message.

Otherwise, as @MMathisLab noted, we will soon also release a new version of CEBRA properly handling this.

CeliaBenquet pushed a commit to CeliaBenquet/CEBRA that referenced this pull request Apr 23, 2025
* Add support for new __sklearn_tags__

* fix inheritance order

* Add more tests

* fix added test
MMathisLab added a commit that referenced this pull request May 23, 2025
* first proposal for batching in tranform method

* first running version of padding with batched inference

* start tests

* add pad_before_transform to fit function and add support for convolutional models in _transform

* remove print statements

* first passing test

* add support for hybrid models

* rewrite transform in sklearn API

* baseline version of a torch.Datset

* move batching logic outside solver

* move functionality to base file in solver and separate in functions

* add test_select_model for single session

* add checks and test for _process_batch

* add test_select_model for multisession

* make self.num_sessions compatible with single session training

* improve test_batched_transform_singlesession

* make it work with small batches

* make test with multisession work

* change to torch padding

* add argument to sklearn api

* add torch padding to _transform

* convert to torch if numpy array as inputs

* add distinction between pad with data and pad with zeros and modify test accordingly

* differentiate between data padding and zero padding

* remove float16

* change argument position

* clean test

* clean test

* Fix warning

* Improve modularity remove duplicate code and todos

* Add tests to solver

* Remove unused import in solver/utils

* Fix test plot

* Add some coverage

* Fix save/load

* Remove duplicate configure_for in multi dataset

* Make save/load cleaner

* Fix codespell errors

* Fix docs compilation errors

* Fix formatting

* Fix extra docs errors

* Fix offset in docs

* Remove attribute ref

* Add review updates

* apply ruff auto-fixes

* Concatenate last batches for batched inference (#200)

* Concatenate last to batches for batched inference

* Add test case

* Fix linting errors in tests (#188)

* apply auto-fixes

* Fix linting errors in tests/

* Fix version check

* Fix `scikit-learn` reference in conda environment files (#195)

* Add support for new __sklearn_tags__ (#205)

* Add support for new __sklearn_tags__

* fix inheritance order

* Add more tests

* fix added test

* Update workflows to actions/setup-python@v5, actions/cache@v4 (#212)

* Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (#206)

* Add tests to check legacy model loading (#214)

* Add improved goodness of fit implementation (#190)

* Started implementing improved goodness of fit implementation

* add tests and improve implementation

* Fix examples

* Fix docstring error

* Handle batch size = None for goodness of fit computation

* adapt GoF implementation

* Fix docstring tests

* Update docstring for goodness_of_fit_score

Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>

* add annotations to goodness_of_fit_history

Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>

* fix typo

Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>

* improve err message

Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>

* make numerical test less conversative

* Add tests for exception handling

* fix tests

---------

Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com>

* Support numpy 2, upgrade tests to support torch 2.6 (#221)

* Drop numpy constraint

* Implement workaround for pytables

* better error message

* pin numpy only for python 3.9

* update dependencies

* Upgrade torch version

* Fix based on python version

* Add support for torch.load with weights_only=True

* Implement safe loading for torch models starting in torch 2.6

* Fix windows specs

* fix docstring

* Revert changes to loading logic

* Release 0.5.0rc1 (#189)

* Make bump_version script runnable on MacOS

* Bump version to 0.5.0rc1

* fix minor formatting issues

* remove commented code

---------

Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>

* Fix pypi action (#222)

* force packaging upgrade to 24.2 for twine

* Bump version to 0.5.0rc2

* remove universal compatibility option

* revert tag

* adapt files to new wheel name due to py3

* Update base.py (#224)

This is a lazy solution to #223

* Change max consistency value to 100 instead of 99 (#227)

* Change text consistency max from 99 to 100
* Update cebra/integrations/matplotlib.py

---------

Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch>
Co-authored-by: Steffen Schneider <steffen@bethgelab.org>

* Update assets.py --> force check for parent dir (#230)

Update assets.py

- mkdir was failing in 0.5.0rc1; attempt to fix

* User docs minor edit (#229)

* user note added to usage.rst

- link added

* Update usage.rst

- more detailed note on the effect of temp.

* Update usage.rst

- add in temp to demo model

- testout put

thanks @stes

* Update docs/source/usage.rst

Co-authored-by: Steffen Schneider <stes@hey.com>

* Update docs/source/usage.rst

Co-authored-by: Steffen Schneider <stes@hey.com>

* Update docs/source/usage.rst

Co-authored-by: Steffen Schneider <stes@hey.com>

---------

Co-authored-by: Steffen Schneider <stes@hey.com>

* General Doc refresher (#232)

* Update installation.rst

- python 3.9+

* Update index.rst

* Update figures.rst

* Update index.rst

-typo fix

* Update usage.rst

- update suggestion on data split

* Update docs/source/usage.rst

Co-authored-by: Steffen Schneider <stes@hey.com>

* Update usage.rst

- indent error fixed

* Update usage.rst

- changed infoNCE to new GoF

* Update usage.rst

- finx numpy() doctest

* Update usage.rst

- small typo fix (label)

* Update usage.rst

---------

Co-authored-by: Steffen Schneider <stes@hey.com>

* render plotly in our docs, show code/doc version (#231)

* Update layout.html (#233)

* Update conf.py (#234)

- adding link to new notebook icon

* Refactoring setup.cfg (#228)

* Home page landing update (#235)

* website refresh

* v0.5.0 (#238)

* Upgrade docs build (#241)

* Improve build setup for docs

* update pydata theme options

* Add README for docs folder

* Fix demo notebook build

* Finish build setup

* update git workflow

* add timeout to workflow

* add timeout also to docs build

* switch build back to sphinx for gh actions

* attempt to fix build workflow

* update to sphinx-build

* fix build workflow

* fix indent error

* fix build system

* revert demos to main

* increase timeout to 30

* Allow indexing of the cebra docs (#242)

* Allow indexing of the cebra docs

* Fix docs workflow

* Fix broken docs coverage workflows (#246)

* Add xCEBRA implementation (AISTATS 2025) (#225)

* Add multiobjective solver and regularized training (#783)

* Add multiobjective solver and regularized training

* Add example for multiobjective training

* Add jacobian regularizer and SAM

* update license headers

* add api draft for multiobjective training

* add all necessary modules to run the complete xcebra pipeline

* add notebooks to reproduce xcebra pipeline

* add first working notebook

* add notebook with hybrid learning

* add notebook with creation of synthetic data

* add notebook with hybrid training

* add plot with R2 for different parts of the embedding

* add new API

* update api wrapper with more checks and messages

* add tests and notebook with new api

* merge xcebra into attribution

* separate xcebra dataset from cebra

* some minor refactoring of cebra dataset

* separate xcebra loader from cebra

* remove xcebra distributions from cebra

* minor refactoring with distributions

* separate xcebra criterions from cebra

* minor refactoring on criterion

* separate xcebra models/criterions/layers from cebra

* refactoring multiobjective

* more refactoring...

* separate xcebra solvers from cebra

* more refactoring

* move xcebra to its own package

* move more files into xcebra package

* more files and remove changes with the registry

* remove unncessary import

* add folder structure

* move back distributions

* add missing init

* remove wrong init

* make loader and dataset run with new imports

* making it run!

* make attribution run

* Run pre-commit

* move xcebra repo one level up

* update gitignore and add __init__ from data

* add init to distributions

* add correct init for attribution pacakge

* add correct init for model package

* fix remaining imports

* fix tests

* add examples back to xcebra repo

* update imports from graphs_xcebra

* add setup.py to create a package

* update imports of graph_xcebra

* update notebooks

* Formatting code for submission

Co-authored-by: Rodrigo Gonzalez <gonlairo@gmail.com>

* move test into xcebra

* Add README

* move distributions back to main package

* clean up examples

* adapt tests

* Add LICENSE

* add train/eval notebook again

* add notebook with clean results

* rm synthetic data

* change name from xcebra to regcl

* change names of modules and adapt imports

* change name from graphs_xcebra to synthetic_data

* Integrate into CEBRA

* Fix remaining imports and make notebook runnable

* Add dependencies, add version flag

* Remove synthetic data files

* reset dockerfile, move vmf

* apply pre-commit

* Update notice

* add some docstrings

* Apply license headers

* add new scd notebook

* add notebook with scd

---------

Co-authored-by: Steffen Schneider <stes@hey.com>

* Fix tests

* bump version

* update dockerfile

* fix progress bar

* remove outdated test

* rename models

* Apply fixes to pass ruff tests

* Fix typos

* Update license headers, fix additional ruff errors

* remove unused comment

* rename regcl in codebase

* change regcl name in dockerfile

* Improve attribution module

* Fix imports name naming

* add basic integration test

* temp disable of binary check

* Add legacy multiobjective model for backward compat

* add synth import back in

* Fix docstrings and type annot in cebra/models/jacobian_regularizer.py

* add xcebra to tests

* add missing cvxpy dep

* fix docstrings

* more docstrings to fix attr error

* Improve build setup for docs

* update pydata theme options

* Add README for docs folder

* Fix demo notebook build

* Finish build setup

* update git workflow

* Move demo notebooks to CEBRA-demos repo

See AdaptiveMotorControlLab/CEBRA-demos#28

* revert unneeded changes in solver

* formatting in solver

* further minimize solver diff

* Revert unneeded updates to the solver

* fix citation

* fix docs build, missing refs

* remove file dependency from xcebra int test

* remove unneeded change in registry

* update gitignore

* update docs

* exclude some assets

* include binary file check again

* add timeout to workflow

* add timeout also to docs build

* switch build back to sphinx for gh actions

* pin sphinx version in setup.cfg

* attempt workflow fix

* attempt to fix build workflow

* update to sphinx-build

* fix build workflow

* fix indent error

* fix build system

* revert demos to main

* adapt workflow for testing

* bump version to 0.6.0rc1

* format imports

* docs writing

* enable build on dev branch

* fix some review comments

* extend multiobjective docs

* Set version to alpha

* make tempdir platform independent

* Remove ratinabox and ephysiopy as deps

* Apply review comments

* Update Makefile

- setting coverage threshold to 80% to not delay good code being made public. In the near future this can be fixed and raised again to 90%.

---------

Co-authored-by: Steffen Schneider <stes@hey.com>
Co-authored-by: Steffen Schneider <steffen.schneider@helmholtz-munich.de>
Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>

* start tests

* remove print statements

* first passing test

* move functionality to base file in solver and separate in functions

* add test_select_model for multisession

* remove float16

* Improve modularity remove duplicate code and todos

* Add tests to solver

* Fix save/load

* Fix extra docs errors

* Add review updates

* apply ruff auto-fixes

* fix linting errors

* Run isort, ruff, yapf

* Fix gaussian mixture dataset import

* Fix all tests but xcebra tests

* Fix pytorch API usage example

* Make xCEBRA compatible with the batched inference & padding in solver

* Add some tests on transform() with xCEBRA

* Add some docstrings and typings and clean unnecessary changes

* Implement review comments

* Fix sklearn test

* Add name in NOTE

Co-authored-by: Steffen Schneider <steffen@bethgelab.org>

* Implement reviews on tests and typing

* Fix import errors

* Add select_model to aux solvers

* Fix docs error

* Add tests on the private functions in base solver

* Update tests and duplicate code based on review

---------

Co-authored-by: Rodrigo <gonlairo@gmail.com>
Co-authored-by: Steffen Schneider <stes@hey.com>
Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
Co-authored-by: Steffen Schneider <steffen.schneider@helmholtz-munich.de>
Co-authored-by: Ícaro <icarosadero@proton.me>
Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch>
Co-authored-by: Steffen Schneider <steffen@bethgelab.org>
Co-authored-by: Rodrigo González Laiz <31796689+gonlairo@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Upstream sklearn change causes error ('super' object has no attribute '__sklearn_tags__') in test suite
4 participants
0