8000 Fuse conv2d with bias and activation. (#1859) · tensorflow/tfjs-core@5c0d017 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Fuse conv2d with bias and activation. (#1859)
Browse files Browse the repository at this point in the history
FEATURE
PERF
  • Loading branch information
annxingyuan authored Jul 29, 2019
1 parent 95e44a5 commit 5c0d017
Show file tree
Hide file tree
Showing 7 changed files with 495 additions and 15 deletions.
6 changes: 6 additions & 0 deletions src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
throw new Error('Not yet implemented');
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
throw new Error('Not yet implemented');
}

conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
throw new Error('Not yet implemented');
}
Expand Down
14 changes: 14 additions & 0 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,20 @@ export class MathBackendCPU implements KernelBackend {
return Tensor.make(x.shape, {values: resultValues}) as T;
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
let result = this.conv2d(x, filter, convInfo);

if (bias) {
result = this.add(result, bias) as Tensor4D;
}
if (activation) {
result = mapActivation(this, activation, result) as Tensor4D;
}
return result;
}

conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
this.assertNotComplex([x, filter], 'conv2d');

Expand Down
67 changes: 55 additions & 12 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -873,10 +873,12 @@ export class MathBackendWebGL implements KernelBackend {

const dtype = upcastType(a.dtype, b.dtype);

const hasBias = bias != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, true) : null;
const program = new MatMulPackedProgram(
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
!!bias,
activation ? mapActivationToShaderProgram(activation, true) : null);
hasBias, fusedActivation);
const output =
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
const inputs: TensorHandle[] = [a, b];
Expand Down Expand Up @@ -1815,15 +1817,18 @@ export class MathBackendWebGL implements KernelBackend {
return this.compileAndRun(program, [x]) as T;
}

conv2dByMatMul(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
Tensor4D {
private conv2dByMatMul(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
// result from 2D to 4D.
const xShape = x.shape;
const xTexData = this.texData.get(x.dataId);
const sharedMatMulDim = convInfo.inChannels;
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
const outerShapeFilter = convInfo.outChannels;
const transposeA = false;
const transposeB = false;

// TODO: Once reduction ops are packed, batchMatMul will always be packed
// and we can remove this condition.
Expand All @@ -1843,8 +1848,11 @@ export class MathBackendWebGL implements KernelBackend {
this.reshape(
filter, [1, convInfo.inChannels, convInfo.outChannels]) as
Tensor3D;

return this.reshape<Rank.R4>(
this.batchMatMul(xReshaped, filterReshaped, false, false),
this.fusedBatchMatMul(
xReshaped, filterReshaped, transposeA, transposeB, bias,
activation),
convInfo.outShape);
}

Expand Down Expand Up @@ -1880,8 +1888,8 @@ export class MathBackendWebGL implements KernelBackend {
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]) as
Tensor3D;

const pointwiseConv =
this.batchMatMul(xReshaped, filterReshaped, false, false);
const pointwiseConv = this.fusedBatchMatMul(
xReshaped, filterReshaped, transposeA, transposeB, bias, activation);
const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
util.assert(
pointwiseConvTexData.isPacked,
Expand All @@ -1896,8 +1904,9 @@ export class MathBackendWebGL implements KernelBackend {
pointwiseConv.dtype, this) as Tensor4D;
}

conv2dWithIm2Row(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
Tensor4D {
private conv2dWithIm2Row(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
// Rearranges conv2d input so each block to be convolved over forms the
// column of a new matrix with shape [filterWidth * filterHeight *
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
Expand All @@ -1915,6 +1924,8 @@ export class MathBackendWebGL implements KernelBackend {
const sharedDim = filterWidth * filterHeight * inChannels;
const numCols = outHeight * outWidth;
const x2ColShape = [sharedDim, numCols];
const transposeA = true;
const transposeB = false;

const xSqueezed = x.squeeze([0]);
const w2Row = filter.reshape([1, sharedDim, -1]) as Tensor3D;
Expand All @@ -1926,14 +1937,46 @@ export class MathBackendWebGL implements KernelBackend {
1, x2ColShape[0], x2ColShape[1]
]) as Tensor3D;

const hasBias = bias != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, true) : null;
const matmulProgram = new MatMulPackedProgram(
im2Col.shape, [1, numCols, convInfo.outChannels], true, false);
const product =
this.compileAndRun<Tensor4D>(matmulProgram, [im2Col, w2Row]);
im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,
transposeB, hasBias, fusedActivation);
const inputs: TensorHandle[] = [im2Col, w2Row];
if (bias) {
inputs.push(bias);
}
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);

return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID')) {
return this.conv2dByMatMul(x, filter, convInfo, bias, activation);
}
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) 8000 {
return this.conv2dWithIm2Row(x, filter, convInfo, bias, activation);
}

const hasBias = bias != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, false) : null;
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation);
const inputs: TensorHandle[] = [x, filter];
if (bias) {
inputs.push(bias);
}
return this.compileAndRun(program, inputs);
}

conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
Expand Down
27 changes: 25 additions & 2 deletions src/backends/webgl/conv_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ export class Conv2DProgram implements GPGPUProgram {
outputShape: number[];
userCode: string;

constructor(convInfo: Conv2DInfo) {
constructor(
convInfo: Conv2DInfo, addBias = false, activation: string = null) {
this.outputShape = convInfo.outShape;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
Expand All @@ -37,7 +38,25 @@ export class Conv2DProgram implements GPGPUProgram {
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
const inputDepthVec4Remainder = convInfo.inChannels % 4;

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
activationSnippet = `
float activation(float x) {
${activation}
}
`;

applyActivationSnippet = `result = activation(result);`;
}

const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}

this.userCode = `
${activationSnippet}
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
Expand Down Expand Up @@ -113,7 +132,11 @@ export class Conv2DProgram implements GPGPUProgram {
}
}
}
setOutput(dotProd);
float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
Expand Down
1 change: 1 addition & 0 deletions src/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ export const conv1d = op({conv1d_});
export const conv2d = op({conv2d_});
export const conv3d = op({conv3d_});
export const conv2dDerFilter = op({conv2dDerFilter_});
export const conv2dDerInput = op({conv2dDerInput_});
export const depthwiseConv2d = op({depthwiseConv2d_});
export const separableConv2d = op({separableConv2d_});
export const conv2dTranspose = op({conv2dTranspose_});
Loading

0 comments on commit 5c0d017

Please sign in to comment.
0