8000 [torch][ao] Do not crash numerics debugger if the shape of the tensor… · pytorch/pytorch@c755508 · GitHub
[go: up one dir, main page]

Skip to content

Commit c755508

Browse files
committed
[torch][ao] Do not crash numerics debugger if the shape of the tensors do not match (#149330)
Summary: Pull Request resolved: #149330 Occasionally we see the loss function to crash because the shapes of `a` and `b` tensors are different. This diff avoids crashing is such scenarios and lets the comparison work for other nodes where the shapes match. Test Plan: - CI Reviewed By: jerryzh168 Differential Revision: D66245053
1 parent 5e9f792 commit c755508

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

torch/ao/quantization/pt2e/_numeric_debugger.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,28 @@ def compare_results(
318318
)
319319
continue
320320
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
321+
for a, b in zip(actual_stats, ref_stats):
322+
if (
323+
isinstance(a, torch.Tensor)
324+
and isinstance(b, torch.Tensor)
325+
and a.shape != b.shape
326+
):
327+
log.warning(
328+
"Cannot compare tensors with different shapes: actual=%s vs ref=%s",
329+
a.shape,
330+
b.shape,
331+
)
321332
try:
322333
results = [
323334
QuantizationComparisonResult(actual=a, ref=b)
324335
for a, b in zip(actual_stats, ref_stats)
336+
# Only compare objects if they're both collections, or both single
337+
# tensors of the same shape.
338+
# TODO: use torch.fx.node.map_aggregate for the collection of
339+
# comparison results instead.
340+
if not isinstance(a, torch.Tensor)
341+
or not isinstance(b, torch.Tensor)
342+
or a.shape == b.shape
325343
]
326344
except Exception as e:
327345
# Add extra information for an exception from QuantizationComparisonResult

0 commit comments

Comments
 (0)
0