remove shapeless const check in full_shape [pr] (#9911)

* remove shapeless const check in full_shape [pr]

* those can go too
This commit is contained in:
qazal 2025-04-18 00:00:26 +03:00 committed by GitHub
commit d287afe3b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -319,9 +319,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def full_shape(self) -> tuple[sint, ...]:
if self.op is Ops.VIEW: return self.shape
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)]
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
# NOTE: if a parent doesn't have st its full_shape is empty
parent_shapes = [x.full_shape for x in self.src]
return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()]))
@property
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
@ -1031,4 +1030,4 @@ merge_views = PatternMatcher([
lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# movement ops apply a new view on the base
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
])
])