8000 [inductor] dtype promotion error in cat decomp (#152995) · pytorch/pytorch@8ea95d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ea95d2

Browse files
pianpwkpytorchmergebot
authored andcommitted
[inductor] dtype promotion error in cat decomp (#152995)
cloning single tensor wasn't following dtype promotion rules for SAM model: #152606 Pull Request resolved: #152995 Approved by: https://github.com/yushangdi, https://github.com/eellison
1 parent e21ff9c commit 8ea95d2

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ resnext50_32x4d,pass,0
286286

287287

288288

289-
sam,fail_to_run,0
289+
sam,pass,0
290290

291291

292292

test/inductor/test_aot_inductor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,17 @@ def forward(self, y):
573573
model = LinearModel(device=self.device)
574574
self.check_model(model, example_inputs)
575575

576+
def test_empty_cat_dtype_promotion(self):
577+
class Foo(torch.nn.Module):
578+
def forward(self, x, y):
579+
z = torch.cat([x, y], dim=1)
580+
z = z.to(dtype=torch.bfloat16)
581+
return z * 2
582+
583+
model = Foo()
584+
inps = (torch.randn(4, 10, dtype=torch.bfloat16), torch.randn(4, 0))
585+
self.check_model(model, inps)
586+
576587
@unittest.skipIf(
577588
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
578589
)

torch/_inductor/decomposition.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,17 @@ def non_empty_tensor(x: torch.Tensor) -> bool:
390390
filtered_tensors = list(filter(non_empty_tensor, tensors))
391391

392392
if len(filtered_tensors) == 1:
393-
return filtered_tensors[0].clone()
393+
# check dtype promotion
394+
promoted_dtype = elementwise_dtypes(
395+
*tensors,
396+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
397+
)[1]
398+
filtered_t = filtered_tensors[0]
399+
return (
400+
filtered_t.clone()
401+
if promoted_dtype == filtered_t.dtype
402+
else filtered_t.to(dtype=promoted_dtype)
403+
)
394404
elif 1 < len(filtered_tensors) < len(tensors):
395405
# on the first call, when we remove empty tensors, we redispatch recursively
396406
return aten.cat.default(filtered_tensors, dim)

0 commit comments

Comments
 (0)
0