-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Inductor] Broadcast to range tree shape before block pointer store #151399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151399
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c9881bb with merge base b0e28f6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison @isuruf I'm glad you're working on this feature! We currently have to add broadcasts defensively since it's hard to know what the actual shape is. The triton compiler eliminates these no-op broadcasts anyways, but they make the code harder to read. It would be great if the IR tracked the shape more directly. |
@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#151399) # Feature This fixes a bug related to block pointer stores. Since Triton's block pointer stores don't support implicit broadcasting, in certain cases we need to generate a `reshape->broadcast->reshape` pattern to ensure that the tensor being stored has the same shape as the block pointer. This happens when the block indexing expression involves strides of 0 or dimensions of 1, both of which we eliminate from the block pointer. The existing logic missed an important edge case. We may need a broadcast prior to the first `reshape` of this pattern, in case the tensor comes from a load with implicit broadcasting. For example, if the range trees have shape `[YBLOCK, XBLOCK]`, but the load has a shape `[1, XBLOCK]`, we need to broadcast this to `[YBLOCK, XBLOCK]` prior to storing. See the example kernel below, which comes from `expand` -> `clone` with 3D tiling. The load has an implicit broadcast, and the store has a reshape. Thus, we need to insert an explicit broadcast between them. ``` @triton.jit def triton_poi_fused_clone_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): znumel = 32 ynumel = 1 xnumel = 32 zoffset = tl.program_id(2) * ZBLOCK zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None] zmask = zindex < znumel yoffset = tl.program_id(1) * YBLOCK yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None] ymask = tl.full([ZBLOCK, YBLOCK, XBLOCK], True, tl.int1) xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :] xmask = xindex < xnumel x1 = xindex z0 = zindex tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[32], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0], eviction_policy='evict_last')[None, None, :] tl.store(tl.make_block_ptr(out_ptr0, shape=[32, 32], strides=[32, 1], block_shape=[ZBLOCK, XBLOCK], order=[1, 0], offsets=[zoffset, xoffset]), tl.reshape(tl.broadcast_to(tmp0, [ZBLOCK, YBLOCK, XBLOCK]), [ZBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1]) ''', device_str='cuda') ``` The tricky part is that we don't want to emit redundant broadcasts in the store. This PR reworks the logic a bit to make sure we don't emit a second broadcast unless it actually changes the shape. # Test plan Added a CI test for this case, which would fail on trunk. Checked that only one broadcast was emitted. Pull Request resolved: pytorch#151399 Approved by: https://github.com/jansel, https://github.com/eellison
Feature
This fixes a bug related to block pointer stores. Since Triton's block pointer stores don't support implicit broadcasting, in certain cases we need to generate a
reshape->broadcast->reshape
pattern to ensure that the tensor being stored has the same shape as the block pointer. This happens when the block indexing expression involves strides of 0 or dimensions of 1, both of which we eliminate from the block pointer.The existing logic missed an important edge case. We may need a broadcast prior to the first
reshape
of this pattern, in case the tensor comes from a load with implicit broadcasting. For example, if the range trees have shape[YBLOCK, XBLOCK]
, but the load has a shape[1, XBLOCK]
, we need to broadcast this to[YBLOCK, XBLOCK]
prior to storing. See the example kernel below, which comes fromexpand
->clone
with 3D tiling. The load has an implicit broadcast, and the store has a reshape. Thus, we need to insert an explicit broadcast between them.The tricky part is that we don't want to emit redundant broadcasts in the store. This PR reworks the logic a bit to make sure we don't emit a second broadcast unless it actually changes the shape.
Test plan
Added a CI test for this case, which would fail on trunk. Checked that only one broadcast was emitted.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov