mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eee1a161b |
||
|
|
7bc9ebf201 | ||
|
|
1c8517d1a3 | ||
|
|
5a6790e58b |
7 changed files with 42 additions and 8 deletions
|
|
@ -269,6 +269,16 @@ class TestAssign(unittest.TestCase):
|
|||
out = attn.cache_k.flatten().numpy()
|
||||
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
||||
|
||||
def test_assign_after(self):
|
||||
t = Tensor.zeros(10).contiguous().realize()
|
||||
t.uop = t.uop.after(t.uop.assign((t+1).uop))
|
||||
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_assign_after_partial(self):
|
||||
t = Tensor.zeros(10).contiguous().realize()
|
||||
t.uop = t.uop.after(t[:5].uop.assign(Tensor.ones(5).uop))
|
||||
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.])
|
||||
|
||||
def test_assign_contiguous(self):
|
||||
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
||||
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import unittest
|
||||
from tinygrad.function import function
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
|
@ -102,7 +103,6 @@ class TestFunction(unittest.TestCase):
|
|||
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.])
|
||||
|
||||
def test_symbolic_index(self):
|
||||
from tinygrad.uop.ops import UOp
|
||||
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||
@function
|
||||
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
|
|
@ -111,6 +111,14 @@ class TestFunction(unittest.TestCase):
|
|||
v = UOp.variable("start_pos", 0, 3)
|
||||
np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13])
|
||||
|
||||
def test_symbolic_shape_input(self):
|
||||
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * 2
|
||||
sz = UOp.variable("sz", 1, 3)
|
||||
slic = table[:sz.bind(2)]
|
||||
np.testing.assert_equal(f(slic)[:2].numpy(), [20,40])
|
||||
|
||||
def test_nested_calls(self):
|
||||
w = Tensor([10., 20., 30.])
|
||||
@function
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ class TransformerBlock:
|
|||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
@function
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
|
|
@ -131,9 +132,15 @@ class TransformerBlock:
|
|||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
# TODO: fix assign to behave like this
|
||||
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
|
||||
tensor_assigned_kv = Tensor(assigned_kv, device=assigned_kv.device)
|
||||
k = tensor_assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = tensor_assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ def add_to_ctx(ctx, x:UOp):
|
|||
return ret
|
||||
|
||||
pm_ctx = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="x"), add_to_ctx),
|
||||
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
|
||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class IndexingContext:
|
|||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||||
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.AFTER}: return None
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
|
|
@ -179,6 +179,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
# no ranges on kernels, they are internal
|
||||
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
||||
|
||||
# no range on after
|
||||
if x.op is Ops.AFTER: continue
|
||||
|
||||
# treat MSTACK/MSELECT like SINK
|
||||
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
|||
|
||||
dict_map = {x:args[x.arg] for x in params}
|
||||
for i, (p, a) in enumerate(dict_map.items()):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict_map, walk=True)
|
||||
|
||||
|
|
@ -364,6 +364,11 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
|
||||
# remove any RESHAPEs on KERNEL
|
||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
|
||||
# remove MOP on AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(GroupOp.Movement, name="y"))), lambda x,y: x.after(y.src[0])),
|
||||
# remove double AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:]))
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -258,7 +258,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||
else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# VECTORIZE/CONST
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue