10000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7b7dd1f commit eaf33f9Copy full SHA for eaf33f9
torch/nn/parallel/parallel_apply.py
@@ -4,7 +4,6 @@
4
import torch
5
from torch._utils import ExceptionWrapper
6
from torch.cuda._utils import _get_device_index
7
-from torch.cuda.amp import autocast
8
9
from ..modules import Module
10
@@ -89,9 +88,9 @@ def _worker(
89
88
if stream is None:
90
stream = torch.cuda.current_stream(device)
91
try:
92
- with torch.cuda.device(device), torch.cuda.stream(stream), autocast(
93
- enabled=autocast_enabled
94
- ):
+ with torch.cuda.device(device), torch.cuda.stream(
+ stream
+ ), torch.amp.autocast("cuda", enabled=autocast_enabled):
95
# this also avoids accidental slicing of `input` if it is a Tensor
96
if not isinstance(input, (list, tuple)):
97
input = (input,)
0 commit comments