Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
8ba57c3246 we name the warp warp now 2025-11-18 22:17:18 -08:00
George Hotz
a3490252e2 fix negative 2025-11-18 17:28:42 -08:00
George Hotz
db150af97e str in limit_bufs 2025-11-18 17:19:39 -08:00
George Hotz
e473daff5a str 2025-11-18 17:15:52 -08:00
George Hotz
1852463f24 string_ranges 2025-11-18 17:12:09 -08:00
6 changed files with 23 additions and 21 deletions

View file

@ -134,7 +134,7 @@ def fix_group_for_reduce(x:UOp):
# do only the non grouped reduces early # do only the non grouped reduces early
ret = x.replace(src=(x.src[0],)+tuple(reduce_r)) ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] reduce_loop = [x.replace(arg=(x.arg[0]+"_gfr", AxisType.REDUCE)) for x in reduce_gfr]
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
# do the final reduce (if/barrier are added in gpudims step) # do the final reduce (if/barrier are added in gpudims step)

View file

@ -18,6 +18,7 @@ class Scheduler:
self.ast, self.ren = ast, ren self.ast, self.ren = ast, ren
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else [] self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
self.opt_range = itertools.count()
@property @property
def rngs(self): def rngs(self):
@ -29,8 +30,6 @@ class Scheduler:
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs] def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
@property @property
def axis_types(self): return [x.arg[-1] for x in self.rngs] def axis_types(self): return [x.arg[-1] for x in self.rngs]
@property
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
def shape_str(self) -> list[str]: def shape_str(self) -> list[str]:
@ -95,10 +94,10 @@ class Scheduler:
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None): def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
if (old_sz:=rng.src[0].divides(amount)) is None: if (old_sz:=rng.src[0].divides(amount)) is None:
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}") raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng new_rng = UOp.range(amount, f"o{next(self.opt_range)}", new_type) if input_new_rng is None else input_new_rng
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),)) replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng) sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}") self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} by {amount} {str(new_type).split('.')[1].lower()}")
return replaced_rng, new_rng return replaced_rng, new_rng
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type] def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
@ -231,9 +230,9 @@ class Scheduler:
for tc in tensor_cores: for tc in tensor_cores:
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar(): if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE # tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0]) in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0]) in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0], reverse=True)
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0]) red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0], reverse=True)
if DEBUG >= 3: if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}", print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}") f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
@ -260,7 +259,8 @@ class Scheduler:
except KernelOptError: continue except KernelOptError: continue
# we create the warp as a whole thing, in case some of these ranges are moved/removed later # we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP) # the $ puts it before any numbered ranges
warp = UOp.range(tc.threads, '$warp', AxisType.WARP)
ne: list[UOp] = [] ne: list[UOp] = []
for opt in tc.opts: for opt in tc.opts:
if opt[0] == "l": if opt[0] == "l":

View file

@ -47,7 +47,7 @@ def do_substitute(ctx, x: UOp):
subs = {} subs = {}
for k,v in ctx.items(): for k,v in ctx.items():
if v is not None: if v is not None:
subs[k] = k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1])) subs[k] = k.replace(src=(k.src[0]//v,), arg=(k.arg[0]+"_0", k.arg[-1]))*v + k.replace(src=(v,), arg=(k.arg[0]+"_1", k.arg[-1]))
if not len(subs): return None if not len(subs): return None
ret = x.substitute(subs).simplify() ret = x.substitute(subs).simplify()
ctx.clear() ctx.clear()
@ -152,7 +152,7 @@ def cut_store_range(ctx, store:UOp, r:UOp):
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
cuts = sorted(dedup([0] + cuts + [r.src[0].arg])) cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))] ranges = [UOp.range((end-start), r.arg[0]+f"_{i}", r.arg[-1]) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])]) return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])

View file

@ -46,7 +46,7 @@ def split_reduceop(reduce:UOp, x:UOp):
# get expanded by rangeifying the UOp x # get expanded by rangeifying the UOp x
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)]) indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges] range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
is_expanded = [i not in range_nums for i in range(len(x.shape))] is_expanded = [str(i) not in range_nums for i in range(len(x.shape))]
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1) if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and not is_expanded[i]]): return None if x.shape[i]%d==0 and not is_expanded[i]]): return None
@ -289,7 +289,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
for s in root.src: for s in root.src:
if s.op in GroupOp.Elementwise: if s.op in GroupOp.Elementwise:
# Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP # Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP
orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges] orig_ranges, end_ranges = s.ranges, [x.replace(arg=(str(next(ctx.range_idx)), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges]
s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=s.device)).index(*orig_ranges) s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=s.device)).index(*orig_ranges)
srcs.append(s) srcs.append(s)
return root.replace(src=tuple(srcs)) return root.replace(src=tuple(srcs))
@ -412,7 +412,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.arg[-1] == AxisType.OUTER: if r.arg[-1] == AxisType.OUTER:
# for outer range, we replace with a bound variable # for outer range, we replace with a bound variable
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None)) return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None) ret = r.replace(arg=(str(ctx.range),)+r.arg[1:], tag=None)
ctx.range += 1 ctx.range += 1
return ret return ret

View file

@ -49,8 +49,8 @@ def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
def range_str(u:UOp, color=False) -> str: def range_str(u:UOp, color=False) -> str:
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]]) assert len(u.arg) == 2
return colored(ret, axis_colors[u.arg[-1]]) if color else ret return colored(u.arg[0], axis_colors[u.arg[-1]]) if color else u.arg[0]
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str: def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)]) ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
@ -86,6 +86,7 @@ class UOpMetaClass(type):
if _buffer is not None: if _buffer is not None:
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}" assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
buffers[created] = _buffer buffers[created] = _buffer
if op is Ops.RANGE: assert isinstance(arg[0], str)
if SPEC > 1: if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender from tinygrad.uop.spec import full_spec, test_pyrender
if SPEC > 2: test_pyrender(created) if SPEC > 2: test_pyrender(created)
@ -425,8 +426,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape) if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret return ret
@staticmethod @staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs): def range(end:sint, axis_id:str|int, axis_type=AxisType.LOOP, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs) assert isinstance(axis_type, AxisType), f"{axis_type} must be an AxisType"
str_axis_id = ("m"+str(-axis_id)) if isinstance(axis_id, int) and axis_id < 0 else str(axis_id)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(str_axis_id, axis_type), **kwargs)
@staticmethod @staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name) def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]): def r(self, op:Ops, axis:tuple[int, ...]):
@ -1350,7 +1353,7 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"), (UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
# NOTE: range has srcs sometimes after control flow # NOTE: range has srcs sometimes after control flow
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c: (UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+ "UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"), (f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
# TODO: index shouldn't mismatch dtype # TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x: (UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:

View file

@ -37,8 +37,7 @@ shared_spec = PatternMatcher([
# RANGE can be in the big graph now # RANGE can be in the big graph now
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) == 2 and isinstance(rng.arg[0], str) and isinstance(rng.arg[-1], AxisType)),
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
# RANGE/SPECIAL define loops, END closes them # RANGE/SPECIAL define loops, END closes them