15
15
* =============================================================================
16
16
*/
17
17
18
+ import { matMulHeader } from './matmul_webgpu' ;
18
19
import { WebGPUProgram } from './webgpu_program' ;
19
20
21
+ export function makeMatMulPackedSource ( workPerThread : number ) : string {
22
+ return `
23
+ ${ matMulHeader }
24
+
25
+ const uint TileSide = TileSize.x; // TileSize.x == TileSize.y
26
+ const uint WorkPerThread = ${ workPerThread } ;
27
+ shared float mm_Asub[TileSide * WorkPerThread][TileSide * WorkPerThread];
28
+ shared float mm_Bsub[TileSide * WorkPerThread][TileSide * WorkPerThread];
29
+
30
+ void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
31
+ uint row = gl_LocalInvocationID.y; // 0..local_size_x
32
+ uint col = gl_LocalInvocationID.x; // 0..local_size_y
33
+ uint tileRow = row * WorkPerThread; // 0..TileSide, stride by local_size
34
+ uint tileCol = col * WorkPerThread; // 0..TileSide
35
+
36
+ // 0..AOuter, stride by tileSize
37
+ uint globalRow = TileSide * gl_WorkGroupID.y + tileRow;
38
+ uint globalCol = TileSide * gl_WorkGroupID.x + tileCol;
39
+
40
+ uint numTiles = (dimInner - 1) / TileSize.x + 1;
41
+
42
+ float acc[WorkPerThread][WorkPerThread];
43
+ float ACached;
44
+ float BCached[WorkPerThread];
45
+
46
+ // Without this initialization strange values show up in acc.
47
+ for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
48
+ for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
49
+ acc[innerRow][innerCol] = 0.0;
50
+ }
51
+ }
52
+
53
+ // Loop over shared dimension.
54
+ for (uint t = 0; t < numTiles; t++) {
55
+ // Load one tile of A and B into local memory.
56
+ for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
57
+ for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
58
+ uint inputRow = tileRow + innerRow;
59
+ uint inputCol = tileCol + innerCol;
60
+
61
+ mm_Asub[inputRow][inputCol] = mm_readA(
62
+ globalRow + innerRow,
63
+ t * TileSize.x + tileCol + innerCol);
64
+ mm_Bsub[inputRow][inputCol] = mm_readB(
65
+ t * TileSize.x + tileRow + innerRow,
66
+ globalCol + innerCol);
67
+ }
68
+ }
69
+
70
+ barrier();
71
+
72
+ // Compute acc values for a single thread.
73
+ for (uint k = 0; k < TileSize.x; k++) {
74
+ for (uint inner = 0; inner < WorkPerThread; inner++) {
75
+ BCached[inner] = mm_Bsub[k][tileCol + inner];
76
+ }
77
+
78
+ for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
79
+ ACached = mm_Asub[tileRow + innerRow][k];
80
+ for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
81
+ acc[innerRow][innerCol] += ACached * BCached[innerCol];
82
+ }
83
+ }
84
+ }
85
+
86
+ barrier();
87
+ }
88
+
89
+ for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
90
+ for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
91
+ uint globalFlatIndex =
92
+ (globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
93
+
94
+ if ((globalCol + innerCol) < dimBOuter &&
95
+ (globalRow + innerRow) < dimAOuter) {
96
+ mm_write(globalRow + innerRow,
97
+ globalCol + innerCol,
98
+ acc[innerRow][innerCol]);
99
+ }
100
+ }
101
+ }
102
+ }
103
+ ` ;
104
+ }
105
+
20
106
export class MatMulPackedProgram implements WebGPUProgram {
21
107
outputShape : number [ ] ;
22
108
userCode : string ;
@@ -38,92 +124,30 @@ export class MatMulPackedProgram implements WebGPUProgram {
38
124
// about boundary conditions when loading from Asub / Bsub when tiles fit
39
125
// neatly inside of output. May slightly improve performance.
40
126
this . userCode = `
41
- const uint WorkPerThread = ${ workPerThread } ;
42
- shared float Asub[TileSize.x * WorkPerThread][TileSize.x * WorkPerThread];
43
- shared float Bsub[TileSize.x * WorkPerThread][TileSize.x * WorkPerThread];
127
+ ${ makeMatMulPackedSource ( workPerThread ) }
44
128
45
- void main() {
46
- uint row = gl_LocalInvocationID.y; // 0..local_size_x
47
- uint col = gl_LocalInvocationID.x; // 0..local_size_y
48
- uint tileRow = row * WorkPerThread; // 0..TileSize, stride by local_size
49
- uint tileCol = col * WorkPerThread; // 0..TileSize
50
-
51
- // 0..AOuter, stride by tileSize
52
- uint globalRow = TileSize.x*gl_WorkGroupID.y + tileRow;
53
- uint globalCol = TileSize.x*gl_WorkGroupID.x + tileCol;
54
-
55
- uint numTiles = (dimInner - 1) / TileSize.x + 1;
56
-
57
- float acc[WorkPerThread][WorkPerThread];
58
- float ACached;
59
- float BCached[WorkPerThread];
60
-
61
- // Without this initialization strange values show up in acc.
62
- for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
63
- for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
64
- acc[innerRow][innerCol] = 0.0;
65
- }
129
+ float mm_readA(uint row, uint col) {
130
+ if (row < dimAOuter && col < dimInner) {
131
+ return A[row * dimInner + col];
132
+ } else {
133
+ return 0.0;
66
134
}
135
+ }
67
136
68
- // Loop over shared dimension.
69
- for(uint t=0; t<numTiles; t++) {
70
- // Load one tile of A and B into local memory.
71
- for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
72
- for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
73
- uint inputRow = tileRow + innerRow;
74
- uint inputCol = tileCol + innerCol;
75
-
76
- uint AColumnIndex = t * TileSize.x + tileCol + innerCol;
77
- uint AFlatIndex =
78
- (globalRow + innerRow) * dimInner + AColumnIndex;
79
-
80
- if(AColumnIndex < dimInner && AFlatIndex < dimAOuter * dimInner) {
81
- Asub[inputRow][inputCol] = A[AFlatIndex];
82
- } else {
83
- Asub[inputRow][inputCol] = 0.0;
84
- }
85
-
86
- uint BRowIndex = t * TileSize.x + tileRow + innerRow;
87
- uint BFlatIndex = BRowIndex * dimBOuter + (globalCol + innerCol);
88
-
89
- if(BRowIndex < dimInner && BFlatIndex < dimInner * dimBOuter) {
90
- Bsub[inputRow][inputCol] = B[BFlatIndex];
91
- } else {
92
- Bsub[inputRow][inputCol] = 0.0;
93
- }
94
- }
95
- }
96
-
97
- barrier();
98
-
99
- // Compute acc values for a single thread.
100
- for(uint k=0; k<TileSize.x; k++) {
101
- for(uint inner=0; inner<WorkPerThread; inner++) {
102
- BCached[inner] = Bsub[k][tileCol + inner];
103
- }
104
-
105
- for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
106
- ACached = Asub[tileRow + innerRow][k];
107
- for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
108
- acc[innerRow][innerCol] += ACached * BCached[innerCol];
109
- }
110
- }
111
- }
112
-
113
- barrier();
137
+ float mm_readB(uint row, uint col) {
138
+ if (row < dimInner && col < dimBOuter) {
139
+ return B[row * dimBOuter + col];
140
+ } else {
141
+ return 0.0;
114
142
}
143
+ }
115
144
116
- for (uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
117
- for (uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
118
- uint globalFlatIndex =
119
- (globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
120
-
121
- if((globalCol + innerCol) < dimBOuter &&
122
- (globalRow + innerRow) < dimAOuter) {
123
- setOutput(globalFlatIndex, acc[innerRow][innerCol]);
124
- }
125
- }
126
- }
145
+ void mm_write(uint row, uint col, float value) {
146
+ setOutput(row * dimBOuter + col, value);
147
+ }
148
+
149
+ void main() {
150
+ mm_matMul(dimAOuter, dimInner, dimBOuter);
127
151
}
128
152
` ;
129
153
}
0 commit comments