@@ -174,13 +174,21 @@ randomkit.ffi.rk_seed(0, randomkit._state)
174
174
175
175
-- Extend torch state handling to handle randomkit's state too
176
176
local _manualSeed = torch .manualSeed
177
- torch .manualSeed = function (seed )
177
+ torch .manualSeed = function (generator , seed )
178
+ if seed then
179
+ return _manualSeed (generator , seed )
180
+ else
181
+ seed = generator
182
+ end
178
183
randomkit .ffi .rk_seed (0 , randomkit ._state )
179
184
return _manualSeed (seed )
180
185
end
181
186
182
187
local _getRNGState = torch .getRNGState
183
- torch .getRNGState = function ()
188
+ torch .getRNGState = function (generator )
189
+ if generator then
190
+ return _getRNGState (generator )
191
+ end
184
192
-- Serialize to string, required to write to file
185
193
local clonedState = ffi .string (randomkit ._state , ffi .sizeof (randomkit ._state ))
186
194
return {
@@ -190,7 +198,12 @@ torch.getRNGState = function()
190
198
end
191
199
192
200
local _setRNGState = torch .setRNGState
193
- torch .setRNGState = function (state )
201
+ torch .setRNGState = function (generator , state )
202
+ if state then
203
+ return _setRNGState (generator , state )
204
+ else
205
+ state = generator
206
+ end
194
207
if not type (state ) == ' table' or not state .torch or not state .randomkit then
195
208
error (' State was not saved with randomkit, cannot set it back' )
196
209
end
0 commit comments