10000 type promotion with 0d-tensors diverges from array API specification · Issue #58736 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

type promotion with 0d-tensors diverges from array API specification #58736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #58743
pmeier opened this issue May 21, 2021 · 4 comments
Open
Tracked by #58743

type promotion with 0d-tensors diverges from array API specification #58736

pmeier opened this issue May 21, 2021 · 4 comments
Labels
module: python array api Issues related to the Python Array API module: type promotion Related to semantics of type promotion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pmeier
Copy link
Collaborator
pmeier commented May 21, 2021

The array API specification stipulates clear type promotion rules, that are independent of the array size and values:

Type promotion stipulated by the array API specifictation

PyTorch mostly adheres to this with one exception: Within a dtype category (integral, floating, complex) 0d-tensors do not participate in type promotion:

import torch

dtype_categories = (
    (torch.int8, torch.uint8, torch.int32, torch.int64),
    (torch.float16, torch.bfloat16, torch.float32, torch.float64),
    (torch.complex32, torch.complex64, torch.complex128),
)

for dtypes in dtype_categories:
    for idx in range(len(dtypes) - 1):
        dtype_nd = dtypes[idx]
        for dtype_0d in dtypes[idx + 1:]:
            a = torch.empty((1,), dtype=dtype_nd)
            b = torch.empty((), dtype=dtype_0d)
            print(f"{a.dtype} + {b.dtype} = {torch.result_type(a, b)}")
torch.int8 + torch.uint8 = torch.int8
torch.int8 + torch.int32 = torch.int8
torch.int8 + torch.int64 = torch.int8
torch.uint8 + torch.int32 = torch.uint8
torch.uint8 + torch.int64 = torch.uint8
torch.int32 + torch.int64 = torch.int32
torch.float16 + torch.bfloat16 = torch.float16
torch.float16 + torch.float32 = torch.float16
torch.float16 + torch.float64 = torch.float16
torch.bfloat16 + torch.float32 = torch.bfloat16
torch.bfloat16 + torch.float64 = torch.bfloat16
torch.float32 + torch.float64 = torch.float32
torch.complex32 + torch.complex64 = torch.complex32
torch.complex32 + torch.complex128 = torch.complex32
torch.complex64 + torch.complex128 = torch.complex64

This is not documented well(see #58489), but seems to be intentional.

cc @nairbv @mruberry

@pmeier pmeier added module: python array api Issues related to the Python Array API module: type promotion Related to semantics of type promotion labels May 21, 2021
@pmeier
Copy link
Collaborator Author
pmeier commented May 21, 2021

It seems numpy has it own set of problems with this:

import numpy as np

dtype_categories = (
    (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64),
    (np.float16, np.float32, np.float64),
)

for dtypes in dtype_categories:
    for idx in range(len(dtypes) - 1):
        dtype_nd = dtypes[idx]
        for dtype_0d in dtypes[idx + 1:]:
            a = np.empty((1,), dtype=dtype_nd)
            b = np.empty((), dtype=dtype_0d)
            print(f"{a.dtype} + {b.dtype} = {np.result_type(a, b)}")

(excluding all correct results)

int8 + uint8 = int8
int8 + uint16 = int16
int8 + uint32 = int32
int16 + uint16 = int16
int16 + uint32 = int32
int16 + uint64 = int64
int32 + uint32 = int32
int32 + uint64 = int32
int64 + uint64 = int64
uint32 + uint64 = uint32
float16 + float32 = float16
float16 + float64 = float16
float32 + float64 = float32
  • For the integral types the result of mixed singed / unsigned promotions needs to up by one byte
  • For floating types and weirdly enough uint32 to uint64 they have the same behavior as we have: the 0d tensor does not participate in type promotion.

@pmeier
Copy link
Collaborator Author
pmeier commented May 21, 2021

jax only fails mixes signed / unsigned promotions involving uint32:

import jax.numpy as jnp

dtype_categories = (
    (jnp.int8, jnp.int16, jnp.int32, jnp.uint8, jnp.uint16, jnp.uint32),
    (jnp.float16, jnp.float32),
)

for dtypes in dtype_categories:
    for idx in range(len(dtypes) - 1):
        dtype_nd = dtypes[idx]
        for dtype_0d in dtypes[idx + 1:]:
            a = jnp.empty((1,), dtype=dtype_nd)
            b = jnp.empty((1,), dtype=dtype_0d)
            print(f"{a.dtype} + {b.dtype} = {jnp.result_type(a, b)}")

(excluding all correct results)

int8 + uint32 = int32
int16 + uint32 = int32
int32 + uint32 = int32

IMO this only happens, because they don't support uint64 (yet). This is also supported by the fact that this behavior does not change if no 0d-tensor is involved.

@pmeier
Copy link
Collaborator Author
pmeier commented May 21, 2021

To make this a little more accessible I wrote a small tool to check an arbitrary array API: https://gist.github.com/pmeier/ea35bdffb597b35f4f6592c5ac201cd4

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 23, 2021
@asmeurer
Copy link
Collaborator

NumPy's type promotion rules are becoming much more consistent with NEP 50. You can test NEP 50 behavior in NumPy by setting NPY_PROMOTION_STATE=weak or np._set_promotion_state('weak') in NumPy 1.24.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python array api Issues related to the Python Array API module: type promotion Related to semantics of type promotion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0