-
-
Notifications
You must be signed in to change notification settings - Fork 18.7k
ENH: Add numba engine for rolling apply #30151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
3b9bff8
9a302bf
0e9a600
36a77ed
dbb2a9b
f0e9a4d
1250aee
4e7fd1a
cb976cf
45420bb
17851cf
20767ca
9619f8d
66fa69c
b8908ea
135f2ad
34a5687
6da8199
123f77e
54e74d1
04d3530
4bbf587
f849bc7
0c30e48
c4c952e
8645976
987c916
b775684
2e04e60
9b20ff5
0c14033
c7106dc
1640085
2846faf
5a645c0
6bac000
6f1c73f
a890337
0a9071c
9d8d40b
84c3491
a429206
5826ad9
cf7571b
4bc9787
18eed60
f715b55
6a765bf
af3fe50
eb7b5e1
f7dfcf4
a42a960
d019830
29d145f
248149c
a3da51e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,7 +6,11 @@ | |||||
|
||||||
|
||||||
def _generate_numba_apply_func( | ||||||
args: Tuple, kwargs: Dict, func: Callable, engine_kwargs: Optional[Dict] | ||||||
args: Tuple, | ||||||
kwargs: Dict, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
func: Callable, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm not sure about the *args piece but here are the docs for it: https://docs.python.org/3/library/typing.html#typing.Callable So I would think Scalar can be imported from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I'll give that a shot. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to go with |
||||||
engine_kwargs: Optional[Dict], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slightly more explicit would be preferable |
||||||
function_cache: Dict, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add subtypes here? Is this |
||||||
): | ||||||
""" | ||||||
Generate a numba jitted apply function specified by values from engine_kwargs. | ||||||
|
@@ -37,6 +41,10 @@ def _generate_numba_apply_func( | |||||
else: | ||||||
loop_range = range | ||||||
|
||||||
# Return an already compiled version of roll_apply if available | ||||||
if func in function_cache: | ||||||
return function_cache[func] | ||||||
|
||||||
def make_rolling_apply(func): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you move make_rolling_apply to module scope (out of this function)? also don't you need to actually assign to the cache? (on a miss) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cache assignment happens here https://github.com/pandas-dev/pandas/pull/30151/files#diff-0de5c5d9abfcdd141e83701eaaec4358R541 (the function needs to run on some data first) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then i wouldn’t even check the cache here ; that must be at the higher level There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay sure thing. I'll move it up then. |
||||||
|
||||||
if isinstance(func, numba.targets.registry.CPUDispatcher): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,7 @@ def __init__( | |
self.win_freq = None | ||
self.axis = obj._get_axis_number(axis) if axis is not None else None | ||
self.validate() | ||
self._numba_func_cache = dict() | ||
|
||
@property | ||
def _constructor(self): | ||
|
@@ -443,6 +444,7 @@ def _apply( | |
floor: int = 1, | ||
is_weighted: bool = False, | ||
name: Optional[str] = None, | ||
use_numba_cache: Optional[bool] = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reason we would not want this ever? (e.g. shouldn't we always cache) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable name is a bit of a misnomer. i.e. this is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. k cool, can you document that somewhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can just be annotated as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seconding Will's question. this appears above too |
||
**kwargs, | ||
): | ||
""" | ||
|
@@ -455,10 +457,11 @@ def _apply( | |
func : callable function to apply | ||
center : bool | ||
require_min_periods : int | ||
floor: int | ||
is_weighted | ||
name: str, | ||
floor : int | ||
is_weighted : bool | ||
name : str, | ||
compatibility with groupby.rolling | ||
use_numba_cache : bool | ||
**kwargs | ||
additional arguments for rolling function and window function | ||
|
||
|
@@ -533,6 +536,9 @@ def calc(x): | |
result = calc(values) | ||
result = np.asarray(result) | ||
|
||
if use_numba_cache: | ||
self._numba_func_cache[name] = func | ||
|
||
if center: | ||
result = self._center_window(result, window) | ||
|
||
|
@@ -1303,13 +1309,21 @@ def apply( | |
elif engine == "numba": | ||
if raw is False: | ||
raise ValueError("raw must be `True` when using the numba engine") | ||
apply_func = _generate_numba_apply_func(args, kwargs, func, engine_kwargs) | ||
apply_func = _generate_numba_apply_func( | ||
args, kwargs, func, engine_kwargs, self._numba_func_cache | ||
) | ||
else: | ||
raise ValueError("engine must be either 'numba' or 'cython'") | ||
|
||
# TODO: Why do we always pass center=False? | ||
# name=func for WindowGroupByMixin._apply | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._apply(apply_func, center=False, floor=0, name=func) | ||
return self._apply( | ||
apply_func, | ||
center=False, | ||
floor=0, | ||
name=func, | ||
use_numba_cache=engine == "numba", | ||
) | ||
|
||
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pandas import Series | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok with this being non-private (e.g. generate_numba_apply_func), unless its only used in this module