8000 [Inductor] Broadcast to range tree shape before block pointer store by blaine-rister · Pull Request #151399 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Broadcast to range tree shape before block pointer store #151399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/inductor/test_torchinductor_strided_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,33 @@ def test_pointwise_index_order(self):
xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950
)

def test_expand_clone_broadcast(self):
"""
Test expand followed by clone. This uses an explicit Triton broadcast.
"""
base_size = (1, 32)
expanded_size = (32, 32)

def foo(x):
return x.expand(*expanded_size).clone()

inps = [torch.randn(base_size, device=self.device)]
result, (triton_code,) = run_and_compare(
self,
foo,
*inps,
expected_num_triton_kernels=1,
expected_num_block_pointers=2,
config_patches={
"triton.max_tiles": 3,
"triton.prefer_nd_tiling": True,
},
)

# We should only need one broadcast.
num_broadcasts = triton_code.count("tl.broadcast_to")
self.assertEqual(num_broadcasts, 1)


@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
@config.patch(cpu_backend="triton")
Expand Down
29 changes: 20 additions & 9 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,16 @@ def codegen_broadcast_and_reshape(
# We need an explicit broadcast for stores, or if the final reshape does more
# than add singletons.
sizevars = V.graph.sizevars
require_broadcast = any(self.broadcasting_dims) and (
len(pre_broadcast_shape) != len(final_shape)
or any(
not (
sizevars.statically_known_equals(pre_dim, 1)
or sizevars.statically_known_equals(pre_dim, post_dim)
)
supports_implicit_broadcast = allow_implicit and (
len(pre_broadcast_shape) == len(final_shape)
and all(
sizevars.statically_known_equals(pre_dim, 1)
or sizevars.statically_known_equals(pre_dim, post_dim)
for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape)
)
)

if not allow_implicit or require_broadcast:
if any(self.broadcasting_dims) and not supports_implicit_broadcast:
value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})"

# Reshape to the final shape.
Expand Down Expand Up @@ -2099,7 +2097,20 @@ def codegen_block_ptr(
return block_ptr, other

def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
# Stores require an explicit broadcast.
# Stores require an explicit broadcast. We do this in two phases:
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
# 2. In case the block pointer has different dimensionality, broadcast/reshape the
# result to the shape of the pointer.
value = f"tl.broadcast_to({value}, {indexing.final_shape})"

# These dims no longer need broadcasting.
for idx, (dim, broadcast_dim) in enumerate(
zip(indexing.final_shape, indexing.broadcast_shape)
):
if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
indexing.broadcasting_dims[idx] = False

value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
)
Expand Down
Loading
0