8000 [macros] allow forwarding generic arguments through macro declarations by ktoso · Pull Request #71271 · swiftlang/swift · GitHub
[go: up one dir, main page]

Skip to content

[macros] allow forwarding generic arguments through macro declarations #71271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions include/swift/AST/MacroDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,26 @@ class ExpandedMacroDefinition {
/// The macro replacements, ASTContext-allocated.
ArrayRef<ExpandedMacroReplacement> replacements;

/// Same as above but for generic argument replacements
ArrayRef<ExpandedMacroReplacement> genericReplacements;

ExpandedMacroDefinition(
StringRef expansionText,
ArrayRef<ExpandedMacroReplacement> replacements
) : expansionText(expansionText), replacements(replacements) { }
ArrayRef<ExpandedMacroReplacement> replacements,
ArrayRef<ExpandedMacroReplacement> genericReplacements
) : expansionText(expansionText),
replacements(replacements),
genericReplacements(genericReplacements) { }

public:
StringRef getExpansionText() const { return expansionText; }

ArrayRef<ExpandedMacroReplacement> getReplacements() const {
return replacements;
}
ArrayRef<ExpandedMacroReplacement> getGenericReplacements() const {
return genericReplacements;
}
};

/// Provides the definition of a macro.
Expand Down Expand Up @@ -162,7 +171,8 @@ class MacroDefinition {
static MacroDefinition forExpanded(
ASTContext &ctx,
StringRef expansionText,
ArrayRef<ExpandedMacroReplacement> replacements
ArrayRef<ExpandedMacroReplacement> replacements,
ArrayRef<ExpandedMacroReplacement> genericReplacements
);

/// Retrieve the external macro being referenced.
Expand Down
4 changes: 3 additions & 1 deletion include/swift/Bridging/ASTGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ ptrdiff_t swift_ASTGen_checkMacroDefinition(
const void *_Nonnull macroSourceLocation,
BridgedStringRef *_Nonnull expansionSourceOutPtr,
ptrdiff_t *_Nullable *_Nonnull replacementsPtr,
ptrdiff_t *_Nonnull numReplacements);
ptrdiff_t *_Nonnull numReplacements,
ptrdiff_t *_Nullable *_Nonnull genericReplacementsPtr,
ptrdiff_t *_Nonnull numGenericReplacements);
void swift_ASTGen_freeExpansionReplacements(
ptrdiff_t *_Nullable replacementsPtr, ptrdiff_t numReplacements);

Expand Down
6 changes: 4 additions & 2 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11846,10 +11846,12 @@ llvm::Optional<BuiltinMacroKind> MacroDecl::getBuiltinKind() const {
MacroDefinition MacroDefinition::forExpanded(
ASTContext &ctx,
StringRef expansionText,
ArrayRef<ExpandedMacroReplacement> replacements
ArrayRef<ExpandedMacroReplacement> replacements,
ArrayRef<ExpandedMacroReplacement> genericReplacements
) {
return ExpandedMacroDefinition{ctx.AllocateCopy(expansionText),
ctx.AllocateCopy(replacements)};
ctx.AllocateCopy(replacements),
ctx.AllocateCopy(genericReplacements)};
}

MacroExpansionDecl::MacroExpansionDecl(DeclContext *dc,
Expand Down
29 changes: 24 additions & 5 deletions lib/ASTGen/Sources/ASTGen/Macros.swift
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ func checkMacroDefinition(
macroLocationPtr: UnsafePointer<UInt8>,
externalMacroOutPtr: UnsafeMutablePointer<BridgedStringRef>,
replacementsPtr: UnsafeMutablePointer<UnsafeMutablePointer<Int>?>,
numReplacementsPtr: UnsafeMutablePointer<Int>
numReplacementsPtr: UnsafeMutablePointer<Int>,
genericReplacementsPtr: UnsafeMutablePointer<UnsafeMutablePointer<Int>?>,
numGenericReplacementsPtr: UnsafeMutablePointer<Int>
) -> Int {
// Assert "out" parameters are initialized.
assert(externalMacroOutPtr.pointee.isEmptyInitialized)
Expand Down Expand Up @@ -293,7 +295,7 @@ func checkMacroDefinition(
)
return Int(BridgedMacroDefinitionKind.externalMacro.rawValue)

case let .expansion(expansionSyntax, replacements: _)
case let .expansion(expansionSyntax, replacements: _, genericReplacements: _)
where expansionSyntax.macroName.text == "externalMacro":
// Extract the identifier from the "module" argument.
guard let firstArg = expansionSyntax.arguments.first,
Expand Down Expand Up @@ -334,13 +336,15 @@ func checkMacroDefinition(
allocateBridgedString("\(module).\(type)")
retur 8000 n Int(BridgedMacroDefinitionKind.externalMacro.rawValue)

case let .expansion(expansionSyntax, replacements: replacements):
case let .expansion(expansionSyntax,
replacements: replacements, genericReplacements: genericReplacements):
// Provide the expansion syntax.
externalMacroOutPtr.pointee =
allocateBridgedString(expansionSyntax.trimmedDescription)

// If there are no replacements, we're done.
if replacements.isEmpty {
let totalReplacementsCount = replacements.count + genericReplacements.count
guard totalReplacementsCount > 0 else {
return Int(BridgedMacroDefinitionKind.expandedMacro.rawValue)
}

Expand All @@ -355,9 +359,24 @@ func checkMacroDefinition(
replacement.reference.endPositionBeforeTrailingTrivia.utf8Offset - expansionStart
replacementBuffer[index * 3 + 2] = replacement.parameterIndex
}

replacementsPtr.pointee = replacementBuffer.baseAddress
numReplacementsPtr.pointee = replacements.count

// The replacements are triples: (startOffset, endOffset, parameter index).
let genericReplacementBuffer = UnsafeMutableBufferPointer<Int>.allocate(capacity: 3 * genericReplacements.count)
for (index, genericReplacement) in genericReplacements.enumerated() {
let expansionStart = expansionSyntax.positionAfterSkippingLeadingTrivia.utf8Offset

genericReplacementBuffer[index * 3] =
genericReplacement.reference.positionAfterSkippingLeadingTrivia.utf8Offset - expansionStart
genericReplacementBuffer[index * 3 + 1] =
genericReplacement.reference.endPositionBeforeTrailingTrivia.utf8Offset - expansionStart
genericReplacementBuffer[index * 3 + 2] =
genericReplacement.parameterIndex
}
genericReplacementsPtr.pointee = genericReplacementBuffer.baseAddress
numGenericReplacementsPtr.pointee = genericReplacements.count

return Int(BridgedMacroDefinitionKind.expandedMacro.rawValue)
}
} catch let errDiags as DiagnosticsError {
Expand Down
84 changes: 72 additions & 12 deletions lib/Sema/TypeCheckMacros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,19 @@ MacroDefinition MacroDefinitionRequest::evaluate(
BridgedStringRef externalMacroName{nullptr, 0};
ptrdiff_t *replacements = nullptr;
ptrdiff_t numReplacements = 0;
ptrdiff_t *genericReplacements = nullptr;
ptrdiff_t numGenericReplacements = 0;
auto checkResult = swift_ASTGen_checkMacroDefinition(
&ctx.Diags, sourceFile->getExportedSourceFile(),
macro->getLoc().getOpaquePointerValue(), &externalMacroName,
&replacements, &numReplacements);
&replacements, &numReplacements,
&genericReplacements, &numGenericReplacements);

// Clean up after the call.
SWIFT_DEFER {
swift_ASTGen_freeBridgedString(externalMacroName);
swift_ASTGen_freeExpansionReplacements(replacements, numReplacements);
// swift_ASTGen_freeExpansionGenericReplacements(genericReplacements, numGenericReplacements); // FIXME: !!!!!!
};

if (checkResult < 0 && ctx.CompletionCallback) {
Expand All @@ -177,6 +181,7 @@ MacroDefinition MacroDefinitionRequest::evaluate(
case BridgedExternalMacro: {
// An external macro described as ModuleName.TypeName. Get both identifiers.
assert(!replacements && "External macro doesn't have replacements");
assert(!genericReplacements && "External macro doesn't have genericReplacements");
StringRef externalMacroStr = externalMacroName.unbridged();
StringRef externalModuleName, externalTypeName;
std::tie(externalModuleName, externalTypeName) = externalMacroStr.split('.');
Expand Down Expand Up @@ -232,8 +237,16 @@ MacroDefinition MacroDefinitionRequest::evaluate(
static_cast<unsigned>(replacements[3*i+1]),
static_cast<unsigned>(replacements[3*i+2])});
}
// Copy over the genericReplacements.
SmallVector<ExpandedMacroReplacement, 2> genericReplacementsVec;
for (unsigned i: range(0, numGenericReplacements)) {
genericReplacementsVec.push_back(
{ static_cast<unsigned>(genericReplacements[3*i]),
static_cast<unsigned>(genericReplacements[3*i+1]),
static_cast<unsigned>(genericReplacements[3*i+2])});
}

return MacroDefinition::forExpanded(ctx, expansionText, replacementsVec);
return MacroDefinition::forExpanded(ctx, expansionText, replacementsVec, genericReplacementsVec);
#else
macro->diagnose(diag::macro_unsupported);
return MacroDefinition::forInvalid();
Expand Down Expand Up @@ -781,24 +794,68 @@ static bool isFromExpansionOfMacro(SourceFile *sourceFile, MacroDecl *macro,

/// Expand a macro definition.
static std::string expandMacroDefinition(
ExpandedMacroDefinition def, MacroDecl *macro, ArgumentList *args) {
ExpandedMacroDefinition def, MacroDecl *macro,
SubstitutionMap subs,
ArgumentList *args) {
ASTContext &ctx = macro->getASTContext();

std::string expandedResult;

StringRef originalText = def.getExpansionText();

unsigned startIdx = 0;
for (const auto replacement: def.getReplacements()) {
unsigned replacementsIdx = 0;
unsigned genericReplacementsIdx = 0;
auto totalReplacementsCount =
def.getReplacements().size() + def.getGenericReplacements().size();

while (replacementsIdx + genericReplacementsIdx < totalReplacementsCount) {
ExpandedMacroReplacement replacement;
bool isExpressionReplacement = true;

// Pick the "next" replacement, in order as they appear in the source text
auto canPickExpressionReplacement = replacementsIdx < def.getReplacements().size();
auto canPickGenericReplacement = genericReplacementsIdx < def.getGenericReplacements().size();
if (canPickExpressionReplacement && canPickGenericReplacement) {
auto expressionReplacement = def.getReplacements()[replacementsIdx];
auto genericReplacement =
def.getGenericReplacements()[genericReplacementsIdx];
isExpressionReplacement =
expressionReplacement.startOffset < genericReplacement.startOffset;
replacement =
isExpressionReplacement ? expressionReplacement : genericReplacement;
} else if (canPickExpressionReplacement) {
isExpressionReplacement = true;
replacement = def.getReplacements()[replacementsIdx];
} else if (canPickGenericReplacement) {
isExpressionReplacement = false;
replacement = def.getGenericReplacements()[replacementsIdx];
} else {
assert(false && "should always select a requirement explicitly rather "
"than fall through");
}

replacementsIdx += isExpressionReplacement ? 1 : 0;
genericReplacementsIdx += isExpressionReplacement ? 0 : 1;

// Add the original text up to the first replacement.
expandedResult.append(
originalText.begin() + startIdx,
originalText.begin() + replacement.startOffset);
expandedResult.append(originalText.begin() + startIdx,
originalText.begin() + replacement.startOffset);

// Add the replacement text.
auto argExpr = args->getArgExprs()[replacement.parameterIndex];
SmallString<32> argTextBuffer;
auto argText = extractInlinableText(ctx.SourceMgr, argExpr, argTextBuffer);
expandedResult.append(argText);
if (isExpressionReplacement) {
auto argExpr = args->getArgExprs()[replacement.parameterIndex];
SmallString<32> argTextBuffer;
auto argText =
extractInlinableText(ctx.SourceMgr, argExpr, argTextBuffer);
expandedResult.append(argText);
} else {
auto typeArgType = subs.getReplacementTypes()[replacement.parameterIndex];
std::string typeNameString;
llvm::raw_string_ostream os(typeNameString);
typeArgType.print(os);
expandedResult.append(typeNameString);
}

// Update the starting position.
startIdx = replacement.endOffset;
Expand Down Expand Up @@ -1093,6 +1150,7 @@ evaluateFreestandingMacro(FreestandingMacroExpansion *expansion,
case MacroDefinition::Kind::Expanded: {
// Expand the definition with the given arguments.
auto result = expandMacroDefinition(macroDef.getExpanded(), macro,
expansion->getMacroRef().getSubstitutions(),
expansion->getArgs());
evaluatedSource = llvm::MemoryBuffer::getMemBufferCopy(
result, adjustMacroExpansionBufferName(*discriminator));
Expand Down Expand Up @@ -1397,7 +1455,9 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
case MacroDefinition::Kind::Expanded: {
// Expand the definition with the given arguments.
auto result = expandMacroDefinition(
macroDef.getExpanded(), macro, attr->getArgs());
macroDef.getExpanded(), macro,
/*genericArgs=*/{}, // attached macros don't have generic parameters
attr->getArgs());
evaluatedSource = llvm::MemoryBuffer::getMemBufferCopy(
result, adjustMacroExpansionBufferName(*discriminator));
break;
Expand Down
17 changes: 16 additions & 1 deletion lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5208,6 +5208,7 @@ class DeclDeserializer {

// Macro replacements block.
SmallVector<ExpandedMacroReplacement, 2> replacements;
SmallVector<ExpandedMacroReplacement, 2> genericReplacements;
if (hasReplacements) {
llvm::BitstreamEntry entry =
MF.fatalIfUnexpected(
Expand All @@ -5232,12 +5233,26 @@ class DeclDeserializer {
replacements.push_back(replacement);
}
}

ArrayRef<uint64_t> serializedGenericReplacements;
decls_block::ExpandedMacroReplacementsLayout::readRecord(
scratch, serializedGenericReplacements);
if (serializedGenericReplacements.size() % 3 == 0) {
for (unsigned i : range(0, serializedGenericReplacements.size() / 3)) {
ExpandedMacroReplacement genericReplacement{
static_cast<unsigned>(serializedGenericReplacements[3*i]),
static_cast<unsigned>(serializedGenericReplacements[3*i + 1]),
static_cast<unsigned>(serializedGenericReplacements[3*i + 2])
};
genericReplacements.push_back(genericReplacement);
}
}
}
}

ctx.evaluator.cacheOutput(
MacroDefinitionRequest{macro},
MacroDefinition::forExpanded(ctx, expansionText, replacements)
MacroDefinition::forExpanded(ctx, expansionText, replacements, genericReplacements)
);
}

Expand Down
17 changes: 17 additions & 0 deletions test/Macros/Inputs/freestanding_macro_library.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,20 @@ public macro stringify<T>(_ value: T) -> (T, String) = #externalMacro(module: "M

@freestanding(declaration, names: named(value))
public macro varValue() = #externalMacro(module: "MacroDefinition", type: "VarValueMacro")

// Macros that pass along generic arguments

@freestanding(expression)
public macro checkGeneric_root<DAS>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression)
public macro checkGeneric<DAS>() = #checkGeneric_root<DAS>()

@freestanding(expression)
public macro checkGeneric2_root<A, B>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression)
public macro checkGeneric2<A, B>() = #checkGeneric2_root<A, B>()

@freestanding(expression)
public macro checkGenericHashableCodable_root<A: Hashable, B: Codable>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression)
public macro checkGenericHashableCodable<A: Hashable, B: Codable>() = #checkGenericHashableCodable_root<A, B>()
11 changes: 11 additions & 0 deletions test/Macros/Inputs/syntax_macro_definitions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,17 @@ public struct VarValueMacro: DeclarationMacro, PeerMacro {
}
}

public struct GenericToVoidMacro: ExpressionMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
return """
()
"""
}
}

public struct SingleMemberMacro: MemberMacro {
public static func expansion(
of node: AttributeSyntax,
Expand Down
17 changes: 17 additions & 0 deletions test/Macros/top_level_freestanding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ macro freestandingWithClosure<T>(_ value: T, body: (T) -> T) = #externalMacro(mo
@freestanding(declaration, names: arbitrary) macro bitwidthNumberedStructs(_ baseName: String) = #externalMacro(module: "MacroDefinition", type: "DefineBitwidthNumberedStructsMacro")
@freestanding(expression) macro stringify<T>(_ value: T) -> (T, String) = #externalMacro(module: "MacroDefinition", type: "StringifyMacro")
@freestanding(declaration, names: named(value)) macro varValue() = #externalMacro(module: "MacroDefinition", type: "VarValueMacro")

@freestanding(expression) macro checkGeneric_root<A>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression) macro checkGeneric<A>() = #checkGeneric_root<A>()

@freestanding(expression) macro checkGeneric2_root<A, B>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression) macro checkGeneric2<A, B>() = #checkGeneric2_root<A, B>()

@freestanding(expression) macro checkGenericHashableCodable_root<A: Hashable, B: Codable>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
@freestanding(expression) macro checkGenericHashableCodable<A: Hashable, B: Codable>() = #checkGenericHashableCodable_root<A, B>()

#endif

// Test unqualified lookup from within a macro expansion
Expand Down Expand Up @@ -128,3 +138,10 @@ protocol Initializable {
struct S {
init(a: Int, b: Int) {}
}

// Check that generic type arguments are passed along in expansions,
// when macro is implemented using another macro.

#checkGeneric<String>()
#checkGeneric2<String, Int>()
#checkGenericHashableCodable<String, Int>()
3 changes: 2 additions & 1 deletion tools/SourceKit/include/SourceKit/Core/LangSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,10 @@ struct MacroExpansionInfo {
};
std::string expansionText;
std::vector<Replacement> replacements;
std::vector<Replacement> genericReplacements;

ExpandedMacroDefinition(StringRef expansionText)
: expansionText(expansionText), replacements(){};
: expansionText(expansionText), replacements(), genericReplacements() {};
};

// Offset of the macro expansion syntax (i.e. attribute or #<macro name>) from
Expand Down
Loading
0