|
61 | 61 | ]
|
62 | 62 |
|
63 | 63 |
|
| 64 | +__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch") |
| 65 | +__TORCH_DICT_SESSION.__enter__() # enable globally and permanently |
| 66 | + |
| 67 | + |
64 | 68 | T = TypeVar("T")
|
65 | 69 | S = TypeVar("S")
|
66 | 70 | U = TypeVar("U")
|
@@ -285,20 +289,15 @@ def tree_flatten(
|
285 | 289 |
|
286 | 290 | >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
287 | 291 | >>> tree_flatten(tree)
|
288 |
| - ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) |
| 292 | + ([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) |
289 | 293 | >>> tree_flatten(1)
|
290 |
| - ([1], PyTreeSpec(*, NoneIsLeaf)) |
| 294 | + ([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) |
291 | 295 | >>> tree_flatten(None)
|
292 |
| - ([None], PyTreeSpec(*, NoneIsLeaf)) |
293 |
| -
|
294 |
| - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is |
295 |
| - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` |
296 |
| - if you want to keep the keys in the insertion order. |
297 |
| -
|
| 296 | + ([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) |
298 | 297 | >>> from collections import OrderedDict
|
299 | 298 | >>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
|
300 | 299 | >>> tree_flatten(tree)
|
301 |
| - ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)) |
| 300 | + ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch')) |
302 | 301 |
|
303 | 302 | Args:
|
304 | 303 | tree (pytree): A pytree to flatten.
|
@@ -357,7 +356,7 @@ def tree_iter(
|
357 | 356 |
|
358 | 357 | >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
359 | 358 | >>> list(tree_iter(tree))
|
360 |
| - [1, 2, 3, 4, None, 5] |
| 359 | + [2, 3, 4, 1, None, 5] |
361 | 360 | >>> list(tree_iter(1))
|
362 | 361 | [1]
|
363 | 362 | >>> list(tree_iter(None))
|
@@ -392,7 +391,7 @@ def tree_leaves(
|
392 | 391 |
|
393 | 392 | >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
394 | 393 | >>> tree_leaves(tree)
|
395 |
| - [1, 2, 3, 4, None, 5] |
| 394 | + [2, 3, 4, 1, None, 5] |
396 | 395 | >>> tree_leaves(1)
|
397 | 396 | [1]
|
398 | 397 | >>> tree_leaves(None)
|
@@ -427,11 +426,11 @@ def tree_structure(
|
427 | 426 |
|
428 | 427 | >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
429 | 428 | >>> tree_structure(tree)
|
430 |
| - PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
| 429 | + PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch') |
431 | 430 | >>> tree_structure(1)
|
432 |
| - PyTreeSpec(*, NoneIsLeaf) |
| 431 | + PyTreeSpec(*, NoneIsLeaf, namespace='torch') |
433 | 432 | >>> tree_structure(None)
|
434 |
| - PyTreeSpec(*, NoneIsLeaf) |
| 433 | + PyTreeSpec(*, NoneIsLeaf, namespace='torch') |
435 | 434 |
|
436 | 435 | Args:
|
437 | 436 | tree (pytree): A pytree to flatten.
|
|
0 commit comments