|
1 | 1 | # Owner(s): ["module: inductor"]
|
2 | 2 | import copy
|
| 3 | +import functools |
3 | 4 | import itertools
|
4 | 5 | import os
|
5 | 6 | import unittest
|
|
32 | 33 | from torch._inductor.utils import run_and_get_code
|
33 | 34 | from torch._inductor.virtualized import V
|
34 | 35 | from torch.fx.experimental.proxy_tensor import make_fx
|
| 36 | +from torch.library import register_fake |
35 | 37 | from torch.testing import FileCheck
|
36 | 38 | from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
|
37 | 39 | from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
|
@@ -1655,6 +1657,78 @@ def my_func_static(x, w, epsilon):
|
1655 | 1657 | test, (code,) = run_and_get_code(my_func_static, *inputs)
|
1656 | 1658 | self.assertTrue("static_scaled_int8_quant" not in code)
|
1657 | 1659 |
|
| 1660 | + def test_mutable_op_nonview_inputs_register_replacement(self): |
| 1661 | + @torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"}) |
| 1662 | + def foo_inplace(x: torch.Tensor) -> None: |
| 1663 | + x.add_(1) |
| 1664 | + |
| 1665 | + # NOTE: only returning None is supported; the custom op cannot return `out` because it's part of op input. |
| 1666 | + @torch.library.custom_op("mylib::bar", mutates_args={"out"}) |
| 1667 | + def bar_out(x: torch.Tensor, out: torch.Tensor) -> None: |
| 1668 | + out.copy_(x + 2) |
| 1669 | + |
| 1670 | + @register_fake("mylib::bar") |
| 1671 | + def bar_out_fake(x: torch.Tensor, out: torch.Tensor) -> None: |
| 1672 | + return None |
| 1673 | + |
| 1674 | + @torch.library.custom_op("mylib::foobar_out", mutates_args={"out"}) |
| 1675 | + def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None: |
| 1676 | + x.add_(1) |
| 1677 | + out.copy_(x + 7) # intentionally different from bar_out |
| 1678 | + |
| 1679 | + def mutable_ops_pattern(x, out): |
| 1680 | + foo_inplace(x) |
| 1681 | + bar_out(x, out) |
| 1682 | + return out |
| 1683 | + |
| 1684 | + def mutable_ops_replacement(x, out): |
| 1685 | + foobar_out(x, out) |
| 1686 | + return out |
| 1687 | + |
| 1688 | + inp = torch.randn(3) |
| 1689 | + |
| 1690 | + my_patterns = PatternMatcherPass() |
| 1691 | + register_replacement( |
| 1692 | + search_fn=mutable_ops_pattern, |
| 1693 | + replace_fn=mutable_ops_replacement, |
| 1694 | + example_inputs=[inp.clone().detach(), inp.clone().detach()], |
| 1695 | + trace_fn=functools.partial(fwd_only, apply_auto_functionalize=True), |
| 1696 | + pass_dicts=my_patterns, |
| 1697 | + ) |
| 1698 | + |
| 1699 | + count = 0 |
| 1700 | + |
| 1701 | + def custom_pass(graph: torch.fx.Graph): |
| 1702 | + nonlocal count |
| 1703 | + count = my_patterns.apply(graph) |
| 1704 | + |
| 1705 | + def custom_backend(graph: torch.fx.GraphModule, example_inputs): |
| 1706 | + from torch._inductor import config |
| 1707 | + |
| 1708 | + current_config = config.shallow_copy_dict() |
| 1709 | + from torch._inductor.compile_fx import compile_fx |
| 1710 | + |
| 1711 | + current_config["post_grad_custom_post_pass"] = custom_pass |
| 1712 | + return compile_fx(graph, example_inputs, config_patches=current_config) |
| 1713 | + |
| 1714 | + # user-function |
| 1715 | + @torch.compile(fullgraph=True, backend=custom_backend) |
| 1716 | + def f(x): |
| 1717 | + x = x.clone() |
| 1718 | + out = torch.zeros_like(x) |
| 1719 | + foo_inplace(x) |
| 1720 | + bar_out(x, out) |
| 1721 | + return out |
| 1722 | + |
| 1723 | + def f_replaced(x): |
| 1724 | + x = x.clone() |
| 1725 | + out = torch.zeros_like(x) |
| 1726 | + foobar_out(x, out) |
| 1727 | + return out |
| 1728 | + |
| 1729 | + self.assertEqual(f(inp.clone().detach()), f_replaced(inp.clone().detach())) |
| 1730 | + self.assertEqual(count, 1) |
| 1731 | + |
1658 | 1732 |
|
1659 | 1733 | if __name__ == "__main__":
|
1660 | 1734 | if IS_LINUX and HAS_GPU:
|
|
0 commit comments