8000 [mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 by srcarroll · Pull Request #97607 · llvm/llvm-project · GitHub
[go: up one dir, main page]

Skip to content

[mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 #97607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5020e49
Add getters for multi dim loop variables in LoopLikeOpInterface
srcarroll Jun 5, 2024
50852d5
Refactor LoopFuseSiblingOp and support parallel fusion
srcarroll Jun 4, 2024
b73238a
add checkFusionStructuralLegality
srcarroll Jun 5, 2024
f5bbd13
replace isLoopWithIdenticalConfiguration with checkFusionStructuralLe…
srcarroll Jun 5, 2024
7d99581
address review comment
srcarroll Jun 5, 2024
a5fa3b3
Make return types optional and change names
srcarroll Jun 6, 2024
1babe68
change return type of getInductionVars to SmallVector<Value>
srcarroll Jun 6, 2024
009fd15
address maks's comments
srcarroll Jun 6, 2024
d34ad95
change interface method names again and revert steps operand change
srcarroll Jun 6, 2024
e0e5262
return option induction vars
srcarroll Jun 6, 2024
7115a6e
address review comments
srcarroll Jun 7, 2024
1d4a444
Merge branch 'main' into add-loop-like-interface-methods
srcarroll Jun 7, 2024
af6b030
Merge branch 'add-loop-like-interface-methods' into scf-parallel-loop…
srcarroll Jun 7, 2024
6336fdf
update after rebase
srcarroll Jun 7, 2024
aa15617
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 7, 2024
7dbe646
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 8, 2024
86406c3
refactor main parallel fusion logic from fuseIfLegal to util func
srcarroll Jun 9, 2024
694d589
remove unused functions
srcarroll Jun 9, 2024
67cb64f
refactor fuseIndependentSiblingForLoops to reuse replaceWithAdditiona…
srcarroll Jun 9, 2024
cc8599f
refactor fuseIndependentSiblingForallLoops to reuse replaceWithAdditi…
srcarroll Jun 9, 2024
48b1af9
wip
srcarroll Jun 10, 2024
7a51cb3
Decouple concrete loop type from `createFused` function
srcarroll Jun 17, 2024
3087326
Refactor ForallOp::replaceWithAdditionalYields
srcarroll Jun 17, 2024
bcf3d4a
revert unnecessary changes
srcarroll Jun 17, 2024
0cb3c4e
cleanup
srcarroll Jun 18, 2024
7e41a54
address some review comments
srcarroll Jun 21, 2024
cc95d75
move `createFused` to `LoopLikeInterface.h`
srcarroll Jun 24, 2024
3430a36
address more review comments
srcarroll Jun 26, 2024
8447c12
switch to function_ref
srcarroll Jun 27, 2024
fbd7b72
check optional values
srcarroll Jun 27, 2024
ffb73a7
replace equalIterationSpaces with checkFusionStructuredLegality
srcarroll Jun 27, 2024
a6d0588
check if isOpSibling in checkFusionStructuralLegality
srcarroll Jun 27, 2024
ff47980
remove extra dominance check
srcarroll Jun 27, 2024
c6847ec
address more review comments
srcarroll Jun 27, 2024
f50c6aa
add more lit tests for scf.parallel
srcarroll Jun 27, 2024
6dd68c1
check for equal loop types in checkFusionStructuralLegality
srcarroll Jun 27, 2024
99d821b
address more comments
srcarroll Jun 27, 2024
6825c15
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 27, 2024
7f9c172
Fix bug in fusion refactor and add test
srcarroll Jul 3, 2024
4b4fd91
add comment
srcarroll Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move createFused to LoopLikeInterface.h
  • Loading branch information
srcarroll committed Jun 24, 2024
commit cc95d75d2cc09f8a33850f3867c8313e374a0dfd
20 changes: 20 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,24 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"

namespace mlir {
/// A function that rewrites `target`'s terminator as a teminator obtained by
/// fusing `source` into `target`.
using FuseTerminatorFn =
std::function<void(RewriterBase &rewriter, LoopLikeOpInterface source,
LoopLikeOpInterface &target, IRMapping mapping)>;

/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
/// `replaceWithAdditionalYields` interface method to replace the loop with a
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
/// callback is repsonsible for updating the fused loop terminator.
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
LoopLikeOpInterface source,
10000 RewriterBase &rewriter,
NewYieldValuesFn newYieldValuesFn,
FuseTerminatorFn fuseTerminatorFn);

} // namespace mlir

#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
101 changes: 20 additions & 81 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,93 +1082,14 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
target.getLoopSteps() == source.getLoopSteps();
auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
auto forAllSource = dyn_cast<scf::ForallOp>(*source);
// TODO: Decouple checks on concrete loop types and move this function
// somewhere for general utility for `LoopLikeOpInterface`
if (forAllTarget && forAllSource)
return iterSpaceEq &&
forAllTarget.getMapping() == forAllSource.getMapping();
return iterSpaceEq;
}

template <typename LoopTy>
void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
IRMapping &mapping) {}

template <>
void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
scf::ForallOp &fused, IRMapping &mapping) {
// Fuse the old terminator in_parallel ops into the new one.
scf::InParallelOp fusedTerm = fused.getTerminator();
rewriter.setInsertionPointToEnd(fusedTerm.getBody());
for (Operation &op : source.getTerminator().getYieldingOps())
rewriter.clone(op, mapping);
}

template <>
void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
scf::ForOp &fused, IRMapping &mapping) {
// Build fused yield results by appropriately mapping original yield operands.
auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
}

// TODO: We should maybe add a method to LoopLikeOpInterface that will
// facilitate this transformation. For now, this acts as a placeholder.
template <>
void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
LoopLikeOpInterface &fused, IRMapping &mapping) {
if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
mapping);
} else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
fuseTerminator(rewriter, cast<scf::ForallOp>(source),
cast<scf::ForallOp>(fused), mapping);
} else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
cast<scf::ParallelOp>(fused), mapping);
} else {
llvm_unreachable("unsupported loop types.");
return;
}
}

LoopLikeOpInterface createFused(LoopLikeOpInterface target,
LoopLikeOpInterface source,
RewriterBase &rewriter,
NewYieldValuesFn newYieldValuesFn) {
auto targetIterArgs = target.getRegionIterArgs();
auto targetInductionVar = *target.getLoopInductionVars();
SmallVector<Value> targetYieldOperands(target.getYieldedValues());
auto sourceIterArgs = source.getRegionIterArgs();
auto sourceInductionVar = *source.getLoopInductionVars();
SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
auto sourceRegion = source.getLoopRegions().front();
LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
newYieldValuesFn);

// Map control operands.
IRMapping mapping;
mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
mapping.map(targetIterArgs,
fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
mapping.map(targetYieldOperands,
fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
mapping.map(sourceIterArgs,
fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
mapping.map(sourceYieldOpera 8000 nds,
fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
// Append everything except the terminator into the fused operation.
rewriter.setInsertionPoint(
fusedLoop.getLoopRegions().front()->front().getTerminator());
for (Operation &op : sourceRegion->front().without_terminator())
rewriter.clone(op, mapping);

// TODO: Replace with corresponding interface method if added
fuseTerminator(rewriter, source, fusedLoop, mapping);

return fusedLoop;
}

scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
Expand All @@ -1177,6 +1098,15 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
// `ForallOp` does not have yields, rather an `InParallelOp` terminator.
return ValueRange{};
},
[&](RewriterBase &b, LoopLikeOpInterface source,
LoopLikeOpInterface &target, IRMapping mapping) {
auto sourceForall = cast<scf::ForallOp>(source);
auto targetForall = cast<scf::ForallOp>(target);
scf::InParallelOp fusedTerm = targetForall.getTerminator();
b.setInsertionPointToEnd(fusedTerm.getBody());
for (Operation &op : sourceForall.getTerminator().getYieldingOps())
b.clone(op, mapping);
}));
rewriter.replaceOp(source,
fusedLoop.getResults().take_back(source.getNumResults()));
Expand All @@ -1191,12 +1121,21 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
target, source, rewriter,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
return source.getYieldedValues();
},
[&](RewriterBase &b, LoopLikeOpInterface source,
LoopLikeOpInterface &target, IRMapping mapping) {
auto sourceFor = cast<scf::ForOp>(source);
auto targetFor = cast<scf::ForOp>(target);
auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
}));
rewriter.replaceOp(source,
fusedLoop.getResults().take_back(source.getNumResults()));
return fusedLoop;
}

// TODO: Finish refactoring this a la the above, but likely requires additional
// interface methods.
scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
Block *block1 = target.getBody();
Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Interfaces/LoopLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "mlir/Interfaces/LoopLikeInterface.h"

#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/DenseSet.h"

Expand Down Expand Up @@ -113,3 +115,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {

return success();
}

67ED LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
LoopLikeOpInterface source,
RewriterBase &rewriter,
NewYieldValuesFn newYieldValuesFn,
FuseTerminatorFn fuseTerminatorFn) {
auto targetIterArgs = target.getRegionIterArgs();
auto targetInductionVar = *target.getLoopInductionVars();
SmallVector<Value> targetYieldOperands(target.getYieldedValues());
auto sourceIterArgs = source.getRegionIterArgs();
auto sourceInductionVar = *source.getLoopInductionVars();
SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
auto sourceRegion = source.getLoopRegions().front();
LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
newYieldValuesFn);

// Map control operands.
IRMapping mapping;
mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
mapping.map(targetIterArgs,
fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
mapping.map(targetYieldOperands,
fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
mapping.map(sourceIterArgs,
fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
mapping.map(sourceYieldOperands,
fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
// Append everything except the terminator into the fused operation.
rewriter.setInsertionPoint(
fusedLoop.getLoopRegions().front()->front().getTerminator());
for (Operation &op : sourceRegion->front().without_terminator())
rewriter.clone(op, mapping);

// TODO: Replace with corresponding interface method if added
fuseTerminatorFn(rewriter, source, fusedLoop, mapping);

return fusedLoop;
}
0