8000 [multigraph] add backend_specialization kwarg to mark_dynamic by bobrenjc93 · Pull Request #152597 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] add backend_specialization kwarg to mark_dynamic #152597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Changes from all commits
Commits
File filter 8000

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions torch/_dynamo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def mark_unbacked(t, index, strict=False):


@forbid_in_graph
def mark_dynamic(t, index, *, min=None, max=None):
def mark_dynamic(t, index, *, min=None, max=None, backend_specializations=None):
Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can we do this for mark_unbacked as well? I think vLLM might actually want to use mark_unbacked
  2. We should probably error somehow if the range on the symint ends up not including the backend_specialization

"""
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.

Expand All @@ -575,6 +575,18 @@ def mark_dynamic(t, index, *, min=None, max=None):
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
before torch.compile.

5) If backend specializations is passed in, we will perform a single generic Dynamo trace followed by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do multiple different mark'ed dynamic dims with backend specializations produce a cross product of specialized graphs?

multiple specialized compilations in addition to a single generic compilation. At runtime, we will dispatch
to a specialized compiled region if the input matches the specialization criteria. For example:

mark_dynamic(..., backend_specializations=[(
16, # hint value
lambda x: x == 16 # specialization predicate
)])

Comment on lines +578 to +586
Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_specializations is a bit wordy. I don't have a better idea though.

mark_dynamic(..., specialize_on=[
   16,
   lambda x: x == 16,
])

?

This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 16
at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate
2-8x speedups depending on the specific specialization and model architecture.
"""
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
Expand All @@ -587,14 +599,21 @@ def mark_dynamic(t, index, *, min=None, max=None):
if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set()
t._dynamo_dynamic_range = set()

if not hasattr(t, "_backend_specializations"):
t._backend_specializations = {}

# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max))
t._backend_specializations[index] = backend_specializations
return

assert isinstance(index, (list, tuple))
for i in index:
mark_dynamic(t, i, min=min, max=max)
mark_dynamic(
t, i, min=min, max=max, backend_specializations=backend_specializations
)


@forbid_in_graph
Expand Down
Loading
0