8000 torch.onnx.export with dynamic axes fails for torch.nn.InstanceNorm1d with track_running_stats=True · Issue #128501 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
torch.onnx.export with dynamic axes fails for torch.nn.InstanceNorm1d with track_running_stats=True #128501
@cdeln

Description

@cdeln

🐛 Describe the bug

torch.onnx.export fails for torch.nn.InstanceNorm1d (and 2d for that sake) when initialized with track_running_stats=True and using dynamic batch size.

Minimal reproducing code:

import torch

B, C, N = 3, 7, 11
model = torch.nn.InstanceNorm1d(C, track_running_stats=True)
model.eval()
inputs = torch.zeros(B, C, N)
torch.onnx.export(model, inputs, '/tmp/model.onnx',
                  input_names=['inputs'],
                  output_names=['outputs'],
                  dynamic_axes={
                      'inputs': {0: 'batch_size'},
                      'outputs': {0: 'batch_size'}
                  })

Traceback (This is the relevant line I believe: SymbolicValueError: Unsupported: ONNX export of instance_norm training for unknown batch size.):

---------------------------------------------------------------------------
SymbolicValueError                        Traceback (most recent call last)
Cell In[57], line 8
      6 model.eval()
      7 inputs = torch.zeros(B, C, N)
----> 8 torch.onnx.export(model, inputs, '/tmp/model.onnx',
      9                   input_names=['inputs'],
     10                   output_names=['outputs'],
     11                   dynamic_axes={
     12                       'inputs': {0: 'batch_size'},
     13                       'outputs': {0: 'batch_size'}
     14                   })

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/utils.py:506, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
    504     """
--> 506     _export(
    507         model,
    508         args,
    509         f,
    510         export_params,
    511         verbose,
    512         training,
    513         input_names,
    514         output_names,
    515         operator_export_type=operator_export_type,
    516         opset_version=opset_version,
    517         do_constant_folding=do_constant_folding,
    518         dynamic_axes=dynamic_axes,
    519         keep_initializers_as_inputs=keep_initializers_as_inputs,
    520         custom_opsets=custom_opsets,
    521         export_modules_as_functions=export_modules_as_functions,
    522     )

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/utils.py:1548, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
   1545     dynamic_axes = {}
   1546 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1548 graph, params_dict, torch_out = _model_to_graph(
   1549     model,
   1550     args,
   1551     verbose,
   1552     input_names,
   1553     output_names,
   1554     operator_export_type,
   1555     val_do_constant_folding,
   1556     fixed_batch_size=fixed_batch_size,
   1557     training=training,
   1558     dynamic_axes=dynamic_axes,
   1559 )
   1561 # TODO: Don't allocate a in-memory string for the protobuf
   1562 defer_weight_export = (
   1563     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1564 )

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/utils.py:1117, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1114 params_dict = _get_named_param_dict(graph, params)
   1116 try:
-> 1117     graph = _optimize_graph(
   1118         graph,
   1119         operator_export_type,
   1120         _disable_torch_constant_prop=_disable_torch_constant_prop,
   1121         fixed_batch_size=fixed_batch_size,
   1122         params_dict=params_dict,
   1123         dynamic_axes=dynamic_axes,
   1124         input_names=input_names,
   1125         module=module,
   1126     )
   1127 except Exception as e:
   1128     torch.onnx.log("Torch IR graph at exception: ", graph)

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/utils.py:665, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    662     _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
    663 _C._jit_pass_onnx_lint(graph)
--> 665 graph = _C._jit_pass_onnx(graph, operator_export_type)
    666 _C._jit_pass_onnx_lint(graph)
    667 _C._jit_pass_lint(graph)

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/utils.py:1891, in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
   1886     if symbolic_fn is not None:
   1887         # TODO Wrap almost identical attrs assignment or comment the difference.
   1888         attrs = {
   1889             k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
   1890         }
-> 1891         return symbolic_fn(graph_context, *inputs, **attrs)
   1893 attrs = {
   1894     k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
   1895     for k in node.attributeNames()
   1896 }
   1897 if namespace == "onnx":
   1898     # Clone node to trigger ONNX shape inference

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py:306, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    300 if len(kwargs) == 1:
    301     assert "_outputs" in kwargs, (
    302         f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
    303         f"'_outputs' key at '**kwargs'. "
    304         f"{FILE_BUG_MSG}"
    305     )
--> 306 return fn(g, *args, **kwargs)

File ~/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:2895, in instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled)
   2893 n = input_size[0]
   2894 if n is None:
-> 2895     raise errors.SymbolicValueError(
   2896         "Unsupported: ONNX export of instance_norm training for unknown "
   2897         "batch size.",
   2898         input,
   2899     )
   2900 c = input_size[1]
   2901 input_size_reshape[0] = 1

SymbolicValueError: Unsupported: ONNX export of instance_norm training for unknown batch size.  [Caused by the value 'input defined in (%input : Float(*, 3, 11, strides=[33, 11, 1], requires_grad=0, device=cpu), %running_mean : Float(3, strides=[1], requires_grad=0, device=cpu), %running_var : Float(3, strides=[1], requires_grad=0, device=cpu), %num_batches_tracked : Long(requires_grad=0, device=cpu) = prim::Param()
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'prim::Param'.] 

    Inputs:
        Empty
    Outputs:
        #0: input defined in (%input : Float(*, 3, 11, strides=[33, 11, 1], requires_grad=0, device=cpu), %running_mean : Float(3, strides=[1], requires_grad=0, device=cpu), %running_var : Float(3, strides=[1], requires_grad=0, device=cpu), %num_batches_tracked : Long(requires_grad=0, device=cpu) = prim::Param()
    )  (type 'Tensor')
        #1: running_mean defined in (%input : Float(*, 3, 11, strides=[33, 11, 1], requires_grad=0, device=cpu), %running_mean : Float(3, strides=[1], requires_grad=0, device=cpu), %running_var : Float(3, strides=[1], requires_grad=0, device=cpu), %num_batches_tracked : Long(requires_grad=0, device=cpu) = prim::Param()
    )  (type 'Tensor')
        #2: running_var defined in (%input : Float(*, 3, 11, strides=[33, 11, 1], requires_grad=0, device=cpu), %running_mean : Float(3, strides=[1], requires_grad=0, device=cpu), %running_var : Float(3, strides=[1], requires_grad=0, device=cpu), %num_batches_tracked : Long(requires_grad=0, device=cpu) = prim::Param()
    )  (type 'Tensor')
        #3: num_batches_tracked defined in (%input : Float(*, 3, 11, strides=[33, 11, 1], requires_grad=0, device=cpu), %running_mean : Float(3, strides=[1], requires_grad=0, device=cpu), %running_var : Float(3, strides=[1], requires_grad=0, device=cpu), %num_batches_tracked : Long(requires_grad=0, device=cpu) = prim::Param()
    )  (type 'Tensor')

Note that it works by removing dynamic axes.
It "only" gives a warning if initializing the layer with track_running_stats=False:

============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

/home/cdeln/.pyenv/versions/3.10.6/envs/default/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py:1466: UserWarning: ONNX export mode is set to TrainingMode.EVAL, but operator 'instance_norm' is set to train=True. Exporting with train=True.
  warnings.warn(

Related issues:

Versions

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.26.3
Libc version: glibc-2.35

Python version: 3.10.6 (main, Apr 26 2023, 09:09:46) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-27-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i5-1235U
CPU family: 6
Model: 154
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 4
CPU max MHz: 4400.0000
CPU min MHz: 400.0000
BogoMIPS: 4992.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
L1d cache: 352 KiB (10 instances)
L1i cache: 576 KiB (10 instances)
L2 cache: 6.5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] fft-conv-pytorch==1.1.3
[pip3] numpy==1.24.3
[pip3] onnx==1.14.0
[pip3] onnxruntime==1.15.1
[pip3] onnxsim==0.4.33
[pip3] pytorch3d==0.7.4
[pip3] pytorchltr==0.2.1
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.2
[pip3] torchmetrics==1.2.0
[pip3] torchvision==0.15.1
[pip3] triton==2.0.0
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

OSS contribution wantedPR from open source contributors welcome to solve this issue.module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0