8000 [inductor] `torch.slice_scatter` throws `AssertionError` when meeting internal `float32` · Issue #147842 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] torch.slice_scatter throws AssertionError when meeting internal float32 #147842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shaoyuyoung opened this issue Feb 25, 2025 · 9 comments · May be fixed by #149814 or #151911
Open

[inductor] torch.slice_scatter throws AssertionError when meeting internal float32 #147842

shaoyuyoung opened this issue Feb 25, 2025 · 9 comments · May be fixed by #149814 or #151911
Labels
good first issue module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shaoyuyoung
Copy link
Contributor
shaoyuyoung commented Feb 25, 2025

🐛 Describe the bug

description: when meeting internal float32 (it's y in my case), eager pass the check and return 0 while inductor throws an assertion error
device: both on triton and CPP

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        y = torch.Tensor([0])  # y dtype: torch.float32
        x = torch.slice_scatter(y, x, 0)
        return x


model = Model()

x = torch.Tensor([0]).to(torch.int64)

inputs = [x]


def run_test(model, inputs, backend):
    model.eval()
    torch.manual_seed(0)
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    try:
        c_output = model(*inputs)
        print(c_output)
    except Exception as e:
        print(e)


run_test(model, inputs, 'eager')
run_test(model, inputs, 'inductor')

Error logs

tensor([0.])
LoweringException: AssertionError: 
  target: aten.slice_scatter.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cpu',
      torch.float32,
      def inner_fn(index):
          _ = index
          tmp0 = ops.constant(0.0, torch.float32)
          return tmp0
      ,
      ranges=[1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    )
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.int64, size=[1], stride=[1]))
  ))

Versions

nightly 20250225

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

@Ajay-26
Copy link
Ajay-26 commented Feb 28, 2025

Hi!
Correct me if I'm wrong, but should x not have dtype as torch.float32 ? It works for me in that case. The error looks like it's an AssertionError because they are of different types

@shaoyuyoung
Copy link
Contributor Author

Hi, @Ajay-26

Correct me if I'm wrong, but should x not have dtype as torch.float32 ?

u r right! However, Eager backend can do implicit dtype conversion as u can print(c_output.dtype) to find it is torch.float32 although the original input dtype is torch.int64 (x = torch.Tensor([0]).to(torch.int64)).

Unfortunately, Inductor compiler can't do this dtype conversion and throws assertionerror which violates the criteria of DL compiler (simulate any behavior of eager). So I think it is a potential behavior inconsistency between pytorch eager and inductor.

Feel free to have any further discussion if it is helpful to you. :)

@Numbers0689
Copy link

Hi @shaoyuyoung , I’d like to work on this issue! from my understanding, the inductor backend currently does not perform implicit dtype conversion, which leads to an assertion error. i plan to modify the inductor compiler to align with the eager backend’s behavior.

before proceeding, i wanted to confirm:

  • should the fix involve explicitly converting int64 inputs to float32 within inductor’s handling of slice_scatter?
  • are there any existing tests that check for dtype consistency in eager vs inductor?

let me know if this is the right way, thanks!

@shaoyuyoung
Copy link
Contributor Author
shaoyuyoung commented Mar 1, 2025

Hi, @Numbers0689 , thanks for your comment and kindness!

  • should the fix involve explicitly converting int64 inputs to float32 within inductor’s handling of slice_scatter?

This solution is enough for this case. But I am not sure whether other dtypes (int32, int8, uint64, etc.) show similar behaviors with int64. It would be better if we could deal with all these similar problems at once.

  • are there any existing tests that check for dtype consistency in eager vs inductor?

To be honest, I am also not sure (but I think no tests exist currently). Previously, we have discussed some dtype inconsistency issues in #147666. Maybe you can get some inspiration from #147666 (comment).
Anyway, regardless of whether there are tests here, you should write a UT (unit test) to verify that the fix is correct. :)

I'm not sure if my answer is correct, feel free to discuss more. Or, you can draft a PR first? And then pt developers will take a look at your PR (for code review). :)

@Numbers0689
Copy link

thanks for the clarification! I'll check out the dtype behavior across other integer types and go through #147666

I'll also add a unit test to verify the fix and start drafting a PR.

@Numbers0689
Copy link

Hi, @shaoyuyoung , while investigating torch.slice_scatter I came across torch.select_scatter and upon testing, both of them throws AssertionError for all Integer dtype (int8, int16, int32, int64, uint8, uint16, uint32, uint64)

@shaoyuyoung
Copy link
Contributor Author
shaoyuyoung commented Mar 2, 2025

Well, expected results... As I have encountered some similar cases. It seems that dtype processing in inductor is very fragile (about consistency with eager's dtype processing). You can get more details in #144362.

Back to fixing this issue. Previously, I have tried to add a manual check for these dtypes one by one like #145136. I think it is enough for this issue. But it would be great if we could find some way to solve this problem uniformly. :)

@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 4, 2025
@Vikramjeetsingh07
Copy link

@shaoyuyoung
#to ensure dtype consistency I tried replicating the error and the simple below fix helps
y = torch.Tensor([0]).to(x.dtype)
x = torch.slice_scatter(y, x, 0)

as you said this will only resolve this. broader solution is fixing Modifying the Inductor Compiler's Type Inference. If you want i can fix and commit a branch. Hope this helps.

@shaoyuyoung
Copy link
Contributor Author

Hi, @Vikramjeetsingh07
thanks for the comment. I think developers are willing to see any PR to fix issues.
Feel free to draft a PR. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants
0