8000 Fix test failures in JAX under NumPy 1.25.0rc1. by copybara-service[bot] · Pull Request #16297 · jax-ml/jax · GitHub
[go: up one dir, main page]

Skip to content

Conversation

@copybara-service
Copy link

Fix test failures in JAX under NumPy 1.25.0rc1.

jnp.finfo(...) of an Array type yields:

TypeError: unhashable type: 'ArrayImpl'

However, np.finfo(...) no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array.

@hawkinsp
Copy link
Collaborator
hawkinsp commented Jun 9, 2023

We're going to hold off submitting this because numpy/numpy#23867 might resolve this.

@eendebakpt
Copy link

@hawkinsp numpy/numpy#23867 has been resolved via numpy/numpy#14847. The changes in this PR might be good anyway. jnp.finfo(alpha_k.dtype) is a bit faster than jnp.finfo(alpha_k) (although I have no idea whether that part of the code is performance critical)

@hawkinsp
Copy link
Collaborator
hawkinsp commented Jun 9, 2023

@eendebakpt Thanks for the fix!

Yes, it seems like this might be a tiny bit better anyway. I'll submit it.

@copybara-service copybara-service bot closed this Jun 9, 2023
@copybara-service copybara-service bot deleted the test_538516573 branch June 9, 2023 21:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

0