mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
single kernel softmax (#9776)
* real single kernel softmax * cleanup * fix blockend insertion * add to bert test
This commit is contained in:
parent
9963bb51e0
commit
fefee5d3ab
5 changed files with 93 additions and 117 deletions
|
|
@ -1,8 +1,15 @@
|
|||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad import Tensor, dtypes, Context, GlobalCounters
|
||||
dtypes.default_float = dtypes.float16
|
||||
from test.test_softmax_fusion import single_kernel_softmax
|
||||
|
||||
if __name__ == "__main__":
|
||||
# softmax in bert layers
|
||||
BS = 96//6
|
||||
t = Tensor.empty(BS, 16, 512, 512)
|
||||
t.softmax(-1, dtype="half").realize()
|
||||
|
||||
# test single kernel softmax
|
||||
GlobalCounters.reset()
|
||||
with Context(DONT_GROUP_REDUCES=1):
|
||||
single_kernel_softmax(t, -1, "half").realize()
|
||||
|
||||
|
|
|
|||
112
test/external/external_test_softmax_fusion.py
vendored
112
test/external/external_test_softmax_fusion.py
vendored
|
|
@ -1,112 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters, Context, Device
|
||||
from tinygrad.ops import Ops, UOp, graph_rewrite, PatternMatcher, track_rewrites, UPat
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
from tinygrad.dtype import dtypes # noqa: F401 # pylint: disable=unused-import
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
# softmax kernel
|
||||
softmax_ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
x1:=UOp(Ops.EXP2, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x4:=UOp(Ops.VIEW, dtypes.float,
|
||||
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(10, 1), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
|
||||
x6:=UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
|
||||
UOp(Ops.UNIQUE, dtypes.void, arg=0, src=()),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float,
|
||||
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
|
||||
x4,)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
|
||||
x12:=UOp(Ops.VIEW, dtypes.void,
|
||||
arg=ShapeTracker(views=(View(shape=(32, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x6,)),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
|
||||
x12,)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(32, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
x1,)),)),)),)),))""")
|
||||
|
||||
pm_expand_view = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: UOp(Ops.EXPAND_AXIS, view.dtype, view.src,
|
||||
tuple(i for i,x in enumerate(view.arg.views[-1].strides) if x == 0)) if view.arg.views[-1].strides == (1, 0) else None),
|
||||
])
|
||||
|
||||
@track_rewrites()
|
||||
def rewrite_softmax(ast):
|
||||
from tinygrad.engine.grouper import merge_views, add_buffer_ops, fix_kernel_ops
|
||||
sink = graph_rewrite(ast, pm_expand_view)
|
||||
buffers = (UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
|
||||
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
|
||||
UOp(Ops.UNIQUE, dtypes.void, arg=1, src=()),)), UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
|
||||
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),
|
||||
UOp(Ops.UNIQUE, dtypes.void, arg=0, src=()),)))
|
||||
sink = graph_rewrite(sink, merge_views+add_buffer_ops+fix_kernel_ops, ctx=({}, buffers), bottom_up=True)
|
||||
return sink
|
||||
|
||||
class TestSoftmaxFusion(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with Context(TRACK_MATCH_STATS=0): cls.test = Tensor.ones(32, 10).contiguous().realize()
|
||||
|
||||
def setUp(self):
|
||||
GlobalCounters.reset()
|
||||
|
||||
def test_softmax(self):
|
||||
# this is the softmax from scaled_dot_product_attention
|
||||
# it becomes 3 kernels
|
||||
print("*** softmax ***")
|
||||
with Context(NOOPT=1, DEBUG=2):
|
||||
out = self.test.softmax(-1)
|
||||
out.realize()
|
||||
|
||||
@unittest.skip("no EXPAND_AXIS")
|
||||
def test_softmax_fuse(self):
|
||||
sink = rewrite_softmax(softmax_ast)
|
||||
k = Kernel(sink, Device.default.renderer)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
|
||||
def test_norm(self):
|
||||
print("*** norm ***")
|
||||
with Context(NOOPT=1, DEBUG=2):
|
||||
# NOTE: you don't actually need the expand, it's broadcasted
|
||||
out = self.test / self.test.mean(-1, keepdim=True).expand(32, 10)
|
||||
out.realize()
|
||||
|
||||
def test_single_kernel_norm(self):
|
||||
with Context(NOOPT=1, DEBUG=2):
|
||||
inp = self.test.reshape(32, 10, 1)
|
||||
div = self.test.reshape(32, 1, 10).expand(32, 10, 10).mean(axis=-1, keepdim=True)
|
||||
out = inp / div
|
||||
out.realize()
|
||||
|
||||
def test_single_kernel_softmax(self):
|
||||
with Context(NOOPT=1, DEBUG=2):
|
||||
inp = self.test.reshape(32, 10, 1)
|
||||
imx = self.test.reshape(32, 1, 10).expand(32, 10, 10).max(axis=-1, keepdim=True)
|
||||
m = inp - imx.detach()
|
||||
e = m.exp()
|
||||
ss = e.reshape(32,1,10).expand(32, 10, 10).sum(axis=-1, keepdim=True)
|
||||
out = e.div(ss)
|
||||
out.realize()
|
||||
|
||||
"""
|
||||
inp = self.test.reshape(32, 10, 1, 1)
|
||||
imx = self.test.reshape(32, 1, 10, 1).expand(32, 10, 10, 1).max(axis=-2, keepdim=True)
|
||||
m = inp - imx.detach()
|
||||
e = m.exp()
|
||||
ss = e.reshape(32,1,1,10).expand(32, 10, 1, 10).sum(axis=-1, keepdim=True)
|
||||
out = e.div(ss)
|
||||
out.realize()
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
69
test/test_softmax_fusion.py
Normal file
69
test/test_softmax_fusion.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, Context
|
||||
from tinygrad.dtype import DTypeLike
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
# only support axis =-1
|
||||
x = x_in.reshape(-1, x_in.shape[-1])
|
||||
nr_dim, r_dim = x.shape
|
||||
|
||||
inp = x.reshape(nr_dim, 1, 1, r_dim).expand(nr_dim, r_dim, 1, r_dim)
|
||||
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, r_dim).max(axis=-2, keepdim=True)
|
||||
m = inp - imx.detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
ss = e.sum(axis=-1, keepdim=True)
|
||||
|
||||
inp = x.reshape(nr_dim, r_dim, 1, 1)
|
||||
imx = x.reshape(nr_dim, 1, r_dim, 1).expand(nr_dim, r_dim, r_dim, 1).max(axis=-2, keepdim=True)
|
||||
m = inp - imx.detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
|
||||
out = e.div(ss).reshape(x_in.shape)
|
||||
return out
|
||||
|
||||
class TestSoftmaxFusion(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with Context(TRACK_MATCH_STATS=0): cls.test = Tensor.rand(32, 10).contiguous().realize()
|
||||
|
||||
def setUp(self):
|
||||
GlobalCounters.reset()
|
||||
|
||||
def test_norm(self):
|
||||
print("*** norm ***")
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
# NOTE: there's an implied expand on the mean here
|
||||
sout = self.test / self.test.mean(-1, keepdim=True)
|
||||
sout.realize()
|
||||
|
||||
print("*** single kernel norm ***")
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
inp = self.test.reshape(32, 10, 1)
|
||||
div = self.test.reshape(32, 1, 10).expand(32, 10, 10).mean(axis=-1, keepdim=True)
|
||||
out = (inp / div).reshape(32, 10)
|
||||
out.realize()
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy())
|
||||
|
||||
def test_softmax(self):
|
||||
# this is the softmax from scaled_dot_product_attention
|
||||
# it becomes 3 kernels
|
||||
print("*** softmax ***")
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
sout = self.test.softmax(-1)
|
||||
sout.realize()
|
||||
|
||||
print("*** single kernel softmax ***")
|
||||
# NOTE: DONT_GROUP_REDUCES is required here
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
|
||||
out = single_kernel_softmax(self.test)
|
||||
out.realize()
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -51,7 +51,8 @@ class Kernel:
|
|||
self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
|
||||
|
||||
# get earlybufs, before any reduceops
|
||||
earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
|
||||
earlybufs: list[UOp] = sorted([x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer],
|
||||
key=lambda x: -prod(x.shape))
|
||||
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
||||
# NOTE: full_shape can be wrong if there's a tree of reduces
|
||||
|
||||
|
|
|
|||
|
|
@ -215,9 +215,20 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
|||
# add BLOCKFORK (slow!)
|
||||
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
||||
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
||||
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
|
||||
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
|
||||
|
||||
forks = {}
|
||||
for u,child_count in block_parent_count.items():
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents:
|
||||
# TODO: this is copied from append_to_block
|
||||
new_block = UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,)))
|
||||
rng = block_ctxs[u]
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
# if none of the children of u are in the same context, we need a BLOCKEND
|
||||
if all(r not in block_ctxs[c] for c in children[u]) and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
forks[u] = UOp(Ops.BLOCKFORK, src=(new_block,), arg=child_count)
|
||||
if not len(forks): break
|
||||
sink = sink.substitute(forks)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue