Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
2e59c5c3a1 move view after const 2025-08-06 11:20:23 -07:00
2 changed files with 5 additions and 8 deletions

View file

@ -15,9 +15,6 @@ merge_views = PatternMatcher([
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
])
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
@ -124,6 +121,9 @@ fix_kernel_ops = view_left_through_load+PatternMatcher([
# add view to LOAD and STORE
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(UOp(Ops.VIEW, arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),

View file

@ -255,11 +255,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
if shape is not None:
from tinygrad.shape.shapetracker import ShapeTracker
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
if device is not None:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)