[Inductor] track block shape of intermediary variables #149905
Labels
module: inductor
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 The feature, motivation and pitch
During codegen each we track the dtype and value range of each intermediary variable we emit in trition. See CSEVariable.
Dtype was recently added in #136778 by @arui-meta and subsequently iterated on in PRs like #141495 and #140057.
While dtypes are a bit finicky to get right, shapes are very easy to track in triton. More or less each operator broadcasts its inputs, reductions remove reduction dims, and then there are a few remaining ops.
@kundaMwiza recently had an use case of shapes in a pr
Ideally the shape of the input would be an attribute of a TritonCSEVariable via shape propagation
.Similarly, I ran into a bug in prologue fusion where I now need to add possibly extraneous broadcasts because in particular cases of loading a constant index we return a different shape.
I'm sure other future changes will run into needing shapes, and after adding we'll discover other places in the codebase we can simplify.
Alternatives
No response
Additional context
No response
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
The text was updated successfully, but these errors were encountered: