8000 [WGSL] Add pass to rewrite entry points to be compatible with Metal · WebKit/WebKit@021a45c · GitHub
[go: up one dir, main page]

Skip to content

Commit 021a45c

Browse files
committed
[WGSL] Add pass to rewrite entry points to be compatible with Metal
https://bugs.webkit.org/show_bug.cgi?id=250832 <rdar://problem/104425157> Reviewed by Myles C. Maxfield. For entry point function, builtin inputs contained in structs must be hoisted into function parameters, and any non-builtin parameters must be moved into a `stage_in` struct. The implementation can't quite handle absolutely every use case yet, but the patch was getting pretty large so I thought I pause here and commit something that works. A couple other things had to happen to get this initial implementation working: - I needed some basic type information, so I added a very rudimentary pass to resolve type names and a new struct type. - I also had to modify the AST to use `Ref`/`RefPtr` instead of `UniqueRef`/`unique_ptr` for types, since we need to be able to copy nodes when moving parameters around. - Fixed a couple small serialization bugs with missing semicolons in MetalFunctionWriter and added serialization for matrix types. * Source/WebGPU/WGSL/AST/ASTAttribute.h: * Source/WebGPU/WGSL/AST/ASTBuiltinAttribute.h: * Source/WebGPU/WGSL/AST/ASTCallableExpression.h: * Source/WebGPU/WGSL/AST/ASTForward.h: * Source/WebGPU/WGSL/AST/ASTFunctionDecl.h: * Source/WebGPU/WGSL/AST/ASTNode.h: * Source/WebGPU/WGSL/AST/ASTStringDumper.cpp: (WGSL::AST::StringDumper::visit): * Source/WebGPU/WGSL/AST/ASTStringDumper.h: * Source/WebGPU/WGSL/AST/ASTStructureDecl.h: * Source/WebGPU/WGSL/AST/ASTTypeDecl.h: (WGSL::AST::ParameterizedType::ParameterizedType): (isType): * Source/WebGPU/WGSL/AST/ASTVariableDecl.h: * Source/WebGPU/WGSL/AST/ASTVisitor.cpp: (WGSL::AST::Visitor::visit): * Source/WebGPU/WGSL/AST/ASTVisitor.h: * Source/WebGPU/WGSL/EntryPointRewriter.cpp: Added. (WGSL::EntryPointRewriter::EntryPointRewriter): (WGSL::EntryPointRewriter::getResolvedType): (WGSL::EntryPointRewriter::rewrite): (WGSL::EntryPointRewriter::collectParameters): (WGSL::EntryPointRewriter::constructInputStruct): (WGSL::EntryPointRewriter::materialize): (WGSL::EntryPointRewriter::visit): (WGSL::EntryPointRewriter::appendBuiltins): (WGSL::RewriteEntryPoints::RewriteEntryPoints): (WGSL::RewriteEntryPoints::visit): (WGSL::rewriteEntryPoints): * Source/WebGPU/WGSL/EntryPointRewriter.h: Added. * Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp: (WGSL::Metal::FunctionDefinitionWriter::visit): * Source/WebGPU/WGSL/Parser.cpp: (WGSL::Parser<Lexer>::parseAttribute): (WGSL::Parser<Lexer>::parseTypeDecl): (WGSL::Parser<Lexer>::parseTypeDeclAfterIdentifier): (WGSL::Parser<Lexer>::parseArrayType): (WGSL::Parser<Lexer>::parseVariableDeclWithAttributes): (WGSL::Parser<Lexer>::parseFunctionDecl): (WGSL::Parser<Lexer>::parseStatement): (WGSL::Parser<Lexer>::parseUnaryExpression): (WGSL::Parser<Lexer>::parsePrimaryExpression): (WGSL::Parser<Lexer>::parseCoreLHSExpression): * Source/WebGPU/WGSL/ParserPrivate.h: * Source/WebGPU/WGSL/PhaseTimer.h: Added. (WGSL::dumpASTIfNeeded): (WGSL::dumpASTAfterParsingIfNeeded): (WGSL::dumpASTBetweenEachPassIfNeeded): (WGSL::dumpASTAtEndIfNeeded): (WGSL::logPhaseTimes): (WGSL::PhaseTimer::PhaseTimer): (WGSL::PhaseTimer::~PhaseTimer): * Source/WebGPU/WGSL/ResolveTypeReferences.cpp: Added. (WGSL::ResolveTypeReferences::ResolveTypeReferences): (WGSL::ResolveTypeReferences::visit): (WGSL::resolveTypeReferences): * Source/WebGPU/WGSL/ResolveTypeReferences.h: Added. * Source/WebGPU/WGSL/WGSL.cpp: (WGSL::prepare): * Source/WebGPU/WebGPU.xcodeproj/project.pbxproj: * Tools/TestWebKitAPI/Tests/WGSL/ParserTests.cpp: (TestWGSLAPI::TEST): * Websites/webkit.org/demos/webgpu/scripts/hello-triangle.js: (async helloTriangle): Canonical link: https://commits.webkit.org/259149@main
1 parent e6ff9ff commit 021a45c

25 files changed

+819
-130
lines changed

Source/WebGPU/WGSL/AST/ASTAttribute.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,16 @@
2727

2828
#include "ASTNode.h"
2929

30-
#include <wtf/UniqueRef.h>
31-
#include <wtf/UniqueRefVector.h>
30+
#include <wtf/RefCounted.h>
3231
#include <wtf/Vector.h>
3332

3433
namespace WGSL::AST {
3534

36-
class Attribute : public Node {
35+
class Attribute : public Node, public RefCounted<Attribute> {
3736
WTF_MAKE_FAST_ALLOCATED;
3837

3938
public:
40-
using List = UniqueRefVector<Attribute, 2>;
39+
using List = Vector<Ref<Attribute>, < 67E6 span class="pl-c1">2>;
4140

4241
Attribute(SourceSpan span)
4342
: Node(span)

Source/WebGPU/WGSL/AST/ASTBuiltinAttribute.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#pragma once
2727

2828
#include "ASTAttribute.h"
29+
#include <wtf/text/WTFString.h>
2930

3031
namespace WGSL::AST {
3132

Source/WebGPU/WGSL/AST/ASTCallableExpression.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@ class CallableExpression final : public Expression {
4141
WTF_MAKE_FAST_ALLOCATED;
4242

4343
public:
44-
CallableExpression(SourceSpan span, UniqueRef<TypeDecl>&& target, Expression::List&& arguments)
44+
CallableExpression(SourceSpan span, Ref<TypeDecl>&& target, Expression::List&& arguments)
4545
: Expression(span)
4646
, m_target(WTFMove(target))
4747
, m_arguments(WTFMove(arguments))
4848
{
4949
}
5050

5151
Kind kind() const override;
52-
TypeDecl& target() { return m_target; }
52+
TypeDecl& target() { return m_target.get(); }
5353
Expression::List& arguments() { return m_arguments; }
5454

5555
private:
5656
// If m_target is a NamedType, it could either be a:
5757
// * Type that does not accept parameters (bool, i32, u32, ...)
5858
// * Identifier that refers to a type alias.
5959
// * Identifier that refers to a function.
60-
UniqueRef<TypeDecl> m_target;
60+
Ref<TypeDecl> m_target;
6161
Expression::List m_arguments;
6262
};
6363

Source/WebGPU/WGSL/AST/ASTForward.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class TypeDecl;
6666
class ArrayType;
6767
class NamedType;
6868
class ParameterizedType;
69+
class StructType;
70+
class TypeReference;
6971

7072
class Parameter;
7173
class StructMember;

Source/WebGPU/WGSL/AST/ASTFunctionDecl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Parameter final : public Node {
4141
public:
4242
using List = UniqueRefVector<Parameter>;
4343

44-
Parameter(SourceSpan span, const String& name, UniqueRef<TypeDecl>&& type, Attribute::List&& attributes)
44+
Parameter(SourceSpan span, const String& name, Ref<TypeDecl>&& type, Attribute::List&& attributes)
4545
: Node(span)
4646
, m_name(name)
4747
, m_type(WTFMove(type))
@@ -51,12 +51,12 @@ class Parameter final : public Node {
5151

5252
Kind kind() const override;
5353
const String& name() const { return m_name; }
54-
TypeDecl& type() { return m_type; }
54+
TypeDecl& type() { return m_type.get(); }
5555
Attribute::List& attributes() { return m_attributes; }
5656

5757
private:
5858
String m_name;
59-
UniqueRef<TypeDecl> m_type;
59+
Ref<TypeDecl> m_type;
6060
Attribute::List m_attributes;
6161
};
6262

@@ -66,7 +66,7 @@ class FunctionDecl final : public Decl {
6666
public:
6767
using List = UniqueRefVector<FunctionDecl>;
6868

69-
FunctionDecl(SourceSpan sourceSpan, const String& name, Parameter::List&& parameters, std::unique_ptr<TypeDecl>&& returnType, CompoundStatement&& body, Attribute::List&& attributes, Attribute::List&& returnAttributes)
69+
FunctionDecl(SourceSpan sourceSpan, const String& name, Parameter::List&& parameters, RefPtr<TypeDecl>&& returnType, CompoundStatement&& body, Attribute::List&& attributes, Attribute::List&& returnAttributes)
7070
: Decl(sourceSpan)
7171
, m_name(name)
7272
, m_parameters(WTFMove(parameters))
@@ -90,7 +90,7 @@ class FunctionDecl final : public Decl {
9090
Parameter::List m_parameters;
9191
Attribute::List m_attributes;
9292
Attribute::List m_returnAttributes;
93-
std::unique_ptr<TypeDecl> m_returnType;
93+
RefPtr<TypeDecl> m_returnType;
9494
CompoundStatement m_body;
9595
};
9696

Source/WebGPU/WGSL/AST/ASTNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class Node {
7777
ArrayType,
7878
NamedType,
7979
ParameterizedType,
80+
StructType,
81+
TypeReference,
8082

8183
Parameter,
8284
StructMember,

Source/WebGPU/WGSL/AST/ASTStringDumper.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,16 @@ void StringDumper::visit(ParameterizedType& type)
359359
m_out.print(">");
360360
}
361361

362+
void StringDumper::visit(StructType& type)
363+
{
364+
m_out.print(type.structDecl().name());
365+
}
366+
367+
void StringDumper::visit(TypeReference& type)
368+
{
369+
visit(type.type());
370+
}
371+
362372
void StringDumper::visit(Parameter& parameter)
363373
{
364374
m_out.print(m_indent);

Source/WebGPU/WGSL/AST/ASTStringDumper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class StringDumper final : public Visitor {
7979
void visit(ArrayType&) override;
8080
void visit(NamedType&) override;
8181
void visit(ParameterizedType&) override;
82+
void visit(StructType&) override;
83+
void visit(TypeReference&) override;
8284

8385
void visit(Parameter&) override;
8486
void visit(StructMember&) override;

Source/WebGPU/WGSL/AST/ASTStructureDecl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class StructMember final : public Node {
3838
public:
3939
using List = UniqueRefVector<StructMember>;
4040

41-
StructMember(SourceSpan span, const String& name, UniqueRef<TypeDecl>&& type, Attribute::List&& attributes)
41+
StructMember(SourceSpan span, const String& name, Ref<TypeDecl>&& type, Attribute::List&& attributes)
4242
: Node(span)
4343
, m_name(name)
4444
, m_attributes(WTFMove(attributes))
@@ -54,7 +54,7 @@ class StructMember final : public Node {
5454
private:
5555
String m_name;
5656
Attribute::List m_attributes;
57-
UniqueRef<TypeDecl> m_type;
57+
Ref<TypeDecl> m_type;
5858
};
5959

6060
class StructDecl final : public Decl {

Source/WebGPU/WGSL/AST/ASTTypeDecl.h

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,18 @@
2727

2828
#include "ASTExpression.h"
2929

30+
#include <wtf/RefCounted.h>
3031
#include <wtf/TypeCasts.h>
3132
#include <wtf/text/WTFString.h>
3233

33-
namespace WGSL::AST {
34+
namespace WGSL {
35+
class ResolveTypeReferences;
3436

35-
class TypeDecl : public Node {
37+
namespace AST {
38+
39+
class StructDecl;
40+
41+
class TypeDecl : public Node, public RefCounted<TypeDecl> {
3642
WTF_MAKE_FAST_ALLOCATED;
3743

3844
public:
@@ -46,7 +52,7 @@ class ArrayType final : public TypeDecl {
4652
WTF_MAKE_FAST_ALLOCATED;
4753

4854
public:
49-
ArrayType(SourceSpan span, std::unique_ptr<TypeDecl>&& elementType, std::unique_ptr<Expression>&& elementCount)
55+
ArrayType(SourceSpan span, RefPtr<TypeDecl>&& elementType, std::unique_ptr<Expression>&& elementCount)
5056
: TypeDecl(span)
5157
, m_elementType(WTFMove(elementType))
5258
, m_elementCount(WTFMove(elementCount))
@@ -58,13 +64,14 @@ class ArrayType final : public TypeDecl {
5864
Expression* maybeElementCount() const { return m_elementCount.get(); }
5965

6066
private:
61-
std::unique_ptr<TypeDecl> m_elementType;
67+
RefPtr<TypeDecl> m_elementType;
6268
std::unique_ptr<Expression> m_elementCount;
6369
};
6470

6571
class NamedType final : public TypeDecl {
6672
WTF_MAKE_FAST_ALLOCATED;
6773

74+
friend class ::WGSL::ResolveTypeReferences;
6875
public:
6976
NamedType(SourceSpan span, const String& name)
7077
: TypeDecl(span)
@@ -74,9 +81,16 @@ class NamedType final : public TypeDecl {
7481

7582
Kind kind() const override;
7683
const String& name() const { return m_name; }
84+
TypeDecl* maybeResolvedReference() const { return m_resolvedReference.get(); }
7785

7886
private:
87+
void resolveTypeReference(Ref<TypeDecl>&& typeDecl)
88+
{
89+
m_resolvedReference = WTFMove(typeDecl);
90+
}
91+
7992
String m_name;
93+
RefPtr<TypeDecl> m_resolvedReference;
8094
};
8195

8296
class ParameterizedType : public TypeDecl {
@@ -98,7 +112,7 @@ class ParameterizedType : public TypeDecl {
98112
Mat4x4
99113
};
100114

101-
ParameterizedType(SourceSpan span, Base base, UniqueRef<TypeDecl>&& elementType)
115+
ParameterizedType(SourceSpan span, Base base, Ref<TypeDecl>&& elementType)
102116
: TypeDecl(span)
103117
, m_base(base)
104118
, m_elementType(WTFMove(elementType))
@@ -140,10 +154,45 @@ class ParameterizedType : public TypeDecl {
140154

141155
private:
142156
Base m_base;
143-
UniqueRef<TypeDecl> m_elementType;
157+
Ref<TypeDecl> m_elementType;
158+
};
159+
160+
class StructType final : public TypeDecl {
161+
WTF_MAKE_FAST_ALLOCATED;
162+
163+
public:
164+
StructType(SourceSpan span, StructDecl& structDecl)
165+
: TypeDecl(span)
166+
, m_structDecl(structDecl)
167+
{
168+
}
169+
170+
Kind kind() const override;
171+
StructDecl& structDecl() const { return m_structDecl; }
172+
173+
private:
174+
StructDecl& m_structDecl;
175+
};
176+
177+
class TypeReference final : public TypeDecl {
178+
WTF_MAKE_FAST_ALLOCATED;
179+
180+
public:
181+
TypeReference(SourceSpan span, Ref<TypeDecl>&& type)
182+
: TypeDecl(span)
183+
, m_type(WTFMove(type))
184+
{
185+
}
186+
187+
Kind kind() const override;
188+
TypeDecl& type() const { return m_type.get(); }
189+
190+
private:
191+
Ref<TypeDecl> m_type;
144192
};
145193

146-
} // namespace WGSL::AST
194+
} // namespace AST
195+
} // namespace WGSL
147196

148197
#define SPECIALIZE_TYPE_TRAITS_WGSL_TYPE(ToValueTypeName, predicate) \
149198
SPECIALIZE_TYPE_TRAITS_BEGIN(WGSL::AST::ToValueTypeName) \
@@ -157,6 +206,8 @@ static bool isType(const WGSL::AST::Node& node)
157206
case WGSL::AST::Node::Kind::ArrayType:
158207
case WGSL::AST::Node::Kind::NamedType:
159208
case WGSL::AST::Node::Kind::ParameterizedType:
209+
case WGSL::AST::Node::Kind::StructType:
210+
case WGSL::AST::Node::Kind::TypeReference:
160211
return true;
161212
default:
162213
return false;
@@ -167,3 +218,5 @@ SPECIALIZE_TYPE_TRAITS_END()
167218
SPECIALIZE_TYPE_TRAITS_WGSL_AST(ArrayType)
168219
SPECIALIZE_TYPE_TRAITS_WGSL_AST(NamedType)
169220
SPECIALIZE_TYPE_TRAITS_WGSL_AST(ParameterizedType)
221+
SPECIALIZE_TYPE_TRAITS_WGSL_AST(StructType)
222+
SPECIALIZE_TYPE_TRAITS_WGSL_AST(TypeReference)

0 commit comments

Comments
 (0)
0