8000 `torch.cuda.manual_seed` ignored · Issue #149621 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
torch.cuda.manual_seed ignored #149621
Open
Open
@vwrewsge

Description

@vwrewsge

🐛 Describe the bug

When using torch.compile, torch.cuda.manual_seed/torch.cuda.manual_seed_all/torch.cuda.random.manual_seed do not seem to properly enforce reproducibility across multiple calls to a compiled function.

torch.cuda.manual_seed

Code:

import torch
import torch._inductor.config

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    # Set the GPU seed
    torch.cuda.manual_seed(3)
    # Create a random tensor on the GPU.
    # If a CUDA device is available, the tensor will be created on CUDA.
    return torch.rand(4, device='cuda' if torch.cuda.is_available() else 'cpu')

# Call the compiled function twice
print("cuda.is_available:", torch.cuda.is_available())
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

cuda.is_available: True
tensor([0.2501, 0.4582, 0.8599, 0.0313], device='cuda:0')
tensor([0.3795, 0.0543, 0.4973, 0.4942], device='cuda:0')

torch.cuda.manual_seed_all

Code:

import torch
import torch._inductor.config
torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    # Reset CUDA seeds
    torch.cuda.manual_seed_all(3)
    # Generate a random tensor on the GPU
    return torch.rand(4, device='cuda')

# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([0.0901, 0.8324, 0.4412, 0.2539], device='cuda:0')
tensor([0.5561, 0.6098, 0.8558, 0.1980], device='cuda:0')

torch.cuda.random.manual_seed

Code

import torch
import torch._inductor.config

torch._inductor.config.fallback_random = True

# Ensure a CUDA device is available.
if not torch.cuda.is_available():
    print("CUDA is not available on this system.")

@torch.compile
def foo():
    # Reset GPU random seed
    torch.cuda.random.manual_seed(3)
    # Generate a random tensor on GPU
    return torch.rand(4, device='cuda')

# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([8.1055e-01, 4.8494e-01, 8.3937e-01, 6.7405e-04], device='cuda:0')
tensor([0.4365, 0.5669, 0.7746, 0.8702], device='cuda:0')

torch.xpu.random.set_rng_state_all

Code:

import torch
import torch._inductor.config
from torch.xpu.random import set_rng_state_all

torch._inductor.config.fallback_random = True

def get_fixed_rng_states():
    num_devices = 1 
    fixed_state = torch.ByteTensor([42] * 128) 
    return [fixed_state for _ in range(num_devices)]

@torch.compile
def foo():
    fixed_states = get_fixed_rng_states()
    set_rng_state_all(fixed_states)
    return torch.rand(4)

result1 = foo()
result2 = foo()
print(result1)
print(result2)

Output:

tensor([0.5937, 0.2101, 0.3331, 0.2723])
tensor([0.4328, 0.0258, 0.5986, 0.3621])

torch.xpu.random.manual_seed_all

Code:

import torch
import torch._inductor.config
from torch.xpu.random import manual_seed_all

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    manual_seed_all(3)
    return torch.rand(4)

result1 = foo()
result2 = foo()

print(result1)
print(result2)

Output:

tensor([0.3189, 0.9985, 0.7599, 0.6258])
tensor([0.4467, 0.3736, 0.8722, 0.9614])

torch.xpu.manual_seed_all

Code:

import torch
from torch.xpu import manual_seed_all  # Import the target API
import torch._inductor.config

torch._inductor.config.fallback_random = True

@torch.compile
def foo():
    manual_seed_all(3)
    return torch.rand(4)

result1 = foo()
result2 = foo()

print(result1)
print(result2)

Output:

tensor([0.2231, 0.9563, 0.8660, 0.9350])
tensor([0.4353, 0.9493, 0.2995, 0.1608])

Versions

torch 2.6.0

cc @pbelevich @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: randomRelated to random number generation in PyTorch (rng generator)oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0