-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[multigraph] add backend_specialization kwarg to mark_dynamic #152597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ec5014
00879aa
a689ee7
edafd9c
c7ce4af
f9ec640
ace4b5a
810c187
428bb5c
738c112
020c2c8
7952fa0
File filter
8000Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -552,7 +552,7 @@ def mark_unbacked(t, index, strict=False): | |
|
||
|
||
@forbid_in_graph | ||
def mark_dynamic(t, index, *, min=None, max=None): | ||
def mark_dynamic(t, index, *, min=None, max=None, backend_specializations=None): | ||
""" | ||
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. | ||
|
||
|
@@ -575,6 +575,18 @@ def mark_dynamic(t, index, *, min=None, max=None): | |
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made | ||
before torch.compile. | ||
|
||
5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do multiple different mark'ed dynamic dims with backend specializations produce a cross product of specialized graphs? |
||
multiple specialized compilations in addition to a single generic compilation. At runtime, we will dispatch | ||
to a specialized compiled region if the input matches the specialization criteria. For example: | ||
|
||
mark_dynamic(..., backend_specializations=[( | ||
16, # hint value | ||
lambda x: x == 16 # specialization predicate | ||
)]) | ||
|
||
Comment on lines
+578
to
+586
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. backend_specializations is a bit wordy. I don't have a better idea though.
? |
||
This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 16 | ||
at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate | ||
2-8x speedups depending on the specific specialization and model architecture. | ||
""" | ||
if is_traceable_wrapper_subclass(t): | ||
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t | ||
|
@@ -587,14 +599,21 @@ def mark_dynamic(t, index, *, min=None, max=None): | |
if not hasattr(t, "_dynamo_dynamic_indices"): | ||
t._dynamo_dynamic_indices = set() | ||
t._dynamo_dynamic_range = set() | ||
|
||
if not hasattr(t, "_backend_specializations"): | ||
t._backend_specializations = {} | ||
|
||
# TODO(voz): Should we bounds check? | ||
t._dynamo_dynamic_indices.add(index) | ||
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) | ||
t._backend_specializations[index] = backend_specializations | ||
return | ||
|
||
assert isinstance(index, (list, tuple)) | ||
for i in index: | ||
mark_dynamic(t, i, min=min, max=max) | ||
mark_dynamic( | ||
t, i, min=min, max=max, backend_specializations=backend_specializations | ||
) | ||
|
||
|
||
@forbid_in_graph | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.