8000 Add remaining method and tests for dtype propagation by eellison · Pull Request #140057 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add remaining method and tests for dtype propagation #140057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add remaining method and tests
[ghstack-poisoned]
  • Loading branch information
eellison committed Nov 7, 2024
commit e2263ba6ec909f6f2906bcb64a4ade15aa4a6868
83 changes: 83 additions & 0 deletions test/inductor/test_op_dtype_prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Owner(s): ["module: inductor"]
import importlib
import os
import sys

import torch
from torch._dynamo.utils import disable_cache_limit
from torch._inductor import config
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)


importlib.import_module("functorch")
importlib.import_module("filelock")


from torch._inductor.lowering import lowerings
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.inductor_utils import HAS_GPU


unique_pointwise_op_names = set()

for op in lowerings:
if not isinstance(op, torch._ops.OpOverload):
continue

if torch.Tag.pointwise not in op.tags:
continue

if op._schema.is_mutable:
continue

op_name = (op.name().split("::")[-1]).split(".")[0]
unique_pointwise_op_names.add(op_name)

pointwise_ops = [
op
for op in op_db
if op.name in unique_pointwise_op_names and "reduction" not in op.variant_test_name
]


class TestCase(InductorTestCase):
@ops(
pointwise_ops,
allowed_dtypes=(
torch.float32,
torch.float64,
torch.int32,
torch.int64,
torch.bool,
),
)
# @config.patch("triton.codegen_upcast_to_fp32", False) # TODO enable
@config.patch("test_configs.runtime_triton_dtype_assert", True)
@disable_cache_limit()
def test_op_dtype_propagation(self, op, dtype):
def run(op, args, kwargs):
return op(*args, **kwargs)

sample_inputs_itr = op.sample_inputs("cuda", dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
out = run(op.get_op(), args, kwargs)
out_c = torch.compile(run)(op.get_op(), args, kwargs)
self.assertEqual(out, out_c)


instantiate_device_type_tests(TestCase, globals(), only_for=("cuda",))

if __name__ == "__main__":
from torch._inductor.test_case import run_tests

if HAS_GPU:
run_tests(needs="filelock")
18 changes: 16 additions & 2 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
)

import sympy
from sympy.printing.printer import Printer

import torch
import torch.fx
from sympy.printing.printer import Printer
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -2010,12 +2010,13 @@ def inner(*args, **kwargs):
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()

idx = 0

def do_cse(v):
# TODO - throw on default
output_dtype = getattr(
dtype_handler,
name,
dtype_handler.default_handler,
)(*args)

csevar = V.kernel.cse.generate(
Expand All @@ -2024,7 +2025,20 @@ def do_cse(v):
bounds=bounds,
dtype=output_dtype,
)

nonlocal idx
if config.test_configs.runtime_triton_dtype_assert:
from torch._inductor.codegen.triton import triton_type

if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[idx]
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
idx += 1

csevar.update_on_args(name, args, kwargs)

return csevar

return pytree.tree_map(do_cse, value)
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,9 @@ def sigmoid(x):
@staticmethod
def signbit(x):
# XX: This is wrong for the value -0.0 in floating point
return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
return (
f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0"
)

@staticmethod
def fmod(a, b):
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,8 @@ class trace:
class test_configs:
force_extern_kernel_in_multi_template = False

runtime_triton_dtype_assert = False


if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
Expand Down
Loading
Loading
0