8000 [Doc] Add deprecated autocast comments for doc (#126062) · pytorch/pytorch@58378f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 58378f1

Browse files
guangyeypytorchmergebot
authored andcommitted
[Doc] Add deprecated autocast comments for doc (#126062)
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. Pull Request resolved: #126062 Approved by: https://github.com/eqy, https://github.com/albanD
1 parent 08aa704 commit 58378f1

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

test/test_autocast.py

8000
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def test_generic_autocast(self):
253253
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
254254
self.assertEqual(generic_autocast_output, cpu_autocast_output)
255255

256+
def test_cpu_autocast_deprecated_warning(self):
257+
with self.assertWarnsRegex(
258+
DeprecationWarning,
259+
r"torch.cpu.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cpu', args...\) instead.",
260+
):
261+
with torch.cpu.amp.autocast():
262+
_ = torch.ones(10)
263+
256264

257265
class CustomLinear(torch.autograd.Function):
258266
@staticmethod

test/test_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,6 +1978,14 @@ def test_autocast_checkpointing(self):
19781978
self.assertTrue(output.dtype is torch.float16)
19791979
output.sum().backward()
19801980

1981+
def test_cuda_autocast_deprecated_warning(self):
1982+
with self.assertWarnsRegex(
1983+
DeprecationWarning,
1984+
r"torch.cuda.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cuda', args...\) instead.",
1985+
):
1986+
with torch.cuda.amp.autocast():
1987+
_ = torch.ones(10)
1988+
19811989
@slowTest
19821990
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
19831991
@serialTest()

torch/cpu/amp/autocast_mode.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any
23

34
import torch
@@ -8,7 +9,7 @@
89
class autocast(torch.amp.autocast_mode.autocast):
910
r"""
1011
See :class:`torch.autocast`.
11-
``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
12+
``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead.
1213
"""
1314

1415
def __init__(
@@ -22,6 +23,10 @@ def __init__(
2223
self.device = "cpu"
2324
self.fast_dtype = dtype
2425
return
26+
warnings.warn(
27+
"torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.",
28+
DeprecationWarning,
29+
)
2530
super().__init__(
2631
"cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
2732
)

torch/cuda/amp/autocast_mode.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import functools
3+
import warnings
34

45
import torch
56

@@ -17,7 +18,7 @@
1718
class autocast(torch.amp.autocast_mode.autocast):
1819
r"""See :class:`torch.autocast`.
1920
20-
``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
21+
``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead.
2122
"""
2223

2324
def __init__(
@@ -31,6 +32,10 @@ def __init__(
3132
self.device = "cuda"
3233
self.fast_dtype = dtype
3334
return
35+
warnings.warn(
36+
"torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.",
37+
DeprecationWarning,
38+
)
3439
super().__init__(
3540
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
3641
)

0 commit comments

Comments
 (0)
0