10000 export: `tensor.view()` fails with dynamic shapes. · Issue #153174 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

export: tensor.view() fails with dynamic shapes. #153174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ysiraichi opened this issue May 8, 2025 · 3 comments
Open

export: tensor.view() fails with dynamic shapes. #153174

ysiraichi opened this issue May 8, 2025 · 3 comments

Comments

@ysiraichi
Copy link
Collaborator
ysiraichi commented May 8, 2025

🐛 Describe the bug

As far as I understand it, the following should work:

class Foo(torch.nn.Module):
    def forward(self, a, b):
        u0 = a.item()
        y = torch.zeros(u0, 18, b.shape[0])
        torch._check((u0 * 18 * b.shape[0]) % 144 == 0) 
        return y.view(-1, 144)

ep = export(
    Foo(),
    (torch.tensor([6]), torch.randn(8)),
    dynamic_shapes={
        "a": None,
        "b": (Dim.DYNAMIC,),
    },
)

It also fails on dynamo, turning the following options:

torch._dynamo.config.capture_dynamic_output_shape_ops = 1
torch._dynamo.config.capture_scalar_outputs = 1

However, the following error is raised:

Traceback (most recent call last):
  File "../examples/issue-151491.py", line 23, in <module>
    ep = export(
  File "torch/export/__init__.py", line 318, in export
    raise e
  File "torch/export/__init__.py", line 285, in export
    return _export(
  File "torch/export/_trace.py", line 1104, in wrapper
    raise e
  File "torch/export/_trace.py", line 1070, in wrapper
    ep = fn(*args, **kwargs)
  File "torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "torch/export/_trace.py", line 2117, in _export
    ep = _export_for_training(
  File "torch/export/_trace.py", line 1104, in wrapper
    raise e
  File "torch/export/_trace.py", line 1070, in wrapper
    ep = fn(*args, **kwargs)
  File "torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "torch/export/_trace.py", line 1978, in _export_for_training
    export_artifact = export_func(
  File "torch/export/_trace.py", line 1920, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "torch/export/_trace.py", line 1705, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "torch/export/_trace.py", line 1846, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "torch/export/_trace.py", line 1624, in _make_fx_helper
    gm = make_fx(
  File "torch/fx/experimental/proxy_tensor.py", line 2288, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "torch/fx/experimental/proxy_tensor.py", line 2226, in trace
    return self._trace_inner(f, *args)
  File "torch/fx/experimental/proxy_tensor.py", line 2197, in _trace_inner
    t = dispatch_trace(
  File "torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "torch/_dynamo/eval_frame.py", line 857, in _fn
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "torch/fx/experimental/proxy_tensor.py", line 1785, in trace
    res = super().trace(root, concrete_args)
  File "torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
  File "torch/fx/experimental/proxy_tensor.py", line 1276, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "torch/export/_trace.py", line 1528, in wrapped_fn
    return tuple(flat_fn(*args))
  File "torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1855, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1766, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/export/_trace.py", line 1830, in forward
    tree_out = mod(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1855, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1766, in _call_impl
    return forward_call(*args, **kwargs)
  File "../examples/issue-151491.py", line 13, in forward
    return y.view(-1, 144)
  File "torch/fx/experimental/proxy_tensor.py", line 1324, in __torch_function__
    return func(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1371, in __torch_function__
    return func(*args, **kwargs)
  File "torch/_export/non_strict_utils.py", line 969, in __torch_function__
    return func(*args, **kwargs)
  File "torch/_ops.py", line 925, in handler
    return torch._library.utils.handle_dispatch_mode(
  File "torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  File "torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1426, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 926, in proxy_call
    out = func(*args, **kwargs)
  File "torch/_ops.py", line 806, in __call__
    return self._op(*args, **kwargs)
  File "torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1338, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1986, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1450, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 2529, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
  File "torch/_refs/__init__.py", line 4728, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "torch/_refs/__init__.py", line 3840, in _reshape_view_helper
    maybe_throw_dde()
  File "torch/_refs/__init__.py", line 3833, in maybe_throw_dde
    f()
  File "torch/_refs/__init__.py", line 3838, in <lambda>
    deferred.append(lambda: bool(accum % length != 0))
  File "torch/__init__.py", line 753, in __bool__
    return self.node.bool_()
  File "torch/fx/experimental/sym_node.py", line 614, in bool_
    return self.guard_bool("", 0)
  File "torch/fx/experimental/sym_node.py", line 536, in guard_bool
    r = self.evaluate()
  File "torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "torch/fx/experimental/symbolic_shapes.py", line 6840, in evaluate_sym_node
    return self.evaluate_expr(
  File "torch/fx/experimental/recording.py", line 264, in wrapper
    return retlog(fn(*args, **kwargs))
  File "torch/fx/experimental/symbolic_shapes.py", line 6856, in evaluate_expr
    return self._evaluate_expr(
  File "torch/fx/experimental/symbolic_shapes.py", line 7128, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression False (unhinted: Ne(Mod(18*s58*u0, ((s58*u0)//8)), 0)).  (Size-like symbols: none)

Caused by: (_refs/__init__.py:3838 in <lambda>)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=""
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "../examples/issue-151491.py", line 13, in forward
    return y.view(-1, 144)

To fix the error, insert one of the following checks before this call:
  1. torch._check(False)
  2. torch._check(True)

(These suggested fixes were derived by replacing  in False and its negation.)

Versions

PyTorch: e9e1aac (May 2, 2025)

cc @chauhang @penguinwu @ezyang @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@ysiraichi
Copy link
Collaborator Author

I believe this should work, since we can infer the size of the first dimension of the view by y.numel() // 144. This would be true if y.numel() % 144 == 0, which is asserted in the torch._check().

@ydwu4
Copy link
Contributor
ydwu4 commented May 12, 2025

Yeah, symbolic reasoning system have a hard time solving complex constraints like this sometimes. cc @pianpwk

@pianpwk
Copy link
Contributor
pianpwk commented May 12, 2025

did we need a copy of this? #151491

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants
0