-
Notifications
You must be signed in to change notification settings - Fork 24.3k
type promotion with 0d-tensors diverges from array API specification #58736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
It seems import numpy as np
dtype_categories = (
(np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64),
(np.float16, np.float32, np.float64),
)
for dtypes in dtype_categories:
for idx in range(len(dtypes) - 1):
dtype_nd = dtypes[idx]
for dtype_0d in dtypes[idx + 1:]:
a = np.empty((1,), dtype=dtype_nd)
b = np.empty((), dtype=dtype_0d)
print(f"{a.dtype} + {b.dtype} = {np.result_type(a, b)}") (excluding all correct results)
|
import jax.numpy as jnp
dtype_categories = (
(jnp.int8, jnp.int16, jnp.int32, jnp.uint8, jnp.uint16, jnp.uint32),
(jnp.float16, jnp.float32),
)
for dtypes in dtype_categories:
for idx in range(len(dtypes) - 1):
dtype_nd = dtypes[idx]
for dtype_0d in dtypes[idx + 1:]:
a = jnp.empty((1,), dtype=dtype_nd)
b = jnp.empty((1,), dtype=dtype_0d)
print(f"{a.dtype} + {b.dtype} = {jnp.result_type(a, b)}") (excluding all correct results)
IMO this only happens, because they don't support |
To make this a little more accessible I wrote a small tool to check an arbitrary array API: https://gist.github.com/pmeier/ea35bdffb597b35f4f6592c5ac201cd4 |
NumPy's type promotion rules are becoming much more consistent with NEP 50. You can test NEP 50 behavior in NumPy by setting |
Uh oh!
There was an error while loading. Please reload this page.
The array API specification stipulates clear type promotion rules, that are independent of the array size and values:
PyTorch mostly adheres to this with one exception: Within a dtype category (integral, floating, complex) 0d-tensors do not participate in type promotion:
This is not documented well(see #58489), but seems to be intentional.
cc @nairbv @mruberry
The text was updated successfully, but these errors were encountered: