8000 webgpu: factor out matmul into a reusable snippet (#1709) · tensorflow/tfjs-core@00c5455 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 00c5455

Browse files
kainino0xannxingyuan
authored andcommitted
webgpu: factor out matmul into a reusable snippet (#1709)
FEATURE
1 parent 4987ad9 commit 00c5455

File tree

2 files changed

+174
-123
lines changed

2 files changed

+174
-123
lines changed

src/backends/webgpu/src/kernels/matmul_packed_webgpu.ts

Lines changed: 105 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,94 @@
1515
* =============================================================================
1616
*/
1717

18+
import {matMulHeader} from './matmul_webgpu';
1819
import {WebGPUProgram} from './webgpu_program';
1920

21+
export function makeMatMulPackedSource(workPerThread: number): string {
22+
return `
23+
${matMulHeader}
24+
25+
const uint TileSide = TileSize.x; // TileSize.x == TileSize.y
26+
const uint WorkPerThread = ${workPerThread};
27+
shared float mm_Asub[TileSide * WorkPerThread][TileSide * WorkPerThread];
28+
shared float mm_Bsub[TileSide * WorkPerThread][TileSide * WorkPerThread];
29+
30+
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
31+
uint row = gl_LocalInvocationID.y; // 0..local_size_x
32+
uint col = gl_LocalInvocationID.x; // 0..local_size_y
33+
uint tileRow = row * WorkPerThread; // 0..TileSide, stride by local_size
34+
uint tileCol = col * WorkPerThread; // 0..TileSide
35+
36+
// 0..AOuter, stride by tileSize
37+
uint globalRow = TileSide * gl_WorkGroupID.y + tileRow;
38+
uint globalCol = TileSide * gl_WorkGroupID.x + tileCol;
39+
40+
uint numTiles = (dimInner - 1) / TileSize.x + 1;
41+
42+
float acc[WorkPerThread][WorkPerThread];
43+
float ACached;
44+
float BCached[WorkPerThread];
45+
46+
// Without this initialization strange values show up in acc.
47+
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
48+
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
49+
acc[innerRow][innerCol] = 0.0;
50+
}
51+
}
52+
53+
// Loop over shared dimension.
54+
for (uint t = 0; t < numTiles; t++) {
55+
// Load one tile of A and B into local memory.
56+
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
57+
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
58+
uint inputRow = tileRow + innerRow;
59+
uint inputCol = tileCol + innerCol;
60+
61+
mm_Asub[inputRow][inputCol] = mm_readA(
62+
globalRow + innerRow,
63+
t * TileSize.x + tileCol + innerCol);
64+
mm_Bsub[inputRow][inputCol] = mm_readB(
65+
t * TileSize.x + tileRow + innerRow,
66+
globalCol + innerCol);
67+
}
68+
}
69+
70+
barrier();
71+
72+
// Compute acc values for a single thread.
73+
for (uint k = 0; k < TileSize.x; k++) {
74+
for (uint inner = 0; inner < WorkPerThread; inner++) {
75+
BCached[inner] = mm_Bsub[k][tileCol + inner];
76+
}
77+
78+
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
79+
ACached = mm_Asub[tileRow + innerRow][k];
80+
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
81+
acc[innerRow][innerCol] += ACached * BCached[innerCol];
82+
}
83+
}
84+
}
85+
86+
barrier();
87+
}
88+
89+
for (uint innerRow = 0; innerRow < WorkPerThread; innerRow++) {
90+
for (uint innerCol = 0; innerCol < WorkPerThread; innerCol++) {
91+
uint globalFlatIndex =
92+
(globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
93+
94+
if ((globalCol + innerCol) < dimBOuter &&
95+
(globalRow + innerRow) < dimAOuter) {
96+
mm_write(globalRow + innerRow,
97+
globalCol + innerCol,
98+
acc[innerRow][innerCol]);
99+
}
100+
}
101+
}
102+
}
103+
`;
104+
}
105+
20106
export class MatMulPackedProgram implements WebGPUProgram {
21107
outputShape: number[];
22108
userCode: string;
@@ -38,92 +124,30 @@ export class MatMulPackedProgram implements WebGPUProgram {
38124
// about boundary conditions when loading from Asub / Bsub when tiles fit
39125
// neatly inside of output. May slightly improve performance.
40126
this.userCode = `
41-
const uint WorkPerThread = ${workPerThread};
42-
shared float Asub[TileSize.x * WorkPerThread][TileSize.x * WorkPerThread];
43-
shared float Bsub[TileSize.x * WorkPerThread][TileSize.x * WorkPerThread];
127+
${makeMatMulPackedSource(workPerThread)}
44128
45-
void main() {
46-
uint row = gl_LocalInvocationID.y; // 0..local_size_x
47-
uint col = gl_LocalInvocationID.x; // 0..local_size_y
48-
uint tileRow = row * WorkPerThread; // 0..TileSize, stride by local_size
49-
uint tileCol = col * WorkPerThread; // 0..TileSize
50-
51-
// 0..AOuter, stride by tileSize
52-
uint globalRow = TileSize.x*gl_WorkGroupID.y + tileRow;
53-
uint globalCol = TileSize.x*gl_WorkGroupID.x + tileCol;
54-
55-
uint numTiles = (dimInner - 1) / TileSize.x + 1;
56-
57-
float acc[WorkPerThread][WorkPerThread];
58-
float ACached;
59-
float BCached[WorkPerThread];
60-
61-
// Without this initialization strange values show up in acc.
62-
for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
63-
for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
64-
acc[innerRow][innerCol] = 0.0;
65-
}
129+
float mm_readA(uint row, uint col) {
130+
if (row < dimAOuter && col < dimInner) {
131+
return A[row * dimInner + col];
132+
} else {
133+
return 0.0;
66134
}
135+
}
67136
68-
// Loop over shared dimension.
69-
for(uint t=0; t<numTiles; t++) {
70-
// Load one tile of A and B into local memory.
71-
for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
72-
for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
73-
uint inputRow = tileRow + innerRow;
74-
uint inputCol = tileCol + innerCol;
75-
76-
uint AColumnIndex = t * TileSize.x + tileCol + innerCol;
77-
uint AFlatIndex =
78-
(globalRow + innerRow) * dimInner + AColumnIndex;
79-
80-
if(AColumnIndex < dimInner && AFlatIndex < dimAOuter * dimInner) {
81-
Asub[inputRow][inputCol] = A[AFlatIndex];
82-
} else {
83-
Asub[inputRow][inputCol] = 0.0;
84-
}
85-
86-
uint BRowIndex = t * TileSize.x + tileRow + innerRow;
87-
uint BFlatIndex = BRowIndex * dimBOuter + (globalCol + innerCol);
88-
89-
if(BRowIndex < dimInner && BFlatIndex < dimInner * dimBOuter) {
90-
Bsub[inputRow][inputCol] = B[BFlatIndex];
91-
} else {
92-
Bsub[inputRow][inputCol] = 0.0;
93-
}
94-
}
95-
}
96-
97-
barrier();
98-
99-
// Compute acc values for a single thread.
100-
for(uint k=0; k<TileSize.x; k++) {
101-
for(uint inner=0; inner<WorkPerThread; inner++) {
102-
BCached[inner] = Bsub[k][tileCol + inner];
103-
}
104-
105-
for(uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
106-
ACached = Asub[tileRow + innerRow][k];
107-
for(uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
108-
acc[innerRow][innerCol] += ACached * BCached[innerCol];
109-
}
110-
}
111-
}
112-
113-
barrier();
137+
float mm_readB(uint row, uint col) {
138+
if (row < dimInner && col < dimBOuter) {
139+
return B[row * dimBOuter + col];
140+
} else {
141+
return 0.0;
114142
}
143+
}
115144
116-
for (uint innerRow=0; innerRow<WorkPerThread; innerRow++) {
117-
for (uint innerCol=0; innerCol<WorkPerThread; innerCol++) {
118-
uint globalFlatIndex =
119-
(globalRow + innerRow) * dimBOuter + (globalCol + innerCol);
120-
121-
if((globalCol + innerCol) < dimBOuter &&
122-
(globalRow + innerRow) < dimAOuter) {
123-
setOutput(globalFlatIndex, acc[innerRow][innerCol]);
124-
}
125-
}
126-
}
145+
void mm_write(uint row, uint col, float value) {
146+
setOutput(row * dimBOuter + col, value);
147+
}
148+
149+
void main() {
150+
mm_matMul(dimAOuter, dimInner, dimBOuter);
127151
}
128152
`;
129153
}

src/backends/webgpu/src/kernels/matmul_webgpu.ts

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,96 @@
1717

1818
import {WebGPUProgram} from './webgpu_program';
1919

20-
export class MatMulProgram implements WebGPUProgram {
21-
outputShape: number[];
22-
userCode: string;
23-
dispatch: [number, number, number];
24-
variableNames = ['A', 'B'];
25-
uniforms = 'uint dimAOuter, dimInner, dimBOuter, batch;';
26-
tileSize: [number, number] = [16, 16]; // Must be square.
20+
export const matMulHeader = `
21+
float mm_readA(uint row, uint col);
22+
float mm_readB(uint row, uint col);
23+
void mm_write(uint row, uint col, float value);
24+
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter);`;
2725

28-
constructor(outputShape: [number, number, number]) {
29-
this.outputShape = outputShape;
30-
this.dispatch = [
31-
Math.ceil(outputShape[1] / this.tileSize[0]),
32-
Math.ceil(outputShape[2] / this.tileSize[1]), 1
33-
];
26+
export function makeMatMulSource(): string {
27+
return `
28+
${matMulHeader}
3429
35-
this.userCode = `
36-
shared float Asub[TileSize.x][TileSize.x];
37-
shared float Bsub[TileSize.x][TileSize.x];
30+
const uint TileSide = TileSize.x; // TileSize.x == TileSize.y
31+
shared float mm_Asub[TileSide][TileSide];
32+
shared float mm_Bsub[TileSide][TileSide];
3833
39-
void main() {
40-
uint localRow = gl_LocalInvocationID.x; // < TileSize.x
41-
uint localCol = gl_LocalInvocationID.y; // < TileSize.x
42-
uint globalRow = TileSize.x*gl_WorkGroupID.x + localRow; // < dimAOuter
43-
uint globalCol = TileSize.x*gl_WorkGroupID.y + localCol; // < dimInner
34+
void mm_matMul(uint dimAOuter, uint dimInner, uint dimBOuter) {
35+
uint localRow = gl_LocalInvocationID.x; // 0..TileSide
36+
uint localCol = gl_LocalInvocationID.y; // 0..TileSide
37+
uint globalRow = TileSize.x * gl_WorkGroupID.x + localRow; // AOuter
38+
uint globalCol = TileSize.x * gl_WorkGroupID.y + localCol; // Inner
4439
4540
float acc = 0.0;
4641
4742
uint numTiles = (dimInner - 1) / TileSize.x + 1;
4843
49-
for (uint t=0; t<numTiles; t++) {
44+
for (uint t = 0; t < numTiles; t++) {
5045
// Load one tile of A and B into local memory
51-
uint tiledACol = TileSize.x*t + localCol;
52-
uint tiledBRow = TileSize.x*t + localRow;
53-
54-
uint AFlatIndex = globalRow * dimInner + tiledACol;
55-
if(AFlatIndex < dimAOuter * dimInner) {
56-
Asub[localRow][localCol] = A[AFlatIndex];
57-
} else {
58-
Asub[localRow][localCol] = 0.0;
59-
}
60-
61-
uint BFlatIndex = tiledBRow * dimBOuter + globalCol;
62-
if(BFlatIndex < dimInner * dimBOuter) {
63-
Bsub[localRow][localCol] = B[BFlatIndex];
64-
} else {
65-
Bsub[localRow][localCol] = 0.0;
66-
}
46+
uint tiledACol = TileSize.x * t + localCol;
47+
uint tiledBRow = TileSize.x * t + localRow;
48+
mm_Asub[localRow][localCol] = mm_readA(globalRow, tiledACol);
49+
mm_Bsub[localRow][localCol] = mm_readB(tiledBRow, globalCol);
6750
6851
// Synchronise to make sure the tile is loaded
6952
barrier();
7053
71-
for (uint k=0; k<TileSize.x; k++) {
72-
acc += Asub[localRow][k] * Bsub[k][localCol];
54+
for (uint k = 0; k < TileSize.x; k++) {
55+
acc += mm_Asub[localRow][k] * mm_Bsub[k][localCol];
7356
}
7457
7558
// Synchronise before loading the next tile
7659
barrier();
7760
}
7861
7962
if (globalCol < dimBOuter && globalRow < dimAOuter) {
80-
setOutput(globalRow * dimBOuter + globalCol, acc);
63+
mm_write(globalRow, globalCol, acc);
64+
}
65+
}
66+
`;
67+
}
68+
69+
export class MatMulProgram implements WebGPUProgram {
70+
outputShape: number[];
71+
userCode: string;
72+
dispatch: [number, number, number];
73+
variableNames = ['A', 'B'];
74+
uniforms = 'uint dimAOuter, dimInner, dimBOuter, batch;';
75+
tileSize: [number, number] = [16, 16]; // Must be square.
76+
77+
constructor(outputShape: [number, number, number]) {
78+
this.outputShape = outputShape;
79+
this.dispatch = [
80+
Math.ceil(outputShape[1] / this.tileSize[0]),
81+
Math.ceil(outputShape[2] / this.tileSize[1]), 1
82+
];
83+
84+
this.userCode = `
85+
${makeMatMulSource()}
86+
87+
float mm_readA(uint row, uint col) {
88+
if (row < dimAOuter && col < dimInner) {
89+
return A[row * dimInner + col];
90+
} else {
91+
return 0.0;
92+
}
93+
}
94+
95+
float mm_readB(uint row, uint col) {
96+
if (row < dimInner && col < dimBOuter) {
97+
return B[row * dimBOuter + col];
98+
} else {
99+
return 0.0;
81100
}
82101
}
102+
103+
void mm_write(uint row, uint col, float value) {
104+
setOutput(row * dimBOuter + col, value);
105+
}
106+
107+
void main() {
108+
mm_matMul(dimAOuter, dimInner, dimBOuter);
109+
}
83110
`;
84111
}
85-
}
112+
}

0 commit comments

Comments
 (0)
0