E64F [inductor] dtype promotion error in cat decomp by pianpwk · Pull Request #152995 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] dtype promotion error in cat decomp #152995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove all_tensors_same
  • Loading branch information
pianpwk committed May 7, 2025
commit 24bcb5ee6c89465688f3a630842792bee8ff9224
6 changes: 1 addition & 5 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,7 @@ def forward(self, x, y):

model = Foo()
inps = (torch.randn(4, 10, dtype=torch.bfloat16), torch.randn(4, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.check_model is the common utility for AOTI compile and check.

ep = torch.export.export(model, inps, strict=False)
optimized = torch._inductor.aoti_load_package(
torch._inductor.aoti_compile_and_package(ep)
)
self.assertTrue(same(optimized(*inps), model(*inps)))
self.check_model(model, inps)

@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
Expand Down
21 changes: 6 additions & 15 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,6 @@ def cat(
tensors: list[torch.Tensor],
dim: int = 0,
) -> torch.Tensor:
def all_tensors_same_dtype(tensors: list[torch.Tensor]) -> bool:
return all(t.dtype == tensors[0].dtype for t in tensors)

def non_empty_tensor(x: torch.Tensor) -> bool:
# For better or worse, this is a valid cat:
#
Expand Down Expand Up @@ -394,18 +391,12 @@ def non_empty_tensor(x: torch.Tensor) -> bool:

if len(filtered_tensors) == 1:
# check dtype promotion
if (
not all_tensors_same_dtype(tensors)
and (
promoted_dtype := elementwise_dtypes(
*tensors,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)[1]
) != filtered_tensors[0].dtype
):
return filtered_tensors[0].to(dtype=promoted_dtype)
else:
return filtered_tensors[0].clone()
promoted_dtype = elementwise_dtypes(
*tensors,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)[1]
filtered_t = filtered_tensors[0]
return filtered_t.clone() if promoted_dtype == filtered_t.dtype else filtered_t.to(dtype=promoted_dtype)
elif 1 < len(filtered_tensors) < len(tensors):
# on the first call, when we remove empty tensors, we redispatch recursively
return aten.cat.default(filtered_tensors, dim)
Expand Down
Loading