8000 Document config_context and transform output (#25289) · jjerphan/scikit-learn@be169bf · GitHub
[go: up one dir, main page]

Skip to content

Commit be169bf

Browse files
ravwojdylajjerphan
authored andcommitted
Document config_context and transform output (scikit-learn#25289)
1 parent 8618ac2 commit be169bf

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

examples/miscellaneous/plot_set_output.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@
8484
set_config(transform_output="pandas")
8585

8686
num_pipe = make_pipeline(SimpleImputer(), StandardScaler())
87+
num_cols = ["age", "fare"]
8788
ct = ColumnTransformer(
8889
(
89-
("numerical", num_pipe, ["age", "fare"]),
90+
("numerical", num_pipe, num_cols),
9091
(
9192
"categorical",
9293
OneHotEncoder(
@@ -114,3 +115,24 @@
114115
# This resets `transform_output` to its default value to avoid impacting other
115116
# examples when generating the scikit-learn documentation
116117
set_config(transform_output="default")
118+
119+
# %%
120+
# When configuring the output type with :func:`config_context` the
121+
# configuration at the time when `transform` or `fit_transform` are
122+
# called is what counts. Setting these only when you construct or fit
123+
# the transformer has no effect.
124+
from sklearn import config_context
125+
126+
scaler = StandardScaler()
127+
scaler.fit(X_train[num_cols])
128+
129+
# %%
130+
with config_context(transform_output="pandas"):
131+
# the output of transform will be a Pandas DataFrame
132+
X_test_scaled = scaler.transform(X_test[num_cols])
133+
X_test_scaled.head()
134+
135+
# %%
136+
# outside of the context manager, the output will be a NumPy array
137+
X_test_scaled = scaler.transform(X_test[num_cols])
138+
X_test_scaled[:5]

0 commit comments

Comments
 (0)
0