8000 [WIP] Add `device` arg to `_lazy_clone` · pytorch/pytorch@48dbbca · GitHub
[go: up one dir, main page]

Skip to content

Commit 48dbbca

Browse files
[WIP] Add device arg to _lazy_clone
ghstack-source-id: b84d27a Pull Request resolved: #148408
1 parent e51615c commit 48dbbca

29 files changed

+430
-59
lines changed

aten/src/ATen/EmptyTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ struct MetaAllocator final : public at::Allocator {
348348
DeleterFnPtr raw_deleter() const override {
349349
return deleter;
350350
}
351-
void copy_data(void* dest, const void* src, std::size_t count) const final {}
351+
void copy_data(void* dest, const void* src, std::size_t count, bool sync=false) const final {}
352352
};
353353

354354
static MetaAllocator g_meta_alloc;

aten/src/ATen/core/CachingHostAllocator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ struct CachingHostAllocatorImpl {
339339
return false;
340340
}
341341

342-
virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
342+
virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]], bool sync [[maybe_unused]] = false) const {
343343
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
344344
}
345345

@@ -641,9 +641,9 @@ struct CachingHostAllocatorInterface : public at::Allocator {
641641
impl_->empty_cache();
642642
}
643643

644-
void copy_data(void* dest, const void* src, std::size_t count)
644+
void copy_data(void* dest, const void* src, std::size_t count, bool sync=false)
645645
const override {
646-
impl_->copy_data(dest, src, count);
646+
impl_->copy_data(dest, src, count, sync);
647647
}
648648

649649
HostStats getStats() {

aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
2323
DeleterFnPtr raw_deleter() const override {
2424
return allocator_->raw_deleter();
2525
}
26-
void copy_data(void* dest, const void* src, std::size_t count) const final {
27-
allocator_->copy_data(dest, src, count);
26+
void copy_data(void* dest, const void* src, std::size_t count, bool sync=false) const final {
27+
allocator_->copy_data(dest, src, count, sync);
2828
}
2929
};
3030

aten/src/ATen/mps/MPSAllocator.mm

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/mps/MPSAllocator.h>
66
#include <c10/core/Allocator.h>
77
#include <c10/core/Storage.h>
8+
#include <ATen/detail/MPSHooksInterface.h>
89

910
#include <iostream>
1011

@@ -820,8 +821,11 @@ bool waitForEvents(c10::ArrayRef<const void*> buffers) const override {
820821
return _getAllocImpl().format_size(size);
821822
}
822823

823-
void copy_data(void* dest, const void* src, std::size_t count) const final {
824+
void copy_data(void* dest, const void* src, std::size_t count, bool sync = false) const final {
824825
default_copy_data(dest, src, count);
826+
if (sync) {
827+
at::detail::getMPSHooks().deviceSynchronize();
828+
}
825829
}
826830

827831
private:

aten/src/ATen/native/AutogradComposite.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/core/Tensor.h>
33
#include <c10/util/SmallBuffer.h>
44
#include <c10/core/impl/COW.h>
5+
#include <c10/core/DispatchKey.h>
56

67
#ifndef AT_PER_OPERATOR_HEADERS
78
#include <ATen/Functions.h>
@@ -13,6 +14,7 @@
1314
#include <ATen/ops/_unpack_dual_native.h>
1415
#include <ATen/ops/_lazy_clone_native.h>
1516
#include <ATen/ops/alias.h>
17+
#include <ATen/ops/empty.h>
1618
#include <ATen/ops/zeros.h>
1719
#endif
1820

@@ -91,14 +93,25 @@ bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
9193
return base.storage().sym_nbytes() / base.itemsize() == other.storage().sym_nbytes() / other.itemsize();
9294
}
9395

94-
Tensor _lazy_clone(Tensor const& self) {
96+
Tensor _lazy_clone(Tensor const& self, optional<c10::Device> device_opt) {
97+
optional<c10::Allocator*> allocator_opt = nullopt;
98+
if (device_opt.has_value()) {
99+
allocator_opt = at::empty({}, at::TensorOptions().device(device_opt.value())).storage().allocator();
100+
}
95101
c10::StorageImpl* self_storage = self.storage().unsafeGetStorageImpl();
96102
c10::intrusive_ptr<c10::StorageImpl> storage =
97-
c10::impl::cow::lazy_clone_storage(*self_storage);
103+
c10::impl::cow::lazy_clone_storage(*self_storage, device_opt, allocator_opt);
98104
TORCH_CHECK(storage != nullptr);
105+
c10::DispatchKeySet key_set = self.key_set();
106+
// If the target device differs, then we must change the key set
107+
if (device_opt.has_value() && device_opt.value().type() != self.device().type()) {
108+
c10::BackendComponent old_backend = c10::toBackendComponent(self.device().type());
109+
c10::BackendComponent new_backend = c10::toBackendComponent(device_opt.value().type());
110+
key_set = key_set.remove_backend(old_backend) | c10::DispatchKeySet(new_backend);
111+
}
99112
auto tensor = c10::make_intrusive<c10::TensorImpl>(
100113
c10::Storage(std::move(storage)),
101-
self.key_set(),
114+
key_set,
102115
self.dtype());
103116
tensor->set_sizes_and_strides(self.sym_sizes(),
104117
self.sym_strides(),

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
1818
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
1919
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
20+
#include <ATen/ops/_lazy_clone.h>
2021
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
2122
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
2223
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
@@ -422,6 +423,26 @@ bool to_will_alias(
422423
self.suggest_memory_format() == memory_format);
423424
}
424425

426+
// static bool _only_device_differs(
427+
// const Tensor& self,
428+
// std::optional<ScalarType> dtype,
429+
// std::optional<Layout> layout,
430+
// std::optional<Device> device,
431+
// std::optional<bool> pin_memory,
432+
// std::optional<c10::MemoryFormat> optional_memory_format) {
433+
// bool device_differs = device.has_value() && device.value() !=
434+
// self.device(); bool dtype_differs = dtype.has_value() && dtype.value() !=
435+
// self.scalar_type(); bool layout_differs = layout.has_value() &&
436+
// layout.value() != self.layout(); bool pin_memory_differs =
437+
// pin_memory.has_value() && pin_memory.value() != self.is_pinned();
438+
// auto memory_format =
439+
// optional_memory_format.value_or(MemoryFormat::Preserve); bool
440+
// memory_format_differs = memory_format != MemoryFormat::Preserve &&
441+
// memory_format != self.suggest_memory_format();
442+
// return device_differs && !dtype_differs && !layout_differs &&
443+
// !pin_memory_differs && !memory_format_differs;
444+
// }
445+
425446
static inline Tensor to_impl(
426447
const Tensor& self,
427448
std::optional<ScalarType> dtype,
@@ -436,6 +457,12 @@ static inline Tensor to_impl(
436457
self, dtype, layout, device, copy, optional_memory_format)) {
437458
return self;
438459
}
460+
// TODO: after I prove that this works, I should only allow it for CPU-MPS,
461+
// and we can enabled others later if needed.
462+
// if (_only_device_differs(self, dtype, layout, device, pin_memory,
463+
// optional_memory_format)) {
464+
// return at::_lazy_clone(self, device);
465+
//}
439466
return at::_to_copy(
440467
self,
441468
dtype,

aten/src/ATen/native/TensorFactories.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ struct ZeroTensorAllocator final : public at::Allocator {
157157
void copy_data(
158158
void* dest [[maybe_unused]],
159159
const void* src [[maybe_unused]],
160-
std::size_t count [[maybe_unused]]) const final {}
160+
std::size_t count [[maybe_unused]],
161+
bool sync [[maybe_unused]] = false) const final {}
161162
at::Device device_;
162163
};
163164

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,9 +1250,10 @@
12501250
CompositeExplicitAutograd: copysign_out
12511251
tags: pointwise
12521252

1253-
- func: _lazy_clone(Tensor self) -> Tensor
1253+
- func: _lazy_clone(Tensor self, *, Device? device=None) -> Tensor
12541254
# Like clone, but the copy takes place lazily, only if either the
1255-
# input or the output are written.
1255+
# input or the output are written. If `device` is given, the output
1256+
# will be copied to the specified device when the write occurs.
12561257
variants: function, method
12571258
dispatch:
12581259
CompositeExplicitAutograd: _lazy_clone

aten/src/ATen/test/xla_tensor_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct XLAAllocator final : public at::Allocator {
2424
at::DeleterFnPtr raw_deleter() const override {
2525
return &XLAFree;
2626
}
27-
void copy_data(void* dest, const void* src, std::size_t count) const final {
27+
void copy_data(void* dest, const void* src, std::size_t count, bool sync=false) const final {
2828
default_copy_data(dest, src, count);
2929
}
3030
};

c10/core/Allocator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
namespace c10 {
99

10-
DataPtr Allocator::clone(const void* data, std::size_t n) {
10+
DataPtr Allocator::clone(const void* data, std::size_t n, bool sync) {
1111
DataPtr new_data = allocate(n);
12-
copy_data(new_data.mutable_get(), data, n);
12+
copy_data(new_data.mutable_get(), data, n, sync);
1313
return new_data;
1414
}
1515

c10/core/Allocator.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,13 @@ struct C10_API Allocator {
173173
// Note that this explicitly ignores any context that may have been
174174
// attached to the input data.
175175
//
176-
// Requires: input data was allocated by the same allocator.
177-
DataPtr clone(const void* data, std::size_t n);
176+
// If `sync=true` is given, then the device will synchronize after the clone
177+
// happens, if the device is normally asynchronous.
178+
//
179+
// Requires: Depending on the details of the allocator, input data may need to
180+
// have been allocated by the same allocator. Some allocators do support
181+
// cloning from a different device.
182+
DataPtr clone(const void* data, std::size_t n, bool sync = false);
178183

179184
// Checks if DataPtr has a simple context, not wrapped with any out of the
180185
// ordinary contexts.
@@ -205,8 +210,11 @@ struct C10_API Allocator {
205210
//
206211
// Requires: src and dest were allocated by this allocator
207212
// Requires: src and dest both have length >= count
208-
virtual void copy_data(void* dest, const void* src, std::size_t count)
209-
const = 0;
213+
virtual void copy_data(
214+
void* dest,
215+
const void* src,
216+
std::size_t count,
217+
bool sync = false) const = 0;
210218

211219
protected:
212220
// Uses `std::memcpy` to copy data.

c10/core/CPUAllocator.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ struct C10_API DefaultCPUAllocator final : at::Allocator {
4141
return &ReportAndDelete;
4242
}
4343

44-
void copy_data(void* dest, const void* src, std::size_t count) const final {
44+
void copy_data(
45+
void* dest,
46+
const void* src,
47+
std::size_t count,
48+
bool sync = false) const final {
4549
default_copy_data(dest, src, count);
4650
}
4751
};
@@ -149,7 +153,11 @@ class DefaultMobileCPUAllocator final : public at::Allocator {
149153
PreGuardBytes;
150154
}
151155

152-
void copy_data(void* dest, const void* src, std::size_t count) const final {
156+
void copy_data(
157+
void* dest,
158+
const void* src,
159+
std::size_t count,
160+
bool sync = false) const final {
153161
default_copy_data(dest, src, count);
154162
}
155163
};

c10/core/DispatchKey.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
679679
}
680680
}
681681

682-
BackendComponent toBackendComponent(DeviceType device_type);
682+
C10_API BackendComponent toBackendComponent(DeviceType device_type);
683683

684684
// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
685685
// DispatchKey::CUDA.

c10/core/impl/COW.cpp

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <c10/core/impl/COW.h>
22

33
#include <c10/core/Allocator.h>
4+
#include <c10/core/DeviceGuard.h>
45
#include <c10/core/StorageImpl.h>
56
#include <c10/core/alignment.h>
67
#include <c10/core/impl/COWDeleter.h>
@@ -48,7 +49,12 @@ bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
4849
return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
4950
}
5051

51-
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
52+
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(
53+
StorageImpl& storage,
54+
c10::optional<c10::Device> device_opt,
55+
c10::optional<c10::Allocator*> allocator_opt) {
56+
TORCH_INTERNAL_ASSERT(device_opt.has_value() == allocator_opt.has_value());
57+
5258
const at::DataPtr& data_ptr = storage.data_ptr();
5359

5460
// There are three possible circumstances:
@@ -76,38 +82,61 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
7682
//
7783
// No locking is required in this case.
7884

79-
std::optional<DataPtr> new_data_ptr; // must be set below
85+
std::optional<DataPtr> new_data_ptr_opt; // must be set below
8086

8187
if (has_simple_data_ptr(storage)) {
8288
// Case 1) We have a simple data pointer: wrap it.
8389
std::unique_ptr<void, DeleterFnPtr> original_ctx =
8490
storage._mutable_data_ptr_no_checks().move_context();
8591

8692
// Save this for the result.
87-
new_data_ptr = make_data_ptr(
88-
data_ptr, *new cow::COWDeleterContext(std::move(original_ctx)));
93+
new_data_ptr_opt = make_data_ptr(
94+
data_ptr,
95+
*new cow::COWDeleterContext(std::move(original_ctx), storage.device()));
8996

9097
// Update this storage to the new copy on write context.
91-
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
98+
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr_opt));
9299
} else if (is_cow_data_ptr(data_ptr)) {
93100
// Case 2): there is already a copy on write context. Just return a
94101
// new storage impl.
95-
new_data_ptr = copy_data_ptr(data_ptr);
102+
new_data_ptr_opt = copy_data_ptr(data_ptr);
96103
} else {
97104
// Case 3) There is a context and it's not copy-on-write. Nothing
98105
// we can do here.
99106
return nullptr;
100107
}
101108

102-
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
109+
TORCH_INTERNAL_ASSERT(new_data_ptr_opt.has_value());
110+
111+
c10::Allocator* allocator = storage.allocator();
112+
c10::DeviceType device_type = storage.device_type();
113+
114+
if (device_opt.has_value()) {
115+
allocator = allocator_opt.value();
116+
117+
DeviceGuard device_guard(device_opt.value());
118+
Device device = device_guard.current_device();
119+
120+
// If a different target device was given, then convert the data pointer to
121+
// that device.
122+
if (device != storage.device()) {
123+
DataPtr& new_data_ptr = new_data_ptr_opt.value();
124+
auto* ctx = new_data_ptr.cast_context<c10::impl::cow::COWDeleterContext>(
125+
c10::impl::cow::cow_deleter);
126+
device_type = device.type();
127+
new_data_ptr.release_context();
128+
new_data_ptr_opt = c10::DataPtr(
129+
new_data_ptr.get(), ctx, c10::impl::cow::cow_deleter, device);
130+
}
131+
}
103132

104133
return make_storage_impl(
105134
StorageImpl::use_byte_size_t(),
106135
storage.sym_nbytes(),
107-
*std::move(new_data_ptr),
108-
storage.allocator(),
136+
*std::move(new_data_ptr_opt),
137+
allocator,
109138
storage.resizable(),
110-
storage.device_type());
139+
device_type);
111140
}
112141

113142
C10_API void materialize_cow_storage(StorageImpl& storage) {
@@ -118,13 +147,14 @@ C10_API void materialize_cow_storage(StorageImpl& storage) {
118147

119148
auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
120149
TORCH_INTERNAL_ASSERT(ctx != nullptr);
121-
150+
bool devices_match = storage.device() == ctx->original_device();
122151
auto result = ctx->decrement_refcount();
123152

124153
// This must be set by each branch below.
125154
std::optional<DataPtr> new_data_ptr;
126155

127-
if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
156+
if (devices_match &&
157+
std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
128158
// This is the only reference to the data. If there were any racing writes,
129159
// the context ensured they finished before giving us the result.
130160
std::unique_ptr<void, DeleterFnPtr> data =
@@ -133,12 +163,10 @@ C10_API void materialize_cow_storage(StorageImpl& storage) {
133163
new_data_ptr = DataPtr(
134164
data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
135165
} else {
136-
TORCH_INTERNAL_ASSERT(
137-
std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
138-
result));
139166
// We don't need to consume the result, it's just a shared lock ensuring
140167
// that the data will remain while we copy it.
141-
new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
168+
new_data_ptr = storage.allocator()->clone(
169+
data_ptr.get(), storage.nbytes(), /*sync=*/!devices_match);
142170
}
143171

144172
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());

0 commit comments

Comments
 (0)
0