8000 std/var: Require correction parameter always be given explicitly · pytorch/pytorch@5dae62e · GitHub
[go: up one dir, main page]

Skip to content

Commit 5dae62e

Browse files
committed
std/var: Require correction parameter always be given explicitly
This changes the default `None` correction and all overloads using the unbiased parameter result in an error. BC-breaking message: `torch.std`, `torch.std_mean`, `torch.var` and `torch.var_mean` now require the `correction` argument be passed in all function calls. This means it cannot be left as defaulted and also the `unbiased` overloads cannot be used any more. To recover the old default, use `correction=1` which is equivalent to `unbiased=True`; or use `correction=0` for the same behavior as `unbiased=False`. ghstack-source-id: 3f68338 Pull Request resolved: #55680
1 parent 5df8317 commit 5dae62e

File tree

10 files changed

+116
-170
lines changed

10 files changed

+116
-170
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,13 +1707,10 @@ static Tensor& std_var_out(
17071707
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
17081708
"std and var only support floating point and complex dtypes");
17091709

1710-
if (!correction_opt.has_value()) {
1711-
TORCH_WARN_ONCE(
1712-
fname, ": the default for the correction parameter is deprecated. ",
1713-
"Call with correction=1 to maintain the current default behavior.")
1714-
}
1715-
1716-
const auto correction = correction_opt.value_or(1);
1710+
TORCH_CHECK(correction_opt.has_value(),
1711+
fname, ": the correction parameter must be given explicitly. ",
1712+
"Call with correction=1 for the old default behavior.")
1713+
const auto correction = *correction_opt;
17171714
if (at::isComplexType(self.scalar_type())) {
17181715
// For complex, calculate variance of real and imaginary components
17191716
// separately then add to get overall variance.
@@ -1787,13 +1784,11 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
17871784
TORCH_CHECK(result1.scalar_type() == c10::toRealValueType(result2.scalar_type()),
17881785
fname, " expected result1 to be real and match the precision of result2. Got ",
17891786
result1.scalar_type(), " and ", result2.scalar_type(), ".");
1790-
if (!correction_opt.has_value()) {
1791-
TORCH_WARN_ONCE(
1792-
fname, ": the default for the correction parameter is deprecated. ",
1793-
"Call with correction=1 to maintain the current default behavior.")
1794-
}
17951787

1796-
const auto correction = correction_opt.value_or(1);
1788+
TORCH_CHECK(correction_opt.has_value(),
1789+
fname, ": the correction parameter must be given explicitly. ",
1790+
"Call with correction=1 for the old default behavior.")
1791+
const auto correction = *correction_opt;
17971792
if (at::isComplexType(self.scalar_type())) {
17981793
// For complex, calculate for real and imaginary components separately then combine as:
17991794
// variance = var_real + var_imag
@@ -1851,13 +1846,13 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
18511846
static inline c10::optional<int64_t> correction_from_unbiased(
18521847
c10::string_view fname, bool unbiased) {
18531848
if (unbiased) {
1854-
TORCH_WARN_ONCE(
1855-
fname, ": The 'unbiased' parameter and it's default value are deprecated in favor of 'correction'. "
1849+
TORCH_CHECK(
1850+
false, fname, ": The 'unbiased' parameter and its default value have been removed. "
18561851
"Use correction=1 for Bessel's correction, equivalent to unbiased=True.");
18571852
return 1;
18581853
} else {
1859-
TORCH_WARN_ONCE(
1860-
fname, ": The 'unbiased; parameter is deprecated. "
1854+
TORCH_CHECK(
1855+
false, fname, ": The 'unbiased; parameter has been removed. "
18611856
"Use correction=0 to apply no Bessel's correction, equivalent to unbiased=False.");
18621857
return 0;
18631858
}

test/onnx/test_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def test_master_opset(self):
864864
def test_std(self):
865865
x = torch.randn(2, 3, 4).float()
866866
self.assertONNX(
867-
lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x
867+
lambda x: torch.std(x, dim=(0, 1), correction=1, keepdim=True), x
868868
)
869869

870870
def test_cumsum(self):

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2964,7 +2964,7 @@ def forward(self, x):
29642964
def test_std(self):
29652965
class StandardDeviation(torch.nn.Module):
29662966
def forward(self, input):
2967-
return torch.std(input, unbiased=False)
2967+
return torch.std(input, correction=0)
29682968

29692969
model = StandardDeviation()
29702970
inputs = torch.randn(2, 3, 4)
@@ -2973,7 +2973,7 @@ def forward(self, input):
29732973
def test_std_along_dims(self):
29742974
class StandardDeviationAlongDims(torch.nn.Module):
29752975
def forward(self, input):
2976-
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=False)
2976+
return torch.std(input, dim=(0, 1), correction=0, keepdim=False)
29772977

29782978
model = StandardDeviationAlongDims()
29792979
inputs = torch.randn(2, 3, 4)

test/quantization/eager/test_quantize_eager_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _forward(self, input):
170170
else:
171171
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
172172
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
173-
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
173+
batch_var = torch.var(conv_orig, dim=[0, 2, 3], correction=0)
174174
n = float(conv_orig.numel() / conv_orig.size()[1])
175175
unbiased_batch_var = batch_var * (n / (n - 1))
176176
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)

test/test_mps.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,53 +2958,49 @@ def helper(shape):
29582958
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
29592959
x = cpu_x.detach().clone().to('mps')
29602960

2961-
for correction_kwarg in [
2962-
dict(unbiased=False),
2963-
dict(unbiased=True),
2964-
dict(correction=2),
2965-
]:
2966-
all_std = torch.std(x, **correction_kwarg)
2967-
all_std_cpu = torch.std(cpu_x, **correction_kwarg)
2961+
for correction in [0, 1, 2]:
2962+
all_std = torch.std(x, correction=correction)
2963+
all_std_cpu = torch.std(cpu_x, correction=correction)
29682964

29692965
self.assertEqual(all_std, all_std_cpu)
29702966

2971-
nil_dim_std = torch.std(x, dim=[], **correction_kwarg)
2972-
nil_dim_std_cpu = torch.std(cpu_x, dim=[], **correction_kwarg)
2967+
nil_dim_std = torch.std(x, dim=[], correction=correction)
2968+
nil_dim_std_cpu = torch.std(cpu_x, dim=[], correction=correction)
29732969

29742970
self.assertEqual(nil_dim_std, nil_dim_std_cpu)
29752971

2976-
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, **correction_kwarg)
2977-
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, **correction_kwarg)
2972+
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, correction=correction)
2973+
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, correction=correction)
29782974

29792975
self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
29802976

2981-
zero_dim_std = torch.std(x, dim=[0], **correction_kwarg)
2982-
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], **correction_kwarg)
2977+
zero_dim_std = torch.std(x, dim=[0], correction=correction)
2978+
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], correction=correction)
29832979

29842980
self.assertEqual(zero_dim_std, zero_dim_std_cpu)
29852981

2986-
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, **correction_kwarg)
2987-
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, **correction_kwarg)
2982+
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, correction=correction)
2983+
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, correction=correction)
29882984

29892985
self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
29902986

2991-
zero_one_dim_std = torch.std(x, dim=[0, 1], **correction_kwarg)
2992-
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], **correction_kwarg)
2987+
zero_one_dim_std = torch.std(x, dim=[0, 1], correction=correction)
2988+
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], correction=correction)
29932989

29942990
self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
29952991

2996-
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, **correction_kwarg)
2997-
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, **correction_kwarg)
2992+
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, correction=correction)
2993+
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, correction=correction)
29982994

29992995
self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
30002996

3001-
two_three_dim_std = torch.std(x, dim=[2, 3], **correction_kwarg)
3002-
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], **correction_kwarg)
2997+
two_three_dim_std = torch.std(x, dim=[2, 3], correction=correction)
2998+
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], correction=correction)
30032999

30043000
self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
30053001

3006-
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, **correction_kwarg)
3007-
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, **correction_kwarg)
3002+
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, correction=correction)
3003+
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, correction=correction)
30083004

30093005
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
30103006

@@ -3021,40 +3017,36 @@ def helper():
30213017
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
30223018
x = cpu_x.detach().clone().to('mps')
30233019

3024-
for correction_kwarg in [
3025-
dict(unbiased=False),
3026-
dict(unbiased=True),
3027-
dict(correction=2),
3028-
]:
3020+
for correction in [0, 1, 2]:
30293021
for keepdim in [False, True]:
30303022

3031-
zero_dim_var = x.var(-1, keepdim=keepdim, **correction_kwarg)
3032-
zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, **correction_kwarg)
3023+
zero_dim_var = x.var(-1, keepdim=keepdim, correction=correction)
3024+
zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, correction=correction)
30333025

30343026
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
30353027

3036-
all_var = torch.var(x, **correction_kwarg)
3037-
all_var_cpu = torch.var(cpu_x, **correction_kwarg)
3028+
all_var = torch.var(x, correction=correction)
3029+
all_var_cpu = torch.var(cpu_x, correction=correction)
30383030

30393031
self.assertEqual(all_var, all_var_cpu)
30403032

3041-
nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, **correction_kwarg)
3042-
nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, **correction_kwarg)
3033+
nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, correction=correction)
3034+
nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, correction=correction)
30433035

30443036
self.assertEqual(nil_dim_var, nil_dim_var_cpu)
30453037

3046-
zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, **correction_kwarg)
3047-
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, **correction_kwarg)
3038+
zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, correction=correction)
3039+
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, correction=correction)
30483040

30493041
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
30503042

3051-
zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, **correction_kwarg)
3052-
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, **correction_kwarg)
3043+
zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, correction=correction)
3044+
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, correction=correction)
30533045

30543046
self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
30553047

3056-
two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, **correction_kwarg)
3057-
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, **correction_kwarg)
3048+
two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, correction=correction)
3049+
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, correction=correction)
30583050

30593051
self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
30603052

0 commit comments

Comments
 (0)
0