8000 add sparse_threshold to make_columns_transformer (#12152) · lithuak/scikit-learn@530900f · GitHub
[go: up one dir, main page]

Skip to content

Commit 530900f

Browse files
datajankoamueller
authored andcommitted
add sparse_threshold to make_columns_transformer (scikit-learn#12152)
fixes issue scikit-learn#12149
1 parent 353b1de commit 530900f

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

sklearn/compose/_column_transformer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,14 @@ def make_column_transformer(*transformers, **kwargs):
689689
non-specified columns will use the ``remainder`` estimator. The
690690
estimator must support `fit` and `transform`.
691691
692+
sparse_threshold : float, default = 0.3
693+
If the transformed output consists of a mix of sparse and dense data,
694+
it will be stacked as a sparse matrix if the density is lower than this
695+
value. Use ``sparse_threshold=0`` to always return dense.
696+
When the transformed output consists of all sparse or all dense data,
697+
the stacked result will be sparse or dense, respectively, and this
698+
keyword will be ignored.
699+
692700
n_jobs : int or None, optional (default=None)
693701
Number of jobs to run in parallel.
694702
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
@@ -725,9 +733,11 @@ def make_column_transformer(*transformers, **kwargs):
725733
"""
726734
n_jobs = kwargs.pop('n_jobs', None)
727735
remainder = kwargs.pop('remainder', 'drop')
736+
sparse_threshold = kwargs.pop('sparse_threshold', 0.3)
728737
if kwargs:
729738
raise TypeError('Unknown keyword arguments: "{}"'
730739
.format(list(kwargs.keys())[0]))
731740
transformer_list = _get_transformer_list(transformers)
732741
return ColumnTransformer(transformer_list, n_jobs=n_jobs,
733-
remainder=remainder)
742+
remainder=remainder,
743+
sparse_threshold=sparse_threshold)

sklearn/compose/tests/test_column_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,13 @@ def test_make_column_transformer_kwargs():
453453
scaler = StandardScaler()
454454
norm = Normalizer()
455455
ct = make_column_transformer(('first', scaler), (['second'], norm),
456-
n_jobs=3, remainder='drop')
456+
n_jobs=3, remainder='drop',
457+
sparse_threshold=0.3)
457458
assert_equal(ct.transformers, make_column_transformer(
458459
('first', scaler), (['second'], norm)).transformers)
459460
assert_equal(ct.n_jobs, 3)
460461
assert_equal(ct.remainder, 'drop')
462+
assert_equal(ct.sparse_threshold, 0.3)
461463
# invalid keyword parameters should raise an error message
462464
assert_raise_message(
463465
TypeError,

0 commit comments

Comments
 (0)
0