add shape to range/special (#15862)

This commit is contained in:
George Hotz 2026-04-22 11:15:02 +08:00 committed by GitHub
commit 0560fa7b0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View file

@ -211,8 +211,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def _shape(self) -> tuple[sint, ...]|None:
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@ -241,7 +241,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0].shape[len(self.src[1:]):]
# some ops init the shape
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
# TODO: VCONST should have the shape of the arg
case Ops.VCONST: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.CUSTOM_FUNCTION: return None

View file

@ -113,7 +113,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)