Support alpha=inf consistently for torch.celu #148065
Labels
module: nn
Related to torch.nn
module: python frontend
For issues relating to PyTorch's Python frontend
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
The celu activation function introduced in https://arxiv.org/pdf/1704.07483 is described as being C1 differentiable across values of alpha.
The Pytorch celu (and related quantized_celu) ) implementation(s) appears to call a parameterized elu, following the celu definition, so it looks like it will have these properties within bounds of overflow.
On the other hand I am not sure what the Pytorch behavior is intended to be in the implementation for alpha=0 and alpha=inf, but I think it would be nice to have it consistent.
At alpha=0 it raises RuntimeError ZeroDivisionError, which is fine, though it seems could instead for alpha=0 just short-circuit to return relu(x).
On the other hand, at alpha=torch.inf, positive values of x appear to return x, while non-positive values of x return
torch.nan
:This seems inconsistent- since the celu(x, alpha=torch.inf) implementation is already returning f(x) = x for the positive domain, I'd suggest to make it a C1 function for the alpha=torch.inf case by just returning the identity op in that case.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
The text was updated successfully, but these errors were encountered: