@@ -293,7 +293,7 @@ def tree_flatten(
293
293
The flattening order (i.e., the order of elements in the output list) is deterministic,
294
294
corresponding to a left-to-right depth-first tree traversal.
295
295
296
- >>> tree = {'b' : (2, [3, 4]), 'a' : 1, 'c' : None, 'd' : 5}
296
+ >>> tree = {"b" : (2, [3, 4]), "a" : 1, "c" : None, "d" : 5}
297
297
>>> tree_flatten(tree)
298
298
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
299
299
>>> tree_flatten(1)
@@ -306,7 +306,7 @@ def tree_flatten(
306
306
if you want to keep the keys in the insertion order.
307
307
308
308
>>> from collections import OrderedDict
309
- >>> tree = OrderedDict([('b' , (2, [3, 4])), ('a' , 1), ('c' , None), ('d' , 5)])
309
+ >>> tree = OrderedDict([("b" , (2, [3, 4])), ("a" , 1), ("c" , None), ("d" , 5)])
310
310
>>> tree_flatten(tree)
311
311
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
312
312
@@ -335,7 +335,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
335
335
336
336
The inverse of :func:`tree_flatten`.
337
337
338
- >>> tree = {'b' : (2, [3, 4]), 'a' : 1, 'c' : None, 'd' : 5}
338
+ >>> tree = {"b" : (2, [3, 4]), "a" : 1, "c" : None, "d" : 5}
339
339
>>> leaves, treespec = tree_flatten(tree)
340
340
>>> tree == tree_unflatten(leaves, treespec)
341
341
True
@@ -365,7 +365,7 @@ def tree_iter(
365
365
366
366
See also :func:`tree_flatten`.
367
367
368
- >>> tree = {'b' : (2, [3, 4]), 'a' : 1, 'c' : None, 'd' : 5}
368
+ >>> tree = {"b" : (2, [3, 4]), "a" : 1, "c" : None, "d" : 5}
369
369
>>> list(tree_iter(tree))
370
370
[1, 2, 3, 4, None, 5]
371
371
>>> list(tree_iter(1))
@@ -400,7 +400,7 @@ def tree_leaves(
400
400
401
401
See also :func:`tree_flatten`.
402
402
403
- >>> tree = {'b' : (2, [3, 4]), 'a' : 1, 'c' : None, 'd' : 5}
403
+ >>> tree = {"b" : (2, [3, 4]), "a" : 1, "c" : None, "d" : 5}
404
404
>>> tree_leaves(tree)
405
405
[1, 2, 3, 4, None, 5]
406
406
>>> tree_leaves(1)
@@ -435,7 +435,7 @@ def tree_structure(
435
435
436
436
See also :func:`tree_flatten`.
437
437
438
- >>> tree = {'b' : (2, [3, 4]), 'a' : 1, 'c' : None, 'd' : 5}
438
+ >>> tree = {"b" : (2, [3, 4]), "a" : 1, "c" : None, "d" : 5}
439
439
>>> tree_structure(tree)
440
440
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
441
441
>>> tree_structure(1)
@@ -472,9 +472,9 @@ def tree_map(
472
472
473
473
See also :func:`tree_map_`.
474
474
475
- >>> tree_map(lambda x: x + 1, {'x' : 7, 'y' : (42, 64)})
475
+ >>> tree_map(lambda x: x + 1, {"x" : 7, "y" : (42, 64)})
476
476
{'x': 8, 'y': (43, 65)}
477
- >>> tree_map(lambda x: x is None, {'x' : 7, 'y' : (42, 64), 'z' : None})
477
+ >>> tree_map(lambda x: x is None, {"x" : 7, "y" : (42, 64), "z" : None})
478
478
{'x': False, 'y': (False, False), 'z': True}
479
479
480
480
If multiple inputs are given, the structure of the tree is taken from the first input;
@@ -572,7 +572,9 @@ def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
572
572
573
573
574
574
@overload
575
- def map_only (__type_or_types_or_pred : Type3 [T , S , U ]) -> MapOnlyFn [Fn3 [T , S , U , Any ]]:
575
+ def map_only (
576
+ __type_or_types_or_pred : Type3 [T , S , U ],
577
+ ) -> MapOnlyFn [Fn3 [T , S , U , Any ]]:
576
578
...
577
579
578
580
@@ -588,12 +590,14 @@ def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
588
590
589
591
590
592
@overload
591
- def map_only (__type_or_types_or_pred : Callable [[Any ], bool ]) -> MapOnlyFn [FnAny [Any ]]:
593
+ def map_only (
594
+ __type_or_types_or_pred : Callable [[Any ], bool ],
595
+ ) -> MapOnlyFn [FnAny [Any ]]:
592
596
...
593
597
594
598
595
599
def map_only (
596
- __type_or_types_or_pred : Union [TypeAny , Callable [[Any ], bool ]]
600
+ __type_or_types_or_pred : Union [TypeAny , Callable [[Any ], bool ]],
597
601
) -> MapOnlyFn [FnAny [Any ]]:
598
602
"""
599
603
Suppose you are writing a tree_map over tensors, leaving everything
@@ -858,7 +862,7 @@ def broadcast_prefix(
858
862
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
859
863
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
860
864
[1, 2, 3, 3]
861
- >>> broadcast_prefix([1, 2, 3], [1, 2, {'a' : 3, 'b' : 4, 'c' : (None, 5)}])
865
+ >>> broadcast_prefix([1, 2, 3], [1, 2, {"a" : 3, "b" : 4, "c" : (None, 5)}])
862
866
[1, 2, 3, 3, 3, 3]
863
867
864
868
Args:
@@ -873,13 +877,19 @@ def broadcast_prefix(
873
877
Returns:
874
878
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
875
879
"""
876
- return optree .broadcast_prefix (
880
+ result : List [Any ] = []
881
+
882
+ def add_leaves (x : Any , subtree : PyTree ) -> None :
883
+ subtreespec = tree_structure (subtree , is_leaf = is_leaf )
884
+ result .extend ([x ] * subtreespec .num_leaves )
885
+
886
+ tree_map_ (
887
+ add_leaves ,
877
888
prefix_tree ,
878
889
full_tree ,
879
890
is_leaf = is_leaf ,
880
- none_is_leaf = True ,
881
- namespace = "torch" ,
882
891
)
892
+ return result
883
893
884
894
885
895
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
0 commit comments