Compare commits

...

23 commits

Author SHA1 Message Date
George Hotz
cc21351428 move to SPEC=3 2026-04-30 12:42:33 -07:00
George Hotz
7e329c5219 spec=2 checks shape 2026-04-30 11:39:34 -07:00
George Hotz
d1193c72ac real fixes 2026-04-30 11:12:40 -07:00
George Hotz
fe2dcbd573 fix image dtypes 2026-04-30 10:58:41 -07:00
George Hotz
4b16c81944
Merge branch 'master' into dtype_shape 2026-04-30 10:35:32 -07:00
George Hotz
05638ed496
Merge branch 'master' into dtype_shape 2026-04-30 08:05:42 -07:00
George Hotz
f8460c1021 fix image 2026-04-30 07:37:40 -07:00
George Hotz
273e0a4fa6 test fix 2026-04-30 07:31:06 -07:00
George Hotz
649cdbf216
Merge branch 'master' into dtype_shape 2026-04-30 07:24:10 -07:00
George Hotz
7969b205dd fix 2026-04-30 07:22:06 -07:00
George Hotz
ac6dee758a fix test 2026-04-30 07:06:48 -07:00
George Hotz
4fb29cc0c4 const assert 2026-04-30 07:02:54 -07:00
George Hotz
c4d1792edf fixes 2026-04-30 06:33:00 -07:00
George Hotz
4a4455f5b1 rev 2026-04-30 06:30:28 -07:00
George Hotz
29dd605a91
Merge branch 'master' into dtype_shape 2026-04-29 19:41:10 -07:00
George Hotz
5325db3af6 direct buffer view 2026-04-29 18:38:04 -07:00
George Hotz
cfdff84df0 fixes 2026-04-29 18:20:20 -07:00
George Hotz
4bf0c35300 correct for index 2026-04-29 17:27:41 -07:00
George Hotz
8ad8249e06 tests pass 2026-04-29 16:01:29 -07:00
George Hotz
95d04048b0 more shapes 2026-04-29 15:57:47 -07:00
George Hotz
d1f9ade9a0 DEFINE_VAR can also have shape 2026-04-29 15:46:42 -07:00
George Hotz
dd19cdc0cd more shape 2026-04-29 15:31:30 -07:00
George Hotz
4b4cfc0d81 dtype.count is shape 2026-04-29 15:07:37 -07:00
3 changed files with 42 additions and 18 deletions

View file

@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
@ -265,7 +265,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
rctx.range_map[x] = (rngs, out_rngs)
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
# NOTE: SPEC=3 is broken here with shape
with Context(SPEC=min(SPEC.value, 2)):
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
return tsink, rctx
def render_ranges(*rngs_list, realized) -> str:

View file

@ -63,12 +63,17 @@ pm_fold_moved_after = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
def move_mop_before_index(r:UOp, idx:UOp):
# TODO: store requires this
try: src_shape = r.src[0]._shape
except RuntimeError: return None
return r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) \
if src_shape is not None and len(idx.src[1:]) == len(r.shape) else None
# movement op on INDEX as a PatternMatcher
# TODO: clean up .src[0]._shape is not None
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), move_mop_before_index),
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),

View file

@ -100,7 +100,11 @@ class UOpMetaClass(type):
buffers[created] = _buffer
if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender
if SPEC > 2: test_pyrender(created)
if SPEC > 2:
# SPEC=3 checks the shape
_ = created._shape
if SPEC > 3:
test_pyrender(created)
with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created))
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
return created
@ -212,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | \
Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@ -228,22 +232,28 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return inner_shape
case Ops.CAST:
# if it has a vec dtype, set the shape
if self.dtype.count > 1: return (self.dtype.count,)
# when PTX casts from ptr to non ptr, remove the shape
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
return None
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
case Ops.STACK: return (len(self.src),)
case Ops.INDEX:
# non pointer index doesn't have a shape
if not isinstance(self.dtype, PtrDType): return None
# fully indexed doesn't have a shape. TODO: remove this
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
# pointer index
return self.src[0].shape[len(self.src[1:]):]
shp:list[sint] = []
# NOTE: the acc buffer can have a dtype with count, we need it back here
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
for s in self.src[1:]: shp.extend(list(s.shape))
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
# TODO: VCONST should have the shape of the arg
case Ops.VCONST: return ()
case Ops.CONST | Ops.DEFINE_VAR:
# these can have shape if it has a vec dtype
if self.dtype.count > 1: return (self.dtype.count,)
return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
case Ops.VCONST: return (len(self.arg),)
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.CUSTOM_FUNCTION: return None
@ -263,6 +273,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
if self.src[0].dtype.count > 1:
assert len(self.src[0]._shape) == 1
return () if self.dtype.count == 1 else (self.dtype.count,)
return self.src[0]._shape
# TODO: disallow shape changing bitcast
@ -316,7 +330,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes} {[x.op for x in self.src]}")
return input_shapes[0]
# all Ops must be explicitly handled
@ -439,7 +453,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike, dtype:DType|None=None):
# constants can optionally have a DEVICE source
ret = UOp.const(dtype or self.dtype.base, b, device=self._device, shape=self.shard_shape if self.axis is not None else self._shape)
dtype = dtype or self.dtype.base
shape = (self.shard_shape if self.axis is not None else self._shape) if dtype.count == 1 else None
ret = UOp.const(dtype, b, device=self._device, shape=shape)
return ret.multi(self.axis) if self.axis is not None else ret
def ufix(self, x):
if isinstance(x, UOp): return x
@ -483,6 +499,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(op, out_dtype, all_srcs, **kwargs)
@staticmethod
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
if shape == (): assert dtype.count == 1, "if shape is () you can't have a vec dtype"
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b):
assert len(b) > 0, "can't create const from empty tuple"