8000 np.linalg.vector_norm: return correct shape for keepdims · asmeurer/numpy@e22c1b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit e22c1b8

Browse files
committed
np.linalg.vector_norm: return correct shape for keepdims
1 parent 6e3b923 commit e22c1b8

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

numpy/linalg/_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,7 @@ def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
33063306
33073307
"""
33083308
x = asanyarray(x)
3309+
shape = list(x.shape)
33093310
if axis is None:
33103311
# Note: np.linalg.norm() doesn't handle 0-D arrays
33113312
x = x.ravel()
@@ -3331,9 +3332,8 @@ def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
33313332
if keepdims:
33323333
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
33333334
# above to avoid matrix norm logic.
3334-
shape = list(x.shape)
33353335
_axis = normalize_axis_tuple(
3336-
range(x.ndim) if axis is None else axis, x.ndim
3336+
range(len(shape)) if axis is None else axis, len(shape)
33373337
)
33383338
for i in _axis:
33393339
shape[i] = 1

numpy/linalg/tests/test_linalg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,3 +2355,8 @@ def test_vector_norm():
23552355
assert_almost_equal(
23562356
actual, np.array([6.7082, 8.124, 9.6436]), double_decimal=3
23572357
)
2358+
2359+
actual = np.linalg.vector_norm(x, keepdims=True)
2360+
expected = np.full((1, 1), 14.2828, dtype='float64')
2361+
assert_equal(actual.shape, expected.shape)
2362+
assert_almost_equal(actual, expected, double_decimal=3)

0 commit comments

Comments
 (0)
0