8000 [WIP] Add `device` arg to `_lazy_clone` · pytorch/pytorch@d18182a · GitHub
[go: up one dir, main page]

Skip to content

Commit d18182a

Browse files
committed
[WIP] Add device arg to _lazy_clone
ghstack-source-id: 16fbf84 Pull Request resolved: #148408
1 parent 8531d24 commit d18182a

File tree

10 files changed

+166
-26
lines changed

10 files changed

+166
-26
lines changed

aten/src/ATen/native/AutogradComposite.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/core/Tensor.h>
33
#include <c10/util/SmallBuffer.h>
44
#include <c10/core/impl/COW.h>
5+
#include <c10/core/DispatchKey.h>
56

67
#ifndef AT_PER_OPERATOR_HEADERS
78
#include <ATen/Functions.h>
@@ -91,14 +92,21 @@ bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
9192
return base.storage().sym_nbytes() / base.itemsize() == other.storage().sym_nbytes() / other.itemsize();
9293
}
9394

94-
Tensor _lazy_clone(Tensor const& self) {
95+
Tensor _lazy_clone(Tensor const& self, optional<c10::Device> device_opt) {
9596
c10::StorageImpl* self_storage = self.storage().unsafeGetStorageImpl();
9697
c10::intrusive_ptr<c10::StorageImpl> storage =
97-
c10::impl::cow::lazy_clone_storage(*self_storage);
98+
c10::impl::cow::lazy_clone_storage(*self_storage, device_opt);
9899
TORCH_CHECK(storage != nullptr);
100+
c10::DispatchKeySet key_set = self.key_set();
101+
// If the target device differs, then we must change the key set
102+
if (device_opt.has_value() && device_opt.value().type() != self.device().type()) {
103+
c10::BackendComponent old_backend = c10::toBackendComponent(self.device().type());
104+
c10::BackendComponent new_backend = c10::toBackendComponent(device_opt.value().type());
105+
key_set = key_set.remove_backend(old_backend) | c10::DispatchKeySet(new_backend);
106+
}
99107
auto tensor = c10::make_intrusive<c10::TensorImpl>(
100108
c10::Storage(std::move(storage)),
101-
self.key_set(),
109+
key_set,
102110
self.dtype());
103111
tensor->set_sizes_and_strides(self.sym_sizes(),
104112
self.sym_strides(),

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
1818
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
1919
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
20+
#include <ATen/ops/_lazy_clone.h>
2021
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
2122
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
2223
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
@@ -422,6 +423,25 @@ bool to_will_alias(
422423
self.suggest_memory_format() == memory_format);
423424
}
424425

426+
bool _only_device_differs(
427+
const Tensor& self,
428+
std::optional<ScalarType> dtype,
429+
std::optional<Layout> layout,
430+
std::optional<Device> device,
431+
std::optional<bool> pin_memory,
432+
std::optional<c10::MemoryFormat> optional_memory_format) {
433+
bool device_differs = device.has_value() && device.value() != self.device();
434+
bool dtype_differs = dtype.has_value() && dtype.value() != self.scalar_type();
435+
bool layout_differs = layout.has_value() && layout.value() != self.layout();
436+
bool pin_memory_differs =
437+
pin_memory.has_value() && pin_memory.value() != self.is_pinned();
438+
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
439+
bool memory_format_differs = memory_format != MemoryFormat::Preserve &&
440+
memory_format != self.suggest_memory_format();
441+
return device_differs && !dtype_differs && !layout_differs &&
442+
!pin_memory_differs && !memory_format_differs;
443+
}
444+
425445
static inline Tensor to_impl(
426446
const Tensor& self,
427447
std::optional<ScalarType> dtype,
@@ -436,6 +456,12 @@ static inline Tensor to_impl(
436456
self, dtype, layout, device, copy, optional_memory_format)) {
437457
return self;
438458
}
459+
// TODO: after I prove that this works, I should only allow it for CPU-MPS,
460+
// and we can enabled others later if needed.
461+
// if (_only_device_differs(self, dtype, layout, device, pin_memory,
462+
// optional_memory_format)) {
463+
// return at::_lazy_clone(self, device);
464+
//}
439465
return at::_to_copy(
440466
self,
441467
dtype,

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,9 +1250,10 @@
12501250
CompositeExplicitAutograd: copysign_out
12511251
tags: pointwise
12521252

1253-
- func: _lazy_clone(Tensor self) -> Tensor
1253+
- func: _lazy_clone(Tensor self, *, Device? device=None) -> Tensor
12541254
# Like clone, but the copy takes place lazily, only if either the
1255-
# input or the output are written.
1255+
# input or the output are written. If `device` is given, the output
1256+
# will be copied to the specified device when the write occurs.
12561257
variants: function, method
12571258
dispatch:
12581259
CompositeExplicitAutograd: _lazy_clone

c10/core/DispatchKey.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
679679
}
680680
}
681681

682-
BackendComponent toBackendComponent(DeviceType device_type);
682+
C10_API BackendComponent toBackendComponent(DeviceType device_type);
683683

684684
// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
685685
// DispatchKey::CUDA.

c10/core/impl/COW.cpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <c10/core/StorageImpl.h>
55
#include <c10/core/alignment.h>
66
#include <c10/core/impl/COWDeleter.h>
7+
#include <c10/cuda/CUDACachingAllocator.h>
8+
#include <c10/cuda/CUDAFunctions.h>
79
#include <c10/util/Exception.h>
810
#include <c10/util/ParallelGuard.h>
911
#include <c10/util/UniqueVoidPtr.h>
@@ -48,7 +50,9 @@ bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
4850
return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
4951
}
5052

51-
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
53+
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(
54+
StorageImpl& storage,
55+
c10::optional<c10::Device> device_opt) {
5256
const at::DataPtr& data_ptr = storage.data_ptr();
5357

5458
// There are three possible circumstances:
@@ -76,38 +80,66 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
7680
//
7781
// No locking is required in this case.
7882

79-
std::optional<DataPtr> new_data_ptr; // must be set below
83+
std::optional<DataPtr> new_data_ptr_opt; // must be set below
8084

8185
if (has_simple_data_ptr(storage)) {
8286
// Case 1) We have a simple data pointer: wrap it.
8387
std::unique_ptr<void, DeleterFnPtr> original_ctx =
8488
storage._mutable_data_ptr_no_checks().move_context();
8589

8690
// Save this for the result.
87-
new_data_ptr = make_data_ptr(
88-
data_ptr, *new cow::COWDeleterContext(std::move(original_ctx)));
91+
new_data_ptr_opt = make_data_ptr(
92+
data_ptr,
93+
*new cow::COWDeleterContext(std::move(original_ctx), storage.device()));
8994

9095
// Update this storage to the new copy on write context.
91-
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
96+
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr_opt));
9297
} else if (is_cow_data_ptr(data_ptr)) {
9398
// Case 2): there is already a copy on write context. Just return a
9499
// new storage impl.
95-
new_data_ptr = copy_data_ptr(data_ptr);
100+
new_data_ptr_opt = copy_data_ptr(data_ptr);
96101
} else {
97102
// Case 3) There is a context and it's not copy-on-write. Nothing
98103
// we can do here.
99104
return nullptr;
100105
}
101106

102-
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
107+
TORCH_INTERNAL_ASSERT(new_data_ptr_opt.has_value());
108+
109+
c10::Allocator* allocator = storage.allocator();
110+
c10::DeviceType device_type = storage.device_type();
111+
112+
if (device_opt.has_value()) {
113+
DeviceGuard device_guard(device_opt.value());
114+
Device device = device_guard.current_device();
115+
116+
// If a different target device was given, then convert the data pointer to
117+
// that device.
118+
if (device != storage.device()) {
119+
DataPtr& new_data_ptr = new_data_ptr_opt.value();
120+
auto* ctx = new_data_ptr.cast_context<c10::impl::cow::COWDeleterContext>(
121+
c10::impl::cow::cow_deleter);
122+
device_type = device.type();
123+
124+
if (device_type == c10::kCUDA) {
125+
allocator = c10::cuda::CUDACachingAllocator::get();
126+
} else {
127+
allocator = c10::GetAllocator(device.type());
128+
}
129+
130+
new_data_ptr.release_context();
131+
new_data_ptr_opt = c10::DataPtr(
132+
new_data_ptr.get(), ctx, c10::impl::cow::cow_deleter, device);
133+
}
134+
}
103135

104136
return make_storage_impl(
105137
StorageImpl::use_byte_size_t(),
106138
storage.sym_nbytes(),
107-
*std::move(new_data_ptr),
108-
storage.allocator(),
139+
*std::move(new_data_ptr_opt),
140+
allocator,
109141
storage.resizable(),
110-
storage.device_type());
142+
device_type);
111143
}
112144

113145
C10_API void materialize_cow_storage(StorageImpl& storage) {
@@ -118,13 +150,14 @@ C10_API void materialize_cow_storage(StorageImpl& storage) {
118150

119151
auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
120152
TORCH_INTERNAL_ASSERT(ctx != nullptr);
121-
153+
bool devices_match = storage.device() == ctx->original_device();
122154
auto result = ctx->decrement_refcount();
123155

124156
// This must be set by each branch below.
125157
std::optional<DataPtr> new_data_ptr;
126158

127-
if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
159+
if (devices_match &&
160+
std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
128161
// This is the only reference to the data. If there were any racing writes,
129162
// the context ensured they finished before giving us the result.
130163
std::unique_ptr<void, DeleterFnPtr> data =
@@ -133,12 +166,14 @@ C10_API void materialize_cow_storage(StorageImpl& storage) {
133166
new_data_ptr = DataPtr(
134167
data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
135168
} else {
136-
TORCH_INTERNAL_ASSERT(
137-
std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
138-
result));
139169
// We don't need to consume the result, it's just a shared lock ensuring
140170
// that the data will remain while we copy it.
141171
new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
172+
if (!devices_match) {
173+
if (storage.device().type() == c10::kCUDA) {
174+
c10::cuda::device_synchronize();
175+
}
176+
}
142177
}
143178

144179
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());

c10/core/impl/COW.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <c10/core/Device.h>
34
#include <c10/macros/Macros.h>
5+
#include <c10/util/Optional.h>
46
#include <c10/util/intrusive_ptr.h>
57

68
namespace c10 {
@@ -17,8 +19,12 @@ namespace c10::impl::cow {
1719
// storage's DataPtr has some context (`DataPtr::get_context()`) which is not
1820
// equal to the data pointer (`DataPtr::get()`). In this case, a nullptr is
1921
// returned.
22+
//
23+
// If `device_opt` is given, the output will be copied to the specified device
24+
// when materialization occurs.
2025
C10_API c10::intrusive_ptr<StorageImpl> lazy_clone_storage(
21-
StorageImpl& storage);
26+
StorageImpl& storage,
27+
optional<Device> device_opt = nullopt);
2228

2329
// Check if a storage has a simple DataPtr with no abnormal context
2430
C10_API bool has_simple_data_ptr(const c10::StorageImpl& storage);

c10/core/impl/COWDeleter.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ void cow::cow_deleter(void* ctx) {
99
}
1010

1111
cow::COWDeleterContext::COWDeleterContext(
12-
std::unique_ptr<void, DeleterFnPtr> data)
13-
: data_(std::move(data)) {
12+
std::unique_ptr<void, DeleterFnPtr> data,
13+
c10::Device original_device)
14+
: data_(std::move(data)), original_device_(original_device) {
1415
// We never wrap a COWDeleterContext.
1516
TORCH_INTERNAL_ASSERT(data_.get_deleter() != cow::cow_deleter);
1617
}
@@ -39,4 +40,12 @@ cow::COWDeleterContext::~COWDeleterContext() {
3940
TORCH_INTERNAL_ASSERT(refcount_ == 0);
4041
}
4142

43+
c10::Device cow::COWDeleterContext::original_device() {
44+
return original_device_;
45+
}
46+
47+
std::int64_t cow::COWDeleterContext::refcount() {
48+
return refcount_.load();
49+
}
50+
4251
} // namespace c10::impl

c10/core/impl/COWDeleter.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <c10/core/Device.h>
34
#include <c10/macros/Export.h>
45
#include <c10/util/UniqueVoidPtr.h>
56

@@ -21,7 +22,9 @@ class C10_API COWDeleterContext {
2122
// Note that the deleter will only be called in our destructor if
2223
// the last reference to this goes away without getting
2324
// materialized.
24-
explicit COWDeleterContext(std::unique_ptr<void, DeleterFnPtr> data);
25+
explicit COWDeleterContext(
26+
std::unique_ptr<void, DeleterFnPtr> data,
27+
c10::Device original_device);
2528

2629
// Increments the current refcount.
2730
void increment_refcount();
@@ -45,6 +48,10 @@ class C10_API COWDeleterContext {
4548
// do with it.
4649
std::variant<NotLastReference, LastReference> decrement_refcount();
4750

51+
c10::Device original_device();
52+
53+
std::int64_t refcount();
54+
4855
private:
4956
// The destructor is hidden, this should only ever be used within
5057
// UniqueVoidPtr using cow::delete_context as the deleter.
@@ -53,6 +60,7 @@ class C10_API COWDeleterContext {
5360
std::shared_mutex mutex_;
5461
std::unique_ptr<void, DeleterFnPtr> data_;
5562
std::atomic<std::int64_t> refcount_ = 1;
63+
c10::Device original_device_;
5664
};
5765

5866
// `cow_deleter` is used as the `ctx_deleter` for DataPtr to implement a COW

test/test_torch.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5336,6 +5336,53 @@ def run(num_threads, num_parallel, skip_first, should_error):
53365336
run(10, 2, False, True)
53375337
run(10, 2, True, True)
53385338

5339+
@onlyCUDA
5340+
def test_lazy_clone_to_device(self, device):
5341+
device_pairs = [
5342+
('cpu', 'cuda'),
5343+
('cpu', 'cuda:0'),
5344+
('cpu', 'cuda:1'),
5345+
('cuda:1', 'cuda:0'),
5346+
('cuda:0', 'cuda:1'),
5347+
# TODO: Figure out why CUDA to CPU segfaults
5348+
# ('cuda', 'cpu'),
5349+
]
5350+
for from_device, to_device in device_pairs:
5351+
from_device_check = torch.empty(0, device=from_device).device
5352+
to_device_check = torch.empty(0, device=to_device).device
5353+
5354+
a = torch.randn(10, device=from_device)
5355+
orig_data_ptr = a.data_ptr()
5356+
b = a._lazy_clone(device=to_device)
5357+
5358+
self.assertEqual(a.device, from_device_check)
5359+
self.assertEqual(b.device, to_device_check)
5360+
self.assertTrue(torch._C._is_cow_tensor(a))
5361+
self.assertEqual(torch._C._data_address(a), orig_data_ptr)
5362+
self.assertTrue(torch._C._is_cow_tensor(b))
5363+
self.assertEqual(torch._C._data_address(b), orig_data_ptr)
5364+
5365+
a[0] = 1
5366+
5367+
self.assertEqual(a.device, from_device_check)
5368+
self.assertEqual(b.device, to_device_check)
5369+
self.assertFalse(torch._C._is_cow_tensor(a))
5370+
self.assertNotEqual(torch._C._data_address(a), orig_data_ptr)
5371+
self.assertTrue(torch._C._is_cow_tensor(b))
5372+
self.assertEqual(torch._C._data_address(b), orig_data_ptr)
5373+
5374+
b[0] = 2
5375+
5376+
self.assertEqual(a.device, from_device_check)
5377+
self.assertEqual(b.device, to_device_check)
5378+
self.assertFalse(torch._C._is_cow_tensor(a))
5379+
self.assertNotEqual(torch._C._data_address(a), orig_data_ptr)
5380+
self.assertFalse(torch._C._is_cow_tensor(b))
5381+
self.assertNotEqual(torch._C._data_address(b), orig_data_ptr)
5382+
5383+
self.assertEqual(a[0], 1)
5384+
self.assertEqual(b[0], 2)
5385+
53395386
# FIXME: move to test distributions
53405387
@skipIfMPS
53415388
@dtypesIfCUDA(torch.float, torch.double, torch.half)

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@
451451
self: grad
452452
result: auto_linear
453453

454-
- name: _lazy_clone(Tensor self) -> Tensor
454+
- name: _lazy_clone(Tensor self, *, Device? device=None) -> Tensor
455455
self: grad
456456
result: auto_linear
457457

0 commit comments

Comments
 (0)
0