10000 `torch.cuda.manual_seed` ignored · Issue #149621 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.cuda.manual_seed ignored #149621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vwrewsge opened this issue Mar 20, 2025 · 3 comments
Open

torch.cuda.manual_seed ignored #149621

vwrewsge opened this issue Mar 20, 2025 · 3 comments
Labels
module: random Related to random number generation in PyTorch (rng generator) oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vwrewsge
Copy link
vwrewsge commented Mar 20, 2025

🐛 Describe the bug

When using torch.compile, torch.cuda.manual_seed/torch.cuda.manual_seed_all/torch.cuda.random.manual_seed do not seem to properly enforce reproducibility across multiple calls to a compiled function.

torch.cuda.manual_seed

Code:

import torch
import torch._inductor.config

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    # Set the GPU seed
    torch.cuda.manual_seed(3)
    # Create a random tensor on the GPU.
    # If a CUDA device is available, the tensor will be created on CUDA.
    return torch.rand(4, device='cuda' if torch.cuda.is_available() else 'cpu')

# Call the compiled function twice
print("cuda.is_available:", torch.cuda.is_available())
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

cuda.is_available: True
tensor([0.2501, 0.4582, 0.8599, 0.0313], device='cuda:0')
tensor([0.3795, 0.0543, 0.4973, 0.4942], device='cuda:0')

torch.cuda.manual_seed_all

Code:

import torch
import torch._inductor.config
torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    # Reset CUDA seeds
    torch.cuda.manual_seed_all(3)
    # Generate a random tensor on the GPU
    return torch.rand(4, device='cuda')

# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([0.0901, 0.8324, 0.4412, 0.2539], device='cuda:0')
tensor([0.5561, 0.6098, 0.8558, 0.1980], device='cuda:0')

torch.cuda.random.manual_seed

Code

import torch
import torch._inductor.config

torch._inductor.config.fallback_random = True

# Ensure a CUDA device is available.
if not torch.cuda.is_available():
    print("CUDA is not available on this system.")

@torch.compile
def foo():
    # Reset GPU random seed
    torch.cuda.random.manual_seed(3)
    # Generate a random tensor on GPU
    return torch.rand(4, device='cuda')

# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([8.1055e-01, 4.8494e-01, 8.3937e-01, 6.7405e-04], device='cuda:0')
tensor([0.4365, 0.5669, 0.7746, 0.8702], device='cuda:0')

torch.xpu.random.set_rng_state_all

Code:

import torch
import torch._inductor.config
from torch.xpu.random import set_rng_state_all

torch._inductor.config.fallback_random = True

def get_fixed_rng_states():
    num_devices = 1 
    fixed_state = torch.ByteTensor([42] * 128) 
    return [fixed_state for _ in range(num_devices)]

@torch.compile
def foo():
    fixed_states = get_fixed_rng_states()
    set_rng_state_all(fixed_states)
    return torch.rand(4)

result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([0.5937, 0.2101, 0.3331, 0.2723])
tensor([0.4328, 0.0258, 0.5986, 0.3621])

torch.xpu.random.manual_seed_all

Code:

import torch
import torch._inductor.config
from torch.xpu.random import manual_seed_all

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    manual_seed_all(3)
    return torch.rand(4)

result1 = foo()
result2 = foo()

print(result1)
print(result2)

Output:

tensor([0.3189, 0.9985, 0.7599, 0.6258])
tensor([0.4467, 0.3736, 0.8722, 0.9614])

torch.xpu.manual_seed_all

Code:

import torch
from torch.xpu import manual_seed_all  # Import the target API
import torch._inductor.config

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    manual_seed_all(3)
    return torch.rand(4)

result1 = foo()
result2 = foo()

print(result1)
print(result2)

Output:

tensor([0.2231, 0.9563, 0.8660, 0.9350])
tensor([0.4353, 0.9493, 0.2995, 0.1608])

Versions

torch 2.6.0

cc @pbelevich @chauhang @penguinwu

@malfet malfet added module: random Related to random number generation in PyTorch (rng generator) oncall: pt2 labels Mar 20, 2025
@IvanKobzarev
Copy link
Contributor

manual_seed op is in dynamo_graph, but it is not in aot and inductor.

@bdhirsh @eellison Is it expected, that we do not have random ops in aot and inductor?

DEBUG: TRACED GRAPH
 ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
 <eval_with_key>.0 class GraphModule(torch.nn.Module):
    def forward(self):
         # File: /home/ivankobzarev/task_manual_seed/r.py:8 in foo, code: torch.cuda.manual_seed_all(3)
        manual_seed_all = torch.cuda.random.manual_seed_all(3);  manual_seed_all = None
        
         # File: /home/ivankobzarev/task_manual_seed/r.py:10 in foo, code: return torch.rand(4, device='cuda')
        rand: "f32[4]" = torch.rand(4, device = 'cuda')
        return (rand,)
        
DEBUG: TRACED GRAPH
 ===== __compiled_fn_1 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self):
         # File: /home/ivankobzarev/task_manual_seed/r.py:8 in foo, code: torch.cuda.manual_seed_all(3)
        manual_seed_all = torch.cuda.random.manual_seed_all(3);  manual_seed_all = None
        
         # File: /home/ivankobzarev/task_manual_seed/r.py:10 in foo, code: return torch.rand(4, device='cuda')
        rand: "f32[4][1]cuda:0" = torch.rand(4, device = 'cuda')
        return (rand,)
        
INFO: TRACED GRAPH
 ===== Forward graph 0 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
         # File: /home/ivankobzarev/task_manual_seed/r.py:10 in foo, code: return torch.rand(4, device='cuda')
        rand: "f32[4][1]cuda:0" = torch.ops.aten.rand.default([4], device = device(type='cuda'), pin_memory = False)
        return (rand,)
        
DEBUG: TRACED GRAPH
 ===== tensorify_python_scalars =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
         # File: /home/ivankobzarev/task_manual_seed/r.py:10 in foo, code: return torch.rand(4, device='cuda')
        rand: "f32[4]" = torch.ops.aten.rand.default([4], device = device(type='cuda'), pin_memory = False)
        return (rand,)
        
DEBUG: Output code: 
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
async_compile.wait(globals())
del async_compile
def call(args):
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Topologically Sorted Source Nodes: [rand], Original ATen: [aten.rand]
        buf0 = torch.ops.aten.rand.default([4], device=device(type='cuda'), pin_memory=False)
        buf1 = buf0
        assert_size_stride(buf1, (4, ), (1, ))
        del buf0
    return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    fn = lambda: call([])
    return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
DEBUG: Output code written to: /tmp/torchinductor_ivankobzarev/wh/cwhw72gkmfq2ybuc6xlmtd72m4vnuwt6fb2lo4qxznr373d5sflt.py
tensor([0.0848, 0.8599, 0.4381, 0.6466], device='cuda:0')
tensor([0.8507, 0.2237, 0.4168, 0.4407], device='cuda:0')

@IvanKobzarev IvanKobzarev added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 24, 2025
@bdhirsh
Copy link
Contributor
bdhirsh commented Mar 24, 2025

cc @anijain2305 torch.cuda.manual_seed seems painful to properly handle in the compiled graph - do you know why we capture it today vs. graph breaking?

@eellison
Copy link
Contributor

Yea, we should be graph breaking, and I thought we did. see #109109

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: random Related to random number generation in PyTorch (rng generator) oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0