mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
59dfb234eb
commit
01e8b60911
15 changed files with 77 additions and 77 deletions
|
|
@ -103,7 +103,7 @@ class Int8Embedding:
|
|||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T
|
||||
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
|
||||
def NF4Linear(block_size):
|
||||
_CODE = [
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class EmbeddingBert(nn.Embedding):
|
|||
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
|
||||
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype)
|
||||
|
||||
class LayerNormBert:
|
||||
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ if __name__ == "__main__":
|
|||
for K in range(K_START, K_STOP+1, K_STEP):
|
||||
print(f"testing {M=} {N=} {K=}")
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
c = a.matmul(b, dtype=acc_dtype).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = rand_input()
|
||||
c = a.conv2d(b, padding=PADDING, acc_dtype=acc_dtype).realize()
|
||||
c = a.conv2d(b, padding=PADDING, dtype=acc_dtype).realize()
|
||||
|
||||
if COMP:
|
||||
import numpy as np, time, torch
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if __name__ == "__main__":
|
|||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = init_matrix(M, K), init_matrix(K, N)
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
c = a.matmul(b, dtype=acc_dtype).realize()
|
||||
|
||||
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
res = c.numpy()
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if __name__ == "__main__":
|
|||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = _rand(device)
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
c = a.matmul(b, dtype=acc_dtype).realize()
|
||||
nc = c.numpy()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
|||
# NOTE: axis=[] in torch means all, change tinygrad?
|
||||
"aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None:
|
||||
self.sum(axis if axis is None or len(axis) else None, keepdim,
|
||||
acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None),
|
||||
dtype = _from_torch_dtype(dtype) if dtype is not None else None),
|
||||
}}
|
||||
|
||||
# we add the "out" here
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
|
||||
|
||||
np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy())
|
||||
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
|
|
@ -694,21 +694,21 @@ class TestAutoCastType(unittest.TestCase):
|
|||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_sum_acc_dtype(self):
|
||||
def test_sum_dtype_arg(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
# default float16 sum returns in float16, overflowed in this case
|
||||
assert t.sum().dtype == dtypes.float16
|
||||
assert math.isinf(t.sum().numpy().item())
|
||||
# specifiying acc_dtype and it's not downcasted
|
||||
assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000)
|
||||
# specifiying dtype and it's not downcasted
|
||||
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
|
||||
|
||||
def test_prod_acc_dtype(self):
|
||||
def test_prod_dtype_arg(self):
|
||||
t = Tensor([100, 200], dtype=dtypes.int32)
|
||||
assert t.prod().dtype == dtypes.int32
|
||||
np.testing.assert_allclose(t.prod().numpy(), 20000)
|
||||
assert t.prod(acc_dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.prod(acc_dtype=dtypes.float32).numpy(), 20000)
|
||||
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
|
||||
|
||||
def test_mean(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
|
||||
|
|
@ -745,8 +745,8 @@ class TestAutoCastType(unittest.TestCase):
|
|||
t1 = Tensor([0, 1], dtype=dt1)
|
||||
t2 = Tensor([0, 1], dtype=dt2)
|
||||
assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2)
|
||||
# if acc_dtype is specified, return in acc_dtype
|
||||
assert (t1.matmul(t2, acc_dtype=acc_dt).dtype == acc_dt)
|
||||
# if dtype is specified, return in dtype
|
||||
assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer
|
|||
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
r = a.matmul(b, dtype=dtype_out)
|
||||
sched = r.schedule()
|
||||
realized_ast = sched[-1].ast
|
||||
run_schedule(sched)
|
||||
|
|
@ -47,7 +47,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
|||
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
|
||||
ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
r = a.matmul(b, dtype=dtype_out)
|
||||
sched = r.schedule()
|
||||
realized_ast = sched[-1].ast
|
||||
k = Kernel(realized_ast)
|
||||
|
|
@ -1050,11 +1050,11 @@ class TestLinearizer(unittest.TestCase):
|
|||
for tensor_dtype, acc_dtype, expected_dtype in tests:
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
|
||||
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, dtype=acc_dtype), expected_dtype)
|
||||
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores(self):
|
||||
|
|
@ -1101,7 +1101,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
for axis in range(9):
|
||||
a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize()
|
||||
b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize()
|
||||
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
|
||||
c = a.conv2d(b, padding=1, dtype=tc.dtype_out)
|
||||
realized_ast, real_bufs = helper_realized_ast(c)
|
||||
|
||||
k = Kernel(realized_ast)
|
||||
|
|
@ -1130,7 +1130,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_tensor_cores_unroll_phi(self):
|
||||
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out)
|
||||
r = x.matmul(y, dtype=tc.dtype_out)
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.op is Ops.WMMA:
|
||||
|
|
@ -1141,7 +1141,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_tensor_cores_unroll_casted_phi(self):
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out)
|
||||
r = x.matmul(y, dtype=tc.dtype_out)
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.op is Ops.WMMA:
|
||||
|
|
@ -1154,7 +1154,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
# all ASSIGN children are outside the loop
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
|
||||
r = x.matmul(y, dtype=tc.dtype_out).relu()
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.op is Ops.WMMA:
|
||||
|
|
@ -2000,7 +2000,7 @@ class TestKernelOpts(unittest.TestCase):
|
|||
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
|
||||
if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue
|
||||
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
r = a.matmul(b, dtype=tc.dtype_out)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_opt(r, [
|
||||
[],
|
||||
|
|
@ -2027,7 +2027,7 @@ class TestKernelOpts(unittest.TestCase):
|
|||
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
|
||||
if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue
|
||||
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
r = a.matmul(b, dtype=tc.dtype_out)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
|
||||
|
|
@ -2093,10 +2093,10 @@ class TestKernelOpts(unittest.TestCase):
|
|||
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
helper_linearizer_opt(b.sum(0, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(1, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
|
||||
# having unsafe ops after sum is fine
|
||||
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
|
|
|
|||
|
|
@ -1273,11 +1273,11 @@ class TestOps(unittest.TestCase):
|
|||
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
|
||||
|
||||
def test_sum_acc_dtype(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32))
|
||||
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64))
|
||||
def test_sum_dtype_arg(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32))
|
||||
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64))
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="")
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(dtype="")
|
||||
|
||||
def test_sum_with_zeros_shape(self):
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
|
||||
|
|
@ -1294,8 +1294,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([()], lambda x: x.prod(0))
|
||||
helper_test_op([()], lambda x: x.prod(-1))
|
||||
|
||||
def test_prod_acc_dtype(self):
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(acc_dtype="")
|
||||
def test_prod_dtype_arg(self):
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(dtype="")
|
||||
|
||||
def test_min(self):
|
||||
helper_test_op([(3,3)], lambda x: x.min())
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestQuantizeOnnx(unittest.TestCase):
|
|||
def test_prequant_conv2d_1x1(self):
|
||||
X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8))
|
||||
W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8))
|
||||
out = X.conv2d(W, acc_dtype=X.dtype)
|
||||
out = X.conv2d(W, dtype=X.dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ class TestQuantizeOnnx(unittest.TestCase):
|
|||
N = 512
|
||||
X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
|
||||
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
|
||||
out = X.matmul(W, acc_dtype=X.dtype)
|
||||
out = X.matmul(W, dtype=X.dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ class TestQuantizeOnnx(unittest.TestCase):
|
|||
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)).realize()
|
||||
#out = X.cast(dtypes.int) @ W.cast(dtypes.int)
|
||||
#out = X @ W
|
||||
out = X.matmul(W, acc_dtype=X.dtype)
|
||||
out = X.matmul(W, dtype=X.dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,8 +68,8 @@ class TestPTXFailures(unittest.TestCase):
|
|||
def test_gated_define_acc_with_half_dtype(self):
|
||||
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
|
||||
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()
|
||||
result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy()
|
||||
reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy()
|
||||
result = a.pad((1,1)).matmul(b, dtype=dtypes.half).numpy()
|
||||
reference = a.pad((1,1)).matmul(b, dtype=dtypes.float).numpy()
|
||||
np.testing.assert_allclose(result, reference)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class TestBEAM(unittest.TestCase):
|
|||
for (dtype_in, dtype_out) in multi_shape_dtype_pairs:
|
||||
a = Tensor.rand(16, 16, dtype=dtype_in)
|
||||
b = Tensor.rand(16, 16, dtype=dtype_in)
|
||||
realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out))
|
||||
realized_ast, _ = helper_realized_ast(a.matmul(b, dtype=dtype_out))
|
||||
|
||||
lins = get_kernel_actions(Kernel(realized_ast)).values()
|
||||
assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ class Embedding:
|
|||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
|
||||
class LSTMCell:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1164,7 +1164,7 @@ class Tensor(SimpleMathTrait):
|
|||
# inject 1's for the extra dims added in create masks
|
||||
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
|
||||
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
|
||||
|
||||
# special permute case
|
||||
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
|
||||
|
|
@ -1255,7 +1255,7 @@ class Tensor(SimpleMathTrait):
|
|||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
index = index.to(self.device)
|
||||
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
|
||||
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1564,14 +1564,14 @@ class Tensor(SimpleMathTrait):
|
|||
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
||||
|
||||
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Returns the sum of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
||||
which the maximum is computed and whether the reduced dimensions are retained.
|
||||
|
||||
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
||||
You can pass in `dtype` keyword argument to control the data type of the accumulation.
|
||||
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
|
|
@ -1588,17 +1588,17 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.sum(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Returns the product of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
||||
which the maximum is computed and whether the reduced dimensions are retained.
|
||||
|
||||
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
||||
You can pass in `dtype` keyword argument to control the data type of the accumulation.
|
||||
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
|
|
@ -1615,7 +1615,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.prod(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
return self.cast(dtype if dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
|
||||
def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -2005,7 +2005,7 @@ class Tensor(SimpleMathTrait):
|
|||
return self._inverse().argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@staticmethod
|
||||
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
|
|
@ -2047,7 +2047,7 @@ class Tensor(SimpleMathTrait):
|
|||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
|
|
@ -2182,7 +2182,7 @@ class Tensor(SimpleMathTrait):
|
|||
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
|
||||
|
|
@ -2208,7 +2208,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.conv2d(w).numpy())
|
||||
```
|
||||
"""
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
padding_ = self._resolve_pool_pads(padding, len(HW))
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
||||
|
|
@ -2221,7 +2221,7 @@ class Tensor(SimpleMathTrait):
|
|||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
|
|
@ -2246,7 +2246,7 @@ class Tensor(SimpleMathTrait):
|
|||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
|
||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
|
||||
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
||||
|
||||
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
||||
|
|
@ -2294,14 +2294,14 @@ class Tensor(SimpleMathTrait):
|
|||
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
||||
|
||||
"""
|
||||
Performs dot product between two tensors.
|
||||
If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
|
||||
If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
|
||||
|
||||
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
||||
You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
a = Tensor([1, 2, 3])
|
||||
|
|
@ -2314,20 +2314,20 @@ class Tensor(SimpleMathTrait):
|
|||
print(a.dot(b).numpy())
|
||||
```
|
||||
"""
|
||||
if IMAGE: return self.image_dot(w, acc_dtype)
|
||||
if IMAGE: return self.image_dot(w, dtype)
|
||||
x, dx, dw = self, self.ndim, w.ndim
|
||||
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
||||
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
|
||||
x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
|
||||
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
||||
return (x*w).sum(-1, dtype=dtype).cast(least_upper_dtype(x.dtype, w.dtype) if dtype is None else dtype)
|
||||
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def matmul(self, x:Tensor, reverse=False, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Performs matrix multiplication between two tensors.
|
||||
|
||||
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
|
||||
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
||||
You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
a = Tensor([[1, 2], [3, 4]])
|
||||
|
|
@ -2335,7 +2335,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(a.matmul(b).numpy())
|
||||
```
|
||||
"""
|
||||
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
||||
return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype)
|
||||
|
||||
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
||||
|
|
@ -2553,14 +2553,14 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
# TODO: should not overwrite acc_dtype here?
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
||||
# TODO: should not overwrite dtype here?
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "mean":
|
||||
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
||||
|
||||
def topk(self, k, dim=-1, largest=True, sorted_=True):
|
||||
|
|
@ -3677,7 +3677,7 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
# NOTE: it also works when `key` and `value` have symbolic shape.
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
||||
qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
|
|
@ -3983,7 +3983,7 @@ class Tensor(SimpleMathTrait):
|
|||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
def image_dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
x, dx, dw = self, self.ndim, w.ndim
|
||||
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
||||
|
|
@ -3997,9 +3997,9 @@ class Tensor(SimpleMathTrait):
|
|||
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
||||
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
|
|
@ -4048,7 +4048,7 @@ class Tensor(SimpleMathTrait):
|
|||
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue