mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
no_assign_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1ed293239 | ||
|
|
810d7ee7ec | ||
|
|
743e5334a4 |
||
|
|
75d311d766 | ||
|
|
fee1f4ccb9 |
5 changed files with 13 additions and 9 deletions
|
|
@ -27,11 +27,11 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||||
|
|
||||||
pm_generate_realize_map = PatternMatcher([
|
pm_generate_realize_map = PatternMatcher([
|
||||||
# always realize
|
# always realize
|
||||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize),
|
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
|
||||||
# realize srcs of these
|
# realize srcs of these
|
||||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||||
# sometimes realize src of assign
|
# sometimes realize src of assign
|
||||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
(UPat(Ops.STORE, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||||
])
|
])
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
|
||||||
root_target = target
|
root_target = target
|
||||||
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
|
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
|
||||||
# when RHS depends on the previous assign result, break with contiguous
|
# when RHS depends on the previous assign result, break with contiguous
|
||||||
if target in src.toposort(): src = src.contiguous()
|
#if target in src.toposort(): src = src.contiguous()
|
||||||
return assign.replace(src=(root_target, src))
|
return assign.replace(src=(root_target, src))
|
||||||
|
|
||||||
def split_reduceop(reduce:UOp, x:UOp):
|
def split_reduceop(reduce:UOp, x:UOp):
|
||||||
|
|
@ -160,7 +160,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
||||||
|
|
||||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
(UPat(Ops.STORE, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||||
|
|
||||||
# ** size 0 **
|
# ** size 0 **
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,8 @@ class Ops(FastEnum):
|
||||||
# ** 6 -- ops that don't exist in programs **
|
# ** 6 -- ops that don't exist in programs **
|
||||||
|
|
||||||
# tensor graph ops
|
# tensor graph ops
|
||||||
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
|
UNIQUE = auto(); DEVICE = auto() #; ASSIGN = auto()
|
||||||
|
ASSIGN = AFTER # ASSIGN is AFTER now (remove it)
|
||||||
|
|
||||||
# local unique
|
# local unique
|
||||||
LUNIQUE = auto()
|
LUNIQUE = auto()
|
||||||
|
|
|
||||||
|
|
@ -301,10 +301,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
raise ValueError(f"invalid type for axis: {axis_arg}")
|
raise ValueError(f"invalid type for axis: {axis_arg}")
|
||||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||||
|
|
||||||
if self.op is Ops.ASSIGN: return self.src[1]._shape
|
if self.op is Ops.STORE: return self.src[1]._shape
|
||||||
|
|
||||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
|
||||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||||
if len(input_shapes) == 0: return None
|
if len(input_shapes) == 0: return None
|
||||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||||
|
|
@ -447,7 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
|
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
|
||||||
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
||||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
def assign(self, x:UOp): return self.after(self.store(x))
|
||||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||||
def contract(self, *rngs:UOp):
|
def contract(self, *rngs:UOp):
|
||||||
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
||||||
|
|
@ -1051,12 +1051,12 @@ class UPat(OpMixin):
|
||||||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
|
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
|
||||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
|
|
||||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
|
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
|
||||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
|
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
|
||||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
||||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
|
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
|
||||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
|
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
|
||||||
|
def assign(self, x:UPat, **kwargs): return self.after(self.store(x), **kwargs)
|
||||||
|
|
||||||
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
|
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
|
||||||
def alu(self, op:Ops, *src:UPat):
|
def alu(self, op:Ops, *src:UPat):
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
|
||||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||||
|
|
||||||
|
# STORE in tensor graph: store a value into a target
|
||||||
|
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
|
||||||
|
|
||||||
# MSELECT chooses one of the multi buffers
|
# MSELECT chooses one of the multi buffers
|
||||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue