8000 Avoid autocast deprecation warning in DataParallel (#130660) · pytorch/pytorch@3f524bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f524bd

Browse files
awaelchliguangyey
authored andcommitted
Avoid autocast deprecation warning in DataParallel (#130660)
Fixes #130659 Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Pull Request resolved: #130660 Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD
1 parent b0a835f commit 3f524bd

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torch/nn/parallel/parallel_apply.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from torch._utils import ExceptionWrapper
66
from torch.cuda._utils import _get_device_index
7-
from torch.cuda.amp import autocast
87

98
from ..modules import Module
109

@@ -89,9 +88,9 @@ def _worker(
8988
if stream is None:
9089
stream = torch.cuda.current_stream(device)
9190
try:
92-
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(
93-
enabled=autocast_enabled
94-
):
91+
with torch.cuda.device(device), torch.cuda.stream(
92+
stream
93+
), torch.amp.autocast("cuda", enabled=autocast_enabled):
9594
# this also avoids accidental slicing of `input` if it is a Tensor
9695
if not isinstance(input, (list, tuple)):
9796
input = (input,)

0 commit comments

Comments
 (0)
0