8000 Fix test_svmlight_format.py::test_dump on ARM · scikit-learn/scikit-learn@50bf0d6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 50bf0d6

Browse files
committed
Fix test_svmlight_format.py::test_dump on ARM
1 parent f8b108d commit 50bf0d6

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_dump():
226226
for X in (X_sparse, X_dense, X_sliced):
227227
for y in (y_sparse, y_dense, y_sliced):
228228
for zero_based in (True, False):
229-
for dtype in [np.float32, np.float64, np.int32]:
229+
for dtype in [np.float32, np.float64, np.int32, np.int64]:
230230
f = BytesIO()
231231
# we need to pass a comment to get the version info in;
232232
# LibSVM doesn't grok comments so they're not put in by
@@ -237,7 +237,13 @@ def test_dump():
237237
# when it is sparse
238238
y = y.T
239239

240-
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
240+
# Note: with dtype=np.int32 we are performing unsafe casts,
241+
# where X.astype(dtype) overflows. The result is
242+
# then platform dependent and X_dense.astype(dtype) may be
243+
# different from X_sparse.astype(dtype).asarray().
244+
X_input = X.astype(dtype)
245+
246+
dump_svmlight_file(X_input, y, f, comment="test",
241247
zero_based=zero_based)
242248
f.seek(0)
243249

@@ -257,17 +263,21 @@ def test_dump():
257263
assert_array_equal(X2.sorted_indices().indices, X2.indices)
258264

259265
X2_dense = X2.toarray()
266+
if sp.issparse(X_input):
267+
X_input_dense = X_input.toarray()
268+
else:
269+
X_input_dense = X_input
260270

261271
if dtype == np.float32:
262272
# allow a rounding error at the last decimal place
263273
assert_array_almost_equal(
264-
X_dense.astype(dtype), X2_dense, 4)
274+
X_input_dense, X2_dense, 4)
265275
assert_array_almost_equal(
266276
y_dense.astype(dtype), y2, 4)
267277
else:
268278
# allow a rounding error at the last decimal place
269279
assert_array_almost_equal(
270-
X_dense.astype(dtype), X2_dense, 15)
280+
X_input_dense, X2_dense, 15)
271281
assert_array_almost_equal(
272282
y_dense.astype(dtype), y2, 15)
273283

0 commit comments

Comments
 (0)
0