mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
remove_vec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b8ff231cd |
||
|
|
a448b47e9c |
||
|
|
6a983bb72b | ||
|
|
fedad1681f |
3 changed files with 24 additions and 7 deletions
|
|
@ -198,7 +198,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||||
(UPat(Ops.REDUCE, name="reduce", src=(UPat.var("x"),)),
|
(UPat(Ops.REDUCE, name="reduce", src=(UPat.var("x"),)),
|
||||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if 0 in x.shape and 0 not in reduce.shape else None),
|
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if 0 in x.shape and 0 not in reduce.shape else None),
|
||||||
# handle size 0
|
# handle size 0
|
||||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and 0 in x.shape else None),
|
(UPat(GroupOp.All-{Ops.SINK, Ops.STACK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and 0 in x.shape else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
# *****************
|
# *****************
|
||||||
|
|
@ -546,7 +546,7 @@ def split_store(x:UOp) -> UOp|None:
|
||||||
# if we have any open ranges here, we don't split
|
# if we have any open ranges here, we don't split
|
||||||
if x.ranges: return None
|
if x.ranges: return None
|
||||||
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
|
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
|
||||||
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
#if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
||||||
|
|
||||||
# local kernel rewrite
|
# local kernel rewrite
|
||||||
lctx = LocalAddBufferContext()
|
lctx = LocalAddBufferContext()
|
||||||
|
|
|
||||||
|
|
@ -212,7 +212,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
match self.op:
|
match self.op:
|
||||||
# late ops don't have shape
|
# late ops don't have shape
|
||||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||||
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
|
Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
|
||||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -228,22 +228,39 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
return inner_shape
|
return inner_shape
|
||||||
|
|
||||||
case Ops.CAST:
|
case Ops.CAST:
|
||||||
|
if self.dtype.count > 1:
|
||||||
|
return (self.dtype.count,)
|
||||||
|
|
||||||
# when PTX casts from ptr to non ptr, remove the shape
|
# when PTX casts from ptr to non ptr, remove the shape
|
||||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
case Ops.STACK: return (len(self.src),)
|
||||||
|
case Ops.GEP:
|
||||||
|
if len(self.arg) > 1: return (len(self.arg),)
|
||||||
|
#assert len(self.arg) == 1
|
||||||
|
return ()
|
||||||
case Ops.INDEX:
|
case Ops.INDEX:
|
||||||
|
shp = []
|
||||||
|
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||||
|
return tuple(shp)
|
||||||
|
"""
|
||||||
# non pointer index doesn't have a shape
|
# non pointer index doesn't have a shape
|
||||||
if not isinstance(self.dtype, PtrDType): return None
|
if not isinstance(self.dtype, PtrDType): return None
|
||||||
# fully indexed doesn't have a shape. TODO: remove this
|
# fully indexed doesn't have a shape. TODO: remove this
|
||||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||||
# pointer index
|
# pointer index
|
||||||
return self.src[0].shape[len(self.src[1:]):]
|
return self.src[0].shape[len(self.src[1:]):]
|
||||||
|
"""
|
||||||
|
|
||||||
# some ops init the shape
|
# some ops init the shape
|
||||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
case Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||||
|
case Ops.CONST:
|
||||||
|
if self.dtype.count > 1: return (self.dtype.count,)
|
||||||
|
return ()
|
||||||
|
case Ops.VCONST: return (len(self.arg),)
|
||||||
# TODO: VCONST should have the shape of the arg
|
# TODO: VCONST should have the shape of the arg
|
||||||
case Ops.VCONST: return ()
|
#case Ops.VCONST: return ()
|
||||||
case Ops.BUFFER: return (self.arg,)
|
case Ops.BUFFER: return (self.arg,)
|
||||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||||
case Ops.CUSTOM_FUNCTION: return None
|
case Ops.CUSTOM_FUNCTION: return None
|
||||||
|
|
@ -492,7 +509,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
||||||
arg=dtype.const(b),
|
arg=dtype.const(b),
|
||||||
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
||||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
|
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != ret.shape else ret
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
||||||
shape:tuple[sint, ...]|None=None, unique=True):
|
shape:tuple[sint, ...]|None=None, unique=True):
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
#if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||||
for u in toposort:
|
for u in toposort:
|
||||||
if u in excluded: continue
|
if u in excluded: continue
|
||||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue