mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
remove_ass
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34bd24bbf2 | ||
|
|
0dd7c13aa8 |
||
|
|
c793a08fbf |
||
|
|
a6913b9add |
3 changed files with 24 additions and 13 deletions
|
|
@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, g
|
||||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||||
|
|
||||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
||||||
|
|
||||||
|
|
@ -16,10 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||||
for s in rb.src:
|
for s in rb.src:
|
||||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||||
|
|
||||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
|
||||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||||
# if it's a kernel, we don't realize it
|
ctx[a] = None
|
||||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
|
||||||
|
|
||||||
pm_generate_realize_map = PatternMatcher([
|
pm_generate_realize_map = PatternMatcher([
|
||||||
# always realize SINK src
|
# always realize SINK src
|
||||||
|
|
@ -29,7 +28,7 @@ pm_generate_realize_map = PatternMatcher([
|
||||||
# realize srcs of COPY, MSELECT, MSTACK
|
# realize srcs of COPY, MSELECT, MSTACK
|
||||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||||
# realize ASSIGN and input to assign (might be optimized out)
|
# realize ASSIGN and input to assign (might be optimized out)
|
||||||
(UPat(Ops.ASSIGN, name="a"), realize_assign),
|
(UPat(Ops.STORE, name="a"), realize_store),
|
||||||
])
|
])
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -56,16 +55,21 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||||
new_srcs = []
|
new_srcs = []
|
||||||
for s in x.src:
|
for s in x.src:
|
||||||
new_src = s
|
new_src = s
|
||||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||||
elif s in ctx.realize_map:
|
elif s in ctx.realize_map:
|
||||||
realized_ranges = ctx.realize_map[s]
|
realized_ranges = ctx.realize_map[s]
|
||||||
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||||||
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
|
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
|
||||||
# None in the device assigns it a number later
|
if s.op is Ops.STORE:
|
||||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
# add the ends if this is a store
|
||||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
del ctx.realize_map[s]
|
||||||
|
else:
|
||||||
|
# None in the device assigns it a number later
|
||||||
|
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||||
|
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||||
|
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||||
new_srcs.append(new_src)
|
new_srcs.append(new_src)
|
||||||
# NOTE: do we need this?
|
# NOTE: do we need this?
|
||||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||||
|
|
@ -85,7 +89,7 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
if (x in ctx.range_map or x.src[0].op is Ops.INDEX): return x.src[0]
|
||||||
|
|
||||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||||
if assign.src[1].op is Ops.KERNEL: return None
|
if assign.src[1].op is Ops.KERNEL: return None
|
||||||
|
|
@ -176,7 +180,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
# mark all ranges as ended
|
# mark all ranges as ended
|
||||||
assert rctx.realize_map[x] is None
|
assert rctx.realize_map[x] is None
|
||||||
rctx.realize_map[x] = list(range(len(x.shape)))
|
rctx.realize_map[x] = list(range(len(x.shape)))
|
||||||
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||||
# treat MSTACK/MSELECT like SINK
|
# treat MSTACK/MSELECT like SINK
|
||||||
continue
|
continue
|
||||||
elif len(consumer_rngs) == 0:
|
elif len(consumer_rngs) == 0:
|
||||||
|
|
@ -209,6 +213,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
out_rngs = tuple(_out_rngs)
|
out_rngs = tuple(_out_rngs)
|
||||||
|
|
||||||
# we have to (partially) realize here if there's new ranges
|
# we have to (partially) realize here if there's new ranges
|
||||||
|
print(_realize_axis)
|
||||||
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
|
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
|
||||||
|
|
||||||
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
|
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
|
||||||
|
|
|
||||||
|
|
@ -534,6 +534,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
|
|
||||||
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||||
|
|
||||||
|
# link any movementops with tags to sink, and remove the tags from other parts of the graph
|
||||||
|
# we add "None" to the tag so it's deduped from the movementop that exists in the graph
|
||||||
|
#tagged_mops = [x.replace(tag=x.tag+(None,)) for x in tsink.toposort() if x.op in GroupOp.Movement and x.tag is not None]
|
||||||
|
#tsink = tsink.replace(src=tsink.src+tuple(tagged_mops))
|
||||||
|
#tsink = tsink.substitute({})
|
||||||
|
|
||||||
# convert movement ops to ranges
|
# convert movement ops to ranges
|
||||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,7 @@ class Tensor(OpMixin):
|
||||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||||
return self.replace(self._apply_uop(UOp.assign, x))
|
return self.replace(self._apply_uop(lambda x,y: x.after(x.store(y)), x))
|
||||||
|
|
||||||
def detach(self) -> Tensor:
|
def detach(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue