mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
7 commits
master
...
remove_vec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
691ec7b1b9 | ||
|
|
b604b96e99 |
||
|
|
997b0959b7 |
||
|
|
63898ab864 |
||
|
|
03f2f8a206 |
||
|
|
8f498ad0db | ||
|
|
dd33e79281 |
7 changed files with 36 additions and 25 deletions
|
|
@ -76,7 +76,7 @@ def expand_index(buf:UOp, vec:UOp):
|
||||||
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4)))
|
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4)))
|
||||||
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
||||||
# generate the individual indexes
|
# generate the individual indexes
|
||||||
return UOp(Ops.STACK, buf.dtype, tuple(buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)))
|
return UOp(Ops.STACK, buf.dtype, tuple(buf.index(vec.gep(i), ptr=True) for i in range(vec.shape[0])))
|
||||||
|
|
||||||
def fold_expanded_index(midx:UOp):
|
def fold_expanded_index(midx:UOp):
|
||||||
buf = midx.src[0].src[0]
|
buf = midx.src[0].src[0]
|
||||||
|
|
@ -104,7 +104,7 @@ def fold_expanded_index(midx:UOp):
|
||||||
for grp in grouped_offsets:
|
for grp in grouped_offsets:
|
||||||
# get the index offset for this element. using [0] is okay, because they are the same
|
# get the index offset for this element. using [0] is okay, because they are the same
|
||||||
lidx = midx.src[offsets[grp[0]][0]]
|
lidx = midx.src[offsets[grp[0]][0]]
|
||||||
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
#if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||||
# set the idxs of the output
|
# set the idxs of the output
|
||||||
for i,g in enumerate(grp):
|
for i,g in enumerate(grp):
|
||||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||||
|
|
@ -143,7 +143,7 @@ load_store_folding = PatternMatcher([
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
|
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
|
||||||
# put PTRCAT after LOAD
|
# put PTRCAT after LOAD
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||||
lambda cat,ld: UOp(Ops.VCAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
lambda cat,ld: UOp(Ops.VCAT, cat.dtype.base, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||||
# put PTRCAT after STORE
|
# put PTRCAT after STORE
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
|
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,8 @@ def do_expand(root:UOp):
|
||||||
assert root.dtype.count == 1
|
assert root.dtype.count == 1
|
||||||
# is this right?
|
# is this right?
|
||||||
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
||||||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
#nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||||
|
nsrc = UOp(root.op, root.dtype, tuple(new_srcs), new_arg)
|
||||||
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
||||||
|
|
||||||
def do_contract(con:UOp):
|
def do_contract(con:UOp):
|
||||||
|
|
@ -147,7 +148,7 @@ def fix_group_for_reduce(x:UOp):
|
||||||
pm_pre_expander = PatternMatcher([
|
pm_pre_expander = PatternMatcher([
|
||||||
# rewrite UPCAST/UNROLL range to something to be expanded
|
# rewrite UPCAST/UNROLL range to something to be expanded
|
||||||
(UPat(Ops.RANGE, name="r"),
|
(UPat(Ops.RANGE, name="r"),
|
||||||
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
|
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype, tuple(range(s:=r.vmax+1))),), ((r.arg[0],s),)) \
|
||||||
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
|
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
|
||||||
# fix REDUCEs with UNROLLs
|
# fix REDUCEs with UNROLLs
|
||||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ class DType(metaclass=DTypeMetaClass):
|
||||||
return float("inf") if dtypes.is_float(self) else True
|
return float("inf") if dtypes.is_float(self) else True
|
||||||
def const(self, val: tuple[ConstType, ...]|ConstType):
|
def const(self, val: tuple[ConstType, ...]|ConstType):
|
||||||
if isinstance(val, tuple):
|
if isinstance(val, tuple):
|
||||||
assert len(val) == self.count, f"mismatch {val} {self}"
|
#assert len(val) == self.count, f"mismatch {val} {self}"
|
||||||
return tuple(map(self.const, val))
|
return tuple(map(self.const, val))
|
||||||
if isinstance(val, InvalidType): return val
|
if isinstance(val, InvalidType): return val
|
||||||
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
||||||
|
|
|
||||||
|
|
@ -65,9 +65,9 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||||
if len(arg) == 0: return UOp(Ops.STACK, dtypes.weakint.vec(0))
|
if len(arg) == 0: return UOp(Ops.STACK, dtypes.weakint)
|
||||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
elif all_int(arg): return UOp.const(dtypes.weakint, arg)
|
||||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||||
|
|
||||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||||
ret: dict[UOp, dict[UOp, None]] = {}
|
ret: dict[UOp, dict[UOp, None]] = {}
|
||||||
|
|
@ -91,6 +91,7 @@ class UOpMetaClass(type):
|
||||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||||
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
||||||
|
assert dtype.count == 1, "dtype with count is not supported"
|
||||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
||||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||||
if metadata is not None: all_metadata[created] = metadata
|
if metadata is not None: all_metadata[created] = metadata
|
||||||
|
|
@ -212,7 +213,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
match self.op:
|
match self.op:
|
||||||
# late ops don't have shape
|
# late ops don't have shape
|
||||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||||
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
|
Ops.CONTRACT | Ops.SINK | Ops.END | \
|
||||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -232,18 +233,26 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
case Ops.STACK: return (len(self.src),)
|
||||||
|
case Ops.GEP:
|
||||||
|
assert len(self.arg) == 1
|
||||||
|
return ()
|
||||||
case Ops.INDEX:
|
case Ops.INDEX:
|
||||||
|
shp = []
|
||||||
|
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||||
|
return tuple(shp)
|
||||||
|
"""
|
||||||
# non pointer index doesn't have a shape
|
# non pointer index doesn't have a shape
|
||||||
if not isinstance(self.dtype, PtrDType): return None
|
if not isinstance(self.dtype, PtrDType): return None
|
||||||
# fully indexed doesn't have a shape. TODO: remove this
|
# fully indexed doesn't have a shape. TODO: remove this
|
||||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||||
# pointer index
|
# pointer index
|
||||||
return self.src[0].shape[len(self.src[1:]):]
|
return self.src[0].shape[len(self.src[1:]):]
|
||||||
|
"""
|
||||||
|
|
||||||
# some ops init the shape
|
# some ops init the shape
|
||||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
|
||||||
# TODO: VCONST should have the shape of the arg
|
case Ops.VCONST: return (len(self.arg),)
|
||||||
case Ops.VCONST: return ()
|
|
||||||
case Ops.BUFFER: return (self.arg,)
|
case Ops.BUFFER: return (self.arg,)
|
||||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||||
case Ops.CUSTOM_FUNCTION: return None
|
case Ops.CUSTOM_FUNCTION: return None
|
||||||
|
|
@ -419,7 +428,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||||
def vectorize(self, *srcs, **kwargs):
|
def vectorize(self, *srcs, **kwargs):
|
||||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
|
return UOp(Ops.STACK, self.dtype, (self,)+srcs, **kwargs)
|
||||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
|
@ -446,7 +455,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
def broadcast(self, count:int):
|
def broadcast(self, count:int):
|
||||||
assert self.dtype.vcount == 1
|
assert self.dtype.vcount == 1
|
||||||
if count == 1: return self
|
if count == 1: return self
|
||||||
return UOp(Ops.STACK, self.dtype.vec(count), (self,)*count)
|
return UOp(Ops.STACK, self.dtype, (self,)*count)
|
||||||
def cast(self, dtype:DType):
|
def cast(self, dtype:DType):
|
||||||
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
|
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
|
||||||
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
|
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
|
||||||
|
|
@ -461,7 +470,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||||
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||||
i = (i,)
|
i = (i,)
|
||||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
return UOp(Ops.GEP, self.dtype, (self,), i)
|
||||||
|
#return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||||
def store(self, src:UOp|ConstType, **kwargs):
|
def store(self, src:UOp|ConstType, **kwargs):
|
||||||
return UOp(Ops.STORE, dtypes.void, (self, self.const_like(src) if not isinstance(src, UOp) else src), **kwargs)
|
return UOp(Ops.STORE, dtypes.void, (self, self.const_like(src) if not isinstance(src, UOp) else src), **kwargs)
|
||||||
|
|
@ -483,13 +493,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
|
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
|
||||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||||
if isinstance(b, tuple) and all_same(b):
|
#if isinstance(b, tuple) and all_same(b):
|
||||||
assert len(b) > 0, "can't create const from empty tuple"
|
# assert len(b) > 0, "can't create const from empty tuple"
|
||||||
b = b[0] # doesn't have to be a VCONST if they are all the same
|
# b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
||||||
arg=dtype.const(b),
|
arg=dtype.const(b),
|
||||||
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
||||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
|
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape is not () else ret
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
||||||
shape:tuple[sint, ...]|None=None, unique=True):
|
shape:tuple[sint, ...]|None=None, unique=True):
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ shared_codegen_spec = PatternMatcher([
|
||||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||||
|
|
||||||
# VECTORIZE/GEP
|
# VECTORIZE/GEP
|
||||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
(UPat(Ops.STACK, name="x"), lambda x: True),
|
||||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||||
|
|
||||||
# LOAD(idx) / STORE(idx, val)
|
# LOAD(idx) / STORE(idx, val)
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ gep_pushing = PatternMatcher([
|
||||||
# GEP on void is skipped
|
# GEP on void is skipped
|
||||||
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
|
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
|
||||||
# GEP in order is removed
|
# GEP in order is removed
|
||||||
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
|
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].shape[0])) else None),
|
||||||
# push all GEPs through ALUs for index (TODO: remove this)
|
# push all GEPs through ALUs for index (TODO: remove this)
|
||||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.weakint, name='gep'),
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.weakint, name='gep'),
|
||||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||||
|
|
@ -294,8 +294,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||||
# after with 1 src is just src[0]
|
# after with 1 src is just src[0]
|
||||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||||
# VECTORIZE/CONST
|
# VECTORIZE/CONST
|
||||||
(UPat(Ops.STACK, src=UPat(Ops.CONST), name="vec"),
|
#(UPat(Ops.STACK, src=UPat(Ops.CONST), name="vec"),
|
||||||
lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src)) if len(vec.src) > 0 else None),
|
# lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src)) if len(vec.src) > 0 else None),
|
||||||
])+div_and_mod_symbolic+gep_pushing
|
])+div_and_mod_symbolic+gep_pushing
|
||||||
|
|
||||||
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
#if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||||
for u in toposort:
|
for u in toposort:
|
||||||
if u in excluded: continue
|
if u in excluded: continue
|
||||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue