-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Implement logaddexp #38384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement logaddexp #38384
Conversation
💊 CI failures summary and remediationsAs of commit 97ab0ca (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 53 times. |
Remove WIP from the title when you're ready for review |
@mruberry anything left? |
Just mark it as no longer WIP/draft when you're ready for review. |
aten/src/ATen/native/BinaryOps.cpp
Outdated
@@ -742,6 +744,28 @@ Tensor& fmod_(Tensor& self, Scalar other) { | |||
return at::fmod_out(self, self, other); | |||
} | |||
|
|||
Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) { | |||
auto iter = TensorIterator::binary_op(result, self, other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need a /*check_mem_overlap=*/=true
here? What if result is self or other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there will be no problem if result is self or other, but problem might occur when view of tensor overlap with other view, so check_mem_overlap
is definitely needed here.
aten/src/ATen/native/BinaryOps.cpp
Outdated
} | ||
|
||
Tensor& logaddexp2_out(Tensor& result, const Tensor& self, const Tensor& other) { | ||
auto iter = TensorIterator::binary_op(result, self, other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same check_mem_overlap
question as above.
cpu_kernel_vec( | ||
iter, | ||
[=](scalar_t a, scalar_t b) -> scalar_t { | ||
scalar_t m = std::max(a, b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if a == b == inf or -inf?
cpu_kernel_vec( | ||
iter, | ||
[=](scalar_t a, scalar_t b) -> scalar_t { | ||
scalar_t m = std::max(a, b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question when a == b == +/- inf
@@ -514,6 +514,14 @@ | |||
- name: log2(Tensor self) -> Tensor | |||
self: grad / (self * 0.6931471805599453) | |||
|
|||
- name: logaddexp(Tensor self, Tensor other) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD Take a look, would you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The formula looks good.
For testing, you want to add an entry here to make sure the gradients will be properly checked.
test/test_torch.py
Outdated
@@ -313,6 +313,38 @@ def test_dim_reduction_less_than_64(self): | |||
with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"): | |||
torch.sum(x, -1) | |||
|
|||
def _test_logaddexp(self, base2=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD What's our plan for testing the derivatives of new functions like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also curious.
test/test_torch.py
Outdated
@@ -313,6 +313,38 @@ def test_dim_reduction_less_than_64(self): | |||
with self.assertRaisesRegex(RuntimeError, "PyTorch doesn't support reduction operations for dim>=64"): | |||
torch.sum(x, -1) | |||
|
|||
def _test_logaddexp(self, base2=False): | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't import NumPy, instead assert TEST_NUMPY
test/test_torch.py
Outdated
ours = our_func(a, b) | ||
self.assertTrue(np.allclose(ours.numpy(), gt)) | ||
|
||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should move these tests into TestTorchDeviceType so you can run them on both the CPU and GPU. Right now you're just testing the CPU. Then see the helper function "_np_compare," which you can use to simplify your tests, and the @dtypes decorator so you can test multiple dtypes. Since you're only enabling logaddexp on float types, you should test at least torch.long (assert throws RuntimeError), torch.float32, and torch.complex64 (assert throws RuntimeError). That way if someone implements logaddexp for complex64, for example, they'll know to update this test.
_np_compare works by taking values to test at. Your value generation is OK but you should add some more "interesting" values like -math.pi, 0, and math.pi, as well as extremal values like nan, -inf, and inf.
torch/_torch_docs.py
Outdated
|
||
Args: | ||
{input} | ||
other (Tensor): the tensor to compute logaddexp with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is a little misleading. Maybe something like: "the tensor whose exponential is added to the exponential of input before the log is taken"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your suggestion is too verbose, may be I should just delete it and let the generator generate the second input tensor
for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
Args: | ||
{input} | ||
other (Tensor): the tensor to compute logaddexp2 with | ||
{out} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same language/doc changes here as with the above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is looking really good. I requested changes to the tests and docs and have a couple additional questions. I also want to check with @albanD if/how we're planning to validate the gradients of these functions.
Once we get that fixed up I think this PR will be good to go!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I'll let @mruberry do a final pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Resolve #38377
Related #38349
This op should be disambiguated with
logsumexp
which do a reduction on a tensor over a specific axis.