8000 Use `torch.cumsum` instead of numpy one (#109400) · pytorch/pytorch@fb58a72 · GitHub
[go: up one dir, main page]

Skip to content

Commit fb58a72

Browse files
malfetpytorchmergebot
authored andcommitted
Use torch.cumsum instead of numpy one (#109400)
`s/list(numpy.cumsum(foo))/torch.cumsum(torch.tensor(foo), 0).tolist()/` Test plan: ` python3 ../test/inductor/test_split_cat_fx_passes.py -v` Partially addresses #109387 Pull Request resolved: #109400 Approved by: https://github.com/ezyang
1 parent 4ee179c commit fb58a72

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

torch/_inductor/fx_passes/split_cat.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import operator
44
from typing import Callable, List, Sequence, Tuple, Union
55

6-
import numpy
7-
86
import torch
97
from torch._dynamo.utils import counters
108

@@ -454,8 +452,7 @@ def get_simplified_split_ranges(
454452
for user_input in user_inputs
455453
if isinstance(user_input, tuple)
456454
}
457-
458-
cumulative_sizes = [0] + list(numpy.cumsum(split_sections))
455+
cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
459456
split_ranges = sorted(
460457
[(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges]
461458
)
@@ -578,7 +575,7 @@ def replace_split(
578575
for i in range(len(split_ranges))
579576
]
580577
# Now assign the right getitem to the right input
581-
cumulative_sizes = [0] + list(numpy.cumsum(split_sections))
578+
cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
582579
new_user_inputs_list = []
583580
for user_inputs in user_inputs_list:
584581
new_user_inputs = []

0 commit comments

Comments
 (0)
0