@@ -125,6 +125,9 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
125
125
if ( inShapeInfo . logicalShape . length > 0 ) {
126
126
const varShapeName = `shape${ inputInfo . name } ` ;
127
127
const varTexShapeName = `texShape${ inputInfo . name } ` ;
128
+ const varStridesName = `strides${ inputInfo . name } ` ;
129
+ const varPackedTexShapeName = `packedTexShape${ inputInfo . name } ` ;
130
+
128
131
const shouldThrow = false ;
129
132
const shapeLocation =
130
133
gpgpu . getUniformLocation ( webGLProgram , varShapeName , shouldThrow ) ;
@@ -137,6 +140,18 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
137
140
if ( texShapeLocation != null ) {
138
141
uniformLocations [ varTexShapeName ] = texShapeLocation ;
139
142
}
143
+
144
+ const varStridesLocation =
145
+ gpgpu . getUniformLocation ( webGLProgram , varStridesName , shouldThrow ) ;
146
+ if ( varStridesLocation != null ) {
147
+ uniformLocations [ varStridesName ] = varStridesLocation ;
148
+ }
149
+
150
+ const varPackedTexShapeLocation = gpgpu . getUniformLocation (
151
+ webGLProgram , varPackedTexShapeName , shouldThrow ) ;
152
+ if ( varPackedTexShapeLocation != null ) {
153
+ uniformLocations [ varPackedTexShapeName ] = varPackedTexShapeLocation ;
154
+ }
140
155
}
141
156
}
142
157
@@ -255,14 +270,12 @@ export function runProgram<T extends Tensor, K extends Tensor>(
255
270
}
256
271
return ;
257
272
}
258
- // Upload the shape/texShape information as a uniform as well.
259
- const varShapeName = `shape${ varName } ` ;
260
- const varShapeLoc = binary . uniformLocations [ varShapeName ] ;
273
+ // Upload shape information as uniform.
274
+ const varShapeLoc = binary . uniformLocations [ `shape${ varName } ` ] ;
261
275
if ( varShapeLoc != null ) {
262
276
let shape : number [ ] | Int32Array ;
263
277
if ( binary . program . usesPackedTextures ) {
264
278
shape = util . packedShapeTransform ( input . shape ) ;
265
- // TODO yassogba@ should anything special happen for isPackShader?
266
279
} else {
267
280
// Call squeezeShape to match the shape used in the shader program
268
281
const { newShape} = util . squeezeShape ( input . shape ) ;
@@ -275,8 +288,22 @@ export function runProgram<T extends Tensor, K extends Tensor>(
275
288
gpgpu . gl . uniform1iv ( varShapeLoc , shape ) ;
276
289
}
277
290
278
- const varTexShapeName = `texShape${ varName } ` ;
279
- const varTexShapeLoc = binary . uniformLocations [ varTexShapeName ] ;
291
+ // Upload precomputed strides
292
+ const varStridesLoc = binary . uniformLocations [ `strides${ varName } ` ] ;
293
+ if ( varStridesLoc != null ) {
294
+ let strides : number [ ] | Int32Array ;
295
+ const { newShape} = util . squeezeShape ( input . shape ) ;
296
+ strides = util . computeStrides ( newShape ) ;
297
+
298
+ if ( ! ( strides instanceof Int32Array ) ) {
299
+ strides = new Int32Array ( strides ) ;
300
+ }
301
+ gpgpu . gl . uniform1iv ( varStridesLoc , strides ) ;
302
+ }
303
+
304
+
305
+ // Upload texShape/packedTexShape information as uniform.
306
+ const varTexShapeLoc = binary . uniformLocations [ `texShape${ varName } ` ] ;
280
307
// TODO(yassogba, nsthoat) rename/document these two shapes:
281
308
// input.texData.shape and input.texData.texShape
282
309
// to make it more apparent why they are both needed.
@@ -285,6 +312,15 @@ export function runProgram<T extends Tensor, K extends Tensor>(
285
312
texShape = new Int32Array ( texShape ) ;
286
313
}
287
314
gpgpu . gl . uniform1iv ( varTexShapeLoc , texShape ) ;
315
+ if ( binary . program . usesPackedTextures ) {
316
+ const varPackedTexShapeLoc =
317
+ binary . uniformLocations [ `packedTexShape${ varName } ` ] ;
318
+ if ( varPackedTexShapeLoc != null ) {
319
+ const packedTexShape =
320
+ [ Math . ceil ( texShape [ 0 ] / 2 ) , Math . ceil ( texShape [ 1 ] / 2 ) ] ;
321
+ gpgpu . gl . uniform1iv ( varPackedTexShapeLoc , packedTexShape ) ;
322
+ }
323
+ }
288
324
289
325
// If the input was sliced, upload the flat offset index.
290
326
if ( input . texData . slice != null && varOffsetLoc != null ) {
@@ -295,6 +331,7 @@ export function runProgram<T extends Tensor, K extends Tensor>(
295
331
} ) ;
296
332
297
333
// Upload output shape uniforms
334
+ // TODO yassogba upload outputStrides and outputPackedTexShape
298
335
if ( output . shape . length > 0 ) {
299
336
const outputShapeName = `outputShape` ;
300
337
const outputShapeLoc = binary . uniformLocations [ outputShapeName ] ;
0 commit comments