10000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f6838d5 commit bb62e9dCopy full SHA for bb62e9d
torch/nn/parallel/parallel_apply.py
@@ -4,7 +4,6 @@
4
import torch
5
from torch._utils import ExceptionWrapper
6
from torch.cuda._utils import _get_device_index
7
-from torch.cuda.amp import autocast
8
9
from ..modules import Module
10
@@ -89,9 +88,9 @@ def _worker(
89
88
if stream is None:
90
stream = torch.cuda.current_stream(device)
91
try:
92
- with torch.cuda.device(device), torch.cuda.stream(stream), autocast(
93
- enabled=autocast_enabled
94
- ):
+ with torch.cuda.device(device), torch.cuda.stream(
+ stream
+ ), torch.amp.autocast("cuda", enabled=autocast_enabled):
95
# this also avoids accidental slicing of `input` if it is a Tensor
96
if not isinstance(input, (list, tuple)):
97
input = (input,)
0 commit comments