Compiling RMSNorm Triton Kernal gives error #121526
Labels
module: user triton
related to ability to directly torch.compile triton kernels
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
I tried to compile some code in the mamba_ssm repo, and hit a compilation error
Error logs
Traceback (most recent call last):
File "/llmlib/scripts/train_mosaic_bert.py", line 287, in
main(cfg)
File "/llmlib/scripts/train_mosaic_bert.py", line 274, in main
trainer.fit()
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1972, in fit
self._train_loop()
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2153, in _train_loop
total_loss_dict = self._train_batch(use_grad_scaling)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2336, in _train_batch
optimizer.step(closure=lambda loss_dict=total_loss_dict, **kwargs: self._train_microbatches(
File "/usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py", line 75, in wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/optim/optimizer.py", line 385, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/optim/decoupled_weight_decay.py", line 278, in step
loss = closure()
^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2336, in
optimizer.step(closure=lambda loss_dict=total_loss_dict, **kwargs: self._train_microbatches(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2439, in _train_microbatches
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 2503, in _train_microbatch
self.state.outputs = self.state.model(self.state.batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/composer/models/huggingface.py", line 446, in forward
output = self.model(**batch) # type: ignore (thirdparty)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 430, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 344, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_caduceus.py", line 210, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_rcps.py", line 152, in forward
def forward(
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/yairschiff/caduceus_base/c699cf9dce9f94a2295a000c9635abd41d372779/modeling_rcps.py", line 104, in forward
def forward(self, x, residual=None, prenorm=False):
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 493, in forward
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 477, in rms_norm_fn
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
File "/usr/lib/python3/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 381, in forward
@staticmethod
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 123, in _layer_norm_fwd
def _layer_norm_fwd(
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 149, in resume_in__layer_norm_fwd
MAX_FUSED_SIZE = 65536 // x.element_size()
File "/usr/lib/python3/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in resume_in__layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 126, in run
def run(self, *args, **kwargs):
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
self.nargs = dict(zip(self.arg_names, args))
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 143, in resume_in_run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 143, in
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 122, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/testing.py", line 102, in do_bench
fn()
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 106, in kernel_call
def kernel_call():
File "/usr/lib/python3/dist-packages/triton/runtime/autotuner.py", line 110, in resume_in_kernel_call
self.fn.run(
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 402, in run
def run(self, *args, **kwargs):
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 413, in resume_in_run
grid = get_special_arg("grid")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 414, in resume_in_run
num_warps = get_special_arg("num_warps")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 415, in resume_in_run
num_ctas = get_special_arg("num_ctas", 1)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 416, in resume_in_run
num_stages = get_special_arg("num_stages")
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run
bound_args = self.signature.bind(*args, **kwargs)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 426, in resume_in_run
bound_args = self.signature.bind(*args, **kwargs)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 427, in resume_in_run
bound_args.apply_defaults()
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run
assert len(bound_args.arguments) == len(self.params)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 429, in resume_in_run
assert len(bound_args.arguments) == len(self.params)
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 453, in resume_in_run
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 453, in
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/eval_frame.py", line 652, in catch_errors
return hijacked_callback(frame, cache_entry, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/convert_frame.py", line 527, in transform
tracer.run()
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
super().run()
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
self.call_function(fn, args, kwargs)
File "/usr/lib/python3/dist-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/tensor.py", line 756, in call_method
return wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.traceback) from None
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
ret_val = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
return fn()
^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1487, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1592, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.traceback) from e
File "/usr/lib/python3/dist-packages/torch/_dynamo/utils.py", line 1573, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 1392, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
op_impl_out = op_impl(self, func, args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/torch/_subclasses/fake_tensor.py", line 737, in nyi
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method is_pinned((FakeTensor( 8000 ..., device='cuda:5', size=(8192, 768), requires_grad=True),), **{}):
NYI: aten.is_pinned.default
from user code:
File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 241, in _pinned_memory_of
return arg.is_pinned()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
Uminified but you can replicate by trying to torch compile the mamba repo. Exact nn.Module that causes issue can be found here:
https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/layernorm.py#L481
Versions
Error was occurring on recent nightly
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @oulgen @aakhundov
The text was updated successfully, but these errors were encountered: