mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
transformer kvcache and mask have same dtype as input (#2771)
* transformer kvcache and mask have same dtype as input * don't use `=0` in cstyle ternary where * (bool) * where float16 test
This commit is contained in:
parent
2dd0dd4ae0
commit
c0f76ed4ea
4 changed files with 13 additions and 13 deletions
|
|
@ -35,9 +35,7 @@ class Attention:
|
|||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
if HALF:
|
||||
self.cache_kv = self.cache_kv.half()
|
||||
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype)
|
||||
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
|
@ -89,13 +87,11 @@ class Transformer:
|
|||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
if HALF: h = h.half()
|
||||
|
||||
if HALF:
|
||||
h = h.half()
|
||||
if mask is not None: mask = mask.half()
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
for hi in self.h: h = hi(h, start_pos, mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
|
|
|
|||
|
|
@ -57,12 +57,13 @@ class Attention:
|
|||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
self.cache_k = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
|
@ -110,9 +111,9 @@ class Transformer:
|
|||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tinygrad.device import Buffer, Device
|
|||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.device import CompiledASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
|
||||
|
|
@ -95,8 +96,10 @@ class TestNonFloatUOps(TestUOps):
|
|||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, PtrDType(dtypes.float16))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class CStyleLanguage(NamedTuple):
|
|||
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"((bool){a}?{b}:{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue