8000 [Inductor] Broadcast to range tree shape before block pointer store by blaine-rister · Pull Request #151399 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Broadcast to range tree shape before block pointer store #151399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

blaine-rister
Copy link
Contributor
@blaine-rister blaine-rister commented Apr 16, 2025

Feature

This fixes a bug related to block pointer stores. Since Triton's block pointer stores don't support implicit broadcasting, in certain cases we need to generate a reshape->broadcast->reshape pattern to ensure that the tensor being stored has the same shape as the block pointer. This happens when the block indexing expression involves strides of 0 or dimensions of 1, both of which we eliminate from the block pointer.

The existing logic missed an important edge case. We may need a broadcast prior to the first reshape of this pattern, in case the tensor comes from a load with implicit broadcasting. For example, if the range trees have shape [YBLOCK, XBLOCK], but the load has a shape [1, XBLOCK], we need to broadcast this to [YBLOCK, XBLOCK] prior to storing. See the example kernel below, which comes from expand -> clone with 3D tiling. The load has an implicit broadcast, and the store has a reshape. Thus, we need to insert an explicit broadcast between them.

@triton.jit
def triton_poi_fused_clone_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    znumel = 32
    ynumel = 1
    xnumel = 32
    zoffset = tl.program_id(2) * ZBLOCK
    zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None]
    zmask = zindex < znumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None]
    ymask = tl.full([ZBLOCK, YBLOCK, XBLOCK], True, tl.int1)
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]
    xmask = xindex < xnumel
    x1 = xindex
    z0 = zindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[32], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0], eviction_policy='evict_last')[None, None, :]
    tl.store(tl.make_block_ptr(out_ptr0, shape=[32, 32], strides=[32, 1], block_shape=[ZBLOCK, XBLOCK], order=[1, 0], offsets=[zoffset, xoffset]), tl.reshape(tl.broadcast_to(tmp0, [ZBLOCK, YBLOCK, XBLOCK]), [ZBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')

The tricky part is that we don't want to emit redundant broadcasts in the store. This PR reworks the logic a bit to make sure we don't emit a second broadcast unless it actually changes the shape.

Test plan

Added a CI test for this case, which would fail on trunk. Checked that only one broadcast was emitted.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Apr 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151399

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c9881bb with merge base b0e28f6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc @isuruf , this is another case that will be simplified by #149905

@blaine-rister blaine-rister marked this pull request as ready for review April 16, 2025 16:17
@blaine-rister
Copy link
Contributor Author
blaine-rister commented Apr 16, 2025

Cc @isuruf , this is another case that will be simplified by #149905

@eellison @isuruf I'm glad you're working on this feature! We currently have to add broadcasts defensively since it's hard to know what the actual shape is. The triton compiler eliminates these no-op broadcasts anyways, but they make the code harder to read. It would be great if the IR tracked the shape more directly.

@facebook-github-bot
Copy link
Contributor

@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 16, 2025
@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…ytorch#151399)

# Feature

This fixes a bug related to block pointer stores. Since Triton's block pointer stores don't support implicit broadcasting, in certain cases we need to generate a `reshape->broadcast->reshape` pattern to ensure that the tensor being stored has the same shape as the block pointer. This happens when the block indexing expression involves strides of 0 or dimensions of 1, both of which we eliminate from the block pointer.

The existing logic missed an important edge case.  We may need a broadcast prior to the first `reshape` of this pattern, in case the tensor comes from a load with implicit broadcasting. For example, if the range trees have shape `[YBLOCK, XBLOCK]`, but the load has a shape `[1, XBLOCK]`, we need to broadcast this to `[YBLOCK, XBLOCK]` prior to storing. See the example kernel below, which comes from `expand` -> `clone` with 3D tiling. The load has an implicit broadcast, and the store has a reshape. Thus, we need to insert an explicit broadcast between them.

```
@triton.jit
def triton_poi_fused_clone_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    znumel = 32
    ynumel = 1
    xnumel = 32
    zoffset = tl.program_id(2) * ZBLOCK
    zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None]
    zmask = zindex < znumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None]
    ymask = tl.full([ZBLOCK, YBLOCK, XBLOCK], True, tl.int1)
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]
    xmask = xindex < xnumel
    x1 = xindex
    z0 = zindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[32], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0], eviction_policy='evict_last')[None, None, :]
    tl.store(tl.make_block_ptr(out_ptr0, shape=[32, 32], strides=[32, 1], block_shape=[ZBLOCK, XBLOCK], order=[1, 0], offsets=[zoffset, xoffset]), tl.reshape(tl.broadcast_to(tmp0, [ZBLOCK, YBLOCK, XBLOCK]), [ZBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```

The tricky part is that we don't want to emit redundant broadcasts in the store. This PR reworks the logic a bit to make sure we don't emit a second broadcast unless it actually changes the shape.

# Test plan

Added a CI test for this case, which would fail on trunk. Checked that only one broadcast was emitted.

Pull Request resolved: pytorch#151399
Approved by: https://github.com/jansel, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0