mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
no_assign_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1ed293239 | ||
|
|
810d7ee7ec | ||
|
|
743e5334a4 |
||
|
|
75d311d766 | ||
|
|
fee1f4ccb9 |
5 changed files with 13 additions and 9 deletions
|
|
@ -27,11 +27,11 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
|||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize),
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
(UPat(Ops.STORE, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
|
|||
root_target = target
|
||||
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
|
||||
# when RHS depends on the previous assign result, break with contiguous
|
||||
if target in src.toposort(): src = src.contiguous()
|
||||
#if target in src.toposort(): src = src.contiguous()
|
||||
return assign.replace(src=(root_target, src))
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
|
|
@ -160,7 +160,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
(UPat(Ops.STORE, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
|
||||
# ** size 0 **
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,8 @@ class Ops(FastEnum):
|
|||
# ** 6 -- ops that don't exist in programs **
|
||||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
|
||||
UNIQUE = auto(); DEVICE = auto() #; ASSIGN = auto()
|
||||
ASSIGN = AFTER # ASSIGN is AFTER now (remove it)
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
|
|
|||
|
|
@ -301,10 +301,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
raise ValueError(f"invalid type for axis: {axis_arg}")
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
if self.op is Ops.ASSIGN: return self.src[1]._shape
|
||||
if self.op is Ops.STORE: return self.src[1]._shape
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
|
|
@ -447,7 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
|
||||
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def assign(self, x:UOp): return self.after(self.store(x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def contract(self, *rngs:UOp):
|
||||
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
||||
|
|
@ -1051,12 +1051,12 @@ class UPat(OpMixin):
|
|||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return self.after(self.store(x), **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
|
|||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
||||
# STORE in tensor graph: store a value into a target
|
||||
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue