Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
e738b2d4a5
Merge branch 'master' into delete_ones 2025-07-25 18:28:23 -07:00
George Hotz
dfb3e99b09 late remove ones 2025-07-25 15:51:56 -07:00
George Hotz
d2473586d1 no keepdims in reduce 2025-07-25 15:44:31 -07:00
3 changed files with 7 additions and 2 deletions

View file

@ -76,6 +76,9 @@ class Kernel:
full_shape = ast.full_shape
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
# extend all shapes of all shapetrackers
self.sts = [x.reshape(x.shape+(1,)*(len(full_shape)-len(x.shape))) for x in self.sts]
# parameters for optimization
self.tensor_core: TensorCore|None = None
self.tensor_core_opts: TensorCoreOptions|None = None
@ -448,6 +451,8 @@ class Kernel:
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
if op.op in GroupOp.Buffer and op in self.bufs:
st = self.sts[self.bufs.index(op)]
# late remove all ones
st = st.reshape(tuple([x for x in st.shape if resolve(x != 1)]))
# NOTE: if CONST got masked after applying opts, we create a new VALID
if op.op is Ops.CONST and any(v.mask is not None for v in st.views): return op.view(st).valid()
# otherwise we just replace the VIEW source

View file

@ -191,7 +191,7 @@ view_left = merge_views+PatternMatcher([
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
#(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
])
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")

View file

@ -84,7 +84,7 @@ class ShapeTracker:
@property
def size(self) -> int: return self.views[-1].size()
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(s for i,s in enumerate(self.shape) if i not in axis)
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]: