-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Fix MaskedTensor
to device ignored mask
#151205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151205
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b81d648 with merge base 25803d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
torch/masked/maskedtensor/core.py
Outdated
device_to = current_device | ||
if len(args) == 1: | ||
arg = args[0] | ||
if isinstance(arg, (torch.device, str, int)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use DeviceLIke from torch._prims_common and maybe use cannonicalize_device below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, do you mean using DeviceLikeType
like this?
def to(self, *args, **kwargs):
current_device = self._masked_data.device
device_to = current_device
if len(args) == 1:
arg = args[0]
if isinstance(arg, DeviceLikeType):
device_to = canonicalize_device(arg)
But mypy seems unhapply about it
>>> Lint for torch/masked/maskedtensor/core.py:
Error (MYPY) [misc]
Parameterized generics cannot be used with class or instance checks
364 | device_to = current_device
365 | if len(args) == 1:
366 | arg = args[0]
>>> 367 | if isinstance(arg, DeviceLikeType):
368 | device_to = canonicalize_device(arg)
369 | elif isinstance(arg, torch.Tensor):
370 | device_to = arg.device
Error (MYPY) [arg-type]
Argument 2 to "isinstance" has incompatible type "<typing special form>";
expected "_ClassInfo"
364 | device_to = current_device
365 | if len(args) == 1:
366 | arg = args[0]
>>> 367 | if isinstance(arg, DeviceLikeType):
368 | device_to = canonicalize_device(arg)
369 | elif isinstance(arg, torch.Tensor):
370 | device_to = arg.device
And found a solution in here python/mypy#12155 (comment)
15fce79
to
eea4367
Compare
MaskedTensor
to device ignored mask`MaskedTensor
to device ignored mask
device_to = current_device | ||
if len(args) == 1: | ||
arg = args[0] | ||
if isinstance(arg, get_args(DeviceLikeType)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually makes typing worse unfortunately as it breaks the type narrower so what you had before was probably better.
Also, curious if it's easy to add tests in our existing test framework for this |
Fixes #147140
Changes
to
implementation inMaskedTensor
to support movemask
to target deviceTest Result