8000 Port `mean` kernel to structured kernels. · pytorch/pytorch@c484167 · GitHub
[go: up one dir, main page]

Skip to content

Commit c484167

Browse files
ysiraichiysiraichicommitted
Port mean kernel to structured kernels.
Tracking issue: #55070 ghstack-source-id: b45602c Pull Request resolved: #61643
1 parent 2369c49 commit c484167

File tree

3 files changed

+38
-30
lines changed

3 files changed

+38
-30
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,21 @@ TORCH_META_FUNC2(sum, dim_IntList)
141141
set_output(shape, self.options().dtype(dtype));
142142
}
143143

144+
TORCH_META_FUNC2(mean, dim)
145+
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
146+
DimVector dims(dim);
147+
maybe_wrap_dims(dims, self.dim());
148+
149+
ScalarType dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
150+
DimVector shape = get_reduction_shape(self, dims, keepdim);
151+
set_output(shape, self.options().dtype(dtype));
152+
153+
TORCH_CHECK(
154+
at::isFloatingType(dtype) || at::isComplexType(dtype),
155+
"Can only calculate the mean of floating types. Got ",
156+
toString(dtype), " instead.");
157+
}
158+
144159
} // namespace meta
145160

146161
namespace native {
@@ -1030,15 +1045,13 @@ Tensor& prod_out(const Tensor& self, Dimname dim,
10301045
return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
10311046
}
10321047

1033-
Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
1034-
bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor &result) {
1035-
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
1036-
TORCH_CHECK(
1037-
at::isFloatingType(scalarType) || at::isComplexType(scalarType),
1038-
"Can only calculate the mean of floating types. Got ",
1039-
toString(scalarType),
1040-
" instead.");
1041-
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
1048+
TORCH_IMPL_FUNC(mean_out)
1049+
(const Tensor& self,
1050+
IntArrayRef dim,
1051+
bool keepdim,
1052+
c10::optional<ScalarType> opt_dtype,
1053+
const Tensor& result) {
1054+
ScalarType dtype = result.scalar_type();
10421055
// TODO: the TensorIterator reduction implementation of mean
10431056
// (mean_kernel_impl()) is unvectorized and leads to very poor performance
10441057
// for production workloads. Once that's fixed, the following code can be used
@@ -1052,27 +1065,22 @@ Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
10521065
dim_prod *= self.size(d);
10531066
}
10541067
}
1055-
at::sum_out(result, self, dim, keepdim, dtype).div_(dim_prod);
1056-
return result;
1057-
}
1058-
1059-
auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
1060-
if (iter.numel() == 0) {
1061-
result.fill_(std::numeric_limits<double>::quiet_NaN());
1068+
auto& result_mut = const_cast<Tensor&>(result);
1069+
at::sum_out(result_mut, self, dim, keepdim, dtype).div_(dim_prod);
10621070
} else {
1063-
mean_stub(iter.device_type(), iter);
1071+
DimVector dims(dim);
1072+
auto iter = at::meta::make_reduction_from_out_ty(
1073+
self, result, dims, keepdim, dtype);
1074+
if (iter.numel() == 0) {
1075+
result.fill_(std::numeric_limits<double>::quiet_NaN());
1076+
} else {
1077+
mean_stub(iter.device_type(), iter);
1078+
}
10641079
}
1065-
return result;
10661080
}
10671081

10681082
Tensor mean_cpu_gpu(const Tensor &self, optional<ScalarType> dtype) {
1069-
return at::native::mean_cpu_gpu(self, IntArrayRef{}, false, dtype);
1070-
}
1071-
1072-
Tensor mean_cpu_gpu(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
1073-
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
1074-
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
1075-
return at::native::mean_out_cpu_gpu(self, dim, keepdim, dtype, result);
1083+
return at::mean(self, IntArrayRef{}, false, dtype);
10761084
}
10771085

10781086
Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, optional<ScalarType> dtype) {
@@ -1878,4 +1886,4 @@ Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const
18781886
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
18791887
}
18801888

1881-
}} // namespace at::native
1889+
}} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2791,16 +2791,17 @@
27912791
QuantizedCPU: mean_quantized_cpu
27922792

27932793
- func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
2794+
structured_delegate: mean.out
27942795
device_check: NoCheck # TensorIterator
27952796
variants: function, method
27962797
dispatch:
2797-
CPU, CUDA: mean_cpu_gpu
27982798
QuantizedCPU: mean_quantized_cpu
27992799

28002800
- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
2801+
structured: True
28012802
device_check: NoCheck # TensorIterator
28022803
dispatch:
2803-
CPU, CUDA: mean_out_cpu_gpu
2804+
CPU, CUDA: mean_out
28042805
QuantizedCPU: mean_out_quantized_cpu
28052806

28062807
- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor

aten/src/ATen/native/quantized/cpu/qreduction.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ Tensor& mean_out_quantized_cpu(
9797
}
9898
#endif
9999
auto self_dequantized = self.dequantize();
100-
auto result_dequantized =
101-
at::native::mean_cpu_gpu(self_dequantized, dim, keepdim, opt_dtype);
100+
auto result_dequantized = at::mean(self_dequantized, dim, keepdim, opt_dtype);
102101
result = at::quantize_per_tensor(
103102
result_dequantized,
104103
self.q_scale(),

0 commit comments

Comments
 (0)
0