Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
35e13e08d5 improve shapes to make them behave like dtype.count, try 2 2026-05-01 10:27:54 -07:00

View file

@ -208,7 +208,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@ -224,21 +224,38 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return inner_shape
case Ops.CAST:
# if it has a vec dtype, set the shape
if self.dtype.count > 1: return (self.dtype.count,)
# when PTX casts from ptr to non ptr, remove the shape of the buffer
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
return ()
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
case Ops.STACK: return (len(self.src),)
case Ops.INDEX:
shp:list[sint] = []
# NOTE: the acc buffer can have a dtype with count, we need it back here
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
for s in self.src[1:]: shp.extend(list(s.shape))
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
# TODO: these should have the shape of the dtype.count
case Ops.CONST | Ops.DEFINE_VAR: return ()
case Ops.GEP | Ops.STACK | Ops.VCONST | Ops.VCAT: return ()
case Ops.CONST | Ops.DEFINE_VAR:
if self.dtype.count > 1: return (self.dtype.count,)
return ()
case Ops.VCONST: return (len(self.arg),)
case Ops.VCAT: return ()
# TODO: contract and unroll should be deleted
case Ops.CONTRACT:
amt = prod([x[1] for x in self.arg])
return (amt,) if amt > 1 else ()
case Ops.UNROLL:
amt = prod([x[1] for x in self.arg])
return () if self.src[0]._shape[0] == amt else (amt,)
# some ops init the shape
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW:
# HACK: BUFFER_VIEW is used inside kernels, so we set the shape to () if it's on an INDEX
@ -261,6 +278,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
if self.src[0].dtype.count > 1:
assert self.src[0]._shape is not None and len(self.src[0]._shape) == 1, f"bad reduce shape on {self.src[0].op} {self.src[0]._shape}"
return () if self.dtype.count == 1 else (self.dtype.count,)
return self.src[0]._shape
# TODO: disallow shape changing bitcast