mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove shapeless const check in full_shape [pr] (#9911)
* remove shapeless const check in full_shape [pr] * those can go too
This commit is contained in:
parent
fe6a482f1d
commit
d287afe3b1
1 changed files with 3 additions and 4 deletions
|
|
@ -319,9 +319,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
||||
parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)]
|
||||
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
||||
# NOTE: if a parent doesn't have st its full_shape is empty
|
||||
parent_shapes = [x.full_shape for x in self.src]
|
||||
return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()]))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
|
|
@ -1031,4 +1030,4 @@ merge_views = PatternMatcher([
|
|||
lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# movement ops apply a new view on the base
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
])
|
||||
])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue