-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[inductor] build failed with gcc8.3 #130815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Pytorch requires gcc>=9.4, so it's better to install gcc 9 and use it to build. |
It's almost impossible to lower the GCC version because it affects too much. |
Some compatibility code already exists here pytorch/torch/_inductor/codecache.py Line 2183 in fedae41
|
Similar issue recently reported here: https://discuss.pytorch.org/t/torch-compile-fails-with-c-compile-error-expected-identifier-before-token [Edit] I see there is a fix already, let's just merge it |
Do not use `[[unlikely]]` as its c++20 language features, see https://en.cppreference.com/w/cpp/language/attributes/likely Fixes #130815 Pull Request resolved: #130816 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/malfet (cherry picked from commit 32f9a80)
Replace [[unlikely]] with unlikely(x) (#130816) Do not use `[[unlikely]]` as its c++20 language features, see https://en.cppreference.com/w/cpp/language/attributes/likely Fixes #130815 Pull Request resolved: #130816 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/malfet (cherry picked from commit 32f9a80) Co-authored-by: Danielmic <30855238+Danielmic@users.noreply.github.com>
Validated with PyTorch 2.4.1 rc |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
Building on Debian 10 using the default compiler GCC 8.3 fails. It looks that the [[likely]] and [[unlikely]] attributes are not well supported by this GCC version.
Error logs
Traceback (most recent call last):
File "//test.py", line 7, in
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
result = inner_convert(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/opt/py3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2268, in RETURN_VALUE
self.output.compile_subgraph(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 971, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/opt/py3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1168, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1241, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/dynamo/output_graph.py", line 1222, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/init.py", line 1729, in call
return compile_fx(model, inputs, config_patches=self.config)
File "/opt/py3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1330, in compile_fx
return aot_autograd(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 58, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 903, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 628, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 443, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 648, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 119, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1257, in fw_compiler_base
return inner_compile(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
File "/opt/py3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/py3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 438, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 714, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1307, in compile_to_fn
return self.compile_to_module().call
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1254, in compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2172, in load_by_key_path
exec(code, mod.dict, mod.dict)
File "/tmp/torchinductor_root/mm/cmmtcbmmfakm37djocu6jtpak3f6yuldla55frjautjluv52xvyz.py", line 59, in
async_compile.wait(globals())
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2727, in wait
scope[key] = result.result()
File "/opt/py3.10/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/opt/py3.10/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/opt/py3.10/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2086, in load_pybinding
result = cls.load(source_code + suffix, cuda)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1948, in load
compile_file(input_path, output_path, cmd)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/torch/venv3/pytorch/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1888, in compile_file
raise exc.CppCompileError(cmd, output) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CppCompileError: C++ compile error
Command:
g++ /tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -D_GLIBCXX_USE_CXX11_ABI=0 -I/torch/venv3/pytorch/lib/python3.10/site-packages/torch/include -I/torch/venv3/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/torch/venv3/pytorch/lib/python3.10/site-packages/torch/include/TH -I/torch/venv3/pytorch/lib/python3.10/site-packages/torch/include/THC -I/opt/py3.10/include/python3.10 -L/torch/venv3/pytorch/lib/python3.10/site-packages/torch/lib -L/opt/py3.10/lib -L/torch/venv3/pytorch/lib/python3.10/site-packages/torch/lib -ltorch -ltorch_cpu -lgomp -ltorch_python -lc10 -mavx2 -mfma -DCPU_CAPABILITY_AVX2 -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS -o /tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.so
Output:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In function ‘T parse_arg(PyObject*, size_t) [with T = long int; PyObject = _object; size_t = long unsigned int]’:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:59:10: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("expected int arg");
^
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In lambda function:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:59:22: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("expected int arg");
^~~~~
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In function ‘T parse_arg(PyObject*, size_t) [with T = long int; PyObject = _object; size_t = long unsigned int]’:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:59:21: error: expected ‘;’ before ‘throw’
[[unlikely]] throw std::runtime_error("expected int arg");
^~~~~~
;
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:68:14: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("tuple args required");
^
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In lambda function:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:68:26: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("tuple args required");
^~~~~
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:68:25: error: expected ‘;’ before ‘throw’
[[unlikely]] throw std::runtime_error("tuple args required");
^~~~~~
;
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:70:14: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("requires 3 args");
^
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In lambda function:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:70:26: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("requires 3 args");
^~~~~
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/tmp/torchinductor_root/i2/ci274xaqjfbtggjldthcbhqysu7zyadxcxalq56dxuyllfyilbfl.cpp:70:25: error: expected ‘;’ before ‘throw’
[[unlikely]] throw std::runtime_error("requires 3 args");
^~~~~~
;
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
Versions
Debian 10 using the default compiler GCC 8.3
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: