10000 Consider changing the behavior of Tensor.__contains__(Tensor) to make more sense · Issue #24338 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Consider changing the behavior of Tensor.__contains__(Tensor) to make more sense #24338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zou3519 opened this issue Aug 14, 2019 · 6 comments
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: ux triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor
zou3519 commented Aug 14, 2019

🐛 Bug

import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([5, 6, 3])
y in x  # True

Expected behavior

This particular case should be False. Related: #17733 #24156
The incorrect semantics was introduced in PyTorch 1.2 (May), but no one else has complained about it yet in the three months since.

Environment

pytorch master.

cc @mruberry @rgommers @heitorschueroff

@gchanan
Copy link
Contributor
gchanan commented Aug 14, 2019
>>> np.array([5,6,3]) in np.array([1,2,3])
True

@zou3519
Copy link
Contributor Author
zou3519 commented Aug 14, 2019

A brief glance through https://stackoverflow.com/questions/18320624/how-does-contains-work-for-ndarrays suggests that the reason numpy's behavior is like this might be for backward compatibility.

I think this might be one of those cases where we shouldn't follow numpy behavior. If we think about a tensor as a generalized python list, it doesn't make sense for something like [5, 6, 3] in [1, 2, 3] to be True.

@zou3519 zou3519 changed the title Tensor.__contains__(Tensor) is wrong Consider changing the behavior of Tensor.__contains__(Tensor) to make more sense Aug 14, 2019
@zou3519
Copy link
Contributor Author
zou3519 commented Aug 14, 2019

Related numpy issue about this behavior: numpy/numpy#3016

@zhangguanheng66 zhangguanheng66 added module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 14, 2019
@zhangguanheng66
Copy link
Contributor

I think we should fix this magic behavior.

@tczhangzhi
Copy link
Contributor
tczhangzhi commented Aug 15, 2019

Well, very interesting!
I think it's a feature request to support non-scalar input, just as numpy.isin.
Or we need to check these cases:

x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2])
x in y
The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0 and should be False?

x = torch.tensor([1, 2, 3])
y = torch.tensor([[1], [2], [3]])
x in y
True and should be False?

x = torch.tensor([1, 2, 3])
y = torch.tensor([[[1, 2, 3]]])
x in y
True and should be False because [[1, 2, 3]] is the element but [1, 2, 3] is not?

@tczhangzhi
Copy link
Contributor

I wrote a note on the road, I don’t know if it meets our expectations.

# element = a
# self = b
def new_in(a, b):
    if isinstance(a, Number):
        return (a == b).any().item()
    if isinstance(a, torch.Tensor):
    # if new_instance(a, torch.Tensor):
        dim = a.dim()
        if dim > b.dim():
            return False

        try:
            result = a == b
        except RuntimeError:
            return False

        for axis in range(-dim, 0):
            if a.size(axis) == b.size(axis):
                result = result.all(axis, True)
            elif a.size(axis) > b.size(axis):
                return False
        return result.any().item()

    raise RuntimeError(
        "Tensor.__contains__ only supports Tensor or Number, but you passed in a %s." %
        type(a)
    )

x = torch.arange(0, 10)
print(new_in(4, x))
# Should True
print(new_in(12, x))
# Should False

x = torch.arange(1, 10).view(3, 3)
val = torch.arange(1, 4)
print(new_in(val, x))
# Should True

val += 10
print(new_in(val, x))
# Should False

x = torch.tensor([1, 2, 3])
y = torch.tensor([5, 6, 3])
print(new_in(x, y))
# Should False

x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2])
print(new_in(x, y))
# Should False

x = torch.tensor([1, 2, 3])
y = torch.tensor([[1], [2], [3]])
print(new_in(x, y))
# Should False

x = torch.tensor([1, 2, 3])
y = torch.tensor([[[1, 2, 3]]])
print(new_in(x, y))
# Should True

@mruberry mruberry added module: numpy Related to numpy support, and also numpy compatibility of our operators module: ux and removed module: operators (deprecated) labels Oct 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: ux triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0