22
22
intc , single , double , csingle , cdouble , inexact , complexfloating , \
23
23
newaxis , ravel , all , Inf , dot , add , multiply , sqrt , maximum , \
24
24
fastCopyAndTranspose , sum , isfinite , size , finfo , errstate , \
25
- geterrobj , float128
25
+ geterrobj , float128 , rollaxis , amin , amax
26
26
from numpy .lib import triu , asfarray
27
27
from numpy .linalg import lapack_lite , _umath_linalg
28
28
from numpy .matrixlib .defmatrix import matrix_power
@@ -1866,6 +1866,44 @@ def lstsq(a, b, rcond=-1):
1866
1866
return wrap (x ), wrap (resids ), results ['rank' ], st
1867
1867
1868
1868
1869
+ def _multi_svd_norm (x , row_axis , col_axis , op ):
1870
+ """Compute the exteme singular values of the 2-D matrices in `x`.
1871
+
1872
+ This is a private utility function used by numpy.linalg.norm().
1873
+
1874
+ Parameters
1875
+ ----------
1876
+ x : ndarray
1877
+ row_axis, col_axis : int
1878
+ The axes of `x` that hold the 2-D matrices.
1879
+ op : callable
1880
+ This should be either numpy.amin or numpy.amax.
1881
+
1882
+ Returns
1883
+ -------
1884
+ result : float or ndarray
1885
+ If `x` is 2-D, the return values is a float.
1886
+ Otherwise, it is an array with ``x.ndim - 2`` dimensions.
1887
+ The return values are either the minimum or maximum of the
1888
+ singular values of the matrices, depending on whether `op`
1889
+ is `numpy.amin` or `numpy.amax`.
1890
+
1891
+ """
1892
+ if row_axis > col_axis :
1893
+ row_axis -= 1
1894
+ y = rollaxis (rollaxis (x , col_axis , x .ndim ), row_axis , - 1 )
1895
+ if x .ndim > 3 :
1896
+ z = y .reshape ((- 1 ,) + y .shape [- 2 :])
1897
+ else :
1898
+ z = y
1899
+ if x .ndim == 2 :
1900
+ result = op (svd (z , compute_uv = 0 ))
1901
+ else :
1902
+ result = array ([op (svd (m , compute_uv = 0 )) for m in z ])
1903
+ result .shape = y .shape [:- 2 ]
1904
+ return result
1905
+
1906
+
1869
1907
def norm (x , ord = None , axis = None ):
1870
1908
"""
1871
1909
Matrix or vector norm.
@@ -1881,10 +1919,12 @@ def norm(x, ord=None, axis=None):
1881
1919
ord : {non-zero int, inf, -inf, 'fro'}, optional
1882
1920
Order of the norm (see table under ``Notes``). inf means numpy's
1883
1921
`inf` object.
1884
- axis : int or None, optional
1885
- If `axis` is not None, it specifies the axis of `x` along which to
1886
- compute the vector norms. If `axis` is None, then either a vector
1887
- norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned.
1922
+ axis : {int, 2-tuple of ints, None}, optional
1923
+ If `axis` is an integer, it specifies the axis of `x` along which to
1924
+ compute the vector norms. If `axis` is a 2-tuple, it specifies the
1925
+ axes that hold 2-D matrices, and the matrix norms of these matrices
1926
+ are computed. If `axis` is None then either a vector norm (when `x`
1927
+ is 1-D) or a matrix norm (when `x` is 2-D) is returned.
1888
1928
1889
1929
Returns
1890
1930
-------
@@ -1972,7 +2012,7 @@ def norm(x, ord=None, axis=None):
1972
2012
>>> LA.norm(a, -3)
1973
2013
nan
1974
2014
1975
- Using the `axis` argument:
2015
+ Using the `axis` argument to compute vector norms :
1976
2016
1977
2017
>>> c = np.array([[ 1, 2, 3],
1978
2018
... [-1, 1, 4]])
@@ -1983,6 +2023,14 @@ def norm(x, ord=None, axis=None):
1983
2023
>>> LA.norm(c, ord=1, axis=1)
1984
2024
array([6, 6])
1985
2025
2026
+ Using the `axis` argument to compute matrix norms:
2027
+
2028
+ >>> m = np.arange(8).reshape(2,2,2)
2029
+ >>> norm(m, axis=(1,2))
2030
+ array([ 3.74165739, 11.22497216])
2031
+ >>> norm(m[0]), norm(m[1])
2032
+ (3.7416573867739413, 11.224972160321824)
2033
+
1986
2034
"""
1987
2035
x = asarray (x )
1988
2036
@@ -1991,8 +2039,14 @@ def norm(x, ord=None, axis=None):
1991
2039
s = (x .conj () * x ).real
1992
2040
return sqrt (add .reduce ((x .conj () * x ).ravel ().real ))
1993
2041
2042
+ # Normalize the `axis` argument to a tuple.
2043
+ if axis is None :
2044
+ axis = tuple (range (x .ndim ))
2045
+ elif not isinstance (axis , tuple ):
2046
+ axis = (axis ,)
2047
+
1994
2048
nd = x .ndim
1995
- if nd == 1 or axis is not None :
2049
+ if len ( axis ) == 1 :
1996
2050
if ord == Inf :
1997
2051
return abs (x ).max (axis = axis )
1998
2052
elif ord == - Inf :
@@ -2018,21 +2072,36 @@ def norm(x, ord=None, axis=None):
2018
2072
# because it will downcast to float64.
2019
2073
absx = asfarray (abs (x ))
2020
2074
return add .reduce (absx ** ord , axis = axis )** (1.0 / ord )
2021
- elif nd == 2 :
2075
+ elif len (axis ) == 2 :
2076
+ row_axis , col_axis = axis
2077
+ if not (- x .ndim <= row_axis < x .ndim and
2078
+ - x .ndim <= col_axis < x .ndim ):
2079
+ raise ValueError ('Invalid axis %r for an array with shape %r' %
2080
+ (axis , x .shape ))
2081
+ if row_axis % x .ndim == col_axis % x .ndim :
2082
+ raise ValueError ('Duplicate axes given.' )
2022
2083
if ord == 2 :
2023
- return svd (x , compute_uv = 0 ). max ( )
2084
+ return _multi_svd_norm (x , row_axis , col_axis , amax )
2024
2085
elif ord == - 2 :
2025
- return svd (x , compute_uv = 0 ). min ( )
2086
+ return _multi_svd_norm (x , row_axis , col_axis , amin )
2026
2087
elif ord == 1 :
2027
- return abs (x ).sum (axis = 0 ).max ()
2088
+ if col_axis > row_axis :
2089
+ col_axis -= 1
2090
+ return add .reduce (abs (x ), axis = row_axis ).max (axis = col_axis )
2028
2091
elif ord == Inf :
2029
- return abs (x ).sum (axis = 1 ).max ()
2092
+ if row_axis > col_axis :
2093
+ row_axis -= 1
2094
+ return add .reduce (abs (x ), axis = col_axis ).max (axis = row_axis )
2030
2095
elif ord == - 1 :
2031
- return abs (x ).sum (axis = 0 ).min ()
2096
+ if col_axis > row_axis :
2097
+ col_axis -= 1
2098
+ return add .reduce (abs (x ), axis = row_axis ).min (axis = col_axis )
2032
2099
elif ord == - Inf :
2033
- return abs (x ).sum (axis = 1 ).min ()
2034
- elif ord in ['fro' ,'f' ]:
2035
- return sqrt (add .reduce ((x .conj () * x ).real .ravel ()))
2100
+ if row_axis > col_axis :
2101
+ row_axis -= 1
2102
+ return add .reduce (abs (x ), axis = col_axis ).min (axis = row_axis )
2103
+ elif ord in [None , 'fro' , 'f' ]:
2104
+ return sqrt (add .reduce ((x .conj () * x ).real , axis = axis ))
2036
2105
else :
2037
2106
raise ValueError ("Invalid norm order for matrices." )
2038
2107
else :
0 commit comments