Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
34bd24bbf2 maybe 2025-11-14 09:46:25 -08:00
George Hotz
0dd7c13aa8
Merge branch 'master' into remove_assign 2025-11-14 09:16:27 -08:00
George Hotz
c793a08fbf
Merge branch 'master' into remove_assign 2025-11-11 19:27:07 -08:00
George Hotz
a6913b9add replace ASSIGN with STORE/AFTER 2025-11-10 19:05:40 -08:00
3 changed files with 24 additions and 13 deletions

View file

@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, g
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL} Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
@ -16,10 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src: for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it ctx[a] = None
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
pm_generate_realize_map = PatternMatcher([ pm_generate_realize_map = PatternMatcher([
# always realize SINK src # always realize SINK src
@ -29,7 +28,7 @@ pm_generate_realize_map = PatternMatcher([
# realize srcs of COPY, MSELECT, MSTACK # realize srcs of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# realize ASSIGN and input to assign (might be optimized out) # realize ASSIGN and input to assign (might be optimized out)
(UPat(Ops.ASSIGN, name="a"), realize_assign), (UPat(Ops.STORE, name="a"), realize_store),
]) ])
@dataclass(frozen=True) @dataclass(frozen=True)
@ -56,16 +55,21 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
new_srcs = [] new_srcs = []
for s in x.src: for s in x.src:
new_src = s new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map: elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s] realized_ranges = ctx.realize_map[s]
assert isinstance(realized_ranges, list), "realize map must contain range list" assert isinstance(realized_ranges, list), "realize map must contain range list"
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges]) closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
# None in the device assigns it a number later if s.op is Ops.STORE:
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) # add the ends if this is a store
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) del ctx.realize_map[s]
else:
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src) new_srcs.append(new_src)
# NOTE: do we need this? # NOTE: do we need this?
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
@ -85,7 +89,7 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
return ret return ret
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp): def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0] if (x in ctx.range_map or x.src[0].op is Ops.INDEX): return x.src[0]
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp): def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
if assign.src[1].op is Ops.KERNEL: return None if assign.src[1].op is Ops.KERNEL: return None
@ -176,7 +180,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# mark all ranges as ended # mark all ranges as ended
assert rctx.realize_map[x] is None assert rctx.realize_map[x] is None
rctx.realize_map[x] = list(range(len(x.shape))) rctx.realize_map[x] = list(range(len(x.shape)))
elif x.op in {Ops.MSTACK, Ops.MSELECT}: elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
# treat MSTACK/MSELECT like SINK # treat MSTACK/MSELECT like SINK
continue continue
elif len(consumer_rngs) == 0: elif len(consumer_rngs) == 0:
@ -209,6 +213,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
out_rngs = tuple(_out_rngs) out_rngs = tuple(_out_rngs)
# we have to (partially) realize here if there's new ranges # we have to (partially) realize here if there's new ranges
print(_realize_axis)
if len(_realize_axis): rctx.realize_map[x] = _realize_axis if len(_realize_axis): rctx.realize_map[x] = _realize_axis
# if this element is a reduce and there's ended ranges, we might have to end some other ranges # if this element is a reduce and there's ended ranges, we might have to end some other ranges

View file

@ -534,6 +534,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
# link any movementops with tags to sink, and remove the tags from other parts of the graph
# we add "None" to the tag so it's deduped from the movementop that exists in the graph
#tagged_mops = [x.replace(tag=x.tag+(None,)) for x in tsink.toposort() if x.op in GroupOp.Movement and x.tag is not None]
#tsink = tsink.replace(src=tsink.src+tuple(tagged_mops))
#tsink = tsink.substitute({})
# convert movement ops to ranges # convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)

View file

@ -299,7 +299,7 @@ class Tensor(OpMixin):
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
return self.replace(self._apply_uop(UOp.assign, x)) return self.replace(self._apply_uop(lambda x,y: x.after(x.store(y)), x))
def detach(self) -> Tensor: def detach(self) -> Tensor:
""" """