8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0b5303e commit 4a2aa0fCopy full SHA for 4a2aa0f
aten/src/ATen/native/cuda/Indexing.cu
@@ -230,7 +230,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>
230
std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
231
dim3 block(C10_WARP_SIZE, indices_per_block);
232
233
- AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
234
value_.scalar_type(), "indexing_backward", [&] {
235
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
236
sorted_indices.data_ptr<int64_t>(),
test/test_indexing.py
@@ -762,9 +762,9 @@ def test_int_indices(self, device):
762
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
763
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
764
765
- @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
766
- @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16)
767
- @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16)
+ @dtypes(torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool)
+ @dtypesIfCPU(torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16)
+ @dtypesIfCUDA(torch.cfloat, torch.cdouble, torch.half, torch.long, torch.bool, torch.bfloat16)
768
def test_index_put_src_datatype(self, device, dtype):
769
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
770
vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
0 commit comments