8000 BUG: fix GaussianHMM.fit to allow input sequences of different lengths · seckcoder/scikit-learn@d1c6e5a · GitHub
[go: up one dir, main page]

Skip to content

Commit d1c6e5a

Browse files
ronwFabian Pedregosa
authored and
Fabian Pedregosa
committed
BUG: fix GaussianHMM.fit to allow input sequences of different lengths
1 parent 2377694 commit d1c6e5a

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

scikits/learn/hmm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,6 @@ def fit(self, obs, n_iter=10, thresh=1e-2, params=string.letters,
327327
small). You can fix this by getting more training data, or
328328
decreasing `covars_prior`.
329329
"""
330-
obs = np.asanyarray(obs)
331-
332330
self._init(obs, init_params)
333331

334332
logprob = []
@@ -679,11 +677,13 @@ def _generate_sample_from_state(self, state):
679677
def _init(self, obs, params='stmc'):
680678
super(GaussianHMM, self)._init(obs, params=params)
681679

682-
if hasattr(self, 'n_features') and self.n_features != obs.shape[2]:
680+
if (hasattr(self, 'n_features')
681+
and self.n_features != obs[0].shape[1]):
683682
raise ValueError('Unexpected number of dimensions, got %s but '
684-
'expected %s' % (obs.shape[2], self.n_features))
683+
'expected %s' % (obs[0].shape[1],
684+
self.n_features))
685685

686-
self.n_features = obs.shape[2]
686+
self.n_features = obs[0].shape[1]
687687

688688
if 'm' in params:
689689
self._means = cluster.KMeans(

scikits/learn/tests/test_hmm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,16 @@ def test_fit(self, params='stmc', n_iter=15, verbose=False, **kwargs):
330330
% (self.cvtype, params, trainll, np.diff(trainll)))
331331
self.assertTrue(np.all(np.diff(trainll) > -0.5))
332332

333+
def test_fit_works_on_sequences_of_different_length(self):
334+
obs = [np.random.rand(3, self.n_features),
335+
np.random.rand(4, self.n_features),
336+
np.random.rand(5, self.n_features)]
337+
338+
h = hmm.GaussianHMM(self.n_states, self.cvtype)
339+
# This shouldn't raise
340+
# ValueError: setting an array element with a sequence.
341+
h.fit(obs)
342+
333343
def test_fit_with_priors(self, params='stmc', n_iter=10,
334344
verbose=False):
335345
startprob_prior = 10 * self.startprob + 2.0
@@ -612,6 +622,16 @@ def test_fit(self, params='stmwc', n_iter=5, verbose=True, **kwargs):
612622
np.diff(trainll))
613623
self.assertTrue(np.all(np.diff(trainll) > -0.5))
614624

625+
def test_fit_works_on_sequences_of_different_length(self):
626+
obs = [np.random.rand(3, self.n_features),
627+
np.random.rand(4, self.n_features),
628+
np.random.rand(5, self.n_features)]
629+
630+
h = hmm.GMMHMM(self.n_states, cvtype=self.cvtype)
631+
# This shouldn't raise
632+
# ValueError: setting an array element with a sequence.
633+
h.fit(obs)
634+
615635

616636
class TestGMMHMMWithSphericalCovars(TestGMMHMM):
617637
cvtype = 'spherical'

scikits/learn/tests/test_mixture.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ def test_GMM_attributes():
169169
assert_raises(ValueError, mixture.GMM, n_states=20, cvtype='badcvtype')
170170

171171

172+
def test_GMM_fit_works_on_sequences_of_different_length():
173+
ndim = 3
174+
obs = [np.random.rand(3, ndim),
175+
np.random.rand(4, ndim),
176+
np.random.rand(5, ndim)]
177+
178+
gmm = mixture.GMM(n_states=1)
179+
# This shouldn't raise
180+
# ValueError: setting an array element with a sequence.
181+
gmm.fit(obs)
182+
183+
172184
class GMMTester():
173185
n_states = 10
174186
n_features = 4

0 commit comments

Comments
 (0)
0