mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
remove_ass
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34bd24bbf2 | ||
|
|
0dd7c13aa8 |
||
|
|
c793a08fbf |
||
|
|
a6913b9add |
3 changed files with 24 additions and 13 deletions
|
|
@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, g
|
|||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
||||
|
||||
|
|
@ -16,10 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
for s in rb.src:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
# if it's a kernel, we don't realize it
|
||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||
ctx[a] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
|
|
@ -29,7 +28,7 @@ pm_generate_realize_map = PatternMatcher([
|
|||
# realize srcs of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# realize ASSIGN and input to assign (might be optimized out)
|
||||
(UPat(Ops.ASSIGN, name="a"), realize_assign),
|
||||
(UPat(Ops.STORE, name="a"), realize_store),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -56,16 +55,21 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||||
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
if s.op is Ops.STORE:
|
||||
# add the ends if this is a store
|
||||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
del ctx.realize_map[s]
|
||||
else:
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||
|
|
@ -85,7 +89,7 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
|||
return ret
|
||||
|
||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
if (x in ctx.range_map or x.src[0].op is Ops.INDEX): return x.src[0]
|
||||
|
||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||
if assign.src[1].op is Ops.KERNEL: return None
|
||||
|
|
@ -176,7 +180,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
# mark all ranges as ended
|
||||
assert rctx.realize_map[x] is None
|
||||
rctx.realize_map[x] = list(range(len(x.shape)))
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
# treat MSTACK/MSELECT like SINK
|
||||
continue
|
||||
elif len(consumer_rngs) == 0:
|
||||
|
|
@ -209,6 +213,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
out_rngs = tuple(_out_rngs)
|
||||
|
||||
# we have to (partially) realize here if there's new ranges
|
||||
print(_realize_axis)
|
||||
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
|
||||
|
||||
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
|
||||
|
|
|
|||
|
|
@ -534,6 +534,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
|
||||
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||
|
||||
# link any movementops with tags to sink, and remove the tags from other parts of the graph
|
||||
# we add "None" to the tag so it's deduped from the movementop that exists in the graph
|
||||
#tagged_mops = [x.replace(tag=x.tag+(None,)) for x in tsink.toposort() if x.op in GroupOp.Movement and x.tag is not None]
|
||||
#tsink = tsink.replace(src=tsink.src+tuple(tagged_mops))
|
||||
#tsink = tsink.substitute({})
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ class Tensor(OpMixin):
|
|||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
return self.replace(self._apply_uop(lambda x,y: x.after(x.store(y)), x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue