diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 68494051041be..99731c7fda599 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -955,6 +955,7 @@ See the :ref:`metrics` section of the user guide for further details. metrics.pairwise_distances metrics.pairwise_distances_argmin metrics.pairwise_distances_argmin_min + metrics.pairwise_distances_chunked .. _mixture_ref: diff --git a/doc/modules/computational_performance.rst b/doc/modules/computational_performance.rst index d66cba212a2dd..ca128c515fb90 100644 --- a/doc/modules/computational_performance.rst +++ b/doc/modules/computational_performance.rst @@ -308,6 +308,27 @@ Debian / Ubuntu. or upgrade to Python 3.4 which has a new version of ``multiprocessing`` that should be immune to this problem. +.. _working_memory: + +Limiting Working Memory +----------------------- + +Some calculations when implemented using standard numpy vectorized operations +involve using a large amount of temporary memory. This may potentially exhaust +system memory. Where computations can be performed in fixed-memory chunks, we +attempt to do so, and allow the user to hint at the maximum size of this +working memory (defaulting to 1GB) using :func:`sklearn.set_config` or +:func:`config_context`. The following suggests to limit temporary working +memory to 128 MiB:: + + >>> import sklearn + >>> with sklearn.config_context(working_memory=128): + ... pass # do chunked work here + +An example of a chunked operation adhering to this setting is +:func:`metric.pairwise_distances_chunked`, which facilitates computing +row-wise reductions of a pairwise distance matrix. + Model Compression ----------------- diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 13de05ada8369..d10400442b387 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -56,6 +56,9 @@ Classifiers and regressors - :class:`dummy.DummyRegressor` now has a ``return_std`` option in its ``predict`` method. The returned standard deviations will be zeros. +- Added :class:`multioutput.RegressorChain` for multi-target + regression. :issue:`9257` by :user:`Kumar Ashutosh `. + - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). :issue:`8190` by :user:`Michael A. Alcorn `. @@ -115,6 +118,13 @@ Metrics :func:`metrics.roc_auc_score`. :issue:`3273` by :user:`Alexander Niederbühl `. +Misc + +- A new configuration parameter, ``working_memory`` was added to control memory + consumption limits in chunked operations, such as the new + :func:`metrics.pairwise_distances_chunked`. See :ref:`working_memory`. + :issue:`10280` by `Joel Nothman`_ and :user:`Aman Dalmia `. + Enhancements ............ @@ -521,6 +531,12 @@ Metrics due to floating point error in the input. :issue:`9851` by :user:`Hanmin Qin `. +- The ``batch_size`` parameter to :func:`metrics.pairwise_distances_argmin_min` + and :func:`metrics.pairwise_distances_argmin` is deprecated to be removed in + v0.22. It no longer has any effect, as batch size is determined by global + ``working_memory`` config. See :ref:`working_memory`. :issue:`10280` by `Joel + Nothman`_ and :user:`Aman Dalmia `. + Cluster - Deprecate ``pooling_func`` unused parameter in diff --git a/sklearn/_config.py b/sklearn/_config.py index 1ad53c9a22527..2b8a2e795bf86 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -5,6 +5,7 @@ _global_config = { 'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)), + 'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)) } @@ -19,7 +20,7 @@ def get_config(): return _global_config.copy() -def set_config(assume_finite=None): +def set_config(assume_finite=None, working_memory=None): """Set global scikit-learn configuration Parameters @@ -29,9 +30,17 @@ def set_config(assume_finite=None): saving time, but leading to potential crashes. If False, validation for finiteness will be performed, avoiding error. Global default: False. + + working_memory : int, optional + If set, scikit-learn will attempt to limit the size of temporary arrays + to this number of MiB (per job when parallelised), often saving both + computation time and memory on expensive operations that can be + performed in chunks. Global default: 1024. """ if assume_finite is not None: _global_config['assume_finite'] = assume_finite + if working_memory is not None: + _global_config['working_memory'] = working_memory @contextmanager @@ -46,6 +55,12 @@ def config_context(**new_config): False, validation for finiteness will be performed, avoiding error. Global default: False. + working_memory : int, optional + If set, scikit-learn will attempt to limit the size of temporary arrays + to this number of MiB (per job when parallelised), often saving both + computation time and memory on expensive operations that can be + performed in chunks. Global default: 1024. + Notes ----- All settings, not just those presently modified, will be returned to diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index c98b0e14493c6..846634e7afec4 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -51,6 +51,7 @@ from .pairwise import pairwise_distances_argmin from .pairwise import pairwise_distances_argmin_min from .pairwise import pairwise_kernels +from .pairwise import pairwise_distances_chunked from .regression import explained_variance_score from .regression import mean_absolute_error @@ -106,6 +107,7 @@ 'pairwise_distances_argmin', 'pairwise_distances_argmin_min', 'pairwise_distances_argmin_min', + 'pairwise_distances_chunked', 'pairwise_kernels', 'precision_recall_curve', 'precision_recall_fscore_support', diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 58bd3f1627f22..b4928ed7492f3 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -18,9 +18,10 @@ from scipy.sparse import csr_matrix from scipy.sparse import issparse +from ..utils.validation import _num_samples from ..utils import check_array from ..utils import gen_even_slices -from ..utils import gen_batches +from ..utils import gen_batches, get_chunk_n_rows from ..utils.extmath import row_norms, safe_sparse_dot from ..preprocessing import normalize from ..externals.joblib import Parallel @@ -257,8 +258,14 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, return distances if squared else np.sqrt(distances, out=distances) +def _argmin_min_reduce(dist, start): + indices = dist.argmin(axis=1) + values = dist[np.arange(dist.shape[0]), indices] + return indices, values + + def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean", - batch_size=500, metric_kwargs=None): + batch_size=None, metric_kwargs=None): """Compute minimum distances between one point and a set of points. This function computes for each row in X, the index of the row of Y which @@ -310,11 +317,9 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean", metrics. batch_size : integer - To reduce memory consumption over the naive solution, data are - processed in batches, comprising batch_size rows of X and - batch_size rows of Y. The default value is quite conservative, but - can be changed for fine-tuning. The larger the number, the larger the - memory usage. + .. deprecated:: 0.20 + Deprecated for removal in 0.22. + Use sklearn.set_config(working_memory=...) instead. metric_kwargs : dict, optional Keyword arguments to pass to specified metric function. @@ -333,12 +338,11 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean", sklearn.metrics.pairwise_distances sklearn.metrics.pairwise_distances_argmin """ - dist_func = None - if metric in PAIRWISE_DISTANCE_FUNCTIONS: - dist_func = PAIRWISE_DISTANCE_FUNCTIONS[metric] - elif not callable(metric) and not isinstance(metric, str): - raise ValueError("'metric' must be a string or a callable") - + if batch_size is not None: + warnings.warn("'batch_size' is ignored. It was deprecated in version " + "0.20 and will be removed in version 0.22. " + "Use sklearn.set_config(working_memory=...) instead.", + DeprecationWarning) X, Y = check_pairwise_arrays(X, Y) if metric_kwargs is None: @@ -347,39 +351,11 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean", if axis == 0: X, Y = Y, X - # Allocate output arrays - indices = np.empty(X.shape[0], dtype=np.intp) - values = np.empty(X.shape[0]) - values.fill(np.infty) - - for chunk_x in gen_batches(X.shape[0], batch_size): - X_chunk = X[chunk_x, :] - - for chunk_y in gen_batches(Y.shape[0], batch_size): - Y_chunk = Y[chunk_y, :] - - if dist_func is not None: - if metric == 'euclidean': # special case, for speed - d_chunk = safe_sparse_dot(X_chunk, Y_chunk.T, - dense_output=True) - d_chunk *= -2 - d_chunk += row_norms(X_chunk, squared=True)[:, np.newaxis] - d_chunk += row_norms(Y_chunk, squared=True)[np.newaxis, :] - np.maximum(d_chunk, 0, d_chunk) - else: - d_chunk = dist_func(X_chunk, Y_chunk, **metric_kwargs) - else: - d_chunk = pairwise_distances(X_chunk, Y_chunk, - metric=metric, **metric_kwargs) - - # Update indices and minimum values using chunk - min_indices = d_chunk.argmin(axis=1) - min_values = d_chunk[np.arange(chunk_x.stop - chunk_x.start), - min_indices] - - flags = values[chunk_x] > min_values - indices[chunk_x][flags] = min_indices[flags] + chunk_y.start - values[chunk_x][flags] = min_values[flags] + indices, values = zip(*pairwise_distances_chunked( + X, Y, reduce_func=_argmin_min_reduce, metric=metric, + **metric_kwargs)) + indices = np.concatenate(indices) + values = np.concatenate(values) if metric == "euclidean" and not metric_kwargs.get("squared", False): np.sqrt(values, values) @@ -387,7 +363,7 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean", def pairwise_distances_argmin(X, Y, axis=1, metric="euclidean", - batch_size=500, metric_kwargs=None): + batch_size=None, metric_kwargs=None): """Compute minimum distances between one point and a set of points. This function computes for each row in X, the index of the row of Y which @@ -441,11 +417,9 @@ def pairwise_distances_argmin(X, Y, axis=1, metric="euclidean", metrics. batch_size : integer - To reduce memory consumption over the naive solution, data are - processed in batches, comprising batch_size rows of X and - batch_size rows of Y. The default value is quite conservative, but - can be changed for fine-tuning. The larger the number, the larger the - memory usage. + .. deprecated:: 0.20 + Deprecated for removal in 0.22. + Use sklearn.set_config(working_memory=...) instead. metric_kwargs : dict keyword arguments to pass to specified metric function. @@ -463,8 +437,9 @@ def pairwise_distances_argmin(X, Y, axis=1, metric="euclidean", if metric_kwargs is None: metric_kwargs = {} - return pairwise_distances_argmin_min(X, Y, axis, metric, batch_size, - metric_kwargs)[0] + return pairwise_distances_argmin_min(X, Y, axis, metric, + metric_kwargs=metric_kwargs, + batch_size=batch_size)[0] def manhattan_distances(X, Y=None, sum_over_features=True, @@ -928,7 +903,8 @@ def cosine_similarity(X, Y=None, dense_output=True): else: Y_normalized = normalize(Y, copy=True) - K = safe_sparse_dot(X_normalized, Y_normalized.T, dense_output=dense_output) + K = safe_sparse_dot(X_normalized, Y_normalized.T, + dense_output=dense_output) return K @@ -1144,6 +1120,177 @@ def _pairwise_callable(X, Y, metric, **kwds): 'sokalsneath', 'sqeuclidean', 'yule', "wminkowski"] +def _check_chunk_size(reduced, chunk_size): + """Checks chunk is a sequence of expected size or a tuple of same + """ + is_tuple = isinstance(reduced, tuple) + if not is_tuple: + reduced = (reduced,) + if any(isinstance(r, tuple) or not hasattr(r, '__iter__') + for r in reduced): + raise TypeError('reduce_func returned %r. ' + 'Expected sequence(s) of length %d.' % + (reduced if is_tuple else reduced[0], chunk_size)) + if any(_num_samples(r) != chunk_size for r in reduced): + # XXX: we use int(_num_samples...) because sometimes _num_samples + # returns a long in Python 2, even for small numbers. + actual_size = tuple(int(_num_samples(r)) for r in reduced) + raise ValueError('reduce_func returned object of length %s. ' + 'Expected same length as input: %d.' % + (actual_size if is_tuple else actual_size[0], + chunk_size)) + + +def pairwise_distances_chunked(X, Y=None, reduce_func=None, + metric='euclidean', n_jobs=1, + working_memory=None, **kwds): + """Generate a distance matrix chunk by chunk with optional reduction + + In cases where not all of a pairwise distance matrix needs to be stored at + once, this is used to calculate pairwise distances in + ``working_memory``-sized chunks. If ``reduce_func`` is given, it is run + on each chunk and its return values are concatenated into lists, arrays + or sparse matrices. + + Parameters + ---------- + X : array [n_samples_a, n_samples_a] if metric == "precomputed", or, + [n_samples_a, n_features] otherwise + Array of pairwise distances between samples, or a feature array. + + Y : array [n_samples_b, n_features], optional + An optional second feature array. Only allowed if + metric != "precomputed". + + reduce_func : callable, optional + The function which is applied on each chunk of the distance matrix, + reducing it to needed values. ``reduce_func(D_chunk, start)`` + is called repeatedly, where ``D_chunk`` is a contiguous vertical + slice of the pairwise distance matrix, starting at row ``start``. + It should return an array, a list, or a sparse matrix of length + ``D_chunk.shape[0]``, or a tuple of such objects. + + If None, pairwise_distances_chunked returns a generator of vertical + chunks of the distance matrix. + + metric : string, or callable + The metric to use when calculating distance between instances in a + feature array. If metric is a string, it must be one of the options + allowed by scipy.spatial.distance.pdist for its metric parameter, or + a metric listed in pairwise.PAIRWISE_DISTANCE_FUNCTIONS. + If metric is "precomputed", X is assumed to be a distance matrix. + Alternatively, if metric is a callable function, it is called on each + pair of instances (rows) and the resulting value recorded. The callable + should take two arrays from X as input and return a value indicating + the distance between them. + + n_jobs : int + The number of jobs to use for the computation. This works by breaking + down the pairwise matrix into n_jobs even slices and computing them in + parallel. + + If -1 all CPUs are used. If 1 is given, no parallel computing code is + used at all, which is useful for debugging. For n_jobs below -1, + (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one + are used. + + working_memory : int, optional + The sought maximum memory for temporary distance matrix chunks. + When None (default), the value of + ``sklearn.get_config()['working_memory']`` is used. + + `**kwds` : optional keyword parameters + Any further parameters are passed directly to the distance function. + If using a scipy.spatial.distance metric, the parameters are still + metric dependent. See the scipy docs for usage examples. + + Yields + ------ + D_chunk : array or sparse matrix + A contiguous slice of distance matrix, optionally processed by + ``reduce_func``. + + Examples + -------- + Without reduce_func: + + >>> X = np.random.RandomState(0).rand(5, 3) + >>> D_chunk = next(pairwise_distances_chunked(X)) + >>> D_chunk # doctest: +ELLIPSIS + array([[0. ..., 0.29..., 0.41..., 0.19..., 0.57...], + [0.29..., 0. ..., 0.57..., 0.41..., 0.76...], + [0.41..., 0.57..., 0. ..., 0.44..., 0.90...], + [0.19..., 0.41..., 0.44..., 0. ..., 0.51...], + [0.57..., 0.76..., 0.90..., 0.51..., 0. ...]]) + + Retrieve all neighbors and average distance within radius r: + + >>> r = .2 + >>> def reduce_func(D_chunk, start): + ... neigh = [np.flatnonzero(d < r) for d in D_chunk] + ... avg_dist = (D_chunk * (D_chunk < r)).mean(axis=1) + ... return neigh, avg_dist + >>> gen = pairwise_distances_chunked(X, reduce_func=reduce_func) + >>> neigh, avg_dist = next(gen) + >>> neigh + [array([0, 3]), array([1]), array([2]), array([0, 3]), array([4])] + >>> avg_dist # doctest: +ELLIPSIS + array([0.039..., 0. , 0. , 0.039..., 0. ]) + + Where r is defined per sample, we need to make use of ``start``: + + >>> r = [.2, .4, .4, .3, .1] + >>> def reduce_func(D_chunk, start): + ... neigh = [np.flatnonzero(d < r[i]) + ... for i, d in enumerate(D_chunk, start)] + ... return neigh + >>> neigh = next(pairwise_distances_chunked(X, reduce_func=reduce_func)) + >>> neigh + [array([0, 3]), array([0, 1]), array([2]), array([0, 3]), array([4])] + + Force row-by-row generation by reducing ``working_memory``: + + >>> gen = pairwise_distances_chunked(X, reduce_func=reduce_func, + ... working_memory=0) + >>> next(gen) + [array([0, 3])] + >>> next(gen) + [array([0, 1])] + """ + n_samples_X = _num_samples(X) + if metric == 'precomputed': + slices = (slice(0, n_samples_X),) + else: + if Y is None: + Y = X + # We get as many rows as possible within our working_memory budget to + # store len(Y) distances in each row of output. + # + # Note: + # - this will get at least 1 row, even if 1 row of distances will + # exceed working_memory. + # - this does not account for any temporary memory usage while + # calculating distances (e.g. difference of vectors in manhattan + # distance. + chunk_n_rows = get_chunk_n_rows(row_bytes=8 * _num_samples(Y), + max_n_rows=n_samples_X, + working_memory=working_memory) + slices = gen_batches(n_samples_X, chunk_n_rows) + + for sl in slices: + if sl.start == 0 and sl.stop == n_samples_X: + X_chunk = X # enable optimised paths for X is Y + else: + X_chunk = X[sl] + D_chunk = pairwise_distances(X_chunk, Y, metric=metric, + n_jobs=n_jobs, **kwds) + if reduce_func is not None: + chunk_size = D_chunk.shape[0] + D_chunk = reduce_func(D_chunk, sl.start) + _check_chunk_size(D_chunk, chunk_size) + yield D_chunk + + def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds): """ Compute the distance matrix from a vector array X and optional Y. @@ -1186,7 +1333,8 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds): Array of pairwise distances between samples, or a feature array. Y : array [n_samples_b, n_features], optional - An optional second feature array. Only allowed if metric != "precomputed". + An optional second feature array. Only allowed if + metric != "precomputed". metric : string, or callable The metric to use when calculating distance between instances in a @@ -1224,6 +1372,9 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds): See also -------- + pairwise_distances_chunked : performs the same calculation as this funtion, + but returns a generator of chunks of the distance matrix, in order to + limit memory usage. paired_distances : Computes the distances between corresponding elements of two arrays """ diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 799b3e4fe9bf7..0ef089c7a3619 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -1,11 +1,15 @@ +from types import GeneratorType + import numpy as np from numpy import linalg +import pytest from scipy.sparse import dok_matrix, csr_matrix, issparse from scipy.spatial.distance import cosine, cityblock, minkowski, wminkowski from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal @@ -14,6 +18,7 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_warns from sklearn.utils.testing import ignore_warnings +from sklearn.utils.testing import assert_warns_message from sklearn.externals.six import iteritems @@ -28,6 +33,7 @@ from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_distances from sklearn.metrics.pairwise import pairwise_distances +from sklearn.metrics.pairwise import pairwise_distances_chunked from sklearn.metrics.pairwise import pairwise_distances_argmin_min from sklearn.metrics.pairwise import pairwise_distances_argmin from sklearn.metrics.pairwise import pairwise_kernels @@ -368,10 +374,128 @@ def test_pairwise_distances_argmin_min(): dist_orig_val = dist[dist_orig_ind, range(len(dist_orig_ind))] dist_chunked_ind, dist_chunked_val = pairwise_distances_argmin_min( - X, Y, axis=0, metric="manhattan", batch_size=50) + X, Y, axis=0, metric="manhattan") np.testing.assert_almost_equal(dist_orig_ind, dist_chunked_ind, decimal=7) np.testing.assert_almost_equal(dist_orig_val, dist_chunked_val, decimal=7) + # Test batch_size deprecation warning + assert_warns_message(DeprecationWarning, "version 0.22", + pairwise_distances_argmin_min, X, Y, batch_size=500, + metric='euclidean') + + +def _reduce_func(dist, start): + return dist[:, :100] + + +def test_pairwise_distances_chunked_reduce(): + rng = np.random.RandomState(0) + X = rng.random_sample((400, 4)) + # Reduced Euclidean distance + S = pairwise_distances(X)[:, :100] + S_chunks = pairwise_distances_chunked(X, None, reduce_func=_reduce_func, + working_memory=2 ** -16) + assert isinstance(S_chunks, GeneratorType) + S_chunks = list(S_chunks) + assert len(S_chunks) > 1 + # atol is for diagonal where S is explicitly zeroed on the diagonal + assert_allclose(np.vstack(S_chunks), S, atol=1e-7) + + +@pytest.mark.parametrize('good_reduce', [ + lambda D, start: list(D), + lambda D, start: np.array(D), + lambda D, start: csr_matrix(D), + lambda D, start: (list(D), list(D)), + lambda D, start: (dok_matrix(D), np.array(D), list(D)), + ]) +def test_pairwise_distances_chunked_reduce_valid(good_reduce): + X = np.arange(10).reshape(-1, 1) + S_chunks = pairwise_distances_chunked(X, None, reduce_func=good_reduce, + working_memory=64) + next(S_chunks) + + +@pytest.mark.parametrize(('bad_reduce', 'err_type', 'message'), [ + (lambda D, s: np.concatenate([D, D[-1:]]), ValueError, + r'length 11\..* input: 10\.'), + (lambda D, s: (D, np.concatenate([D, D[-1:]])), ValueError, + r'length \(10, 11\)\..* input: 10\.'), + (lambda D, s: (D[:9], D), ValueError, + r'length \(9, 10\)\..* input: 10\.'), + (lambda D, s: 7, TypeError, + r'returned 7\. Expected sequence\(s\) of length 10\.'), + (lambda D, s: (7, 8), TypeError, + r'returned \(7, 8\)\. Expected sequence\(s\) of length 10\.'), + (lambda D, s: (np.arange(10), 9), TypeError, + r', 9\)\. Expected sequence\(s\) of length 10\.'), +]) +def test_pairwise_distances_chunked_reduce_invalid(bad_reduce, err_type, + message): + X = np.arange(10).reshape(-1, 1) + S_chunks = pairwise_distances_chunked(X, None, reduce_func=bad_reduce, + working_memory=64) + assert_raises_regexp(err_type, message, next, S_chunks) + + +def check_pairwise_distances_chunked(X, Y, working_memory, metric='euclidean'): + gen = pairwise_distances_chunked(X, Y, working_memory=working_memory, + metric=metric) + assert isinstance(gen, GeneratorType) + blockwise_distances = list(gen) + Y = np.array(X if Y is None else Y) + min_block_mib = len(Y) * 8 * 2 ** -20 + + for block in blockwise_distances: + memory_used = block.nbytes + assert memory_used <= max(working_memory, min_block_mib) * 2 ** 20 + + blockwise_distances = np.vstack(blockwise_distances) + S = pairwise_distances(X, Y, metric=metric) + assert_array_almost_equal(blockwise_distances, S) + + +@ignore_warnings +def test_pairwise_distances_chunked(): + # Test the pairwise_distance helper function. + rng = np.random.RandomState(0) + # Euclidean distance should be equivalent to calling the function. + X = rng.random_sample((400, 4)) + check_pairwise_distances_chunked(X, None, working_memory=1, + metric='euclidean') + # Test small amounts of memory + for power in range(-16, 0): + check_pairwise_distances_chunked(X, None, working_memory=2 ** power, + metric='euclidean') + # X as list + check_pairwise_distances_chunked(X.tolist(), None, working_memory=1, + metric='euclidean') + # Euclidean distance, with Y != X. + Y = rng.random_sample((200, 4)) + check_pairwise_distances_chunked(X, Y, working_memory=1, + metric='euclidean') + check_pairwise_distances_chunked(X.tolist(), Y.tolist(), working_memory=1, + metric='euclidean') + # absurdly large working_memory + check_pairwise_distances_chunked(X, Y, working_memory=10000, + metric='euclidean') + # "cityblock" uses scikit-learn metric, cityblock (function) is + # scipy.spatial. + check_pairwise_distances_chunked(X, Y, working_memory=1, + metric='cityblock') + # Test that a value error is raised if the metric is unknown + assert_raises(ValueError, next, + pairwise_distances_chunked(X, Y, metric="blah")) + + # Test precomputed returns all at once + D = pairwise_distances(X) + gen = pairwise_distances_chunked(D, + working_memory=2 ** -16, + metric='precomputed') + assert isinstance(gen, GeneratorType) + assert next(gen) is D + assert_raises(StopIteration, next, gen) + def test_euclidean_distances(): # Check the pairwise Euclidean distances computation diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index ace8590b08157..83ee27cccd4b1 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -145,7 +145,6 @@ def predict(self, X): X = check_array(X, accept_sparse='csr') neigh_dist, neigh_ind = self.kneighbors(X) - classes_ = self.classes_ _y = self._y if not self.outputs_2d_: diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 08731b3c06e94..efaa57f850367 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -3,14 +3,14 @@ def test_config_context(): - assert get_config() == {'assume_finite': False} + assert get_config() == {'assume_finite': False, 'working_memory': 1024} # Not using as a context manager affects nothing config_context(assume_finite=True) assert get_config()['assume_finite'] is False with config_context(assume_finite=True): - assert get_config() == {'assume_finite': True} + assert get_config() == {'assume_finite': True, 'working_memory': 1024} assert get_config()['assume_finite'] is False with config_context(assume_finite=True): @@ -34,7 +34,7 @@ def test_config_context(): assert get_config()['assume_finite'] is True - assert get_config() == {'assume_finite': False} + assert get_config() == {'assume_finite': False, 'working_memory': 1024} # No positional arguments assert_raises(TypeError, config_context, True) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 21dcca4f0764e..e3d1e7faaabd1 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -17,7 +17,7 @@ from ..externals.joblib import cpu_count from ..exceptions import DataConversionWarning from .deprecation import deprecated - +from .. import get_config __all__ = ["murmurhash3_32", "as_float_array", "assert_all_finite", "check_array", @@ -514,3 +514,42 @@ def indices_to_mask(indices, mask_length): mask[indices] = True return mask + + +def get_chunk_n_rows(row_bytes, max_n_rows=None, + working_memory=None): + """Calculates how many rows can be processed within working_memory + + Parameters + ---------- + row_bytes : int + The expected number of bytes of memory that will be consumed + during the processing of each row. + max_n_rows : int, optional + The maximum return value. + working_memory : int or float, optional + The number of rows to fit inside this number of MiB will be returned. + When None (default), the value of + ``sklearn.get_config()['working_memory']`` is used. + + Returns + ------- + int or the value of n_samples + + Warns + ----- + Issues a UserWarning if ``row_bytes`` exceeds ``working_memory`` MiB. + """ + + if working_memory is None: + working_memory = get_config()['working_memory'] + + chunk_n_rows = int(working_memory * (2 ** 20) // row_bytes) + if max_n_rows is not None: + chunk_n_rows = min(chunk_n_rows, max_n_rows) + if chunk_n_rows < 1: + warnings.warn('Could not adhere to working_memory config. ' + 'Currently %.0fMiB, %.0fMiB required.' % + (working_memory, np.ceil(row_bytes * 2 ** -20))) + chunk_n_rows = 1 + return chunk_n_rows diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index fa93bf34fe6bc..1f1efed825c80 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -1,6 +1,7 @@ from itertools import chain, product import warnings +import pytest import numpy as np import scipy.sparse as sp from scipy.linalg import pinv2 @@ -9,7 +10,8 @@ from sklearn.utils.testing import (assert_equal, assert_raises, assert_true, assert_almost_equal, assert_array_equal, SkipTest, assert_raises_regex, - assert_greater_equal, ignore_warnings) + assert_greater_equal, ignore_warnings, + assert_warns_message, assert_no_warnings) from sklearn.utils import check_random_state from sklearn.utils import deprecated from sklearn.utils import resample @@ -18,9 +20,11 @@ from sklearn.utils import safe_indexing from sklearn.utils import shuffle from sklearn.utils import gen_even_slices +from sklearn.utils import get_chunk_n_rows from sklearn.utils.extmath import pinvh from sklearn.utils.arpack import eigsh from sklearn.utils.mocking import MockDataFrame +from sklearn import config_context def test_make_rng(): @@ -274,3 +278,39 @@ def test_gen_even_slices(): slices = gen_even_slices(10, -1) assert_raises_regex(ValueError, "gen_even_slices got n_packs=-1, must be" " >=1", next, slices) + + +@pytest.mark.parametrize( + ('row_bytes', 'max_n_rows', 'working_memory', 'expected', 'warning'), + [(1024, None, 1, 1024, None), + (1024, None, 0.99999999, 1023, None), + (1023, None, 1, 1025, None), + (1025, None, 1, 1023, None), + (1024, None, 2, 2048, None), + (1024, 7, 1, 7, None), + (1024 * 1024, None, 1, 1, None), + (1024 * 1024 + 1, None, 1, 1, + 'Could not adhere to working_memory config. ' + 'Currently 1MiB, 2MiB required.'), + ]) +def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, + expected, warning): + if warning is not None: + def check_warning(*args, **kw): + return assert_warns_message(UserWarning, warning, *args, **kw) + else: + check_warning = assert_no_warnings + + actual = check_warning(get_chunk_n_rows, + row_bytes=row_bytes, + max_n_rows=max_n_rows, + working_memory=working_memory) + + assert actual == expected + assert type(actual) is type(expected) + with config_context(working_memory=working_memory): + actual = check_warning(get_chunk_n_rows, + row_bytes=row_bytes, + max_n_rows=max_n_rows) + assert actual == expected + assert type(actual) is type(expected)