-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[testing] Port torch.{repeat, tile}
tests to use OpInfo machinery
#50199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c20dd9f
126058c
af87706
ea92bca
1973767
f7171b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,10 @@ | |
from torch._six import nan | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict) | ||
from torch.testing._internal.common_methods_invocations import shape_funcs | ||
from torch.testing._internal.common_device_type import ( | ||
instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA, | ||
dtypesIfCPU, dtypesIfCUDA) | ||
dtypesIfCPU, dtypesIfCUDA, ops) | ||
|
||
# TODO: replace with make_tensor | ||
def _generate_input(shape, dtype, device, with_extremal): | ||
|
@@ -599,7 +600,21 @@ def test_nonzero_non_diff(self, device): | |
nz = x.nonzero() | ||
self.assertFalse(nz.requires_grad) | ||
|
||
class TestShapeFuncs(TestCase): | ||
"""Test suite for Shape manipulating operators using the ShapeFuncInfo.""" | ||
|
||
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128)) | ||
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good way to filter the OpInfos. In the future we may want to consider allowing filtering directly by name or function. That is,
Which might be simpler to remember and more readable. |
||
def test_repeat_tile_vs_numpy(self, device, dtype, op): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since torch.repeat is more restrictive than tile, does this test not test where torch.tile and np.tile are compatible but torch.repeat isn't? |
||
samples = op.sample_inputs(device, dtype, requires_grad=False) | ||
for sample in samples: | ||
(t, dims) = sample.input | ||
expected = op.ref(t.cpu().numpy(), dims, **sample.kwargs) | ||
result = op(t, d 10000 ims, **sample.kwargs).cpu().numpy() | ||
self.assertEqual(expected, result) | ||
|
||
instantiate_device_type_tests(TestShapeOps, globals()) | ||
instantiate_device_type_tests(TestShapeFuncs, globals()) | ||
|
||
if __name__ == '__main__': | ||
run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -443,6 +443,34 @@ def sample_movedim_moveaxis(op_info, device, dtype, requires_grad): | |
requires_grad=requires_grad), | ||
(0, -1, -2, -3), (-3, -2, -1, -0)))) | ||
|
||
|
||
def sample_repeat_tile(op_info, device, dtype, requires_grad): | ||
rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) | ||
shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) | ||
|
||
if requires_grad: | ||
There was a problem hiding this comment. 28BEChoose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good way to filter the samples for gradcheck and gradgradcheck. |
||
# Tests for variant_consistency_jit, grad, gradgrad | ||
# are slower. Use smaller bags of `rep_dims` and `shapes` | ||
# in this case. | ||
rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore | ||
shapes = ((), (0,), (2,), (3, 2)) # type: ignore | ||
|
||
tensors = [make_tensor(shape, device, dtype, | ||
low=None, high=None, | ||
requires_grad=requires_grad) for shape in shapes] | ||
|
||
samples = [] | ||
for rep_dim, tensor in product(rep_dims, tensors): | ||
for t in (tensor, tensor.T): | ||
if op_info.name == 'repeat' and len(rep_dim) >= t.dim(): | ||
# `torch.repeat` errors for `len(rep_dims) < t.dim()`, | ||
# so we filter such combinations. | ||
samples.append(SampleInput((t, rep_dim),)) | ||
elif op_info.name == 'tile': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment here explaining the filtering for tile and repeat. This is a clever way to filter, and this answers my previous question about test coverage. |
||
samples.append(SampleInput((t, rep_dim),)) | ||
|
||
return samples | ||
|
||
def np_unary_ufunc_integer_promotion_wrapper(fn): | ||
# Wrapper that passes PyTorch's default scalar | ||
# type as an argument to the wrapped NumPy | ||
|
@@ -529,6 +557,28 @@ def sample_inputs(self, device, dtype, requires_grad=False): | |
] | ||
|
||
|
||
class ShapeFuncInfo(OpInfo): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a brief comment here explaining what this derived class is intended for (maybe something like, "Early version of a specialized OpInfo for "shape" operations like tile and roll"?) |
||
"""Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" | ||
def __init__(self, | ||
name, # the string name of the function | ||
*, | ||
ref, # a reference function | ||
dtypes=floating_types(), | ||
dtypesIfCPU=None, | ||
dtypesIfCUDA=None, | ||
dtypesIfROCM=None, | ||
sample_inputs_func=None, | ||
**kwargs): | ||
super(ShapeFuncInfo, self).__init__(name, | ||
dtypes=dtypes, | ||
dtypesIfCPU=dtypesIfCPU, | ||
dtypesIfCUDA=dtypesIfCUDA, | ||
dtypesIfROCM=dtypesIfROCM, | ||
sample_inputs_func=sample_inputs_func, | ||
**kwargs) | ||
self.ref = ref | ||
|
||
|
||
class HermitianOpInfo(OpInfo): | ||
"""Operator information for Hermitian functions | ||
These are functions that take Hermitian matrices as input. | ||
|
@@ -578,7 +628,6 @@ def sample_inputs_linalg_pinv_hermitian(op_info, device, dtype, requires_grad=Fa | |
o.kwargs = {"hermitian": True} | ||
return out | ||
|
||
|
||
def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): | ||
""" | ||
This function generates always solvable input for torch.linalg.solve | ||
|
@@ -1405,6 +1454,24 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): | |
test_inplace_grad=False, | ||
supports_tensor_out=False, | ||
sample_inputs_func=sample_movedim_moveaxis), | ||
ShapeFuncInfo('repeat', | ||
op=lambda x, dims: x.repeat(dims), | ||
ref=np.tile, | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), | ||
supports_tensor_out=False, | ||
test_inplace_grad=False, | ||
skips=( | ||
# torch.repeat does not exist so we get a RuntimeError. | ||
SkipInfo('TestCommon', 'test_variant_consistency_jit', | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), | ||
), | ||
sample_inputs_func=sample_repeat_tile), | ||
ShapeFuncInfo('tile', | ||
ref=np.tile, | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), | ||
supports_tensor_out=False, | ||
test_inplace_grad=False, | ||
sample_inputs_func=sample_repeat_tile), | ||
] | ||
|
||
if TEST_SCIPY: | ||
|
@@ -1506,6 +1573,7 @@ def reference_sigmoid(x): | |
unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] | ||
spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] | ||
sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse is True] | ||
shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)] | ||
|
||
def index_variable(shape, max_indices, device=torch.device('cpu')): | ||
if not isinstance(shape, tuple): | ||
|
@@ -1982,14 +2050,6 @@ def method_tests(): | |
('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]), | ||
('renorm', (S, S, S), (1, 2, 3), 'norm_1'), | ||
('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'), | ||
('repeat', (S,), (2,), 'single_number'), | ||
('repeat', (), (2, 3), 'scalar'), | ||
('repeat', (2, 2), (3, 2)), | ||
('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'), | ||
('repeat', (S, S), (1, 1), 'keepdim0'), | ||
('repeat', (S, S), (3, 1, 1), 'keepdim1'), | ||
('repeat', (S,), (0, ), 'zero_dim'), | ||
('repeat', (S,), (0, 2), 'zero_dim_multi'), | ||
('logcumsumexp', (S, S, S), (0,), 'dim0', (), [0]), | ||
('logcumsumexp', (S, S, S), (1,), 'dim1', (), [0]), | ||
('logcumsumexp', (), (0,), 'dim0_scalar', (), [0]), | ||
|
@@ -2206,11 +2266,6 @@ def method_tests(): | |
('diagonal', (M, M, M), (1, 1, 2), '3d_1'), | ||
('diagonal', (M, M, M), (2, 0, 1), '3d_2'), | ||
('diagonal', (M, M, M), (-2, 0, 1), '3d_3'), | ||
('tile', (2, 2), ([2, 2, 2],), 'more_reps_dims', (False,)), | ||
('tile', (2, 2), ([2, 2],), 'same_reps_dims', (False,)), | ||
('tile', (2, 2), ([2, 3],), 'less_reps_dims', (False,)), | ||
('tile', (2, 2, 2), ([2, 2, 0],), 'zero_rep_dim', (False,)), | ||
('tile', (), ([S, S, S],), 'empty_tensor', (False,)), | ||
('tril', (M, M), NO_ARGS), | ||
('tril', (M, M), (2,), 'idx'), | ||
('tril', (S, M, M), NO_ARGS, 'batched'), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment here describing what this class is for.
Do we really want to add another test class instead of just putting this function into the existing TestShapeOps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought, going forward all
ShapeFuncsInfo
Op tests might as well go underTestShapeFuncs
. But we can just move it to the existing class as well.Do let me know if that is preferred.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way is fine, just add a comment explaining why people should put a test here, for example, vs. the other test suite.