8000 [testing] Port `torch.{repeat, tile}` tests to use OpInfo machinery by kshitij12345 · Pull Request #50199 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[testing] Port torch.{repeat, tile} tests to use OpInfo machinery #50199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

kshitij12345
Copy link
Collaborator

Reference: #50013

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jan 7, 2021

💊 CI failures summary and remediations

As of commit ea92bca (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@kshitij12345
Copy link
Collaborator Author

Test Timings

test_shape_ops.py

torch.repeat

============================================================================= slowest 10 durations =============================================================================
1.15s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_repeat_cuda_complex128
0.14s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_repeat_cpu_complex128
0.06s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_repeat_cuda_float64
0.06s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_repeat_cuda_int64
0.06s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_repeat_cuda_uint8
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_repeat_cpu_float64
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_repeat_cpu_int64
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_repeat_cpu_uint8

(2 durations < 0.005s hidden.  Use -vv to show these durations.)
====================================================================== 8 passed, 117 deselected in 3.01s =======================================================================

torch.tile

============================================================================= slowest 10 durations =============================================================================
1.35s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_tile_cuda_complex128
0.24s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_tile_cpu_complex128
0.08s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_tile_cuda_uint8
0.07s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_tile_cuda_float64
0.07s call     test/test_shape_ops.py::TestShapeFuncsCUDA::test_repeat_tile_vs_numpy_tile_cuda_int64
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_tile_cpu_float64
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_tile_cpu_uint8
0.01s call     test/test_shape_ops.py::TestShapeFuncsCPU::test_repeat_tile_vs_numpy_tile_cpu_int64

(2 durations < 0.005s hidden.  Use -vv to show these durations.)
====================================================================== 8 passed, 117 deselected in 3.32s =======================================================================

test_ops.py

torch.repeat

============================================================================= slowest 10 durations =============================================================================
2.18s call     test/test_ops.py::TestGradientsCUDA::test_fn_gradgrad_repeat_cuda_complex128
1.49s call     test/test_ops.py::TestOpInfoCUDA::test_supported_dtypes_repeat_cuda_bfloat16
0.72s call     test/test_ops.py::TestGradientsCUDA::test_fn_gradgrad_repeat_cuda_float64
0.65s call     test/test_ops.py::TestGradientsCPU::test_fn_gradgrad_repeat_cpu_complex128
0.45s call     test/test_ops.py::TestGradientsCUDA::test_fn_grad_repeat_cuda_complex128
0.28s call     test/test_ops.py::TestOpInfoCPU::test_supported_dtypes_repeat_cpu_bfloat16
0.27s call     test/test_ops.py::TestGradientsCUDA::test_fn_grad_repeat_cuda_float64
0.18s call     test/test_ops.py::TestGradientsCPU::test_fn_gradgrad_repeat_cpu_float64
0.17s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_eager_repeat_cuda_complex64
0.17s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_eager_repeat_cuda_complex128
=============================================================== 56 passed, 56 skipped, 5570 deselected in 13.90s ===============================================================

torch.tile

============================================================================= slowest 10 durations =============================================================================
2.97s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_complex64
2.87s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_complex128
2.68s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float32
2.64s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float16
2.62s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_bfloat16
2.61s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float64
2.08s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_complex64
2.07s call     test/test_ops.py::TestGradientsCUDA::test_fn_gradgrad_tile_cuda_complex128
2.02s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_complex128
1.98s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_float16
=============================================================== 80 passed, 32 skipped, 5570 deselected in 48.66s ===============================================================

@kshitij12345
Copy link
Collaborator Author

In terms of time required by the test, I see that test_variant_consistency_jit requires a lot of time. Just running the test_variant_consistency_jit for torch.tile takes about 40s.

============================================================================= slowest 10 durations =============================================================================
3.80s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_bfloat16
2.98s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_complex64
2.88s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_complex128
2.63s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float64
2.59s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float32
2.56s call     test/test_ops.py::TestCommonCUDA::test_variant_consistency_jit_tile_cuda_float16
2.18s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_complex64
2.17s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_bfloat16
2.15s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_complex128
2.03s call     test/test_ops.py::TestCommonCPU::test_variant_consistency_jit_tile_cpu_float16
===================================================================== 24 passed, 5658 deselected in 40.08s =====================================================================

Running all other Op Tests (gradcheck, gradgradcheck, etc) with test_variant_consistency_jit takes about 48-50s.

@kshitij12345
Copy link
Collaborator Author

can torch.repeat be implemented as a call to torch.tile? I understand that torch.tile is actually implemented as a call to repeat currently, but from a UX standpoint, could we alias torch.repeat to torch.tile? It's true that torch.tile can accept more inputs than torch.repeat, but will every valid input to torch.repeat produce the same output when given to torch.tile?

torch.tile is more general than torch.repeat. torch.tile supports cases where passed dims could be less, more or same as the actual dim of the tensor to be repeated. While torch.repeat only accepts cases where passed dims is more or same as the actual dim.

So we can implement torch.repeat in terms of torch.tile by just checking for reps.size() >= self.dim()

Example

Tensor repeat(const Tensor& self, IntArrayRef reps){
  TORCH_CHECK(reps.size() >= self.dim(),
           "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
  return self.tile(reps);
}

@kshitij12345 kshitij12345 requested a review from mruberry January 7, 2021 15:02
@kshitij12345 kshitij12345 marked this pull request as ready for review January 7, 2021 15:02
@codecov
Copy link
codecov bot commented Jan 7, 2021

Codecov Report

Merging #50199 (f7171b7) into master (3f052ba) will decrease coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #50199      +/-   ##
==========================================
- Coverage   80.66%   80.57%   -0.09%     
==========================================
  Files        1912     1912              
  Lines      208058   208078      +20     
==========================================
- Hits       167820   167662     -158     
- Misses      40238    40416     +178     

@@ -605,7 +606,19 @@ def test_nonzero_non_diff(self, device):
nz = x.nonzero()
self.assertFalse(nz.requires_grad)

class TestShapeFuncs(TestCase):
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128))
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good way to filter the OpInfos. In the future we may want to consider allowing filtering directly by name or function. That is,

@ops('tile', 'repeat')
@ops(torch.tile, torch.repeat)

Which might be simpler to remember and more readable.

class TestShapeFuncs(TestCase):
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128))
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']])
def test_repeat_tile_vs_numpy(self, device, dtype, op):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since torch.repeat is more restrictive than tile, does this test not test where torch.tile and np.tile are compatible but torch.repeat isn't?

rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),)
shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1))

if requires_grad:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good way to filter the samples for gradcheck and gradgradcheck.

for t in (tensor, tensor.T):
if op_info.name == 'repeat' and len(rep_dim) >= t.dim():
samples.append(SampleInput((t, rep_dim),))
elif op_info.name == 'tile':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here explaining the filtering for tile and repeat. This is a clever way to filter, and this answers my previous question about test coverage.

@@ -500,6 +526,26 @@ def sample_inputs(self, device, dtype, requires_grad=False):
]


class ShapeFuncInfo(OpInfo):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a brief comment here explaining what this derived class is intended for (maybe something like, "Early version of a specialized OpInfo for "shape" operations like tile and roll"?)

@@ -605,7 +606,19 @@ def test_nonzero_non_diff(self, device):
nz = x.nonzero()
self.assertFalse(nz.requires_grad)

class TestShapeFuncs(TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here describing what this class is for.

Do we really want to add another test class instead of just putting this function into the existing TestShapeOps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought, going forward all ShapeFuncsInfo Op tests might as well go under TestShapeFuncs. But we can just move it to the existing class as well.

Do let me know if that is preferred.

Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way is fine, just add a comment explaining why people should put a test here, for example, vs. the other test suite.

Copy link
Collaborator
@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kshitij12345, thanks for taking a look at these functions! I appreciate the new tricks you've used to acquire just their OpInfos and filter the sample inputs based on whether they'll be used for {grad}gradchecks or not. Nice work.

I've made a few small comments. Also, repeat has an entry here that we can remove:

('repeat', '', _small_2d, lambda t, d: [2, 2, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),

* Add comment for TestShapeFuncs class
* Add comment for ShapeFuncInfo class
* Add comment for filtering inputs for repeat
* Remove redundant test from test_torch.py
@kshitij12345
Copy link
Collaborator Author

@mruberry PTAL :)

Copy link
Collaborator
@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! This looks great, thanks @kshitij12345!

Copy link
Contributor
@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kshitij12345 kshitij12345 deleted the develop/opinfo/repeat-tile branch January 19, 2021 14:40
@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 316f0b8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0