10000 Enable lazy cloning in `Tensor.to` between CPU and MPS · pytorch/pytorch@c41fbc5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c41fbc5

Browse files
committed
Enable lazy cloning in Tensor.to between CPU and MPS
ghstack-source-id: e508c76 Pull Request resolved: #150569
1 parent c1e0870 commit c41fbc5

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

aten/src/ATen/native/TensorConversions.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
1818
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
1919
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
20+
#include <ATen/ops/_lazy_clone.h>
2021
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
2122
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
2223
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
@@ -422,6 +423,25 @@ bool to_will_alias(
422423
self.suggest_memory_format() == memory_format);
423424
}
424425

426+
static bool _only_device_differs(
427+
const Tensor& self,
428+
std::optional<ScalarType> dtype,
429+
std::optional<Layout> layout,
430+
std::optional<Device> device,
431+
std::optional<bool> pin_memory,
432+
std::optional<c10::MemoryFormat> optional_memory_format) {
433+
bool device_differs = device.has_value() && device.value() != self.device();
434+
bool dtype_differs = dtype.has_value() && dtype.value() != self.scalar_type();
435+
bool layout_differs = layout.has_value() && layout.value() != self.layout();
436+
bool pin_memory_differs =
437+
pin_memory.has_value() && pin_memory.value() != self.is_pinned();
438+
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
439+
bool memory_format_differs = memory_format != MemoryFormat::Preserve &&
440+
memory_format != self.suggest_memory_format();
441+
return device_differs && !dtype_differs && !layout_differs &&
442+
!pin_memory_differs && !memory_format_differs;
443+
}
444+
425445
static inline Tensor to_impl(
426446
const Tensor& self,
427447
std::optional<ScalarType> dtype,
@@ -436,6 +456,26 @@ static inline Tensor to_impl(
436456
self, dtype, layout, device, copy, optional_memory_format)) {
437457
return self;
438458
}
459+
if (device.has_value()) {
460+
c10::DeviceType src_device_type = self.device().type();
461+
c10::DeviceType dst_device_type = device.value().type();
462+
// Conversion between MPS and CPU is done lazily, as long as `device` is the
463+
// only thing that is changed. Also, in order to lazy clone from CPU to MPS,
464+
// the CPU data must be pinned.
465+
if ((src_device_type == c10::kCPU && dst_device_type == c10::kMPS &&
466+
self.is_pinned()) ||
467+
(src_device_type == c10::kMPS && dst_device_type == c10::kCPU)) {
468+
if (_only_device_differs(
469+
self,
470+
dtype,
471+
layout,
472+
device,
473+
pin_memory,
474+
optional_memory_format)) {
475+
return at::_lazy_clone(self, device);
476+
}
477+
}
478+
}
439479
return at::_to_copy(
440480
self,
441481
dtype,

test/test_lazy_clone.py

+33-4
@@ -76,7 +83,15 @@ def test_interdevice_materialize(self, device, materialize_first, case):
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def get_src_dest_devices(self, case, device):
5656
@skipCUDAIf(True, "Does not work for CUDA")
5757
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
5858
@skipXLA
59+
@parametrize(
60+
"op",
61+
[
62+
"_lazy_clone",
63+
"to",
64+
],
65+
)
5966
@parametrize("materialize_first", ("src", "dest"))
6067
@parametrize(
6168
"case",
@@ -67,7 +74,7 @@ def get_src_dest_devices(self, case, device):
6774
"from_1_to_0",
6875
],
6976
)
70-
def test_interdevice_materialize(self, device, materialize_first, case):
77+
def test_interdevice_materialize(self, device, op, materialize_first, case):
7178
src_device, dest_device = self.get_src_dest_devices(case, device)
7279

7380
src_device_check = torch.empty(0, device=src_device).device
7683

7784
a = torch.randn(10, device=src_device, pin_memory=pin_memory)
7885
orig_data_ptr = torch._C._data_address_resolve_unified(a)
79-
b = a._lazy_clone(device=dest_device)
86+
87+
if op == "_lazy_clone":
88+
b = a._lazy_clone(device=dest_device)
89+
elif op == "to":
90+
if torch.device(device).type != "mps":
91+
self.skipTest("op='to' only runs if device='mps'")
92+
b = a.to(device=dest_device)
93+
else:
94+
raise AssertionError(f"op='{op}' not recognized")
8095

8196
self.assertEqual(a.device, src_device_check)
8297
self.assertEqual(b.device, dest_device_check)
@@ -146,6 +161,13 @@ def test_interdevice_materialize(self, device, materialize_first, case):
146161
@skipCUDAIf(True, "Does not work for CUDA")
147162
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
148163
@skipXLA
164+
@parametrize(
165+
"op",
166+
[
167+
"_lazy_clone",
168+
"to",
169+
],
170+
)
149171
@parametrize(
150172
"case",
151173
[
@@ -156,7 +178,7 @@ def test_interdevice_materialize(self, device, materialize_first, case):
156178
"from_1_to_0",
157179
],
158180
)
159-
def test_interdevice_read(self, device, case):
181+
def test_interdevice_read(self, device, op, case):
160182
src_device, dest_device = self.get_src_dest_devices(case, device)
161183

162184
src_device_check = torch.empty(0, device=src_device).device
@@ -168,7 +190,14 @@ def test_interdevice_read(self, device, case):
168190
a.copy_(orig_tensor)
169191

170192
orig_data_ptr = torch._C._data_address_resolve_unified(a)
171-
b = a._lazy_clone(device=dest_device)
193+
if op == "_lazy_clone":
194+
b = a._lazy_clone(device=dest_device)
195+
elif op == "to":
196+
if torch.device(device).type != "mps":
197+
self.skipTest("op='to' only runs if device='mps'")
198+
b = a.to(device=dest_device)
199+
else:
200+
raise AssertionError(f"op='{op}' not recognized")
172201

173202
self.assertEqual(a.device, src_device_check)
174203
self.assertEqual(b.device, dest_device_check)

0 commit comments

Comments
 (0)
0