8000 [export] Fix sym_bool serialization by yiming0416 · Pull Request #144295 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[export] Fix sym_bool serialization #144295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

yiming0416
Copy link
Contributor
@yiming0416 yiming0416 commented Jan 7, 2025

Summary:
When there is a torch._check() that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target operation.ge, operator.le and operator.eq. These operators belong to _SYM_BOOL_OPS but the meta_val of these nodes are are bool instead of torch.SymBool.

Similar things can happen to torch.SymInt, where a node.target belongs to _SYM_INT_OPS but node.meta["val"] is an int instead of torch.SymInt.

Therefore, we need to check both meta_val type and node.target type during serialization.

Test Plan:

buck2 run @mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal
buck2 run @mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal

Differential Revision: D67883754

Copy link
pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144295

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit fca2326 with merge base 127f836 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

yiming0416 added a commit to yiming0416/pytorch that referenced this pull request Jan 7, 2025
Summary:

As title

Test Plan: CI

Differential Revision: D67883754
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

yiming0416 added a commit to yiming0416/pytorch that referenced this pull request Jan 7, 2025
Summary:

As title

Test Plan: CI

Differential Revision: D67883754
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

Copy link
Contributor
@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests for each case of float, int, bool to exercise the new checks of ops?

if isinstance(meta_val, torch.SymInt):
sym_output = Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))
elif isinstance(meta_val, torch.SymFloat):
if isinstance(meta_val, torch.SymFloat) or node.target in _SYM_FLOAT_OPS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of these checks is not significant, right? Are _SYM_X_OPS for X in FLOAT, INT, BOOL disjoint?

Can you comment why you needed to check the ops? Is meta_val sometimes None or does it not match the torch.SymX types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avikchaudhuri
The order of these checks matter. Although _SYM_X_OPS are disjoint, consider the line

x / torch.sym_sqrt(x.shape[0])

in test test_sym_sqrt, the op is torch.sym_sqrt belongs to _SYM_INT_OPS, but the meta_val is torch.SymFloat. That's why float has to go first.

The reason I came up with this fix is that #136364 caused regression on serializing internal models reflected by the internal dashboard. Because for some internal models, there could be a node with target operator.ge but has a meta_val of type bool instead of torch.SymBool. I tried to add tests but I was unable to reproduce this behavior in OSS. Any thoughts on this?

Also before #136364, we check both meta_val and op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avikchaudhuri Please see the updated summary and added test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm puzzled why torch.sym_sqrt belongs to _SYM_INT_OPS (Is it because the expected input type is symint but here we're talking about output type? it looks like a floaty op to me.) @angelayi

OK otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we don't actually need to differentiate between "symint ops" and "symfloat ops". Previously we only had "symint ops" because we only had symints 😅.

I think we can also replace this code with just self.serialize_output(node.name, meta_val) and that function should handle everything for us.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 7, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

yiming0416 added a commit to yiming0416/pytorch that referenced this pull request Jan 8, 2025
Summary:

When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`.

Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`.

Therefore, we need to check both `meta_val` type and `node.target` type during serialization.

Test Plan:
```
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal
```

Differential Revision: D67883754
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

yiming0416 added a commit to yiming0416/pytorch that referenced this pull request Jan 9, 2025
Summary:

When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`.

Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`.

Therefore, we need to check both `meta_val` type and `node.target` type during serialization.

Test Plan:
```
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal
```

Reviewed By: avikchaudhuri

Differential Revision: D67883754
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

Summary:

When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`.

Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`.

Therefore, we need to check both `meta_val` type and `node.target` type during serialization.

Test Plan:
```
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal
```

Reviewed By: avikchaudhuri, angelayi

Differential Revision: D67883754
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67883754

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0