-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[export] Fix sym_bool serialization #144295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144295
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit fca2326 with merge base 127f836 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D67883754 |
21a010d
to
cf11419
Compare
Summary: As title Test Plan: CI Differential Revision: D67883754
This pull request was exported from Phabricator. Differential Revision: D67883754 |
cf11419
to
3b89ca4
Compare
Summary: As title Test Plan: CI Differential Revision: D67883754
This pull request was exported from Phabricator. Differential Revision: D67883754 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some tests for each case of float, int, bool to exercise the new checks of ops?
torch/_export/serde/serialize.py
Outdated
if isinstance(meta_val, torch.SymInt): | ||
sym_output = Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val)) | ||
elif isinstance(meta_val, torch.SymFloat): | ||
if isinstance(meta_val, torch.SymFloat) or node.target in _SYM_FLOAT_OPS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of these checks is not significant, right? Are _SYM_X_OPS
for X in FLOAT
, INT
, BOOL
disjoint?
Can you comment why you needed to check the ops? Is meta_val
sometimes None or does it not match the torch.SymX
types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avikchaudhuri
The order of these checks matter. Although _SYM_X_OPS
are disjoint, consider the line
x / torch.sym_sqrt(x.shape[0])
in test test_sym_sqrt
, the op is torch.sym_sqrt
belongs to _SYM_INT_OPS
, but the meta_val
is torch.SymFloat
. That's why float has to go first.
The reason I came up with this fix is that #136364 caused regression on serializing internal models reflected by the internal dashboard. Because for some internal models, there could be a node with target operator.ge
but has a meta_val of type bool
instead of torch.SymBool
. I tried to add tests but I was unable to reproduce this behavior in OSS. Any thoughts on this?
Also before #136364, we check both meta_val and op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avikchaudhuri Please see the updated summary and added test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm puzzled why torch.sym_sqrt
belongs to _SYM_INT_OPS
(Is it because the expected input type is symint but here we're talking about output type? it looks like a floaty op to me.) @angelayi
OK otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we don't actually need to differentiate between "symint ops" and "symfloat ops". Previously we only had "symint ops" because we only had symints 😅.
I think we can also replace this code with just self.serialize_output(node.name, meta_val)
and that function should handle everything for us.
3b89ca4
to
364a04c
Compare
This pull request was exported from Phabricator. Differential Revision: D67883754 |
Summary: When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`. Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`. Therefore, we need to check both `meta_val` type and `node.target` type during serialization. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal ``` Differential Revision: D67883754
364a04c
to
4cd99bd
Compare
This pull request was exported from Phabricator. Differential Revision: D67883754 |
4cd99bd
to
13673bc
Compare
Summary: When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`. Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`. Therefore, we need to check both `meta_val` type and `node.target` type during serialization. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal ``` Reviewed By: avikchaudhuri Differential Revision: D67883754
This pull request was exported from Phabricator. Differential Revision: D67883754 |
Summary: When there is a `torch._check()` that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with target `operation.ge`, `operator.le` and `operator.eq`. These operators belong to `_SYM_BOOL_OPS` but the `meta_val` of these nodes are are `bool` instead of `torch.SymBool`. Similar things can happen for `torch.SymInt` as well, where a `node.target` belongs to `_SYM_INT_OPS` but `node.meta["val"]` is an `int` instead of `torch.SymInt`. Therefore, we need to check both `meta_val` type and `node.target` type during serialization. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_bool_torch_check_equal buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_sym_int_torch_check_equal ``` Reviewed By: avikchaudhuri, angelayi Differential Revision: D67883754
13673bc
to
fca2326
Compare
This pull request was exported from Phabricator. Differential Revision: D67883754 |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
When there is a
torch._check()
that checks if a sym_int is equal to some constant, it will generate 3 nodes in the graph with targetoperation.ge
,operator.le
andoperator.eq
. These operators belong to_SYM_BOOL_OPS
but themeta_val
of these nodes are arebool
instead oftorch.SymBool
.Similar things can happen to
torch.SymInt
, where anode.target
belongs to_SYM_INT_OPS
butnode.meta["val"]
is anint
instead oftorch.SymInt
.Therefore, we need to check both
meta_val
type andnode.target
type during serialization.Test Plan:
Differential Revision: D67883754