8000 uint8 scalar tensors cannot be used for integer indexing · Issue #70916 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

uint8 scalar tensors cannot be used for integer indexing #70916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #58743
pmeier opened this issue Jan 6, 2022 · 4 comments
Open
Tracked by #58743

uint8 scalar tensors cannot be used for integer indexing #70916

pmeier opened this issue Jan 6, 2022 · 4 comments
Labels
module: python array api Issues related to the Python Array API triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pmeier
Copy link
Collaborator
pmeier commented Jan 6, 2022

Integer, scalar tensors should behave like integers when used as index. Tensors of dtype torch.uint8 deviate from that:

import torch

t_1d_single = torch.empty(1)
t_1d_multi = torch.empty(2)

for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
    print("single", dtype, t_1d_single[torch.tensor(0, dtype=dtype)].shape)
    print("multi1", dtype, t_1d_multi[torch.tensor(0, dtype=dtype)].shape)
    print("multi2", dtype, t_1d_multi[torch.tensor(1, dtype=dtype)].shape)
    print("#" * 50)
single torch.uint8 torch.Size([0, 1])
multi1 torch.uint8 torch.Size([0, 2])
multi2 torch.uint8 torch.Size([1, 2])
##################################################
single torch.int8 torch.Size([])
multi1 torch.int8 torch.Size([])
multi2 torch.int8 torch.Size([])
##################################################
single torch.int16 torch.Size([])
multi1 torch.int16 torch.Size([])
multi2 torch.int16 torch.Size([])
##################################################
single torch.int32 torch.Size([])
multi1 torch.int32 torch.Size([])
multi2 torch.int32 torch.Size([])
##################################################
single torch.int64 torch.Size([])
multi1 torch.int64 torch.Size([])
multi2 torch.int64 torch.Size([])
##################################################

cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi

@pmeier pmeier added the module: python array api Issues related to the Python Array API label Jan 6, 2022
@albanD
Copy link
Collaborator
albanD commented Jan 6, 2022

I think this is related to the fact that uint8 used to be our type used for masking before bool was introduced.
But masking with uint8 is deprecated, you should have a warning right? So maybe it is time to forbid it completely. So that in next release we can enable indexing with it?

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 6, 2022
@pmeier
Copy link
Collaborator Author
pmeier commented Jan 6, 2022

you should have a warning right?

Nope.

@mruberry
Copy link
Collaborator
mruberry commented Jan 6, 2022

We would accept a PR that improves the deprecation warnings for uint8 indexing.

@ngimel
Copy link
Collaborator
ngimel commented Jan 22, 2022

For a special case of 0d indexing tensor, we don't warn and silently convert it to long index before sending to index function

impl::recordTensorIndex(impl::boolToIndexingTensor(result, tensor.item<uint8_t>() != 0, original_tensor_device), outIndices, dim_ptr);
. It's not particularly efficient, and it fails to trigger a warning that other uint8 indexing invocations trigger.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python array api Issues related to the Python Array API triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0