8000 [export] torch.tensor constructor specializes on float value · Issue #153411 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[export] torch.tensor constructor specializes on float value #153411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
pianpwk opened this issue May 12, 2025 · 2 comments
Closed

[export] torch.tensor constructor specializes on float value #153411

pianpwk opened this issue May 12, 2025 · 2 comments
Assignees
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step module: dynamic shapes oncall: export oncall: pt2

Comments

@pianpwk
Copy link
Contributor
pianpwk commented May 12, 2025

🐛 Describe the bug

exporting a torch.tensor() constructor call on a float scalar specializes on the value, leading to a data-dependent error:

import torch

# fails
class Foo(torch.nn.Module):
    def forward(self, flt):
        scalar = flt.item()
        return torch.tensor([scalar])

ep = torch.export.export(Foo(), (torch.tensor([3.14]),))

error:

W0512 13:48:07.270000 3995695 torch/fx/experimental/symbolic_shapes.py:6837] failed during evaluate_expr(zuf0, hint=None, size_oblivious=False, forcing_spec=False
W0512 13:48:07.271000 3995695 torch/fx/experimental/symbolic_shapes.py:7449] Unable to find user code corresponding to {zuf0}



def forward(self, arg0_1: "f32[1]"):
     # File: /data/users/pianpwk/ptclone/pytorch/test_float_spec.py:6 in forward, code: scalar = flt.item()
    item: "Sym(zuf0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None
    



def forward(self, arg0_1: "f32[1]"):
     # File: /data/users/pianpwk/ptclone/pytorch/test_float_spec.py:6 in forward, code: scalar = flt.item()
    item: "Sym(zuf0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None
    
Traceback (most recent call last):
  File "/data/users/pianpwk/ptclone/pytorch/test_float_spec.py", line 9, in <module>
    ep = torch.export.export(Foo(), (torch.tensor([3.14]),))
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/__init__.py", line 318, in export
    raise e
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/__init__.py", line 285, in export
    return _export(
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1104, in wrapper
    raise e
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1070, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 2117, in _export
    ep = _export_for_training(
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1104, in wrapper
    raise e
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1070, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1978, in _export_for_training
    export_artifact = export_func(
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1920, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1705, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1846, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1624, in _make_fx_helper
    gm = make_fx(
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 2228, in trace
    return self._trace_inner(f, *args)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
  File "/data/users/pianpwk/ptclone/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1787, in trace
    res = super().trace(root, concrete_args)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1528, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/data/users/pianpwk/ptclone/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/nn/modules/module.py", line 1766, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/export/_trace.py", line 1830, in forward
    tree_out = mod(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/d
8000
ata/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/nn/modules/module.py", line 1766, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/test_float_spec.py", line 7, in forward
    return torch.tensor([scalar])
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/proxy_tensor.py", line 1373, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/_export/non_strict_utils.py", line 973, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/sym_node.py", line 526, in guard_float
    r = self.evaluate()
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6810, in evaluate_sym_node
    return self.evaluate_expr(
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/recording.py", line 264, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6826, in evaluate_expr
    return self._evaluate_expr(
  File "/data/users/pianpwk/ptclone/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7098, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression zuf0 (unhinted: zuf0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:973 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="zuf0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/data/users/pianpwk/ptclone/pytorch/test_float_spec.py", line 7, in forward
    return torch.tensor([scalar])

strangely enough, all of these pass:

# passes
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch.compile(fullgraph=True, backend="eager", dynamic=True)
def fn(flt):
    scalar = flt.item()
    return torch.tensor([scalar])

fn(torch.tensor([3.14]))

# passes
class Foo(torch.nn.Module):
    def forward(self, flt):
        scalar = flt.item()
        return torch.tensor([scalar])

ep = torch.export.export(Foo(), (torch.tensor([3]),))

# passes
class Foo(torch.nn.Module):
    def forward(self, flt):
        scalar = flt.item()
        return torch.scalar_tensor(scalar)

ep = torch.export.export(Foo(), (torch.tensor([3.4]),))

Versions

Collecting environment information...
/data/users/pianpwk/ptclone/pytorch/torch/cuda/init.py:799: UserWarning: Can't initialize NVML
warnings.warn("Can't initialize NVML")
PyTorch version: 2.8.0a0+git05326b7
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version: Could not collect
CMake version: version 4.0.2
Libc version: glibc-2.34

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk14_hardened_2601_gcd42476b84e9-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.8.0
/usr/lib64/libcudnn_adv_infer.so.8.8.0
/usr/lib64/libcudnn_adv_train.so.8.8.0
/usr/lib64/libcudnn_cnn_infer.so.8.8.0
/usr/lib64/libcudnn_cnn_train.so.8.8.0
/usr/lib64/libcudnn_ops_infer.so.8.8.0
/usr/lib64/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 92
On-line CPU(s) list: 0-91
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 92
Socket(s): 1
Stepping: 1
BogoMIPS: 4792.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 5.8 MiB (92 instances)
L1i cache: 5.8 MiB (92 instances)
L2 cache: 46 MiB (92 instances)
L3 cache: 1.4 GiB (92 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-91
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.15.0
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0a0+git05326b7
[conda] mkl-include 2025.1.0 pypi_0 pypi
[conda] mkl-static 2025.1.0 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.15.0 pypi_0 pypi
[conda] pytorch-triton 3.3.0+git96316ce5 pypi_0 pypi
[conda] torch 2.8.0a0+git88a068f dev_0

cc @chauhang @penguinwu @ezyang @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@AbhiLegend
Copy link

I am sharing a workaround and also some details ,please @pianpwk have a look.
PyTorch Export and .item() Issue

  • flt.item() causes a data-dependent guard, which torch.export() cannot resolve when symbolic tracing.
  • It works in torch.compile() with capture_scalar_outputs=True.
  • Fixes:
    • Use .reshape(1) or .unsqueeze(0) to avoid scalar extraction.
    • Use torch.scalar_tensor() if scalar extraction is needed (it’s sometimes allowed).

I am also sharing the notebook for reference https://colab.research.google.com/drive/1WbzOktDJ6CR-xQI_9o7GU4ku4nPuNAOS?usp=sharing
If this is the way you want please assign the issue to me @pianpwk

@pianpwk
Copy link
Contributor Author
pianpwk commented May 12, 2025

I am sharing a workaround and also some details ,please @pianpwk hav 8000 e a look.

ideally we'd make a framework fix so the tensor() call doesn't throw an error, the workarounds aren't important.

@avikchaudhuri avikchaudhuri added the export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step label May 20, 2025
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this issue Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step module: dynamic shapes oncall: export oncall: pt2
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0