8000 [inductor] Fix block ptr store if input is constant by kundaMwiza · Pull Request #148679 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] Fix block ptr store if input is constant #148679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kundaMwiza
Copy link
Contributor
@kundaMwiza kundaMwiza commented Mar 6, 2025

Since block ptr stores require explicit broadcasts, the input to tl.store needs to be reshaped and broadcasted. Currently, it is assumed that the input to be stored is in block form (e.g. XBLOCK), however it is possible for the input to be a scalar, and so special handling is required to reshape + broadcast the scalar to the output block shape.

Ideally the shape of the input would be an attribute of a TritonCSEVariable via shape propagation but that is not the case today. The patch in this PR determines if the input is a constant by checking the arguments to an FX store node which is not ideal. Maybe there is an alternative and simpler method

Fixes #ISSUE_NUMBER

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

… block ptrs

Remove formatting changes

Lint

Rename vars

Dont reshape float32 scalars

Handle constants separately
Copy link
pytorch-bot bot commented Mar 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148679

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 10ced66 with merge base c65ee72 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

assert block_ptr not in advancements, (
"duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
)
assert (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner changes

@@ -3039,9 +3059,9 @@ def sort(
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.sort not supported inside ops.masked"
assert self.persistent_reduction, (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner changes

@kundaMwiza
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 6, 2025
@colesbury colesbury requested a review from eellison March 6, 2025 18:35
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 6, 2025
@eellison eellison requested a review from blaine-rister March 11, 2025 22:41
Comment on lines +2073 to +2076
assert isinstance(value, CSEVariable)
# See `_shaped_constant`. Tensor constants are only created
# if the dtype is not float32.
if value.dtype is not torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of action at a distance. could we refactor this ?

And yes, i agree, we should add shape.

value = triton_reshape(
str(value), [sympy.S.One], [sympy.S.One] * len(indexing.block_shape)
)
value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(indexing.block_shape)})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

https://github.com/pytorch/pytorch/pull/151399/files recently fixed a related bug by always broadcasting block ptr stores to indexing.block_shape. In light of that change, would it make sense to reuse indexing.codegen_broadcast_and_reshape for scalars instead of having separate code paths here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0