10000 Update on "[4/N] Test NaN checker against broadcast" · pytorch/pytorch@edeb15e · GitHub
[go: up one dir, main page]

Skip to content

Commit edeb15e

Browse files
committed
Update on "[4/N] Test NaN checker against broadcast"
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
2 parents 7cfdb24 + 8f44404 commit edeb15e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+967
-240
lines changed

aten/src/ATen/AccumulateType.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) {
99
switch (device) { \
1010
case DeviceType::CUDA: \
1111
return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::CUDA>>::value; \
12+
case DeviceType::XPU: \
13+
return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::XPU>>::value; \
1214
case DeviceType::MPS: \
1315
return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::MPS>>::value; \
1416
default: \

aten/src/ATen/native/cuda/Nonzero.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/core/Tensor.h>
33
#include <ATen/Dispatch.h>
4+
#include <ATen/EmptyTensor.h>
45
#include <ATen/cuda/CUDAContext.h>
56
#include <c10/cuda/CUDACachingAllocator.h>
67
#include <ATen/cuda/EmptyTensor.h>
@@ -70,7 +71,16 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
7071
auto temp_storage = allocator.allocate(temp_storage_bytes);
7172
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
7273
int num_nonzeros_h;
73-
at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
74+
auto pinned_num_nonzeros_h = at::detail::empty_cpu(
75+
{1}, /* size */
76+
c10::CppTypeToScalarType<int>(), /* dtype */
77+
std::nullopt, /* layout */
78+
std::nullopt, /* device */
79+
true, /* pin_memory */
80+
std::nullopt /* memory format */
81+
);
82+
at:: F438 cuda::memcpy_and_sync((void *)pinned_num_nonzeros_h.const_data_ptr<int>(), num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
83+
num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());
7484
//expected output size is num_nonzeros x ndim
7585
//we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output)
7686
//we are able to directly use passed output with this size and strides, and we can also (per contract)

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,6 @@ _scaled_dot_product_flash_attention_cpu(
870870
int64_t batchSize = query.size(0);
871871
int64_t qSize = query.size(2);
872872
int64_t num_head = query.size(1);
873-
int64_t headSize = query.size(3);
874873

875874
TORCH_CHECK(c10::isFloatingType(dtype),
876875
"scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead.");
@@ -888,7 +887,7 @@ _scaled_dot_product_flash_attention_cpu(
888887
(attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4),
889888
"scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}");
890889

891-
at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options());
890+
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
892891
const auto accumulate_dtype = toOpMathType(dtype);
893892
at::Tensor logsumexp = at::empty({batchSize, qSize, num_head},
894893
query.options().dtype(accumulate_dtype));

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ libtorch_python_core_sources = [
835835
"torch/csrc/dynamo/extra_state.cpp",
836836
"torch/csrc/dynamo/framelocals_mapping.cpp",
837837
"torch/csrc/dynamo/guards.cpp",
838+
"torch/csrc/dynamo/utils.cpp",
838839
"torch/csrc/dynamo/init.cpp",
839840
"torch/csrc/functorch/init.cpp",
840841
"torch/csrc/fx/node.cpp",

docs/source/notes/serialization.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,4 @@ The following utility functions are related to serialization:
398398
.. autofunction:: clear_safe_globals
399399
.. autofunction:: get_safe_globals
400400
.. autoclass:: safe_globals
401+
.. autoclass:: skip_data

test/distributed/_tensor/test_dtensor.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
init_device_mesh,
1616
)
1717
from torch.distributed._tensor.debug import CommDebugMode
18+
from torch.distributed._tensor.experimental import implicit_replication
1819
from torch.distributed._tensor.placement_types import (
1920
DTensorSpec,
2021
Partial,
@@ -778,8 +779,6 @@ def test_implicit_replication(self):
778779
local_tensor1 = torch.ones(4, 3)
779780
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
780781

781-
from torch.distributed._tensor.experimental import implicit_replication
782-
783782
with implicit_replication():
784783
# We put the scalar tensor as the left operand so we can test out
785784
# when a non-dtensor is a the arg in the args list.
@@ -816,6 +815,41 @@ def add_scalar_tensor_with_dtensor():
816815
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
817816
)
818817

818+
@with_comms
819+
def test_implicit_replication_for_foreach_ops(self):
820+
mesh = init_device_mesh(
821+
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
822+
)
823+
global_tensor1 = torch.randn(4, 2)
824+
dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)])
825+
self.assertEqual(dtensor_2d.full_tensor(), global_tensor1)
826+
global_tensor2 = torch.randn(4)
827+
dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)])
828+
dtensor_list = [dtensor_2d, dtensor_1d]
829+
830+
# Check without implicit replication, cross mesh error raises.
831+
with self.assertRaisesRegex(
832+
RuntimeError, "DTensor does not support cross-mesh operation yet!"
833+
):
834+
torch._foreach_mul(dtensor_list, 2.0)
835+
836+
# Check dtensor result matches tensor result.
837+
with implicit_replication():
838+
torch._foreach_mul_(dtensor_list, 2.0)
839+
self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0)
840+
self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0)
841+
842+
mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type)
843+
dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)])
844+
dtensor_list = [dtensor_2d, dtensor_1d]
845+
846+
# Check even with implicit replication, cross mesh error raises if different device mesh don't
847+
# belong to the same root mesh.
848+
with self.assertRaisesRegex(
849+
RuntimeError, "DTensor does not support cross-mesh operation yet!"
850+
):
851+
torch._foreach_mul_(dtensor_list, 2.0)
852+
819853
@with_comms
820854
def test_metadata_consistency_check(self):
821855
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,6 @@ def wrapped(fn):
191191
xfail("index_reduce", "amin"),
192192
xfail("index_select"),
193193
xfail("isin"),
194-
xfail("isinf"),
195-
xfail("isneginf"),
196-
xfail("isposinf"),
197194
xfail("kthvalue"),
198195
xfail("linalg.cholesky"),
199196
xfail("linalg.cholesky_ex"),

0 commit comments

Comments
 (0)
0