8000 Fix fetch_rcv1 · scikit-learn/scikit-learn@281b631 · GitHub
[go: up one dir, main page]

Skip to content

Commit 281b631

Browse files
committed
Fix fetch_rcv1
1 parent efa63b2 commit 281b631

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

sklearn/datasets/rcv1.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,23 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
166166

167167
Xy = load_svmlight_files(files, n_features=N_FEATURES)
168168

169-
# delete archives
170-
for f in files:
171-
remove(f.name)
172-
173169
# Training data is before testing data
174170
X = sp.vstack([Xy[8], Xy[0], Xy[2], Xy[4], Xy[6]]).tocsr()
175171
sample_id = np.hstack((Xy[9], Xy[1], Xy[3], Xy[5], Xy[7]))
176172
sample_id = sample_id.astype(np.uint32)
177173

178174
joblib.dump(X, samples_path, compress=9)
179175
joblib.dump(sample_id, sample_id_path, compress=9)
176+
177+
# delete archives
178+
for f in files:
179+
f.close()
180+
remove(f.name)
180181
else:
181182
X = joblib.load(samples_path)
182183
sample_id = joblib.load(sample_id_path)
183184

185+
184186
# load target (y), categories, and sample_id_bis
185187
if download_if_missing and (not exists(sample_topics_path) or
186188
not exists(topics_path)):
@@ -195,20 +197,21 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
195197
y = np.zeros((N_SAMPLES, N_CATEGORIES), dtype=np.uint8)
196198
sample_id_bis = np.zeros(N_SAMPLES, dtype=np.int32)
197199
category_names = {}
198-
for line in GzipFile(filename=topics_archive_path, mode='rb'):
199-
line_components = line.decode("ascii").split(u" ")
200-
if len(line_components) == 3:
201-
cat, doc, _ = line_components
202-
if cat not in category_names:
203-
n_cat += 1
204-
category_names[cat] = n_cat
205-
206-
doc = int(doc)
207-
if doc != doc_previous:
208-
doc_previous = doc
209-
n_doc += 1
210-
sample_id_bis[n_doc] = doc
211-
y[n_doc, category_names[cat]] = 1
200+
with GzipFile(filename=topics_archive_path, mode='rb') as f:
201+
for line in f:
202+
line_components = line.decode("ascii").split(u" ")
203+
if len(line_components) == 3:
204+
cat, doc, _ = line_components
205+
if cat not in category_names:
206+
n_cat += 1
207+
category_names[cat] = n_cat
208+
209+
doc = int(doc)
210+
if doc != doc_previous:
211+
doc_previous = doc
212+
n_doc += 1
213+
sample_id_bis[n_doc] = doc
214+
y[n_doc, category_names[cat]] = 1
212215

213216
# delete archive
214217
remove(topics_archive_path)

0 commit comments

Comments
 (0)
0