8000 Avoid overwriting COW data in MPS code · pytorch/pytorch@c1e0870 · GitHub
[go: up one dir, main page]

Skip to content

Commit c1e0870

Browse files
committed
Avoid overwriting COW data in MPS code
ghstack-source-id: 671272f Pull Request resolved: #150721
1 parent a8604ea commit c1e0870

File tree

8 files changed

+257
-12
lines changed

8 files changed

+257
-12
lines changed

aten/src/ATen/mps/MPSGeneratorImpl.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Generator createMPSGenerator(uint64_t seed_val) {
7070

7171
auto state_tensor = at::detail::empty_cpu(
7272
{(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
73-
auto rng_state = state_tensor.data_ptr<uint8_t>();
73+
auto rng_state = state_tensor.mutable_data_ptr<uint8_t>();
7474
auto current_seed = this->current_seed();
7575
auto current_offset = this->get_offset();
7676

aten/src/ATen/native/AutogradComposite.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ Tensor _lazy_clone(Tensor const& self, std::optional<c10::Device> device_opt) {
171171
if (self.device().type() == c10::kMPS) {
172172
at::detail::getMPSHooks().deviceSynchronize();
173173
}
174+
} else if (self.device().type() == c10::kMPS) {
175+
// CHECK: Do we always need to sync for MPS?
176+
at::detail::getMPSHooks().deviceSynchronize();
174177
}
175178
return Tensor(std::move(tensor));
176179
}

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,32 @@ MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = Mem
108108
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
109109

110110
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
111-
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
111+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().mutable_data());
112112
}
113113

114+
// This class wraps a tensor with an API that can obtain the underlying
115+
// `id<MTLBuffer>` while preventing COW materialization and attempting to
116+
// prevent mutations to the data. Unfortunately, there is no way to make the
117+
// compiler actually prevent mutating the data in the MPS code because Metal
118+
// APIs operate on `id<MTLBuffer>`, which resolves to `struct objc_object *`, a
119+
// pointer to non-const data.
120+
class ConstMTLBufferTensor {
121+
public:
122+
ConstMTLBufferTensor(const TensorBase& tensor) : _tensor(tensor) {}
123+
124+
// WARNING: Do not write to the buffer returned by this function.
125+
id<MTLBuffer> mtl_buffer_unsafe() const {
126+
return __builtin_bit_cast(id<MTLBuffer>, _tensor.storage().data());
127+
}
128+
129+
const TensorBase& tensor() const {
130+
return _tensor;
131+
}
132+
133+
private:
134+
const TensorBase& _tensor;
135+
};
136+
114137
class Placeholder {
115138
public:
116139
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
@@ -355,7 +378,7 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne
355378
if (C10_UNLIKELY(t.device().type() == kCPU)) {
356379
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
357380
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
358-
[encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx];
381+
[encoder setBytes:t.storage().mutable_data() length:t.element_size() atIndex:idx];
359382
} else {
360383
TORCH_CHECK(false, "Passed CPU tensor to MPS op");
361384
}
@@ -364,6 +387,25 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne
364387
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
365388
}
366389

390+
template <typename encoder_t,
391+
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
392+
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
393+
static inline void mtl_setBuffer(encoder_t encoder, ConstMTLBufferTensor b, unsigned idx) {
394+
const TensorBase& t = b.tensor();
395+
if (C10_UNLIKELY(t.device().type() == kCPU)) {
396+
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
397+
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
398+
[encoder setBytes:b.mtl_buffer_unsafe() length:t.element_size() atIndex:idx];
399+
// [encoder setBytes:getMTLBufferStorage(t) length:t.element_size() atIndex:idx];
400+
} else {
401+
TORCH_CHECK(false, "Passed CPU tensor to MPS op");
402+
}
403+
return;
404+
}
405+
[encoder setBuffer:b.mtl_buffer_unsafe() offset:t.storage_offset() * t.element_size() atIndex:idx];
406+
// [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
407+
}
408+
367409
// Implementation of setBytes for containers vs trivially copiable types must be separate
368410
// Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`,
369411
// while trivially copiable, includes padding which if copied as Metal shader parameters
@@ -395,6 +437,10 @@ inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> val,
395437
[encoder setBuffer:val offset:0 atIndex:idx];
396438
}
397439

440+
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, ConstMTLBufferTensor val, unsigned idx) {
441+
mtl_setBuffer(encoder, val, idx);
442+
}
443+
398444
template <>
399445
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const Tensor& val, unsigned idx) {
400446
mtl_setBuffer(encoder, val, idx);

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ void printTensorNDArray(const TensorBase& t) {
388388
auto selfDType = getMPSDataType(t.scalar_type());
389389

390390
// Initialize data
391-
id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
391+
id<MTLBuffer> selfBuf = ConstMTLBufferTensor(t).mtl_buffer_unsafe();
392+
// id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
392393
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
393394
dataType:selfDType] autorelease];
394395
C10_CLANG_DIAGNOSTIC_PUSH()

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ static void copy_cast_mps(at::Tensor& dst,
112112
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared;
113113
NSUInteger alignedLength = 0;
114114

115-
const void* host_dst = static_cast<const char*>(dst.storage().data()) + dst.storage_offset() * dst.itemsize();
115+
void* host_dst = static_cast<char*>(dst.storage().mutable_data()) + dst.storage_offset() * dst.itemsize();
116116
void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)dst_tensor_nbytes, &alignedLength);
117117
NSUInteger destOffset = (uintptr_t(host_dst) - uintptr_t(alignedPtr));
118118
// 4 bytes alignment required on macos for blits.
@@ -258,7 +258,8 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
258258
src = src_;
259259
}
260260
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
261-
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
261+
id<MTLBuffer> sourceBuffer = ConstMTLBufferTensor(src).mtl_buffer_unsafe();
262+
/// id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
262263

263264
// Scatter to `dst` if the memory is not contiguous
264265
// If the memory is not contiguous, it means that the tensor has strides and we would not be
@@ -295,7 +296,8 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
295296
} else if (dst_byte_offset) {
296297
auto maybeCastedSource =
297298
at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
298-
auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
299+
auto maybeCastedSourceBuffer = ConstMTLBufferTensor(maybeCastedSource).mtl_buffer_unsafe();
300+
// auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
299301
copy_cast_mps(maybeCastedSource, src, maybeCastedSourceBuffer, sourceBuffer);
300302

301303
uint64_t profile_id = getMPSProfiler().beginProfileCopy(

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
162162
getMPSProfiler().beginProfileKernel(indexSelectPSO, indexFunction, {inputTensor});
163163

164164
[computeEncoder setComputePipelineState:indexSelectPSO];
165-
mtl_setArgs(
166-
computeEncoder, indexAB, index_size, index_stride, kernelDataOffsets, inputTensor, outputTensor, num_indices);
165+
mtl_setArgs(computeEncoder,
166+
indexAB,
167+
index_size,
168+
index_stride,
169+
kernelDataOffsets,
170+
ConstMTLBufferTensor(inputTensor),
171+
// inputTensor,
172+
outputTensor,
173+
num_indices);
167174
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
168175
if (serial_index_put) {
169176
mtl_setBytes(computeEncoder, numIters, 7);

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,13 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
146146
}
147147

148148
[computeEncoder setComputePipelineState:gatherPSO];
149-
mtl_setArgs(computeEncoder, src, dst.has_storage() ? dst : output, src_sizes, src_strides, numThreads);
149+
mtl_setArgs(computeEncoder,
150+
ConstMTLBufferTensor(src),
151+
// src,
152+
dst.has_storage() ? dst : output,
153+
src_sizes,
154+
src_strides,
155+
numThreads);
150156
if (src.dim() > 4) {
151157
mtl_setBytes<int32_t>(computeEncoder, src.dim(), 5);
152158
}
@@ -192,7 +198,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
192198
}
193199

194200
[computeEncoder setComputePipelineState:scatterPSO];
195-
mtl_setArgs(computeEncoder, src, output, output_sizes, output_strides, numThreads);
201+
mtl_setArgs(computeEncoder, ConstMTLBufferTensor(src), output, output_sizes, output_strides, numThreads);
202+
// mtl_setArgs(computeEncoder, src, output, output_sizes, output_strides, numThreads);
196203
if (output.dim() > 4) {
197204
mtl_setBytes<int32_t>(computeEncoder, output.dim(), 5);
198205
}

test/test_mps.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections import defaultdict
2121
from torch import inf
2222
from torch.nn import Buffer, Parameter
23-
from torch.testing._internal import opinfo
23+
from torch.testing._internal import composite_compliance, opinfo
2424
from torch.testing._internal.common_utils import \
2525
(gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI,
2626
NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
@@ -48,6 +48,7 @@
4848
import operator
4949

5050
test_consistency_op_db = copy.deepcopy(op_db)
51+
test_cow_inputs_op_db = copy.deepcopy(op_db)
5152
test_error_inputs_op_db = copy.deepcopy(op_db)
5253

5354
# Add bicubic2d_aa to test_consistency_op_db
@@ -12049,6 +12050,183 @@ def test_fmax_mixed_dtypes(self, device):
1204912050
self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
1205012051

1205112052

12053+
class TestCOWInputs(TestCase):
12054+
# Tests that MPS ops do not mutate the underlying data of COW inputs.
12055+
# Materialization is allowed, but the original data buffer should never be
12056+
# written to.
12057+
# TODO: When we enable the `test_cow_input` test from `test_ops.py` for MPS,
12058+
# we can remove this test.
12059+
@ops(test_cow_inputs_op_db, allowed_dtypes=(torch.float,))
12060+
def test_cow_input_not_mutated(self, device, dtype, op):
12061+
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
12062+
12063+
def is_strided_tensor(arg):
12064+
return torch.is_tensor(arg) and arg.layout == torch.strided
12065+
12066+
def check_cow_input(
12067+
arg_copy,
12068+
arg_raw,
12069+
idx_or_kw,
12070+
backward_or_forward="forward",
12071+
):
12072+
arg_name = (
12073+
f"Argument {idx_or_kw}"
12074+
if isinstance(idx_or_kw, int)
12075+
else f"Keyword argument '{idx_or_kw}'"
12076+
) + f" during {backward_or_forward} call"
12077+
12078+
if is_strided_tensor(arg_raw):
12079+
self.assertTrue(
12080+
torch._C._is_cow_tensor(arg_raw),
12081+
msg=(
12082+
f"{arg_name} raw input should remain COW, but it "
12083+
"unexpectedly materialized."
12084+
),
12085+
)
12086+
# TODO: Make `torch.allclose` avoid materializing. We have to
12087+
# lazy clone arg_raw here before the comparison to prevent it
12088+
# from materializing and messing up subsequent checks.
12089+
arg_lazy_cloned = torch._lazy_clone(arg_raw)
12090+
print('------------------------------')
12091+
print('original value:')
12092+
print(arg_copy)
12093+
print('value after op:')
12094+
print(arg_lazy_cloned)
12095+
print('------------------------------')
12096+
self.assertTrue(
12097+
torch.allclose(
12098+
arg_lazy_cloned, arg_copy, rtol=0, atol=0, equal_nan=True
12099+
),
12100+
msg=(
12101+
f"{arg_name} COW input data was mutated."
12102+
),
12103+
)
12104+
12105+
for sample in samples:
12106+
args_raw = [sample.input] + list(sample.args)
12107+
kwargs_raw = sample.kwargs
12108+
12109+
# Eagerly cloned inputs used to keep track of the original values of
12110+
# inputs
12111+
args_copy = []
12112+
kwargs_copy = {}
12113+
12114+
# The lazy cloned inputs to be passed to the op.
12115+
args_lazy_cloned = []
12116+
kwargs_lazy_cloned = {}
12117+
12118+
# In order to keep the original args/kwargs_raw COW in cases where
12119+
# the op materializes the input, we need to start with three sets of
12120+
# COW inputs.
12121+
args_lazy_cloned_2 = []
12122+
kwargs_lazy_cloned_2 = {}
12123+
12124+
leaf_tensors = composite_compliance.gather_leaf_tensors(args_raw, kwargs_raw)
12125+
12126+
# Convert strided tensor inputs to COW tensors and make copies of
12127+
# all inputs
12128+
for idx, arg in enumerate(args_raw):
12129+
if is_strided_tensor(arg):
12130+
args_copy.append(arg.detach().clone())
12131+
args_lazy_cloned.append(torch._lazy_clone(arg))
12132+
args_lazy_cloned_2.append(torch._lazy_clone(arg))
12133+
else:
12134+
if torch.is_tensor(arg):
12135+
args_copy.append(arg.detach().clone())
12136+
else:
12137+
args_copy.append(copy.deepcopy(arg))
12138+
args_lazy_cloned.append(arg)
12139+
args_lazy_cloned_2.append(arg)
12140+
12141+
for kw, arg in kwargs_raw.items():
12142+
if is_strided_tensor(arg):
12143+
kwargs_copy[kw] = arg.detach().clone()
12144+
kwargs_lazy_cloned[kw] = torch._lazy_clone(arg)
12145+
kwargs_lazy_cloned_2[kw] = torch._lazy_clone(arg)
12146+
else:
12147+
if torch.is_tensor(arg):
12148+
kwargs_copy[kw] = arg.detach().clone()
12149+
else:
12150+
kwargs_copy[kw] = copy.deepcopy(arg)
12151+
kwargs_lazy_cloned[kw] = arg
12152+
kwargs_lazy_cloned_2[kw] = arg
12153+
12154+
# Call forward op
12155+
try:
12156+
results_raw = op.get_op()(*args_lazy_cloned, **kwargs_lazy_cloned)
12157+
except NotImplementedError:
12158+
raise unittest.SkipTest("Op not implemented") from None
12159+
12160+
# Check that COW inputs remain COW after the forward op is executed
12161+
for idx, arg in enumerate(args_lazy_cloned):
12162+
check_cow_input(args_copy[idx], args_raw[idx], idx)
12163+
12164+
for kw, arg in kwargs_lazy_cloned.items():
12165+
check_cow_input(kwargs_copy[kw], kwargs_raw[kw], kw)
12166+
12167+
# Call backward op if it is supported. This part of the test is
12168+
# based on `composite_compliance.check_backward_formula`
12169+
if (
12170+
op.supports_autograd
12171+
and len(leaf_tensors) > 0
12172+
and not op.skip_cow_input_backward
12173+
):
12174+
if sample.output_process_fn_grad is not None:
12175+
results_raw = sample.output_process_fn_grad(results_raw)
12176+
12177+
leaf_results = pytree.tree_leaves(results_raw)
12178+
results = [
12179+
r
12180+
for r in leaf_results
12181+
if isinstance(r, torch.Tensor) and r.requires_grad
12182+
]
12183+
12184+
all_results_strided = all(
12185+
is_strided_tensor(result) for result in results
12186+
)
12187+
12188+
# Only test backward if the results are strided tensors
12189+
if all_results_strided:
12190+
output_grads_raw = [
12191+
torch.ones(r.shape, device=r.device, dtype=r.dtype)
12192+
for r in results
12193+
]
12194+
output_grads_copy = []
12195+
output_grads_lazy_cloned = []
12196+
output_grads_lazy_cloned_2 = []
12197+
12198+
# Convert output grads to COW tensors and make copies
12199+
for output_grad in output_grads_raw:
12200+
output_grads_copy.append(output_grad.detach().clone())
12201+
output_grads_lazy_cloned.append(torch._lazy_clone(output_grad))
12202+
output_grads_lazy_cloned_2.append(torch._lazy_clone(output_grad))
12203+
12204+
torch.autograd.grad(
12205+
results,
12206+
leaf_tensors,
12207+
output_grads_lazy_cloned,
12208+
allow_unused=True,
12209+
retain_graph=True,
12210+
)
12211+
12212+
# Check that COW inputs remain COW after the backward op is executed
12213+
for idx, arg in enumerate(args_lazy_cloned):
12214+
check_cow_input(
12215+
args_copy[idx],
12216+
args_raw[idx],
12217+
idx,
12218+
backward_or_forward="backward",
12219+
)
12220+
12221+
# Check that COW inputs remain COW after the backward op is executed
12222+
for idx, output_grad in enumerate(output_grads_lazy_cloned):
12223+
check_cow_input(
12224+
output_grads_copy[idx],
12225+
output_grads_raw[idx],
12226+
f"output grad {idx}",
12227+
backward_or_forward="backward",
12228+
)
12229+
1205212230

1205312231
class TestErrorInputs(TestCase):
1205412232
_ignore_not_implemented_error = True
@@ -12342,6 +12520,7 @@ def test_metal_capture(self):
1234212520
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
1234312521
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
1234412522
instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
12523+
instantiate_device_type_tests(TestCOWInputs, globals(), allow_mps=True, only_for="mps")
1234512524
instantiate_parametrized_tests(TestLogical)
1234612525
instantiate_parametrized_tests(TestMPS)
1234712526
instantiate_parametrized_tests(TestSDPA)

0 commit comments

Comments
 (0)
0