10000 [MRG] FIX Avoid default mutable argument in constructor of Agglomerat… · scikit-learn/scikit-learn@0b02125 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b02125

Browse files
glemaitrejnothman
authored andcommitted
[MRG] FIX Avoid default mutable argument in constructor of AgglomerativeClustering (#8153)
1 parent d0ce0d9 commit 0b02125

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sklearn/cluster/hierarchical.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
660660
"""
661661

662662
def __init__(self, n_clusters=2, affinity="euclidean",
663-
memory=Memory(cachedir=None, verbose=0),
663+
memory=None,
664664
connectivity=None, compute_full_tree='auto',
665665
linkage='ward', pooling_func=np.mean):
666666
self.n_clusters = n_clusters
@@ -685,8 +685,13 @@ def fit(self, X, y=None):
685685
"""
686686
X = check_array(X, ensure_min_samples=2, estimator=self)
687687
memory = self.memory
688-
if isinstance(memory, six.string_types):
688+
if memory is None:
689+
memory = Memory(cachedir=None, verbose=0)
690+
elif isinstance(memory, six.string_types):
689691
memory = Memory(cachedir=memory, verbose=0)
692+
elif not isinstance(memory, Memory):
693+
raise ValueError('`memory` has to be a `str` or a `joblib.Memory`'
694+
' instance')
690695

691696
if self.n_clusters <= 0:
692697
raise ValueError("n_clusters should be an integer greater than 0."

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,17 @@ def test_height_linkage_tree():
115115
assert 92AF _true(len(children) + n_leaves == n_nodes)
116116

117117

118+
def test_agglomerative_clustering_wrong_arg_memory():
119+
# Test either if an error is raised when memory is not
120+
# either a str or a joblib.Memory instance
121+
rng = np.random.RandomState(0)
122+
n_samples = 100
123+
X = rng.randn(n_samples, 50)
124+
memory = 5
125+
clustering = AgglomerativeClustering(memory=memory)
126+
assert_raises(ValueError, clustering.fit, X)
127+
128+
118129
def test_agglomerative_clustering():
119130
# Check that we obtain the correct number of clusters with
120131
# agglomerative clustering.

0 commit comments

Comments
 (0)
0