-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[testing] Port torch.{repeat, tile}
tests to use OpInfo machinery
#50199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[testing] Port torch.{repeat, tile}
tests to use OpInfo machinery
#50199
Conversation
💊 CI failures summary and remediationsAs of commit ea92bca (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Test Timings test_shape_ops.py
test_ops.py
|
In terms of time required by the test, I see that
Running all other Op Tests (gradcheck, gradgradcheck, etc) with |
So we can implement Example Tensor repeat(const Tensor& self, IntArrayRef reps){
TORCH_CHECK(reps.size() >= self.dim(),
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
return self.tile(reps);
} |
Codecov Report
@@ Coverage Diff @@
## master #50199 +/- ##
==========================================
- Coverage 80.66% 80.57% -0.09%
==========================================
Files 1912 1912
Lines 208058 208078 +20
==========================================
- Hits 167820 167662 -158
- Misses 40238 40416 +178 |
@@ -605,7 +606,19 @@ def test_nonzero_non_diff(self, device): | |||
nz = x.nonzero() | |||
self.assertFalse(nz.requires_grad) | |||
|
|||
class TestShapeFuncs(TestCase): | |||
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128)) | |||
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good way to filter the OpInfos. In the future we may want to consider allowing filtering directly by name or function. That is,
@ops('tile', 'repeat')
@ops(torch.tile, torch.repeat)
Which might be simpler to remember and more readable.
class TestShapeFuncs(TestCase): | ||
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128)) | ||
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']]) | ||
def test_repeat_tile_vs_numpy(self, device, dtype, op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since torch.repeat is more restrictive than tile, does this test not test where torch.tile and np.tile are compatible but torch.repeat isn't?
rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) | ||
shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) | ||
|
||
if requires_grad: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good way to filter the samples for gradcheck and gradgradcheck.
for t in (tensor, tensor.T): | ||
if op_info.name == 'repeat' and len(rep_dim) >= t.dim(): | ||
samples.append(SampleInput((t, rep_dim),)) | ||
elif op_info.name == 'tile': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment here explaining the filtering for tile and repeat. This is a clever way to filter, and this answers my previous question about test coverage.
@@ -500,6 +526,26 @@ def sample_inputs(self, device, dtype, requires_grad=False): | |||
] | |||
|
|||
|
|||
class ShapeFuncInfo(OpInfo): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a brief comment here explaining what this derived class is intended for (maybe something like, "Early version of a specialized OpInfo for "shape" operations like tile and roll"?)
@@ -605,7 +606,19 @@ def test_nonzero_non_diff(self, device): | |||
nz = x.nonzero() | |||
self.assertFalse(nz.requires_grad) | |||
|
|||
class TestShapeFuncs(TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment here describing what this class is for.
Do we really want to add another test class instead of just putting this function into the existing TestShapeOps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought, going forward all ShapeFuncsInfo
Op tests might as well go under TestShapeFuncs
. But we can just move it to the existing class as well.
Do let me know if that is preferred.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way is fine, just add a comment explaining why people should put a test here, for example, vs. the other test suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kshitij12345, thanks for taking a look at these functions! I appreciate the new tricks you've used to acquire just their OpInfos and filter the sample inputs based on whether they'll be used for {grad}gradchecks or not. Nice work.
I've made a few small comments. Also, repeat has an entry here that we can remove:
Line 6818 in 55919a4
('repeat', '', _small_2d, lambda t, d: [2, 2, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), |
* Add comment for TestShapeFuncs class * Add comment for ShapeFuncInfo class * Add comment for filtering inputs for repeat * Remove redundant test from test_torch.py
@mruberry PTAL :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! This looks great, thanks @kshitij12345!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Reference: #50013