mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove use of full with buffer=False and non-None device= (#16489)
This commit is contained in:
parent
6f2a2857c8
commit
19eb72ff60
10 changed files with 26 additions and 29 deletions
|
|
@ -32,9 +32,9 @@ class TestMovedConstFolding(unittest.TestCase):
|
|||
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
||||
|
||||
def test_copy_padded_const(self):
|
||||
schedule = Tensor.ones(4, device="CPU:0", buffer=False).pad(((1, 1),)).to("CPU:1").schedule_linear()
|
||||
schedule = Tensor.ones(4, buffer=False).pad(((1, 1),)).to("CPU:1").schedule_linear()
|
||||
assert not any(si.src[0].op is Ops.COPY for si in schedule.src), "const copy should be folded"
|
||||
np.testing.assert_equal(Tensor.ones(4, device="CPU:0", buffer=False).pad(((1, 1),)).to("CPU:1").numpy(), [0, 1, 1, 1, 1, 0])
|
||||
np.testing.assert_equal(Tensor.ones(4, buffer=False).pad(((1, 1),)).to("CPU:1").numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
|
||||
|
|
|
|||
|
|
@ -86,9 +86,9 @@ class TestQuantizeFP8(unittest.TestCase):
|
|||
class TestLocalAmax(unittest.TestCase):
|
||||
def test_multi_tensor_local_shard_amax(self):
|
||||
devices = ("CPU:0", "CPU:1")
|
||||
x = Tensor.arange(16, device=devices[0]).reshape(4, 4).cast(dtypes.float).contiguous().realize().shard(devices, axis=0).realize()
|
||||
x = Tensor.arange(16).reshape(4, 4).cast(dtypes.float).clone(devices[0]).realize().shard(devices, axis=0).realize()
|
||||
GlobalCounters.reset()
|
||||
out = (x * local_abs_max(x)).contiguous().realize()
|
||||
out = (x * local_abs_max(x)).clone().realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 4)
|
||||
self.assertEqual(out.tolist(), [[0., 7., 14., 21.], [28., 35., 42., 49.], [120., 135., 150., 165.], [180., 195., 210., 225.]])
|
||||
|
||||
|
|
|
|||
|
|
@ -1278,11 +1278,6 @@ class TestCopyFolding(unittest.TestCase):
|
|||
x = y.one_hot(10)
|
||||
check_schedule(x, 3, filter_sink=False)
|
||||
|
||||
def test_const_copy_multi(self):
|
||||
x = Tensor.ones(1, device="CPU", buffer=False).to_(["CPU", "CPU:1"]) * 2
|
||||
run_linear(*check_schedule(x, 2, filter_sink=False))
|
||||
self.assertEqual(x.item(), 2.0)
|
||||
|
||||
def test_late_const_copy_folding(self):
|
||||
a = Tensor.arange(3).clone().realize()
|
||||
zeros = Tensor.zeros(3, buffer=False).realize()
|
||||
|
|
|
|||
2
test/external/external_test_onnx_ops.py
vendored
2
test/external/external_test_onnx_ops.py
vendored
|
|
@ -371,7 +371,7 @@ class TestMainOnnxOps(TestOnnxOps):
|
|||
Shape = onnx_ops["Shape"]
|
||||
Compress = onnx_ops["Compress"]
|
||||
with Context(DEV="CPU"):
|
||||
x = Tensor.arange(4, device="PYTHON").reshape(2,2)
|
||||
x = Tensor.arange(4).clone("PYTHON").reshape(2,2)
|
||||
self.assertEqual(EyeLike(x).device, x.device)
|
||||
self.assertEqual(Shape(x).device, x.device)
|
||||
out = Compress(x, [True, False, True, False])
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestAttention(unittest.TestCase):
|
|||
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
assert_jit_cache_len(rope_prune, 1)
|
||||
assert_jit_cache_len(rope_noprune, 2)
|
||||
assert_jit_cache_len(rope_noprune, 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TestLinearizerRewrite(unittest.TestCase):
|
|||
print(prg.src[3].arg)
|
||||
|
||||
def test_arange(self):
|
||||
out = Tensor.arange(32, device="NULL")
|
||||
out = Tensor.arange(32).clone("NULL")
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
si = out.schedule_linear().src[-1]
|
||||
opts_to_apply = []
|
||||
|
|
@ -28,7 +28,7 @@ class TestLinearizerRewrite(unittest.TestCase):
|
|||
print(prg.src[3].arg)
|
||||
|
||||
def test_kernel_info(self):
|
||||
out = Tensor.arange(4, device="NULL")
|
||||
out = Tensor.arange(4).clone("NULL")
|
||||
si = out.schedule_linear().src[-1]
|
||||
|
||||
ast = si.src[0].replace(arg=KernelInfo(opts_to_apply=()))
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from tinygrad.uop.ops import resolve
|
|||
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device:str|None=None) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2, device=device)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end, device=device).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).clone(device)
|
||||
|
||||
class ExpertWeights:
|
||||
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
||||
|
|
@ -177,7 +177,7 @@ class TransformerBlock(FFNBlock):
|
|||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, buffer=False).triu(start_pos+1) \
|
||||
if resolve(T != 1) else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
|
|
@ -223,7 +223,7 @@ class MLATransformerBlock(FFNBlock):
|
|||
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = k[..., :self.config.kv_lora_rank]
|
||||
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, buffer=False).triu(start_pos+1) \
|
||||
if resolve(T != 1) else None
|
||||
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
|
||||
if mask is not None: attn = attn + mask
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
print(Tensor.full((2, 3), False).numpy())
|
||||
```
|
||||
"""
|
||||
# TODO: enable this check
|
||||
# if not buffer: assert device is None, "buffer=False does not support device specification"
|
||||
from tinygrad.uop.ops import UOp
|
||||
new_shape = argfix(shape)
|
||||
dt = to_dtype(dtype) if dtype is not None else None
|
||||
|
|
@ -724,8 +726,8 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device, buffer=False)
|
||||
values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis])
|
||||
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
|
||||
match = x.unsqueeze(-1).eq(values_t.unsqueeze(-2)) * type(self).ones(n, n, device=self.device, buffer=False).triu()
|
||||
idx = (-(match * type(self).arange(n, 0, -1, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32)
|
||||
match = x.unsqueeze(-1).eq(values_t.unsqueeze(-2)) * type(self).ones(n, n, buffer=False).triu()
|
||||
idx = (-(match * type(self).arange(n, 0, -1).reshape(n, 1)).max(-2) + n).cast(dtypes.int32)
|
||||
return values, idx.transpose(-1, axis)
|
||||
|
||||
def cummin(self, axis:int=0) -> tuple[Self, Self]:
|
||||
|
|
@ -771,7 +773,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
last_dim_size = x.shape[-1]
|
||||
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
|
||||
x_cummax, _ = x.cummax(-1)
|
||||
mask = type(self).ones(last_dim_size, last_dim_size, device=self.device, buffer=False).tril()
|
||||
mask = type(self).ones(last_dim_size, last_dim_size, buffer=False).tril()
|
||||
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), self.dtype.min).exp().sum(-1).log() + x_cummax
|
||||
return ret.transpose(-1, axis)
|
||||
|
||||
|
|
@ -868,7 +870,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
|
||||
x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape)
|
||||
# compute indices for sorted values
|
||||
mask = type(self).ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device, buffer=False).tril()
|
||||
mask = type(self).ones(orig_len, orig_len, dtype=dtypes.bool, buffer=False).tril()
|
||||
mask = mask.reshape((None, None) + (1,)*(self.ndim-dim-1))
|
||||
def compute_counts(t:Self): return (mask & t.unsqueeze(dim).eq(t.unsqueeze(dim+1))).sum(dim+1)
|
||||
count_orig, count_sorted = compute_counts(self), compute_counts(x)
|
||||
|
|
@ -1075,7 +1077,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
```
|
||||
"""
|
||||
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
||||
if isinstance(src, (int, float, bool)): src = type(self).full(index.shape, src, dtype=self.dtype, device=self.device, buffer=False)
|
||||
if isinstance(src, (int, float, bool)): src = type(self).full(index.shape, src, dtype=self.dtype, buffer=False)
|
||||
elif reduce: raise TypeError("non-scalar src is not supported with reduce arg. use scatter_reduce")
|
||||
if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
|
||||
if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
|
||||
|
|
|
|||
|
|
@ -1075,7 +1075,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
qk_matmul_return_val = scores
|
||||
|
||||
if is_causal:
|
||||
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, buffer=False).tril(0)
|
||||
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], dtype=dtypes.bool, buffer=False).tril(0)
|
||||
scores = scores.masked_fill(causal_mask.logical_not(), -float("inf"))
|
||||
|
||||
if attn_mask is not None:
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
|
|||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...]|None, dtype:DType) -> list[list[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype, buffer=False) for m in mat], dim=dim)
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), dtype=dtype, buffer=False) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
|
|
@ -1049,10 +1049,10 @@ class Tensor(OpMixin):
|
|||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
if size is None:
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False)
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
|
||||
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
||||
counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||
|
||||
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1127,7 +1127,7 @@ class Tensor(OpMixin):
|
|||
|
||||
data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
|
||||
|
||||
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False)
|
||||
state = Tensor.zeros(bs, 25, dtype=dtypes.uint64, buffer=False)
|
||||
for k in range(int(data.shape[1])):
|
||||
state = state ^ data[:, k]
|
||||
for i in range(24): # f1600
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue