|
20 | 20 | from collections import defaultdict
|
21 | 21 | from torch import inf
|
22 | 22 | from torch.nn import Buffer, Parameter
|
23 |
| -from torch.testing._internal import opinfo |
| 23 | +from torch.testing._internal import composite_compliance, opinfo |
24 | 24 | from torch.testing._internal.common_utils import \
|
25 | 25 | (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI,
|
26 | 26 | NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
|
|
48 | 48 | import operator
|
49 | 49 |
|
50 | 50 | test_consistency_op_db = copy.deepcopy(op_db)
|
| 51 | +test_cow_inputs_op_db = copy.deepcopy(op_db) |
51 | 52 | test_error_inputs_op_db = copy.deepcopy(op_db)
|
52 | 53 |
|
53 | 54 | # Add bicubic2d_aa to test_consistency_op_db
|
@@ -12049,6 +12050,183 @@ def test_fmax_mixed_dtypes(self, device):
|
12049 | 12050 | self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
|
12050 | 12051 |
|
12051 | 12052 |
|
| 12053 | +class TestCOWInputs(TestCase): |
| 12054 | + # Tests that MPS ops do not mutate the underlying data of COW inputs. |
| 12055 | + # Materialization is allowed, but the original data buffer should never be |
| 12056 | + # written to. |
| 12057 | + # TODO: When we enable the `test_cow_input` test from `test_ops.py` for MPS, |
| 12058 | + # we can remove this test. |
| 12059 | + @ops(test_cow_inputs_op_db, allowed_dtypes=(torch.float,)) |
| 12060 | + def test_cow_input_not_mutated(self, device, dtype, op): |
| 12061 | + samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) |
| 12062 | + |
| 12063 | + def is_strided_tensor(arg): |
| 12064 | + return torch.is_tensor(arg) and arg.layout == torch.strided |
| 12065 | + |
| 12066 | + def check_cow_input( |
| 12067 | + arg_copy, |
| 12068 | + arg_raw, |
| 12069 | + idx_or_kw, |
| 12070 | + backward_or_forward="forward", |
| 12071 | + ): |
| 12072 | + arg_name = ( |
| 12073 | + f"Argument {idx_or_kw}" |
| 12074 | + if isinstance(idx_or_kw, int) |
| 12075 | + else f"Keyword argument '{idx_or_kw}'" |
| 12076 | + ) + f" during {backward_or_forward} call" |
| 12077 | + |
| 12078 | + if is_strided_tensor(arg_raw): |
| 12079 | + self.assertTrue( |
| 12080 | + torch._C._is_cow_tensor(arg_raw), |
| 12081 | + msg=( |
| 12082 | + f"{arg_name} raw input should remain COW, but it " |
| 12083 | + "unexpectedly materialized." |
| 12084 | + ), |
| 12085 | + ) |
| 12086 | + # TODO: Make `torch.allclose` avoid materializing. We have to |
| 12087 | + # lazy clone arg_raw here before the comparison to prevent it |
| 12088 | + # from materializing and messing up subsequent checks. |
| 12089 | + arg_lazy_cloned = torch._lazy_clone(arg_raw) |
| 12090 | + print('------------------------------') |
| 12091 | + print('original value:') |
| 12092 | + print(arg_copy) |
| 12093 | + print('value after op:') |
| 12094 | + print(arg_lazy_cloned) |
| 12095 | + print('------------------------------') |
| 12096 | + self.assertTrue( |
| 12097 | + torch.allclose( |
| 12098 | + arg_lazy_cloned, arg_copy, rtol=0, atol=0, equal_nan=True |
| 12099 | + ), |
| 12100 | + msg=( |
| 12101 | + f"{arg_name} COW input data was mutated." |
| 12102 | + ), |
| 12103 | + ) |
| 12104 | + |
| 12105 | + for sample in samples: |
| 12106 | + args_raw = [sample.input] + list(sample.args) |
| 12107 | + kwargs_raw = sample.kwargs |
| 12108 | + |
| 12109 | + # Eagerly cloned inputs used to keep track of the original values of |
| 12110 | + # inputs |
| 12111 | + args_copy = [] |
| 12112 | + kwargs_copy = {} |
| 12113 | + |
| 12114 | + # The lazy cloned inputs to be passed to the op. |
| 12115 | + args_lazy_cloned = [] |
| 12116 | + kwargs_lazy_cloned = {} |
| 12117 | + |
| 12118 | + # In order to keep the original args/kwargs_raw COW in cases where |
| 12119 | + # the op materializes the input, we need to start with three sets of |
| 12120 | + # COW inputs. |
| 12121 | + args_lazy_cloned_2 = [] |
| 12122 | + kwargs_lazy_cloned_2 = {} |
| 12123 | + |
| 12124 | + leaf_tensors = composite_compliance.gather_leaf_tensors(args_raw, kwargs_raw) |
| 12125 | + |
| 12126 | + # Convert strided tensor inputs to COW tensors and make copies of |
| 12127 | + # all inputs |
| 12128 | + for idx, arg in enumerate(args_raw): |
| 12129 | + if is_strided_tensor(arg): |
| 12130 | + args_copy.append(arg.detach().clone()) |
| 12131 | + args_lazy_cloned.append(torch._lazy_clone(arg)) |
| 12132 | + args_lazy_cloned_2.append(torch._lazy_clone(arg)) |
| 12133 | + else: |
| 12134 | + if torch.is_tensor(arg): |
| 12135 | + args_copy.append(arg.detach().clone()) |
| 12136 | + else: |
| 12137 | + args_copy.append(copy.deepcopy(arg)) |
| 12138 | + args_lazy_cloned.append(arg) |
| 12139 | + args_lazy_cloned_2.append(arg) |
| 12140 | + |
| 12141 | + for kw, arg in kwargs_raw.items(): |
| 12142 | + if is_strided_tensor(arg): |
| 12143 | + kwargs_copy[kw] = arg.detach().clone() |
| 12144 | + kwargs_lazy_cloned[kw] = torch._lazy_clone(arg) |
| 12145 | + kwargs_lazy_cloned_2[kw] = torch._lazy_clone(arg) |
| 12146 | + else: |
| 12147 | + if torch.is_tensor(arg): |
| 12148 | + kwargs_copy[kw] = arg.detach().clone() |
| 12149 | + else: |
| 12150 | + kwargs_copy[kw] = copy.deepcopy(arg) |
| 12151 | + kwargs_lazy_cloned[kw] = arg |
| 12152 | + kwargs_lazy_cloned_2[kw] = arg |
| 12153 | + |
| 12154 | + # Call forward op |
| 12155 | + try: |
| 12156 | + results_raw = op.get_op()(*args_lazy_cloned, **kwargs_lazy_cloned) |
| 12157 | + except NotImplementedError: |
| 12158 | + raise unittest.SkipTest("Op not implemented") from None |
| 12159 | + |
| 12160 | + # Check that COW inputs remain COW after the forward op is executed |
| 12161 | + for idx, arg in enumerate(args_lazy_cloned): |
| 12162 | + check_cow_input(args_copy[idx], args_raw[idx], idx) |
| 12163 | + |
| 12164 | + for kw, arg in kwargs_lazy_cloned.items(): |
| 12165 | + check_cow_input(kwargs_copy[kw], kwargs_raw[kw], kw) |
| 12166 | + |
| 12167 | + # Call backward op if it is supported. This part of the test is |
| 12168 | + # based on `composite_compliance.check_backward_formula` |
| 12169 | + if ( |
| 12170 | + op.supports_autograd |
| 12171 | + and len(leaf_tensors) > 0 |
| 12172 | + and not op.skip_cow_input_backward |
| 12173 | + ): |
| 12174 | + if sample.output_process_fn_grad is not None: |
| 12175 | + results_raw = sample.output_process_fn_grad(results_raw) |
| 12176 | + |
| 12177 | + leaf_results = pytree.tree_leaves(results_raw) |
| 12178 | + results = [ |
| 12179 | + r |
| 12180 | + for r in leaf_results |
| 12181 | + if isinstance(r, torch.Tensor) and r.requires_grad |
| 12182 | + ] |
| 12183 | + |
| 12184 | + all_results_strided = all( |
| 12185 | + is_strided_tensor(result) for result in results |
| 12186 | + ) |
| 12187 | + |
| 12188 | + # Only test backward if the results are strided tensors |
| 12189 | + if all_results_strided: |
| 12190 | + output_grads_raw = [ |
| 12191 | + torch.ones(r.shape, device=r.device, dtype=r.dtype) |
| 12192 | + for r in results |
| 12193 | + ] |
| 12194 | + output_grads_copy = [] |
| 12195 | + output_grads_lazy_cloned = [] |
| 12196 | + output_grads_lazy_cloned_2 = [] |
| 12197 | + |
| 12198 | + # Convert output grads to COW tensors and make copies |
| 12199 | + for output_grad in output_grads_raw: |
| 12200 | + output_grads_copy.append(output_grad.detach().clone()) |
| 12201 | + output_grads_lazy_cloned.append(torch._lazy_clone(output_grad)) |
| 12202 | + output_grads_lazy_cloned_2.append(torch._lazy_clone(output_grad)) |
| 12203 | + |
| 12204 | + torch.autograd.grad( |
| 12205 | + results, |
| 12206 | + leaf_tensors, |
| 12207 | + output_grads_lazy_cloned, |
| 12208 | + allow_unused=True, |
| 12209 | + retain_graph=True, |
| 12210 | + ) |
| 12211 | + |
| 12212 | + # Check that COW inputs remain COW after the backward op is executed |
| 12213 | + for idx, arg in enumerate(args_lazy_cloned): |
| 12214 | + check_cow_input( |
| 12215 | + args_copy[idx], |
| 12216 | + args_raw[idx], |
| 12217 | + idx, |
| 12218 | + backward_or_forward="backward", |
| 12219 | + ) |
| 12220 | + |
| 12221 | + # Check that COW inputs remain COW after the backward op is executed |
| 12222 | + for idx, output_grad in enumerate(output_grads_lazy_cloned): |
| 12223 | + check_cow_input( |
| 12224 | + output_grads_copy[idx], |
| 12225 | + output_grads_raw[idx], |
| 12226 | + f"output grad {idx}", |
| 12227 | + backward_or_forward="backward", |
| 12228 | + ) |
| 12229 | + |
12052 | 12230 |
|
12053 | 12231 | class TestErrorInputs(TestCase):
|
12054 | 12232 | _ignore_not_implemented_error = True
|
@@ -12342,6 +12520,7 @@ def test_metal_capture(self):
|
12342 | 12520 | instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
|
12343 | 12521 | instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
|
12344 | 12522 | instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
|
| 12523 | +instantiate_device_type_tests(TestCOWInputs, globals(), allow_mps=True, only_for="mps") |
12345 | 12524 | instantiate_parametrized_tests(TestLogical)
|
12346 | 12525 | instantiate_parametrized_tests(TestMPS)
|
12347 | 12526 | instantiate_parametrized_tests(TestSDPA)
|
|
0 commit comments