8000 If user provides RNG, spawn it before deepcopying (#6948) · scikit-image/scikit-image@f2745cf · GitHub
[go: up one dir, main page]

Skip to content

Commit f2745cf

Browse files
authored
If user provides RNG, spawn it before deepcopying (#6948)
Otherwise, the user can still draw values from the RNG and change its state. See scikit-learn/scikit-learn#16988 (comment)
1 parent 3998940 commit f2745cf

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

skimage/feature/brief.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class BRIEF(DescriptorExtractor):
4747
For matching across images, the same `rng` should be used to construct
4848
descriptors. To facilitate this:
4949
50-
(a) `rng` defauls to 1
50+
(a) `rng` defaults to 1
5151
(b) Subsequent calls of the ``extract`` method will use the same rng/seed.
5252
sigma : float, optional
5353
Standard deviation of the Gaussian low-pass filter applied to the image
@@ -134,10 +134,18 @@ def __init__(self, descriptor_size=256, patch_size=49,
134134
self.mode = mode
135135
self.sigma = sigma
136136

137-
if rng is None:
138-
self.seed = np.random.SeedSequence()
137+
if isinstance(rng, np.random.Generator):
138+
# Spawn an independent RNG from parent RNG provided by the user.
139+
# This is necessary so that we can safely deepcopy the RNG.
140+
# See https://github.com/scikit-learn/scikit-learn/issues/16988#issuecomment-1518037853
141+
bg = rng._bit_generator
142+
ss = bg._seed_seq
143+
child_ss, = ss.spawn(1)
144+
self.rng = np.random.Generator(type(bg)(child_ss))
145+
elif rng is None:
146+
self.rng = np.random.default_rng(np.random.SeedSequence())
139147
else:
140-
self.seed = rng
148+
self.rng = np.random.default_rng(rng)
141149

142150
self.descriptors = None
143151
self.mask = None
@@ -155,7 +163,8 @@ def extract(self, image, keypoints):
155163
"""
156164
check_nD(image, 2)
157165

158-
rng = np.random.default_rng(copy.deepcopy(self.seed))
166+
# Copy RNG so we can repeatedly call extract with the same random values
167+
rng = copy.deepcopy(self.rng)
159168

160169
image = _prepare_grayscale_input_2D(image)
161170

skimage/feature/tests/test_brief.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
import copy
3+
24
import numpy as np
35

46
from skimage._shared.testing import assert_array_equal
@@ -80,3 +82,18 @@ def test_border(dtype):
8082

8183
assert extractor.descriptors.shape[0] == 3
8284
assert_array_equal(extractor.mask, (False, True, True, True))
85+
86+
87+
def test_independent_rng():
88+
img = np.zeros((100, 100), dtype=int)
89+
keypoints = np.array([[1, 1], [20, 20], [50, 50], [80, 80]])
90+
91+
rng = np.random.default_rng()
92+
extractor = BRIEF(patch_size=41, rng=rng)
93+
94+
x = copy.deepcopy(extractor.rng).random()
95+
rng.random()
96+
extractor.extract(img, keypoints)
97+
z = copy.deepcopy(extractor.rng).random()
98+
99+
assert x == z

0 commit comments

Comments
 (0)
0