8000 [Dynamo] Replace `unimplemented` with `unimplemented_v2` in `torch/_dynamo/variables/iter.py` by shink · Pull Request #151789 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/iter.py #151789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions torch/_dynamo/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import sys
from typing import Optional, TYPE_CHECKING, Union

from .. import polyfills, variables
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
handle_observed_exception,
ObservedUserStopIteration,
raise_observed_exception,
unimplemented,
unimplemented_v2,
UserError,
)
from .base import ValueMutationNew, VariableTracker
Expand Down Expand Up @@ -76,13 +76,14 @@ def call_function(
from .builtin import BuiltinVariable

if any(key not in ["initial", "func"] for key in kwargs.keys()):
unimplemented(
"Unsupported kwargs for itertools.accumulate: "
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
unimplemented_v2(
gb_type="Unsupported kwargs for itertools.accumulate",
context=f"call_function {self} {args} {kwargs}",
explanation=f"Expected kwargs: 'initial', 'func', but got "
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}",
hints=[*graph_break_hints.USER_ERROR],
)

acc = kwargs.get("initial")

if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)

Expand All @@ -94,13 +95,32 @@ def call_function(
# Default to operator.add
func = BuiltinVariable(operator.add).call_function
else:
unimplemented(
"itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
unimplemented_v2(
gb_type="Unsupported `func` in itertools.accumulate",
context=f"call_function {self} {args} {kwargs}",
explanation="Dynamo does not know how to get the "
"function to use for itertools.accumulate. "
"itertools.accumulate expects the `func` as the second "
"argument or as a keyword argument.",
hints=[*graph_break_hints.USER_ERROR],
)
else:
unimplemented("Unsupported arguments for itertools.accumulate")
unimplemented_v2(
gb_type="Unsupported arguments for itertools.accumulate",
context=f"call_function {self} {args} {kwargs}",
explanation="Dynamo does not know how to trace "
f"itertools.accumulate with args: {args} and kwargs: {kwargs}. "
"itertools.accumulate expects an iterable, an optional "
"binary function for accumulation, and an optional initial "
"value to set the starting state.",
hints=[
"Make sure the arguments to itertools.accumulate are correct.",
*graph_break_hints.SUPPORTABLE,
],
)

items = []
acc = kwargs.get("initial")
if acc is not None:
items.append(acc)
for item in seq:
Expand All @@ -110,8 +130,12 @@ def call_function(
try:
acc = func(tx, [acc, item], {})
except Exception as e:
unimplemented(
f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
unimplemented_v2(
gb_type="Unexpected failure during itertools.accumulate() iteration",
context=f"call_function {self} {args} {kwargs}",
explanation="Unexpected failure in invoking function during accumulate. "
f"Failed running func {func}({item}{acc})",
hints=[*graph_break_hints.DIFFICULT],
from_exc=e,
)
items.append(acc)
Expand All @@ -137,9 +161,12 @@ def call_function(
)
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
unimplemented(
"Unsupported kwar 8000 gs for itertools.groupby: "
f"{','.join(set(kwargs.keys()) - {'key'})}"
unimplemented_v2(
gb_type="Unsupported kwargs for itertools.groupby",
context=f"call_function {self} {args} {kwargs}",
explanation=f"Expected kwargs: 'key', but got "
f"{','.join(set(kwargs.keys()) - {'key'})}",
hints=[*graph_break_hints.USER_ERROR],
)

def retrieve_const_key(key):
Expand All @@ -148,14 +175,30 @@ def retrieve_const_key(key):
elif isinstance(key, variables.ConstantVariable):
return key.as_python_constant()
else:
unimplemented(
"Unsupported key type for itertools.groupby: " + str(type(key))
unimplemented_v2(
gb_type="Unsupported key type for itertools.groupby",
context=f"call_function {self} {args} {kwargs}",
explanation="Dynamo does not know how to trace "
f"itertools.groupby with key type: {str(type(key))}. "
"We only support grouping keys that are constants (int, float, str, etc.)",
hints=[*graph_break_hints.SUPPORTABLE],
)

if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)
else:
unimplemented("Unsupported arguments for itertools.groupby")
unimplemented_v2(
gb_type="Unsupported arguments for itertools.groupby",
context=f"call_function {self} {args} {kwargs}",
explanation="Dynamo does not know how to trace "
f"itertools.groupby with args: {args} and kwargs: {kwargs}. "
"itertools.groupby expects an iterable to group and an "
"optional key function to determine groupings.",
hints=[
"Make sure the arguments to itertools.groupby are correct.",
*graph_break_hints.SUPPORTABLE,
],
)

if "key" in kwargs:

Expand Down Expand Up @@ -186,8 +229,11 @@ def keyfunc(x):
< 8000 /td> )
)
except Exception as e:
unimplemented(
"Unexpected failure when calling itertools.groupby",
unimplemented_v2(
gb_type="Unexpected failure during itertools.groupby() iteration",
context=f"call_function {self} {args} {kwargs}",
explanation="Unexpected failure in invoking function during groupby",
hints=[*graph_break_hints.SUPPORTABLE],
from_exc=e,
)
return variables.ListIteratorVariable(
Expand Down Expand Up @@ -219,7 +265,12 @@ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def next_variable(self, tx):
unimplemented("abstract method, must implement")
unimplemented_v2(
gb_type="Unsupported next() call",
context=f"next({self})",
explanation="This abstract method must be implemented",
hints=[*graph_break_hints.DYNAMO_BUG],
)

# NOTE: only call when unpacking this iterator safely done eagerly!
# Normally, iterators are accessed lazily.
Expand Down Expand Up @@ -321,8 +372,11 @@ def next_variable(self, tx):
try:
new_item = self.iterator.next_variable(tx)
if len(self.saved) > MAX_ITERATOR_LIMIT:
unimplemented(
"input iterator to itertools.cycle has too many items"
unimplemented_v2(
gb_type="input iterator to itertools.cycle has too many items",
context=f"next({self})",
explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}",
hints=[],
)
tx.output.side_effects.mutation(self)
self.saved.append(new_item)
Expand Down
Loading
0