8000 Compute bounds for the variables created during codegen (#123100) · pytorch/pytorch@bb668c6 · GitHub
[go: up one dir, main page]

Skip to content

Commit bb668c6

Browse files
lezcanopytorchmergebot
authored andcommitted
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: #123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10
1 parent 3827810 commit bb668c6

File tree

7 files changed

+109
-27
lines changed

7 files changed

+109
-27
lines changed

test/inductor/test_torchinductor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9611,7 +9611,12 @@ def test_randint_int64_mod(self):
96119611
# This used to not compile due to a wrong return type of randint64_cpu
96129612
# See https://github.com/pytorch/pytorch/issues/117435
96139613
def fn(n):
9614-
return torch.randint(low=-5, high=5, size=(n,), dtype=torch.int64) % 10
9614+
return (
9615+
torch.randint(
9616+
low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device
9617+
)
9618+
% 10
9619+
)
96159620

96169621
res = torch.compile(fn)(20)
96179622
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())

torch/_inductor/codegen/common.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
2828
from torch.utils import _pytree as pytree
2929
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
30-
from torch.utils._sympy.value_ranges import ValueRanges
30+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
3131

3232
from .. import config, metrics
3333
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
@@ -269,6 +269,7 @@ def deduce_node_dtype(self, node: torch.fx.Node):
269269
if node.target in (
270270
"get_index",
271271
"index_expr",
272+
"randint64",
272273
):
273274
return torch.int64
274275

@@ -529,7 +530,7 @@ def constant(value, dtype):
529530

530531
@staticmethod
531532
def reciprocal(x):
532-
return ops.truediv("1", x)
533+
return ops.truediv(ops.constant(1, torch.int32), x)
533534

534535
@staticmethod
535536
def square(x):
@@ -566,7 +567,11 @@ def bitwise_right_shift(x, y):
566567
@staticmethod
567568
def remainder(a, b):
568569
r = ops.mod(a, b)
569-
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
570+
cond = ops.and_(
571+
ops.ne(r, ops.constant(0, torch.int32)),
572+
ops.ne(ops.signbit(r), ops.signbit(b)),
573+
)
574+
return ops.where(cond, ops.add(r, b), r)
570575

571576
@staticmethod
572577
def load_seed(name, offset):
@@ -1473,31 +1478,67 @@ def __enter__(self):
14731478
# TODO: hoist this to top level
14741479
class CSEProxy:
14751480
self.name = "CSEProxy"
1481+
vr_analysis = ValueRangeAnalysis()
14761482

14771483
@staticmethod
14781484
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
14791485
def inner(*args, **kwargs):
1480-
# TritonTemplateKernel has no current_node
1481-
buf_bounds = ValueRanges.unknown()
1482-
if (
1483-
fx_node := getattr(V.interpreter, "current_node", None)
1484-
) and fx_node.target == name:
1485-
assert isinstance(self.node_to_bounds, dict)
1486-
buf_bounds = self.node_to_bounds.get(
1487-
fx_node, ValueRanges.unknown()
1488-
)
1486+
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
14891487

14901488
value = getattr(parent_handler, name)(*args, **kwargs) # type F438 : ignore[has-type]
14911489

14921490
def do_cse(v):
1493-
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
1491+
csevar = self.cse.generate(self.compute, v, bounds=bounds)
14941492
csevar.update_on_args(name, args, kwargs)
14951493
return csevar
14961494

14971495
return pytree.tree_map(do_cse, value)
14981496

14991497
return inner
15001498

1499+
@staticmethod
1500+
def _bound_variable(name, *args, **kwargs):
1501+
"""
1502+
If the variable comes from an FX node, we forward the bound we have already computed
1503+
Else, if the variable when codegen'ing another op, we try to compute its bounds
1504+
"""
1505+
from ..select_algorithm import TritonTemplateKernel
1506+
1507+
if isinstance(V.kernel, TritonTemplateKernel):
1508+
return ValueRanges.unknown()
1509+
1510+
fx_node = V.interpreter.current_node
1511+
if fx_node.target == name:
1512+
assert isinstance(self.node_to_bounds, dict)
1513+
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
1514+
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
1515+
# These create lots of inner strings. We would need to compute the bounds at the ops
1516+
# We will also likely not get much from computing VRs on these nodes
1517+
if any(
1518+
s in fx_node.target
1519+
for s in ("set_indirect", "reduction", "scan")
1520+
):
1521+
return ValueRanges.unknown()
1522+
1523+
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
1524+
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
1525+
1526+
# If there is no FX bound but we know how to compute one we do so
1527+
assert not kwargs
1528+
1529+
def arg_to_bound(x):
1530+
if isinstance(x, CSEVariable):
1531+
return x.bounds
1532+
elif isinstance(x, sympy.Expr):
1533+
return bound_sympy(x)
1534+
else:
1535+
return x
1536+
1537+
arg_bounds = list(map(arg_to_bound, args))
1538+
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
1539+
else:
1540+
return ValueRanges.unknown()
1541+
15011542
@staticmethod
15021543
def indirect_indexing(
15031544
var: CSEVariable, size: sympy.Expr, check: bool = True

torch/_inductor/codegen/cpp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from ..utils import (
3636
cache_on_self,
37+
get_bounds_index_expr,
3738
get_fused_kernel_name,
3839
is_welford_reduction,
3940
parallel_num_threads,
@@ -841,7 +842,7 @@ def mod(a, b):
841842
@staticmethod
842843
def constant(val, dtype):
843844
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
844-
assert opt_ctx and opt_ctx.dtype is not None
845+
assert opt_ctx and opt_ctx.dtype is not None, opt_ctx
845846
dtype = opt_ctx.dtype
846847
if dtype in DTYPE_LOWP_FP:
847848
# Since load promotes all half-precision inputs to float, constants
@@ -854,7 +855,12 @@ def index_expr(expr, dtype):
854855
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
855856
assert opt_ctx and opt_ctx.dtype is not None
856857
dtype = opt_ctx.dtype
857-
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
858+
859+
idx_str = cexpr(V.kernel.rename_indexing(expr))
860+
var = V.kernel.cse.generate(
861+
V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
862+
)
863+
return ops.to_dtype(var, dtype)
858864

859865
@staticmethod
860866
def masked(mask, body, other):
@@ -1451,7 +1457,10 @@ def index_expr(expr, dtype):
14511457
if stride == 0:
14521458
return CppOverrides.index_expr(expr, dtype)
14531459
elif stride is not None:
1454-
value = ops.to_dtype(cexpr(index), dtype)
1460+
idx = V.kernel.cse.generate(
1461+
V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
1462+
)
1463+
value = ops.to_dtype(idx, dtype)
14551464
if isinstance(value, OpsValue):
14561465
value = value.value
14571466
csevar = V.kernel.arange(value, stride)

torch/_inductor/codegen/triton.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
5959
from ..utils import (
6060
cache_on_self,
61+
get_bounds_index_expr,
6162
get_dtype_size,
6263
get_fused_kernel_name,
6364
get_kernel_metadata,
@@ -619,7 +620,7 @@ def relu(x):
619620
elif bug == "accuracy":
620621
return f"{x} + 1"
621622
elif bug is None:
622-
return ops.maximum("0", x)
623+
return ops.maximum(ops.constant(0, torch.int32), x)
623624
else:
624625
raise AssertionError(
625626
f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
@@ -864,11 +865,9 @@ def floordiv(a, b):
864865

865866
@staticmethod
866867
def sign(x):
867-
def to_int(s):
868-
return f"{s}.to(tl.int8)"
869-
870-
left = to_int(ops.lt("0", x))
871-
right = to_int(ops.lt(x, "0"))
868+
z = ops.constant(0, torch.int32)
869+
left = ops.to_dtype((ops.lt(z, x)), torch.int8)
870+
right = ops.to_dtype((ops.lt(x, z)), torch.int8)
872871
sub = ops.sub(left, right)
873872
return f"{sub}.to({x}.dtype)"
874873

@@ -916,8 +915,9 @@ def constant(cls, value, dtype):
916915
def index_expr(cls, expr, dtype):
917916
indexing = V.kernel.indexing(expr, block_ptr=False)
918917
assert isinstance(indexing, IndexingOptions)
919-
# This is called from CSEProxy.__getattr__, so we'll set the bounds there
920-
var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
918+
var = V.kernel.cse.generate(
919+
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
920+
)
921921

922922
if dtype not in {torch.int32, torch.int64}:
923923
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
@@ -929,10 +929,14 @@ def masked(mask, body, other):
929929
with V.kernel.mask_loads(mask) as new_mask:
930930
result = body()
931931

932+
# Remove once CSEVariables track the dtype
933+
if result.bounds.is_bool:
934+
other = bool(other)
932935
# Take dtype from result to prevent accidental promotion
933936
other = V.kernel.cse.generate(
934937
V.kernel.compute,
935938
f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)",
939+
bounds=ValueRanges.wrap(other),
936940
)
937941
return ops.where(new_mask, result, other)
938942

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ def is_fbcode():
355355
# assert that indirect indexing does not read / write out of bounds
356356
assert_indirect_indexing = True
357357

358+
# compute CSE bounds on variables that do not appear in the FX graph
359+
compute_all_bounds = False
360+
358361
# constant folding on the joint graph
359362
joint_graph_constant_folding = True
360363

torch/_inductor/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from torch.fx.passes.shape_prop import ShapeProp
5050
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
5151
from torch.utils._sympy.symbol import make_symbol, SymT
52+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
5253
from . import config
5354
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
5455

@@ -539,6 +540,20 @@ def sympy_str(expr: sympy.Expr) -> str:
539540
return str(expr)
540541

541542

543+
def get_bounds_index_expr(index):
544+
from .virtualized import V
545+
546+
# If this expression does not come from an FX node, we compute its bounds
547+
if (
548+
config.compute_all_bounds
549+
and (fx_node := getattr(V.interpreter, "current_node", None))
550+
and fx_node.target != "index_expr"
551+
):
552+
return bound_sympy(index)
553+
else:
554+
return ValueRanges.unknown()
555+
556+
542557
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
543558
"""
544559
Used to generate an integer-nonnegative symbol.

torch/utils/_sympy/value_ranges.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None:
138138
if not sympy_generic_le(lower, upper):
139139
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
140140
except TypeError:
141-
raise TypeError(f"Could not compare {lower} <= {upper}")
141+
raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200
142142
# Because this is a frozen class
143143
object.__setattr__(self, "lower", lower)
144144
object.__setattr__(self, "upper", upper)
@@ -340,6 +340,9 @@ class SymPyValueRangeAnalysis:
340340

341341
@staticmethod
342342
def constant(value, dtype):
343+
if isinstance(value, ValueRanges):
344+
assert value.is_singleton()
345+
value = value.lower
343346
# NB: value is NOT a sympy expression, it's a constant!
344347
is_python = isinstance(value, (int, float, bool))
345348
assert is_python or isinstance(
@@ -663,7 +666,9 @@ def where(a, b, c):
663666
b = ValueRanges.wrap(b)
664667
c = ValueRanges.wrap(c)
665668
a = a.boolify()
666-
assert b.is_bool == c.is_bool
669+
# We sometimes write unknown without specifying the type correctly
670+
# In particular, we do that when initialising the bounds for loads in bounds.py
671+
assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
667672
if b.is_bool:
668673
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
669674
else:

0 commit comments

Comments
 (0)
0