@@ -719,6 +719,47 @@ def transform(self, X):
719719 trans .transform (df_mixed )
720720
721721
722+ def test_validate_data_cast_to_ndarray ():
723+ """Check cast_to_ndarray option of _validate_data."""
724+
725+ pd = pytest .importorskip ("pandas" )
726+ iris = datasets .load_iris ()
727+ df = pd .DataFrame (iris .data , columns = iris .feature_names )
728+ y = pd .Series (iris .target )
729+
730+ class NoOpTransformer (TransformerMixin , BaseEstimator ):
731+ pass
732+
733+ no_op = NoOpTransformer ()
734+ X_np_out = no_op ._validate_data (df , cast_to_ndarray = True )
735+ assert isinstance (X_np_out , np .ndarray )
736+ assert_allclose (X_np_out , df .to_numpy ())
737+
738+ X_df_out = no_op ._validate_data (df , cast_to_ndarray = False )
739+ assert X_df_out is df
740+
741+ y_np_out = no_op ._validate_data (y = y , cast_to_ndarray = True )
742+ assert isinstance (y_np_out , np .ndarray )
743+ assert_allclose (y_np_out , y .to_numpy ())
744+
745+ y_series_out = no_op ._validate_data (y = y , cast_to_ndarray = False )
746+ assert y_series_out is y
747+
748+ X_np_out , y_np_out = no_op ._validate_data (df , y , cast_to_ndarray = True )
749+ assert isinstance (X_np_out , np .ndarray )
750+ assert_allclose (X_np_out , df .to_numpy ())
751+ assert isinstance (y_np_out , np .ndarray )
752+ assert_allclose (y_np_out , y .to_numpy ())
753+
754+ X_df_out , y_series_out = no_op ._validate_data (df , y , cast_to_ndarray = False )
755+ assert X_df_out is df
756+ assert y_series_out is y
757+
758+ msg = "Validation should be done on X, y or both."
759+ with pytest .raises (ValueError , match = msg ):
760+ no_op ._validate_data ()
761+
762+
722763def test_clone_keeps_output_config ():
723764 """Check that clone keeps the set_output config."""
724765
0 commit comments