8000 `custom_jvp` and `custom_vjp` · Issue #87222 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

custom_jvp and custom_vjp #87222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
0x00b1 opened this issue Oct 18, 2022 · 3 comments
Open

custom_jvp and custom_vjp #87222

0x00b1 opened this issue Oct 18, 2022 · 3 comments
Labels
module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@0x00b1
Copy link
Contributor
0x00b1 commented Oct 18, 2022

🚀 The feature, motivation and pitch

functorch equivalents to jax.custom_jvp and jax.custom_vjp, i.e., decorators to define custom derivatives:

PyTorch API

class F(Function):
    @staticmethod
    def jvp():
        pass

    @staticmethod
    def vjp():
        pass

f = F()

Hypothetical functorch API

@functorch.custom_jvp
@functorch.custom_vjp
def f(x: Tensor) -> Tensor:
    pass

@f.x_jvp
def f_x_jvp():
    pass

@f.x_vjp
def f_x_vjp():
    pass

Alternatives

No response

Additional context

No response

cc @zou3519 @Chillee @samdow @soumith

@Chillee
Copy link
Collaborator
Chillee commented Oct 18, 2022

cc: @zou3519

@zou3519
Copy link
Contributor
zou3519 commented Oct 18, 2022

Yes, this is on our radar. @0x00b1 do you have a concrete use case for this?

@0x00b1
Copy link
Contributor Author
0x00b1 commented Oct 19, 2022

Yes, this is on our radar. @0x00b1 do you have a concrete use case for this?

I do! There’e alternatives but it’d make implementing this (or, minimally, prototyping) RFC easier!

See my description in the Implementation section.

@samdow samdow added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: functorch Pertaining to torch.func or pytorch/functorch labels Oct 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0