Open
Description
🐛 Describe the bug
When using torch.compile, torch.cuda.manual_seed/torch.cuda.manual_seed_all/torch.cuda.random.manual_seed do not seem to properly enforce reproducibility across multiple calls to a compiled function.
torch.cuda.manual_seed
Code:
import torch
import torch._inductor.config
torch._inductor.config.fallback_random = True
@torch.compile
def foo():
# Set the GPU seed
torch.cuda.manual_seed(3)
# Create a random tensor on the GPU.
# If a CUDA device is available, the tensor will be created on CUDA.
return torch.rand(4, device='cuda' if torch.cuda.is_available() else 'cpu')
# Call the compiled function twice
print("cuda.is_available:", torch.cuda.is_available())
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
cuda.is_available: True
tensor([0.2501, 0.4582, 0.8599, 0.0313], device='cuda:0')
tensor([0.3795, 0.0543, 0.4973, 0.4942], device='cuda:0')
torch.cuda.manual_seed_all
Code:
import torch
import torch._inductor.config
torch._inductor.config.fallback_random = True
@torch.compile
def foo():
# Reset CUDA seeds
torch.cuda.manual_seed_all(3)
# Generate a random tensor on the GPU
return torch.rand(4, device='cuda')
# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
tensor([0.0901, 0.8324, 0.4412, 0.2539], device='cuda:0')
tensor([0.5561, 0.6098, 0.8558, 0.1980], device='cuda:0')
torch.cuda.random.manual_seed
Code
import torch
import torch._inductor.config
torch._inductor.config.fallback_random = True
# Ensure a CUDA device is available.
if not torch.cuda.is_available():
print("CUDA is not available on this system.")
@torch.compile
def foo():
# Reset GPU random seed
torch.cuda.random.manual_seed(3)
# Generate a random tensor on GPU
return torch.rand(4, device='cuda')
# Call the compiled function twice
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
tensor([8.1055e-01, 4.8494e-01, 8.3937e-01, 6.7405e-04], device='cuda:0')
tensor([0.4365, 0.5669, 0.7746, 0.8702], device='cuda:0')
torch.xpu.random.set_rng_state_all
Code:
import torch
import torch._inductor.config
from torch.xpu.random import set_rng_state_all
torch._inductor.config.fallback_random = True
def get_fixed_rng_states():
num_devices = 1
fixed_state = torch.ByteTensor([42] * 128)
return [fixed_state for _ in range(num_devices)]
@torch.compile
def foo():
fixed_states = get_fixed_rng_states()
set_rng_state_all(fixed_states)
return torch.rand(4)
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
tensor([0.5937, 0.2101, 0.3331, 0.2723])
tensor([0.4328, 0.0258, 0.5986, 0.3621])
torch.xpu.random.manual_seed_all
Code:
import torch
import torch._inductor.config
from torch.xpu.random import manual_seed_all
torch._inductor.config.fallback_random = True
@torch.compile
def foo():
manual_seed_all(3)
return torch.rand(4)
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
tensor([0.3189, 0.9985, 0.7599, 0.6258])
tensor([0.4467, 0.3736, 0.8722, 0.9614])
torch.xpu.manual_seed_all
Code:
import torch
from torch.xpu import manual_seed_all # Import the target API
import torch._inductor.config
torch._inductor.config.fallback_random = True
@torch.compile
def foo():
manual_seed_all(3)
return torch.rand(4)
result1 = foo()
result2 = foo()
print(result1)
print(result2)
Output:
tensor([0.2231, 0.9563, 0.8660, 0.9350])
tensor([0.4353, 0.9493, 0.2995, 0.1608])
Versions
torch 2.6.0