mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add shape to range/special (#15862)
This commit is contained in:
parent
3821e442eb
commit
0560fa7b0f
2 changed files with 5 additions and 4 deletions
|
|
@ -211,8 +211,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
|
||||
|
|
@ -241,7 +241,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
# TODO: VCONST should have the shape of the arg
|
||||
case Ops.VCONST: return ()
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
|
|
|
|||
|
|
@ -113,7 +113,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue