10BC0 [multigraph] use backend specializations in compile_and_call_fx_graph by bobrenjc93 · Pull Request #152601 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ade77f1
use backend specializations in compile_and_call_fx_graph
bobrenjc93 May 1, 2025
c8e359b
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 2, 2025
87b796e
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 2, 2025
2fc599a
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 3, 2025
7b49125
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 3, 2025
2f887ce
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 3, 2025
49e8ee5
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 3, 2025
26915b7
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
b9a6bc5
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
24b506c
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
6a82fcc
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
e8eb158
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
c19b818
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
fc279ed
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
81621b3
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
994b1d4
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
415ea24
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
b119bb1
Update on "use backend specializations in compile_and_call_fx_graph"
bobrenjc93 May 4, 2025
c2fb9cc
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
f0fb81e
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
becf7f6
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
00abea0
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
3e521df
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
6302ff8
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
ee88cb9
Update on "[multigraph] use backend specializations in compile_and_ca…
bobrenjc93 May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
14256e6040d9e14698a877924456cdd92bfcd01d
8eeef7f5b5363e9f35576184659226cc082311d6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional?

46 changes: 46 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10464,6 +10464,52 @@ def f(x):
self.assertEqual(out_ref.stride(), out_test.stride())
self.assertEqual(x_ref, x_test)

@requires_gpu()
@skip_if_not_triton
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_inductor_multiple_specializations(self):
from triton.testing import do_bench

@torch.compile(
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
dynamic=False,
)
def inductor_matmul(a, b):
torch._check(a.shape[0] == b.shape[1])
return (m, torch.mm(a, b))

m = 16
k = 1280
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, The function is decorated with requires_gpu which means the case will run on GPUs like cuda/xpu, but the hard code cuda here will fail on other GPUs like XPU.

Suggested change
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)

dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)

b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)

torch._dynamo.decorators.mark_dynamic(
dynamic_a,
0,
)
torch._dynamo.decorators.mark_dynamic(
dynamic_specialized_a,
0,
backend_specializations=[
(16, lambda x0: x0 == 16),
],
Comment on lines +10498 to +10500
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an api flow for when you want to specify conditions on multiple vars?

E.g.

lambda x, y: x == 1 and y == 1 
lambda: x, y: x % 16 and y % 16

You dont necessarily want to specialize on x == 1 and y % 16, which I assume would fall out of the pairwise specializations

Copy link
Contributor Author
@bobrenjc93 bobrenjc93 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment. vLLM actually only has one symbolic variable (https://www.anyscale.com/blog/continuous-batching-llm-inference) so we don't need to worry about that for our first customer. That being said, I'm happy to bikeshed what a better multi-var API may look like during composability.

)
torch._dynamo.decorators.mark_dynamic(
b,
1,
)
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
torch._dynamo.reset()
dynamic_specialized = do_bench(
lambda: inductor_matmul(dynamic_specialized_a, b)
)
Comment on lines +10508 to +10510
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the output code

self.assertGreaterEqual(dynamic, dynamic_specialized)

@requires_gpu()
def test_stride_preservation_with_stride_modifying_fx_pass(self):
def f(x):
Expand Down
4 changes: 2 additions & 2 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
)
self.config[attr_name] = val

def __call__(self, model_, inputs_):
def __call__(self, model_, inputs_, **kwargs):
from torch._inductor.compile_fx import compile_fx

return compile_fx(model_, inputs_, config_patches=self.config)
return compile_fx(model_, inputs_, config_patches=self.config, **kwargs)

def get_compiler_config(self):
from torch._inductor.compile_fx import get_patched_config_dict
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@ def mark_dynamic(t, index, *, min=None, max=None, backend_specializations=None):
# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max))
t._backend_specializations[index] = backend_specializations
t._backend_specializations[index] = (
backend_specializations if backend_specializations is not None else []
)
return

assert isinstance(index, (list, tuple))
Expand Down
50 changes: 45 additions & 5 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
from torch._C._dynamo import guards
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
from torch._guards import (
CompileContext,
Expand Down Expand Up @@ -157,6 +158,8 @@
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")

RootGuardManager = guards.RootGuardManager


@dataclass(frozen=True)
class VariableTrackerCacheKey:
Expand Down Expand Up @@ -1496,8 +1499,34 @@ def compile_and_call_fx_graph(self, tx, rv, root):
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode

specialized_compiles = []
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
sources = [a.source for a in self.graphargs]
for specialization in old_fake_mode.shape_env.backend_specializations:
source_index = sources.index(specialization.source)
check_fn_source = inspect.getsource(specialization.check_fn).strip()
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
specialization.check_fn,
[check_fn_source],
)

log.debug(
"Compiling backend specialized graph with specialization=%s",
check_fn_source,
)

specialized_compiles.append(
(
functools.partial(
lambda idx, args, check_fn=check_fn: check_fn(
args[idx]
),
source_index,
),
self.call_user_compiler(gm, specialization=specialization),
Comment on lines +1506 to +1527
Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls the backend compiler with the (tensor_args,) for this graph and the specialization argument, right?

I'm not sure this is the right design. The (tensor_args,) don't have the same shape as the specialization -- will that be a problem?

An alternative design is that there is some lazy dispatching layer right after Dynamo but before AOTAutograd. Let's say the user calls the following for the first time:

# A
x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)

Then this traces out a graph from Dynamo with dynamic shapes.

Then, on future calls to torch.compile:

# B
y = torch.randn(1)
torch.compile(f)(y)
  • On seeing a specialized shape for the first time: this skips Dynamo but directly forwards the args (y,) to the backend to compile a graph
# C
z = torch.randn(1)
torch.compile(f)(z)
  • On seeing a specialized shape again: this pulls up the graph the backend compiled for said shape.

One way to implement this is:

  • Let's think about the Dynamo cache as a mapping from guards to a callable
  • After (A), there is a guard for each of the specializations: {"batch_size==1": call_backend_compile(), "batch_size==2": call_backend_compile(), "batch_size==anything_else": compiled_artifact}
  • (B) hits the call_backend_compile() function, which will compile a backend function and replace the Dynamo cache entry with {"batch_size==1": compiled_artifact}
  • Future hits to this guard (e.g. C) will just hit the compiled artifact.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benefit of the alternative lazy design is that the backend doesn't need to work hard to figure out how to do the specialization: it's almost like calling regular torch.compile again, except it is able to skip Dynamo.

One side effect is that we don't have to impose constraints on the strides (this PR needs to do that because it needs to figure out how to create a FakeTensor, right?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense. cc @anijain2305 for thoughts as well

Copy link
Contributor
@anijain2305 anijain2305 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few details that we need to think about

  1. We will have multiple cache entries per code object here. For example, our cache size limit is 8, but the specialization here will require us to raise cache size limit for certain code objects.

  2. Dynamo cache as a mapping from guards to a callable - This is true, but there is a subtle difference. Dynamo does guards to bytecode mapping. This bytecode contains the call to the compiled_graph (not Fx graph, a compiled graph). So in this design, we will have to figure out how to (1) stash the bytecode, and (2) stash the Dynamo graph.

  3. Overwriting cache entry is also questionable.

Maybe we have the bytecode that calls the backend_compile. And then the backend_compile internally checks if there is a compiled code. If yes, then run the compiled code, otherwise run the AOT + Inductor compilation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 thoughts on #153449 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue with the current implementation? Its not bad. It gives the hierarchical feel, which kind of makes sense in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following code:

x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)
x = torch.randn(1)
torch.compile(f)(x)
x = torch.randn(2)
torch.compile(f)(x)

On the first torch.compile call, we will attempt to compile all of the backend specializations. That torch.compile call only has one set of sample inputs (of shape [3]). The problems I'm worried about is:
a) Compile time will be slow up front. On the first torch.compile it looks like we call the backend compiler three times.
b) Because there are no real tensor inputs of shape [1] and shape [2], we need to guess at those tensors and assume that they're contiguous. This doesn't seem very good

The lazier design (#153449) solves this by (a) deferring compilation of shape [1] and shape [2] until we actually see inputs of those shapes and (b) if the strides change then it's a recompile

)
)

from torch.fx._lazy_graph_module import _LazyGraphModule

Expand Down Expand Up @@ -1528,7 +1557,18 @@ def compile_and_call_fx_graph(self, tx, rv, root):

counters["stats"]["unique_graphs"] += 1
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)
if specialized_compiles:

@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
def specialized_dispatch(*args, **kwargs):
for check_fn, specialized_compiled_fn in specialized_compiles:
if check_fn(args):
return specialized_compiled_fn(*args, **kwargs)
return compiled_fn(*args, **kwargs)

self.install_global_unsafe(name, specialized_dispatch)
else:
self.install_global_unsafe(name, compiled_fn)

cg = PyCodegen(tx)
cg.make_call_generated_code(name)
Expand All @@ -1542,16 +1582,16 @@ def placeholders(self) -> list[fx.Node]:
def graphargs(self) -> list[GraphArg]:
return [node.meta["grapharg"] for node in self.placeholders]

def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
with dynamo_timed(
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)
return self._call_user_compiler(gm, **kwargs)

def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def _call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
assert self.compiler_fn is not None
tot = 0
placeholders = []
Expand Down Expand Up @@ -1581,7 +1621,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
compiler_fn = self.compiler_fn
if config.verify_correctness:
compiler_fn = WrapperBackend(compiler_fn)
compiled_fn = compiler_fn(gm, self.example_inputs())
compiled_fn = compiler_fn(gm, self.example_inputs(), **kwargs)
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except (TensorifyScalarRestartAnalysis, ShortenTraceback):
Expand Down
8 changes: 5 additions & 3 deletions torch/_dynamo/repro/after_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_paths(exc):
# Check for either accuracy (level 4) or other type of failures.
if config.repro_level == 4:
# Check Accuracy
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs, **kwargs)
if _accuracy_fails(gm, example_inputs, compiler_fn):
log.warning(
"Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error."
Expand All @@ -125,7 +125,9 @@ def add_paths(exc):
raise exc
else:
try:
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
compiled_gm = compiler_fn(
copy.deepcopy(gm), example_inputs, **kwargs
)
run_fwd_maybe_bwd(compiled_gm, example_inputs)
except Exception as exc:
log.warning(
Expand All @@ -147,7 +149,7 @@ def add_paths(exc):
add_paths(exc)
raise
else:
compiled_gm = compiler_fn(gm, example_inputs)
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)

return compiled_gm

Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,6 +3053,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides = []
constraint_sizes = []
constraint_strides = []
backend_specializations = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_strict_unbacked = i in getattr(
Expand All @@ -3063,6 +3064,10 @@ def update_dim2constraint(dim, constraint_range, name):
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

backend_specializations.append(
getattr(e, "_backend_specializations", {}).get(i, [])
)

# Reflect the user directive in the frame_state
# For dynamic, apply None always

Expand Down Expand Up @@ -3182,6 +3187,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides=dynamic_strides,
constraint_sizes=constraint_sizes,
constraint_strides=constraint_strides,
backend_specializations=backend_specializations,
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
Expand Down
8 changes: 5 additions & 3 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def fakify(
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
else:
dynamic_sizes.append(DimDynamic.STATIC)
symbolic_context = StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
)
)
t_id = id(t)
assert mode.shape_env is not None
Expand Down
12 changes: 10 additions & 2 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_pytree_subclasses_that_lose_info,
make_fx,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.experimental.symbolic_shapes import BackendSpecialization, ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass


Expand Down Expand Up @@ -489,6 +489,7 @@ def process_inputs(
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
ignore_shape_env: bool = False,
specialization: Optional[BackendSpecialization] = None,
) -> FakifiedFlatArgs:
with fake_mode:

Expand Down Expand Up @@ -547,6 +548,7 @@ def convert(idx, x):
symbolic_context=symbolic_context,
source=source,
trace=trace,
specialization=specialization,
)
return result

Expand Down Expand Up @@ -1084,6 +1086,7 @@ def aot_module_simplified(
cudagraphs: Optional[BoxedBool] = None,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
ignore_shape_env: bool = False,
specialization: Optional[BackendSpecialization] = None,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
Expand Down Expand Up @@ -1155,7 +1158,12 @@ def aot_module_simplified(
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
fake_flat_args = process_inputs(
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
full_args,
aot_config,
fake_mode,
shape_env,
ignore_shape_env,
specialization=specialization,
)

def dispatch_and_compile():
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@

from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import BackendSpecialization

from .ir import ExternKernelNode

Expand Down Expand Up @@ -1914,6 +1915,7 @@ def compile_fx(
config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
ignore_shape_env: bool = False,
specialization: Optional[BackendSpecialization] = None,
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
"""
Main entry point for compiling given FX graph. Despite the fact that this
Expand All @@ -1939,6 +1941,7 @@ def compile_fx(
inner_compile=config.patch(config_patches)(inner_compile),
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
specialization=specialization,
)

# TODO: This probably shouldn't be a recursive call
Expand Down Expand Up @@ -1995,13 +1998,15 @@ def compile_fx(
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
specialization=specialization,
)

recursive_compile_fx = functools.partial(
compile_fx,
inner_compile=inner_compile,
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
specialization=specialization,
)

if not graph_returns_tuple(model_):
Expand Down Expand Up @@ -2332,6 +2337,7 @@ def bw_compiler(
cudagraphs=cudagraphs,
boxed_forward_device_index=forward_device,
ignore_shape_env=ignore_shape_env,
specialization=specialization,
)(model_, example_inputs_)
except ShortenTraceback as e:
# We will also shorten the traceback inside dynamo.
Expand Down
10 changes: 9 additions & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@

from torch._guards import Source
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
from torch.fx.experimental.symbolic_shapes import (
BackendSpecialization,
ShapeEnv,
SymbolicContext,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -354,6 +358,7 @@ def from_real_tensor(
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
trace: bool = True,
specialization: Optional[BackendSpecialization] = None,
) -> FakeTensor:
# see note [Tensor Fakification and Symbol Caching]
if not symbolic_context and not source and shape_env:
Expand Down Expand Up @@ -408,6 +413,7 @@ def mk_fake_tensor(
source=source,
symbolic_context=symbolic_context,
trace=trace,
specialization=specialization,
)
if out is NotImplemented:
raise UnsupportedFakeTensorException("meta converter nyi")
Expand Down Expand Up @@ -2864,6 +2870,7 @@ def from_tensor(
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
trace: bool = True,
specialization: Optional[BackendSpecialization] = None,
) -> FakeTensor:
shape_env: Optional[ShapeEnv] = self.shape_env
if static_shapes is None:
Expand All @@ -2880,6 +2887,7 @@ def from_tensor(
source=source,
symbolic_context=symbolic_context,
trace=trace,
specialization=specialization,
)


Expand Down
Loading
Loading
0