10000 std/var: Deprecate default `correction` value and `unbiased` argument by peterbell10 · Pull Request #55679 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

std/var: Deprecate default correction value and unbiased argument #55679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
58ac5d9
std/var: Deprecate default correction value
peterbell10 Apr 9, 2021
30d6401
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 13, 2021
558993a
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 15, 2021
64c6311
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 15, 2021
616fdf6
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 19, 2021
6917922
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 19, 2021
0b84a3b
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 19, 2021
e634721
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 20, 2021
8ad2a46
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 20, 2021
46f3956
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 21, 2021
37460b1
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 21, 2021
8000
5b23cc9
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 22, 2021
27c30ed
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 23, 2021
b4427c7
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 27, 2021
de12904
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 29, 2021
5845b5c
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 29, 2021
c010d60
Update on "std/var: Deprecate default correction value"
peterbell10 May 4, 2021
b905ec4
Update on "std/var: Deprecate default correction value"
peterbell10 Jun 14, 2021
7eaa98f
Update on "std/var: Deprecate default correction value"
peterbell10 Jun 24, 2021
b1bcd66
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 6, 2021
f931557
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 7, 2021
58dbff1
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 9, 2021
28b34e3
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 9, 2021
87a60a6
Update on "std/var: Deprecate default correction value"
peterbell10 Aug 9, 2021
4fb2746
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 6, 2022
d4714c5
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 7, 2022
15a6441
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 12, 2022
f16eaeb
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 15, 2022
86d0e28
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 15, 2022
bcba142
Update on "std/var: Deprecate default correction value"
peterbell10 Apr 15, 2022
a9ed413
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 30, 2022
888dcb3
Update on "std/var: Deprecate default correction value"
peterbell10 Jul 30, 2022
2ffe094
Rebase on "std/var: Deprecate default correction value"
peterbell10 Sep 28, 2022
28d522e
Update on "std/var: Deprecate default correction value"
peterbell10 Sep 28, 2022
2a43515
Update on "std/var: Deprecate default correction value"
peterbell10 Sep 29, 2022
b9cab95
Update on "std/var: Deprecate default correction value"
peterbell10 Sep 29, 2022
0793756
Add test for correction=None on "std/var: Deprecate default correctio…
peterbell10 Oct 3, 2022
7e69023
Update on "std/var: Deprecate default correction value"
peterbell10 Oct 4, 2022
404c53e
Update on "std/var: Deprecate default correction value"
peterbell10 Oct 13, 2022
61fafce
Update on "std/var: Deprecate default correction value"
peterbell10 Oct 15, 2022
cac16c6
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 15, 2022
87266e0
Rebase and fix merge conflicts on "std/var: Deprecate default `correc…
peterbell10 Oct 16, 2022
a02bf2d
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 16, 2022
e93ff6d
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 17, 2022
11fc629
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 17, 2022
638db4f
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 19, 2022
211dc6d
Fix lint on "std/var: Deprecate default `correction` value and `unbia…
peterbell10 Oct 19, 2022
0b14c2c
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Oct 20, 2022
ad41bbb
Rebase and fix merge conflicts on "std/var: Deprecate default `correc…
peterbell10 Oct 21, 2022
ecbe7f5
Split std/var opinfos into two on "std/var: Deprecate default `correc…
peterbell10 Nov 1, 2022
9338b8c
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Nov 1, 2022
0047 10000 3bd
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Nov 2, 2022
cfe3ffc
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Nov 2, 2022
dee08f2
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Nov 3, 2022
37d5873
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Nov 7, 2022
be40d20
Rebase on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Dec 4, 2022
8b2bc01
Update on "std/var: Deprecate default `correction` value and `unbiase…
peterbell10 Dec 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 54 additions & 38 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1707,6 +1707,13 @@ static Tensor& std_var_out(
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"std and var only support floating point and complex dtypes");

if (!correction_opt.has_value()) {
TORCH_WARN_ONCE(
fname, ": the default for the correction parameter is deprecated. ",
"Call with correction=1 to maintain the current default behavior.")
}

const auto correction = correction_opt.value_or(1);
if (at::isComplexType(self.scalar_type())) {
// For complex, calculate variance of real and imaginary components
// separately then add to get overall variance.
Expand All @@ -1718,7 +1725,7 @@ static Tensor& std_var_out(
real_out,
real_in,
dim,
correction_opt,
correction,
keepdim,
/*take_sqrt=*/false);

Expand All @@ -1729,7 +1736,7 @@ static Tensor& std_var_out(
imag_out,
imag_in,
dim,
correction_opt,
correction,
keepdim,
/*take_sqrt=*/false);

Expand All @@ -1741,7 +1748,6 @@ static Tensor& std_var_out(
}

// Computation for floating point
const auto correction = correction_opt.value_or(1);
ScalarType dtype = get_dtype_from_result(result, {});
auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
Expand Down Expand Up @@ -1781,7 +1787,13 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
TORCH_CHECK(result1.scalar_type() == c10::toRealValueType(result2.scalar_type()),
fname, " expected result1 to be real and match the precision of result2. Got ",
result1.scalar_type(), " and ", result2.scalar_type(), ".");
if (!correction_opt.has_value()) {
TORCH_WARN_ONCE(
fname, ": the default for the correction parameter is deprecated. ",
"Call with correction=1 to maintain the current default behavior.")
}

const auto correction = correction_opt.value_or(1);
if (at::isComplexType(self.scalar_type())) {
// For complex, calculate for real and imaginary components separately then combine as:
// variance = var_real + var_imag
Expand All @@ -1796,7 +1808,7 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
real_out_mean,
real_in,
dim,
correction_opt,
correction,
keepdim,
/*take_sqrt=*/false);

Expand All @@ -1809,7 +1821,7 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
imag_out_mean,
imag_in,
dim,
correction_opt,
correction,
keepdim,
/*take_sqrt=*/false);

Expand All @@ -1822,7 +1834,6 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
}

// Computation for floating point
const auto correction = correction_opt.value_or(1);
ScalarType dtype = get_dtype_from_result(result1, {});
auto iter =
make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
Expand All @@ -1837,32 +1848,41 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
return std::tuple<Tensor&, Tensor&>(result1, result2);
}

static inline c10::optional<int64_t> correction_from_unbiased(
c10::string_view fname, bool unbiased) {
if (unbiased) {
TORCH_WARN_ONCE(
fname, ": The 'unbiased' parameter and it's default value are deprecated in favor of 'correction'. "
"Use correction=1 for Bessel's correction, equivalent to unbiased=True.");
return 1;
} else {
TORCH_WARN_ONCE(
fname, ": The 'unbiased; parameter is deprecated. "
"Use correction=0 to apply no Bessel's correction, equivalent to unbiased=False.");
return 0;
}
}

std::tuple<Tensor, Tensor> var_mean(
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::var_mean(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
keepdim);
const auto correction = correction_from_unbiased("var_mean", unbiased);
return at::var_mean(self, dim, correction, keepdim);
}

std::tuple<Tensor, Tensor> std_mean(
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::std_mean(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
k 57AE eepdim);
const auto correction = correction_from_unbiased("std_mean", unbiased);
return at::std_mean(self, dim, correction, keepdim);
}

std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
return at::std_mean(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
const auto correction = correction_from_unbiased("std_mean", unbiased);
return at::std_mean(self, /*dim=*/c10::nullopt, correction);
}

std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
return at::var_mean(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
const auto correction = correction_from_unbiased("var_mean", unbiased);
return at::var_mean(self, /*dim=*/c10::nullopt, correction);
}

std::tuple<Tensor&, Tensor&> var_mean_out(
Expand Down Expand Up @@ -1896,38 +1916,34 @@ std::tuple<Tensor, Tensor> std_mean(
}

Tensor var(const Tensor& self, bool unbiased) {
return at::var(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
const auto correction = correction_from_unbiased("var", unbiased);
return at::var(self, /*dim=*/c10::nullopt, correction);
}

Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::var(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
keepdim);
Tensor var(const Tensor& self, at::OptionalIntArrayRef dim,
bool unbiased, bool keepdim) {
const auto correction = correction_from_unbiased("var", unbiased);
return at::var(self, dim, correction, keepdim);
}

Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
return at::var_out(
result, self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
keepdim);
const auto correction = correction_from_unbiased("var", unbiased);
return at::var_out(result, self, dim, correction, keepdim);
}

Tensor std(const Tensor& self, bool unbiased) {
return at::std(
self, /*dim=*/c10::nullopt, /*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
const auto correction = correction_from_unbiased("std", unbiased);
return at::std(self, /*dim=*/c10::nullopt, correction);
}

Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::std(self, dim,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
const auto correction = correction_from_unbiased("std", unbiased);
return at::std(self, dim, correction, keepdim);
}

Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
return at::std_out(result, self, opt_dim,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
const auto correction = correction_from_unbiased("std", unbiased);
return at::std_out(result, self, opt_dim, correction, keepdim);
}

Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/fastrnns/custom_lstms.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, normalized_shape):
@jit.script_method
def compute_layernorm_stats(self, input):
mu = input.mean(-1, keepdim=True)
sigma = input.std(-1, keepdim=True, unbiased=False)
sigma = input.std(-1, keepdim=True, correction=0)
return mu, sigma

@jit.script_method
Expand Down
10 changes: 5 additions & 5 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def test_binomial_vectorized_count(self):
samples = bin1.sample(torch.Size((100000,)))
self.assertTrue((samples <= total_count.type_as(samples)).all())
self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0)
self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0)
self.assertEqual(samples.var(correction=1, dim=0), bin1.variance, atol=0.02, rtol=0)

def test_negative_binomial(self):
p = torch.arange(0.05, 1, 0.1).requires_grad_()
Expand Down Expand Up @@ -2074,7 +2074,7 @@ def test_lowrank_multivariate_normal_moments(self):
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
empirical_var = samples.var(0)
empirical_var = samples.var(correction=1, dim=0)
self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0)

def test_multivariate_normal_shape(self):
Expand Down Expand Up @@ -2216,7 +2216,7 @@ def test_multivariate_normal_moments(self):
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
empirical_var = samples.var(0)
empirical_var = samples.var(correction=1, dim=0)
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)

# We applied same tests in Multivariate Normal distribution for Wishart distribution
Expand Down Expand Up @@ -2380,7 +2380,7 @@ def test_wishart_moments(self):
samples = d.rsample((ndim * ndim * 100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0)
empirical_var = samples.var(0)
empirical_var = samples.var(correction=1, dim=0)
self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0)

def test_exponential(self):
Expand Down Expand Up @@ -2621,7 +2621,7 @@ def test_kumaraswamy_mean_variance(self):
max_error = max(error[error == error])
self.assertLess(max_error, 0.01,
"Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases)))
expected = samples.var(0)
expected = samples.var(correction=1, dim=0)
actual = m.variance
error = (expected - actual).abs()
max_error = max(error[error == error])
Expand Down
Loading
0