8000 ENH: linalg: allow the 'axis' argument of linalg.norm to be a 2-tuple… · numpy/numpy@a292614 · GitHub
[go: up one dir, main page]

Skip to content

Commit a292614

Browse files
ENH: linalg: allow the 'axis' argument of linalg.norm to be a 2-tuple, in which case matrix norms of the collection of 2-D matrices are computed.
1 parent 6eb57a7 commit a292614

File tree

2 files changed

+134
-19
lines changed

2 files changed

+134
-19
lines changed

numpy/linalg/linalg.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
intc, single, double, csingle, cdouble, inexact, complexfloating, \
2323
newaxis, ravel, all, Inf, dot, add, multiply, sqrt, maximum, \
2424
fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, \
25-
geterrobj, float128
25+
geterrobj, float128, rollaxis, amin, amax
2626
from numpy.lib import triu, asfarray
2727
from numpy.linalg import lapack_lite, _umath_linalg
2828
from numpy.matrixlib.defmatrix import matrix_power
@@ -1866,6 +1866,44 @@ def lstsq(a, b, rcond=-1):
18661866
return wrap(x), wrap(resids), results['rank'], st
18671867

18681868

1869+
def _multi_svd_norm(x, row_axis, col_axis, op):
1870+
"""Compute the exteme singular values of the 2-D matrices in `x`.
1871+
1872+
This is a private utility function used by numpy.linalg.norm().
1873+
1874+
Parameters
1875+
----------
1876+
x : ndarray
1877+
row_axis, col_axis : int
1878+
The axes of `x` that hold the 2-D matrices.
1879+
op : callable
1880+
This should be either numpy.amin or numpy.amax.
1881+
1882+
Returns
1883+
-------
1884+
result : float or ndarray
1885+
If `x` is 2-D, the return values is a float.
1886+
Otherwise, it is an array with ``x.ndim - 2`` dimensions.
1887+
The return values are either the minimum or maximum of the
1888+
singular values of the matrices, depending on whether `op`
1889+
is `numpy.amin` or `numpy.amax`.
1890+
1891+
"""
1892+
if row_axis > col_axis:
1893+
row_axis -= 1
1894+
y = rollaxis(rollaxis(x, col_axis, x.ndim), row_axis, -1)
1895+
if x.ndim > 3:
1896+
z = y.reshape((-1,) + y.shape[-2:])
1897+
else:
1898+
z = y
1899+
if x.ndim == 2:
1900+
result = op(svd(z, compute_uv=0))
1901+
else:
1902+
result = array([op(svd(m, compute_uv=0)) for m in z])
1903+
result.shape = y.shape[:-2]
1904+
return result
1905+
1906+
18691907
def norm(x, ord=None, axis=None):
18701908
"""
18711909
Matrix or vector norm.
@@ -1881,10 +1919,12 @@ def norm(x, ord=None, axis=None):
18811919
ord : {non-zero int, inf, -inf, 'fro'}, optional
18821920
Order of the norm (see table under ``Notes``). inf means numpy's
18831921
`inf` object.
1884-
axis : int or None, optional
1885-
If `axis` is not None, it specifies the axis of `x` along which to
1886-
compute the vector norms. If `axis` is None, then either a vector
1887-
norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned.
1922+
axis : {int, 2-tuple of ints, None}, optional
1923+
If `axis` is an integer, it specifies the axis of `x` along which to
1924+
compute the vector norms. If `axis` is a 2-tuple, it specifies the
1925+
axes that hold 2-D matrices, and the matrix norms of these matrices
1926+
are computed. If `axis` is None then either a vector norm (when `x`
1927+
is 1-D) or a matrix norm (when `x` is 2-D) is returned.
18881928
18891929
Returns
18901930
-------
@@ -1972,7 +2012,7 @@ def norm(x, ord=None, axis=None):
19722012
>>> LA.norm(a, -3)
19732013
nan
19742014
1975-
Using the `axis` argument:
2015+
Using the `axis` argument to compute vector norms:
19762016
19772017
>>> c = np.array([[ 1, 2, 3],
19782018
... [-1, 1, 4]])
@@ -1983,6 +2023,14 @@ def norm(x, ord=None, axis=None):
19832023
>>> LA.norm(c, ord=1, axis=1)
19842024
array([6, 6])
19852025
2026+
Using the `axis` argument to compute matrix norms:
2027+
2028+
>>> m = np.arange(8).reshape(2,2,2)
2029+
>>> norm(m, axis=(1,2))
2030+
array([ 3.74165739, 11.22497216])
2031+
>>> norm(m[0]), norm(m[1])
2032+
(3.7416573867739413, 11.224972160321824)
2033+
19862034
"""
19872035
x = asarray(x)
19882036

@@ -1991,8 +2039,14 @@ def norm(x, ord=None, axis=None):
19912039
s = (x.conj() * x).real
19922040
return sqrt(add.reduce((x.conj() * x).ravel().real))
19932041

2042+
# Normalize the `axis` argument to a tuple.
2043+
if axis is None:
2044+
axis = tuple(range(x.ndim))
2045+
elif not isinstance(axis, tuple):
2046+
axis = (axis,)
2047+
19942048
nd = x.ndim
1995-
if nd == 1 or axis is not None:
2049+
if len(axis) == 1:
19962050
if ord == Inf:
19972051
return abs(x).max(axis=axis)
19982052
elif ord == -Inf:
@@ -2018,21 +2072,36 @@ def norm(x, ord=None, axis=None):
20182072
# because it will downcast to float64.
20192073
absx = asfarray(abs(x))
20202074
return add.reduce(absx**ord, axis=axis)**(1.0/ord)
2021-
elif nd == 2:
2075+
elif len(axis) == 2:
2076+
row_axis, col_axis = axis
2077+
if not (-x.ndim <= row_axis < x.ndim and
2078+
-x.ndim <= col_axis < x.ndim):
2079+
raise ValueError('Invalid axis %r for an array with shape %r' %
2080+
(axis, x.shape))
2081+
if row_axis % x.ndim == col_axis % x.ndim:
2082+
raise ValueError('Duplicate axes given.')
20222083
if ord == 2:
2023-
return svd(x, compute_uv=0).max()
2084+
return _multi_svd_norm(x, row_axis, col_axis, amax)
20242085
elif ord == -2:
2025-
return svd(x, compute_uv=0).min()
2086+
return _multi_svd_norm(x, row_axis, col_axis, amin)
20262087
elif ord == 1:
2027-
return abs(x).sum(axis=0).max()
2088+
if col_axis > row_axis:
2089+
col_axis -= 1
2090+
return add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
20282091
elif ord == Inf:
2029-
return abs(x).sum(axis=1).max()
2092+
if row_axis > col_axis:
2093+
row_axis -= 1
2094+
return add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
20302095
elif ord == -1:
2031-
return abs(x).sum(axis=0).min()
2096+
if col_axis > row_axis:
2097+
col_axis -= 1
2098+
return add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
20322099
elif ord == -Inf:
2033-
return abs(x).sum(axis=1).min()
2034-
elif ord in ['fro','f']:
2035-
return sqrt(add.reduce((x.conj() * x).real.ravel()))
2100+
if row_axis > col_axis:
2101+
row_axis -= 1
2102+
return add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
2103+
elif ord in [None, 'fro', 'f']:
2104+
return sqrt(add.reduce((x.conj() * x).real, axis=axis))
20362105
else:
20372106
raise ValueError("Invalid norm order for matrices.")
20382107
else:

numpy/linalg/tests/test_linalg.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ def test_matrix(self):
563563
self.assertRaises(ValueError, norm, A, 0)
564564

565565
def test_axis(self):
566+
# Vector norms.
566567
# Compare the use of `axis` with computing the norm of each row
567568
# or column separately.
568569
A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
@@ -572,10 +573,55 @@ def test_axis(self):
572573
expected1 = [norm(A[k,:], ord=order) for k in range(A.shape[0])]
573574
assert_almost_equal(norm(A, ord=order, axis=1), expected1)
574575

575-
# Check bad case. Using `axis` implies vector norms are being
576-
# computed, so also using `ord='fro'` raises a ValueError
577-
# (just like `norm([1,2,3], ord='fro')` does).
576+
# Matrix norms.
577+
B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4)
578+
579+
for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
580+
assert_almost_equal(norm(A, ord=order), norm(A, ord=order, axis=(0,1)))
581+
582+
n = norm(B, ord=order, axis=(1,2))
583+
expected = [norm(B[k], ord=order) for k in range(B.shape[0])]
584+
print("shape is %r, axis=(1,2)" % (B.shape,))
585+
assert_almost_equal(n, expected)
586+
587+
n = norm(B, ord=order, axis=(2,1))
588+
expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])]
589+
print("shape is %r, axis=(2,1)" % (B.shape,))
590+
assert_almost_equal(n, expected)
591+
592+
n = norm(B, ord=order, axis=(0,2))
593+
expected = [norm(B[:,k,:], ord=order) for k in range(B.shape[1])]
594+
print("shape is %r, axis=(0,2)" % (B.shape,))
595+
assert_almost_equal(n, expected)
596+
597+
n = norm(B, ord=order, axis=(0,1))
598+
expected = [norm(B[:,:,k], ord=order) for k in range(B.shape[2])]
599+
print("shape is %r, axis=(0,1)" % (B.shape,))
600+
assert_almost_equal(n, expected)
601+
602+
def test_bad_args(self):
603+
# Check that bad arguments raise the appropriate exceptions.
604+
605+
A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
606+
B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4)
607+
608+
# Using `axis=<integer>` or passing in a 1-D array implies vector
609+
# norms are being computed, so also using `ord='fro'` raises a
610+
# ValueError.
578611
self.assertRaises(ValueError, norm, A, 'fro', 0)
612+
self.assertRaises(ValueError, norm, [3, 4], 'fro', None)
613+
614+
# Similarly, norm should raise an exception when ord is any finite
615+
# number other than 1, 2, -1 or -2 when computing matrix norms.
616+
for order in [0, 3]:
617+
self.assertRaises(ValueError, norm, A, order, None)
618+
self.assertRaises(ValueError, norm, A, order, (0,1))
619+
self.assertRaises(ValueError, norm, B, order, (1,2))
620+
621+
# Invalid axis
622+
self.assertRaises(ValueError, norm, B, None, 3)
623+
self.assertRaises(ValueError, norm, B, None, (2,3))
624+
self.assertRaises(ValueError, norm, B, None, (0,1,2))
579625

580626

581627
class TestNormDouble(_TestNorm):

0 commit comments

Comments
 (0)
0