8000 Merge pull request #12116 from shoyer/array-function-numpy-lib · numpy/numpy@2ae0014 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2ae0014

Browse files
authored
Merge pull request #12116 from shoyer/array-function-numpy-lib
ENH: __array_function__ support for np.lib, part 1/2
2 parents eb2bd11 + 4141e24 commit 2ae0014

File tree

7 files changed

+379
-0
lines changed

7 files changed

+379
-0
lines changed

numpy/lib/arraypad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import division, absolute_import, print_function
77

88
import numpy as np
9+
from numpy.core.overrides import array_function_dispatch
910

1011

1112
__all__ = ['pad']
@@ -990,6 +991,11 @@ def _validate_lengths(narray, number_elements):
990991
# Public functions
991992

992993

994+
def _pad_dispatcher(array, pad_width, mode, **kwargs):
995+
return (array,)
996+
997+
998+
@array_function_dispatch(_pad_dispatcher)
993999
def pad(array, pad_width, mode, **kwargs):
9941000
"""
9951001
Pads an array.

numpy/lib/arraysetops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from __future__ import division, absolute_import, print_function
2929

3030
import numpy as np
31+
from numpy.core.overrides import array_function_dispatch
3132

3233

3334
__all__ = [
@@ -36,6 +37,11 @@
3637
]
3738

3839

40+
def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
41+
return (ary, to_end, to_begin)
42+
43+
44+
@array_function_dispatch(_ediff1d_dispatcher)
3945
def ediff1d(ary, to_end=None, to_begin=None):
4046
"""
4147
The differences between consecutive elements of an array.
@@ -133,6 +139,12 @@ def _unpack_tuple(x):
133139
return x
134140

135141

142+
def _unique_dispatcher(ar, return_index=None, return_inverse=None,
143+
return_counts=None, axis=None):
144+
return (ar,)
145+
146+
147+
@array_function_dispatch(_unique_dispatcher)
136148
def unique(ar, return_index=False, return_inverse=False,
137149
return_counts=False, axis=None):
138150
"""
@@ -313,6 +325,12 @@ def _unique1d(ar, return_index=False, return_inverse=False,
313325
return ret
314326

315327

328+
def _intersect1d_dispatcher(
329+
ar1, ar2, assume_unique=None, return_indices=None):
330+
return (ar1, ar2)
331+
332+
333+
@array_function_dispatch(_intersect1d_dispatcher)
316334
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
317335
"""
318336
Find the intersection of two arrays.
@@ -408,6 +426,11 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
408426
return int1d
409427

410428

429+
def _setxor1d_dispatcher(ar1, ar2, assume_unique=None):
430+
return (ar1, ar2)
431+
432+
433+
@array_function_dispatch(_setxor1d_dispatcher)
411434
def setxor1d(ar1, ar2, assume_unique=False):
412435
"""
413436
Find the set exclusive-or of two arrays.
@@ -562,6 +585,11 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
562585
return ret[rev_idx]
563586

564587

588+
def _isin_dispatcher(element, test_elements, assume_unique=None, invert=None):
589+
return (element, test_elements)
590+
591+
592+
@array_function_dispatch(_isin_dispatcher)
565593
def isin(element, test_elements, assume_unique=False, invert=False):
566594
"""
567595
Calculates `element in test_elements`, broadcasting over `element` only.
@@ -660,6 +688,11 @@ def isin(element, test_elements, assume_unique=False, invert=False):
660688
invert=invert).reshape(element.shape)
661689

662690

691+
def _union1d_dispatcher(ar1, ar2):
692+
return (ar1, ar2)
693+
694+
695+
@array_function_dispatch(_union1d_dispatcher)
663696
def union1d(ar1, ar2):
664697
"""
665698
Find the union of two arrays.
@@ -695,6 +728,12 @@ def union1d(ar1, ar2):
695728
"""
696729
return unique(np.concatenate((ar1, ar2), axis=None))
697730

731+
732+
def _setdiff1d_dispatcher(ar1, ar2, assume_unique=None):
733+
return (ar1, ar2)
734+
735+
736+
@array_function_dispatch(_setdiff1d_dispatcher)
698737
def setdiff1d(ar1, ar2, assume_unique=False):
699738
"""
700739
Find the set difference of two arrays.

numpy/lib/financial.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from decimal import Decimal
1616

1717
import numpy as np
18+
from numpy.core.overrides import array_function_dispatch
19+
1820

1921
__all__ = ['fv', 'pmt', 'nper', 'ipmt', 'ppmt', 'pv', 'rate',
2022
'irr', 'npv', 'mirr']
@@ -36,6 +38,12 @@ def _convert_when(when):
3638
except (KeyError, TypeError):
3739
return [_when_to_num[x] for x in when]
3840

41+
42+
def _fv_dispatcher(rate, nper, pmt, pv, when=None):
43+
return (rate, nper, pmt, pv)
44+
45+
46+
@array_function_dispatch(_fv_dispatcher)
3947
def fv(rate, nper, pmt, pv, when='end'):
4048
"""
4149
Compute the future value.
@@ -124,6 +132,12 @@ def fv(rate, nper, pmt, pv, when='end'):
124132
(1 + rate*when)*(temp - 1)/rate)
125133
return -(pv*temp + pmt*fact)
126134

135+
136+
def _pmt_dispatcher(rate, nper, pv, fv=None, when=None):
137+
return (rate, nper, pv, fv)
138+
139+
140+
@array_function_dispatch(_pmt_dispatcher)
127141
def pmt(rate, nper, pv, fv=0, when='end'):
128142
"""
129143
Compute the payment against loan principal plus interest.
@@ -216,6 +230,12 @@ def pmt(rate, nper, pv, fv=0, when='end'):
216230
(1 + masked_rate*when)*(temp - 1)/masked_rate)
217231
return -(fv + pv*temp) / fact
218232

233+
234+
def _nper_dispatcher(rate, pmt, pv, fv=None, when=None):
235+
return (rate, pmt, pv, fv)
236+
237+
238+
@array_function_dispatch(_nper_dispatcher)
219239
def nper(rate, pmt, pv, fv=0, when='end'):
220240
"""
221241
Compute the number of periodic payments.
@@ -284,6 +304,12 @@ def nper(rate, pmt, pv, fv=0, when='end'):
284304
B = np.log((-fv+z) / (pv+z))/np.log(1+rate)
285305
return np.where(rate == 0, A, B)
286306

307+
308+
def _ipmt_dispatcher(rate, per, nper, pv, fv=None, when=None):
309+
return (rate, per, nper, pv, fv)
310+
311+
312+
@array_function_dispatch(_ipmt_dispatcher)
287313
def ipmt(rate, per, nper, pv, fv=0, when='end'):
288314
"""
289315
Compute the interest portion of a payment.
@@ -379,6 +405,7 @@ def ipmt(rate, per, nper, pv, fv=0, when='end'):
379405
pass
380406
return ipmt
381407

408+
382409
def _rbl(rate, per, pmt, pv, when):
383410
"""
384411
This function is here to simply have a different name for the 'fv'
@@ -388,6 +415,12 @@ def _rbl(rate, per, pmt, pv, when):
388415
"""
389416
return fv(rate, (per - 1), pmt, pv, when)
390417

418+
419+
def _ppmt_dispatcher(rate, per, nper, pv, fv=None, when=None):
420+
return (rate, per, nper, pv, fv)
421+
422+
423+
@array_function_dispatch(_ppmt_dispatcher)
391424
def ppmt(rate, per, nper, pv, fv=0, when='end'):
392425
"""
393426
Compute the payment against loan principal.
@@ -416,6 +449,12 @@ def ppmt(rate, per, nper, pv, fv=0, when='end'):
416449
total = pmt(rate, nper, pv, fv, when)
417450
return total - ipmt(rate, per, nper, pv, fv, when)
418451

452+
453+
def _pv_dispatcher(rate, nper, pmt, fv=None, when=None):
454+
return (rate, nper, nper, pv, fv)
455+
456+
457+
@array_function_dispatch(_pv_dispatcher)
419458
def pv(rate, nper, pmt, fv=0, when='end'):
420459
"""
421460
Compute the present value.
@@ -520,13 +559,20 @@ def _g_div_gp(r, n, p, x, y, w):
520559
(n*t2*x - p*(t1 - 1)*(r*w + 1)/(r**2) + n*p*t2*(r*w + 1)/r +
521560
p*(t1 - 1)*w/r))
522561

562+
563+
def _rate_dispatcher(nper, pmt, pv, fv, when=None, guess=None, tol=None,
564+
maxiter=None):
565+
return (nper, pmt, pv, fv)
566+
567+
523568
# Use Newton's iteration until the change is less than 1e-6
524569
# for all values or a maximum of 100 iterations is reached.
525570
# Newton's rule is
526571
# r_{n+1} = r_{n} - g(r_n)/g'(r_n)
527572
# where
528573
# g(r) is the formula
529574
# g'(r) is the derivative with respect to r.
575+
@array_function_dispatch(_rate_dispatcher)
530576
def rate(nper, pmt, pv, fv, when='end', guess=None, tol=None, maxiter=100):
531577
"""
532578< 1241 /code>
Compute the rate of interest per period.
@@ -598,6 +644,12 @@ def rate(nper, pmt, pv, fv, when='end', guess=None, tol=None, maxiter=100):
598644
else:
599645
return rn
600646

647+
648+
def _irr_dispatcher(values):
649+
return (values,)
650+
651+
652+
@array_function_dispatch(_irr_dispatcher)
601653
def irr(values):
602654
"""
603655
Return the Internal Rate of Return (IRR).
@@ -677,6 +729,12 @@ def irr(values):
677729
rate = rate.item(np.argmin(np.abs(rate)))
678730
return rate
679731

732+
733+
def _npv_dispatcher(rate, values):
734+
return (values,)
735+
736+
737+
@array_function_dispatch(_npv_dispatcher)
680738
def npv(rate, values):
681739
"""
682740
Returns the NPV (Net Present Value) of a cash flow series.
@@ -722,6 +780,12 @@ def npv(rate, values):
722780
values = np.asarray(values)
723781
return (values / (1+rate)**np.arange(0, len(values))).sum(axis=0)
724782

783+
784+
def _mirr_dispatcher(values, finance_rate, reinvest_rate):
785+
return (values,)
786+
787+
788+
@array_function_dispatch(_mirr_dispatcher)
725789
def mirr(values, finance_rate, reinvest_rate):
726790
"""
727791
Modified internal rate of return.

0 commit comments

Comments
 (0)
0