8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b0a835f commit 3f524bdCopy full SHA for 3f524bd
torch/nn/parallel/parallel_apply.py
@@ -4,7 +4,6 @@
4
import torch
5
from torch._utils import ExceptionWrapper
6
from torch.cuda._utils import _get_device_index
7
-from torch.cuda.amp import autocast
8
9
from ..modules import Module
10
@@ -89,9 +88,9 @@ def _worker(
89
88
if stream is None:
90
stream = torch.cuda.current_stream(device)
91
try:
92
- with torch.cuda.device(device), torch.cuda.stream(stream), autocast(
93
- enabled=autocast_enabled
94
- ):
+ with torch.cuda.device(device), torch.cuda.stream(
+ stream
+ ), torch.amp.autocast("cuda", enabled=autocast_enabled):
95
# this also avoids accidental slicing of `input` if it is a Tensor
96
if not isinstance(input, (list, tuple)):
97
input = (input,)
0 commit comments