8000 [dynamo] Initial support for `nonstrict_trace` by StrongerXi · Pull Request #146367 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Initial support for nonstrict_trace #146367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
StrongerXi committed Feb 11, 2025
commit ceff0c99ff99a50ae6340faf9f685f3136dcfbc5
19 changes: 19 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10274,6 +10274,25 @@ def fn(x, y):

self.assertEqual(actual, expected)

def test_pytree_leafspec_as_proxy(self):
import torch.utils._pytree as pytree

@allow_in_graph
def inner_fn(spec, x):
if spec.num_leaves == 1:
return x + 1
return x + 2

def fn(x):
return inner_fn(pytree._LEAF_SPEC, x)

fn_opt = torch.compile(fullgraph=True)(fn)
inps = (torch.ones(2),)
actual = fn_opt(*inps)
expected = fn(*inps)

self.assertEqual(actual, expected)

def test_shape_env_no_recording(self):
main = ShapeEnv(should_record_events=False)

Expand Down
20 changes: 2 additions & 18 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,6 @@ def wrap_name(module_key):
return wrap_name(k)

name = OutputGraph.module_key_name(*names)

name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
Copy link
Contributor Author
@StrongerXi StrongerXi Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chunk is just refactoring name generation into this call.

self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):
Expand Down Expand Up @@ -1515,15 +1514,7 @@ def dedup_pass(self):
return dict()

def install_subgraph(self, name, sub_gm):
next_name = None
i = 0
while not next_name:
candidate = f"{name}_{i}"
if candidate in self.nn_modules:
i += 1
else:
next_name = candidate

next_name = get_unique_name_wrt(name, self.nn_modules)
sub_gm.__name__ = next_name
sub_gm.torchdynamo_force_dynamic = False
# This graph module is not present in the user space, so it can't be
Expand Down Expand Up @@ -2229,14 +2220,7 @@ def create_graph_input(
TracingContext.extract_stack()
)

# unique
if name in self.input_name_to_proxy:
for i in itertools.count():
candidate_name = f"{name}_{i}"
if candidate_name not in self.input_name_to_proxy:
name = candidate_name
break

name = get_unique_name_wrt(name, self.input_name_to_proxy)
if self.input_name_to_proxy:
prev_name = next(reversed(self.input_name_to_proxy))
node = self.input_name_to_proxy[prev_name].node
Expand Down
7 changes: 5 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,15 +2554,18 @@ def get_unique_name_wrt(prefix: str, *containers) -> str:
Return a name that starts with `prefix` and is not in any of the
`containers` (e.g., map, set).
"""
name = prefix
for i in itertools.count():
name = f"{prefix}_{i}"
found = False

for container in containers:
if name in container:
found = True
break

if not found:
return name
# else update and retry
name = f"{prefix}_{i}"

raise AssertionError("unreachable")

Expand Down
9 changes: 7 additions & 2 deletions torch/utils/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,16 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
return unflatten_fn(child_pytrees, self.context)


@dataclasses.dataclass(frozen=True, repr=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this affect the inherited __eq__ in any way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dataclasses.dataclass has eq=True by default, and will create a semantically equivalent LeafSpec.__eq__ as TreeSpec.__eq__. Lemme use eq=False explicitly here to demand reuse. In theory one can also provide a faster LeafSpec.__eq__ that's type based only (rather than comparing individual fields, in the generated __eq__).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for checking

class LeafSpec(TreeSpec):
def __init__(self) -> None:
super().__init__(None, None, [])
type: Any = dataclasses.field(default=None, init=False)
context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field(
default_factory=list, init=False
)

def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation.
Copy link
Contributor Author
@StrongerXi StrongerXi Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically making LeafSpec a true dataclass whose __init__ signature accords with dataclasses.field(leaf_spec_instance). This avoids overspecializing FrozenDataclassVariable. as_python_constant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit more why we needed to change LeafSpec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, so without this change, the implementation in FrozenDataclassVariable. as_python_constant (which is basically a more correct variation of FrozenDataclassVariable.as_proxy) wouldn't work for LeafSpec, because the method assumes a correlation between dataclasses.fields(obj) and obj.__class__.__init__:

args = []
kwargs = {}
for field in fields(self.value):
    if field.init:
        data = self.fields[field.name].as_python_constant()
        if getattr(field, "kw_only", False):
            kwargs[field.name] = data
        else:
            args.append(data)
                                                       
ctor = self.python_type()
return ctor(*args, **kwargs)

Specifically, in the old LeafSpec, __init__ takes in no argument, but fields(obj) would give the type, context and children_spec fields with init=True (inherited from TreeSpec), so the impl above would mistakenly give too many positional args to __init__.

So we could either

  1. specialize FrozenDataclassVariable.as_python_constant to LeafSpec and TreeSpec by grabbing the relevant fields and invoking constructor manually
  2. do what I did in this patch, which I chose because it also makes it easier to reason about the LeafSpec dataclass in general (with the dataclasses.fields method).

Lmk what you think, I have no strong preference here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

object.__setattr__(self, "num_nodes", 1)
object.__setattr__(self, "num_leaves", 1)
object.__setattr__(self, "num_children", 0)
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
0