8000 [ONNX] Exported ONNX module with for loop + scatter operation on tensor seems to be incorrect · Issue #29647 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Exported ONNX module with for loop + scatter operation on tensor seems to be incorrect #29647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
julioasotodv opened this issue Nov 12, 2019 · 6 comments
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@julioasotodv
Copy link
julioasotodv commented Nov 12, 2019

🐛 Bug

It looks like creating a nn.Module with the combination of:

  1. A for loop that loops over a tensor
  2. An iterative scatter operation over another tensor
  3. Exporting this module to ONNX

Generates an inaccurate onnx file that the official onnx library cannot load correctly.

To Reproduce

Please take a look at the following code (don't worry; it is not as long as it looks):

import torch
from torch import nn

# Module creation:
class IterativelyModifyTensor(nn.Module):
    """
    Creates a module that takes a 2D tensor/matrix (input_2d_tensor),
    and replaces each row of it with a 1D tensor (substitution_tensor)
    with a for loop.
    
    This is just for demonstration purposes; to see if we can iteratively
    modify matrix rows in Pytorch.
    """
    def __init__(self):
        super(IterativelyModifyTensor, self).__init__()
        
    def forward(self, input_2d_tensor, substitution_tensor):
        output_tensor = (torch
                         .zeros(1)
                         .repeat(input_2d_tensor.size(0) * input_2d_tensor.size(1))
                         .view(input_2d_tensor.size(0), input_2d_tensor.size(1))
                         .long()
                        )
        num_rows = input_2d_tensor.size(0)
        
        for row in range(num_rows):
            dim = 0
            index = torch.ones(1).repeat(input_2d_tensor.size(1)).unsqueeze(0).long() * row
            src = substitution_tensor.unsqueeze(0)
            output_tensor = output_tensor.scatter(dim, index, src)
        
        return output_tensor
    
# Module instantiation:
my_modifier = IterativelyModifyTensor()

# Generate inputs:
input_2d_tensor = torch.zeros(4, 3).long()
substitution_tensor = torch.arange(3)

print("input_2d_tensor:\n%s\n" % input_2d_tensor)
print("substitution_tensor:\n%s\n" % substitution_tensor)

# Call forward:
output = my_modifier(input_2d_tensor, substitution_tensor)

print("output:\n%s\n" % output)

# Scripting the module:
print("scripting the module...\n")

my_modifier_scripted = torch.jit.script(my_modifier)

# See if output from scripted module is correct (it is):
output_from_scripted = my_modifier_scripted(input_2d_tensor, substitution_tensor)

print("output_from_scripted:\n%s\n" % output_from_scripted)

# Exporting the scripted module to ONNX:
print("exporting to ONNX...\n")

torch.onnx.export(model=my_modifier_scripted,
                  args=(input_2d_tensor, substitution_tensor),
                  f="modifier.onnx",
                  example_outputs=[output_from_scripted],
                  input_names=["input_2d_tensor", "substitution_tensor"],
                  output_names=["output"])

# Loading the ONNX model with Microsoft's onnxruntime:
print("loading exported ONNX model with onnxruntime...\n")

import onnxruntime as ort

ort_session = ort.InferenceSession('modifier.onnx')

# See if output onnxruntime is correct (it is):
output_onnx = ort_session.run(None, 
                              {"input_2d_tensor": input_2d_tensor.numpy(),
                               "substitution_tensor": substitution_tensor.numpy()
                              })

print("output_onnx:\n%s\n" % output_onnx[0])

# Loading the ONNX model with the onnx library + caffe2:
print("loading model with onnx (+ caffe2 runtime)...\n")

import onnx
import caffe2.python.onnx.backend as backend

model = onnx.load('modifier.onnx')

caffe_model = backend.prepare(model)

# check_model fails:
onnx.checker.check_model(model)

# This next line does not work either:
#output_caffe = rep.run((input_2d_tensor.numpy(), substitution_tensor.numpy()))

The onnx.checker.check_model(model) fails with:

---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
<ipython-input-102-a14d40767c4c> in <module>
     90 model = onnx.load('modifier.onnx')
     91 
---> 92 caffe_model = backend.prepare(model)
     93 
     94 # check_model fails:

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/caffe2/python/onnx/backend.py in prepare(cls, model, device, raw_values_dict, **kwargs)
    691         '''
    692         if not kwargs.pop('no_check_UNSAFE', False):
--> 693             super(Caffe2Backend, cls).prepare(model, device, **kwargs)
    694         opset_version = None
    695         for imp in model.opset_import:

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/onnx/backend/base.py in prepare(cls, model, device, **kwargs)
     72                 ):  # type: (...) -> Optional[BackendRep]
     73         # TODO Remove Optional from return type
---> 74         onnx.checker.check_model(model)
     75         return None
     76 

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/onnx/checker.py in check_model(model)
     84         C.check_model_path(model)
     85     else:
---> 86         C.check_model(model.SerializeToString())
     87 
     88 

ValidationError: Field 'shape' of type is required but missing.

==> Context: Bad node spec: input: "24" input: "25" input: "22" output: "output" op_type: "Loop" attribute { name: "body" g { node { output: "30" op_type: "Constant" attribute { name: "value" t { dims: 1 data_type: 1 raw_data: "\000\000\200?" } type: TENSOR } } node { input: "input_2d_tensor" output: "31" op_type: "Shape" } node { input: "31" input: "3" output: "32" op_type: "Gather" attribute { name: "axis" i: 0 type: INT } } node { input: "32" output: "33" op_type: "Unsqueeze" attribute { name: "axes" ints: 0 type: INTS } } node { input: "33" output: "34" op_type: "Concat" attribute { name: "axis" i: 0 type: INT } } node { input: "30" input: "34" output: "35" op_type: "Tile" } node { input: "35" output: "36" op_type: "Unsqueeze" attribute { name: "axes" ints: 0 type: INTS } } node { input: "36" output: "37" op_type: "Cast" attribute { name: "to" i: 7 type: INT } } node { input: "37" input: "row.1" output: "38" op_type: "Mul" } node { input: "substitution_tensor" output: "39" op_type: "Unsqueeze" attribute { name: "axes" ints: 0 type: INTS } } node { input: "output_tensor.6" input: "38" input: "39" output: "40" op_type: "Scatter" attribute { name: "axis" i: 0 type: INT } } node { input: "2" output: "41" op_type: "Cast" attribute { name: "to" i: 9 type: INT } } name: "torch-jit-export1" input { name: "row.1" type { tensor_type { elem_type: 7 shape { } } } } input { name: "cond" type { tensor_type { elem_type: 9 } } } input { name: "output_tensor.6" } output { name: "41" } output { name: "40" } } type: GRAPH }

What is weird is that onnxruntime is able to load and perform inference without any problem (it even outputs the right result and seems to capture the dynamic flow). However, the onnx library does not seem to agree.

Expected behavior

The exported onnx file should be compatible with onnx, and caffe2 should be able to emit the same output as onnxruntime.

Environment

Pytorch == 1.3.0 (on OSX, pre-built binary from pytorch's conda channel)
onnxruntime == 0.5.0
onnx == 1.5.0

Additional context

The for loop (in this case iterating over rows in a matrix) + updating iteratively a tensor using scatter is IMO a frequent operation in dynamic flows such as custom attention modules.

Thank you!

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

@mruberry mruberry added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 12, 2019
@dashesy
Copy link
Contributor
dashesy commented Nov 12, 2019

jit scripts have many issues with ONNX now. You should try to find a way to export without jit.script for example find a single scatter operation that does the same thing as that for-loop. Usually you will need a combination of torch.expand (to create proper shape) then torch.gather and then torch.scatter

@BowenBao
Copy link
Collaborator
BowenBao commented Jan 6, 2020

hi @julioasotodv, thanks for reporting this issue.
I could not repro the checker failure locally. It seems to be already addressed in this PR onnx/onnx#2009. Could you try updating onnx to 1.6 and try again?

@julioasotodv
Copy link
Author

@BowenBao sure! Let me try it and see what happens...

@julioasotodv
Copy link
Author

@BowenBao now there is no error; when this line runs:

caffe_model = backend.prepare(model)

The Python process segfaults... I don't know if this error is due to my environment (I just updated pytorch to 1.3.1 from conda channel and onnx 1.6.0 from pip), or if the exception now turned to a segmentation fault.

@garymm
Copy link
Collaborator
garymm commented Jan 28, 2022

With latest version of PyTorch built from onnx_ms_1, the exported graph seems to replace input_2d_tensor with constants, and then the exported graph only has 1 input rather than 2.

TorchScript graph:

graph(%self : __torch__.IterativelyModifyTensor,
      %input_2d_tensor.1 : Tensor,
      %substitution_tensor.1 : Tensor):
  %35 : bool = prim::Constant[value=1]() # /tmp/ipykernel_260140/1389552357.py:25:8
  %26 : bool = prim::Constant[value=0]()
  %25 : int = prim::Constant[value=4]() # /tmp/ipykernel_260140/1389552357.py:17:25
  %6 : NoneType = prim::Constant()
  %3 : int = prim::Constant[value=1]() # /tmp/ipykernel_260140/1389552357.py:18:32
  %12 : int = prim::Constant[value=0]() # /tmp/ipykernel_260140/1389552357.py:19:54
  %5 : int[] = prim::ListConstruct(%3)
  %10 : Tensor = aten::zeros(%5, %6, %6, %6, %6) # /tmp/ipykernel_260140/1389552357.py:17:25
  %13 : int = aten::size(%input_2d_tensor.1, %12) # /tmp/ipykernel_260140/1389552357.py:19:33
  %15 : int = aten::size(%input_2d_tensor.1, %3) # /tmp/ipykernel_260140/1389552357.py:19:59
  %16 : int = aten::mul(%13, %15) # /tmp/ipykernel_260140/1389552357.py:19:33
  %17 : int[] = prim::ListConstruct(%16)
  %18 : Tensor = aten::repeat(%10, %17) # /tmp/ipykernel_260140/1389552357.py:17:25
  %20 : int = aten::size(%input_2d_tensor.1, %12) # /tmp/ipykernel_260140/1389552357.py:20:31
  %22 : int = aten::size(%input_2d_tensor.1, %3) # /tmp/ipykernel_260140/1389552357.py:20:56
  %23 : int[] = prim::ListConstruct(%20, %22)
  %24 : Tensor = aten::view(%18, %23) # /tmp/ipykernel_260140/1389552357.py:17:25
  %output_tensor.1 : Tensor = aten::to(%24, %25, %26, %26, %6) # /tmp/ipykernel_260140/1389552357.py:17:25
  %num_rows.1 : int = aten::size(%input_2d_tensor.1, %12) # /tmp/ipykernel_260140/1389552357.py:23:19
  %output_tensor : Tensor = prim::Loop(%num_rows.1, %35, %output_tensor.1) # /tmp/ipykernel_260140/1389552357.py:25:8
    block0(%row.1 : int, %output_tensor.11 : Tensor):
      %38 : int[] = prim::ListConstruct(%3)
      %43 : Tensor = aten::ones(%38, %6, %6, %6, %6) # /tmp/ipykernel_260140/1389552357.py:27:20
      %45 : int = aten::size(%input_2d_tensor.1, %3) # /tmp/ipykernel_260140/1389552357.py:27:41
      %46 : int[] = prim::ListConstruct(%45)
      %47 : Tensor = aten::repeat(%43, %46) # /tmp/ipykernel_260140/1389552357.py:27:20
      %48 : Tensor = aten::unsqueeze(%47, %12) # /tmp/ipykernel_260140/1389552357.py:27:20
      %53 : Tensor = aten::to(%48, %25, %26, %26, %6) # /tmp/ipykernel_260140/1389552357.py:27:20
      %index.1 : Tensor = aten::mul(%53, %row.1) # /tmp/ipykernel_260140/1389552357.py:27:20
      %src.1 : 
8000
Tensor = aten::unsqueeze(%substitution_tensor.1, %12) # /tmp/ipykernel_260140/1389552357.py:28:18
      %output_tensor.5 : Tensor = aten::scatter(%output_tensor.11, %12, %index.1, %src.1) # /tmp/ipykernel_260140/1389552357.py:29:28
      -> (%35, %output_tensor.5)
  return (%output_tensor)

ONNX graph (from export(verbose=True)):

graph(%substitution_tensor : Long(3, strides=[1], requires_grad=0, device=cpu),
      %42 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %43 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %46 : Long(2, strides=[1], requires_grad=0, device=cpu)):
  %2 : Float(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %3 : Float(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %4 : Bool(device=cpu) = onnx::Constant[value={1}]()
  %5 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={4}]() # /tmp/ipykernel_260140/1389552357.py:19:33
  %6 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={3}]() # /tmp/ipykernel_260140/1389552357.py:19:59
  %13 : Long(1, strides=[1], device=cpu) = onnx::ConstantOfShape[value={1}](%43) # /tmp/ipykernel_260140/1389552357.py:17:25
  %14 : Float(1, device=cpu) = onnx::Expand(%3, %13) # /tmp/ipykernel_260140/1389552357.py:17:25
  %15 : Float(12, device=cpu) = onnx::Tile(%14, %42) # /tmp/ipykernel_260140/1389552357.py:17:25
  %19 : Float(4, 3, device=cpu) = onnx::Reshape(%15, %46) # /tmp/ipykernel_260140/1389552357.py:17:25
  %output_tensor : Long(4, 3, strides=[3, 1], device=cpu) = onnx::Cast[to=7](%19) # /tmp/ipykernel_260140/1389552357.py:17:25
  %output : Long(4, 3, strides=[3, 1], device=cpu) = onnx::Loop(%5, %4, %output_tensor) # /tmp/ipykernel_260140/1389552357.py:25:8
    block0(%row.1 : Long(requires_grad=0, device=cpu), %cond : Bool(device=cpu), %output_tensor.11 : Long(4, 3, strides=[3, 1], device=cpu)):
      %25 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[axes=[0]](%6)
      %26 : Long(1, strides=[1], device=cpu) = onnx::Concat[axis=0](%25)
      %27 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[axes=[0]](%6)
      %28 : Long(1, strides=[1], device=cpu) = onnx::Concat[axis=0](%27) # /tmp/ipykernel_260140/1389552357.py:27:20
      %29 : Long(1, strides=[1], device=cpu) = onnx::Shape(%26)
      %30 : Long(1, strides=[1], device=cpu) = onnx::ConstantOfShape[value={1}](%29)
      %31 : Float(1, device=cpu) = onnx::Expand(%2, %30)
      %32 : Float(3, device=cpu) = onnx::Tile(%31, %28) # /tmp/ipykernel_260140/1389552357.py:27:20
      %33 : Float(1, 3, strides=[3, 1], device=cpu) = onnx::Unsqueeze[axes=[0]](%32) # /tmp/ipykernel_260140/1389552357.py:27:20
      %34 : Long(1, 3, strides=[3, 1], device=cpu) = onnx::Cast[to=7](%33) # /tmp/ipykernel_260140/1389552357.py:27:20
      %index : Long(*, *, strides=[3, 1], device=cpu) = onnx::Mul(%34, %row.1) # /tmp/ipykernel_260140/1389552357.py:27:20
      %src : Long(1, 3, strides=[3, 1], device=cpu) = onnx::Unsqueeze[axes=[0]](%substitution_tensor) # /tmp/ipykernel_260140/1389552357.py:28:18
      %output_tensor.3 : Long(*, *, strides=[3, 1], device=cpu) = onnx::Scatter[axis=0](%output_tensor.11, %index, %src) # /tmp/ipykernel_260140/1389552357.py:29:28
      %38 : Bool(device=cpu) = onnx::Identity(%4)
      -> (%38, %output_tensor.3)
  return (%output)

Tracked internally at Microsoft by https://msdata.visualstudio.com/Vienna/_workitems/edit/1600858

@garymm garymm added the onnx-triaged triaged by ONNX team label Jan 28, 2022
@abock abock added this to ONNX Jun 14, 2023
@github-project-automation github-project-automation bot moved this to Inbox in ONNX Jun 14, 2023
@thiagocrepaldi
Copy link
Collaborator

Closing this as Caffe2 is being removed from pytorch repo (#125038)

@github-project-automation github-project-automation bot moved this from Inbox to Done in ONNX May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

6 participants
0