8000 Fix uint view copy (#151598) by eellison · Pull Request #154121 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix uint view copy (#151598) #154121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,20 @@ def test_scatter_index_not_wrapped(self):
out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
)

def test_uint_view_copy(self):
@torch.compile
def view_copy(target, source):
assert target.dtype == torch.bfloat16
assert source.dtype == torch.uint16
target.view(torch.uint16).copy_(source)

target = torch.ones(1024, dtype=torch.bfloat16, device="cuda")
source = torch.full_like(target, 4, dtype=torch.uint16)

out = target.view(torch.uint16).copy_(source).clone()
view_copy(target, source)
self.assertEqual(out, target.view(torch.uint16))

def test_embedding_var_mean(self):
def forward(arg0_1):
full = torch.ops.aten.full.default(
Expand Down
38 changes: 9 additions & 29 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,6 @@ def triton_compute_type(dtype: torch.dtype) -> str:
return triton_type(upcast_compute_type(dtype))


def _get_primitive_bitwidth(dtype: torch.dtype) -> int:
"""Number of bits of triton_compute_type()"""
dtype = upcast_compute_type(dtype)
itemsize = getattr(dtype, "itemsize", None)
if itemsize:
return itemsize * 8
else:
return -1


def triton_store_type(dtype: torch.dtype) -> str:
"""Convert torch.dtype to triton type, with fix for storing tl.bool"""
if dtype == torch.bool:
Expand Down Expand Up @@ -887,30 +877,20 @@ def _get_min_elements_per_thread(

@staticmethod
def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
triton_dtype = triton_compute_type(dtype)
assert src_dtype.itemsize == dtype.itemsize
# We may promote float16 or bfloat16 to float32 and cause the
# bitwidth of dtype to be different from the input tensor (i.e. float32).
# In such as case, we will have to convert the input tensor to
# its src_type, perform bitcast, and then convert the bit-casted
# tensor back to float to ensure we use values with the right precision.
if (
src_dtype in (torch.float16, torch.bfloat16)
and config.triton.codegen_upcast_to_fp32
):
triton_src_dtype = str(src_dtype).split(".")[-1]
cast_x = f"{x}.to(tl.{triton_src_dtype})"
if dtype in (torch.float16, torch.bfloat16):
triton_type_name = str(dtype).split(".")[-1]
triton_dtype = f"tl.{triton_type_name}"
cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)"
if dtype in (torch.float16, torch.bfloat16):
return f"{cast_x}.to(tl.float32)"
return cast_x
else:
src_dtype_bitwidth = _get_primitive_bitwidth(src_dtype)
target_dtype_bitwidth = _get_primitive_bitwidth(dtype)
bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False"
return f"{x}.to({triton_dtype}, bitcast={bitcast})"
if x.dtype != src_dtype:
x = f"{x}.to({triton_type(src_dtype)})"

out = f"{x}.to({triton_type(dtype)}, bitcast=True)"
if upcast_compute_type(dtype) != dtype:
out = f"{out}.to({triton_type(upcast_compute_type(dtype))})"

return out

@staticmethod
def _shaped_constant(value, dtype, shape):
Expand Down
Loading
0