@@ -110,7 +110,13 @@ def _num_samples(x):
110
110
x = np .asarray (x )
111
111
else :
112
112
raise TypeError ("Expected sequence or array-like, got %r" % x )
113
- return x .shape [0 ] if hasattr (x , 'shape' ) else len (x )
113
+ if hasattr (x , 'shape' ):
114
+ if len (x .shape ) == 0 :
115
+ raise TypeError ("Singleton array %r cannot be considered"
116
+ " a valid collection." % x )
117
+ return x .shape [0 ]
118
+ else :
119
+ return len (x )
114
120
115
121
116
122
def check_consistent_length (* arrays ):
@@ -222,10 +228,11 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, order, copy,
222
228
223
229
224
230
def check_array(array , accept_sparse = None , dtype = None , order = None , copy = False ,
225
- force_all_finite = True , ensure_2d = True , allow_nd = False ):
231
+ force_all_finite = True , ensure_2d = True , allow_nd = False ,
232
+ ensure_min_samples = 1 , ensure_min_features = 1 ):
226
233
"""Input validation on an array, list, sparse matrix or similar.
227
234
228
- By default, the input is converted to an at least 2nd numpy array.
235
+ By default, the input is converted to an at least 2d numpy array.
229
236
230
237
Parameters
231
238
----------
@@ -257,6 +264,16 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
257
264
allow_nd : boolean (default=False)
258
265
Whether to allow X.ndim > 2.
259
266
267
+ ensure_min_samples : int (default=1)
268
+ Make sure that the array has a minimum number of samples in its first
269
+ axis (rows for a 2D array). Setting to 0 disables this check.
270
+
271
+ ensure_min_features : int (default=1)
272
+ Make sure that the 2D array has some minimum number of features
273
+ (columns). The default value of 1 rejects empty datasets.
274
+ This check is only enforced when ``ensure_2d`` is True and
275
+ ``allow_nd`` is False. Setting to 0 disables this check.
276
+
260
277
Returns
261
278
-------
262
279
X_converted : object
@@ -278,12 +295,26 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
278
295
if force_all_finite :
279
296
_assert_all_finite (array )
280
297
298
+ if ensure_min_samples > 0 :
299
+ n_samples = _num_samples (array )
300
+ if n_samples < ensure_min_samples :
301
+ raise ValueError ("Found array with %d sample(s) (shape=%r) while a"
302
+ " minimum of %d is required."
303
+ % (n_samples , array .shape , ensure_min_samples ))
304
+
305
+ if ensure_min_features > 0 and ensure_2d and not allow_nd :
306
+ n_features = array .shape [1 ]
307
+ if n_features < ensure_min_features :
308
+ raise ValueError ("Found array with %d feature(s) (shape=%r) while"
309
+ " a minimum of %d is required."
310
+ % (n_features , array .shape , ensure_min_features ))
281
311
return array
282
312
283
313
284
314
def check_X_y (X , y , accept_sparse = None , dtype = None , order = None , copy = False ,
285
315
force_all_finite = True , ensure_2d = True , allow_nd = False ,
286
- multi_output = False ):
316
+ multi_output = False , ensure_min_samples = 1 ,
317
+ ensure_min_features = 1 ):
287
318
"""Input validation for standard estimators.
288
319
289
320
Checks X and y for consistent length, enforces X 2d and y 1d.
@@ -327,13 +358,24 @@ def check_X_y(X, y, accept_sparse=None, dtype=None, order=None, copy=False,
327
358
Whether to allow 2-d y (array or sparse matrix). If false, y will be
328
359
validated as a vector.
329
360
361
+ ensure_min_samples : int (default=1)
362
+ Make sure that X has a minimum number of samples in its first
363
+ axis (rows for a 2D array).
364
+
365
+ ensure_min_features : int (default=1)
366
+ Make sure that the 2D X has some minimum number of features
367
+ (columns). The default value of 1 rejects empty datasets.
368
+ This check is only enforced when ``ensure_2d`` is True and
369
+ ``allow_nd`` is False.
370
+
330
371
Returns
331
372
-------
332
373
X_converted : object
333
374
The converted and validated X.
334
375
"""
335
376
X = check_array (X , accept_sparse , dtype , order , copy , force_all_finite ,
336
- ensure_2d , allow_nd )
377
+ ensure_2d , allow_nd , ensure_min_samples ,
378
+ ensure_min_features )
337
379
if multi_output :
338
380
y = check_array (y , 'csr' , force_all_finite = True , ensure_2d = False )
339
381
else :
@@ -353,7 +395,7 @@ def column_or_1d(y, warn=False):
353
395
y : array-like
354
396
355
397
warn : boolean, default False
356
- To control display of warnings.
398
+ To control display of warnings.
357
399
358
400
Returns
359
401
-------
@@ -406,6 +448,7 @@ def check_random_state(seed):
406
448
raise ValueError ('%r cannot be used to seed a numpy.random.RandomState'
407
449
' instance' % seed )
408
450
451
+
409
452
def has_fit_parameter (estimator , parameter ):
410
453
"""Checks whether the estimator's fit method supports the given parameter.
411
454
@@ -512,4 +555,4 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
512
555
attributes = [attributes ]
513
556
514
557
if not all_or_any ([hasattr (estimator , attr ) for attr in attributes ]):
515
- raise NotFittedError (msg % {'name' : type (estimator ).__name__ })
558
+ raise NotFittedError (msg % {'name' : type (estimator ).__name__ })
0 commit comments