8000 dont patch setRNGState when generator is used · google-deepmind/torch-randomkit@ef4ec1e · GitHub
[go: up one dir, main page]

Skip to content

Commit ef4ec1e

Browse files
author
James Kirkpatrick
committed
dont patch setRNGState when generator is used
1 parent bbcdabe commit ef4ec1e

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

luasrc/wrapC.lua

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,21 @@ randomkit.ffi.rk_seed(0, randomkit._state)
174174

175175
-- Extend torch state handling to handle randomkit's state too
176176
local _manualSeed = torch.manualSeed
177-
torch.manualSeed = function(seed)
177+
torch.manualSeed = function(generator, seed)
178+
if seed then
179+
return _manualSeed(generator, seed)
180+
else
181+
seed = generator
182+
end
178183
randomkit.ffi.rk_seed(0, randomkit._state)
179184
return _manualSeed(seed)
180185
end
181186

182187
local _getRNGState = torch.getRNGState
183-
torch.getRNGState = function()
188+
torch.getRNGState = function(generator)
189+
if generator then
190+
return _getRNGState(generator)
191+
end
184192
-- Serialize to string, required to write to file
185193
local clonedState = ffi.string(randomkit._state, ffi.sizeof(randomkit._state))
186194
return {
@@ -190,7 +198,12 @@ torch.getRNGState = function()
190198
end
191199

192200
local _setRNGState = torch.setRNGState
193-
torch.setRNGState = function(state)
201+
torch.setRNGState = function(generator, state)
202+
if state then
203+
return _setRNGState(generator, state)
204+
else
205+
state = generator
206+
end
194207
if not type(state) == 'table' or not state.torch or not state.randomkit then
195208
error('State was not saved with randomkit, cannot set it back')
196209
end

0 commit comments

Comments
 (0)
0