8000 [torch][ao] Do not crash numerics debugger if the shape of the tensors do not match by dulinriley · Pull Request #149330 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[torch][ao] Do not crash numerics debugger if the shape of the tensors do not match #149330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torch/ao/quantization/pt2e/_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,28 @@ def compare_results(
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
for a, b in zip(actual_stats, ref_stats):
if (
isinstance(a, torch.Tensor)
and isinstance(b, torch.Tensor)
and a.shape != b.shape
):
log.warning(
"Cannot compare tensors with different shapes: actual=%s vs ref=%s",
a.shape,
b.shape,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise exception here instead of ignore the error

Copy link
Contributor Author
@dulinriley dulinriley Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, because this type of error would typically indicate the debug handle was placed on the wrong nodes to compare before and after. Debug handles should always produce tensors of the same shape.

But debug handle errors should not prevent reporting of other accuracy errors, which an exception would do

try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
# Only compare objects if they're both collections, or both single
# tensors of the same shape.
# TODO: use torch.fx.node.map_aggregate for the collection of
# comparison results instead.
if not isinstance(a, torch.Tensor)
or not isinstance(b, torch.Tensor)
or a.shape == b.shape
Comment on lines +340 to +342
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also with this, some of the mismatches will be ignored, that's probably not what we want I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to ignore mismatches if they were comparing this of different types, or when the tensors have different shapes, because it's a problem with the debug handles at that point

]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
Expand Down
Loading
0