-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: Replace _lazywhere
with xpx.apply_where
#22557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
_lazy_apply
with xpx.apply_where
_lazywhere
with xpx.apply_where
426254a
to
226a055
Compare
They should be running in the array API job because we install |
To reduce the diff and potential for regressions, can we move the existing |
scipy/stats/_stats_py.py
Outdated
@@ -670,7 +668,7 @@ def tmean(a, limits=None, inclusive=(True, True), axis=None): | |||
# explicit dtype specification required due to data-apis/array-api-compat#152 | |||
sum = xp.sum(a, axis=axis, dtype=a.dtype) | |||
n = xp.sum(xp.asarray(~mask, dtype=a.dtype), axis=axis, dtype=a.dtype) | |||
mean = _lazywhere(n != 0, (sum, n), xp.divide, xp.nan) | |||
mean = xpx.apply_where(n != 0, operator.truediv, (sum, n), fill_value=xp.nan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explanation here: https://github.com/data-apis/array-api-extra/pull/141/files#r1961964801
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't studied the reason yet, but change looks fine.
You're correct, my bad. |
To clarify, are you asking me to revert to
and leave
Sorry I'm not familiar with what you mean with "old" versus "new" infrastructure - could you explain, give me a pointer to the documentation, or give me the exact verbiage? |
Thanks for asking. Yeah - very close.
Never mind. Let me review more closely. If it's all very systematic, we can just change it all. I thought there were more changes in |
scipy/stats/_discrete_distns.py
Outdated
(r != 0) | (k != 0), | ||
lambda k, M, n, r: | ||
(-betaln(k+1, r) + betaln(k+r, 1) | ||
- betaln(n-k+1, M-r-n+1) + betaln(M-r-k+1, 1) | ||
+ betaln(n+1, M-n+1) - betaln(M+1, 1)), | ||
(k, M, n, r), fill_value=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that is confusing me while reviewing these is that the third positional argument can be the arguments when fill_value
is used, but the third positional argument is f2
elsewhere. In these case, please consider https://github.com/data-apis/array-api-extra/pull/141/files#r1962054546 or using keywords.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed most files and it all looks pretty good, but I'd like to address https://github.com/scipy/scipy/pull/22557#discussion_r1962042105/https://github.com/data-apis/array-api-extra/pull/141/files#r1962054546 before finishing up. Also, I'd feel more comfortable with the change if the test with hypothesis
were run on apply_where
(e.g. if it were brought over to array-api-extra
). That was written pretty carefully IIRC.
226a055
to
61bd874
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does look quite good to me. One minor question inline. I agree with @mdhaber's review - can be merged once xpx.apply_where
is in.
Slightly reworked, hopefully improved UI
Yep I think so!
@@ -692,11 +689,11 @@ def tmean(a, limits=None, inclusive=(True, True), axis=None): | |||
# explicit dtype specification required due to data-apis/array-api-compat#152 | |||
sum = xp.sum(a, axis=axis, dtype=a.dtype) | |||
n = xp.sum(xp.asarray(~mask, dtype=a.dtype), axis=axis, dtype=a.dtype) | |||
mean = _lazywhere(n != 0, (sum, n), xp.divide, xp.nan) | |||
mean = xpx.apply_where(n != 0, (sum, n), operator.truediv, fill_value=xp.nan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did xp.divide
not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See data-apis/array-api-extra#160:
If you forget about the meta-namespace and just use xp in the lambdas, at the moment most things will keep working.
This is because accidentally several functions in the dask.array, numpy, and cupy namespaces are interoperable or even the same function. However you will find cases where this doesn't hold true and you need the correct namespace.This will become a much bigger source of headaches in the future when dask around generic Array API compatible namespaces will become commonplace (note: Dask does NOT support them today).
This pattern repeats itself many, many times in scipy. At the moment there are only a handful of cases that are array API-aware, and they all use xp.divide, so the problem can be worked around by replacing it with operator.truediv. But if you look at scipy.stats in #22557 you'll find a miriad of calls to np. functions inside the lambdas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, I remember that now. Given that the code now looks a little harder to understand and this trick won't actually work in other places, I'd probably keep it unchanged and have an issue for solving this problem - and just fail the Dask test in the meantime.
If that's too much churn and this one-line change is more pragmatic, then fine with me as well of course. Dask just has a bunch of issues with namespacing, and this feels like a hack that happens to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK operator.truediv
and xp.divide
are identical though? It doesn't feel harder to read to me before or after the change?
_lazywhere
with xpx.apply_where
_lazywhere
with xpx.apply_where
Test failures are unrelated. @mdhaber this is ready to merge! |
jax.jit
full
and similar should accept 0-d masked array as input mdhaber/marray#89. Unit tests remain all green. @mdhaber