From 407403d98ac44324ae5a2ca26c1e5814857e5622 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:06:44 -0800 Subject: [PATCH 1/6] Add complex number support to `linalg.slogdet` --- spec/API_specification/array_api/linalg.py | 38 ++++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index b3595e1fa..fae06c55b 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -345,31 +345,49 @@ def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tupl """ def slogdet(x: array, /) -> Tuple[array, array]: - """ + r""" Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) ``x``. .. note:: The purpose of this function is to calculate the determinant more accurately when the determinant is either very small or very large, as calling ``det`` may overflow or underflow. + The sign of the determinant is given by + + .. math:: + \operatorname{sign}(\det x) = \begin{cases} + 0 & \textrm{if } \det x = 0 && + \frac{\det x}{|\det x|} + \end{cases} + + where :math:`|\det x|` is the absolute value of :math:`\det x`. + + **Special Cases** + + For real-valued floating-point operands, + + - If the determinant is zero, the corresponding ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``. + + For complex floating-point operands, + + - If the determinant is ``0 + 0j``, the corresponding ``sign`` should be ``0 + 0j`` and ``logabsdet`` should be ``-infinity + 0j``. + + .. note:: + Depending on the underlying algorithm, when the determinant is zero, the returned result may differ from ``-infinity`` (or ``-infinity + 0j``). In all cases, the determinant should be equal to ``sign * exp(logabsdet)`` (although, again, the result may be subject to numerical precision errors). + Parameters ---------- x: array - input array having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Should have a real-valued floating-point data type. + input array having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Should have a floating-point data type. Returns ------- out: Tuple[array, array] a namedtuple (``sign``, ``logabsdet``) whose - - first element must have the field name ``sign`` and must be an array containing a number representing the sign of the determinant for each square matrix. - - second element must have the field name ``logabsdet`` and must be an array containing the determinant for each square matrix. - - For a real matrix, the sign of the determinant must be either ``1``, ``0``, or ``-1``. + - first element must have the field name ``sign`` and must be an array containing a number representing the sign of the determinant for each square matrix. Must have the same data type as ``x``. + - second element must have the field name ``logabsdet`` and must be an array containing the natural logarithm of the absolute value of the determinant for each square matrix. If ``x`` is real-valued, the returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. If ``x`` is complex`, the returned array must have a real-valued floating-point data type having the same precision as ``x`` (e.g., if ``x`` is ``complex64``, ``logabsdet`` must have a ``float32`` data type). - Each returned array must have shape ``shape(x)[:-2]`` and a real-valued floating-point data type determined by :ref:`type-promotion`. - - .. note:: - If a determinant is zero, then the corresponding ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``; however, depending on the underlying algorithm, the returned result may differ. In all cases, the determinant should be equal to ``sign * exp(logsabsdet)`` (although, again, the result may be subject to numerical precision errors). + Each returned array must have shape ``shape(x)[:-2]``. """ def solve(x1: array, x2: array, /) -> array: From 97532c00e886d3c3bd8b59b2fc946a53e7c69a1b Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:09:42 -0800 Subject: [PATCH 2/6] Update copy --- spec/API_specification/array_api/linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index fae06c55b..a487ea424 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -361,6 +361,8 @@ def slogdet(x: array, /) -> Tuple[array, array]: where :math:`|\det x|` is the absolute value of :math:`\det x`. + When ``x`` is a stack of matrices, the function must compute the sign and natural logarithm of the absolute value of the determinant for each matrix in the stack. + **Special Cases** For real-valued floating-point operands, From cb7c3245975ca012901f5f669e58258abf539538 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:24:01 -0800 Subject: [PATCH 3/6] Fix equation --- spec/API_specification/array_api/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index a487ea424..475a9a622 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -355,8 +355,8 @@ def slogdet(x: array, /) -> Tuple[array, array]: .. math:: \operatorname{sign}(\det x) = \begin{cases} - 0 & \textrm{if } \det x = 0 && - \frac{\det x}{|\det x|} + 0 & \textrm{if } \det x = 0 \\ + \frac{\det x}{|\det x|} & \textrm{otherwise} \end{cases} where :math:`|\det x|` is the absolute value of :math:`\det x`. From 5f4eea984d7a6835343df4fb5c79b04d593e38f2 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:24:36 -0800 Subject: [PATCH 4/6] Update copy --- spec/API_specification/array_api/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 475a9a622..6d48428d9 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -359,7 +359,7 @@ def slogdet(x: array, /) -> Tuple[array, array]: \frac{\det x}{|\det x|} & \textrm{otherwise} \end{cases} - where :math:`|\det x|` is the absolute value of :math:`\det x`. + where :math:`|\det x|` is the absolute value of the determinant of ``x``. When ``x`` is a stack of matrices, the function must compute the sign and natural logarithm of the absolute value of the determinant for each matrix in the stack. From d9d01caec446d8b578897851a6487140b3c7097b Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:25:14 -0800 Subject: [PATCH 5/6] Fix stray backtick --- spec/API_specification/array_api/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 6d48428d9..4f00b8bc3 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -387,7 +387,7 @@ def slogdet(x: array, /) -> Tuple[array, array]: a namedtuple (``sign``, ``logabsdet``) whose - first element must have the field name ``sign`` and must be an array containing a number representing the sign of the determinant for each square matrix. Must have the same data type as ``x``. - - second element must have the field name ``logabsdet`` and must be an array containing the natural logarithm of the absolute value of the determinant for each square matrix. If ``x`` is real-valued, the returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. If ``x`` is complex`, the returned array must have a real-valued floating-point data type having the same precision as ``x`` (e.g., if ``x`` is ``complex64``, ``logabsdet`` must have a ``float32`` data type). + - second element must have the field name ``logabsdet`` and must be an array containing the natural logarithm of the absolute value of the determinant for each square matrix. If ``x`` is real-valued, the returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. If ``x`` is complex, the returned array must have a real-valued floating-point data type having the same precision as ``x`` (e.g., if ``x`` is ``complex64``, ``logabsdet`` must have a ``float32`` data type). Each returned array must have shape ``shape(x)[:-2]``. """ From 01fd124056f1a36ad1ca1845a30f642822caaba6 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Tue, 13 Dec 2022 01:28:24 -0800 Subject: [PATCH 6/6] Remove word --- spec/API_specification/array_api/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py index 4f00b8bc3..f15337020 100644 --- a/spec/API_specification/array_api/linalg.py +++ b/spec/API_specification/array_api/linalg.py @@ -367,11 +367,11 @@ def slogdet(x: array, /) -> Tuple[array, array]: For real-valued floating-point operands, - - If the determinant is zero, the corresponding ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``. + - If the determinant is zero, the ``sign`` should be ``0`` and ``logabsdet`` should be ``-infinity``. For complex floating-point operands, - - If the determinant is ``0 + 0j``, the corresponding ``sign`` should be ``0 + 0j`` and ``logabsdet`` should be ``-infinity + 0j``. + - If the determinant is ``0 + 0j``, the ``sign`` should be ``0 + 0j`` and ``logabsdet`` should be ``-infinity + 0j``. .. note:: Depending on the underlying algorithm, when the determinant is zero, the returned result may differ from ``-infinity`` (or ``-infinity + 0j``). In all cases, the determinant should be equal to ``sign * exp(logabsdet)`` (although, again, the result may be subject to numerical precision errors).