8000 [pytree] support PyStructSequence types for Python pytree by XuehaiPan · Pull Request #113258 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] support PyStructSequence types for Python pytree #113258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 154 commits into
base: gh/XuehaiPan/14/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
154 commits
Select commit Hold shift + click to select a range
2fc3da2
[pytree] support PyStructSequence types for Python pytree
XuehaiPan Nov 8, 2023
62f0b8d
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 8, 2023
6a294a8
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 8, 2023
8f546f7
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 8, 2023
43a2325
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 11, 2023
ebc1b27
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 11, 2023
cf1ae5e
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 11, 2023
8475b97
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 11, 2023
d89d657
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 18, 2023
2eb6fff
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 18, 2023
c5fa703
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 18, 2023
9a940b5
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 22, 2023
380cfd0
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 22, 2023
e0a01e0
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 28, 2023
aa066d2
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
d9d16a7
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
14297e9
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
ebc844c
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
26faa30
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
845f037
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
05e7983
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Nov 30, 2023
bc3ca2f
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Dec 1, 2023
f0baf18
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Dec 1, 2023
a61ff5a
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Dec 7, 2023
6f2be9b
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Dec 8, 2023
ec64d86
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Dec 24, 2023
5e31000
Update on "[pytree] support PyStructSequence types for Python pytree"
XuehaiPan Jan 20, 2024
ed91b71
Update
XuehaiPan Mar 20, 2024
c5431ae
Update
XuehaiPan Mar 22, 2024
28605bd
Update
XuehaiPan Mar 24, 2024
484f024
Update
XuehaiPan Apr 21, 2024
28d82e5
Update
XuehaiPan Apr 21, 2024
e4224e7
Update
XuehaiPan Jun 21, 2024
78adeb4
Update
XuehaiPan Jul 22, 2024
27050fe
Update
XuehaiPan Aug 12, 2024
b0c0291
Update
XuehaiPan Oct 20, 2024
2b0c0d3
Update
XuehaiPan Oct 20, 2024
59ee2c8
Update
XuehaiPan Oct 20, 2024
0bb2611
Update
XuehaiPan Oct 20, 2024
3a8a3c4
Update
XuehaiPan Oct 20, 2024
3bc59c4
Update
XuehaiPan Oct 20, 2024
d103d6f
Update
XuehaiPan Oct 21, 2024
baf9639
Update
XuehaiPan Oct 21, 2024
2705eb0
Update
XuehaiPan Oct 21, 2024
ab26fb5
Update
XuehaiPan Oct 21, 2024
6392274
Update
XuehaiPan Oct 21, 2024
9963e2b
Update
XuehaiPan Oct 21, 2024
a0a7275
Update
XuehaiPan Oct 21, 2024
f692809
Update
XuehaiPan Oct 21, 2024
312b879
Update
XuehaiPan Oct 21, 2024
6a3a30b
Update
XuehaiPan Oct 21, 2024
0bde528
Update
XuehaiPan Oct 21, 2024
e2fbde8
Update
XuehaiPan Oct 22, 2024
0ccaa9b
Update
XuehaiPan Oct 22, 2024
61762c2
Update
XuehaiPan Oct 22, 2024
d622085
Update
XuehaiPan Oct 22, 2024
02e5a90
Update
XuehaiPan Oct 22, 2024
a6ae9ba
Update
XuehaiPan Oct 24, 2024
9d6d32d
Update
XuehaiPan Oct 24, 2024
d5cb3bb
Update
XuehaiPan Oct 24, 2024
7531a9f
Update
XuehaiPan Oct 25, 2024
81bab69
Update
XuehaiPan Oct 29, 2024
d5df9f5
Update
XuehaiPan Oct 29, 2024
0f2580c
Update
XuehaiPan Oct 29, 2024
f2429c0
Update
XuehaiPan Oct 29, 2024
35f3b8a
Update
XuehaiPan Oct 29, 2024
5987236
Update
XuehaiPan Oct 29, 2024
83d8183
Update
XuehaiPan Oct 29, 2024
607cc30
Update
XuehaiPan Oct 30, 2024
73f5fb3
Update
XuehaiPan Oct 30, 2024
2af386d
Update
XuehaiPan Nov 5, 2024
15d65d9
Update
XuehaiPan Nov 11, 2024
4fd7c75
Update
XuehaiPan Nov 17, 2024
40dadf3
Update
XuehaiPan Nov 20, 2024
a4a45b3
Update
XuehaiPan Nov 20, 2024
b2c5515
Update
XuehaiPan Nov 20, 2024
c8450bf
Update
XuehaiPan Nov 20, 2024
65c0297
Update
XuehaiPan Nov 20, 2024
92946c4
Update
XuehaiPan Nov 20, 2024
0f4b8fc
Update
XuehaiPan Nov 20, 2024
ca5f6df
Update
XuehaiPan Nov 21, 2024
735ac30
Update
XuehaiPan Nov 21, 2024
ae21d50
Update
XuehaiPan Nov 21, 2024
f5591b0
Update
XuehaiPan Nov 21, 2024
d7dd93e
Update
XuehaiPan Nov 22, 2024
d3b1ad0
Update
XuehaiPan Nov 22, 2024
ff1052b
Update
XuehaiPan Nov 26, 2024
baa1b22
Update
XuehaiPan Nov 26, 2024
6c73653
Update
XuehaiPan Nov 27, 2024
188d3ca
Update
XuehaiPan Dec 2, 2024
fe8181e
Update
XuehaiPan Dec 2, 2024
6c93d67
Update
XuehaiPan Dec 7, 2024
6cceebd
Update
XuehaiPan Dec 9, 2024
90f2365
Update
XuehaiPan Dec 13, 2024
7faa200
Update
XuehaiPan Jan 13, 2025
f251e7c
Update
XuehaiPan Feb 4, 2025
ff5523d
Update
XuehaiPan Feb 25, 2025
98a7dab
Update
XuehaiPan Feb 25, 2025
74c7d88
Update
XuehaiPan Feb 25, 2025
97a7ef3
Update
XuehaiPan Feb 25, 2025
23807f0
Update
XuehaiPan Feb 25, 2025
8573d0d
Update
XuehaiPan Feb 25, 2025
ee4586e
Update
XuehaiPan Feb 25, 2025
d1848b5
Update
XuehaiPan Feb 25, 2025
b930161
Update
XuehaiPan Feb 25, 2025
bf63d7a
Update
XuehaiPan Feb 26, 2025
7bcc529
Update
XuehaiPan Feb 26, 2025
c6dbbee
Update
XuehaiPan Feb 26, 2025
c8f6b44
Update
XuehaiPan Feb 26, 2025
2d7f03a
Update
XuehaiPan Feb 26, 2025
e9741e4
Update
XuehaiPan Feb 26, 2025
a11008d
Update
XuehaiPan Feb 26, 2025
61d14ef
Update
XuehaiPan Feb 28, 2025
ecf82cc
Update
XuehaiPan Mar 3, 2025
db4ff87
Update
XuehaiPan Mar 6, 2025
0573e00
Update
XuehaiPan Mar 7, 2025
4e9fe91
Update
XuehaiPan Mar 8, 2025
8b03f5c
Update
XuehaiPan Mar 8, 2025
4dd5394
Update
XuehaiPan Mar 8, 2025
c2ae349
Update
XuehaiPan Mar 9, 2025
8e2ca8d
Update
XuehaiPan Mar 11, 2025
b23377e
Update
XuehaiPan Mar 12, 2025
cd7154e
Update
XuehaiPan Mar 13, 2025
acff96a
Update
XuehaiPan Mar 14, 2025
58d6cf7
Update
XuehaiPan Mar 20, 2025
9451abd
Update
XuehaiPan Mar 31, 2025
8976c0a
Update
XuehaiPan Mar 31, 2025
67a7b31
Update
XuehaiPan Mar 31, 2025
ea3874a
Update
XuehaiPan Apr 1, 2025
84e9321
Update
XuehaiPan Apr 2, 2025
c3c1119
Update
XuehaiPan Apr 2, 2025
a4d0a3c
Update
XuehaiPan Apr 2, 2025
9f70896
Update
XuehaiPan Apr 3, 2025
8000
4fe3104
Update
XuehaiPan Apr 5, 2025
cfc602b
Update
XuehaiPan Apr 7, 2025
f009887
Update
XuehaiPan Apr 23, 2025
ce0fd35
Update
XuehaiPan Apr 26, 2025
4523c06
Update
XuehaiPan May 1, 2025
e4184c7
Update
XuehaiPan May 5, 2025
7fb72b0
Update
XuehaiPan May 8, 2025
056a129
Update
XuehaiPan May 14, 2025
d7f944e
Update
XuehaiPan May 14, 2025
797bb0f
Update
XuehaiPan May 16, 2025
ce4623e
Update
XuehaiPan May 27, 2025
df0c904
Update
XuehaiPan May 28, 2025
1c5b296
Update
XuehaiPan May 31, 2025
c519e7c
Update
XuehaiPan Jun 6, 2025
c52cf1c
Update
XuehaiPan Jun 18, 2025
d6c487c
Update
XuehaiPan Jun 18, 2025
abe952a
Update
XuehaiPan Jun 27, 2025
4361c94
Update
XuehaiPan Jun 28, 2025
1e39c00
Update
XuehaiPan Jul 3, 2025
8f38334
Update
XuehaiPan Jul 9, 2025
f3a097e
Update
XuehaiPan Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[pytree] support PyStructSequence types for Python pytree"
[ghstack-poisoned]
  • Loading branch information
XuehaiPan committed Nov 8, 2023
commit 62f0b8d1858a3c46b599518d3b51b055c3e729cf
6 changes: 4 additions & 2 deletions torch/return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@


def pytree_register_structseq(cls):
if not torch.utils._pytree.is_structseq_class(cls):
warnings.warn(f"Class {cls!r} is not a PyStructSequence class.")
if torch.utils._pytree.is_structseq_class(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuehaiPan: what's the motivation for having all PyStructSequence be pytrees? All namedtuples being pytrees is already unexpected behavior for some people.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the motivation for having all PyStructSequence be pytrees?

@zou3519 The idea is there might be a use case that someone writes a custom operator and wants it supported by pytree by default. Currently, if a third-party package has a custom return type, it needs to call torch.return_types.pytree_register_structseq manually.

There is also a codegen utility to generate C++ code for PyStructSequence type and assign it to torch._C._return_types.

if typename is None:
typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
typenames[tn_key] = typename
definitions.append(
f"""\
PyTypeObject* get_{name}_namedtuple() {{
static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }};
static PyTypeObject {typename};
static bool is_initialized = false;
static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
if (!is_initialized) {{
PyStructSequence_InitType(&{typename}, &desc);
{typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
is_initialized = true;
}}
return &{typename};
}}
"""

See also:


All namedtuples being pytrees is already unexpected behavior for some people.

I agree that some people want their namedtuple class to be an opaque object rather than a container. Using dataclasses.dataclass might be an alternative for them. Before this PR, all namedtuples are already pytrees and all dataclasses are leaf nodes by default.

See also discussion for making all dataclasses be pytrees by default:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is there might be a use case that someone writes a custom operator and wants it supported by pytree by default. Currently, if a third-party package has a custom return type, it needs to call torch.return_types.pytree_register_structseq manually.

What does JAX's pytree do on structseq types?

For custom ops that return structseq: I don't think there is a way to easily do this, so I wouldn't worry about that use case.

Copy link
Collaborator Author
@XuehaiPan XuehaiPan Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does JAX's pytree do on structseq types?

By default, structseqs are leaf nodes and all namedtuples are pytree nodes in JAX's pytree. Users need to manually register structseq types are pytree nodes like what we do for pytree_register_structseq.

The JAX community uses namedtuples instead of structseqs. They never use structseqs. If you search PyStructSequence in https://github.com/google/jax and https://github.com/tensorflow/tensorflow, you will get an empty result.

https://github.com/search?q=repo%3Agoogle%2Fjax+PyStructSequence&type=code
https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+PyStructSequence&type=code

Here are lists of structseq types:

CPython internals
In [1]: import sys

In [2]: import importlib

In [3]: for mod in sys.stdlib_module_names:
   ...:     try:
   ...:         importlib.import_module(mod)
   ...:     except ImportError:
   ...:         pass
   ...:         

In [4]: import optree

In [5]: list(filter(optree.is_structseq_class, tuple.__subclasses__()))
Out[5]: 
[
     <class 'sys.int_info'>,
     <class 'sys.float_info'>,
     <class 'UnraisableHookArgs'>,
     <class 'sys.hash_info'>,
     <class 'sys.version_info'>,
     <class 'sys.flags'>,
     <class 'sys.thread_info'>,
     <class 'asyncgen_hooks'>,
     <class '_thread._ExceptHookArgs'>,
     <class 'os.stat_result'>,
     <class 'os.statvfs_result'>,
     <class 'os.terminal_size'>,
     <class 'posix.times_result'>,
     <class 'posix.uname_result'>,
     <class 'time.struct_time'>,
     <class 'resource.struct_rusage'>,
     <class '_lsprof.profiler_entry'>,
     <class '_lsprof.profiler_subentry'>,
     <class 'curses.ncurses_version'>,
     <class 'pwd.struct_passwd'>,
     <class 'grp.struct_group'>
]
With `torch.return_types`
In [6]: import torch

In [7]: list(filter(optree.is_structseq_class, tuple.__subclasses__()))
Out[7]: 
[
     <class 'sys.int_info'>,
     <class 'sys.float_info'>,
     <class 'UnraisableHookArgs'>,
     <class 'sys.hash_info'>,
     <class 'sys.version_info'>,
     <class 'sys.flags'>,
     <class 'sys.thread_info'>,
     <class 'asyncgen_hooks'>,
     <class '_thread._ExceptHookArgs'>,
     <class 'os.stat_result'>,
     <class 'os.statvfs_result'>,
     <class 'os.terminal_size'>,
     <class 'posix.times_result'>,
     <class 'posix.uname_result'>,
     <class 'time.struct_time'>,
     <class 'resource.struct_rusage'>,
     <class '_lsprof.profiler_entry'>,
     <class '_lsprof.profiler_subentry'>,
     <class 'curses.ncurses_version'>,
     <class 'pwd.struct_passwd'>,
     <class 'grp.struct_group'>,
     <class 'numpy.core.multiarray.typeinfo'>,
     <class 'numpy.core.multiarray.typeinforanged'>,
     <class 'torch.return_types._fake_quantize_per_tensor_affine_cachemask_tensor_qparams'>,
     <class 'torch.return_types._fused_moving_avg_obs_fq_helper'>,
     <class 'torch.return_types._linalg_det'>,
     <class 'torch.return_types._linalg_det_out'>,
     <class 'torch.return_types._linalg_eigh'>,
     <class 'torch.return_types._linalg_eigh_out'>,
     <class 'torch.return_types._linalg_slogdet'>,
     <class 'torch.return_types._linalg_slogdet_out'>,
     <class 'torch.return_types._linalg_solve_ex'>,
     <class 'torch.return_types._linalg_solve_ex_out'>,
     <class 'torch.return_types._linalg_svd'>,
     <class 'torch.return_types._linalg_svd_out'>,
     <class 'torch.return_types._lu_with_info'>,
     <class 'torch.return_types._scaled_dot_product_efficient_attention'>,
     <class 'torch.return_types._scaled_dot_product_flash_attention'>,
     <class 'torch.return_types._unpack_dual'>,
     <class 'torch.return_types.aminmax'>,
     <class 'torch.return_types.aminmax_out'>,
     <class 'torch.return_types.cummax'>,
     <class 'torch.return_types.cummax_out'>,
     <class 'torch.return_types.cummin'>,
     <class 'torch.return_types.cummin_out'>,
     <class 'torch.return_types.frexp'>,
     <class 'torch.return_types.frexp_out'>,
     <class 'torch.return_types.geqrf_out'>,
     <class 'torch.return_types.geqrf'>,
     <class 'torch.return_types.histogram_out'>,
     <class 'torch.return_types.histogram'>,
     <class 'torch.return_types.histogramdd'>,
     <class 'torch.return_types.kthvalue'>,
     <class 'torch.return_types.kthvalue_out'>,
     <class 'torch.return_types.linalg_cholesky_ex'>,
     <class 'torch.return_types.linalg_cholesky_ex_out'>,
     <class 'torch.return_types.linalg_eig'>,
     <class 'torch.return_types.linalg_eig_out'>,
     <class 'torch.return_types.linalg_eigh'>,
     <class 'torch.return_types.linalg_eigh_out'>,
     <class 'torch.return_types.linalg_inv_ex'>,
     <class 'torch.return_types.linalg_inv_ex_out'>,
     <class 'torch.return_types.linalg_ldl_factor'>,
     <class 'torch.return_types.linalg_ldl_factor_out'>,
     <class 'torch.return_types.linalg_ldl_factor_ex'>,
     <class 'torch.return_types.linalg_ldl_factor_ex_out'>,
     <class 'torch.return_types.linalg_lstsq'>,
     <class 'torch.return_types.linalg_lstsq_out'>,
     <class 'torch.return_types.linalg_lu'>,
     <class 'torch.return_types.linalg_lu_out'>,
     <class 'torch.return_types.linalg_lu_factor'>,
     <class 'torch.return_types.linalg_lu_factor_out'>,
     <class 'torch.return_types.linalg_lu_factor_ex'>,
     <class 'torch.return_types.linalg_lu_factor_ex_out'>,
     <class 'torch.return_types.linalg_qr'>,
     <class 'torch.return_types.linalg_qr_out'>,
     <class 'torch.return_types.linalg_slogdet'>,
     <class 'torch.return_types.linalg_slogdet_out'>,
     <class 'torch.return_types.linalg_solve_ex'>,
     <class 'torch.return_types.linalg_solve_ex_out'>,
     <class 'torch.return_types.linalg_svd'>,
     <class 'torch.return_types.linalg_svd_out'>,
     <class 'torch.return_types.lu_unpack'>,
     <class 'torch.return_types.lu_unpack_out'>,
     <class 'torch.return_types.max'>,
     <class 'torch.return_types.max_out'>,
     <class 'torch.return_types.median'>,
     <class 'torch.return_types.median_out'>,
     <class 'torch.return_types.min'>,
     <class 'torch.return_types.min_out'>,
     <class 'torch.return_types.mode'>,
     <class 'torch.return_types.mode_out'>,
     <class 'torch.return_types.nanmedian'>,
     <class 'torch.return_types.nanmedian_out'>,
     <class 'torch.return_types.qr_out'>,
     <class 'torch.return_types.qr'>,
     <class 'torch.return_types.slogdet'>,
     <class 'torch.return_types.slogdet_out'>,
     <class 'torch.return_types.sort_out'>,
     <class 'torch.return_types.sort'>,
     <class 'torch.return_types.svd_out'>,
     <class 'torch.return_types.svd'>,
     <class 'torch.return_types.topk_out'>,
     <class 'torch.return_types.topk'>,
     <class 'torch.return_types.triangular_solve_out'>,
     <class 'torch.return_types.triangular_solve'>
]

For the use cases for pytree in PyTorch, the leaf values are mostly torch.Tensor or None. The Python internal PyStructSequeunce types should never be used as a container in the PyTorch use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should be consistent, so we should not register structseq as pytree nodes by default. However, users can manually add them (and we "manually" add them for all of the torch.return_types).

Does optree support structseq types as pytree nodes by default? If so, we probably want some way to toggle this behavior to make it consistent.

Copy link
Collaborator Author
@XuehaiPan XuehaiPan Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should be consistent, so we should not register structseq as pytree nodes by default. However, users can manually add them (and we "manually" add them for all of the torch.return_types).

@zou3519 Do you mean to be consistent with JAX pytree or the old behavior of pytree in PyTorch?


Does optree support structseq types as pytree nodes by default?

Yes. The motivations are:

  1. PyStructSequence is a concept in Python C API but shares some similar semantics of namedtuple.

    Python Documentation: Struct Sequence Objects

    Struct sequence objects are the C equivalent of namedtuple() objects, i.e. a sequence whose items can also be accessed through attributes. To create a struct sequence, you first have to create a specific struct sequence type.

    Also, the CPython upstream may want to add members to PyStructSequence to make it consistent with namedtuple in Python side.

    Since we treat all namedtuple types are pytree nodes, it is reasonable to have all PyStructSequence types be pytree nodes if the upstream PR is landed.

    Here is the implementation for is_namedtuple_class in the previous PR in the ghstack:

    def is_namedtuple_class(cls):
        return (
            isinstance(cls, type)
            and issubclass(cls, tuple)
            and isinstance(getattr(cls, "_fields", None), tuple)
            and all(type(field) is str for field in cls._fields)
            and callable(getattr(cls, "_make", None))    # <= structseq fails here even if python/cpython#108648 landed
            and callable(getattr(cls, "_asdict", None))
        )

    We will get is_namedtuple(structseq) to be False even if GH-46145: make PyStructSequence compatible with namedtuple python/cpython#108648 landed.

    But for JAX pytree, it only checks the presence of _fields:

    def is_namedtuple(obj):
        return isinstance(obj, tuple) and hasattr(obj, '_fields')

    All PyStructSequence instances will be namedtuple if GH-46145: make PyStructSequence compatible with namedtuple python/cpython#108648 landed. I.e., they will be pytree nodes automatically in JAX pytree.

  2. PyStructSequence types are subclasses of tuple, and they are container types. This fits the definition of pytree node. That's why we manually register all types in torch.return_types.* as pytree node. Here is an example usage that requires this: tree_map(lambda x: x.cuda(), tree).

  3. The JAX community does not use PyStructSequence.

  4. The number of PyStructSequence is limited (see [pytree] support PyStructSequence types for Python pytree #113258 (comment)). Because it is related to C code and there is no easy way to create a new type from the Python side.


If so, we probably want some way to toggle this behavior to make it consistent.

I can add an API to have namespace-specific configurations in optree. But before that, I want to hear your thoughts about the explanations above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Do you mean to be consistent with JAX pytree or the old behavior of pytree in PyTorch?

We should be consistent with JAX pytree. Also, the old behavior of pytree in PyTorch is the same as JAX's behavior: structseq is not treated as a pytree node type, unless the user explicitly specifies it. We've explicitly specified that all of the torch.return_types are pytree nodes.

  1. PyStructSequence is a concept in Python C API but shares some similar semantics of namedtuple.

The only user of structseq we've seen out in the wild is torch.return_types. So it doesn't matter too much what we do with it, but we should be consistent because inconsistencies end up leading to UX issues.

All PyStructSequence instances will be namedtuple if landed

I would wait until those PRs actually land before changing the behavior, to be consistent

  1. PyStructSequence types are subclasses of tuple, and they are container types. This fits the definition of pytree node. That's why we manually register all types in torch.return_types.* as pytree node. Here is an example usage that requires this: tree_map(lambda x: x.cuda(), tree).

I don't think subclasses of a type get the pytree treatment; we require exact match, right?
I agree PyStructSequence is a container type, but not all container types are pytree nodes by default: see dataclass.

We manually register all types in torch.return_types.* as pytree nodes because we own them and as the owners, we've decided they should have pytree semantics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All PyStructSequence instances will be namedtuple if landed
I would wait until those PRs actually land before changing the behavior, to be consistent

Should we ask on the PR when they think they'll merge it? lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ask on the PR when they think they'll merge it? lol

Based on the label of the issue, the initial plan is to have this in Python 3.9 but delayed to 3.13. I think they will merge relevant changes eventually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just leave this PR open and avoid changing optree for now. As we discussed above, people don't really use the structseq API.

return

warnings.warn(f"Class {cls!r} is not a PyStructSequence class.")

def structseq_flatten(structseq):
return list(structseq), type(structseq)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.
0