File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -54,6 +54,21 @@ function myTest.callsetRNGWithGenerator()
54
54
tester :assert (ok , ' Failed to set RNG state' )
55
55
end
56
56
57
+ function myTest .setRNGWithBadPointer ()
58
+ -- To simulate restoring the state from a previous run, we invalidate the
59
+ -- pointer to the main Torch generator in a state that we are passing to
60
+ -- setRNGState, and check that this doesn't break things.
61
+ local state = torch .getRNGState ()
62
+ local x = tonumber (randomkit .binomial (10 , 0.4 ))
63
+ local badState = ffi .cast (" rk_state *" , state .randomkit )
64
+ badState .torch_state = ffi .cast (" THGenerator*" , 0 )
65
+ torch .setRNGState {
66
+ torch = state .torch ,
67
+ randomkit = badState
68
+ }
69
+ tester :asserteq (x , tonumber (randomkit .binomial (10 , 0.4 )))
70
+ end
71
+
57
72
58
73
tester :add (myTest )
59
74
return tester :run ()
Original file line number Diff line number Diff line change @@ -210,6 +210,12 @@ torch.setRNGState = function(generator, state)
210
210
_setRNGState (state .torch )
211
211
-- Deserialize from string
212
212
ffi .copy (randomkit ._state , state .randomkit , ffi .sizeof (randomkit ._state ))
213
+
214
+ -- If the state being set is from a previous run, the pointer to the main
215
+ -- Torch generator will no longer be valid. So we will explicitly update the
216
+ -- pointer to be current.
217
+ randomkit ._state .torch_state =
218
+ ffi .cast (" THGenerator*" , torch .pointer (torch ._gen ))
213
219
end
214
220
215
221
local returnTypeMapping = {
You can’t perform that action at this time.
0 commit comments