8000 Update base for Update on "[FSDP][dtensor] use _StridedShard to repre… · pytorch/pytorch@1ed2ce4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1ed2ce4

Browse files
committed
Update base for Update on "[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result"
Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 3. Re-enabled the tests that were disabled in #129519 **test** `pytest test/distributed/_composable/fsdp/` `pytest test/distributed/_composable/test_composability/test_2d_composability.py` `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114) [ghstack-poisoned]
2 parents 5da798e + da32021 commit 1ed2ce4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1321
-387
lines changed

aten/src/ATen/native/Copy.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
130130
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
131131
bool is_supported_device(Device device) {
132132
DeviceType device_type = device.type();
133-
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS;
133+
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS || device_type == kXPU;
134134
}
135135

136136
} // namespace
@@ -221,6 +221,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
221221
// cpu_tensor.copy_(xla_tensor) => xla_tensor._copy_from(cpu_tensor)
222222
// xla_tensor.copy_(cpu_tensor) => cpu_tensor._copy_from(xla_tensor)
223223
// Both the _copy_from calls above will be dispatched to XLA's _copy_from kernels.
224+
224225
if (!is_supported_device(src.device()) || !is_supported_device(self.device())) {
225226
at::_copy_from(src, self, non_blocking);
226227
return self;
@@ -287,6 +288,8 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
287288
device_type = kHIP;
288289
} else if (iter.device_type(1) == kMPS) {
289290
device_type = kMPS;
291+
} else if (iter.device_type(1) == kXPU){
292+
device_type = kXPU;
290293
}
291294

292295
// TODO: if we need to, we can also enable this path for quantized tensor

aten/src/ATen/native/Distance.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, std
102102
// See Note [cdist relies on cdist_impl redispatching]
103103
// Keep this condition in sync with the condition at the Note
104104
if (!(p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25))))) {
105-
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1);
106-
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2);
105+
TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X1 got: ", device1);
106+
TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X2 got: ", device2);
107107
}
108108

109109
auto dim1 = x1.dim();

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,8 +1814,8 @@ static Tensor& std_var_out(
18141814
const char* fname, Tensor& result, const Tensor& self,
18151815
at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
18161816
bool keepdim, bool take_sqrt) {
1817-
TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda(),
1818-
"std and var only supports tensors on a CPU or CUDA device, but got: ",
1817+
TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda() || self.device().is_xpu(),
1818+
"std and var supports tensors on a CPU, CUDA, or XPU device only, but got: ",
18191819
self.device().type());
18201820
TORCH_CHECK(self.layout() == Layout::Strided,
18211821
"std and var only supports strided layout, got: ", self.layout());
@@ -1887,8 +1887,8 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
18871887
at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
18881888
bool keepdim, bool take_sqrt) {
18891889
AT_ASSERT(result1.defined() && result2.defined());
1890-
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
1891-
fname, " only supports tensors on a CPU or CUDA device, got: ",
1890+
TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu(),
1891+
fname, " supports tensors on a CPU, CUDA, or XPU device only, got: ",
18921892
self.device().type());
18931893
TORCH_CHECK(self.layout() == Layout::Strided,
18941894
fname, " only supports strided layout, got: ", self.layout());

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List<std::optional<Tensor>
811811
at::assert_no_overlap(self, *index);
812812
}
813813
}
814-
if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) {
814+
if ((self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU) && (accumulate || globalContext().deterministicAlgorithms())) {
815815
TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ",
816816
value_.device(), " for value tensor");
817817
index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe);

aten/src/ATen/native/cuda/PointwiseOpsKernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
namespace at::native {
1313

14+
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
1415
CONSTEXPR_EXCEPT_WIN_CUDA char addcmul_name[] = "addcmul";
16+
#endif
1517
void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
1618
auto dtype = iter.common_dtype();
1719
if (at::isComplexType(dtype)) {
@@ -55,8 +57,10 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
5557
}
5658
}
5759

60+
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
5861
// return a + alpha * (b / static_cast<accscalar_t>(c));
5962
CONSTEXPR_EXCEPT_WIN_CUDA char addcdiv_name[] = "addcdiv";
63+
#endif
6064
void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
6165
auto dtype = iter.common_dtype();
6266
if (at::isComplexType(dtype)) {

0 commit comments

Comments
 (0)
0