@@ -192,6 +192,31 @@ TORCH_META_FUNC2(sum, dim_IntList)
192
192
namedinference::propagate_names_for_reduction (result, self, dims, keepdim);
193
193
}
194
194
195
+ TORCH_META_FUNC2 (mean, dim)
196
+ (const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
197
+ auto self_dtype = self.scalar_type ();
198
+ TORCH_CHECK (
199
+ at::isFloatingType (self_dtype) || at::isComplexType (self_dtype),
200
+ " Can only calculate the mean of floating types. Got " ,
201
+ toString (self_dtype), " instead." );
202
+
203
+ ScalarType dtype;
204
+ const auto & result = maybe_get_output ();
205
+
206
+ if (result.defined ()) {
207
+ dtype = opt_dtype.value_or (result.scalar_type ());
208
+ } else {
209
+ dtype = at::native::get_dtype_from_self (self, opt_dtype, true );
210
+ }
211
+
212
+ DimVector dims (dim);
213
+ maybe_wrap_dims (dims, self.dim ());
214
+
215
+ DimVector shape = get_reduction_shape (self, dims, keepdim);
216
+ set_output (shape, self.options ().dtype (dtype));
217
+ namedinference::propagate_names_for_reduction (result, self, dim, keepdim);
218
+ }
219
+
195
220
} // namespace meta
196
221
197
222
namespace meta {
@@ -1056,15 +1081,13 @@ Tensor& prod_out(const Tensor& self, Dimname dim,
1056
1081
return at::prod_out (result, self, dimname_to_position (self, dim), keepdim, opt_dtype);
1057
1082
}
1058
1083
1059
- Tensor &mean_out_cpu_gpu (const Tensor &self, IntArrayRef dim,
1060
- bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor &result) {
1061
- ScalarType scalarType = opt_dtype.has_value () ? opt_dtype.value () : self.scalar_type ();
1062
- TORCH_CHECK (
1063
- at::i
10000
sFloatingType (scalarType) || at::isComplexType (scalarType),
1064
- " Can only calculate the mean of floating types. Got " ,
1065
- toString (scalarType),
1066
- " instead." );
1067
- ScalarType dtype = get_dtype_from_result (result, opt_dtype);
1084
+ TORCH_IMPL_FUNC (mean_out)
1085
+ (const Tensor& self,
1086
+ IntArrayRef dim,
1087
+ bool keepdim,
1088
+ c10::optional<ScalarType> opt_dtype,
1089
+ const Tensor& result) {
1090
+ ScalarType dtype = result.scalar_type ();
1068
1091
// TODO: the TensorIterator reduction implementation of mean
1069
1092
// (mean_kernel_impl()) is unvectorized and leads to very poor performance
1070
1093
// for production workloads. Once that's fixed, the following code can be used
@@ -1078,27 +1101,22 @@ Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
1078
1101
dim_prod *= self.size (d);
1079
1102
}
1080
1103
}
1081
- at::sum_out (result, self, dim, keepdim, dtype).div_ (dim_prod);
1082
- return result;
1083
- }
1084
-
1085
- auto iter = make_reduction (" mean" , result, self, dim, keepdim, dtype);
1086
- if (iter.numel () == 0 ) {
1087
- result.fill_ (std::numeric_limits<double >::quiet_NaN ());
1104
+ auto & result_mut = const_cast <Tensor&>(result);
1105
+ at::sum_out (result_mut, self, dim, keepdim, dtype).div_ (dim_prod);
1088
1106
} else {
1089
- mean_stub (iter.device_type (), iter);
1107
+ DimVector dims (dim);
1108
+ auto iter = at::meta::make_reduction_from_out_ty (
1109
+ self, result, dims, keepdim, dtype);
1110
+ if (iter.numel () == 0 ) {
1111
+ result.fill_ (std::numeric_limits<double >::quiet_NaN ());
1112
+ } else {
1113
+ mean_stub (iter.device_type (), iter);
1114
+ }
1090
1115
}
1091
- return result;
1092
1116
}
1093
1117
1094
1118
Tensor mean_cpu_gpu (const Tensor &self, optional<ScalarType> dtype) {
1095
- return at::native::mean_cpu_gpu (self, IntArrayRef{}, false , dtype);
1096
- }
1097
-
1098
- Tensor mean_cpu_gpu (const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
1099
- ScalarType dtype = get_dtype_from_self (self, opt_dtype, true );
1100
- Tensor result = create_reduction_result (self, dim, keepdim, dtype);
1101
- return at::native::mean_out_cpu_gpu (self, dim, keepdim, dtype, result);
1119
+ return at::mean (self, IntArrayRef{}, false , dtype);
1102
1120
}
1103
1121
1104
1122
Tensor mean (const Tensor& self, DimnameList dim, bool keepdim, optional<ScalarType> dtype) {
0 commit comments