8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ba8dc7d commit 54af3caCopy full SHA for 54af3ca
src/diffusers/utils/torch_utils.py
@@ -38,7 +38,7 @@ def maybe_allow_in_graph(cls):
38
def randn_tensor(
39
shape: Union[Tuple, List],
40
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
41
- device: Optional["torch.device"] = None,
+ device: Optional[Union[str, "torch.device"]] = None,
42
dtype: Optional["torch.dtype"] = None,
43
layout: Optional["torch.layout"] = None,
44
):
@@ -47,6 +47,8 @@ def randn_tensor(
47
is always created on the CPU.
48
"""
49
# device on which tensor is created defaults to device
50
+ if isinstance(device, str):
51
+ device = torch.device(device)
52
rand_device = device
53
batch_size = shape[0]
54
0 commit comments