8000 webgpu: implement a version of conv2d using matmul (#1710) · tensorflow/tfjs-core@0940cf0 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 0940cf0

Browse files
kainino0xannxingyuan
authored andcommitted
webgpu: implement a version of conv2d using matmul (#1710)
FEATURE
1 parent 28fd404 commit 0940cf0

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

src/backends/webgpu/src/backend_webgpu.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import * as shaderc from '@webgpu/shaderc';
2626

2727
import * as binary_op from './kernels/binary_op_webgpu';
2828
import {BinaryOpProgram} from './kernels/binary_op_webgpu';
29+
import {Conv2DMMProgram} from './kernels/conv2d_mm_webgpu';
2930
import {Conv2DNaiveProgram} from './kernels/conv2d_naive_webgpu';
3031
import {MatMulPackedProgram} from './kernels/matmul_packed_webgpu';
3132
import {MatMulProgram} from './kernels/matmul_webgpu';
@@ -262,7 +263,15 @@ export class WebGPUBackend extends KernelBackend {
262263
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
263264
const output =
264265
Tensor.make(convInfo.outShape, {}, x.dtype, this) as Tensor4D;
265-
const program = new Conv2DNaiveProgram(convInfo);
266+
let program: Conv2DMMProgram|Conv2DNaiveProgram;
267+
268+
const workPerThread = ENV.get('WEBGPU_CONV2D_WORK_PER_THREAD') as number;
269+
if (workPerThread === -1) {
270+
// TODO(kainino0x): This may be obsolete, but is kept for reference.
271+
program = new Conv2DNaiveProgram(convInfo);
272+
} else {
273+
program = new Conv2DMMProgram(convInfo, workPerThread);
274+
}
266275

267276
const pad = convInfo.padInfo.type === 'VALID' ?
268277
[0, 0] :

src/backends/webgpu/src/flags_webgpu.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,10 @@ ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
2525
* matMul without register blocking.
2626
*/
2727
ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);
28+
29+
/**
30+
* -1: conv2d_naive
31+
* 0: conv2d_mm with matmul without register blocking
32+
* >0: conv2d_mm with matmul_packed with WPT=this
33+
*/
34+
ENV.registerFlag('WEBGPU_CONV2D_WORK_PER_THREAD', () => 2);
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs-core';
19+
import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
20+
21+
import {generateGetOutputCoords} from '../shader_util';
22+
import {computeDispatch} from '../webgpu_util';
23+
24+
import {makeMatMulPackedSource} from './matmul_packed_webgpu';
25+
import {makeMatMulSource} from './matmul_webgpu';
26+
import {WebGPUProgram} from './webgpu_program';
27+
28+
export class Conv2DMMProgram implements WebGPUProgram {
29+
outputShape: number[];
30+
userCode: string;
31+
dispatch: [number, number, number];
32+
variableNames = ['x', 'W'];
33+
uniforms = 'ivec4 xShape, outShape; ivec2 WShape, pad, stride;';
34+
workGroupSize: [number, number, number] = [
35+
16, 16, // must be square (for matmul)
36+
1
37+
];
38+
39+
constructor(convInfo: Conv2DInfo, workPerThread: number) {
40+
this.outputShape = convInfo.outShape;
41+
const dispatchLayout = {x: [1], y: [2], z: [0, 3]};
42+
43+
tf.util.assert(
44+
convInfo.dataFormat === 'channelsLast',
45+
() => 'TODO: NCHW is unimplemented');
46+
tf.util.assert(
47+
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1,
48+
() => 'TODO: Dilation is unimplemented');
49+
50+
let elementsPerThread: [number, number, number];
51+
let matMulSource: string;
52+
if (workPerThread === 0) {
53+
elementsPerThread = [1, 1, 1];
54+
matMulSource = makeMatMulSource();
55+
} else {
56+
elementsPerThread = [workPerThread, workPerThread, 1];
57+
matMulSource = makeMatMulPackedSource(workPerThread);
58+
}
59+
this.dispatch = computeDispatch(
60+
dispatchLayout, this.outputShape, this.workGroupSize,
61+
elementsPerThread);
62+
63+
this.userCode = `
64+
${matMulSource}
65+
66+
bool coordIsValid(ivec4 coord, ivec4 shape) {
67+
return all(greaterThanEqual(coord, ivec4(0))) &&
68+
all(lessThan(coord, shape));
69+
}
70+
71+
${generateGetOutputCoords(dispatchLayout, this.outputShape.length)}
72+
73+
int batch;
74+
75+
float mm_readA(uint row, uint col) {
76+
ivec4 coord = ivec4(
77+
(col / WShape[1]) % WShape[0],
78+
col % WShape[1],
79+
col / (WShape[1] * WShape[0]),
80+
row);
81+
82+
ivec4 shape = ivec4(WShape, xShape[3], outShape[3]);
83+
return coordIsValid(coord, shape) ? W[getFlatIndex(coord, shape)] : 0;
84+
}
85+
86+
float mm_readB(uint row, uint col) {
87+
int outRow = int(col) / outShape[2];
88+
int outCol = int(col) % outShape[2];
89+
90+
int WRow = (int(row) / WShape[1]) % WShape[0];
91+
int WCol = int(row) % WShape[1];
92+
93+
ivec4 coord = ivec4(
94+
batch,
95+
pad[0] + outRow * stride[0] + WRow,
96+
pad[1] + outCol * stride[1] + WCol,
97+
row / (WShape[1] * WShape[0]));
98+
return coordIsValid(coord, xShape) ?
99+
x[getFlatIndex(coord, xShape)] : 0;
100+
}
101+
102+
void mm_write(uint row, uint col, float value) {
103+
ivec4 outCoord = ivec4(
104+
batch,
105+
col / outShape[2],
106+
col % outShape[2],
107+
row);
108+
if (coordIsValid(outCoord, outShape)) {
109+
result[getFlatIndex(outCoord, outShape)] = value;
110+
}
111+
}
112+
113+
void main() {
114+
batch = getOutputCoords()[0];
115+
116+
int dimAOuter = outShape[3];
117+
int dimBOuter = outShape[1] * outShape[2];
118+
int dimInner = WShape[0] * WShape[1] * xShape[3];
119+
mm_matMul(dimAOuter, dimInner, dimBOuter);
120+
}
121+
`;
122+
}
123+
}

0 commit comments

Comments
 (0)
0