@@ -431,6 +431,30 @@ def test_one_hot_encoder_inverse(sparse_, drop):
431
431
assert_raises_regex (ValueError , msg , enc .inverse_transform , X_tr )
432
432
433
433
434
+ @pytest .mark .parametrize ("method" , ['fit' , 'fit_transform' ])
435
+ @pytest .mark .parametrize ("X" , [
436
+ [1 , 2 ],
437
+ np .array ([3. , 4. ])
438
+ ])
439
+ def test_X_is_not_1D (X , method ):
440
+ oh = OneHotEncoder ()
441
+
442
+ msg = ("Expected 2D array, got 1D array instead" )
443
+ with pytest .raises (ValueError , match = msg ):
444
+ getattr (oh , method )(X )
445
+
446
+
447
+ @pytest .mark .parametrize ("method" , ['fit' , 'fit_transform' ])
448
+ def test_X_is_not_1D_pandas (method ):
449
+ pd = pytest .importorskip ('pandas' )
450
+ X = pd .Series ([6 , 3 , 4 , 6 ])
451
+ oh = OneHotEncoder ()
452
+
453
+ msg = ("Expected 2D array, got 1D array instead" )
454
+ with pytest .raises (ValueError , match = msg ):
455
+ getattr (oh , method )(X )
456
+
457
+
434
458
@pytest .mark .parametrize ("X, cat_exp, cat_dtype" , [
435
459
([['abc' , 55 ], ['def' , 55 ]], [['abc' , 'def' ], [55 ]], np .object_ ),
436
460
(np .array ([[1 , 2 ], [3 , 2 ]]), [[1 , 3 ], [2 ]], np .integer ),
@@ -569,8 +593,14 @@ def test_one_hot_encoder_feature_names_unicode():
569
593
@pytest .mark .parametrize ("X" , [np .array ([[1 , np .nan ]]).T ,
570
594
np .array ([['a' , np .nan ]], dtype = object ).T ],
571
595
ids = ['numeric' , 'object' ])
596
+ @pytest .mark .parametrize ("as_data_frame" , [False , True ],
597
+ ids = ['array' , 'dataframe' ])
572
598
@pytest .mark .parametrize ("handle_unknown" , ['error' , 'ignore' ])
573
- def test_one_hot_encoder_raise_missing (X , handle_unknown ):
599
+ def test_one_hot_encoder_raise_missing (X , as_data_frame , handle_unknown ):
600
+ if as_data_frame :
601
+ pd = pytest .importorskip ('pandas' )
602
+ X = pd .DataFrame (X )
603
+
574
604
ohe = OneHotEncoder (categories = 'auto' , handle_unknown = handle_unknown )
575
605
576
606
with pytest .raises (ValueError , match = "Input contains NaN" ):
@@ -579,7 +609,12 @@ def test_one_hot_encoder_raise_missing(X, handle_unknown):
579
609
with pytest .raises (ValueError , match = "Input contains NaN" ):
580
610
ohe .fit_transform (X )
581
611
582
- ohe .fit (X [:1 , :])
612
+ if as_data_frame :
613
+ X_partial = X .iloc [:1 , :]
614
+ else :
615
+ X_partial = X [:1 , :]
616
+
617
+ ohe .fit (X_partial )
583
618
584
619
with pytest .raises (ValueError , match = "Input contains NaN" ):
585
620
ohe .transform (X )
@@ -688,16 +723,18 @@ def test_encoder_dtypes_pandas():
688
723
pd = pytest .importorskip ('pandas' )
689
724
690
725
enc = OneHotEncoder (categories = 'auto' )
691
- exp = np .array ([[1. , 0. , 1. , 0. ], [0. , 1. , 0. , 1. ]], dtype = 'float64' )
726
+ exp = np .array ([[1. , 0. , 1. , 0. , 1. , 0. ],
727
+ [0. , 1. , 0. , 1. , 0. , 1. ]], dtype = 'float64' )
692
728
693
- X = pd .DataFrame ({'A' : [1 , 2 ], 'B' : [3 , 4 ]}, dtype = 'int64' )
729
+ X = pd .DataFrame ({'A' : [1 , 2 ], 'B' : [3 , 4 ], 'C' : [ 5 , 6 ] }, dtype = 'int64' )
694
730
enc .fit (X )
695
731
assert all ([enc .categories_ [i ].dtype == 'int64' for i in range (2 )])
696
732
assert_array_equal (enc .transform (X ).toarray (), exp )
697
733
698
- X = pd .DataFrame ({'A' : [1 , 2 ], 'B' : ['a' , 'b' ]})
734
+ X = pd .DataFrame ({'A' : [1 , 2 ], 'B' : ['a' , 'b' ], 'C' : [3. , 4. ]})
735
+ X_type = [int , object , float ]
699
736
enc .fit (X )
700
- assert all ([enc .categories_ [i ].dtype == 'object' for i in range (2 )])
737
+ assert all ([enc .categories_ [i ].dtype == X_type [ i ] for i in range (3 )])
701
738
assert_array_equal (enc .transform (X ).toarray (), exp )
702
739
703
740
0 commit comments