8000 [Set] Add correct set/frozenset __init__ behavior by guilhermeleobas · Pull Request #152908 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Set] Add correct set/frozenset __init__ behavior #152908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
15 changes: 15 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4340,6 +4340,20 @@ def fn(x):
ref = opt_fn(x)
self.assertEqual(ref, res)

@parametrize("_type", [set, frozenset], name_fn=lambda t: t.__name__)
def test_set_call___init__(self, _type):
@make_test
def fn(a, b):
s = _type({"apple", "banana", "cherry"})
s.__init__({"google", "microsoft", "apple"})
# frozenset should remain the same while set gets updated
if "banana" in s:
return a + b
else:
return a - b

fn(self)

def test_frozenset_construction(self):
def fn(x):
s = frozenset({x})
Expand Down Expand Up @@ -4989,6 +5003,7 @@ def __getattribute__(self, name):


instantiate_parametrized_tests(FunctionTests)
instantiate_parametrized_tests(DefaultsTests)

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
Empty file.
19 changes: 17 additions & 2 deletions torch/_dynamo/variables/dicts.py
5E19
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,14 @@ def call_method(
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
# We foward the calls to the dictionary model
if name == "add":
# We forward the calls to the dictionary model
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
return ConstantVariable.create(None)
elif name == "add":
assert not kwargs
if len(args) != 1:
raise_args_mismatch(tx, name)
Expand Down Expand Up @@ -905,6 +911,15 @@ def call_method(
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
# frozenset is immutable. Calling __init__ again shouldn't have any effect
# In[1]: s = frozenset([1, 2])
#
# In[2]: s.__init__([3, 4])
#
# In[3]: s
# frozenset({1, 2})
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)


Expand Down
Loading
0