8000 Adjust setRNGState so that it updates the pointer to the generator. · google-deepmind/torch-randomkit@0003434 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0003434

Browse files
committed
Adjust setRNGState so that it updates the pointer to the generator.
Otherwise, restoring a state from a previous run can set the pointer to one that is no longer valid.
1 parent c8be331 commit 0003434

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

luasrc/tests/testWrap.lua

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ function myTest.callsetRNGWithGenerator()
5454
tester:assert(ok, 'Failed to set RNG state')
5555
end
5656

57+
function myTest.setRNGWithBadPointer()
58+
-- To simulate restoring the state from a previous run, we invalidate the
59+
-- pointer to the main Torch generator in a state that we are passing to
60+
-- setRNGState, and check that this doesn't break things.
61+
local state = torch.getRNGState()
62+
local x = tonumber(randomkit.binomial(10, 0.4))
63+
local badState = ffi.cast("rk_state *", state.randomkit)
64+
badState.torch_state = ffi.cast("THGenerator*", 0)
65+
torch.setRNGState{
66+
torch = state.torch,
67+
randomkit = badState
68+
}
69+
tester:asserteq(x, tonumber(randomkit.binomial(10, 0.4)))
70+
end
71+
5772

5873
tester:add(myTest)
5974
return tester:run()

luasrc/wrapC.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ torch.setRNGState = function(generator, state)
210210
_setRNGState(state.torch)
211211
-- Deserialize from string
212212
ffi.copy(randomkit._state, state.randomkit, ffi.sizeof(randomkit._state))
213+
214+
-- If the state being set is from a previous run, the pointer to the main
215+
-- Torch generator will no longer be valid. So we will explicitly update the
216+
-- pointer to be current.
217+
randomkit._state.torch_state =
218+
ffi.cast("THGenerator*", torch.pointer(torch._gen))
213219
end
214220

215221
local returnTypeMapping = {

0 commit comments

Comments
 (0)
0