10000 std/var: Deprecate default `correction` value and `unbiased` argument · peterbell10/pytorch@b148bd7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b148bd7

Browse files
committed
std/var: Deprecate default correction value and unbiased argument
The default correction value of 1 is incompatible with the array api which defaults to 0 correction. So, we should deprecate the existing default as the first step in changing the default without introducing a silent BC-break. This also deprecates the unbiased overloads entirely, since it for better consistency with the array API. ghstack-source-id: 955bfe6 Pull Request resolved: pytorch#55679
1 parent 3d0b2e8 commit b148bd7

19 files changed

+281
-255
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,13 @@ static Tensor& std_var_out(
17071707
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
17081708
"std and var only support floating point and complex dtypes");
17091709

1710+
if (!correction_opt.has_value()) {
1711+
TORCH_WARN_ONCE(
1712+
fname, ": the default for the correction parameter is deprecated. ",
1713+
"Call with correction=1 to maintain the current default behavior.")
1714+
}
1715+
1716+
const auto correction = correction_opt.value_or(1);
17101717
if (at::isComplexType(self.scalar_type())) {
17111718
// For complex, calculate variance of real and imaginary components
17121719
// separately then add to get overall variance.
@@ -1718,7 +1725,7 @@ static Tensor& std_var_out(
17181725
real_out,
17191726
real_in,
17201727
dim,
1721-
correction_opt,
1728+
correction,
17221729
keepdim,
17231730
/*take_sqrt=*/false);
17241731

@@ -1729,7 +1736,7 @@ static Tensor& std_var_out(
17291736
imag_out,
17301737
imag_in,
17311738
dim,
1732-
correction_opt,
1739+
correction,
17331740
keepdim,
17341741
/*take_sqrt=*/false);
17351742

@@ -1741,7 +1748,6 @@ static Tensor& std_var_out(
17411748
}
17421749

17431750
// Computation for floating point
1744-
const auto correction = correction_opt.value_or(1);
17451751
ScalarType dtype = get_dtype_from_result(result, {});
17461752
auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
17471753
TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
@@ -1781,7 +1787,13 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
17811787
TORCH_CHECK(result1.scalar_type() == c10::toRealValueType(result2.scalar_type()),
17821788
fname, " expected result1 to be real and match the precision of result2. Got ",
17831789
result1.scalar_type(), " and ", result2.scalar_type(), ".");
1790+
if (!correction_opt.has_value()) {
1791+
TORCH_WARN_ONCE(
1792+
fname, ": the default for the correction parameter is deprecated. ",
1793+
"Call with correction=1 to maintain the current default behavior.")
1794+
}
17841795

1796+
const auto correction = correction_opt.value_or(1);
17851797
if (at::isComplexType(self.scalar_type())) {
17861798
// For complex, calculate for real and imaginary components separately then combine as:
17871799
// variance = var_real + var_imag
@@ -1796,7 +1808,7 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
17961808
real_out_mean,
17971809
real_in,
17981810
dim,
1799-
correction_opt,
1811+
correction,
18001812
keepdim,
18011813
/*take_sqrt=*/false);
18021814

@@ -1809,7 +1821,7 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
18091821
imag_out_mean,
18101822
imag_in,
18111823
dim,
1812-
correction_opt,
1824+
correction,
18131825
keepdim,
18141826
/*take_sqrt=*/false);
18151827

@@ -1822,7 +1834,6 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
18221834
}
18231835

18241836
// Computation for floating point
1825-
const auto correction = correction_opt.value_or(1);
18261837
ScalarType dtype = get_dtype_from_result(result1, {});
18271838
auto iter =
18281839
make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
@@ -1837,32 +1848,41 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
18371848
return std::tuple<Tensor&, Tensor&>(result1, result2);
18381849
}
18391850

1851+
static inline c10::optional<int64_t> correction_from_unbiased(
1852+
c10::string_view fname, bool unbiased) {
1853+
if (unbiased) {
1854+
TORCH_WARN_ONCE(
1855+
fname, ": The 'unbiased' parameter and it's default value are deprecated in favor of 'correction'. "
1856+
"Use correction=1 for Bessel's correction, equivalent to unbiased=True.");
1857+
return 1;
1858+
} else {
1859+
TORCH_WARN_ONCE(
1860+
fname, ": The 'unbiased; parameter is deprecated. "
1861+
"Use correction=0 to apply no Bessel's correction, equivalent to unbiased=False.");
1862+
return 0;
1863+
}
1864+
}
1865+
18401866
std::tuple<Tensor, Tensor> var_mean(
18411867
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1842-
return at::var_mean(
1843-
self, /*dim=*/at::OptionalIntArrayRef(dim),
1844-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
1845-
keepdim);
1868+
const auto correction = correction_from_unbiased("var_mean", unbiased);
1869+
return at::var_mean(self, dim, correction, keepdim);
18461870
}
18471871

18481872
std::tuple<Tensor, Tensor> std_mean(
18491873
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1850-
return at::std_mean(
1851-
self, /*dim=*/at::OptionalIntArrayRef(dim),
1852-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
1853-
keepdim);
1874+
const auto correction = correction_from_unbiased("std_mean", unbiased);
1875+
return at::std_mean(self, dim, correction, keepdim);
18541876
}
18551877

18561878
std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
1857-
return at::std_mean(
1858-
self, /*dim=*/c10::nullopt,
1859-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
1879+
const auto correction = correction_from_unbiased("std_mean", unbiased);
1880+
return at::std_mean(self, /*dim=*/c10::nullopt, correction);
18601881
}
18611882

18621883
std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
1863-
return at::var_mean(
1864-
self, /*dim=*/c10::nullopt,
1865-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
1884+
const auto correction = correction_from_unbiased("var_mean", unbiased);
1885+
return at::var_mean(self, /*dim=*/c10::nullopt, correction);
18661886
}
18671887

18681888
std::tuple<Tensor&, Tensor&> var_mean_out(
@@ -1896,38 +1916,34 @@ std::tuple<Tensor, Tensor> std_mean(
18961916
}
18971917

18981918
Tensor var(const Tensor& self, bool unbiased) {
1899-
return at::var(
1900-
self, /*dim=*/c10::nullopt,
1901-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
1919+
const auto correction = correction_from_unbiased("var", unbiased);
1920+
return at::var(self, /*dim=*/c10::nullopt, correction);
19021921
}
19031922

1904-
Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1905-
return at::var(
1906-
self, /*dim=*/at::OptionalIntArrayRef(dim),
1907-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
1908-
keepdim);
1923+
Tensor var(const Tensor& self, at::OptionalIntArrayRef dim,
1924+
bool unbiased, bool keepdim) {
1925+
const auto correction = correction_from_unbiased("var", unbiased);
1926+
return at::var(self, dim, correction, keepdim);
19091927
}
19101928

19111929
Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
1912-
return at::var_out(
1913-
result, self, /*dim=*/at::OptionalIntArrayRef(dim),
1914-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
1915-
keepdim);
1930+
const auto correction = correction_from_unbiased("var", unbiased);
1931+
return at::var_out(result, self, dim, correction, keepdim);
19161932
}
19171933

19181934
Tensor std(const Tensor& self, bool unbiased) {
1919-
return at::std(
1920-
self, /*dim=*/c10::nullopt, /*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
1935+
const auto correction = correction_from_unbiased("std", unbiased);
1936+
return at::std(self, /*dim=*/c10::nullopt, correction);
19211937
}
19221938

19231939
Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1924-
return at::std(self, dim,
1925-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
1940+
const auto correction = correction_from_unbiased("std", unbiased);
1941+
return at::std(self, dim, correction, keepdim);
19261942
}
19271943

19281944
Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
1929-
return at::std_out(result, self, opt_dim,
1930-
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
1945+
const auto correction = correction_from_unbiased("std", unbiased);
1946+
return at::std_out(result, self, opt_dim, correction, keepdim);
19311947
}
19321948

19331949
Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,

benchmarks/fastrnns/custom_lstms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(self, normalized_shape):
135135
@jit.script_method
136136
def compute_layernorm_stats(self, input):
137137
mu = input.mean(-1, keepdim=True)
138-
sigma = input.std(-1, keepdim=True, unbiased=False)
138+
sigma = input.std(-1, keepdim=True, correction=0)
139139
return mu, sigma
140140

141141
@jit.script_method

test/distributions/test_distributions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,7 @@ def test_binomial_vectorized_count(self):
12081208
samples = bin1.sample(torch.Size((100000,)))
12091209
self.assertTrue((samples <= total_count.type_as(samples)).all())
12101210
self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0)
1211-
self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0)
1211+
self.assertEqual(samples.var(correction=1, dim=0), bin1.variance, atol=0.02, rtol=0)
12121212

12131213
def test_negative_binomial(self):
12141214
p = torch.arange(0.05, 1, 0.1).requires_grad_()
@@ -2074,7 +2074,7 @@ def test_lowrank_multivariate_normal_moments(self):
20742074
samples = d.rsample((100000,))
20752075
empirical_mean = samples.mean(0)
20762076
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
2077-
empirical_var = samples.var(0)
2077+
empirical_var = samples.var(correction=1, dim=0)
20782078
self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0)
20792079

20802080
def test_multivariate_normal_shape(self):
@@ -2216,7 +2216,7 @@ def test_multivariate_normal_moments(self):
22162216
samples = d.rsample((100000,))
22172217
empirical_mean = samples.mean(0)
22182218
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
2219-
empirical_var = samples.var(0)
2219+
empirical_var = samples.var(correction=1, dim=0)
22202220
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
22212221

22222222
# We applied same tests in Multivariate Normal distribution for Wishart distribution
@@ -2380,7 +2380,7 @@ def test_wishart_moments(self):
23802380
samples = d.rsample((ndim * ndim * 100000,))
23812381
empirical_mean = samples.mean(0)
23822382
self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0)
2383-
empirical_var = samples.var(0)
2383+
empirical_var = samples.var(correction=1, dim=0)
23842384
self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0)
23852385

23862386
def test_exponential(self):
@@ -2621,7 +2621,7 @@ def test_kumaraswamy_mean_variance(self):
26212621
max_error = max(error[error == error])
26222622
self.assertLess(max_error, 0.01,
26232623
"Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases)))
2624-
expected = samples.var(0)
2624+
expected = samples.var(correction=1, dim=0)
26252625
actual = m.variance
26262626
error = (expected - actual).abs()
26272627
max_error = max(error[error == error])

0 commit comments

Comments
 (0)
0