8000 [testing] improve broadcasts_input error message (#58295) · pytorch/pytorch@49c2da0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 49c2da0

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[testing] improve broadcasts_input error message (#58295)
Summary: Context: The Error message when `broadcasts_input` is marked incorrectly is uninformative [See Error Currently] #57941 (comment) Error Currently ``` Traceback (most recent call last): File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 326, in test_variant_consistency_eager _test_consistency_helper(samples, variants) File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 310, in _test_consistency_helper variant_forward = variant(cloned, File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__ self._raiseFailure("{} not raised".format(exc_name)) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure raise self.test_case.failureException(msg) AssertionError: RuntimeError not raised ``` Error After PR ``` Traceback (most recent call last): File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 329, in test_variant_consistency_eager _test_consistency_helper(samples, variants) File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 313, in _test_consistency_helper variant_forward = variant(cloned, File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__ self._raiseFailure("{} not raised".format(exc_name)) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure raise self.test_case.failureException(msg) AssertionError: RuntimeError not raised : inplace variant either allowed resizing or you have marked the sample SampleInput(input=Tensor, args=(tensor([[[ 2.1750, -8.5027, -3.1403, -6.9942, 3.2609], [-2.5057, -5.9123, -5.4633, 6.1203, -8.2124], [-3.5802, -8.4869, -6.0700, 2.3431, -8.1955], [-7.3316, 1.3248, -6.8661, 7.1483, -8.0719], [ 4.5977, -4.0448, -6.2044, -2.1314, -8.4956]], [[ 3.2769, -8.4360, 1.2826, 7.1749, 4.7653], [-0.2816, -2.5997, -4.7659, -3.7814, 3.9704], [-2.1778, -3.8117, -6.0276, -0.8423, -5.9646], [ 8.6544, -3.0922, 0.2558, -4.9318, -4.7596], [ 4.5583, 4.3830, 5.8793, 0.9713, -2.1481]], [[-1.0447, 0.9334, 7.6405, -4.8933, -7.4010], [ 7.7168, -8.4266, -5.5980, -6.9368, 7.1309], [-8.7720, -5.0890, -0.4975, 1.9518, 1.7074], [-8.5783, 8.5510, -8.5459, -3.5451, 8.4319], [ 8.5052, -8.9149, -6.6298, -1.2750, -5.7367]], [[-6.5625, 8.2795, -4.9311, 1.9501, -7.1777], [-8.4035, 1.1136, -7.6418, -7.0726, -2.8281], [ 4.2668, -0.2883, -6.2246, 2.3396, 1.2911], [ 4.6550, -1.9525, 4.4873, -3.8061, -0.8653], [-3.4256, 4.4423, 8.2937, -5.3456, -4.2624]], [[ 7.6128, -6.3932, 4.7131, -5.4938, 6.4792], [-6.5385, 2.4385, 4.5570, 3.7803, -8.3281], [-2.9785, -4.4745, -1.1778, -8.9324, 1.3663], [ 3.7437, 3.5171, -6.3135, -8.4519, -2.7033], [-5.0568, -8.4630, -4.2870, -3.7284, -1.5238]]], device='cuda:0', dtype=torch.float32, requires_grad=True),), broadcasts_input=True) incorrectly with `broadcasts_self=True ``` **NOTE**: Printing the sample looks very verbose and it may be hard to figure out which sample is incorrectly configured if there are multiple samples with similar input shapes. Two Options to make this error less verbose * Don't print the sample and just print `inplace variant either allowed resizing or you have marked one of the sample incorrectly with broadcasts_self=True` * Have some mechanism to name samples which will be printed in the `repr` (which will need extra machinery) Pull Request resolved: #58295 Reviewed By: ngimel Differential Revision: D28627308 Pulled By: mruberry fbshipit-source-id: b3bdeacac3cf9c0d984f0b85410ecce474291d20
1 parent 083d3bb commit 49c2da0

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

test/test_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,10 @@ def _test_consistency_helper(samples, variants):
321321
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
322322

323323
if variant in inplace_ops and sample.broadcasts_input:
324-
with self.assertRaises(RuntimeError):
324+
with self.assertRaises(RuntimeError,
325+
msg=('inplace variant either incorrectly allowed '
326+
'resizing or you have marked the sample {}'
327+
' incorrectly with `broadcasts_self=True'.format(sample.summary()))):
325328
variant_forward = variant(cloned,
326329
*sample.args,
327330
**sample.kwargs)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def __init__(self, cls_name=None, test_name=None, *,
8181
class SampleInput(object):
8282
"""Represents sample inputs to a function."""
8383

84-
__slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input']
84+
__slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input', 'name']
8585

86-
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False):
86+
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False, name=""):
8787
# input is the first input to the op and must be either a Tensor or TensorList (Sequence[Tensor]).
8888
# This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...).
8989
# op with TensorList inputs do not support method or inplace variants.
@@ -92,6 +92,7 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N
9292
self.args = args
9393
self.kwargs = kwargs if kwargs is not None else {}
9494
self.output_process_fn_grad = output_process_fn_grad
95+
self.name = name
9596

9697
# Specifies if `self.input` is broadcasted or not,
9798
# given that the operator supports broadcasting.
@@ -103,17 +104,56 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N
103104
# for such inputs (as they will error out otherwise).
104105
self.broadcasts_input = broadcasts_input
105106

106-
def __repr__(self):
107+
def _repr_helper(self, formatter):
108+
# Helper function to return the details of the SampleInput as `str`
109+
# It consolidates all the fields of SampleInput and allows,
110+
# formatting the fields like `input`, `args`, etc with `formatter`
111+
# callable to customize the representation.
112+
# Look at `summary` method for example.
107113
arguments = [
108-
'input=Tensor' if isinstance(self.input, torch.Tensor) else f'input=TensorList[{len(self.input)}]',
109-
f'args={self.args}' if len(self.args) > 0 else None,
110-
f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None,
111-
(f'output_process_fn_grad={self.output_process_fn_grad}'
112-
if self.output_process_fn_grad is not None else None),
113-
f'broadcasts_input={self.broadcasts_input}']
114+
f'input={formatter(self.input)}',
115+
f'args={formatter(self.args)}',
116+
f'kwargs={formatter(self.kwargs)}',
117+
f'output_process_fn_grad={self.output_process_fn_grad}',
118+
f'broadcasts_input={self.broadcasts_input}',
119+
f'name={repr(self.name)}']
114120

115121
return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
116122

123+
def __repr__(self):
124+
return self._repr_helper(lambda x: x)
125+
126+
def summary(self):
127+
# Returns the SampleInput details in a more
128+
# friendly format.
129+
# It formats `Tensor` and `TensorList`
130+
# in a more condensed representation.
131+
def is_iter(arg):
132+
try:
133+
iter(arg)
134+
return True
135+
except TypeError as te:
136+
return False
137+
138+
def formatter(arg):
139+
# Format any instance of `Tensor` (standalone, in list, or in dict)
140+
# by Tensor[TensorShape]
141+
# Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4]
142+
if isinstance(arg, torch.Tensor):
143+
shape = str(tuple(arg.shape)).replace('(', '').replace(')', '')
144+
return f"Tensor[{shape}]"
145+
elif isinstance(arg, dict):
146+
return {k: formatter(v) for k, v in arg.items()}
147+
elif is_iterable_of_tensors(arg):
148+
return "TensorList[" + ", ".join(map(formatter, arg)) + "]"
149+
elif is_iter(arg): # Handle list, tuple or any iterable type
150+
return "(" + ",".join(map(formatter, arg)) + ")"
151+
152+
return repr(arg)
153+
154+
return self._repr_helper(formatter)
155+
156+
117157
class AliasInfo(object):
118158
"""Class holds alias information. For example, torch.abs ->
119159
torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_

0 commit comments

Comments
 (0)
0