8000 Improve the precision of abs() and sign() for large values (#99550) · pytorch/pytorch@09b189e · GitHub
[go: up one dir, main page]

Skip to content

Commit 09b189e

Browse files
lezcanopytorchmergebot
authored andcommitted
Improve the precision of abs() and sign() for large values (#99550)
@ev-br found in Quansight-Labs/numpy_pytorch_interop#117 (comment) that the precision of `abs()` for large values in the vectorised case is less-than-good. This PR fixes this issue. While doing that, we are able to comment out a few tests on extremal values. Fixes #53958 #48486 Pull Request resolved: #99550 Approved by: https://github.com/ngimel, https://github.com/peterbell10
1 parent 5ee5afb commit 09b189e

File tree

5 files changed

+22
-53
lines changed

5 files changed

+22
-53
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ template <> class Vectorized<c10::complex<double>> {
121121
return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
122122
}
123123
__m256d abs_() const {
124-
return _mm256_sqrt_pd(abs_2_()); // abs abs
124+
auto real = _mm256_movedup_pd(values); // real real
125+
// movehdup_pd does not exist...
126+
auto imag = _mm256_permute_pd(values, 0xf); // imag imag
127+
return Sleef_hypotd4_u05(real, imag); // abs abs
125128
}
126129
Vectorized<c10::complex<double>> abs() const {
127130
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
@@ -143,11 +146,8 @@ template <> class Vectorized<c10::complex<double>> {
143146
auto abs = abs_();
144147
auto zero = _mm256_setzero_pd();
145148
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
146-
auto abs_val = Vectorized(abs);
147-
148-
auto div = values / abs_val.values; // x / abs(x)
149-
150-
return blendv(div, zero, mask);
149+
auto div = values / abs;
150+
return _mm256_blendv_pd(div, zero, mask);
151151
}
152152
__m256d real_() const {
153153
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,

aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ template <> class Vectorized<c10::complex<float>> {
157157
return _mm256_permute_ps(ret, 0xD8);
158158
}
159159
__m256 abs_() const {
160-
return _mm256_sqrt_ps(abs_2_()); // abs abs
160+
auto real = _mm256_moveldup_ps(values); // real real
161+
auto imag = _mm256_movehdup_ps(values); // imag imag
162+
return Sleef_hypotf8_u05(real, imag); // abs abs
161163
}
162164
Vectorized<c10::complex<float>> abs() const {
163165
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
@@ -179,10 +181,7 @@ template <> class Vectorized<c10::complex<float>> {
179181
auto abs = abs_();
180182
auto zero = _mm256_setzero_ps();
181183
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
182-
auto abs_val = Vectorized(abs);
183-
184-
auto div = values / abs_val.values; // x / abs(x)
185-
184+
auto div = values / abs;
186185
return _mm256_blendv_ps(div, zero, mask);
187186
}
188187
__m256 real_() const {

aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ template <> class Vectorized<c10::complex<double>> {
174174
return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
175175
}
176< 8000 /code>176
__m512d abs_() const {
177-
return _mm512_sqrt_pd(abs_2_()); // abs abs
177+
auto real = _mm512_movedup_pd(values); // real real
178+
// movehdup_pd does not exist...
179+
auto imag = _mm512_permute_pd(values, 0xff); // imag imag
180+
return Sleef_hypotd8_u05(real, imag); // abs abs
178181
}
179182
Vectorized<c10::complex<double>> abs() const {
180183
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
@@ -200,13 +203,8 @@ template <> class Vectorized<c10::complex<double>> {
200203
auto abs = abs_();
201204
auto zero = _mm512_setzero_pd();
202205
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
203-
auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask,
204-
0xFFFFFFFFFFFFFFFF);
205-
auto abs_val = Vectorized(abs);
206-
207-
auto div = values / abs_val.values; // x / abs(x)
208-
209-
return blendv(div, zero, _mm512_castsi512_pd(mask_vec));
206+
auto div = values / abs;
207+
return _mm512_mask_blend_pd(mask, div, zero);
210208
}
211209
__m512d real_() const {
212210
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,

aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,9 @@ template <> class Vectorized<c10::complex<float>> {
680680
return ret;
681681
}
682682
__m512 abs_() const {
683-
return _mm512_sqrt_ps(abs_2_()); // abs abs
683+
auto real = _mm512_moveldup_ps(values); // real real
684+
auto imag = _mm512_movehdup_ps(values); // imag imag
685+
return Sleef_hypotf16_u05(real, imag); // abs abs
684686
}
685687
Vectorized<c10::complex<float>> abs() const {
686688
const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
@@ -706,10 +708,7 @@ template <> class Vectorized<c10::complex<float>> {
706708
auto abs = abs_();
707709
auto zero = _mm512_setzero_ps();
708710
auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);
709-
auto abs_val = Vectorized(abs);
710-
711-
auto div = values / abs_val.values; // x / abs(x)
712-
711+
auto div = values / abs;
713712
return _mm512_mask_blend_ps(mask, div, zero);
714713
}
715714
__m512 real_() const {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8884,10 +8884,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
88848884
'test_inplace_gradgrad', dtypes=(torch.cdouble,)),
88858885
DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients',
88868886
'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)),
8887-
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
8888-
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
8889-
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
8890-
device_type='cpu', dtypes=[torch.cfloat]),
88918887
DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs",
88928888
"test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
88938889
# Reference: https://github.com/pytorch/pytorch/issues/49224
@@ -13185,11 +13181,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1318513181
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ],
1318613182
skips=(
1318713183
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
13188-
dtypes=(torch.int, torch.int8)),
13189-
# pytorch computes (0+nanj), numpy computes (-5e-18-1j) for input (-501.-1.0000e+20j)
13190-
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
13191-
"test_reference_numerics_large", dtypes=(torch.complex64,), device_type='cpu',
13192-
active_if=not IS_MACOS and not IS_WINDOWS),),
13184+
dtypes=(torch.int, torch.int8)),),
1319313185
),
1319413186
UnaryUfuncInfo(
1319513187
'nn.functional.tanhshrink',
@@ -14078,14 +14070,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1407814070
# Reference: https://github.com/pytorch/pytorch/issues/41245
1407914071
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
1408014072
dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
14081-
# Reference: https://github.com/pytorch/pytorch/issues/53958
14082-
# Test fails in comparison on Nan as the `equal_nan` is True for
14083-
# comparing the CPU tensors.
14084-
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
14085-
device_type='cpu', dtypes=[torch.complex64, torch.complex128]),
14086-
# Reference: https://github.com/pytorch/pytorch/issues/48486
14087-
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
14088-
device_type='cpu', dtypes=[torch.complex64]),
1408914073
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
1409014074
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
1409114075
)),
@@ -18077,18 +18061,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1807718061
#
1807818062
ElementwiseUnaryPythonRefInfo(
1807918063
"_refs.abs",
18080-
torch_opinfo_name="abs",
18081-
skips=(
18082-
# Reference result was farther (0.0) from the precise computation
18083-
# than the torch result was (nan)!
18084-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
18085-
dtypes=(torch.chalf,), device_type='cpu', active_if=not (IS_MACOS or IS_WINDOWS)),
18086-
# Reference result was farther (0.0) from the precise computation
18087-
# than the torch result was (nan)!
18088-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
18089-
dtypes=(torch.cha 58BD lf,), device_type='cpu', active_if=not (IS_MACOS or IS_WINDOWS)),
18090-
)
18091-
),
18064+
torch_opinfo_name="abs"),
1809218065
ElementwiseUnaryPythonRefInfo(
1809318066
"_refs.acos",
1809418067
torch_opinfo_name="acos",

0 commit comments

Comments
 (0)
0