8000 [ROCm] [Inductor] Nightly torch.compile assert_size_stride AssertionError: wrong number of dimensions · Issue #137414 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ROCm] [Inductor] Nightly torch.compile assert_size_stride AssertionError: wrong number of dimensions #137414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
functionstackx opened this issue Oct 6, 2024 · 6 comments
Labels
module: inductor module: rocm AMD GPU support for Pytorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@functionstackx
Copy link
Contributor
functionstackx commented Oct 6, 2024

🐛 Describe the bug

hi @hliuca ,

ROCm Nightly has been greatly improved performance ever since the F.Linear fix but unfortunately pytorch compile does not work on ROCm even though it works on CUDA.

I am hitting assert_size_stride in ROCm inductor. Guessing the bug is in CausalSelfAttention layer. I have attached a reprod on this issue.

cc: @hongxiayang

Eager Command (this works without crash)

DISABLE_ADDMM_HIP_LT=0 python train.py

Compile Command that causes the error

DISABLE_ADDMM_HIP_LT=0 python train.py --pt-compile=True

Error Trace

  File "/workspace/llm-train-bench/train.py", line 98, in train
    loss.backward()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 613, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2082, in backward
    out = call_compiled_backward()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2014, in call_compiled_backward
    out = call_func_at_runtime_with_args(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 629, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1562, in __call__
    return self.current_callable(inputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2002, in run
    return model(new_inputs)
  File "/tmp/torchinductor_root/w2/cw2ykhght4cwmzshdlwygh6t25c4cydylbclfhamin6rxvug7plt.py", line 1095, in call
    assert_size_stride(getitem_6, (8, 12, 1024), (12288, 1024, 1))
AssertionError: wrong number of dimensions

Reprod Script

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F
import torch.nn as nn

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)
 
    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV


def train(
    bsz: int = 8,
    n_workers: int = 8,
    n_steps: int = 128,
    pt_compile: bool = False,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(0)

    cfg_json = {
    "n_layers": 1,
    "n_heads": 12,
    "d_embd": 768,
    "max_seq_len": 1024,
    "vocab_size": 50304,
    }

    model = GPT(**cfg_json).to('cuda:0')
    if pt_compile:
        model = torch.compile(model)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    model.train()

    for step_idx in range(100):
        input_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
        label_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')

        with torch.amp.autocast('cuda', torch.bfloat16):
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        torch.cuda.synchronize()
        print(f"finish {step_idx} step")

if __name__ == "__main__":
    import fire
    fire.Fire(train)

Versions

ROCm

# pip list | grep torch
pytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241006+rocm6.2
torchvision             0.18.0a0+68ba7ec

Nvidia

versions where this works on nvidia include 24.07 ngc container, 24.08, pypi nightly

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Oct 6, 2024
@eellison eellison added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: inductor labels Oct 8, 2024
@jataylo
Copy link
Collaborator
jataylo commented Oct 9, 2024

Hi @OrenLeung

I can reproduce this. As a workaround if we use the math backend for the scaled_dot_product_attention call then the code works, e.g..

with torch.backends.cuda.sdp_kernel(
    enable_math=True,
    enable_flash=False,
    enable_mem_efficient=False
):

It seems we are hitting some issues with flash attention with torch.compile, if I run the generated code independently then we see this:

  File "/tmp/torchinductor_root/jm/cjmyhjyelpqbkmepriz55zdgziqhvzfjuo6rcplxaz4ik3z3pssj.py", line 1218, in call
    buf34 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(reinterpret_tensor(buf29, (8, 12, 1024, 64), (786432, 64, 768, 1), 0), permute_1, permute_2, permute_3, getitem_5, getitem_6, None, None, 1024, 1024, 0.0, True, getitem_11, getitem_12, scale=0.125)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
RuntimeError: L's rank should be 2 but is 3

@jataylo
Copy link
Collaborator
jataylo commented Oct 9, 2024

Hi @OrenLeung I've confirmed ROCm#1428 fixes the issue.

I'll reping here when we have upstreamed this.

@functionstackx
Copy link
Contributor Author

Hi @OrenLeung I've confirmed ROCm#1428 fixes the issue.

I'll reping here when we have upstreamed this.

Thanks for looking into this!

will look into using math backend for sdpa or maybe just explicitly graph breaking on flash sdpa

@functionstackx
Copy link
Contributor Author
functionstackx commented Oct 15, 2024

workaround I am doing till #137717 gets into nightly

def disable_torch_compile_if_amd(func):
    # Define a wrapper that applies the torch.compiler.disable decorator conditionally
    if torch.cuda.is_available() and "MI300X" in torch.cuda.get_device_name():
        return torch.compiler.disable()(func)
    else:
        return func

@disable_torch_compile_if_amd
def scaled_dot_product_attention_wrapper(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True):
    # with torch.nn.attention.sdpa_kernel(
    #     enable_math=True,
    #     enable_flash=False,
    #     enable_mem_efficient=False
    # ):
    o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=dropout_p, is_causal=is_causal)
    return o_BHTD

pytorchmergebot pushed a commit that referenced this issue Oct 17, 2024
The logsumexp tensor was considered for internal use only but apparently exposed to unit tests and inductors.

The stream should be selected after picking the current device. Otherwise the code is checking the default device's architecture.

Fixes #131316 #137414

Pull Request resolved: #137717
Approved by: https://github.com/drisspg

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
@functionstackx
Copy link
Contributor Author

hi @hliuca @jataylo @xinyazhang ,

I can confirm this fixed the issue in my internal codebase. Thank you for the fix!

Closing this issue as fixed in PR 137717

@hliuca
Copy link
hliuca commented Oct 21, 2024

Thank you @OrenLeung

jataylo added a commit to jataylo/pytorch that referenced this issue Nov 13, 2024
The logsumexp tensor was considered for internal use only but apparently exposed to unit tests and inductors.

The stream should be selected after picking the current device. Otherwise the code is checking the default device's architecture.

Fixes pytorch#131316 pytorch#137414

Pull Request resolved: pytorch#137717
Approved by: https://github.com/drisspg

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
(cherry picked from commit 770fcaf)
jataylo added a commit to jataylo/pytorch that referenced this issue Nov 13, 2024
The logsumexp tensor was considered for internal use only but apparently exposed to unit tests and inductors.

The stream should be selected after picking the current device. Otherwise the code is checking the default device's architecture.

Fixes pytorch#131316 pytorch#137414

Pull Request resolved: pytorch#137717
Approved by: https://github.com/drisspg

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
(cherry picked from commit 770fcaf)
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue Nov 13, 2024
…torch#137717) (#1695)

The logsumexp tensor was considered for internal use only but apparently
exposed to unit tests and inductors.

The stream should be selected after picking the current device.
Otherwise the code is checking the default device's architecture.

Fixes pytorch#131316 pytorch#137414

Pull Request resolved: pytorch#137717
Approved by: https://github.com/drisspg

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
(cherry picked from commit 770fcaf)

Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor module: rocm AMD GPU support for Pytorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

5 participants
0