Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
b1ed293239 go 2026-03-12 14:52:37 +08:00
George Hotz
810d7ee7ec fixes 2026-03-12 14:49:02 +08:00
George Hotz
743e5334a4
Merge branch 'master' into no_assign_2 2026-03-12 14:25:13 +08:00
George Hotz
75d311d766 no contig there 2026-03-12 13:53:11 +08:00
George Hotz
fee1f4ccb9 ASSIGN is STORE+AFTER 2026-03-12 11:59:10 +08:00
5 changed files with 13 additions and 9 deletions

View file

@ -27,11 +27,11 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
pm_generate_realize_map = PatternMatcher([ pm_generate_realize_map = PatternMatcher([
# always realize # always realize
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize), (UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
# realize srcs of these # realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# sometimes realize src of assign # sometimes realize src of assign
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src), (UPat(Ops.STORE, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
]) ])
@dataclass(frozen=True) @dataclass(frozen=True)

View file

@ -61,7 +61,7 @@ def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous # when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous() #if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src)) return assign.replace(src=(root_target, src))
def split_reduceop(reduce:UOp, x:UOp): def split_reduceop(reduce:UOp, x:UOp):
@ -160,7 +160,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain), (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
# make source contiguous if it has hazardous movement ops on the dest buffer # make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), (UPat(Ops.STORE, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
# ** size 0 ** # ** size 0 **

View file

@ -82,7 +82,8 @@ class Ops(FastEnum):
# ** 6 -- ops that don't exist in programs ** # ** 6 -- ops that don't exist in programs **
# tensor graph ops # tensor graph ops
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto() UNIQUE = auto(); DEVICE = auto() #; ASSIGN = auto()
ASSIGN = AFTER # ASSIGN is AFTER now (remove it)
# local unique # local unique
LUNIQUE = auto() LUNIQUE = auto()

View file

@ -301,10 +301,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
raise ValueError(f"invalid type for axis: {axis_arg}") raise ValueError(f"invalid type for axis: {axis_arg}")
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
if self.op is Ops.ASSIGN: return self.src[1]._shape if self.op is Ops.STORE: return self.src[1]._shape
# elementwise ops keep the shape the same. all inputs with shape must match # elementwise ops keep the shape the same. all inputs with shape must match
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}): if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
input_shapes = [x._shape for x in self.src if x._shape is not None] input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}") if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
@ -447,7 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs) return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def assign(self, x:UOp): return self.after(self.store(x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def contract(self, *rngs:UOp): def contract(self, *rngs:UOp):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast" assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
@ -1051,12 +1051,12 @@ class UPat(OpMixin):
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs) def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs) def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs) def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs) def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs) def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs) def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs) def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return self.after(self.store(x), **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b)) def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat): def alu(self, op:Ops, *src:UPat):

View file

@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
# ASSIGN has a target and a value. It can also optionally depend on other assigns # ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# STORE in tensor graph: store a value into a target
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
# MSELECT chooses one of the multi buffers # MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)), (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),