8000 ENH: np.rescale(): add return_params argument · numpy/numpy@c37e546 · GitHub
[go: up one dir, main page]

Skip to content

Commit c37e546

Browse files
committed
ENH: np.rescale(): add return_params argument
1 parent b131eaa commit c37e546

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

numpy/lib/function_base.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4401,6 +4401,7 @@ def append(arr, values, axis=None):
44014401

44024402

44034403
def rescale(arr, out_min=0, out_max=1, in_min=None, in_max=None,
4404+
axis=None, return_params=False, out=None):
44044405
"""Linearly scale non-NaN values in array to the specified interval.
44054406
44064407
.. versionadded:: 1.15.0
@@ -4426,6 +4427,10 @@ def rescale(arr, out_min=0, out_max=1, in_min=None, in_max=None,
44264427
axis : int, optional
44274428
Axis along which to scale `arr`. If `None`, scaling is done over
44284429
all axes.
4430+
return_params : bool, optional
4431+
If True, also return `offset` and `scale` parameters so that
4432+
``offset + scale * x`` maps ``x`` from input interval onto
4433+
output interval.
44294434
out : array_like, optional
44304435
Array of the same shape as `arr`, in which to store the result.
44314436
@@ -4434,6 +4439,10 @@ def rescale(arr, out_min=0, out_max=1, in_min=None, in_max=None,
44344439
scaled_array : ndarray
44354440
Output array with the same shape as `arr` and all non-NaN values
44364441
scaled from interval [in_min, in_max] to [out_min, out_max].
4442+
offset : scalar or ndarray, optional
4443+
Only provided if `return_params` is True.
4444+
scale : scalar or ndarray, optional
4445+
Only provided if `return_params` is True.
44374446
44384447
Raises
44394448
------
@@ -4460,14 +4469,25 @@ def rescale(arr, out_min=0, out_max=1, in_min=None, in_max=None,
44604469
out_min = asanyarray(out_min)
44614470
out_max = asanyarray(out_max)
44624471

4472+
keepdims = axis is not None # Avoids extra, unnecessary dims
44634473
if in_min is None:
4464-
in_min = np.nanmin(arr, axis=axis, keepdims=True)
4474+
in_min = np.nanmin(arr, axis=axis, keepdims=keepdims).astype(float)
44654475
if in_max is None:
4466-
in_max = np.nanmax(arr, axis=axis, keepdims=True)
4476+
in_max = np.nanmax(arr, axis=axis, keepdims=keepdims).astype(float)
44674477

4468-
res = (arr - in_min) / (in_max - in_min) * (out_max - out_min) + out_min
4478+
oldlen = in_max - in_min
4479+
newlen = out_max - out_min
4480+
4481+
offset = (in_max * out_min - in_min * out_max) / oldlen
4482+
scale = newlen / oldlen
4483+
4484+
res = arr * scale + offset
44694485

44704486
if out is None:
4471-
return res
4472-
out[...] = res
4487+
out = res
4488+
else:
4489+
out[...] = res
4490+
4491+
if return_params:
4492+
return out, offset, scale
44734493
return out

numpy/lib/tests/test_function_base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2982,9 +2982,9 @@ def setup(self):
29822982
self.arr3d = np.arange(8, dtype=float).reshape((2, 2, 2))
29832983

29842984
def test_basic(self):
2985-
assert_equal(rescale(self.arr),
2986-
[[0, .2, .4, .6],
2987-
[1, .8, .6, .4]])
2985+
assert_allclose(rescale(self.arr),
2986+
[[0, .2, .4, .6],
2987+
[1, .8, .6, .4]])
29882988
assert_equal(rescale(self.arr, self.arr.min(), self.arr.max()),
29892989
self.arr)
29902990

@@ -2998,9 +2998,9 @@ def test_min_max(self):
29982998
assert_equal(rescale(self.arr, [1, 0, 1, 2], [2, 1, 1, 3], axis=0),
29992999
[[1, 0, 1, 3],
30003000
[2, 1, 1, 2]])
3001-
assert_equal(rescale(self.arr, [[0], [0]], [[1], [1]], axis=1),
3002-
[[0, 1/3, 2/3, 1],
3003-
[1, 2/3, 1/3, 0]])
3001+
assert_allclose(rescale(self.arr, [[0], [0]], [[1], [1]], axis=1),
3002+
[[0, 1/3, 2/3, 1],
3003+
[1, 2/3, 1/3, 0]])
30043004
assert_equal(rescale(self.arr3d, out_max=3.5),
30053005
(self.arr3d / 2).reshape((2, 2, 2)))
30063006

@@ -3012,15 +3012,24 @@ def test_axis(self):
30123012
assert_equal(rescale(self.arr, axis=0),
30133013
[[0, 0, 0, 1],
30143014
[1, 1, 1, 0]])
3015-
assert_equal(rescale(self.arr, axis=1),
3016-
[[0, 1/3, 2/3, 1],
3017-
[1, 2/3, 1/3, 0]])
3015+
assert_allclose(rescale(self.arr, axis=1),
3016+
[[0, 1/3, 2/3, 1],
3017+
[1, 2/3, 1/3, 0]])
30183018
assert_equal(rescale(self.arr3d, axis=2),
30193019
[[[0, 1],
30203020
[0, 1]],
30213021
[[0, 1],
30223022
[0, 1]]])
30233023

3024+
def test_return_params(self):
3025+
_, offset, scale = rescale(self.arr, return_params=True)
3026+
assert_equal(offset, 0)
3027+
assert_equal(scale, .2)
3028+
3029+
_, offset, scale = rescale(self.arr, axis=1, return_params=True)
3030+
assert_equal(offset, np.c_[0, -2/3].T)
3031+
assert_equal(scale, np.c_[1/3, 1/3].T)
3032+
30243033
def test_out(self):
30253034
out = np.zeros_like(self.arr)
30263035
assert_equal(rescale(self.arr, self.arr.min(), self.arr.max(), out=out),

0 commit comments

Comments
 (0)
0