8000 Fix uint view copy (#151598) (#154121) · pytorch/pytorch@306ba12 · GitHub
[go: up one dir, main page]

Skip to content

Commit 306ba12

Browse files
authored
Fix uint view copy (#151598) (#154121)
Fix for #151156. We have some logic to undo our upcast prior to dtype bitcast. This pr cleans up that logic using dtypes in codegen. Pull Request resolved: #151598 Approved by: https://github.com/zou3519 ghstack dependencies: #151562
1 parent 1ae9953 commit 306ba12

File tree

2 files changed

+23
-29
lines changed

2 files changed

+23
-29
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff li 8000 ne numberDiff line change
@@ -885,6 +885,20 @@ def test_scatter_index_not_wrapped(self):
885885
out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
886886
)
887887

888+
def test_uint_view_copy(self):
889+
@torch.compile
890+
def view_copy(target, source):
891+
assert target.dtype == torch.bfloat16
892+
assert source.dtype == torch.uint16
893+
target.view(torch.uint16).copy_(source)
894+
895+
target = torch.ones(1024, dtype=torch.bfloat16, device="cuda")
896+
source = torch.full_like(target, 4, dtype=torch.uint16)
897+
898+
out = target.view(torch.uint16).copy_(source).clone()
899+
view_copy(target, source)
900+
self.assertEqual(out, target.view(torch.uint16))
901+
888902
def test_embedding_var_mean(self):
889903
def forward(arg0_1):
890904
full = torch.ops.aten.full.default(

torch/_inductor/codegen/triton.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -706,16 +706,6 @@ def triton_compute_type(dtype: torch.dtype) -> str:
706706
return triton_type(upcast_compute_type(dtype))
707707

708708

709-
def _get_primitive_bitwidth(dtype: torch.dtype) -> int:
710-
"""Number of bits of triton_compute_type()"""
711-
dtype = upcast_compute_type(dtype)
712-
itemsize = getattr(dtype, "itemsize", None)
713-
if itemsize:
714-
return itemsize * 8
715-
else:
716-
return -1
717-
718-
719709
def triton_store_type(dtype: torch.dtype) -> str:
720710
"""Convert torch.dtype to triton type, with fix for storing tl.bool"""
721711
if dtype == torch.bool:
@@ -887,30 +877,20 @@ def _get_min_elements_per_thread(
887877

888878
@staticmethod
889879
def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
890-
triton_dtype = triton_compute_type(dtype)
880+
assert src_dtype.itemsize == dtype.itemsize
891881
# We may promote float16 or bfloat16 to float32 and cause the
892882
# bitwidth of dtype to be different from the input tensor (i.e. float32).
893883
# In such as case, we will have to convert the input tensor to
894884
# its src_type, perform bitcast, and then convert the bit-casted
895885
# tensor back to float to ensure we use values with the right precision.
896-
if (
897-
src_dtype in (torch.float16, torch.bfloat16)
898-
and config.triton.codegen_upcast_to_fp32
899-
):
900-
triton_src_dtype = str(src_dtype).split(".")[-1]
901-
cast_x = f"{x}.to(tl.{triton_src_dtype})"
902-
if dtype in (torch.float16, torch.bfloat16):
903-
triton_type_name = str(dtype).split(".")[-1]
904-
triton_dtype = f"tl.{triton_type_name}"
905-
cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)"
906-
if dtype in (torch.float16, torch.bfloat16):
907-
return f"{cast_x}.to(tl.float32)"
908-
return cast_x
909-
else:
910-
src_dtype_bitwidth = _get_primitive_bitwidth(src_dtype)
911-
target_dtype_bitwidth = _get_primitive_bitwidth(dtype)
912-
bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False"
913-
return f"{x}.to({triton_dtype}, bitcast={bitcast})"
886+
if x.dtype != src_dtype:
887+
x = f"{x}.to({triton_type(src_dtype)})"
888+
889+
out = f"{x}.to({triton_type(dtype)}, bitcast=True)"
890+
if upcast_compute_type(dtype) != dtype:
891+
out = f"{out}.to({triton_type(upcast_compute_type(dtype))})"
892+
893+
return out
914894

915895
@staticmethod
916896
def _shaped_constant(value, dtype, shape):

0 commit comments

Comments
 (0)
0