tests pass

This commit is contained in:
George Hotz 2026-04-29 16:01:29 -07:00
commit 8ad8249e06

View file

@ -272,6 +272,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
if self.dtype.count > 1:
assert len(self.src[0]._shape) == 1
return (self.dtype.count,)
return self.src[0]._shape
# TODO: disallow shape changing bitcast
@ -289,8 +293,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE}):
ps = self.src[0]._shape
# TODO: WMMA is used for both axis WMMA and op WMMA. fix this and remove this hack. tested by BERT on AMD LLVM
if ps is None and self.op is Ops.WMMA: return None
if ps is None: raise RuntimeError(f"movement op {self.op} requires shape")
match self.op:
case Ops.RESHAPE: