-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multireduce Kernels - Linearizer and Scheduler Changes #4208
Conversation
Can you break this up into Linearizer and Scheduler changes? First Linearizer to support generation of the kernels (with lots of good tests!), then scheduler to actually enable it. |
yeah 100% |
Does this bring layernorm to a single kernel? Right now it's three (discussing in #scheduler on discord)
|
I can get it to fuse into two and probably could get it to be one; it depends mostly on how we adjust the rules around what to fuse |
left suggestions with the unit tests in #4220 - in general:
|
Thanks for the tests! I would like to try to generalize scheduling fusions; for reduceops my thinking was that any set of consecutive shape transformations from the same shape ought to be fused: ex. the two SUMs in standard deviation or layernorm. There is the issue of control flow divergence, where any operations on the reduced shape won't use all the threads: ex the DIV by N in a mean calculation I started working on a |
An issue with getting layernorm to fuse into one kernel is that the output shape != the reduced shape, ex if the scheduler allows it to fuse it will look like this:
alu2 is computed in a previous loop so it can't be accessed in the final parse. I could store or recompute it, both options make some sense for fusion because they don't require any trips to dram but from an occupancy perspective it might make sense to give the final (x -μ)/σ operation it's own kernel plus whatever comes after it |
# reduce op | ||
# reduce ops | ||
assert len(self.reduceops) == len(set(self.reduceops)), "All reduceops must be unique" | ||
self.reduceops.reverse() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah you need to linearize the deepest reduceop first.
I think .reverse() won't work for multi output because you could have an AST like:
0 ━┳ STORE MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)))
1 ┗━┳ SUM ((0,), dtypes.float)
2 ┗━━ LOAD MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32,), strides=(1,), offset=0, mask=None, contiguous=True),)))
0 ━┳ STORE MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)))
1 ┗━┳ SUM ((0,), dtypes.float)
2 ┗━┳ SUB
3 ┣━━ LOAD MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32,), strides=(1,), offset=0, mask=None, contiguous=True),)))
4 ┗━┳ SUM ((0,), dtypes.float)
5 ┗━━ LOAD MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32,), strides=(1,), offset=0, mask=None, contiguous=True),)))
two ideas:
- self.reduceops orders by depth
render_reduceop
recurses (similar to parse_ast)
@@ -111,7 +109,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La | |||
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \ | |||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]): | |||
simple_pads.add(buf.base) | |||
else: | |||
elif all([x.op not in ReduceOps for x in buf.base.srcs if hasattr(x, "op")]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if buf isn't the only child of x?
x = Tensor.empty(32)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
e = r0 + r1
e.realize()
https://tiny-tools-client.vercel.app/?id=7b4dc4f7c0c34dcb94e1d82639cc3180

This branch currently is behind tinygrad/master. The line count difference bot is disabled. |
re: an earlier pr this implements the actual kernel fusion of standard deviation
it:
schedule.py
to put two ReduceOps in a kernel if they have the same shapeslinearizer.py
to handle the multiple reduceopsan example of a fused kernel: