|
1 | 1 | # Owner(s): ["module: inductor"]
|
2 | 2 | import copy
|
| 3 | +import functools |
3 | 4 | import itertools
|
4 | 5 | import os
|
5 | 6 | import unittest
|
|
32 | 33 | from torch._inductor.utils import run_and_get_code
|
33 | 34 | from torch._inductor.virtualized import V
|
34 | 35 | from torch.fx.experimental.proxy_tensor import make_fx
|
| 36 | +from torch.library import register_fake |
35 | 37 | from torch.testing import FileCheck
|
36 | 38 | from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
|
37 | 39 | from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
|
@@ -1655,6 +1657,99 @@ def my_func_static(x, w, epsilon):
|
1655 | 1657 | test, (code,) = run_and_get_code(my_func_static, *inputs)
|
1656 | 1658 | self.assertTrue("static_scaled_int8_quant" not in code)
|
1657 | 1659 |
|
| 1660 | + def test_mutable_op_nonview_inputs_register_replacement(self): |
| 1661 | + @torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"}) |
| 1662 | + def foo_inplace(x: torch.Tensor) -> None: |
| 1663 | + x.add_(1) |
| 1664 | + |
| 1665 | + # NOTE: only returning None is supported; the custom op cannot return `out` because it's part of op input. |
| 1666 | + @torch.library.custom_op("mylib::bar", mutates_args={"out"}) |
| 1667 | + def bar_out(x: torch.Tensor, out: torch.Tensor) -> None: |
| 1668 | + out.copy_(x + 2) |
| 1669 | + |
| 1670 | + @torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"}) |
| 1671 | + def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None: |
| 1672 | + x.add_(1) |
| 1673 | + out.copy_(x + 7) # intentionally different from bar_out |
| 1674 | + |
| 1675 | + def mutable_ops_pattern(x, out): |
| 1676 | + foo_inplace(x) |
| 1677 | + bar_out(x, out) |
| 1678 | + return x, out |
| 1679 | + |
| 1680 | + def mutable_ops_replacement(x, out): |
| 1681 | + foobar_out(x, out) |
| 1682 | + return x, out |
| 1683 | + |
| 1684 | + inp = torch.randn(3) |
| 1685 | + |
| 1686 | + my_patterns = PatternMatcherPass() |
| 1687 | + register_replacement( |
| 1688 | + search_fn=mutable_ops_pattern, |
| 1689 | + replace_fn=mutable_ops_replacement, |
| 1690 | + example_inputs=[inp.clone().detach(), inp.clone().detach()], |
| 1691 | + trace_fn=functools.partial(fwd_only, apply_auto_functionalize=True), |
| 1692 | + pass_dicts=my_patterns, |
| 1693 | + ) |
| 1694 | + |
| 1695 | + count = 0 |
| 1696 | + |
| 1697 | + def custom_pass(graph: torch.fx.Graph): |
| 1698 | + nonlocal count |
| 1699 | + count = my_patterns.apply(graph) |
| 1700 | + |
| 1701 | + def custom_backend(graph: torch.fx.GraphModule, example_inputs): |
| 1702 | + from torch._inductor import config |
| 1703 | + |
| 1704 | + current_config = config.shallow_copy_dict() |
| 1705 | + from torch._inductor.compile_fx import compile_fx |
| 1706 | + |
| 1707 | + current_config["post_grad_custom_post_pass"] = custom_pass |
| 1708 | + return compile_fx(graph, example_inputs, config_patches=current_config) |
| 1709 | + |
| 1710 | + # Case 1: mutates a clone of graph input |
| 1711 | + @torch.compile(fullgraph=True, backend=custom_backend) |
| 1712 | + def f1(x): |
| 1713 | + x = x.clone() |
| 1714 | + out = torch.zeros_like(x) |
| 1715 | + foo_inplace(x) |
| 1716 | + bar_out(x, out) |
| 1717 | + return out |
| 1718 | + |
| 1719 | + def f1_replaced(x): |
| 1720 | + x = x.clone() |
| 1721 | + out = torch.zeros_like(x) |
| 1722 | + foobar_out(x, out) |
| 1723 | + return out |
| 1724 | + |
| 1725 | + f1_inp = inp.clone().detach() |
| 1726 | + f1_replaced_inp = inp.clone().detach() |
| 1727 | + f1_out = f1(f1_inp) |
| 1728 | + f1_replaced_out = f1_replaced(f1_replaced_inp) |
| 1729 | + self.assertEqual(f1_inp, f1_replaced_inp) |
| 1730 | + self.assertEqual(f1_out, f1_replaced_out) |
| 1731 | + self.assertEqual(count, 1) |
| 1732 | + |
| 1733 | + # Case 2: mutates graph input (not supported yet) |
| 1734 | + @torch.compile(fullgraph=True, backend=custom_backend) |
| 1735 | + def f2(x): |
| 1736 | + out = torch.zeros_like(x) |
| 1737 | + foo_inplace(x) |
| 1738 | + bar_out(x, out) |
| 1739 | + return out |
| 1740 | + |
| 1741 | + def f2_replaced(x): |
| 1742 | + out = torch.zeros_like(x) |
| 1743 | + foobar_out(x, out) |
| 1744 | + return out |
| 1745 | + |
| 1746 | + f2_inp = inp.clone().detach() |
| 1747 | + f2_replaced_inp = inp.clone().detach() |
| 1748 | + f2_out = f2(f2_inp) |
| 1749 | + f2_replaced_out = f2_replaced(f2_replaced_inp) |
| 1750 | + self.assertEqual(f1_inp, f1_replaced_inp) |
| 1751 | + self.assertEqual(f2_out, f2_replaced_out) |
| 1752 | + self.assertEqual(count, 1) |
1658 | 1753 |
|
1659 | 1754 | if __name__ == "__main__":
|
1660 | 1755 | if IS_LINUX and HAS_GPU:
|
|
0 commit comments