8000 webgpu: uniforms, makeBindGroup/Layout, non-square tiles (#1689) · tensorflow/tfjs-core@7f26888 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 7f26888

Browse files
kainino0xannxingyuan
authored andcommitted
webgpu: uniforms, makeBindGroup/Layout, non-square tiles (#1689)
INTERNAL
1 parent 95a2139 commit 7f26888

File tree

4 files changed

+170
-104
lines changed

4 files changed

+170
-104
lines changed

src/backends/webgpu/src/backend_webgpu.ts

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import './flags_webgpu';
2121

22-
import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, tensor1d, Tensor3D, util} from '@tensorflow/tfjs-core';
22+
import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor3D, util} from '@tensorflow/tfjs-core';
2323
import * as shaderc from '@webgpu/shaderc';
2424

2525
import * as binary_op from './kernels/binary_op_webgpu';
@@ -32,8 +32,7 @@ import * as webgpu_program from './kernels/webgpu_program';
3232
import {WebGPUBinary} from './kernels/webgpu_program';
3333

3434
type TensorInfo = {
35-
shape: number[],
36-
dtype: DataType,
35+
byteSize: number,
3736
values: Float32Array|Int32Array|Uint8Array,
3837
id: number,
3938
buffer: GPUBuffer
@@ -75,28 +74,31 @@ export class WebGPUBackend extends KernelBackend {
7574
private tensorMap = new WeakMap<DataId, TensorInfo>();
7675

7776
disposeData(dataId: DataId): void {
78-
// Tensor disposal logic.
77+
if (!this.tensorMap.has(dataId)) {
78+
throw new Error(`Tensor ${dataId} was not registered!`);
79+
}
80+
81+
const info = this.tensorMap.get(dataId);
82+
this.destroyBuffer(info.byteSize, info.buffer);
7983
}
8084

81-
private createBuffer(size: number) {
82-
return this.device.createBuffer({
83-
size,
84-
usage: GPUBufferUsage.TRANSFER_SRC | GPUBufferUsage.TRANSFER_DST |
85-
GPUBufferUsage.STORAGE,
86-
});
85+
private createBuffer(
86+
size: number,
87+
usage: GPUBufferUsage = GPUBufferUsage.STORAGE |
88+
GPUBufferUsage.TRANSFER_SRC | GPUBufferUsage.TRANSFER_DST) {
89+
return this.device.createBuffer({size, usage});
8790
}
8891

89-
private setBufferData(
90-
buffer: GPUBuffer, data: Float32Array|Int32Array|Uint8Array) {
91-
buffer.setSubData(0, data);
92+
private destroyBuffer(byteSize: number, buffer: GPUBuffer) {
93+
// TODO: recycle deleted buffers
94+
buffer.destroy();
9295
}
9396

9497
register(dataId: object, shape: number[], dtype: DataType): void {
9598
if (!this.tensorMap.has(dataId)) {
96-
const buffer = this.createBuffer(
97-
util.sizeFromShape(shape) * util.bytesPerElement(dtype));
98-
99-
this.tensorMap.set(dataId, {shape, dtype, values: null, id: -1, buffer});
99+
const byteSize = util.sizeFromShape(shape) * util.bytesPerElement(dtype);
100+
const buffer = this.createBuffer(byteSize);
101+
this.tensorMap.set(dataId, {byteSize, values: null, id: -1, buffer});
100102
}
101103
}
102104

@@ -107,7 +109,7 @@ export class WebGPUBackend extends KernelBackend {
107109

108110
const info = this.tensorMap.get(dataId);
109111
info.values = values;
110-
this.setBufferData(info.buffer, values);
112+
info.buffer.setSubData(0, values);
111113
this.tensorMap.set(dataId, info);
112114
}
113115

@@ -118,15 +120,11 @@ export class WebGPUBackend extends KernelBackend {
118120
}
119121

120122
private async getBufferData(info: TensorInfo): Promise<ArrayBuffer> {
121-
const size =
122-
util.sizeFromShape(info.shape) * util.bytesPerElement(info.dtype);
123-
const staging = this.device.createBuffer({
124-
size,
125-
usage: GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.MAP_READ,
126-
});
123+
const staging = this.createBuffer(
124+
info.byteSize, GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.MAP_READ);
127125
{
128126
const encoder = this.device.createCommandEncoder({});
129-
encoder.copyBufferToBuffer(info.buffer, 0, staging, 0, size);
127+
encoder.copyBufferToBuffer(info.buffer, 0, staging, 0, info.byteSize);
130128
this.commandQueue.push(encoder);
131129
this.submitQueue();
132130
}
@@ -158,36 +156,40 @@ export class WebGPUBackend extends KernelBackend {
158156
return Tensor.make(shape, {}, dtype, this) as T;
159157
}
160158

159+
private tensorToBinding(tensor?: Tensor): webgpu_program.BindingInfo {
160+
if (!tensor) {
161+
return null;
162+
}
163+
164+
const tensorData = this.tensorMap.get(tensor.dataId);
165+
166+
return {
167+
resource: {
168+
offset: 0,
169+
size: tensor.size * util.bytesPerElement(tensor.dtype),
170+
buffer: tensorData.buffer
171+
}
172+
};
173+
}
174+
161175
private compileAndRun<
162176
K extends {dtype: DataType, size: number, dataId: {}, shape: number[]}>(
163-
program: webgpu_program.WebGPUProgram, inputs: Tensor[],
164-
output?: Tensor): K {
177+
program: webgpu_program.WebGPUProgram, inputs: Tensor[], output?: Tensor,
178+
uniforms?: webgpu_program.BindingInfo): K {
165179
if (output == null) {
166180
output = this.makeOutputArray(program.outputShape, inputs[0].dtype);
167181
}
168182
const key = webgpu_program.makeShaderKey(program);
169183
const {bindGroupLayout, pipeline} = this.getAndSavePipeline(key, () => {
170184
return webgpu_program.compileProgram(
171185
this.compiler, this.shaderc.shader_kind.compute, this.compileOpts,
172-
this.device, program, inputs, output);
186+
this.device, program, inputs, output, uniforms);
173187
});
174188

175189
// Creating bind groups on the fly should never be a bottleneck.
176-
const bg = this.device.createBindGroup({
177-
layout: bindGroupLayout,
178-
bindings: inputs.concat(output).map((tensor, i: number) => {
179-
const tensorData = this.tensorMap.get(tensor.dataId);
180-
181-
return {
182-
binding: i,
183-
resource: {
184-
offset: 0,
185-
size: tensor.size * util.bytesPerElement(tensor.dtype),
186-
buffer: tensorData.buffer
187-
}
188-
};
189-
})
190-
});
190+
const bg = webgpu_program.makeBindGroup(
191+
this.device, bindGroupLayout, inputs.map(t => this.tensorToBinding(t)),
192+
this.tensorToBinding(output), uniforms);
191193

192194
const encoder = this.device.createCommandEncoder({});
193195
const pass = encoder.beginComputePass();
@@ -204,6 +206,17 @@ export class WebGPUBackend extends KernelBackend {
204206
return output as {} as K;
205207
}
206208

209+
private makeUniforms(data: Uint32Array): webgpu_program.BindingInfo {
210+
const dimensionsBuffer = this.createBuffer(
211+
data.byteLength,
212+
GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.UNIFORM);
213+
dimensionsBuffer.setSubData(0, data);
214+
215+
return {
216+
resource: {offset: 0, size: data.byteLength, buffer: dimensionsBuffer}
217+
};
218+
}
219+
207220
pad<T extends Tensor>(
208221
x: T, paddings: Array<[number, number]>, constantValue: number): T {
209222
const program = new PadProgram(x.shape, paddings, constantValue);
@@ -244,12 +257,17 @@ export class WebGPUBackend extends KernelBackend {
244257
const output =
245258
Tensor.make([batch, outerShapeA, outerShapeB], {}, a.dtype, this) as
246259
Tensor3D;
247-
248260
const program = new MatMulProgram(output.shape);
249-
const dimensions =
250-
tensor1d([outerShapeA, sharedDim, outerShapeB, batch], 'int32');
251-
// TODO: dispose mnkb
252261

253-
return this.compileAndRun(program, [a, b, dimensions], output) as Tensor3D;
262+
const dimensionsData =
263+
new Uint32Array([outerShapeA, sharedDim, outerShapeB, batch]);
264+
const dimensions = this.makeUniforms(dimensionsData);
265+
266+
const result =
267+
this.compileAndRun(program, [a, b], output, dimensions) as Tensor3D;
268+
269+
this.destroyBuffer(dimensionsData.byteLength, dimensions.resource.buffer);
270+
271+
return result;
254272
}
255273
}

src/backends/webgpu/src/kernels/matmul_webgpu.ts

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,53 +21,51 @@ export class MatMulProgram implements WebGPUProgram {
2121
outputShape: number[];
2222
userCode: string;
2323
dispatch: [number, number, number];
24-
variableNames = ['A', 'B', 'Dimensions'];
25-
tileSize = 8;
24+
variableNames = ['A', 'B'];
25+
uniforms = 'uint dimAOuter, dimInner, dimBOuter, batch;';
26+
tileSize: [number, number] = [16, 16]; // Must be square.
2627

2728
constructor(outputShape: [number, number, number]) {
2829
this.outputShape = outputShape;
2930
this.dispatch = [
30-
Math.ceil(outputShape[1] / this.tileSize),
31-
Math.ceil(outputShape[2] / this.tileSize), 1
31+
Math.ceil(outputShape[1] / this.tileSize[0]),
32+
Math.ceil(outputShape[2] / this.tileSize[1]), 1
3233
];
3334

3435
this.userCode = `
35-
shared float Asub[TileSize][TileSize];
36-
shared float Bsub[TileSize][TileSize];
36+
shared float Asub[TileSize.x][TileSize.x];
37+
shared float Bsub[TileSize.x][TileSize.x];
3738
3839
void main() {
39-
// M is A outer, N is shared, K is B outer
40-
uint M = Dimensions[0], N = Dimensions[1],
41-
K = Dimensions[2], batch = Dimensions[3];
42-
uint row = gl_LocalInvocationID.x; // Local row ID (max: TileSize)
43-
uint col = gl_LocalInvocationID.y; // Local col ID (max: TileSize)
44-
uint globalRow = TileSize*gl_WorkGroupID.x + row; // Row ID of C (0..M)
45-
uint globalCol = TileSize*gl_WorkGroupID.y + col; // Col ID of C (0..N)
40+
uint localRow = gl_LocalInvocationID.x; // < TileSize.x
41+
uint localCol = gl_LocalInvocationID.y; // < TileSize.x
42+
uint globalRow = TileSize.x*gl_WorkGroupID.x + localRow; // < dimAOuter
43+
uint globalCol = TileSize.x*gl_WorkGroupID.y + localCol; // < dimInner
4644
4745
float acc = 0.0;
4846
49-
uint numTiles = (N - 1)/TileSize + 1;
47+
uint numTiles = (dimInner - 1) / TileSize.x + 1;
5048
5149
for (uint t=0; t<numTiles; t++) {
5250
// Load one tile of A and B into local memory
53-
uint tiledRow = TileSize*t + row;
54-
uint tiledCol = TileSize*t + col;
55-
Asub[row][col] = A[globalRow*N + tiledCol];
56-
Bsub[row][col] = B[tiledRow*K + globalCol];
51+
uint tiledACol = TileSize.x*t + localCol;
52+
uint tiledBRow = TileSize.x*t + localRow;
53+
Asub[localRow][localCol] = A[globalRow * dimInner + tiledACol];
54+
Bsub[localRow][localCol] = B[tiledBRow * dimBOuter + globalCol];
5755
5856
// Synchronise to make sure the tile is loaded
5957
barrier();
6058
61-
for (uint k=0; k<TileSize; k++) {
62-
acc += Asub[row][k] * Bsub[k][col];
59+
for (uint k=0; k<TileSize.x; k++) {
60+
acc += Asub[localRow][k] * Bsub[k][localCol];
6361
}
6462
6563
// Synchronise before loading the next tile
6664
barrier();
6765
}
6866
69-
if(globalCol < K && globalRow < M) {
70-
setOutput(globalRow*K + globalCol, acc);
67+
if (globalCol < dimBOuter && globalRow < dimAOuter) {
68+
setOutput(globalRow * dimBOuter + globalCol, acc);
7169
}
7270
}
7371
`;

src/backends/webgpu/src/kernels/webgpu_program.ts

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ export interface WebGPUProgram {
2626
// Dispatch determines the layout of thread groups.
2727
dispatch: [number, number, number];
2828
variableNames: string[];
29-
tileSize?: number;
29+
uniforms?: string;
30+
tileSize?: [number, number?, number?];
3031
}
3132

3233
export interface WebGPUBinary {
@@ -38,35 +39,62 @@ export interface TensorData {
3839
dtype: DataType;
3940
}
4041

42+
export interface BindingInfo {
43+
resource: {offset: number, size: number, buffer: GPUBuffer};
44+
}
45+
46+
export const makeBindGroup =
47+
(device: GPUDevice, bindGroupLayout: GPUBindGroupLayout,
48+
inputs: BindingInfo[], output: BindingInfo, uniforms?: BindingInfo) => {
49+
const bindings = [output, ...inputs];
50+
if (uniforms) {
51+
bindings.push(uniforms);
52+
}
53+
return device.createBindGroup({
54+
layout: bindGroupLayout,
55+
bindings: bindings.map((b, i) => ({binding: i, resource: b.resource})),
56+
});
57+
};
58+
59+
const makeBindGroupLayout =
60+
(device: GPUDevice, inputs: Tensor[], output: Tensor,
61+
uniforms?: BindingInfo): GPUBindGroupLayout => {
62+
const bindings = Array(1 + inputs.length).fill({
63+
visibility: GPUShaderStageBit.COMPUTE,
64+
type: 'storage-buffer' as GPUBindingType
65+
});
66+
if (uniforms) {
67+
bindings.push({
68+
visibility: GPUS 10000 haderStageBit.COMPUTE,
69+
type: 'uniform-buffer' as GPUBindingType
70+
});
71+
}
72+
return device.createBindGroupLayout({
73+
bindings: bindings.map((b, i) => ({binding: i, ...b})),
74+
});
75+
};
76+
4177
export const compileProgram =
4278
(shaderCompiler: shaderc.Compiler, shaderKind: shaderc.ShaderKind,
4379
compileOptions: shaderc.CompileOptions, device: GPUDevice,
44-
program: WebGPUProgram, inputs: Tensor[],
45-
output: Tensor): WebGPUBinary => {
46-
const bindings =
47-
inputs.concat(output).map((input: Tensor, idx: number) => {
48-
return {
49-
binding: idx,
50-
visibility: GPUShaderStageBit.COMPUTE,
51-
type: 'storage-buffer'
52-
} as GPUBindGroupLayoutBinding;
53-
});
80+
program: WebGPUProgram, inputs: Tensor[], output: Tensor,
81+
uniforms?: BindingInfo): WebGPUBinary => {
5482
const inputsData = inputs.map((input: Tensor) => {
5583
return {dtype: input.dtype, shape: input.shape};
5684
});
5785
const outputData = {dtype: output.dtype, shape: output.shape};
5886

59-
const source = shader_preprocessor.makeShader(
60-
inputsData, program.variableNames, outputData, program.userCode,
61-
program.tileSize);
87+
const source =
88+
shader_preprocessor.makeShader(inputsData, outputData, program);
6289
const result = shaderCompiler.CompileGlslToSpv(
6390
source, shaderKind, 'file', 'main', compileOptions);
6491
const error = result.GetErrorMessage();
6592
if (error.length) {
6693
throw new Error(`Shader compilation failed: ${error}`);
6794
}
95+
const bindGroupLayout =
96+
makeBindGroupLayout(device, inputs, output, uniforms);
6897
const code = result.GetBinary();
69-
const bindGroupLayout = device.createBindGroupLayout({bindings});
7098
const layout =
7199
device.createPipelineLayout({bindGroupLayouts: [bindGroupLayout]});
72100
const module = device.createShaderModule({code});
@@ -79,4 +107,4 @@ export const compileProgram =
79107
export function makeShaderKey(program: WebGPUProgram): string {
80108
const key = program.userCode;
81109
return key;
82-
};
110+
}

0 commit comments

Comments
 (0)
0