8000 MNT Replace pytest.warns(None) in test_utils (#23137) · thomasjpfan/scikit-learn@a739f6c · GitHub
[go: up one dir, main page]

Skip to content

Commit a739f6c

Browse files
authored
MNT Replace pytest.warns(None) in test_utils (scikit-learn#23137)
1 parent 41d7bd2 commit a739f6c

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

sklearn/utils/tests/test_utils.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -557,27 +557,47 @@ def test_gen_even_slices():
557557

558558

559559
@pytest.mark.parametrize(
560-
("row_bytes", "max_n_rows", "working_memory", "expected", "warn_msg"),
560+
("row_bytes", "max_n_rows", "working_memory", "expected"),
561561
[
562-
(1024, None, 1, 1024, None),
563-
(1024, None, 0.99999999, 1023, None),
564-
(1023, None, 1, 1025, None),
565-
(1025, None, 1, 1023, None),
566-
(1024, None, 2, 2048, None),
567-
(1024, 7, 1, 7, None),
568-
(1024 * 1024, None, 1, 1, None),
569-
(
570-
1024 * 1024 + 1,
571-
None,
572-
1,
573-
1,
574-
"Could not adhere to working_memory config. Currently 1MiB, 2MiB required.",
575-
),
562+
(1024, None, 1, 1024),
563+
(1024, None, 0.99999999, 1023),
564+
(1023, None, 1, 1025),
565+
(1025, None, 1, 1023),
566+
(1024, None, 2, 2048),
567+
(1024, 7, 1, 7),
568+
(1024 * 1024, None, 1, 1),
576569
],
577570
)
578-
def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected, warn_msg):
579-
warning = None if warn_msg is None else UserWarning
580-
with pytest.warns(warning, match=warn_msg) as w:
571+
def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected):
572+
with warnings.catch_warnings():
573+
warnings.simplefilter("error", UserWarning)
574+
actual = get_chunk_n_rows(
575+
row_bytes=row_bytes,
576+
max_n_rows=max_n_rows,
577+
working_memory=working_memory,
578+
)
579+
580+
assert actual == expected
581+
assert type(actual) is type(expected)
582+
with config_context(working_memory=working_memory):
583+
with warnings.catch_warnings():
584+
warnings.simplefilter("error", UserWarning)
585+
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
586+
assert actual == expected
587+
assert type(actual) is type(expected)
588+
589+
590+
def test_get_chunk_n_rows_warns():
591+
"""Check that warning is raised when working_memory is too low."""
592+
row_bytes = 1024 * 1024 + 1
593+
max_n_rows = None
594+
working_memory = 1
595+
expected = 1
596+
597+
warn_msg = (
598+
"Could not adhere to working_memory config. Currently 1MiB, 2MiB required."
599+
)
600+
with pytest.warns(UserWarning, match=warn_msg):
581601
actual = get_chunk_n_rows(
582602
row_bytes=row_bytes,
583603
max_n_rows=max_n_rows,
@@ -586,15 +606,12 @@ def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected, warn_
586606

587607
assert actual == expected
588608
assert type(actual) is type(expected)
589-
if warn_msg is None:
590-
assert len(w) == 0
609+
591610
with config_context(working_memory=working_memory):
592-
with pytest.warns(warning, match=warn_msg) as w:
611+
with pytest.warns(UserWarning, match=warn_msg):
593612
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
594613
assert actual == expected
595614
assert type(actual) is type(expected)
596-
if warn_msg is None:
597-
assert len(w) == 0
598615

599616

600617
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)
0