8000 Improve the precision of abs() and sign() for large values by lezcano · Pull Request #99550 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Improve the precision of abs() and sign() for large values #99550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "Improve the precision of abs() and sign() for large values"
ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
  • Loading branch information
lezcano committed Apr 20, 2023
commit 4c83086c3cf90f49866b9e1052463bd4fa362253
16 changes: 10 additions & 6 deletions aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
8000
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,16 @@ template <> class Vectorized<c10::complex<double>> {
auto val_2 = _mm256_mul_pd(values, values); // a*a b*b
return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
}
__m256d abs_() const {
auto real = _mm256_movedup_pd(values); // real real
// movehdup_pd does not exist...
auto imag = _mm256_permute_pd(values, 0xf); // imag imag
return Sleef_hypotd4_u05(real, imag); // abs abs
}
Vectorized<c10::complex<double>> abs() const {
return Sleef_hypotd4_u05(real_(), imag().values);
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm256_and_pd(abs_(), real_mask); // abs 0
}
__m256d angle_() const {
//angle = atan2(b/a)
Expand All @@ -135,14 +143,10 @@ template <> class Vectorized<c10::complex<double>> {
return _mm256_and_pd(angle, real_mask); // angle 0
}
Vectorized<c10::complex<double>> sgn() const {
auto abs_zero = abs().values; // abs 0
auto abs = _mm256_movedup_pd(abs_zero); // abs abs

auto abs = abs_();
auto zero = _mm256_setzero_pd();
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);

auto div = values / abs;

return _mm256_blendv_pd(div, zero, mask);
}
__m256d real_() const {
Expand Down
15 changes: 9 additions & 6 deletions aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,15 @@ template <> class Vectorized<c10::complex<float>> {
auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b
return _mm256_permute_ps(ret, 0xD8);
}
__m256 abs_() const {
auto real = _mm256_moveldup_ps(values); // real real
auto imag = _mm256_movehdup_ps(values); // imag imag
return Sleef_hypotf8_u05(real, imag); // abs abs
}
Vectorized<c10::complex<float>> abs() const {
return Sleef_hypotf8_u05(real_(), imag().values);
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
return _mm256_and_ps(abs_(), real_mask); // abs 0
}
__m256 angle_() const {
//angle = atan2(b/a)
Expand All @@ -171,14 +178,10 @@ template <> class Vectorized<c10::complex<float>> {
return _mm256_and_ps(angle, real_mask); // angle 0
}
Vectorized<c10::complex<float>> sgn() const {
auto abs_zero = abs().values; // abs 0
auto abs = _mm256_moveldup_ps(abs_zero); // abs abs

auto abs = abs_();
auto zero = _mm256_setzero_ps();
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);

auto div = values / abs;

return _mm256_blendv_ps(div, zero, mask);
}
__m256 real_() const {
Expand Down
18 changes: 12 additions & 6 deletions aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,18 @@ template <> class Vectorized<c10::complex<double>> {
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
}
__m512d abs_() const {
auto real = _mm256_movedup_pd(values); // real real
// movehdup_pd does not exist...
auto imag = _mm256_permute_pd(values, 0xff); // imag imag
return Sleef_hypotd8_u05(real, imag); // abs abs
}
Vectorized<c10::complex<double>> abs() const {
return Sleef_hypotd8_u05(real_(), imag().values);
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm512_and_pd(abs_(), real_mask); // abs 0
}
__m512d angle_() const {
//angle = atan2(b/a)
Expand All @@ -190,14 +200,10 @@ template <> class Vectorized<c10::complex<double>> {
return _mm512_and_pd(angle, real_mask); // angle 0
}
Vectorized<c10::complex<double>> sgn() const {
auto abs_zero = abs().values; // abs 0
auto abs = _mm512_movedup_pd(abs_zero); // abs abs

auto abs = abs_();
auto zero = _mm512_setzero_pd();
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);

auto div = values / abs;

return _mm512_mask_blend_pd(mask, div, zero);
}
__m512d real_() const {
Expand Down
17 changes: 11 additions & 6 deletions aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,17 @@ template <> class Vectorized<c10::complex<float>> {
auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b
return ret;
}
__m512 abs_() const {
auto real = _mm512_moveldup_ps(values); // real real
auto imag = _mm512_movehdup_ps(values); // imag imag
return Sleef_hypotf16_u05(real, imag); // abs abs
}
Vectorized<c10::complex<float>> abs() const {
return Sleef_hypotf16_u05(real_(), imag().values);
const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
return _mm512_and_ps(abs_(), real_mask); // abs 0
}
__m512 angle_() const {
//angle = atan2(b/a)
Expand All @@ -696,14 +705,10 @@ template <> class Vectorized<c10::complex<float>> {
return _mm512_and_ps(angle, real_mask); // angle 0
}
Vectorized<c10::complex<float>> sgn() const {
auto abs_zero = abs().values; // abs 0
auto abs = _mm512_moveldup_ps(abs_zero); // abs abs

auto abs = abs_();
auto zero = _mm512_setzero_ps();
auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);

auto div = values / abs;

return _mm512_mask_blend_ps(mask, div, zero);
}
__m512 real_() const {
Expand Down
14 changes: 1 addition & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13181,11 +13181,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ],
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
dtypes=(torch.int, torch.int8)),
# pytorch computes (0+nanj), numpy computes (-5e-18-1j) for input (-501.-1.0000e+20j)
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
"test_reference_numerics_large", dtypes=(torch.complex64,), device_type='cpu',
active_if=not IS_MACOS and not IS_WINDOWS),),
dtypes=(torch.int, torch.int8)),),
),
UnaryUfuncInfo(
'nn.functional.tanhshrink',
Expand Down Expand Up @@ -14074,14 +14070,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# Reference: https://github.com/pytorch/pytorch/issues/41245
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
# Reference: https://github.com/pytorch/pytorch/issues/53958
# Test fails in comparison on Nan as the `equal_nan` is True for
# comparing the CPU tensors.
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
device_type='cpu', dtypes=[torch.complex64, torch.complex128]),
# Reference: https://github.com/pytorch/pytorch/issues/48486
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cpu', dtypes=[torch.complex64]),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
)),
Expand Down
0