8000 Support alpha=inf consistently for torch.celu · Issue #148065 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support alpha=inf consistently for torch.celu #148065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
redwrasse opened this issue Feb 27, 2025 · 1 comment · May be fixed by #148066
Open

Support alpha=inf consistently for torch.celu #148065

redwrasse opened this issue Feb 27, 2025 · 1 comment · May be fixed by #148066
Labels
module: nn Related to torch.nn module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@redwrasse
Copy link
Contributor
redwrasse commented Feb 27, 2025

The celu activation function introduced in https://arxiv.org/pdf/1704.07483 is described as being C1 differentiable across values of alpha.

  • for alpha -> inf it converges to the identity op (f(x) = x)
  • for alpha -> 0+ it converges to relu(x) = max(0, x)

The Pytorch celu (and related quantized_celu) ) implementation(s) appears to call a parameterized elu, following the celu definition, so it looks like it will have these properties within bounds of overflow.

On the other hand I am not sure what the Pytorch behavior is intended to be in the implementation for alpha=0 and alpha=inf, but I think it would be nice to have it consistent.

At alpha=0 it raises RuntimeError ZeroDivisionError, which is fine, though it seems could instead for alpha=0 just short-circuit to return relu(x).

On the other hand, at alpha=torch.inf, positive values of x appear to return x, while non-positive values of x return torch.nan:

x = torch.tensor(2.)
 torch.celu(x, torch.inf)
# tensor(2.)
torch.celu(-x, torch.inf)
# tensor(nan)

x = torch.tensor(0.)
print(torch.celu(x, torch.inf))
# tensor(nan)
print(torch.celu(-x, torch.inf))
# torch.nan

This seems inconsistent- since the celu(x, alpha=torch.inf) implementation is already returning f(x) = x for the positive domain, I'd suggest to make it a C1 function for the alpha=torch.inf case by just returning the identity op in that case.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: python frontend For issues relating to PyTorch's Python frontend module: nn Related to torch.nn labels Feb 27, 2025
@redwrasse
Copy link
Contributor Author

Bump if there's interest, have a PR that's going to go stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0