-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence #113257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…uence [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113257
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 1979e85 with merge base f74d5d5 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this going to change the behavior for every namedtuple class in the wild? That sounds pretty BC-breaking?
It will not affect namedtuple classes directly created by # Not affected by this PR
DirectNamedTuple1 = collections.namedtuple('DirectNamedTuple1', ['field1', 'field2', 'field3'])
# Not affected by this PR
class DirectNamedTuple2(typing.NamedTuple):
fieldA: int
fieldB: str
fieldC: torch.Tensor This PR affects subclasses of the namedtuple classes. They are kind of # Not affected by this PR
DirectNamedTuple1 = collections.namedtuple('DirectNamedTuple1', ['field1', 'field2', 'field3'])
# Before this PR. The following type is not a namedtuple. It is a leaf type in pytree.
# After this PR. The following type is a namedtuple. It is a node type in pytree.
class ChildNamedTuple1(DirectNamedTuple1):
pass
# Not affected by this PR
class DirectNamedTuple2(typing.NamedTuple):
fieldA: int
fieldB: str
fieldC: torch.Tensor
# Before this PR. The following type is not a namedtuple. It is a leaf type in pytree.
# After this PR. The following type is a namedtuple. It is a node type in pytree.
class ChildNamedTuple2(DirectNamedTuple2):
pass Before this PR: >>> is_namedtuple_class(DirectNamedTuple1)
True
>>> is_namedtuple_class(DirectNamedTuple2)
True
>>> is_namedtuple_class(ChildNamedTuple1)
False
>>> is_namedtuple_class(ChildNamedTuple2)
False After this PR: >>> is_namedtuple_class(DirectNamedTuple1)
True
>>> is_namedtuple_class(DirectNamedTuple2)
True
>>> is_namedtuple_class(ChildNamedTuple1)
True
>>> is_namedtuple_class(ChildNamedTuple2)
True Some references: a subclass of a dataclass is still a dataclass. In [1]: import dataclasses
In [2]: @dataclasses.dataclass
...: class MyClass:
...: a: int
...: b: str
...:
In [3]: dataclasses.is_dataclass(MyClass)
Out[3]: True
In [4]: class OtherClass(MyClass):
...: pass
...:
In [5]: dataclasses.is_dataclass(OtherClass)
Out[5]: True |
…uence ghstack-source-id: a089db0 Pull Request resolved: pytorch#113257
…uence ghstack-source-id: 56d7a24 Pull Request resolved: pytorch#113257
@angelayi This PR only changes the behavior during flattening. If the treespec is created and saved by the old code, there should be not break by this PR. |
Ah you're right sorry, it's not BC breaking, it's FC breaking (but that's not a super strong requirement). The internal failure is that since the type is now the actual namedtuple class instead of "namedtuple", our internal cpp implementation of pytree fails since "we didn't do a pytree registration for that type". |
…uence ghstack-source-id: 6b8161c Pull Request resolved: pytorch#113257
Is there anything I can help? The 2.7.0 release branch is finalized today. cc @zou3519 @angelayi |
Sorry, what I meant is that this PR should not change the TreeSpec schema since it's breaking assumptions for frameworks downstream, like our cpp implementation. |
The assumption would be that the type can be reconstructed by the field names, which is not fulfilled for named tuple subclasses. Is it correct? |
The assumption is that for namedtuple pytrees, the treespec looks like |
This PR does not change that. In [1]: import torch.utils._pytree as pytree
In [2]: from typing import NamedTuple
In [3]: class DirectNamedTuple(NamedTuple):
...: x: int
...: y: int
...: z: int
...:
In [4]: class ChildNamedTuple(DirectNamedTuple):
...: pass
...:
In [5]: pytree.tree_structure(DirectNamedTuple(1, 2, 3))
Out[5]:
TreeSpec(namedtuple, <class '__main__.DirectNamedTuple'>, [*,
*,
*])
In [6]: pytree.tree_structure(ChildNamedTuple(1, 2, 3))
Out[6]:
TreeSpec(namedtuple, <class '__main__.ChildNamedTuple'>, [*,
*,
*]) |
|
@angelayi Thanks for the hint. I updated the PR and resolved this in the latest commit. A new test is also included. Please take a review, thanks! |
…uence ghstack-source-id: ad7098a Pull Request resolved: pytorch#113257
…uence ghstack-source-id: 846269e Pull Request resolved: pytorch#113257
@@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec): | |||
def store_namedtuple_fields(ts): | |||
if ts.type is None: | |||
return | |||
if ts.type == namedtuple: | |||
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB: Due to this, previously, all tests passed in OSS while break internally. With the latest commit, this can be reverted.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor / unit-test / linux-jammy-cpu-py3.12-gcc11-inductor-halide / build Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…13257) Summary: Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves #75982. New tests are included in this PR. - #75982 X-link: pytorch/pytorch#113257 Approved by: https://github.com/zou3519 Reviewed By: clee2000 Differential Revision: D72251158 fbshipit-source-id: cd57a40c0507ff7cb21fc42f42d42edf95379eef
…uence (pytorch#113257) Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves pytorch#75982. New tests are included in this PR. - pytorch#75982 Pull Request resolved: pytorch#113257 Approved by: https://github.com/zou3519
Stack from ghstack (oldest at bottom):
treespec_{leaf,tuple,dict}
functions for args_spec modification #138214Changes in this PR:
is_structseq
andis_structseq_class
functions to determine a object or a class is PyStructSequence.structseq
which can be used as the registration key for PyStructSequence types likenamedtuple
for Named Tuple types.is_namedtuple
to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created bycollections.namedtuple
ortyping.NamedTuple
were namedtuple classes while their subclasses were not. This PR makesis_namedtuple
return true for subclasses of namedtuple class.Resolves #75982. New tests are included in this PR.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec