8000 Compiling RMSNorm Triton Kernal gives error · Issue #121526 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Compiling RMSNorm Triton Kernal gives error #121526

8000 New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Skylion007 opened this issue Mar 8, 2024 · 2 comments
Open

Compiling RMSNorm Triton Kernal gives error #121526

Skylion007 opened this issue Mar 8, 2024 · 2 comments
Labels
module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Skylion007
Copy link
Collaborator
Skylion007 commented Mar 8, 2024

🐛 Describe the bug

I tried to compile some code in the mamba_ssm repo, and hit a compilation error

Error logs

Traceback (most recent call last):
File "/llmlib/scripts/train_mosaic_bert.py", line 287, in
main(cfg)
File "/llmlib/scripts/train_mosaic_bert.py", line 274, in main
trainer.fit()
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1972, in fit
self._train_loop()
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2153, in _train_loop
total_loss_dict = self._train_batch(use_grad_scaling)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2336, in _train_batch
optimizer.step(closure=lambda loss_dict=total_loss_dict, **kwargs: self._train_microbatches(
File "/usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py", line 75, in wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/optim/optimizer.py", line 385, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/optim/decoupled_weight_decay.py", line 278, in step
loss = closure()
^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2336, in
optimizer.step(closure=lambda loss_dict=total_loss_dict, **kwargs: self._train_microbatches(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2439, in _train_microbatches
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2503, in _train_microbatch
self.state.outputs = self.state.model(self.state.batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/models/huggingface.py", line 446, in forward
output = self.model(**batch) # type: ignore (thirdparty)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 430, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 344, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 210, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_rcps.py", line 152, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_rcps.py", line 104, in forward
def forward(self, x, residual=None, prenorm=False):
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 493, in forward
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 477, in rms_norm_fn
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
File "/usr/lib/python3/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 381, in forward
@staticmethod
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 123, in _layer_norm_fwd
def _layer_norm_fwd(
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 149, in resume_in__layer_norm_fwd
MAX_FUSED_SIZE = 65536 // x.element_size()
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in resume_in__layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 126, in run
def run(self, *args, **kwargs):
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 143, in
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/testing.py", line 102, in do_bench
fn()
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 106, in kernel_call
def kernel_call():
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call
self.fn.run(
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 402, in run
def run(self, *args, **kwargs):
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 413, in resume_in_run
grid = get_special_arg("grid")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run
num_warps = get_special_arg("num_warps")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run
num_ctas = get_special_arg("num_ctas", 1)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run
num_stages = get_special_arg("num_stages")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run
bound_args = self.signature.bind(*args, **kwargs)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run
bound_args = self.signature.bind(*args, **kwargs)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run
bound_args.apply_defaults()
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run
assert len(bound_args.arguments) == len(self.params)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run
assert len(bound_args.arguments) == len(self.params)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 453, in resume_in_run
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 453, in
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/eval_frame.py", line 652, in catch_errors
return hijacked_callback(frame, cache_entry, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 527, in transform
tracer.run()
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
super().run()
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
self.call_function(fn, args, kwargs)
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/tensor.py", line 756, in call_method
return wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.traceback) from None
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
ret_val = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
return fn()
^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1487, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1592, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.traceback) from e
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1573, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 1392, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
op_impl_out = op_impl(self, func, args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 737, in nyi
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method is_pinned(
(FakeTensor( 8000 ..., device='cuda:5', size=(8192, 768), requires_grad=True),), **{}):
NYI: aten.is_pinned.default

from user code:
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 241, in _pinned_memory_of
return arg.is_pinned()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

Minified repro

Uminified but you can replicate by trying to torch compile the mamba repo. Exact nn.Module that causes issue can be found here:
https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/layernorm.py#L481

Versions

Error was occurring on recent nightly

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @oulgen @aakhundov

@oulgen oulgen added the module: user triton related to ability to directly torch.compile triton kernels label Mar 8, 2024
@oulgen
Copy link
Contributor
oulgen commented Mar 8, 2024

It looks like the kernel is being executed using FakeTensor inputs, will take a look

@williamwen42 williamwen42 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 8, 2024
@s-rog
Copy link
s-rog commented May 22, 2024

https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py#L926

from flash_attn.ops.triton.layer_norm import RMSNorm

Not sure if it's the same in mamba but RMSNorm can be imported and tested through flash_attn. I ran into an error as well trying to compile a nn.Sequential with RMSNorm from flash.

539B
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0