8000 wip by bobrenjc93 · Pull Request #152749 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

wip #152749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

wip #152749

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
import types
from collections import Counter
from typing import Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union, Any

import torch.nn
from torch.utils._ordered_set import OrderedSet
Expand Down Expand Up @@ -597,12 +597,14 @@
if source not in self.tempvars:
self.tempvars[source] = None

def make_specialized_call_generated_code(
def make_call_specialized_code(
self,
fn_name: str,
specializations: list[tuple[str, BackendSpecialization]],
specializations: list[tuple[str, BackendSpecialization, Any]],
) -> None:
"""Try specializations in order; fall back to fn_name if none match"""
from .variables.builder import (GraphArg)

graphargs = self.tx.output.graphargs
seen_sources: OrderedSet[Source] = OrderedSet()

Expand All @@ -620,18 +622,17 @@
if arg.source is not None:
collect_temp_source(arg.source)

for fn_spec_name, spec in specializations:
# Reconstruct the source expression to evaluate the specialization condition
self.call_reconstruct(GraphArg(source=spec.source, example_value=None))
self.extend_output(self.create_store_var("spec_value"))
for fn_spec_name, spec, _ in specializations:
self.call_reconstruct(spec.source)
self.extend_output(self.create_store("spec_value"))

# Load the specialization function and call it with spec_value
self.extend_output(self.create_load_const(spec.specialization))
self.extend_output(self.create_load_var("spec_value"))
self.extend_output(self.create_load_const(spec.check_fn))
self.extend_output(self.create_load("spec_value"))
self.extend_output(create_call_function(1, False))

skip_label = self.new_block()

Check failure on line 634 in torch/_dynamo/codegen.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PyCodegen" has no attribute "new_block"
self.extend_output(create_jump_if_false(skip_label))

Check failure on line 635 in torch/_dynamo/codegen.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [name-defined]

Name "create_jump_if_false" is not defined

# If specialization matched, call the specialized function
self.extend_output(self.load_function_name(fn_spec_name, True))
Expand All @@ -648,8 +649,8 @@
else:
self.call_reconstruct(arg)
self.extend_output(create_call_function(len(graphargs), False))
self.extend_output(create_jump(self.end_block))

Check failure on line 652 in torch/_dynamo/codegen.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PyCodegen" has no attribute "end_block"

Check failure on line 652 in torch/_dynamo/codegen.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [name-defined]

Name "create_jump" is not defined
self.start_block(skip_label)

Check failure on line 653 in torch/_dynamo/codegen.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PyCodegen" has no attribute "start_block"

# No specialization matched — call base function
self.extend_output(self.load_function_name(fn_name, True))
Expand Down
12 changes: 6 additions & 6 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,11 +1496,11 @@ def compile_and_call_fx_graph(self, tx, rv, root):
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode

specialized_compiled_fns = []
specialized_compiles = []
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
for specialization in old_fake_mode.shape_env.backend_specializations:
specialized_compiled_fns.append 10000 ((
specialized_compiles.append((
unique_id("__specialized_compiled_fn"),
specialization,
self.call_user_compiler(gm, specialization=specialization)
Expand Down Expand Up @@ -1539,10 +1539,10 @@ def compile_and_call_fx_graph(self, tx, rv, root):

cg = PyCodegen(tx)

if specialized_compiled_fns:
for name, specialized_compiled_fn in specialized_compiled_fns:
self.install_global_unsafe(name, specialized_compiled_fn)
cg.make_call_specialized_code(name, specialized_compiled_fns)
if specialized_compiles:
for fn_name, _, specialized_compiled_fn in specialized_compiles:
self.install_global_unsafe(fn_name, specialized_compiled_fn)
cg.make_call_specialized_code(name, specialized_compiles)
else:
cg.make_call_generated_code(name)
return cg.get_instructions()
Expand Down
6 changes: 3 additions & 3 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def find_symbol_binding_fx_nodes(
return r


@dataclass
@dataclass(frozen=True)
class BackendSpecialization:
source: TensorPropertySource
hint: int
Expand Down Expand Up @@ -3579,7 +3579,7 @@ def _init(

self.trace_asserts = trace_asserts

self.backend_specializations = []
self.backend_specializations = set()

from torch.fx.experimental.validator import translation_validation_enabled

Expand Down Expand Up @@ -4065,7 +4065,7 @@ def _produce_dyn_sizes_from_int_tuple(
symbolic_context=symbolic_context,
)
for specialization in symbolic_context.backend_specializations[i]:
self.backend_specializations.append(BackendSpecialization(
self.backend_specializations.add(BackendSpecialization(
TensorPropertySource(source, TensorProperty.SIZE, i),
*specialization,
))
Expand Down
Loading
0