8000 use more efficient implementation for broadcasted indexing in determi… · pytorch/pytorch@beb52f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit beb52f5

Browse files
ngimelpytorchmergebot
authored andcommitted
use more efficient implementation for broadcasted indexing in determi… (#156744)
…nistic scatter_add per title Pull Request resolved: #156744 Approved by: https://github.com/suo
1 parent 9b498d3 commit beb52f5

File tree

2 files changed

+56
-68
lines changed

2 files changed

+56
-68
lines changed

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 40 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,81 +2153,53 @@ static void _scatter_via_index_put(
21532153
const Tensor& src,
21542154
const Tensor& mut_out,
21552155
bool accumulate) {
2156-
if (self.dim() == 1) {
2157-
torch::List<std::optional<Tensor>> indices;
2158-
indices.reserve(1);
2159-
indices.push_back(index);
2160-
mut_out.index_put_(indices, src, accumulate);
2161-
} else {
2162-
Tensor mut_out_contig = mut_out.contiguous();
2163-
2164-
auto index_coords_sizes = index.sizes().vec();
2165-
index_coords_sizes.push_back(self.dim());
2166-
auto index_coords = at::empty(
2167-
index_coords_sizes,
2168-
at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
2156+
// If index is expanded with zero strides across non-scatter dimensions,
2157+
// advanced indexing with the index tensor alone achieves the desired
2158+
// semantics and avoids creating large intermediate tensors.
2159+
bool broadcast_index = true;
2160+
for (const auto i : c10::irange(index.dim())) {
2161+
if (i == dim) {
2162+
continue;
2163+
}
2164+
if (index.stride(i) != 0) {
2165+
broadcast_index = false;
2166+
break;
2167+
}
2168+
}
21692169

2170-
for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
2171-
if (dim_other == dim) {
2172-
continue;
2173-
}
2174-
auto dim_coord_vals = at::arange(
2175-
index.size(dim_other), at::TensorOptions().device(self.device()));
2170+
auto src_view = at::as_strided(src, index.sizes(), src.strides());
2171+
torch::List<std::optional<Tensor>> indices;
2172+
indices.reserve(self.dim());
21762173

2177-
for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1;
2178-
dim_unsqueeze++) {
2179-
dim_coord_vals =
2180-
dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
2174+
if (self.dim() == 1 || broadcast_index) {
2175+
Tensor squeezed = index;
2176+
if (broadcast_index && index.dim() > 1) {
2177+
for (const auto d : c10::irange(index.dim())) {
2178+
if (d == dim) {
2179+
continue;
2180+
}
2181+
squeezed = squeezed.select(d, 0);
21812182
}
2182-
2183-
auto view_sizes = index.sizes().vec();
2184-
view_sizes.push_back(1);
2185-
auto view_strides = index_coords.strides().vec();
2186-
view_strides[self.dim()] = self.dim();
2187-
2188-
at::as_strided(index_coords, view_sizes, view_strides, dim_other)
2189-
.copy_(dim_coord_vals.unsqueeze(-1));
21902183
}
2184+
for ([[maybe_unused]] const auto d : c10::irange(dim)) {
2185+
indices.push_back(Tensor());
2186+
}
2187+
indices.push_back(squeezed);
2188+
mut_out.index_put_(indices, src_view, accumulate);
2189+
return;
2190+
}
21912191

2192-
auto view_sizes = index.sizes().vec();
2193-
view_sizes.push_back(1);
2194-
auto view_strides = index_coords.strides().vec();
2195-
view_strides[self.dim()] = self.dim();
2196-
2197-
at::as_strided(index_coords, view_sizes, view_strides, dim)
2198-
.copy_(index.unsqueeze(-1));
2199-
2200-
Tensor index_coords_flat = index_coords.flatten(0, -2);
2201-
2202-
// Copy mut_out_contig's strides into a tensor
2203-
// TODO: Is there a utility function that already does this?
2204-
IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
2205-
Tensor coord_strides = at::empty(
2206-
{mut_out_contig.dim()},
2207-
TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
2208-
std::memcpy(
2209-
coord_strides.mutable_data_ptr(),
2210-
mut_out_contig_strides.data(),
2211-
coord_strides.nbytes());
2212-
coord_strides = coord_strides.to(mut_out_contig.device());
2213-
2214-
// `index_flat` contains the 1-D indices corresponding with the
2215-
// flattened `mut_out`
2216-
Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
2217-
Tensor mut_out_flat = mut_out_contig.flatten();
2218-
Tensor src_flat =
2219-
at::as_strided(src, index.sizes(), src.strides()).flatten();
2220-
2221-
torch::List<std::optional<Tensor>> indices;
2222-
indices.reserve(1);
2223-
indices.push_back(index_flat);
2224-
2225-
mut_out_flat.index_put_(indices, src_flat, accumulate);
2226-
2227-
if (!mut_out.is_contiguous()) {
2228-
mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
2192+
for (const auto d : c10::irange(self.dim())) {
2193+
if (d == dim) {
2194+
indices.push_back(index);
2195+
} else {
2196+
auto arange = at::arange(index.size(d), index.options());
2197+
std::vector<int64_t> shape(index.dim(), 1);
2198+
shape[d] = index.size(d);
2199+
indices.push_back(arange.view(shape).expand(index.sizes()));
22292200
}
22302201
}
2202+
mut_out.index_put_(indices, src_view, accumulate);
22312203
}
22322204

22332205
template <

test/test_scatter_gather_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ def helper(input_size, idx_size):
380380
helper([50, 8, 7], 100)
381381
helper([50, 3, 4, 5], 100)
382382

383+
@dtypes(torch.float32)
384+
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
385+
for d in (0, 1):
386+
inp = torch.randn(3, 4, device=device, dtype=dtype)
387+
idx_1d = torch.randint(3, (10,), device=device)
388+
src_shape = list(inp.shape)
389+
src_shape[d] = 10
390+
src = torch.randn(src_shape, device=device, dtype=dtype)
391+
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
392+
print(idx.stride())
393+
ref = inp.clone().scatter_add_(d, idx, src)
394+
with DeterministicGuard(True):
395+
res = inp.clone().scatter_add_(d, idx, src)
396+
self.assertEqual(res, ref)
397+
398+
383399
@onlyCPU
384400
@dtypes(torch.float32, torch.float64, torch.bfloat16)
385401
def test_gather_expanded_index(self, device, dtype):

0 commit comments

Comments
 (0)
0