8000 Avoid overwriting COW data in MPS code · pytorch/pytorch@12c6960 · GitHub
[go: up one dir, main page]

Skip to content

Commit 12c6960

Browse files
committed
Avoid overwriting COW data in MPS code
ghstack-source-id: a1708b8 Pull Request resolved: #150721
1 parent 01930db commit 12c6960

File tree

10 files changed

+302
-10
lines changed

10 files changed

+302
-10
lines changed

aten/src/ATen/mps/MPSAllocator.mm

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,18 @@ void copy_data(void* dest, const void* src, std::size_t count) const final {
853853
} else if (isSharedBufferCPUPtr(dest)) {
854854
TORCH_INTERNAL_ASSERT(isSharedBufferCPUPtr(src));
855855
}
856+
// CHECK: Do we need to sync here?
857+
auto stream = getDefaultMPSStream();
858+
dispatch_sync(stream->queue(), ^() {
859+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
860+
});
861+
856862
default_copy_data(dest, src, count);
863+
864+
// CHECK: Do we need to sync here?
865+
dispatch_sync(stream->queue(), ^() {
866+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
867+
});
857868
}
858869

859870
void* get_cpu_ptr_from_device_ptr(void* device_ptr) const override {

aten/src/ATen/mps/MPSGeneratorImpl.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Generator createMPSGenerator(uint64_t seed_val) {
7070

7171
auto state_tensor = at::detail::empty_cpu(
7272
{(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
73-
auto rng_state = state_tensor.data_ptr<uint8_t>();
73+
auto rng_state = state_tensor.mutable_data_ptr<uint8_t>();
7474
auto current_seed = this->current_seed();
7575
auto current_offset = this->get_offset();
7676

aten/src/ATen/native/AutogradComposite.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ Tensor _lazy_clone(Tensor const& self, std::optional<c10::Device> device_opt) {
171171
if (self.device().type() == c10::kMPS) {
172172
at::detail::getMPSHooks().deviceSynchronize();
173173
}
174+
} else if (self.device().type() == c10::kMPS) {
175+
// CHECK: Do we always need to sync for MPS?
176+
at::detail::getMPSHooks().deviceSynchronize();
174177
}
175178
return Tensor(std::move(tensor));
176179
}

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,33 @@ MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = Mem
108108
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
109109

110110
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
111+
// return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().mutable_data());
111112
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
112113
}
113114

115+
// This class wraps a tensor with an API that can obtain the underlying
116+
// `id<MTLBuffer>` while preventing COW materialization and attempting to
117+
// prevent mutations to the data. Unfortunately, there is no way to make the
118+
// compiler actually prevent mutating the data in the MPS code because Metal
119+
// APIs operate on `id<MTLBuffer>`, which resolves to `struct objc_object *`, a
120+
// pointer to non-const data.
121+
class ConstMTLBufferTensor {
122+
public:
123+
ConstMTLBufferTensor(const TensorBase& tensor) : _tensor(tensor) {}
124+
125+
// WARNING: Do not write to the buffer returned by this function.
126+
id<MTLBuffer> mtl_buffer_unsafe() const {
127+
return __builtin_bit_cast(id<MTLBuffer>, _tensor.storage().data());
128+
}
129+
130+
const TensorBase& tensor() const {
131+
return _tensor;
132+
}
133+
134+
private:
135+
const TensorBase& _tensor;
136+
};
137+
114138
class Placeholder {
115139
public:
116140
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
@@ -355,6 +379,7 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne
355379
if (C10_UNLIKELY(t.device().type() == kCPU)) {
356380
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
357381
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
382+
// [encoder setBytes:t.storage().mutable_data() length:t.element_size() atIndex:idx];
358383
[encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx];
359384
} else {
360385
TORCH_CHECK(false, "Passed CPU tensor to MPS op");
@@ -364,6 +389,25 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne
364389
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
365390
}
366391

392+
template <typename encoder_t,
393+
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
394+
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
395+
static inline void mtl_setBuffer(encoder_t encoder, ConstMTLBufferTensor b, unsigned idx) {
396+
const TensorBase& t = b.tensor();
397+
if (C10_UNLIKELY(t.device().type() == kCPU)) {
398+
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
399+
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
400+
[encoder setBytes:b.mtl_buffer_unsafe() length:t.element_size() atIndex:idx];
401+
// [encoder setBytes:getMTLBufferStorage(t) length:t.element_size() atIndex:idx];
402+
} else {
403+
TORCH_CHECK(false, "Passed CPU tensor to MPS op");
404+
}
405+
return;
406+
}
407+
[encoder setBuffer:b.mtl_buffer_unsafe() offset:t.storage_offset() * t.element_size() atIndex:idx];
408+
// [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
409+
}
410+
367411
// Implementation of setBytes for containers vs trivially copiable types must be separate
368412
// Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`,
369413
// while trivially copiable, includes padding which if copied as Metal shader parameters
@@ -395,6 +439,10 @@ inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> val,
395439
[encoder setBuffer:val offset:0 atIndex:idx];
396440
}
397441

442+
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, ConstMTLBufferTensor val, unsigned idx) {
443+
mtl_setBuffer(encoder, val, idx);
444+
}
445+
398446
template <>
399447
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const Tensor& val, unsigned idx) {
400448
mtl_setBuffer(encoder, val, idx);

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ void printTensorNDArray(const TensorBase& t) {
387387
auto selfDType = getMPSDataType(t.scalar_type());
388388

389389
// Initialize data
390-
id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
390+
id<MTLBuffer> selfBuf = ConstMTLBufferTensor(t).mtl_buffer_unsafe();
391+
// id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
391392
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
392393
dataType:selfDType] autorelease];
393394
C10_CLANG_DIAGNOSTIC_PUSH()

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ static void copy_cast_mps(at::Tensor& dst,
112112
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared;
113113
NSUInteger alignedLength = 0;
114114

115-
const void* host_dst = static_cast<const char*>(dst.storage().data()) + dst.storage_offset() * dst.itemsize();
115+
// void* host_dst = static_cast<char*>(dst.storage().mutable_data()) + dst.storage_offset() * dst.itemsize();
116+
void* host_dst = static_cast<char*>(dst.storage().data()) + dst.storage_offset() * dst.itemsize();
116117
void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)dst_tensor_nbytes, &alignedLength);
117118
NSUInteger destOffset = (uintptr_t(host_dst) - uintptr_t(alignedPtr));
118119
// 4 bytes alignment required on macos for blits.
@@ -258,7 +259,8 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
258259
src = src_;
259260
}
260261
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
261-
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
262+
id<MTLBuffer> sourceBuffer = ConstMTLBufferTensor(src).mtl_buffer_unsafe();
263+
/// id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
262264

263265
// Scatter to `dst` if the memory is not contiguous
264266
// If the memory is not contiguous, it means that the tensor has strides and we would not be
@@ -295,7 +297,8 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
295297
} else if (dst_byte_offset) {
296298
auto maybeCastedSource =
297299
at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
298-
auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
300+
auto maybeCastedSourceBuffer = ConstMTLBufferTensor(maybeCastedSource).mtl_buffer_unsafe();
301+
// auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
299302
copy_cast_mps(maybeCastedSource, src, maybeCastedSourceBuffer, sourceBuffer);
300303

301304
uint64_t profile_id = getMPSProfiler().beginProfileCopy(

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
162162
getMPSProfiler().beginProfileKernel(indexSelectPSO, indexFunction, {inputTensor});
163163

164164
[computeEncoder setComputePipelineState:indexSelectPSO];
165-
mtl_setArgs(
166-
computeEncoder, indexAB, index_size, index_stride, kernelDataOffsets, inputTensor, outputTensor, num_indices);
165+
mtl_setArgs(computeEncoder,
166+
indexAB,
167+
index_size,
168+
index_stride,
169+
kernelDataOffsets,
170+
ConstMTLBufferTensor(inputTensor),
171+
// inputTensor,
172+
outputTensor,
173+
num_indices);
167174
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
168175
if (serial_index_put) {
169176
mtl_setBytes(computeEncoder, numIters, 7);

aten/src/ATen/native/mps/operations/View.mm

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,13 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
146146
}
147147

148148
[computeEncoder setComputePipelineState:gatherPSO];
149-
mtl_setArgs(computeEncoder, src, dst.has_storage() ? dst : output, src_sizes, src_strides, numThreads);
149+
mtl_setArgs(computeEncoder,
150+
ConstMTLBufferTensor(src),
151+
// src,
152+
dst.has_storage() ? dst : output,
153+
src_sizes,
154+
src_strides,
155+
numThreads);
150156
if (src.dim() > 4) {
151157
mtl_setBytes<int32_t>(computeEncoder, src.dim(), 5);
152158
}
@@ -192,7 +198,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
192198
}
193199

194200
[computeEncoder setComputePipelineState:scatterPSO];
195-
mtl_setArgs(computeEncoder, src, output, output_sizes, output_strides, numThreads);
201+
mtl_setArgs(computeEncoder, ConstMTLBufferTensor(src), output, output_sizes, output_strides, numThreads);
202+
// mtl_setArgs(computeEncoder, src, output, output_sizes, output_strides, numThreads);
196203
if (output.dim() > 4) {
197204
mtl_setBytes<int32_t>(computeEncoder, output.dim(), 5);
198205
}

test/test_lazy_clone.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,37 @@ def test_interdevice_read(self, device, case):
208208
self.assertTrue(torch._C._is_cow_tensor(b))
209209
self.assertEqual(torch._C._data_address_resolve_unified(b), orig_data_ptr)
210210

211+
def test_clone_after_lazy_clone(self, device):
212+
a = torch.randn(10, device=device)
213+
orig_data_ptr = torch._C._data_address_resolve_unified(a)
214+
b = torch._lazy_clone(a)
215+
216+
self.assertTrue(torch._C._is_cow_tensor(a))
217+
self.assertTrue(torch._C._is_cow_tensor(b))
218+
self.assertEqual(torch._C._data_address_resolve_unified(a), orig_data_ptr)
219+
self.assertEqual(torch._C._data_address_resolve_unified(b), orig_data_ptr)
220+
221+
c = b.clone()
222+
223+
self.assertTrue(torch._C._is_cow_tensor(a))
224+
self.assertTrue(torch._C._is_cow_tensor(b))
225+
self.assertFalse(torch._C._is_cow_tensor(c))
226+
self.assertEqual(torch._C._data_address_resolve_unified(a), orig_data_ptr)
227+
self.assertEqual(torch._C._data_address_resolve_unified(b), orig_data_ptr)
228+
229+
self.assertEqual(b.clone(), c)
230+
self.assertEqual(a.clone(), c)
231+
232+
self.assertTrue(torch._C._is_cow_tensor(a))
233+
self.assertTrue(torch._C._is_cow_tensor(b))
234+
self.assertFalse(torch._C._is_cow_tensor(c))
235+
self.assertEqual(torch._C._data_address_resolve_unified(a), orig_data_ptr)
236+
self.assertEqual(torch._C._data_address_resolve_unified(b), orig_data_ptr)
237+
238+
self.assertEqual(a, b)
239+
self.assertEqual(a, c)
240+
self.assertEqual(b, c)
241+
211242

212243
instantiate_device_type_tests(TestLazyCloneDeviceType, globals(), allow_mps=True)
213244

0 commit comments

Comments
 (0)
0