This commit is contained in:
George Hotz 2025-08-09 08:14:02 -07:00
commit 7ddcb8632f
3 changed files with 15 additions and 5 deletions

View file

@ -81,5 +81,15 @@ class TestRangeify(unittest.TestCase):
out = blk._feed_forward(x)
out.realize()
def test_flash_attention(self):
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
q.scaled_dot_product_attention(k, v).realize()
if __name__ == '__main__':
unittest.main()

View file

@ -103,7 +103,7 @@ class CStyleLanguage(Renderer):
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
string_rewrite = base_rewrite
extra_matcher = extra_pm

View file

@ -94,11 +94,11 @@ pm_rangeify = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([x.src[1+p] for p in argsort(x.src[0].arg)]))),
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)]))),
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)]))),
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[a.const_like(0) if resolve(x!=y, False) else a for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape)])),
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
@ -123,7 +123,7 @@ class AddBufferContext:
def add_store(ctx:AddBufferContext, x:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
buf = UOp(Ops.DEFINE_GLOBAL if prod(shape) > 65536 else Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape)), arg=ctx.dg)
buf = UOp(Ops.DEFINE_GLOBAL if prod(shape) > 65536 or ctx.dg == 0 else Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape)), arg=ctx.dg)
ctx.map[buf] = (buf.op, ctx.dg)
ctx.dg += 1
return buf.reshape(shape).index(*rngs).store(x.src[0], *rngs)