mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eee1a161b |
||
|
|
7bc9ebf201 | ||
|
|
1c8517d1a3 | ||
|
|
5a6790e58b |
7 changed files with 42 additions and 8 deletions
|
|
@ -269,6 +269,16 @@ class TestAssign(unittest.TestCase):
|
||||||
out = attn.cache_k.flatten().numpy()
|
out = attn.cache_k.flatten().numpy()
|
||||||
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
||||||
|
|
||||||
|
def test_assign_after(self):
|
||||||
|
t = Tensor.zeros(10).contiguous().realize()
|
||||||
|
t.uop = t.uop.after(t.uop.assign((t+1).uop))
|
||||||
|
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||||
|
|
||||||
|
def test_assign_after_partial(self):
|
||||||
|
t = Tensor.zeros(10).contiguous().realize()
|
||||||
|
t.uop = t.uop.after(t[:5].uop.assign(Tensor.ones(5).uop))
|
||||||
|
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.])
|
||||||
|
|
||||||
def test_assign_contiguous(self):
|
def test_assign_contiguous(self):
|
||||||
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
||||||
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.function import function
|
from tinygrad.function import function
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
|
from tinygrad.uop.ops import UOp
|
||||||
|
|
||||||
class TestFunction(unittest.TestCase):
|
class TestFunction(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
|
@ -102,7 +103,6 @@ class TestFunction(unittest.TestCase):
|
||||||
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.])
|
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.])
|
||||||
|
|
||||||
def test_symbolic_index(self):
|
def test_symbolic_index(self):
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
table = Tensor([10,20,30,40]).contiguous().realize()
|
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||||
@function
|
@function
|
||||||
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
|
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||||
|
|
@ -111,6 +111,14 @@ class TestFunction(unittest.TestCase):
|
||||||
v = UOp.variable("start_pos", 0, 3)
|
v = UOp.variable("start_pos", 0, 3)
|
||||||
np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13])
|
np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13])
|
||||||
|
|
||||||
|
def test_symbolic_shape_input(self):
|
||||||
|
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||||
|
@function
|
||||||
|
def f(x:Tensor) -> Tensor: return x * 2
|
||||||
|
sz = UOp.variable("sz", 1, 3)
|
||||||
|
slic = table[:sz.bind(2)]
|
||||||
|
np.testing.assert_equal(f(slic)[:2].numpy(), [20,40])
|
||||||
|
|
||||||
def test_nested_calls(self):
|
def test_nested_calls(self):
|
||||||
w = Tensor([10., 20., 30.])
|
w = Tensor([10., 20., 30.])
|
||||||
@function
|
@function
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ class TransformerBlock:
|
||||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
|
||||||
|
@function
|
||||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||||
x_norm = self.attn_norm(x) # (B,T,D)
|
x_norm = self.attn_norm(x) # (B,T,D)
|
||||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||||
|
|
@ -131,9 +132,15 @@ class TransformerBlock:
|
||||||
q = apply_rope(q, freqs_cis)
|
q = apply_rope(q, freqs_cis)
|
||||||
k = apply_rope(k, freqs_cis)
|
k = apply_rope(k, freqs_cis)
|
||||||
|
|
||||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
# TODO: fix assign to behave like this
|
||||||
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
|
||||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device)
|
||||||
|
k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||||
|
v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||||
|
|
||||||
|
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||||
|
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||||
|
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||||
|
|
||||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ def add_to_ctx(ctx, x:UOp):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
pm_ctx = PatternMatcher([
|
pm_ctx = PatternMatcher([
|
||||||
(UPat(Ops.BUFFER, name="x"), add_to_ctx),
|
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||||
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
|
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
|
||||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
|
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class IndexingContext:
|
||||||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||||||
|
|
||||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.AFTER}: return None
|
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||||
new_srcs = []
|
new_srcs = []
|
||||||
for s in x.src:
|
for s in x.src:
|
||||||
new_src = s
|
new_src = s
|
||||||
|
|
@ -179,6 +179,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
# no ranges on kernels, they are internal
|
# no ranges on kernels, they are internal
|
||||||
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
||||||
|
|
||||||
|
# no range on after
|
||||||
|
if x.op is Ops.AFTER: continue
|
||||||
|
|
||||||
# treat MSTACK/MSELECT like SINK
|
# treat MSTACK/MSELECT like SINK
|
||||||
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue
|
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||||
|
|
||||||
dict_map = {x:args[x.arg] for x in params}
|
dict_map = {x:args[x.arg] for x in params}
|
||||||
for i, (p, a) in enumerate(dict_map.items()):
|
for i, (p, a) in enumerate(dict_map.items()):
|
||||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||||
return c.src[0].substitute(dict_map, walk=True)
|
return c.src[0].substitute(dict_map, walk=True)
|
||||||
|
|
||||||
|
|
@ -364,6 +364,11 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||||
|
|
||||||
# remove any RESHAPEs on KERNEL
|
# remove any RESHAPEs on KERNEL
|
||||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||||
|
|
||||||
|
# remove MOP on AFTER
|
||||||
|
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(GroupOp.Movement, name="y"))), lambda x,y: x.after(y.src[0])),
|
||||||
|
# remove double AFTER
|
||||||
|
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:]))
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||||
|
|
|
||||||
|
|
@ -258,7 +258,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||||
# only RANGE/IF/STORE/KERNEL have side effects
|
# only RANGE/IF/STORE/KERNEL have side effects
|
||||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR} else y.src for y in x.src[1:]])))),
|
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||||
|
else y.src for y in x.src[1:]])))),
|
||||||
# after with 1 src is just src[0]
|
# after with 1 src is just src[0]
|
||||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||||
# VECTORIZE/CONST
|
# VECTORIZE/CONST
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue