@@ -2153,81 +2153,53 @@ static void _scatter_via_index_put(
2153
2153
const Tensor& src,
2154
2154
const Tensor& mut_out,
2155
2155
bool accumulate) {
2156
- if (self. dim () == 1 ) {
2157
- torch::List<std::optional<Tensor>> indices;
2158
- indices. reserve ( 1 );
2159
- indices. push_back (index) ;
2160
- mut_out. index_put_ (indices, src, accumulate);
2161
- } else {
2162
- Tensor mut_out_contig = mut_out. contiguous () ;
2163
-
2164
- auto index_coords_sizes = index.sizes(). vec ();
2165
- index_coords_sizes. push_back (self. dim ()) ;
2166
- auto index_coords = at::empty (
2167
- index_coords_sizes,
2168
- at::TensorOptions (). dtype (at::ScalarType::Long). device (self. device ()));
2156
+ // If index is expanded with zero strides across non-scatter dimensions,
2157
+ // advanced indexing with the index tensor alone achieves the desired
2158
+ // semantics and avoids creating large intermediate tensors.
2159
+ bool broadcast_index = true ;
2160
+ for ( const auto i : c10::irange (index. dim ())) {
2161
+ if (i == dim) {
2162
+ continue ;
2163
+ }
2164
+ if ( index.stride (i) != 0 ) {
2165
+ broadcast_index = false ;
2166
+ break ;
2167
+ }
2168
+ }
2169
2169
2170
- for (int64_t dim_other = 0 ; dim_other < self.dim (); dim_other++) {
2171
- if (dim_other == dim) {
2172
- continue ;
2173
- }
2174
- auto dim_coord_vals = at::arange (
2175
- index.size (dim_other), at::TensorOptions ().device (self.device ()));
2170
+ auto src_view = at::as_strided (src, index.sizes (), src.strides ());
2171
+ torch::List<std::optional<Tensor>> indices;
2172
+ indices.reserve (self.dim ());
2176
2173
2177
- for (int64_t dim_unsqueeze = 0 ; dim_unsqueeze < self.dim () - 1 ;
2178
- dim_unsqueeze++) {
2179
- dim_coord_vals =
2180
- dim_coord_vals.unsqueeze ((dim_unsqueeze >= dim_other) ? -1 : 0 );
2174
+ if (self.dim () == 1 || broadcast_index) {
2175
+ Tensor squeezed = index;
2176
+ if (broadcast_index && index.dim () > 1 ) {
2177
+ for (const auto d : c10::irange (index.dim ())) {
2178
+ if (d == dim) {
2179
+ continue ;
2180
+ }
2181
+ squeezed = squeezed.select (d, 0 );
2181
2182
}
2182
-
2183
- auto view_sizes = index.sizes ().vec ();
2184
- view_sizes.push_back (1 );
2185
- auto view_strides = index_coords.strides ().vec ();
2186
- view_strides[self.dim ()] = self.dim ();
2187
-
2188
- at::as_strided (index_coords, view_sizes, view_strides, dim_other)
2189
- .copy_ (dim_coord_vals.unsqueeze (-1 ));
2190
2183
}
2184
+ for ([[maybe_unused]] const auto d : c10::irange (dim)) {
2185
+ indices.push_back (Tensor ());
2186
+ }
2187
+ indices.push_back (squeezed);
2188
+ mut_out.index_put_ (indices, src_view, accumulate);
2189
+ return ;
2190
+ }
2191
2191
2192
- auto view_sizes = index.sizes ().vec ();
2193
- view_sizes.push_back (1 );
2194
- auto view_strides = index_coords.strides ().vec ();
2195
- view_strides[self.dim ()] = self.dim ();
2196
-
2197
- at::as_strided (index_coords, view_sizes, view_strides, dim)
2198
- .copy_ (index.unsqueeze (-1 ));
2199
-
2200
- Tensor index_coords_flat = index_coords.flatten (0 , -2 );
2201
-
2202
- // Copy mut_out_contig's strides into a tensor
2203
- // TODO: Is there a utility function that already does this?
2204
- IntArrayRef mut_out_contig_strides = mut_out_contig.strides ();
2205
- Tensor coord_strides = at::empty (
2206
- {mut_out_contig.dim ()},
2207
- TensorOptions ().dtype (at::ScalarType::Long).device (at::kCPU ));
2208
- std::memcpy (
2209
- coord_strides.mutable_data_ptr (),
2210
- mut_out_contig_strides.data (),
2211
- coord_strides.nbytes ());
2212
- coord_strides = coord_strides.to (mut_out_contig.device ());
2213
-
2214
- // `index_flat` contains the 1-D indices corresponding with the
2215
- // flattened `mut_out`
2216
- Tensor index_flat = (index_coords_flat * coord_strides).sum ({-1 });
2217
- Tensor mut_out_flat = mut_out_contig.flatten ();
2218
- Tensor src_flat =
2219
- at::as_strided (src, index.sizes (), src.strides ()).flatten ();
2220
-
2221
- torch::List<std::optional<Tensor>> indices;
2222
- indices.reserve (1 );
2223
- indices.push_back (index_flat);
2224
-
2225
- mut_out_flat.index_put_ (indices, src_flat, accumulate);
2226
-
2227
- if (!mut_out.is_contiguous ()) {
2228
- mut_out.copy_ (mut_out_flat.reshape (mut_out.sizes ()));
2192
+ for (const auto d : c10::irange (self.dim ())) {
2193
+ if (d == dim) {
2194
+ indices.push_back (index);
2195
+ } else {
2196
+ auto arange = at::arange (index.size (d), index.options ());
2197
+ std::vector<int64_t > shape (index.dim (), 1 );
2198
+ shape[d] = index.size (d);
2199
+ indices.push_back (arange.view (shape).expand (index.sizes ()));
2229
2200
}
2230
2201
}
2202
+ mut_out.index_put_ (indices, src_view, accumulate);
2231
2203
}
2232
2204
2233
2205
template <
0 commit comments