8000 [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. by sun-jacobi · Pull Request #81248 · llvm/llvm-project · GitHub
[go: up one dir, main page]

Skip to content

[RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. #81248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
8000
Diff view
Prev Previous commit
Next Next commit
add AllowExtMask
  • Loading branch information
sun-jacobi committed Feb 21, 2024
commit f0e6c8b2f91c2476d8ddc2d44f611881db142791
47 changes: 20 additions & 27 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13316,8 +13316,7 @@ namespace {
// apply a combine.
struct CombineResult;

enum class ExtKind { ZExt, SExt, FPExt };

enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
Expand Down Expand Up @@ -13448,13 +13447,11 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;

unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;

MVT EltVT = SupportsExt == ExtKind::FPExt
? MVT::getFloatingPointVT(NarrowSize)
: MVT::getIntegerVT(NarrowSize);

assert(NarrowSize >= NarrowMinSize &&
assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
"Trying to extend something we can't represent");
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
return NarrowVT;
Expand Down Expand Up @@ -13823,33 +13820,32 @@ struct CombineResult {
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
/// are zext) and LHS and RHS can be folded into Root.
/// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
/// AllowExtMask define which form `ext` can take in this pattern.
///
/// \note If the pattern can match with both zext and sext, the returned
/// CombineResult will feature the zext result.
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
assert((AllowSExt || AllowZExt || AllowFPExt) &&
"Forgot to set what you want?");
static std::optional<CombineResult>
canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS,
uint8_t AllowExtMask, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::ZExt),
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
/*RHSExt=*/{ExtKind::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::SExt),
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
/*RHSExt=*/{ExtKind::SExt});
if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::FPExt),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
Expand All @@ -13867,9 +13863,9 @@ static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
/*AllowZExt=*/true,
/*AllowFPExt=*/true, DAG, Subtarget);
return canFoldToVWWithSameExtensionImpl(
Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
Expand Down Expand Up @@ -13911,9 +13907,8 @@ static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
/*AllowZExt=*/false,
/*AllowFPExt=*/false, DAG, Subtarget);
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
Expand All @@ -13924,9 +13919,8 @@ static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
/*AllowZExt=*/true,
/*AllowFPExt=*/false, DAG, Subtarget);
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
Expand All @@ -13937,9 +13931,8 @@ static std::optional<CombineResult>
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
/*AllowZExt=*/false,
/*AllowFPExt=*/true, DAG, Subtarget);
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
Expand Down
0