8000 make sure we're not counting non-template, non-externs · pytorch/pytorch@76ffe99 · GitHub
[go: up one dir, main page]

Skip to content

Commit 76ffe99

Browse files
committed
make sure we're not counting non-template, non-externs
1 parent 730af8a commit 76ffe99

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

torch/_inductor/scheduler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,16 @@ def fuse(
13331333
@cache_on_self
13341334
def estimate_flops(self) -> int | None:
13351335
# don't increment counters in fused methods so we don't double count
1336-
fps = list(filter(None, (node.estimate_flops() for node in self.get_nodes())))
1336+
fps = list(
1337+
filter(
1338+
None,
1339+
(
1340+
node.estimate_flops()
1341+
for node in self.get_nodes()
1342+
if node.is_template() or node.is_extern()
1343+
),
1344+
)
1345+
)
13371346
if len(fps) == 0:
13381347
return None
13391348
ret = sum(fps)
@@ -1888,7 +1897,16 @@ def get_outputs(self) -> list[SchedulerBuffer]:
18881897
@cache_on_self
18891898
def estimate_flops(self) -> int | None:
18901899
# don't increment counters in fused methods so we don't double count
1891-
fps = list(filter(None, (node.estimate_flops() for node in self.get_nodes())))
1900+
fps = list(
1901+
filter(
1902+
None,
1903+
(
1904+
node.estimate_flops()
1905+
for node in self.get_nodes()
1906+
if node.is_template() or node.is_extern()
1907+
),
1908+
)
1909+
)
18921910
if len(fps) == 0:
18931911
return None
18941912
ret = sum(fps)

0 commit comments

Comments
 (0)
0