8000 [pytree] add APIs to determine a class is a namedtuple or PyStructSequence by XuehaiPan · Pull Request #113257 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] add APIs to determine a class is a namedtuple or PyStructSequence #113257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 117 commits into from

Conversation

XuehaiPan
Copy link
Collaborator
@XuehaiPan XuehaiPan commented Nov 8, 2023

Stack from ghstack (oldest at bottom):

Changes in this PR:

  1. Add is_structseq and is_structseq_class functions to determine a object or a class is PyStructSequence.
  2. Add a generic class structseq which can be used as the registration key for PyStructSequence types like namedtuple for Named Tuple types.
  3. Change is_namedtuple to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by collections.namedtuple or typing.NamedTuple were namedtuple classes while their subclasses were not. This PR makes is_namedtuple return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

Copy link
pytorch-bot bot commented Nov 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113257

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 1979e85 with merge base f74d5d5 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@XuehaiPan XuehaiPan requested a review from zou3519 November 8, 2023 12:42
@XuehaiPan XuehaiPan self-assigned this Nov 8, 2023
@XuehaiPan XuehaiPan added the topic: not user facing topic category label Nov 8, 2023
@XuehaiPan XuehaiPan linked an issue Nov 8, 2023 that may be closed by this pull request
Copy link
Collaborator
@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to change the behavior for every namedtuple class in the wild? That sounds pretty BC-breaking?

@XuehaiPan XuehaiPan added the topic: bc breaking topic category label Nov 9, 2023
@XuehaiPan
Copy link
Collaborator Author
XuehaiPan commented Nov 9, 2023

Is this going to change the behavior for every namedtuple class in the wild? That sounds pretty BC-breaking?

It will not affect namedtuple classes directly created by colllections.namedtuple or typing.NamedTuple:

# Not affected by this PR
DirectNamedTuple1 = collections.namedtuple('DirectNamedTuple1', ['field1', 'field2', 'field3'])


# Not affected by this PR
class DirectNamedTuple2(typing.NamedTuple):
    fieldA: int
    fieldB: str
    fieldC: torch.Tensor

This PR affects subclasses of the namedtuple classes. They are kind of namedtuple but not created directly by collections.namedtuple.

# Not affected by this PR
DirectNamedTuple1 = collections.namedtuple('DirectNamedTuple1', ['field1', 'field2', 'field3'])


# Before this PR. The following type is not a namedtuple. It is a leaf type in pytree.
# After this PR. The following type is a namedtuple. It is a node type in pytree.
class ChildNamedTuple1(DirectNamedTuple1):
    pass


# Not affected by this PR
class DirectNamedTuple2(typing.NamedTuple):
    fieldA: int
    fieldB: str
    fieldC: torch.Tensor


# Before this PR. The following type is not a namedtuple. It is a leaf type in pytree.
# After this PR. The following type is a namedtuple. It is a node type in pytree.
class ChildNamedTuple2(DirectNamedTuple2):
    pass

Before this PR:

>>> is_namedtuple_class(DirectNamedTuple1)
True
>>> is_namedtuple_class(DirectNamedTuple2)
True
>>> is_namedtuple_class(ChildNamedTuple1)
False
>>> is_namedtuple_class(ChildNamedTuple2)
False

After this PR:

>>> is_namedtuple_class(DirectNamedTuple1)
True
>>> is_namedtuple_class(DirectNamedTuple2)
True
>>> is_namedtuple_class(ChildNamedTuple1)
True
>>> is_namedtuple_class(ChildNamedTuple2)
True

Some references: a subclass of a dataclass is still a dataclass.

In [1]: import dataclasses

In [2]: @dataclasses.dataclass
   ...: class MyClass:
   ...:     a: int
   ...:     b: str
   ...:     

In [3]: dataclasses.is_dataclass(MyClass)
Out[3]: True

In [4]: class OtherClass(MyClass):
   ...:     pass
   ...:     

In [5]: dataclasses.is_dataclass(OtherClass)
Out[5]: True

XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Nov 11, 2023
…PyStructSequence"


Resolves #75982. New tests are included in this PR.

- #75982


[ghstack-poisoned]
…PyStructSequence"


Resolves #75982. New tests are included in this PR.

- #75982


[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Nov 11, 2023
…PyStructSequence"


Resolves #75982. New tests are included in this PR.

- #75982


[ghstack-poisoned]
@XuehaiPan
Copy link
Collaborator Author

Internally, we have been saving the treespec and loading it for models

@angelayi This PR only changes the behavior during flattening. If the treespec is created and saved by the old code, there should be not break by this PR.

8000

@angelayi
Copy link
Contributor

Ah you're right sorry, it's not BC breaking, it's FC breaking (but that's not a super strong requirement). The internal failure is that since the type is now the actual namedtuple class instead of "namedtuple", our internal cpp implementation of pytree fails since "we didn't do a pytree registration for that type".

[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2025
@XuehaiPan
Copy link
Collaborator Author
XuehaiPan commented Mar 31, 2025

The internal failure is that since the type is now the actual namedtuple class instead of "namedtuple", our internal cpp implementation of pytree fails since "we didn't do a pytree registration for that type".

Is there anything I can help? The 2.7.0 release branch is finalized today. cc @zou3519 @angelayi

@angelayi
Copy link
Contributor

Sorry, what I meant is that this PR should not change the TreeSpec schema since it's breaking assumptions for frameworks downstream, like our cpp implementation.

@XuehaiPan
Copy link
Collaborator Author

it's breaking assumptions for frameworks downstream, like our cpp implementation.

The assumption would be that the type can be reconstructed by the field names, which is not fulfilled for named tuple subclasses. Is it correct?

@angelayi
Copy link
Contributor

The assumption is that for namedtuple pytrees, the treespec looks like TreeSpec(namedtuple, <class '__main__.DirectNamedTuple1'>, [*, *])

@XuehaiPan
Copy link
Collaborator Author

The assumption is that for namedtuple pytrees, the treespec looks like TreeSpec(namedtuple, <class '__main__.DirectNamedTuple1'>, [*, *])

This PR does not change that.

In [1]: import torch.utils._pytree as pytree

In [2]: from typing import NamedTuple

In [3]: class DirectNamedTuple(NamedTuple):
   ...:     x: int
   ...:     y: int
   ...:     z: int
   ...:     

In [4]: class ChildNamedTuple(DirectNamedTuple):
   ...:     pass
   ...:     

In [5]: pytree.tree_structure(DirectNamedTuple(1, 2, 3))
Out[5]: 
TreeSpec(namedtuple, <class '__main__.DirectNamedTuple'>, [*,
  *,
  *])

In [6]: pytree.tree_structure(ChildNamedTuple(1, 2, 3))
Out[6]: 
TreeSpec(namedtuple, <class '__main__.ChildNamedTuple'>, [*,
  *,
  *])

@angelayi
Copy link
Contributor
(Pdb) from collections import namedtuple
(Pdb) import torch.utils._pytree as pytree
(Pdb) DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
(Pdb) pytree._register_namedtuple(DirectNamedTuple1, serialized_type_name="DirectNamedTuple1")
(Pdb) pytree.tree_flatten(DirectNamedTuple1(1, 2))[1]
TreeSpec(DirectNamedTuple1, <class 'sigmoid.inference.test.e2e_test_utils.DirectNamedTuple1'>, [*,
  *])

[ghstack-poisoned]
@XuehaiPan
Copy link
Collaborator Author
XuehaiPan commented Mar 31, 2025
(Pdb) from collections import namedtuple
(Pdb) import torch.utils._pytree as pytree
(Pdb) DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
(Pdb) pytree._register_namedtuple(DirectNamedTuple1, serialized_type_name="DirectNamedTuple1")
(Pdb) pytree.tree_flatten(DirectNamedTuple1(1, 2))[1]
TreeSpec(DirectNamedTuple1, <class 'sigmoid.inference.test.e2e_test_utils.DirectNamedTuple1'>, [*,
  *])

@angelayi Thanks for the hint. I updated the PR and resolved this in the latest commit. A new test is also included. Please take a review, thanks!

XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2025
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2025
@@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec):
def store_namedtuple_fields(ts):
if ts.type is None:
return
if ts.type == namedtuple:
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: Due to this, previously, all tests passed in OSS while break internally. With the latest commit, this can be reverted.

@XuehaiPan
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / unit-test / linux-jammy-cpu-py3.12-gcc11-inductor-halide / build

Details for Dev Infra team Raised by workflow job

@XuehaiPan
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Apr 2, 2025
…13257)

Summary:
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

X-link: pytorch/pytorch#113257
Approved by: https://github.com/zou3519

Reviewed By: clee2000

Differential Revision: D72251158

fbshipit-source-id: cd57a40c0507ff7cb21fc42f42d42edf95379eef
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…uence (pytorch#113257)

Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves pytorch#75982. New tests are included in this PR.

- pytorch#75982

Pull Request resolved: pytorch#113257
Approved by: https://github.com/zou3519
@github-actions github-actions bot deleted the gh/XuehaiPan/13/head branch May 2, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ci-test-showlocals Show local variables on test failures ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source Reverted Stale topic: bc breaking topic category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

API to determine if a torch.return_type is a "structseq"
8 participants
0