8000 Improving Bunch Class to ensure consistent attributes · scikit-learn/scikit-learn@a4d0c0f · GitHub
[go: up one dir, main page]

Skip to content

Commit a4d0c0f

Browse files
committed
Improving Bunch Class to ensure consistent attributes
Adding set/getattr methods that fill/query the same thing as `bunch[key]`. Add test for a non-regression bug in fetch_20newsgroups.
1 parent 0fe613e commit a4d0c0f

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

sklearn/datasets/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,43 @@
1818
from os.path import isdir
1919
from os import listdir
2020
from os import makedirs
21+
import re
2122

2223
import numpy as np
2324

2425
from ..utils import check_random_state
2526

2627

2728
class Bunch(dict):
28-
"""Container object for datasets: dictionary-like object that
29-
exposes its keys as attributes."""
29+
"""Container object for datasets
30+
31+
Dictionary-like object that exposes its keys as attributes.
32+
33+
>>> b = Bunch(a=1, b=2)
34+
>>> b['b']
35+
2
36+
>>> b.b
37+
2
38+
>>> b.a = 3
39+
>>> b['a']
40+
3
41+
>>> b.c = 6
42+
>>> b['c']
43+
6
44+
45+
"""
3046

3147
def __init__(self, **kwargs):
3248
dict.__init__(self, kwargs)
33-
self.__dict__ = self
49+
50+
def __setattr__(self, key, value):
51+
self[key] = value
52+
53+
def __getattr__(self, key):
54+
return self[key]
55+
56+
def __getstate__(self):
57+
return self.__dict__
3458

3559

3660
def get_data_home(data_home=None):

sklearn/datasets/tests/test_20news.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def test_20news():
3838
entry2 = data.data[np.where(data.target == label)[0][0]]
3939
assert_equal(entry1, entry2)
4040

41+
def test_20news_length_consistency():
42+
"""Checks the length consistencies within the bunch"""
43+
try:
44+
data = datasets.fetch_20newsgroups(
45+
subset='all', download_if_missing=False, shuffle=False)
46+
except IOError:
47+
raise SkipTest("Download 20 newsgroups to run this test")
48+
# Extract the full dataset
49+
data = datasets.fetch_20newsgroups(subset='all')
50+
assert_equal(len(data['data']), len(data.data))
51+
assert_equal(len(data['target']), len(data.target))
52+
assert_equal(len(data['filenames']), len(data.filenames))
53+
4154

4255
def test_20news_vectorized():
4356
# This test is slow.

sklearn/datasets/twenty_newsgroups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
161161
for the test set, 'all' for both, with shuffled ordering.
162162
163163
data_home: optional, default: None
164-
Specify an download and cache folder for the datasets. If None,
164+
Specify a download and cache folder for the datasets. If None,
165165
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
166166
167167
categories: None or collection of string or unicode
@@ -231,9 +231,9 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
231231
target.extend(data.target)
232232
filenames.extend(data.filenames)
233233

234-
data.data = data_lst
235-
data.target = np.array(target)
236-
data.filenames = np.array(filenames)
234+
data['data'] = data_lst
235+
data['target'] = np.array(target)
236+
data['filenames'] = np.array(filenames)
237237
else:
238238
raise ValueError(
239239
"subset can only be 'train', 'test' or 'all', got '%s'" % subset)

0 commit comments

Comments
 (0)
0