8000 [Regression][PT2.1][Dynamic] torch._dynamo.exc.TorchRuntimeError: Failed running call_method index_add · Issue #111203 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Regression][PT2.1][Dynamic] torch._dynamo.exc.TorchRuntimeError: Failed running call_method index_add #111203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jay746 opened this issue Oct 13, 2023 · 3 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jay746
Copy link
jay746 commented Oct 13, 2023

🐛 Describe the bug

Tensor.index_addwith out tensor has been regressed in PT2,1 upgrade. Same test used to work in PT2.1
Please use below code to reproduce the issue.

import torch


def fn(dim, index, source, alpha):
    res = inputs.index_add(dim, index, source, alpha=alpha)
    return res


if __name__ == "__main__":
    inputs = torch.randn([4, 16, 32, 8])
    dim = 1
    alpha = 2
    index = torch.tensor([0, 1, 2])
    source = torch.randn([4, 3, 32, 8])
    compl_fn = torch.compile(fn, dynamic=True, backend="eager")
    res = compl_fn(dim, index, source, alpha)
    print(res)

Error logs

Traceback (most recent call last):
File "debug_abs.py", line 16, in
res = compl_fn(dim, index, source, alpha)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
super().run()
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1167, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/variables/misc.py", line 594, in call_function
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/variables/tensor.py", line 655, in call_method
return wrap_fx_proxy(
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 1187, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 1274, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1376, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.traceback) from None
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1337, in get_fake_value
return wrap_fake_exception(
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
return fn()
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1338, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1410, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.traceback) from e
File "/tmp/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1399, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
return decomposition_table[func](*args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_prims_common/wrappers.py", line 229, in _fn
result = fn(args, **kwargs)
File "/tmp/lib/python3.8/site-packages/torch/_decomp/decompositions.py", line 2054, in index_add
return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
File "/tmp/lib/python3.8/site-packages/torch/_decomp/decompositions.py", line 2075, in _index_add
or utils.is_weakly_lesser_type(type(alpha), python_type),
File "/tmp/lib/python3.8/site-packages/torch/_prims_common/init.py", line 1073, in is_weakly_lesser_type
assert a in ordered_types
torch._dynamo.exc.TorchRuntimeError: Failed running call_method index_add(
(FakeTensor(..., size=(s1, s5, s2, s3)), 1, FakeTensor(..., size=(s0,), dtype=torch.int64), FakeTensor(..., size=(s1, s0, s2, s3))), **{'alpha': s4}):

Minified repro

No response

Versions

[pip3] numpy==1.24.4
[pip3] torch==2.1.0
[pip3] torchaudio==2.0.1
[pip3] torchdata==0.6.1
[pip3] torchmetrics==1.2.0
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.1

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

@kjabon
Copy link
kjabon commented Oct 13, 2023

I've got a dumb question that I cannot for the life of me find the answer to anywhere in docs, wiki, contrib guides, etc. What is the oncall: pt2 label?

@ezyang
Copy link
Contributor
ezyang commented Oct 13, 2023

we have an oncall rotation specifically for pt2 issues, so if you label it with that it will directly get to sent to those folks. Also, there are some people who subscribe to all pt2 issues like me #24422

@ezyang
Copy link
Contributor
ezyang commented Oct 13, 2023

same as #111208

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0