8000 [AUTOCAST] FEAT: Allow passing a `torch.device` object to autocast by guillemc23 · Pull Request #153539 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[AUTOCAST] FEAT: Allow passing a torch.device object to autocast #153539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
10000 Expand Up @@ -378,11 +378,12 @@ def test_invalid_device(self):
assert torch.amp.is_autocast_available(device_type=dev)

def test_non_string_device(self):
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
"""Test that `autocast` now accepts a `torch.device` object for `device_type` and uses its type"""
dev = torch.device("cpu")
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
# Should work without raising an exception
with torch.autocast(device_type=dev):
x = torch.tensor([1.0])
self.assertTrue(x.dtype == torch.float32)


if __name__ == "__main__":
Expand Down
15 changes: 11 additions & 4 deletions torch/_logging/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import traceback
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any, Optional, TypedDict

import torch._logging._internal

Expand All @@ -16,6 +16,13 @@
DUMPED_FILES: set[str] = set()


class TracebackFrame(TypedDict):
line: int
name: str
filename: int
loc: Optional[str]


def intern_string(s: Optional[str]) -> int:
if s is None:
return -1
Expand Down Expand Up @@ -47,7 +54,7 @@
)


def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[dict[str, Any]]:
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[TracebackFrame]:
# dict naming convention here coincides with
# python/combined_traceback.cpp
r = [
Expand All @@ -59,10 +66,10 @@
}
for frame in tb
]
return r

Check failure on line 69 in torch/_logging/structured.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [return-value]

Incompatible return value type (got "list[dict[str, str | int | None]]", expected "list[TracebackFrame]")


def get_user_stack(num_frames: int) -> list[dict[str, Any]]:
def get_user_stack(num_frames: int) -> list[TracebackFrame]:
from torch._guards import TracingContext
from torch.utils._traceback import CapturedTraceback

Expand All @@ -85,7 +92,7 @@

def get_framework_stack(
num_frames: int = 25, cpp: bool = False
) -> list[dict[str, Any]]:
) -> list[TracebackFrame]:
"""
Returns the traceback for the user stack and the framework stack
"""
Expand Down
14 changes: 10 additions & 4 deletions torch/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import functools
import warnings
from typing import Any, Optional
from typing import Any, Optional, Union

import torch
from torch.types import _dtype
Expand Down Expand Up @@ -202,9 +202,10 @@ def forward(self, x):
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).

Args:
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
device_type(str or torch.device, required): Device type to use. Possible values are:
'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'.
If you pass a device, we will use the device type of that
device.
enabled(bool, optional): Whether autocasting should be enabled in the region.
Default: ``True``
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
Expand All @@ -222,10 +223,15 @@ def __init__(
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if not torch._jit_internal.is_scripting():
if isinstance(device_type, torch.device):
device_type = device_type.type

if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)

if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
Expand Down
Loading
0