8000 Port `mean` kernel to structured kernels. (#61643) · pytorch/pytorch@fbce0fe · GitHub
[go: up one dir, main page]

Skip to content

Commit fbce0fe

Browse files
ysiraichialanwaketan
authored andcommitted
Port mean kernel to structured kernels. (#61643)
Summary: Pull Request resolved: #61643 Tracking issue: #55070 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D29783866 Pulled By: ezyang fbshipit-source-id: dc95baf593096c03fb5f292ee6c36de3cc7f2b35
1 parent 77ca0d9 commit fbce0fe

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,31 @@ TORCH_META_FUNC2(sum, dim_IntList)
192192
namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
193193
}
194194

195+
TORCH_META_FUNC2(mean, dim)
196+
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
197+
auto self_dtype = self.scalar_type();
198+
TORCH_CHECK(
199+
at::isFloatingType(self_dtype) || at::isComplexType(self_dtype),
200+
"Can only calculate the mean of floating types. Got ",
201+
toString(self_dtype), " instead.");
202+
203+
ScalarType dtype;
204+
const auto& result = maybe_get_output();
205+
206+
if (result.defined()) {
207+
dtype = opt_dtype.value_or(result.scalar_type());
208+
} else {
209+
dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
210+
}
211+
212+
DimVector dims(dim);
213+
maybe_wrap_dims(dims, self.dim());
214+
215+
DimVector shape = get_reduction_shape(self, dims, keepdim);
216+
set_output(shape, self.options().dtype(dtype));
217+
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
218+
}
219+
195220
} // namespace meta
196221

197222
namespace meta {
@@ -1056,15 +1081,13 @@ Tensor& prod_out(const Tensor& self, Dimname dim,
10561081
return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
10571082
}
10581083

1059-
Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
1060-
bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor &result) {
1061-
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
1062-
TORCH_CHECK(
1063-
at::i 10000 sFloatingType(scalarType) || at::isComplexType(scalarType),
1064-
"Can only calculate the mean of floating types. Got ",
1065-
toString(scalarType),
1066-
" instead.");
1067-
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
1084+
TORCH_IMPL_FUNC(mean_out)
1085+
(const Tensor& self,
1086+
IntArrayRef dim,
1087+
bool keepdim,
1088+
c10::optional<ScalarType> opt_dtype,
1089+
const Tensor& result) {
1090+
ScalarType dtype = result.scalar_type();
10681091
// TODO: the TensorIterator reduction implementation of mean
10691092
// (mean_kernel_impl()) is unvectorized and leads to very poor performance
10701093
// for production workloads. Once that's fixed, the following code can be used
@@ -1078,27 +1101,22 @@ Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
10781101
dim_prod *= self.size(d);
10791102
}
10801103
}
1081-
at::sum_out(result, self, dim, keepdim, dtype).div_(dim_prod);
1082-
return result;
1083-
}
1084-
1085-
auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
1086-
if (iter.numel() == 0) {
1087-
result.fill_(std::numeric_limits<double>::quiet_NaN());
1104+
auto& result_mut = const_cast<Tensor&>(result);
1105+
at::sum_out(result_mut, self, dim, keepdim, dtype).div_(dim_prod);
10881106
} else {
1089-
mean_stub(iter.device_type(), iter);
1107+
DimVector dims(dim);
1108+
auto iter = at::meta::make_reduction_from_out_ty(
1109+
self, result, dims, keepdim, dtype);
1110+
if (iter.numel() == 0) {
1111+
result.fill_(std::numeric_limits<double>::quiet_NaN());
1112+
} else {
1113+
mean_stub(iter.device_type(), iter);
1114+
}
10901115
}
1091-
return result;
10921116
}
10931117

10941118
Tensor mean_cpu_gpu(const Tensor &self, optional<ScalarType> dtype) {
1095-
return at::native::mean_cpu_gpu(self, IntArrayRef{}, false, dtype);
1096-
}
1097-
1098-
Tensor mean_cpu_gpu(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
1099-
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
1100-
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
1101-
return at::native::mean_out_cpu_gpu(self, dim, keepdim, dtype, result);
1119+
return at::mean(self, IntArrayRef{}, false, dtype);
11021120
}
11031121

11041122
Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, optional<ScalarType> dtype) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,16 +2815,17 @@
28152815
QuantizedCPU: mean_quantized_cpu
28162816

28172817
- func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
2818+
structured_delegate: mean.out
28182819
device_check: NoCheck # TensorIterator
28192820
variants: function, method
28202821
dispatch:
2821-
CPU, CUDA: mean_cpu_gpu
28222822
QuantizedCPU: mean_quantized_cpu
28232823

28242824
- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
2825+
structured: True
28252826
device_check: NoCheck # TensorIterator
28262827
dispatch:
2827-
CPU, CUDA: mean_out_cpu_gpu
2828+
CPU, CUDA: mean_out
28282829
QuantizedCPU: mean_out_quantized_cpu
28292830

28302831
- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor

aten/src/ATen/native/quantized/cpu/qreduction.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ Tensor& mean_out_quantized_cpu(
9797
}
9898
#endif
9999
auto self_dequantized = self.dequantize();
100-
auto result_dequantized =
101-
at::native::mean_cpu_gpu(self_dequantized, dim, keepdim, opt_dtype);
100+
auto result_dequantized = at::mean(self_dequantized, dim, keepdim, opt_dtype);
102101
result = at::quantize_per_tensor(
103102
result_dequantized,
104103
self.q_scale(),

0 commit comments

Comments
 (0)
0