8000 [macros] allow forwarding generic arguments through macro declaration… · swiftlang/swift@5acd366 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5acd366

Browse files
authored
[macros] allow forwarding generic arguments through macro declarations (#71271)
* [macros] allow forwarding generic arguments through macro declarations [macros] add more tests for generic argument forwarding in macro declarations * [macros] correct replacement picking logic
1 parent e4d6bee commit 5acd366

File tree

11 files changed

+188
-26
lines changed

11 files changed

+188
-26
lines changed

include/swift/AST/MacroDefinition.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,26 @@ class ExpandedMacroDefinition {
7878
/// The macro replacements, ASTContext-allocated.
7979
ArrayRef<ExpandedMacroReplacement> replacements;
8080

81+
/// Same as above but for generic argument replacements
82+
ArrayRef<ExpandedMacroReplacement> genericReplacements;
83+
8184
ExpandedMacroDefinition(
8285
StringRef expansionText,
83-
ArrayRef<ExpandedMacroReplacement> replacements
84-
) : expansionText(expansionText), replacements(replacements) { }
86+
ArrayRef<ExpandedMacroReplacement> replacements,
87+
ArrayRef<ExpandedMacroReplacement> genericReplacements
88+
) : expansionText(expansionText),
89+
replacements(replacements),
90+
genericReplacements(genericReplacements) { }
8591

8692
public:
8793
StringRef getExpansionText() const { return expansionText; }
8894

8995
ArrayRef<ExpandedMacroReplacement> getReplacements() const {
9096
return replacements;
9197
}
98+
ArrayRef<ExpandedMacroReplacement> getGenericReplacements() const {
99+
return genericReplacements;
100+
}
92101
};
93102

94103
/// Provides the definition of a macro.
@@ -162,7 +171,8 @@ class MacroDefinition {
162171
static MacroDefinition forExpanded(
163172
ASTContext &ctx,
164173
StringRef expansionText,
165-
ArrayRef<ExpandedMacroReplacement> replacements
174+
ArrayRef<ExpandedMacroReplacement> replacements,
175+
ArrayRef<ExpandedMacroReplacement> genericReplacements
166176
);
167177

168178
/// Retrieve the external macro being referenced.

include/swift/Bridging/ASTGen.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ ptrdiff_t swift_ASTGen_checkMacroDefinition(
7373
const void *_Nonnull macroSourceLocation,
7474
BridgedStringRef *_Nonnull expansionSourceOutPtr,
7575
ptrdiff_t *_Nullable *_Nonnull replacementsPtr,
76-
ptrdiff_t *_Nonnull numReplacements);
76+
ptrdiff_t *_Nonnull numReplacements,
77+
ptrdiff_t *_Nullable *_Nonnull genericReplacementsPtr,
78+
ptrdiff_t *_Nonnull numGenericReplacements);
7779
void swift_ASTGen_freeExpansionReplacements(
7880
ptrdiff_t *_Nullable replacementsPtr, ptrdiff_t numReplacements);
7981

lib/AST/Decl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11853,10 +11853,12 @@ llvm::Optional<BuiltinMacroKind> MacroDecl::getBuiltinKind() const {
1185311853
MacroDefinition MacroDefinition::forExpanded(
1185411854
ASTContext &ctx,
1185511855
StringRef expansionText,
11856-
ArrayRef<ExpandedMacroReplacement> replacements
11856+
ArrayRef<ExpandedMacroReplacement> replacements,
11857+
ArrayRef<ExpandedMacroReplacement> genericReplacements
1185711858
) {
1185811859
return ExpandedMacroDefinition{ctx.AllocateCopy(expansionText),
11859-
ctx.AllocateCopy(replacements)};
11860+
ctx.AllocateCopy(replacements),
11861+
ctx.AllocateCopy(genericReplacements)};
1186011862
}
1186111863

1186211864
MacroExpansionDecl::MacroExpansionDecl(DeclContext *dc,

lib/ASTGen/Sources/ASTGen/Macros.swift

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ func checkMacroDefinition(
200200
macroLocationPtr: UnsafePointer<UInt8>,
201201
externalMacroOutPtr: UnsafeMutablePointer<BridgedStringRef>,
202202
replacementsPtr: UnsafeMutablePointer<UnsafeMutablePointer<Int>?>,
203-
numReplacementsPtr: UnsafeMutablePointer<Int>
203+
numReplacementsPtr: UnsafeMutablePointer<Int>,
204+
genericReplacementsPtr: UnsafeMutablePointer<UnsafeMutablePointer<Int>?>,
205+
numGenericReplacementsPtr: UnsafeMutablePointer<Int>
204206
) -> Int {
205207
// Assert "out" parameters are initialized.
206208
assert(externalMacroOutPtr.pointee.isEmptyInitialized)
@@ -293,7 +295,7 @@ func checkMacroDefinition(
293295
)
294296
return Int(BridgedMacroDefinitionKind.externalMacro.rawValue)
295297

296-
case let .expansion(expansionSyntax, replacements: _)
298+
case let .expansion(expansionSyntax, replacements: _, genericReplacements: _)
297299
where expansionSyntax.macroName.text == "externalMacro":
298300
// Extract the identifier from the "module" argument.
299301
guard let firstArg = expansionSyntax.arguments.first,
@@ -334,13 +336,15 @@ func checkMacroDefinition(
334336
allocateBridgedString("\(module).\(type)")
335337
return Int(BridgedMacroDefinitionKind.externalMacro.rawValue)
336338

337-
case let .expansion(expansionSyntax, replacements: replacements):
339+
case let .expansion(expansionSyntax,
340+
replacements: replacements, genericReplacements: genericReplacements):
338341
// Provide the expansion syntax.
339342
externalMacroOutPtr.pointee =
340343
allocateBridgedString(expansionSyntax.trimmedDescription)
341344

342345
// If there are no replacements, we're done.
343-
if replacements.isEmpty {
346+
let totalReplacementsCount = replacements.count + genericReplacements.count
347+
guard totalReplacementsCount > 0 else {
344348
return Int(BridgedMacroDefinitionKind.expandedMacro.rawValue)
345349
}
346350

@@ -355,9 +359,24 @@ func checkMacroDefinition(
355359
replacement.reference.endPositionBeforeTrailingTrivia.utf8Offset - expansionStart
356360
replacementBuffer[index * 3 + 2] = replacement.parameterIndex
357361
}
358-
359362
replacementsPtr.pointee = replacementBuffer.baseAddress
360363
numReplacementsPtr.pointee = replacements.count
364+
365+
// The replacements are triples: (startOffset, endOffset, parameter index).
366+
let genericReplacementBuffer = UnsafeMutableBufferPointer<Int>.allocate(capacity: 3 * genericReplacements.count)
367+
for (index, genericReplacement) in genericReplacements.enumerated() {
368+
let expansionStart = expansionSyntax.positionAfterSkippingLeadingTrivia.utf8Offset
369+
370+
genericReplacementBuffer[index * 3] =
371+
genericReplacement.reference.positionAfterSkippingLeadingTrivia.utf8Offset - expansionStart
372+
genericReplacementBuffer[index * 3 + 1] =
373+
genericReplacement.reference.endPositionBeforeTrailingTrivia.utf8Offset - expansionStart
374+
genericReplacementBuffer[index * 3 + 2] =
375+
genericReplacement.parameterIndex
376+
}
377+
genericReplacementsPtr.pointee = genericReplacementBuffer.baseAddress
378+
numGenericReplacementsPtr.pointee = genericReplacements.count
379+
361380
return Int(BridgedMacroDefinitionKind.expandedMacro.rawValue)
362381
#if RESILIENT_SWIFT_SYNTAX
363382
@unknown default:

lib/Sema/TypeCheckMacros.cpp

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,19 @@ MacroDefinition MacroDefinitionRequest::evaluate(
148148
BridgedStringRef externalMacroName{nullptr, 0};
149149
ptrdiff_t *replacements = nullptr;
150150
ptrdiff_t numReplacements = 0;
151+
ptrdiff_t *genericReplacements = nullptr;
152+
ptrdiff_t numGenericReplacements = 0;
151153
auto checkResult = swift_ASTGen_checkMacroDefinition(
152154
&ctx.Diags, sourceFile->getExportedSourceFile(),
153155
macro->getLoc().getOpaquePointerValue(), &externalMacroName,
154-
&replacements, &numReplacements);
156+
&replacements, &numReplacements,
157+
&genericReplacements, &numGenericReplacements);
155158

156159
// Clean up after the call.
157160
SWIFT_DEFER {
158161
swift_ASTGen_freeBridgedString(externalMacroName);
159162
swift_ASTGen_freeExpansionReplacements(replacements, numReplacements);
163+
// swift_ASTGen_freeExpansionGenericReplacements(genericReplacements, numGenericReplacements); // FIXME: !!!!!!
160164
};
161165

162166
if (checkResult < 0 && ctx.CompletionCallback) {
@@ -177,6 +181,7 @@ MacroDefinition MacroDefinitionRequest::evaluate(
177181
case BridgedExternalMacro: {
178182
// An external macro described as ModuleName.TypeName. Get both identifiers.
179183
assert(!replacements && "External macro doesn't have replacements");
184+
assert(!genericReplacements && "External macro doesn't have genericReplacements");
180185
StringRef externalMacroStr = externalMacroName.unbridged();
181186
StringRef externalModuleName, externalTypeName;
182187
std::tie(externalModuleName, externalTypeName) = externalMacroStr.split('.');
@@ -232,8 +237,16 @@ MacroDefinition MacroDefinitionRequest::evaluate(
232237
static_cast<unsigned>(replacements[3*i+1]),
233238
static_cast<unsigned>(replacements[3*i+2])});
234239
}
240+
// Copy over the genericReplacements.
241+
SmallVector<ExpandedMacroReplacement, 2> genericReplacementsVec;
242+
for (unsigned i: range(0, numGenericReplacements)) {
243+
genericReplacementsVec.push_back(
244+
{ static_cast<unsigned>(genericReplacements[3*i]),
245+
static_cast<unsigned>(genericReplacements[3*i+1]),
246+
static_cast<unsigned>(genericReplacements[3*i+2])});
247+
}
235248

236-
return MacroDefinition::forExpanded(ctx, expansionText, replacementsVec);
249+
return MacroDefinition::forExpanded(ctx, expansionText, replacementsVec, genericReplacementsVec);
237250
#else
238251
macro->diagnose(diag::macro_unsupported);
239252
return MacroDefinition::forInvalid();
@@ -781,24 +794,68 @@ static bool isFromExpansionOfMacro(SourceFile *sourceFile, MacroDecl *macro,
781794

782795
/// Expand a macro definition.
783796
static std::string expandMacroDefinition(
784-
ExpandedMacroDefinition def, MacroDecl *macro, ArgumentList *args) {
797+
ExpandedMacroDefinition def, MacroDecl *macro,
798+
SubstitutionMap subs,
799+
ArgumentList *args) {
785800
ASTContext &ctx = macro->getASTContext();
786801

787802
std::string expandedResult;
788803

789804
StringRef originalText = def.getExpansionText();
805+
790806
unsigned startIdx = 0;
791-
for (const auto replacement: def.getReplacements()) {
807+
unsigned replacementsIdx = 0;
808+
unsigned genericReplacementsIdx = 0;
809+
auto totalReplacementsCount =
810+
def.getReplacements().size() + def.getGenericReplacements().size();
811+
812+
while (replacementsIdx + genericReplacementsIdx < totalReplacementsCount) {
813+
ExpandedMacroReplacement replacement;
814+
bool isExpressionReplacement = true;
815+
816+
// Pick the "next" replacement, in order as they appear in the source text
817+
auto canPickExpressionReplacement = replacementsIdx < def.getReplacements().size();
818+
auto canPickGenericReplacement = genericReplacementsIdx < def.getGenericReplacements().size();
819+
if (canPickExpressionReplacement && canPickGenericReplacement) {
820+
auto expressionReplacement = def.getReplacements()[replacementsIdx];
821+
auto genericReplacement =
822+
def.getGenericReplacements()[genericReplacementsIdx];
823+
isExpressionReplacement =
824+
expressionReplacement.startOffset < genericReplacement.startOffset;
825+
replacement =
826+
isExpressionReplacement ? expressionReplacement : genericReplacement;
827+
} else if (canPickExpressionReplacement) {
828+
isExpressionReplacement = true;
829+
replacement = def.getReplacements()[replacementsIdx];
830+
} else if (canPickGenericReplacement) {
831+
isExpressionReplacement = false;
832+
replacement = def.getGenericReplacements()[replacementsIdx];
833+
} else {
834+
assert(false && "should always select a requirement explicitly rather "
835+
"than fall through");
836+
}
837+
838+
replacementsIdx += isExpressionReplacement ? 1 : 0;
839+
genericReplacementsIdx += isExpressionReplacement ? 0 : 1;
840+
792841
// Add the original text up to the first replacement.
793-
expandedResult.append(
794-
originalText.begin() + startIdx,
795-
originalText.begin() + replacement.startOffset);
842+
expandedResult.append(originalText.begin() + startIdx,
843+
originalText.begin() + replacement.startOffset);
796844

797845
// Add the replacement text.
798-
auto argExpr = args->getArgExprs()[replacement.parameterIndex];
799-
SmallString<32> argTextBuffer;
800-
auto argText = extractInlinableText(ctx.SourceMgr, argExpr, argTextBuffer);
801-
expandedResult.append(argText);
846+
if (isExpressionReplacement) {
847+
auto argExpr = args->getArgExprs()[replacement.parameterIndex];
848+
SmallString<32> argTextBuffer;
849+
auto argText =
850+
extractInlinableText(ctx.SourceMgr, argExpr, argTextBuffer);
851+
expandedResult.append(argText);
852+
} else {
853+
auto typeArgType = subs.getReplacementTypes()[replacement.parameterIndex];
854+
std::string typeNameString;
855+
llvm::raw_string_ostream os(typeNameString);
856+
typeArgType.print(os);
857+
expandedResult.append(typeNameString);
858+
}
802859

803860
// Update the starting position.
804861
startIdx = replacement.endOffset;
@@ -1093,6 +1150,7 @@ evaluateFreestandingMacro(FreestandingMacroExpansion *expansion,
10931150
case MacroDefinition::Kind::Expanded: {
10941151
// Expand the definition with the given arguments.
10951152
auto result = expandMacroDefinition(macroDef.getExpanded(), macro,
1153+
expansion->getMacroRef().getSubstitutions(),
10961154
expansion->getArgs());
10971155
evaluatedSource = llvm::MemoryBuffer::getMemBufferCopy(
10981156
result, adjustMacroExpansionBufferName(*discriminator));
@@ -1397,7 +1455,9 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
13971455
case MacroDefinition::Kind::Expanded: {
13981456
// Expand the definition with the given arguments.
13991457
auto result = expandMacroDefinition(
1400-
macroDef.getExpanded(), macro, attr->getArgs());
1458+
macroDef.getExpanded(), macro,
1459+
/*genericArgs=*/{}, // attached macros don't have generic parameters
1460+
attr->getArgs());
14011461
evaluatedSource = llvm::MemoryBuffer::getMemBufferCopy(
14021462
result, adjustMacroExpansionBufferName(*discriminator));
14031463
break;

lib/Serialization/Deserialization.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5229,6 +5229,7 @@ class DeclDeserializer {
52295229

52305230
// Macro replacements block.
52315231
SmallVector<ExpandedMacroReplacement, 2> replacements;
5232+
SmallVector<ExpandedMacroReplacement, 2> genericReplacements;
52325233
if (hasReplacements) {
52335234
llvm::BitstreamEntry entry =
52345235
MF.fatalIfUnexpected(
@@ -5253,12 +5254,26 @@ class DeclDeserializer {
52535254
replacements.push_back(replacement);
52545255
}
52555256
}
5257+
5258+
ArrayRef<uint64_t> serializedGenericReplacements;
5259+
decls_block::ExpandedMacroReplacementsLayout::readRecord(
5260+
scratch, serializedGenericReplacements);
5261+
if (serializedGenericReplacements.size() % 3 == 0) {
5262+
for (unsigned i : range(0, serializedGenericReplacements.size() / 3)) {
5263+
ExpandedMacroReplacement genericReplacement{
5264+
static_cast<unsigned>(serializedGenericReplacements[3*i]),
5265+
static_cast<unsigned>(serializedGenericReplacements[3*i + 1]),
5266+
static_cast<unsigned>(serializedGenericReplacements[3*i + 2])
5267+
};
5268+
genericReplacements.push_back(genericReplacement);
5269+
}
5270+
}
52565271
}
52575272
}
52585273

52595274
ctx.evaluator.cacheOutput(
52605275
MacroDefinitionRequest{macro},
5261-
MacroDefinition::forExpanded(ctx, expansionText, replacements)
5276+
MacroDefinition::forExpanded(ctx, expansionText, replacements, genericReplacements)
52625277
);
52635278
}
52645279

test/Macros/Inputs/freestanding_macro_library.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,20 @@ public macro stringify<T>(_ value: T) -> (T, String) = #externalMacro(module: "M
1818

1919
@freestanding(declaration, names: named(value))
2020
public macro varValue() = #externalMacro(module: "MacroDefinition", type: "VarValueMacro")
21+
22+
// Macros that pass along generic arguments
23+
24+
@freestanding(expression)
25+
public macro checkGeneric_root<DAS>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
26+
@freestanding(expression)
27+
public macro checkGeneric<DAS>() = #checkGeneric_root<DAS>()
28+
29+
@freestanding(expression)
30+
public macro checkGeneric2_root<A, B>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
31+
@freestanding(expression)
32+
public macro checkGeneric2<A, B>() = #checkGeneric2_root<A, B>()
33+
34+
@freestanding(expression)
35+
public macro checkGenericHashableCodable_root<A: Hashable, B: Codable>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
36+
@freestanding(expression)
37+
public macro checkGenericHashableCodable<A: Hashable, B: Codable>() = #checkGenericHashableCodable_root<A, B>()

test/Macros/Inputs/syntax_macro_definitions.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,17 @@ public struct VarValueMacro: DeclarationMacro, PeerMacro {
18191819
}
18201820
}
18211821

1822+
public struct GenericToVoidMacro: ExpressionMacro {
1823+
public static func expansion(
1824+
of node: some FreestandingMacroExpansionSyntax,
1825+
in context: some MacroExpansionContext
1826+
) throws -> ExprSyntax {
1827+
return """
1828+
()
1829+
"""
1830+
}
1831+
}
1832+
18221833
public struct SingleMemberMacro: MemberMacro {
18231834
public static func expansion(
18241835
of node: AttributeSyntax,

test/Macros/top_level_freestanding.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ macro freestandingWithClosure<T>(_ value: T, body: (T) -> T) = #externalMacro(mo
3535
@freestanding(declaration, names: arbitrary) macro bitwidthNumberedStructs(_ baseName: String) = #externalMacro(module: "MacroDefinition", type: "DefineBitwidthNumberedStructsMacro")
3636
@freestanding(expression) macro stringify<T>(_ value: T) -> (T, String) = #externalMacro(module: "MacroDefinition", type: "StringifyMacro")
3737
@freestanding(declaration, names: named(value)) macro varValue() = #externalMacro(module: "MacroDefinition", type: "VarValueMacro")
38+
39+
@freestanding(expression) macro checkGeneric_root<A>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
40+
@freestanding(expression) macro checkGeneric<A>() = #checkGeneric_root<A>()
41+
42+
@freestanding(expression) macro checkGeneric2_root<A, B>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
43+
@freestanding(expression) macro checkGeneric2<A, B>() = #checkGeneric2_root<A, B>()
44+
45+
@freestanding(expression) macro checkGenericHashableCodable_root<A: Hashable, B: Codable>() = #externalMacro(module: "MacroDefinition", type: "GenericToVoidMacro")
46+
@freestanding(expression) macro checkGenericHashableCodable<A: Hashable, B: Codable>() = #checkGenericHashableCodable_root<A, B>()
47+
3848
#endif
3949

4050
// Test unqualified lookup from within a macro expansion
@@ -128,3 +138,10 @@ protocol Initializable {
128138
struct S {
129139
init(a: Int, b: Int) {}
130140
}
141+
142+
// Check that generic type arguments are passed along in expansions,
143+
// when macro is implemented using another macro.
144+
145+
#checkGeneric<String>()
146+
#checkGeneric2<String, Int>()
147+
#checkGenericHashableCodable<String, Int>()

tools/SourceKit/include/SourceKit/Core/LangSupport.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ struct MacroExpansionInfo {
262262
};
263263
std::string expansionText;
264264
std::vector<Replacement> replacements;
265+
std::vector<Replacement> genericReplacements;
265266

266267
ExpandedMacroDefinition(StringRef expansionText)
267-
: expansionText(expansionText), replacements(){};
268+
: expansionText(expansionText), replacements(), genericReplacements() {};
268269
};
269270

270271
// Offset of the macro expansion syntax (i.e. attribute or #<macro name>) from

0 commit comments

Comments
 (0)
0