mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simpler
This commit is contained in:
parent
e268eb2d5c
commit
7ddcb8632f
3 changed files with 15 additions and 5 deletions
|
|
@ -81,5 +81,15 @@ class TestRangeify(unittest.TestCase):
|
|||
out = blk._feed_forward(x)
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
MATDIM = 16
|
||||
EMB = 8
|
||||
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class CStyleLanguage(Renderer):
|
|||
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
|
|
|||
|
|
@ -94,11 +94,11 @@ pm_rangeify = PatternMatcher([
|
|||
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([x.src[1+p] for p in argsort(x.src[0].arg)]))),
|
||||
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)]))),
|
||||
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)]))),
|
||||
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[a.const_like(0) if resolve(x!=y, False) else a for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
|
||||
|
|
@ -123,7 +123,7 @@ class AddBufferContext:
|
|||
def add_store(ctx:AddBufferContext, x:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
buf = UOp(Ops.DEFINE_GLOBAL if prod(shape) > 65536 else Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape)), arg=ctx.dg)
|
||||
buf = UOp(Ops.DEFINE_GLOBAL if prod(shape) > 65536 or ctx.dg == 0 else Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape)), arg=ctx.dg)
|
||||
ctx.map[buf] = (buf.op, ctx.dg)
|
||||
ctx.dg += 1
|
||||
return buf.reshape(shape).index(*rngs).store(x.src[0], *rngs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue