8000 add gaurd else true, guard else false and gaurd_size_oblivious in dec… · pytorch/pytorch@086987e · GitHub
[go: up one dir, main page]

Skip to content

Commit 086987e

Browse files
committed
add gaurd else true, guard else false and gaurd_size_oblivious in decompositions.py
ghstack-source-id: 424fdbf Pull Request resolved: #148430
1 parent 0bd2caa commit 086987e

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

torch/_inductor/decomposition.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
ELEMENTWISE_TYPE_PROMOTION_KIND,
3232
type_to_dtype,
3333
)
34-
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious
34+
from torch.fx.experimental.symbolic_shapes import (
35+
definitely_true,
36+
guard_else_false,
37+
guard_else_true,
38+
)
3539

3640
from . import config, inductor_prims
3741
from .utils import (
@@ -261,13 +265,13 @@ def bmm(
261265
batch2: torch.Tensor,
262266
) -> torch.Tensor:
263267
if config.coordinate_descent_tuning and self.device.type != "cpu":
264-
if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious(
268+
if guard_else_false(self.shape[1] == 1) or guard_else_false(
265269
batch2.shape[2] == 1
266270
):
267271
out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
268272
return out
269273
if self.device.type == "cpu":
270-
if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious(
274+
if guard_else_false(self.size(1) == 1) and guard_else_false(
271275
batch2.size(-1) == 1
272276
):
273277
counters["inductor"]["decompose_bmm"] += 1
@@ -287,16 +291,14 @@ def addmm(
287291
alpha: torch.types.Number = 1,
288292
) -> torch.Tensor:
289293
if self.device.type == "cpu":
290-
if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious(
291-
mat2.size(-1) == 1
292-
):
294+
if guard_else_false(mat1.size(0) == 1) and guard_else_false(mat2.size(-1) == 1):
293295
counters["inductor"]["decompose_addmm"] += 1
294296
out = torch.sum(
295297
mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
296298
).unsqueeze(0)
297299
return alpha * out + beta * self
298300
if (
299-
guard_size_oblivious(mat1.size(0) == 1)
301+
guard_else_false(mat1.size(0) == 1)
300302
and definitely_true(mat2.size(0) <= 16)
301303
and definitely_true(mat2.size(1) <= 16)
302304
):
@@ -315,21 +317,21 @@ def mm(
315317
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
316318
# todo: Look into why and fix it (hopefully)
317319
if config.coordinate_descent_tuning and self.device.type != "cpu":
318-
if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious(
320+
if guard_else_false(self.shape[0] == 1) or guard_else_false(
319321
input2.shape[1] == 1
320322
):
321323
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
322324
if self.device.type == "cpu":
323325
if (
324-
guard_size_oblivious(self.size(-1) == 1)
325-
and guard_size_oblivious(self.size(0) > 0)
326-
and guard_size_oblivious(input2.size(0) == 1)
326+
guard_else_false(self.size(-1) == 1)
327+
and guard_else_true(self.size(0) > 0)
328+
and guard_else_false(input2.size(0) == 1)
327329
and (self.dtype == input2.dtype)
328330
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
329331
):
330332
counters["inductor"]["decompose_mm"] += 1
331333
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
332-
if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
334+
if guard_else_false(self.size(0) == 1) and guard_else_false(
333335
input2.size(-1) == 1
334336
):
335337
counters["inductor"]["decompose_mm"] += 1
@@ -348,8 +350,6 @@ def cat(
348350
tensors: list[torch.Tensor],
349351
dim: int = 0,
350352
) -> torch.Tensor:
351-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
352-
353353
def non_empty_tensor(x: torch.Tensor) -> bool:
354354
# For better or worse, this is a valid cat:
355355
#
@@ -367,10 +367,10 @@ def non_empty_tensor(x: torch.Tensor) -> bool:
367367
# runtime assert forcing u0 to be zero. So if this hasn't happened,
368368
# we know that the unbacked SymInt has appropriate size and there are
369369
# no problems.
370-
if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
370+
if len(x.shape) == 1 and guard_else_false(x.shape[0] == 0):
371371
return False
372372

373-
if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
373+
if dim < len(x.shape) and guard_else_false(x.shape[dim] == 0):
374374
return False
375375

376376
return True

torch/fx/experimental/symbolic_shapes.py

+25
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,31 @@ def _symint_wrap(s: sympy.Symbol) -> SymInt:
11491149
return symbol_to_path
11501150

11511151

1152+
# This is used for size oblivious reasoning to avoid 0/1 specializations.
1153+
def guard_else_false(a: BoolLikeType) -> bool:
1154+
"""
1155+
try to gaurd a, if data dependent error encountered just return false.
1156+
"""
1157+
if isinstance(a, SymBool):
1158+
try:
1159+
guard_bool(a)
1160+
except GuardOnDataDependentSymNode:
1161+
return False
1162+
return bool(a)
1163+
1164+
1165+
def guard_else_true(a: BoolLikeType) -> bool:
1166+
"""
1167+
try to gaurd a, if data dependent error encountered just return true.
1168+
"""
1169+
if isinstance(a, SymBool):
1170+
try:
1171+
guard_bool(a)
1172+
except GuardOnDataDependentSymNode:
1173+
return True
1174+
return bool(a)
1175+
1176+
11521177
def definitely_true(a: BoolLikeType) -> bool:
11531178
"""
11541179
Returns True only if we can tell that a is True, possibly introducing

0 commit comments

Comments
 (0)
0