transformer kvcache and mask have same dtype as input (#2771)

* transformer kvcache and mask have same dtype as input

* don't use `=0` in cstyle ternary where

* (bool)

* where float16 test
This commit is contained in:
chenyu 2023-12-14 22:41:51 -05:00 committed by GitHub
commit c0f76ed4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 13 deletions

View file

@ -35,9 +35,7 @@ class Attention:
# create kv cache
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
if HALF:
self.cache_kv = self.cache_kv.half()
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype)
keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
@ -89,13 +87,11 @@ class Transformer:
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
h = tok_emb + pos_emb
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
if HALF: h = h.half()
if HALF:
h = h.half()
if mask is not None: mask = mask.half()
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1).realize() if seqlen > 1 else None
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
for hi in self.h: h = hi(h, start_pos, mask)
logits = self.lm_head(self.ln_f(h))
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead

View file

@ -57,12 +57,13 @@ class Attention:
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, n_heads, head_dim = xq.shape
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
self.cache_k = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
@ -110,9 +111,9 @@ class Transformer:
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
_bsz, seqlen = tokens.shape
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
h = self.tok_embeddings(tokens)
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos+1).realize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()

View file

@ -7,6 +7,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from test.test_dtype import is_dtype_supported
def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
@ -95,8 +96,10 @@ class TestNonFloatUOps(TestUOps):
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, PtrDType(dtypes.float16))
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -36,7 +36,7 @@ class CStyleLanguage(NamedTuple):
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"((bool){a}?{b}:{c})"
}
# returns a str expression of the casted xs with the given type