[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fx fixes #402

Merged
merged 16 commits into from
Jun 4, 2023
Merged

Fx fixes #402

merged 16 commits into from
Jun 4, 2023

Conversation

borisfom
Copy link
Contributor

Changes to make TensorProduct classes FX-traceable

Description

Those changes are partial and would only allow subset of classes used in DiffDock, pass FX symbolic_trace().
jit_script_fx needs to be set to False for this to work, too.

Motivation and Context

Resolves: #???

How Has This Been Tested?

Checklist:

  • I have read the CONTRIBUTING document.
  • My code follows the code style of this project.
  • I have updated the documentation (if relevant).
  • I have added tests that cover my changes (if relevant).
  • The modified code is cuda compatible (github tests don't test cuda) (if relevant).
  • I have updated the Changelog.

borisfom and others added 11 commits May 13, 2023 23:53
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
@mariogeiger
Copy link
Member

Thanks a lot for the PR, it looks good!
Is it ready to merge?

@borisfom
Copy link
Contributor Author

Thanks a lot for the PR, it looks good!
Is it ready to merge?

Not sure, there is a test failure, but I can't tell if it's a problem or if it actually fixes earlier issue :
FAILED tests/o3/tensor_product_test.py::test_input_weights_jit[float64] - Failed: DID NOT RAISE (<class 'RuntimeError'>, <class 'torch.jit.Error'>)

@mariogeiger
Copy link
Member

If you can investigate a bit this bug it's cool, otherwise we can also just comment it for now

@mariogeiger
Copy link
Member

I guess it's just because now some assertions are inside an if statement

@Linux-cpp-lisp
Copy link
Contributor
Linux-cpp-lisp commented May 31, 2023

Instead of wrapping things in conditionals, I think you could replace assert with torch._assert(condition, message) which is fx traceable.

See https://pytorch.org/docs/stable/generated/torch._assert.html.

This is cleaner and doesn't change the behavior of the functions under tracing.

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
@borisfom
Copy link
Contributor Author

I guess it's just because now some assertions are inside an if statement

Right; I have used Lisp's suggestion now :)

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
@borisfom
Copy link
Contributor Author
borisfom commented Jun 4, 2023

@mariogeiger : I guess it's ready for merge now :)

@mariogeiger mariogeiger merged commit dd5afda into e3nn:main Jun 4, 2023
@mariogeiger
Copy link
Member

Good job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants