mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
23 commits
master
...
dtype_shap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc21351428 | ||
|
|
7e329c5219 | ||
|
|
d1193c72ac | ||
|
|
fe2dcbd573 | ||
|
|
4b16c81944 |
||
|
|
05638ed496 |
||
|
|
f8460c1021 | ||
|
|
273e0a4fa6 | ||
|
|
649cdbf216 |
||
|
|
7969b205dd | ||
|
|
ac6dee758a | ||
|
|
4fb29cc0c4 | ||
|
|
c4d1792edf | ||
|
|
4a4455f5b1 | ||
|
|
29dd605a91 |
||
|
|
5325db3af6 | ||
|
|
cfdff84df0 | ||
|
|
4bf0c35300 | ||
|
|
8ad8249e06 | ||
|
|
95d04048b0 | ||
|
|
d1f9ade9a0 | ||
|
|
dd19cdc0cd | ||
|
|
4b4cfc0d81 |
3 changed files with 42 additions and 18 deletions
|
|
@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, AddrSpace
|
|||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
|
|
@ -265,7 +265,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
|
||||
rctx.range_map[x] = (rngs, out_rngs)
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||||
# NOTE: SPEC=3 is broken here with shape
|
||||
with Context(SPEC=min(SPEC.value, 2)):
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||||
return tsink, rctx
|
||||
|
||||
def render_ranges(*rngs_list, realized) -> str:
|
||||
|
|
|
|||
|
|
@ -63,12 +63,17 @@ pm_fold_moved_after = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
def move_mop_before_index(r:UOp, idx:UOp):
|
||||
# TODO: store requires this
|
||||
try: src_shape = r.src[0]._shape
|
||||
except RuntimeError: return None
|
||||
return r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) \
|
||||
if src_shape is not None and len(idx.src[1:]) == len(r.shape) else None
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), move_mop_before_index),
|
||||
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
|
|
|
|||
|
|
@ -100,7 +100,11 @@ class UOpMetaClass(type):
|
|||
buffers[created] = _buffer
|
||||
if SPEC > 1:
|
||||
from tinygrad.uop.spec import full_spec, test_pyrender
|
||||
if SPEC > 2: test_pyrender(created)
|
||||
if SPEC > 2:
|
||||
# SPEC=3 checks the shape
|
||||
_ = created._shape
|
||||
if SPEC > 3:
|
||||
test_pyrender(created)
|
||||
with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created))
|
||||
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
|
||||
return created
|
||||
|
|
@ -212,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | \
|
||||
Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
|
||||
|
|
@ -228,22 +232,28 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return inner_shape
|
||||
|
||||
case Ops.CAST:
|
||||
# if it has a vec dtype, set the shape
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
# when PTX casts from ptr to non ptr, remove the shape
|
||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||
return None
|
||||
|
||||
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
|
||||
case Ops.STACK: return (len(self.src),)
|
||||
case Ops.INDEX:
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
# fully indexed doesn't have a shape. TODO: remove this
|
||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
# pointer index
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
shp:list[sint] = []
|
||||
# NOTE: the acc buffer can have a dtype with count, we need it back here
|
||||
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
|
||||
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
# TODO: VCONST should have the shape of the arg
|
||||
case Ops.VCONST: return ()
|
||||
case Ops.CONST | Ops.DEFINE_VAR:
|
||||
# these can have shape if it has a vec dtype
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
return ()
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
|
||||
case Ops.VCONST: return (len(self.arg),)
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
|
|
@ -263,6 +273,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return self.src[0]._shape
|
||||
# REDUCE with empty axis is passthrough (lowered form)
|
||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
||||
# these can mismatch if there's a horizonal reduce
|
||||
if self.src[0].dtype.count > 1:
|
||||
assert len(self.src[0]._shape) == 1
|
||||
return () if self.dtype.count == 1 else (self.dtype.count,)
|
||||
return self.src[0]._shape
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
|
|
@ -316,7 +330,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes} {[x.op for x in self.src]}")
|
||||
return input_shapes[0]
|
||||
|
||||
# all Ops must be explicitly handled
|
||||
|
|
@ -439,7 +453,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike, dtype:DType|None=None):
|
||||
# constants can optionally have a DEVICE source
|
||||
ret = UOp.const(dtype or self.dtype.base, b, device=self._device, shape=self.shard_shape if self.axis is not None else self._shape)
|
||||
dtype = dtype or self.dtype.base
|
||||
shape = (self.shard_shape if self.axis is not None else self._shape) if dtype.count == 1 else None
|
||||
ret = UOp.const(dtype, b, device=self._device, shape=shape)
|
||||
return ret.multi(self.axis) if self.axis is not None else ret
|
||||
def ufix(self, x):
|
||||
if isinstance(x, UOp): return x
|
||||
|
|
@ -483,6 +499,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(op, out_dtype, all_srcs, **kwargs)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
|
||||
if shape == (): assert dtype.count == 1, "if shape is () you can't have a vec dtype"
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b):
|
||||
assert len(b) > 0, "can't create const from empty tuple"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue