@@ -47,7 +47,7 @@ class BRIEF(DescriptorExtractor):
47
47
For matching across images, the same `rng` should be used to construct
48
48
descriptors. To facilitate this:
49
49
50
- (a) `rng` defauls to 1
50
+ (a) `rng` defaults to 1
51
51
(b) Subsequent calls of the ``extract`` method will use the same rng/seed.
52
52
sigma : float, optional
53
53
Standard deviation of the Gaussian low-pass filter applied to the image
@@ -134,10 +134,18 @@ def __init__(self, descriptor_size=256, patch_size=49,
134
134
self .mode = mode
135
135
self .sigma = sigma
136
136
137
- if rng is None :
138
- self .seed = np .random .SeedSequence ()
137
+ if isinstance (rng , np .random .Generator ):
138
+ # Spawn an independent RNG from parent RNG provided by the user.
139
+ # This is necessary so that we can safely deepcopy the RNG.
140
+ # See https://github.com/scikit-learn/scikit-learn/issues/16988#issuecomment-1518037853
141
+ bg = rng ._bit_generator
142
+ ss = bg ._seed_seq
143
+ child_ss , = ss .spawn (1 )
144
+ self .rng = np .random .Generator (type (bg )(child_ss ))
145
+ elif rng is None :
146
+ self .rng = np .random .default_rng (np .random .SeedSequence ())
139
147
else :
140
- self .seed = rng
148
+ self .rng = np . random . default_rng ( rng )
141
149
142
150
self .descriptors = None
143
151
self .mask = None
@@ -155,7 +163,8 @@ def extract(self, image, keypoints):
155
163
"""
156
164
check_nD (image , 2 )
157
165
158
- rng = np .random .default_rng (copy .deepcopy (self .seed ))
166
+ # Copy RNG so we can repeatedly call extract with the same random values
167
+ rng = copy .deepcopy (self .rng )
159
168
160
169
image = _prepare_grayscale_input_2D (image )
161
170
0 commit comments