|
12 | 12 |
|
13 | 13 | from sklearn.utils.deprecation import deprecated
|
14 | 14 | from sklearn.utils.metaestimators import available_if, if_delegate_has_method
|
| 15 | +from sklearn.utils._readonly_array_wrapper import _test_sum |
15 | 16 | from sklearn.utils._testing import (
|
16 | 17 | assert_raises,
|
17 | 18 | assert_warns,
|
@@ -680,30 +681,59 @@ def test_tempmemmap(monkeypatch):
|
680 | 681 | assert registration_counter.nb_calls == 2
|
681 | 682 |
|
682 | 683 |
|
683 |
| -def test_create_memmap_backed_data(monkeypatch): |
| 684 | +@pytest.mark.parametrize("aligned", [False, True]) |
| 685 | +def test_create_memmap_backed_data(monkeypatch, aligned): |
684 | 686 | registration_counter = RegistrationCounter()
|
685 | 687 | monkeypatch.setattr(atexit, "register", registration_counter)
|
686 | 688 |
|
687 | 689 | input_array = np.ones(3)
|
688 |
| - data = create_memmap_backed_data(input_array) |
| 690 | + data = create_memmap_backed_data(input_array, aligned=aligned) |
689 | 691 | check_memmap(input_array, data)
|
690 | 692 | assert registration_counter.nb_calls == 1
|
691 | 693 |
|
692 |
| - data, folder = create_memmap_backed_data(input_array, return_folder=True) |
| 694 | + data, folder = create_memmap_backed_data( |
| 695 | + input_array, return_folder=True, aligned=aligned |
| 696 | + ) |
693 | 697 | check_memmap(input_array, data)
|
694 | 698 | assert folder == os.path.dirname(data.filename)
|
695 | 699 | assert registration_counter.nb_calls == 2
|
696 | 700 |
|
697 | 701 | mmap_mode = "r+"
|
698 |
| - data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode) |
| 702 | + data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode, aligned=aligned) |
699 | 703 | check_memmap(input_array, data, mmap_mode)
|
700 | 704 | assert registration_counter.nb_calls == 3
|
701 | 705 |
|
702 | 706 | input_list = [input_array, input_array + 1, input_array + 2]
|
703 |
| - mmap_data_list = create_memmap_backed_data(input_list) |
704 |
| - for input_array, data in zip(input_list, mmap_data_list): |
705 |
| - check_memmap(input_array, data) |
706 |
| - assert registration_counter.nb_calls == 4 |
| 707 | + if aligned: |
| 708 | + with pytest.raises( |
| 709 | + ValueError, match="If aligned=True, input must be a single numpy array." |
| 710 | + ): |
| 711 | + create_memmap_backed_data(input_list, aligned=True) |
| 712 | + else: |
| 713 | + mmap_data_list = create_memmap_backed_data(input_list, aligned=False) |
| 714 | + for input_array, data in zip(input_list, mmap_data_list): |
| 715 | + check_memmap(input_array, data) |
| 716 | + assert registration_counter.nb_calls == 4 |
| 717 | + |
| 718 | + |
| 719 | +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32, np.int64]) |
| 720 | +def test_memmap_on_contiguous_data(dtype): |
| 721 | + """Test memory mapped array on contigous memoryview.""" |
| 722 | + x = np.arange(10).astype(dtype) |
| 723 | + assert x.flags["C_CONTIGUOUS"] |
| 724 | + assert x.flags["ALIGNED"] |
| 725 | + |
| 726 | + # _test_sum consumes contiguous arrays |
| 727 | + # def _test_sum(NUM_TYPES[::1] x): |
| 728 | + sum_origin = _test_sum(x) |
| 729 | + |
| 730 | + # now on memory mapped data |
| 731 | + # aligned=True so avoid https://github.com/joblib/joblib/issues/563 |
| 732 | + # without alignment, this can produce segmentation faults, see |
| 733 | + # https://github.com/scikit-learn/scikit-learn/pull/21654 |
| 734 | + x_mmap = create_memmap_backed_data(x, mmap_mode="r+", aligned=True) |
| 735 | + sum_mmap = _test_sum(x_mmap) |
| 736 | + assert sum_mmap == pytest.approx(sum_origin, rel=1e-11) |
707 | 737 |
|
708 | 738 |
|
709 | 739 | @pytest.mark.parametrize(
|
|
0 commit comments