single kernel softmax (#9776)

* real single kernel softmax

* cleanup

* fix blockend insertion

* add to bert test
This commit is contained in:
George Hotz 2025-04-08 12:35:48 +08:00 committed by GitHub
commit fefee5d3ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 93 additions and 117 deletions

View file

@ -1,8 +1,15 @@
from tinygrad import Tensor, dtypes
from tinygrad import Tensor, dtypes, Context, GlobalCounters
dtypes.default_float = dtypes.float16
from test.test_softmax_fusion import single_kernel_softmax
if __name__ == "__main__":
# softmax in bert layers
BS = 96//6
t = Tensor.empty(BS, 16, 512, 512)
t.softmax(-1, dtype="half").realize()
# test single kernel softmax
GlobalCounters.reset()
with Context(DONT_GROUP_REDUCES=1):
single_kernel_softmax(t, -1, "half").realize()

View file

@ -1,112 +0,0 @@
import unittest
from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.ops import Ops, UOp, graph_rewrite, PatternMatcher, track_rewrites, UPat
from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import dtypes # noqa: F401 # pylint: disable=unused-import
from tinygrad.shape.shapetracker import ShapeTracker, View # noqa: F401 # pylint: disable=unused-import
# softmax kernel
softmax_ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x1:=UOp(Ops.EXP2, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x4:=UOp(Ops.VIEW, dtypes.float,
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(10, 1), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
x6:=UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
UOp(Ops.UNIQUE, dtypes.void, arg=0, src=()),)),)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float,
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
x4,)),)),
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
x12:=UOp(Ops.VIEW, dtypes.void,
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
x6,)),)),)),)),
UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
x12,)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(32, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
x1,)),)),)),)),))""")
pm_expand_view = PatternMatcher([
(UPat(Ops.VIEW, name="view"),
lambda view: UOp(Ops.EXPAND_AXIS, view.dtype, view.src,
tuple(i for i,x in enumerate(view.arg.views[-1].strides) if x == 0)) if view.arg.views[-1].strides == (1, 0) else None),
])
@track_rewrites()
def rewrite_softmax(ast):
from tinygrad.engine.grouper import merge_views, add_buffer_ops, fix_kernel_ops
sink = graph_rewrite(ast, pm_expand_view)
buffers = (UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
UOp(Ops.UNIQUE, dtypes.void, arg=1, src=()),)), UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
UOp(Ops.UNIQUE, dtypes.void, arg=0, src=()),)))
sink = graph_rewrite(sink, merge_views+add_buffer_ops+fix_kernel_ops, ctx=({}, buffers), bottom_up=True)
return sink
class TestSoftmaxFusion(unittest.TestCase):
@classmethod
def setUpClass(cls):
with Context(TRACK_MATCH_STATS=0): cls.test = Tensor.ones(32, 10).contiguous().realize()
def setUp(self):
GlobalCounters.reset()
def test_softmax(self):
# this is the softmax from scaled_dot_product_attention
# it becomes 3 kernels
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=2):
out = self.test.softmax(-1)
out.realize()
@unittest.skip("no EXPAND_AXIS")
def test_softmax_fuse(self):
sink = rewrite_softmax(softmax_ast)
k = Kernel(sink, Device.default.renderer)
prg = k.to_program()
print(prg.src)
def test_norm(self):
print("*** norm ***")
with Context(NOOPT=1, DEBUG=2):
# NOTE: you don't actually need the expand, it's broadcasted
out = self.test / self.test.mean(-1, keepdim=True).expand(32, 10)
out.realize()
def test_single_kernel_norm(self):
with Context(NOOPT=1, DEBUG=2):
inp = self.test.reshape(32, 10, 1)
div = self.test.reshape(32, 1, 10).expand(32, 10, 10).mean(axis=-1, keepdim=True)
out = inp / div
out.realize()
def test_single_kernel_softmax(self):
with Context(NOOPT=1, DEBUG=2):
inp = self.test.reshape(32, 10, 1)
imx = self.test.reshape(32, 1, 10).expand(32, 10, 10).max(axis=-1, keepdim=True)
m = inp - imx.detach()
e = m.exp()
ss = e.reshape(32,1,10).expand(32, 10, 10).sum(axis=-1, keepdim=True)
out = e.div(ss)
out.realize()
"""
inp = self.test.reshape(32, 10, 1, 1)
imx = self.test.reshape(32, 1, 10, 1).expand(32, 10, 10, 1).max(axis=-2, keepdim=True)
m = inp - imx.detach()
e = m.exp()
ss = e.reshape(32,1,1,10).expand(32, 10, 1, 10).sum(axis=-1, keepdim=True)
out = e.div(ss)
out.realize()
"""
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,69 @@
import unittest
import numpy as np
from tinygrad import Tensor, GlobalCounters, Context
from tinygrad.dtype import DTypeLike
from tinygrad.helpers import DEBUG
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
# only support axis =-1
x = x_in.reshape(-1, x_in.shape[-1])
nr_dim, r_dim = x.shape
inp = x.reshape(nr_dim, 1, 1, r_dim).expand(nr_dim, r_dim, 1, r_dim)
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, r_dim).max(axis=-2, keepdim=True)
m = inp - imx.detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
ss = e.sum(axis=-1, keepdim=True)
inp = x.reshape(nr_dim, r_dim, 1, 1)
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, 1).max(axis=-2, keepdim=True)
m = inp - imx.detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
out = e.div(ss).reshape(x_in.shape)
return out
class TestSoftmaxFusion(unittest.TestCase):
@classmethod
def setUpClass(cls):
with Context(TRACK_MATCH_STATS=0): cls.test = Tensor.rand(32, 10).contiguous().realize()
def setUp(self):
GlobalCounters.reset()
def test_norm(self):
print("*** norm ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
# NOTE: there's an implied expand on the mean here
sout = self.test / self.test.mean(-1, keepdim=True)
sout.realize()
print("*** single kernel norm ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
inp = self.test.reshape(32, 10, 1)
div = self.test.reshape(32, 1, 10).expand(32, 10, 10).mean(axis=-1, keepdim=True)
out = (inp / div).reshape(32, 10)
out.realize()
np.testing.assert_allclose(sout.numpy(), out.numpy())
def test_softmax(self):
# this is the softmax from scaled_dot_product_attention
# it becomes 3 kernels
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
sout = self.test.softmax(-1)
sout.realize()
print("*** single kernel softmax ***")
# NOTE: DONT_GROUP_REDUCES is required here
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
out = single_kernel_softmax(self.test)
out.realize()
np.testing.assert_allclose(sout.numpy(), out.numpy())
if __name__ == '__main__':
unittest.main()

View file

@ -51,7 +51,8 @@ class Kernel:
self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
# get earlybufs, before any reduceops
earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
earlybufs: list[UOp] = sorted([x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer],
key=lambda x: -prod(x.shape))
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
# NOTE: full_shape can be wrong if there's a tree of reduces

View file

@ -215,9 +215,20 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
# add BLOCKFORK (slow!)
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
forks = {}
for u,child_count in block_parent_count.items():
if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents:
# TODO: this is copied from append_to_block
new_block = UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,)))
rng = block_ctxs[u]
lrng = list(rng)
for r in rng[::-1]:
# if none of the children of u are in the same context, we need a BLOCKEND
if all(r not in block_ctxs[c] for c in children[u]) and r.op is not Ops.BLOCKSTART:
lrng.remove(r)
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
forks[u] = UOp(Ops.BLOCKFORK, src=(new_block,), arg=child_count)
if not len(forks): break
sink = sink.substitute(forks)