diff --git a/doc/computing.rst b/doc/computing.rst index 6732b754918b0..8b355f22ec641 100644 --- a/doc/computing.rst +++ b/doc/computing.rst @@ -14,3 +14,4 @@ Computing with scikit-learn computing/scaling_strategies computing/computational_performance computing/parallelism + computing/engine diff --git a/doc/computing/engine.rst b/doc/computing/engine.rst new file mode 100644 index 0000000000000..b27a61fd3a25a --- /dev/null +++ b/doc/computing/engine.rst @@ -0,0 +1,29 @@ +.. Places parent toc into the sidebar + +:parenttoc: True + +.. _engine: + +Computation Engines (experimental) +================================== + +**This API is experiment** which means that it is subject to change without +any backward compatibility guarantees. + +TODO: explain goals here + +Activating an engine +-------------------- + +TODO: installing third party engine provider packages + +TODO: how to list installed engines + +TODO: how to install a plugin + +Writing a new engine provider +----------------------------- + +TODO: show engine API of a given estimator. + +TODO: give example setup.py with setuptools to define an entrypoint. diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 1427e61c03385..184a91aa7b8e6 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -2,6 +2,26 @@ .. currentmodule:: sklearn + +TODO: move to doc/whats_new/v1.3.rst once it exists in main: + +- |Enhancement| Experimential engine API (no backward compatibility guarantees) + to allow for external packages to contribute alternative implementations for + the core computational routines of some selected scikit-learn estimators. + + Currently, the following estimators allow alternative implementations: + + - :class:`~sklearn.cluster.KMeans` (only for the LLoyd algorithm). + - TODO: add more here + + External engine providers include: + + - https://github.com/soda-inria/sklearn-numba-dpex that provided a KMeans + engine optimized for OpenCL enabled GPUs. + - TODO: add more here + + :pr:`24497` by :user:`ogrisel`, :user:`fcharras`. + .. _changes_1_2: Version 1.2.0 diff --git a/setup.py b/setup.py index 17d90a45f35ea..d4e70cbfa5f57 100755 --- a/setup.py +++ b/setup.py @@ -606,6 +606,7 @@ def setup_package(): python_requires=python_requires, install_requires=min_deps.tag_to_packages["install"], package_data={"": ["*.pxd"]}, + entry_points={"pytest11": ["sklearn_plugin_testing = sklearn._engine.testing"]}, zip_safe=False, # the package can run out of an .egg file include_package_data=True, extras_require={ diff --git a/sklearn/_config.py b/sklearn/_config.py index e4c398c9c5444..7db1788657104 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -14,6 +14,7 @@ ), "enable_cython_pairwise_dist": True, "array_api_dispatch": False, + "engine_provider": (), "transform_output": "default", } _threadlocal = threading.local() @@ -53,6 +54,7 @@ def set_config( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + engine_provider=None, transform_output=None, ): """Set global scikit-learn configuration @@ -122,6 +124,15 @@ def set_config( .. versionadded:: 1.2 + engine_provider : str or sequence of str, default=None + Enable computational engine implementation provided by third party + packages to leverage specific hardware platforms using frameworks or + libraries outside of the usual scikit-learn project dependencies. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.3 + transform_output : str, default=None Configure output of `transform` and `fit_transform`. @@ -155,6 +166,8 @@ def set_config( local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist if array_api_dispatch is not None: local_config["array_api_dispatch"] = array_api_dispatch + if engine_provider is not None: + local_config["engine_provider"] = engine_provider if transform_output is not None: local_config["transform_output"] = transform_output @@ -169,6 +182,7 @@ def config_context( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + engine_provider=None, transform_output=None, ): """Context manager for global scikit-learn configuration. @@ -237,6 +251,15 @@ def config_context( .. versionadded:: 1.2 + engine_provider : str or sequence of str, default=None + Enable computational engine implementation provided by third party + packages to leverage specific hardware platforms using frameworks or + libraries outside of the usual scikit-learn project dependencies. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.3 + transform_output : str, default=None Configure output of `transform` and `fit_transform`. @@ -285,6 +308,7 @@ def config_context( pairwise_dist_chunk_size=pairwise_dist_chunk_size, enable_cython_pairwise_dist=enable_cython_pairwise_dist, array_api_dispatch=array_api_dispatch, + engine_provider=engine_provider, transform_output=transform_output, ) diff --git a/sklearn/_engine/__init__.py b/sklearn/_engine/__init__.py new file mode 100644 index 0000000000000..649ad1ba062d3 --- /dev/null +++ b/sklearn/_engine/__init__.py @@ -0,0 +1,4 @@ +from .base import get_engine_classes, list_engine_provider_names + + +__all__ = ["get_engine_classes", "list_engine_provider_names"] diff --git a/sklearn/_engine/base.py b/sklearn/_engine/base.py new file mode 100644 index 0000000000000..8e84432784075 --- /dev/null +++ b/sklearn/_engine/base.py @@ -0,0 +1,115 @@ +from importlib.metadata import entry_points +from importlib import import_module +from functools import lru_cache +import warnings + +from .._config import get_config + +SKLEARN_ENGINES_ENTRY_POINT = "sklearn_engines" + + +class EngineSpec: + + __slots__ = ["name", "provider_name", "module_name", "engine_qualname"] + + def __init__(self, name, provider_name, module_name, engine_qualname): + self.name = name + self.provider_name = provider_name + self.module_name = module_name + self.engine_qualname = engine_qualname + + def get_engine_class(self): + engine = import_module(self.module_name) + for attr in self.engine_qualname.split("."): + engine = getattr(engine, attr) + return engine + + +def _parse_entry_point(entry_point): + module_name, engine_qualname = entry_point.value.split(":") + provider_name = next(iter(module_name.split(".", 1))) + return EngineSpec(entry_point.name, provider_name, module_name, engine_qualname) + + +@lru_cache +def _parse_entry_points(provider_names=None): + specs = [] + all_entry_points = entry_points() + if hasattr(all_entry_points, "select"): + engine_entry_points = all_entry_points.select(group=SKLEARN_ENGINES_ENTRY_POINT) + else: + engine_entry_points = all_entry_points.get(SKLEARN_ENGINES_ENTRY_POINT, ()) + for entry_point in engine_entry_points: + try: + spec = _parse_entry_point(entry_point) + if provider_names is not None and spec.provider_name not in provider_names: + # Skip entry points that do not match the requested provider names. + continue + specs.append(spec) + except Exception as e: + # Do not raise an exception in case an invalid package has been + # installed in the same Python env as scikit-learn: just warn and + # skip. + warnings.warn( + f"Invalid {SKLEARN_ENGINES_ENTRY_POINT} entry point" + f" {entry_point.name} with value {entry_point.value}: {e}" + ) + if provider_names is not None: + observed_provider_names = {spec.provider_name for spec in specs} + missing_providers = set(provider_names) - observed_provider_names + if missing_providers: + raise RuntimeError( + "Could not find any provider for the" + f" {SKLEARN_ENGINES_ENTRY_POINT} entry point with name(s):" + f" {', '.join(repr(p) for p in sorted(missing_providers))}" + ) + return specs + + +def list_engine_provider_names(): + """Find the list of sklearn_engine provider names + + This function only inspects the metadata and should trigger any module import. + """ + return sorted({spec.provider_name for spec in _parse_entry_points()}) + + +def _get_engine_classes(engine_name, provider_names, engine_specs, default): + specs_by_provider = {} + for spec in engine_specs: + if spec.name != engine_name: + continue + specs_by_provider.setdefault(spec.provider_name, spec) + + for provider_name in provider_names: + spec = specs_by_provider.get(provider_name) + if spec is not None: + # XXX: should we return an instance or the class itself? + yield spec.provider_name, spec.get_engine_class() + + yield "default", default + + +def get_engine_classes(engine_name, default, verbose=False): + provider_names = get_config()["engine_provider"] + if isinstance(provider_names, str): + provider_names = (provider_names,) + elif not isinstance(provider_names, tuple): + # Make sure the provider names are a tuple to make it possible for the + # lru cache to hash them. + provider_names = tuple(provider_names) + if not provider_names: + yield "default", default + return + engine_specs = _parse_entry_points(provider_names=provider_names) + for provider, engine_class in _get_engine_classes( + engine_name=engine_name, + provider_names=provider_names, + engine_specs=engine_specs, + default=default, + ): + if verbose: + print( + f"trying engine {engine_class.__module__}.{engine_class.__qualname__} ." + ) + yield provider, engine_class diff --git a/sklearn/_engine/testing.py b/sklearn/_engine/testing.py new file mode 100644 index 0000000000000..8c553ee06d655 --- /dev/null +++ b/sklearn/_engine/testing.py @@ -0,0 +1,38 @@ +from pytest import xfail, hookimpl + +from sklearn import config_context + +from sklearn.exceptions import NotSupportedByEngineError + + +# TODO: document this pytest plugin + write a tutorial on how to develop a new plugin +# and explain good practices regarding testing against sklearn test modules. +def pytest_addoption(parser): + group = parser.getgroup("Sklearn plugin testing") + group.addoption( + "--sklearn-engine-provider", + action="store", + nargs=1, + type=str, + help="Name of the an engine provider for sklearn to activate for all tests.", + ) + + +@hookimpl(hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + engine_provider = pyfuncitem.config.getoption("sklearn_engine_provider") + if engine_provider is None: + yield + return + + with config_context(engine_provider=engine_provider): + try: + outcome = yield + outcome.get_result() + except NotSupportedByEngineError: + xfail( + reason=( + "This test cover features that are not supported by the " + f"engine provided by {engine_provider}." + ) + ) diff --git a/sklearn/_engine/tests/__init__.py b/sklearn/_engine/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/_engine/tests/test_engines.py b/sklearn/_engine/tests/test_engines.py new file mode 100644 index 0000000000000..93c3962a753a5 --- /dev/null +++ b/sklearn/_engine/tests/test_engines.py @@ -0,0 +1,138 @@ +import re +from collections import namedtuple +import pytest + +from sklearn._engine import list_engine_provider_names +from sklearn._engine import get_engine_classes +from sklearn._engine.base import _parse_entry_point +from sklearn._engine.base import _get_engine_classes +from sklearn._engine.base import EngineSpec +from sklearn._config import config_context + + +class FakeDefaultEngine: + pass + + +class FakeEngine: + pass + + +class FakeEngineHolder: + class NestedFakeEngine: + pass + + +FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "value"]) + + +def test_parse_entry_point(): + fake_entry_point = FakeEntryPoint( + name="fake_engine", + value="sklearn._engine.tests.test_engines:FakeEngine", + ) + spec = _parse_entry_point(fake_entry_point) + assert spec.name == "fake_engine" + assert spec.provider_name == "sklearn" # or should it be scikit-learn? + assert spec.get_engine_class() is FakeEngine + + +def test_parse_entry_point_for_nested_engine_class(): + fake_entry_point = FakeEntryPoint( + name="nested_fake_engine", + value="sklearn._engine.tests.test_engines:FakeEngineHolder.NestedFakeEngine", + ) + spec = _parse_entry_point(fake_entry_point) + assert spec.name == "nested_fake_engine" + assert spec.provider_name == "sklearn" # or should it be scikit-learn? + assert spec.get_engine_class() is FakeEngineHolder.NestedFakeEngine + + +def test_list_engine_provider_names(): + provider_names = list_engine_provider_names() + for provider_name in provider_names: + assert isinstance(provider_name, str) + + +def test_get_engine_class_with_default(): + # Use config_context with an empty provider tuple to make sure that not provider + # are available for test_missing_engine_name + with config_context(engine_provider=()): + engine_classes = list( + get_engine_classes("test_missing_engine_name", default=FakeEngine) + ) + assert engine_classes == [("default", FakeEngine)] + + +def test_get_engine_class(): + engine_specs = ( + EngineSpec( + "kmeans", "provider3", "sklearn._engine.tests.test_engines", "FakeEngine" + ), + EngineSpec( + "kmeans", + "provider4", + "sklearn._engine.tests.test_engines", + "FakeEngineHolder.NestedFakeEngine", + ), + ) + + engine_class = list( + _get_engine_classes( + engine_name="missing", + provider_names=("provider1", "provider3"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [("default", FakeDefaultEngine)] + + engine_class = list( + _get_engine_classes( + engine_name="kmeans", + provider_names=("provider3", "provider4"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [ + ("provider3", FakeEngine), + ("provider4", FakeEngineHolder.NestedFakeEngine), + ("default", FakeDefaultEngine), + ] + + engine_class = list( + _get_engine_classes( + engine_name="kmeans", + provider_names=("provider4", "provider3"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [ + ("provider4", FakeEngineHolder.NestedFakeEngine), + ("provider3", FakeEngine), + ("default", FakeDefaultEngine), + ] + + engine_specs = engine_specs + ( + EngineSpec( + "kmeans", + "provider1", + "sklearn.provider1.somewhere", + "OtherEngine", + ), + ) + + # Invalid imports are delayed until they are actually needed. + engine_classes = _get_engine_classes( + engine_name="kmeans", + provider_names=("provider4", "provider3", "provider1"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + + next(engine_classes) + next(engine_classes) + with pytest.raises(ImportError, match=re.escape("sklearn.provider1")): + next(engine_classes) diff --git a/sklearn/cluster/_kmeans.py b/sklearn/cluster/_kmeans.py index 8fac729725b38..de54602e5ebfc 100644 --- a/sklearn/cluster/_kmeans.py +++ b/sklearn/cluster/_kmeans.py @@ -54,7 +54,8 @@ from ._k_means_elkan import init_bounds_sparse from ._k_means_elkan import elkan_iter_chunked_dense from ._k_means_elkan import elkan_iter_chunked_sparse - +from .._engine import get_engine_classes +from .._config import get_config ############################################################################### # Initialization heuristic @@ -256,15 +257,194 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial # K-means batch estimation by EM (expectation maximization) -def _tolerance(X, tol): - """Return a tolerance which is dependent on the dataset.""" - if tol == 0: - return 0 - if sp.issparse(X): - variances = mean_variance_axis(X, axis=0)[1] - else: - variances = np.var(X, axis=0) - return np.mean(variances) * tol +class _IgnoreParam: + pass + + +class KMeansCythonEngine: + """Cython-based implementation of the core k-means routines + + This implementation is meant to be swappable by alternative implementations + in third-party packages via the sklearn_engines entry-point and the + `engine_provider` kwarg of `sklearn.config_context`. + + TODO: see URL for more details. + """ + + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None, sample_weight=None): + # The default engine accepts everything + return True + + def pre_fit(self, X, y=None, sample_weight=None): + estimator = self.estimator + X = estimator._validate_data( + X, + accept_sparse="csr", + dtype=[np.float64, np.float32], + order="C", + copy=estimator.copy_x, + accept_large_sparse=False, + ) + # this sets estimator _algorithm implicitly + # XXX: shall we explose this logic as part of then engine API? + # or is the current API flexible enough? + estimator._check_params_vs_input(X) + + # TODO: delegate rng and sample weight checks to engine + random_state = check_random_state(estimator.random_state) + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) + + # Also store the number of threads on the estimator to be reused at + # prediction time XXX: shall we wrap engine-specific private fit + # attributes in a predict context dict set as attribute on the + # estimator? + estimator._n_threads = self._n_threads = _openmp_effective_n_threads() + + # Validate init array + init = estimator.init + init_is_array_like = _is_arraylike_not_scalar(init) + if init_is_array_like: + init = check_array(init, dtype=X.dtype, copy=True, order="C") + estimator._validate_center_shape(X, init) + + # subtract of mean of x for more accurate distance computations + if not sp.issparse(X): + X_mean = X.mean(axis=0) + # The copy was already done above + X -= X_mean + + if init_is_array_like: + init -= X_mean + + self.X_mean = X_mean + + # precompute squared norms of data points + x_squared_norms = row_norms(X, squared=True) + + if estimator._algorithm == "elkan": + kmeans_single = _kmeans_single_elkan + else: + kmeans_single = _kmeans_single_lloyd + estimator._check_mkl_vcomp(X, X.shape[0]) + + self.x_squared_norms = x_squared_norms + self.kmeans_single_func = kmeans_single + self.random_state = random_state + self.tol = self.scale_tolerance(X, estimator.tol) + self.init = init + return X, y, sample_weight + + def init_centroids(self, X): + # XXX: the actual implementation of the centroids init should also be + # moved to the engine. + return self.estimator._init_centroids( + X, + x_squared_norms=self.x_squared_norms, + init=self.init, + random_state=self.random_state, + ) + + def scale_tolerance(self, X, tol): + """Return a tolerance which is dependent on the dataset.""" + if tol == 0: + return 0 + if sp.issparse(X): + variances = mean_variance_axis(X, axis=0)[1] + else: + variances = np.var(X, axis=0) + return np.mean(variances) * tol + + def unshift_centers(self, X, best_centers): + if not sp.issparse(X): + if not self.estimator.copy_x: + X += self.X_mean + best_centers += self.X_mean + + def is_same_clustering(self, labels, best_labels, n_clusters): + return _is_same_clustering(labels, best_labels, n_clusters) + + def fit(self, X, y=None, sample_weight=None): + centers_init = self.init_centroids(X) + if self.estimator.verbose: + print("Initialization complete") + + best_inertia, best_labels = None, None + + for i in range(self.estimator._n_init): + labels, inertia, centers, n_iter_ = self.kmeans_single_func( + X, + sample_weight, + centers_init, + max_iter=self.estimator.max_iter, + tol=self.tol, + n_threads=self._n_threads, + verbose=self.estimator.verbose, + ) + + # determine if these results are the best so far + # we chose a new run if it has a better inertia and the clustering is + # different from the best so far (it's possible that the inertia is + # slightly better even if the clustering is the same with potentially + # permuted labels, due to rounding errors) + if best_inertia is None or ( + inertia < best_inertia + and not self.is_same_clustering(labels, best_labels, self.n_clusters) + ): + self.best_labels = labels + self.best_centers = centers + self.best_inertia = inertia + self.best_n_iter = n_iter_ + + # return best_labels, best_inertia, best_centers, best_n_iter + + def post_fit(self, X, y=None, sample_weight=None): + self.unshift_centers(X, self.best_centers) + + distinct_clusters = len(set(self.best_labels)) + if distinct_clusters < self.estimator.n_clusters: + warnings.warn( + "Number of distinct clusters ({}) found smaller than " + "n_clusters ({}). Possibly due to duplicate points " + "in X.".format(distinct_clusters, self.estimator.n_clusters), + ConvergenceWarning, + stacklevel=2, + ) + + self.estimator.cluster_centers_ = self.best_centers + self.estimator._n_features_out = self.best_centers.shape[0] + self.estimator.labels_ = self.best_labels + self.estimator.inertia_ = self.best_inertia + self.estimator.n_iter_ = self.best_n_iter + + def pre_predict(self, X, sample_weight): + X = self.estimator._check_test_data(X) + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) + return X, sample_weight + + def predict(self, X, sample_weight=None): + labels, _ = _labels_inertia_threadpool_limit( + X, + sample_weight, + self.estimator.cluster_centers_, + n_threads=self.estimator._n_threads, + ) + + return labels + + def pre_transform(self, X): + return self.estimator._check_test_data(X) + + def transform(self, X): + return euclidean_distances(X, self.estimator.cluster_centers_) + + def score(self, X, sample_weight): + _, scores = _labels_inertia_threadpool_limit( + X, sample_weight, self.estimator.cluster_centers_, self.estimator._n_threads + ) + return scores @validate_params( @@ -857,9 +1037,6 @@ def _check_params_vs_input(self, X, default_n_init=None): f"n_samples={X.shape[0]} should be >= n_clusters={self.n_clusters}." ) - # tol - self._tol = _tolerance(X, self.tol) - # n-init # TODO(1.4): Remove self._n_init = self.n_init @@ -1373,6 +1550,26 @@ def _check_params_vs_input(self, X): ) self._algorithm = "lloyd" + def _get_engine(self, X, y=None, sample_weight=None, reset=False): + for provider, engine_class in get_engine_classes( + "kmeans", default=KMeansCythonEngine + ): + if hasattr(self, "_engine_provider") and not reset: + if self._engine_provider != provider: + continue + + engine = engine_class(self) + if engine.accepts(X, y=y): + self._engine_provider = provider + return engine + + if hasattr(self, "_engine_provider"): + raise RuntimeError( + "Estimator was previously fitted with the" + f" {self._engine_provider} engine, but it is not available. Currently" + f" configured engines: {get_config()['engine_provider']}" + ) + def _warn_mkl_vcomp(self, n_active_threads): """Warn when vcomp and mkl are both present""" warnings.warn( @@ -1410,102 +1607,146 @@ def fit(self, X, y=None, sample_weight=None): """ self._validate_params() - X = self._validate_data( + engine = self._get_engine(X, y, sample_weight, reset=True) + + if hasattr(engine, "pre_fit"): + X, y, sample_weight = engine.pre_fit( + X, + y=y, + sample_weight=sample_weight, + ) + + engine.fit( X, - accept_sparse="csr", - dtype=[np.float64, np.float32], - order="C", - copy=self.copy_x, - accept_large_sparse=False, + y=y, + sample_weight=sample_weight, ) - self._check_params_vs_input(X) + if hasattr(engine, "post_fit"): + engine.post_fit( + X, + y=y, + sample_weight=sample_weight, + ) - random_state = check_random_state(self.random_state) - sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) - self._n_threads = _openmp_effective_n_threads() + return self - # Validate init array - init = self.init - init_is_array_like = _is_arraylike_not_scalar(init) - if init_is_array_like: - init = check_array(init, dtype=X.dtype, copy=True, order="C") - self._validate_center_shape(X, init) + def predict(self, X, sample_weight=None): + """Predict the closest cluster each sample in X belongs to. - # subtract of mean of x for more accurate distance computations - if not sp.issparse(X): - X_mean = X.mean(axis=0) - # The copy was already done above - X -= X_mean + In the vector quantization literature, `cluster_centers_` is called + the code book and each value returned by `predict` is the index of + the closest code in the code book. - if init_is_array_like: - init -= X_mean + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to predict. - # precompute squared norms of data points - x_squared_norms = row_norms(X, squared=True) + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. - if self._algorithm == "elkan": - kmeans_single = _kmeans_single_elkan - else: - kmeans_single = _kmeans_single_lloyd - self._check_mkl_vcomp(X, X.shape[0]) + Returns + ------- + labels : ndarray of shape (n_samples,) + Index of the cluster each sample belongs to. + """ + check_is_fitted(self) + engine = self._get_engine(X) + if hasattr(engine, "pre_predict"): + X, sample_weight = engine.pre_predict(X, sample_weight) - best_inertia, best_labels = None, None + y_pred = engine.predict(X, sample_weight) - for i in range(self._n_init): - # Initialize centers - centers_init = self._init_centroids( - X, x_squared_norms=x_squared_norms, init=init, random_state=random_state - ) - if self.verbose: - print("Initialization complete") + if hasattr(engine, "post_predict"): + engine.post_predict(X, sample_weight) - # run a k-means once - labels, inertia, centers, n_iter_ = kmeans_single( - X, - sample_weight, - centers_init, - max_iter=self.max_iter, - verbose=self.verbose, - tol=self._tol, - n_threads=self._n_threads, - ) + return y_pred - # determine if these results are the best so far - # we chose a new run if it has a better inertia and the clustering is - # different from the best so far (it's possible that the inertia is - # slightly better even if the clustering is the same with potentially - # permuted labels, due to rounding errors) - if best_inertia is None or ( - inertia < best_inertia - and not _is_same_clustering(labels, best_labels, self.n_clusters) - ): - best_labels = labels - best_centers = centers - best_inertia = inertia - best_n_iter = n_iter_ + def fit_transform(self, X, y=None, sample_weight=None): + """Compute clustering and transform X to cluster-distance space. - if not sp.issparse(X): - if not self.copy_x: - X += X_mean - best_centers += X_mean + Equivalent to fit(X).transform(X), but more efficiently implemented. - distinct_clusters = len(set(best_labels)) - if distinct_clusters < self.n_clusters: - warnings.warn( - "Number of distinct clusters ({}) found smaller than " - "n_clusters ({}). Possibly due to duplicate points " - "in X.".format(distinct_clusters, self.n_clusters), - ConvergenceWarning, - stacklevel=2, - ) + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to transform. - self.cluster_centers_ = best_centers - self._n_features_out = self.cluster_centers_.shape[0] - self.labels_ = best_labels - self.inertia_ = best_inertia - self.n_iter_ = best_n_iter - return self + y : Ignored + Not used, present here for API consistency by convention. + + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. + + Returns + ------- + X_new : ndarray of shape (n_samples, n_clusters) + X transformed in the new space. + """ + # XXX pre_transform() is not called because fit() calls pre_fit() + self.fit(X, sample_weight=sample_weight) + engine = self._get_engine(X) + return self._transform(X, engine) + + def transform(self, X): + """Transform X to a cluster-distance space. + + In the new space, each dimension is the distance to the cluster + centers. Note that even if X is sparse, the array returned by + `transform` will typically be dense. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to transform. + + Returns + ------- + X_new : ndarray of shape (n_samples, n_clusters) + X transformed in the new space. + """ + check_is_fitted(self) + engine = self._get_engine(X) + if hasattr(engine, "pre_transform"): + X = engine.pre_transform(X) + return self._transform(X, engine) + + def _transform(self, X, engine): + """Guts of transform method; no input validation.""" + X_ = engine.transform(X) + if hasattr(engine, "post_transform"): + engine.post_transform(X_) + return X_ + + def score(self, X, y=None, sample_weight=None): + """Opposite of the value of X on the K-means objective. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data. + + y : Ignored + Not used, present here for API consistency by convention. + + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. + + Returns + ------- + score : float + Opposite of the value of X on the K-means objective. + """ + check_is_fitted(self) + engine = self._get_engine(X) + + X, sample_weight = engine.pre_predict(X, sample_weight) + + return -engine.score(X, sample_weight) def _mini_batch_step( @@ -1866,6 +2107,15 @@ def __init__( def _check_params_vs_input(self, X): super()._check_params_vs_input(X, default_n_init=3) + if self.tol > 0: + if sp.issparse(X): + variances = mean_variance_axis(X, axis=0)[1] + else: + variances = np.var(X, axis=0) + self._tol = np.mean(variances) * self.tol + else: + self._tol = 0.0 + self._batch_size = min(self.batch_size, X.shape[0]) # init_size diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py index d04e9dba4fade..977c667840483 100644 --- a/sklearn/cluster/tests/test_bicluster.py +++ b/sklearn/cluster/tests/test_bicluster.py @@ -253,7 +253,6 @@ def test_spectralbiclustering_parameter_validation(params, type_err, err_msg): @pytest.mark.parametrize("est", (SpectralBiclustering(), SpectralCoclustering())) def test_n_features_in_(est): - X, _, _ = make_biclusters((3, 3), 3, random_state=0) assert not hasattr(est, "n_features_in_") diff --git a/sklearn/exceptions.py b/sklearn/exceptions.py index d84c1f6b40526..bf368be65a274 100644 --- a/sklearn/exceptions.py +++ b/sklearn/exceptions.py @@ -5,6 +5,7 @@ __all__ = [ "NotFittedError", + "NotSupportedByEngineError", "ConvergenceWarning", "DataConversionWarning", "DataDimensionalityWarning", @@ -38,6 +39,19 @@ class NotFittedError(ValueError, AttributeError): """ +class NotSupportedByEngineError(NotImplementedError): + """External plugins might not support all the combinations of parameters and + input types that the the vanilla sklearn implementation otherwise supports. In such + cases, plugins can raise this exception class. When running the sklearn test modules + using the sklearn pytest plugin, all the unit tests that fail by raising this + exception class will be automatically marked as "xfail", this enables sorting out + the tests that fail because they test features that are not supported by the plugin + and tests that fail because the plugin misbehave on supported features. + + .. versionadded:: 1.2 + """ + + class ConvergenceWarning(UserWarning): """Custom warning to capture convergence problems diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 3a0a702be3792..dc92711769771 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -758,6 +758,11 @@ class from an array representing our data set and ask who's array([[1], [2]]...) """ + return self._kneighbors( + X=X, n_neighbors=n_neighbors, return_distance=return_distance + ) + + def _kneighbors(self, X=None, n_neighbors=None, return_distance=True): check_is_fitted(self) if n_neighbors is None: diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index b849d28e131a5..ffe0bd04b6c7b 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -19,6 +19,102 @@ from ._base import NeighborsBase, KNeighborsMixin, RadiusNeighborsMixin from ..base import ClassifierMixin from ..utils._param_validation import StrOptions +from .._engine import get_engine_classes +from .._config import get_config + + +class KNeighborsClassifierCythonEngine: + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None): + # The default engine accepts everything + return True + + def fit(self, X, y=None): + return self.estimator._fit(X, y) + + def predict(self, X): + if self.estimator.weights == "uniform": + # In that case, we do not need the distances to perform + # the weighting so we do not compute them. + neigh_ind = self.estimator.kneighbors(X, return_distance=False) + neigh_dist = None + else: + neigh_dist, neigh_ind = self.estimator.kneighbors(X) + + classes_ = self.estimator.classes_ + _y = self.estimator._y + if not self.estimator.outputs_2d_: + _y = self.estimator._y.reshape((-1, 1)) + classes_ = [self.estimator.classes_] + + n_outputs = len(classes_) + n_queries = _num_samples(X) + weights = _get_weights(neigh_dist, self.estimator.weights) + + y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) + for k, classes_k in enumerate(classes_): + if weights is None: + mode, _ = _mode(_y[neigh_ind, k], axis=1) + else: + mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) + + mode = np.asarray(mode.ravel(), dtype=np.intp) + y_pred[:, k] = classes_k.take(mode) + + if not self.estimator.outputs_2d_: + y_pred = y_pred.ravel() + + return y_pred + + def predict_proba(self, X): + if self.estimator.weights == "uniform": + # In that case, we do not need the distances to perform + # the weighting so we do not compute them. + neigh_ind = self.estimator.kneighbors(X, return_distance=False) + neigh_dist = None + else: + neigh_dist, neigh_ind = self.estimator.kneighbors(X) + + classes_ = self.estimator.classes_ + _y = self.estimator._y + if not self.estimator.outputs_2d_: + _y = self.estimator._y.reshape((-1, 1)) + classes_ = [self.estimator.classes_] + + n_queries = _num_samples(X) + + weights = _get_weights(neigh_dist, self.estimator.weights) + if weights is None: + weights = np.ones_like(neigh_ind) + + all_rows = np.arange(n_queries) + probabilities = [] + for k, classes_k in enumerate(classes_): + pred_labels = _y[:, k][neigh_ind] + proba_k = np.zeros((n_queries, classes_k.size)) + + # a simple ':' index doesn't work right + for i, idx in enumerate(pred_labels.T): # loop is O(n_neighbors) + proba_k[all_rows, idx] += weights[:, i] + + # normalize 'votes' into real [0,1] probabilities + normalizer = proba_k.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba_k /= normalizer + + probabilities.append(proba_k) + + if not self.estimator.outputs_2d_: + probabilities = probabilities[0] + + return probabilities + + def kneighbors(self, X=None, n_neighbors=None, return_distance=True): + return self.estimator._kneighbors( + X=X, n_neighbors=n_neighbors, return_distance=return_distance + ) class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase): @@ -212,7 +308,26 @@ def fit(self, X, y): """ self._validate_params() - return self._fit(X, y) + engine = self._get_engine(X, y, reset=True) + + if hasattr(engine, "pre_fit"): + X, y, sample_weight = engine.pre_fit( + X, + y=y, + ) + + engine.fit( + X, + y=y, + ) + + if hasattr(engine, "post_fit"): + engine.post_fit( + X, + y=y, + ) + + return self def predict(self, X): """Predict the class labels for the provided data. @@ -228,36 +343,19 @@ def predict(self, X): y : ndarray of shape (n_queries,) or (n_queries, n_outputs) Class labels for each data sample. """ - if self.weights == "uniform": - # In that case, we do not need the distances to perform - # the weighting so we do not compute them. - neigh_ind = self.kneighbors(X, return_distance=False) - neigh_dist = None - else: - neigh_dist, neigh_ind = self.kneighbors(X) - - classes_ = self.classes_ - _y = self._y - if not self.outputs_2d_: - _y = self._y.reshape((-1, 1)) - classes_ = [self.classes_] + engine = self._get_engine(X) - n_outputs = len(classes_) - n_queries = _num_samples(X) - weights = _get_weights(neigh_dist, self.weights) - - y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) - for k, classes_k in enumerate(classes_): - if weights is None: - mode, _ = _mode(_y[neigh_ind, k], axis=1) - else: - mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) + if hasattr(engine, "pre_predict"): + X = engine.pre_predict( + X, + ) - mode = np.asarray(mode.ravel(), dtype=np.intp) - y_pred[:, k] = classes_k.take(mode) + y_pred = engine.predict(X) - if not self.outputs_2d_: - y_pred = y_pred.ravel() + if hasattr(engine, "post_predict"): + engine.post_predict( + X, + ) return y_pred @@ -277,51 +375,115 @@ def predict_proba(self, X): The class probabilities of the input samples. Classes are ordered by lexicographic order. """ - if self.weights == "uniform": - # In that case, we do not need the distances to perform - # the weighting so we do not compute them. - neigh_ind = self.kneighbors(X, return_distance=False) - neigh_dist = None - else: - neigh_dist, neigh_ind = self.kneighbors(X) + engine = self._get_engine(X) - classes_ = self.classes_ - _y = self._y - if not self.outputs_2d_: - _y = self._y.reshape((-1, 1)) - classes_ = [self.classes_] + if hasattr(engine, "pre_predict_proba"): + X = engine.pre_predict_proba( + X, + ) - n_queries = _num_samples(X) + probabilities = engine.predict_proba(X) - weights = _get_weights(neigh_dist, self.weights) - if weights is None: - weights = np.ones_like(neigh_ind) + if hasattr(engine, "post_predict_proba"): + engine.post_predict_proba( + X, + ) - all_rows = np.arange(n_queries) - probabilities = [] - for k, classes_k in enumerate(classes_): - pred_labels = _y[:, k][neigh_ind] - proba_k = np.zeros((n_queries, classes_k.size)) + return probabilities - # a simple ':' index doesn't work right - for i, idx in enumerate(pred_labels.T): # loop is O(n_neighbors) - proba_k[all_rows, idx] += weights[:, i] + def kneighbors(self, X=None, n_neighbors=None, return_distance=True): + """Find the K-neighbors of a point. - # normalize 'votes' into real [0,1] probabilities - normalizer = proba_k.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba_k /= normalizer + Returns indices of and distances to the neighbors of each point. - probabilities.append(proba_k) + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_queries, n_features), \ + or (n_queries, n_indexed) if metric == 'precomputed', default=None + The query point or points. + If not provided, neighbors of each indexed point are returned. + In this case, the query point is not considered its own neighbor. - if not self.outputs_2d_: - probabilities = probabilities[0] + n_neighbors : int, default=None + Number of neighbors required for each sample. The default is the + value passed to the constructor. + + return_distance : bool, default=True + Whether or not to return the distances. + + Returns + ------- + neigh_dist : ndarray of shape (n_queries, n_neighbors) + Array representing the lengths to points, only present if + return_distance=True. + + neigh_ind : ndarray of shape (n_queries, n_neighbors) + Indices of the nearest points in the population matrix. + + Examples + -------- + In the following example, we construct a NearestNeighbors + class from an array representing our data set and ask who's + the closest point to [1,1,1] + + >>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]] + >>> from sklearn.neighbors import NearestNeighbors + >>> neigh = NearestNeighbors(n_neighbors=1) + >>> neigh.fit(samples) + NearestNeighbors(n_neighbors=1) + >>> print(neigh.kneighbors([[1., 1., 1.]])) + (array([[0.5]]), array([[2]])) + + As you can see, it returns [[0.5]], and [[2]], which means that the + element is at distance 0.5 and is the third element of samples + (indexes start at 0). You can also query for multiple points: + + >>> X = [[0., 1., 0.], [1., 0., 1.]] + >>> neigh.kneighbors(X, return_distance=False) + array([[1], + [2]]...) + """ + engine = self._get_engine(X) + + if hasattr(engine, "pre_kneighbors"): + X = engine.pre_kneighbors( + X=X, n_neighbors=n_neighbors, return_distance=return_distance + ) + + probabilities = engine.kneighbors( + X=X, n_neighbors=n_neighbors, return_distance=return_distance + ) + + if hasattr(engine, "post_kneighbors"): + engine.post_kneighbors( + X=X, n_neighbors=n_neighbors, return_distance=return_distance + ) return probabilities def _more_tags(self): return {"multilabel": True} + def _get_engine(self, X, y=None, sample_weight=None, reset=False): + for provider, engine_class in get_engine_classes( + "kneigborsclassifier", default=KNeighborsClassifierCythonEngine + ): + if hasattr(self, "_engine_provider") and not reset: + if self._engine_provider != provider: + continue + + engine = engine_class(self) + if engine.accepts(X, y=y): + self._engine_provider = provider + return engine + + if hasattr(self, "_engine_provider"): + raise RuntimeError( + "Estimator was previously fitted with the" + f" {self._engine_provider} engine, but it is not available. Currently" + f" configured engines: {get_config()['engine_provider']}" + ) + class RadiusNeighborsClassifier(RadiusNeighborsMixin, ClassifierMixin, NeighborsBase): """Classifier implementing a vote among neighbors within a given radius. diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index a0b8f29662b69..083d8d85ed94e 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -17,6 +17,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), "transform_output": "default", } @@ -33,6 +34,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), "transform_output": "default", } assert get_config()["assume_finite"] is False @@ -66,6 +68,7 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), "transform_output": "default", }