mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
string_ran
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ba57c3246 | ||
|
|
a3490252e2 | ||
|
|
db150af97e | ||
|
|
e473daff5a | ||
|
|
1852463f24 |
6 changed files with 23 additions and 21 deletions
|
|
@ -134,7 +134,7 @@ def fix_group_for_reduce(x:UOp):
|
||||||
|
|
||||||
# do only the non grouped reduces early
|
# do only the non grouped reduces early
|
||||||
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
|
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
|
||||||
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
|
reduce_loop = [x.replace(arg=(x.arg[0]+"_gfr", AxisType.REDUCE)) for x in reduce_gfr]
|
||||||
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
|
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
|
||||||
|
|
||||||
# do the final reduce (if/barrier are added in gpudims step)
|
# do the final reduce (if/barrier are added in gpudims step)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class Scheduler:
|
||||||
self.ast, self.ren = ast, ren
|
self.ast, self.ren = ast, ren
|
||||||
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
|
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
|
||||||
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
|
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
|
||||||
|
self.opt_range = itertools.count()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rngs(self):
|
def rngs(self):
|
||||||
|
|
@ -29,8 +30,6 @@ class Scheduler:
|
||||||
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
|
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
|
||||||
@property
|
@property
|
||||||
def axis_types(self): return [x.arg[-1] for x in self.rngs]
|
def axis_types(self): return [x.arg[-1] for x in self.rngs]
|
||||||
@property
|
|
||||||
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
|
|
||||||
|
|
||||||
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
||||||
def shape_str(self) -> list[str]:
|
def shape_str(self) -> list[str]:
|
||||||
|
|
@ -95,10 +94,10 @@ class Scheduler:
|
||||||
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
|
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
|
||||||
if (old_sz:=rng.src[0].divides(amount)) is None:
|
if (old_sz:=rng.src[0].divides(amount)) is None:
|
||||||
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
|
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
|
||||||
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
|
new_rng = UOp.range(amount, f"o{next(self.opt_range)}", new_type) if input_new_rng is None else input_new_rng
|
||||||
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
|
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
|
||||||
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
|
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
|
||||||
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}")
|
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} by {amount} {str(new_type).split('.')[1].lower()}")
|
||||||
return replaced_rng, new_rng
|
return replaced_rng, new_rng
|
||||||
|
|
||||||
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
|
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
|
||||||
|
|
@ -231,9 +230,9 @@ class Scheduler:
|
||||||
for tc in tensor_cores:
|
for tc in tensor_cores:
|
||||||
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
||||||
# tensor cores have three ranges. X, Y, and REDUCE
|
# tensor cores have three ranges. X, Y, and REDUCE
|
||||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
|
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)
|
||||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0], reverse=True)
|
||||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0], reverse=True)
|
||||||
if DEBUG >= 3:
|
if DEBUG >= 3:
|
||||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||||
|
|
@ -260,7 +259,8 @@ class Scheduler:
|
||||||
except KernelOptError: continue
|
except KernelOptError: continue
|
||||||
|
|
||||||
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
||||||
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
# the $ puts it before any numbered ranges
|
||||||
|
warp = UOp.range(tc.threads, '$warp', AxisType.WARP)
|
||||||
ne: list[UOp] = []
|
ne: list[UOp] = []
|
||||||
for opt in tc.opts:
|
for opt in tc.opts:
|
||||||
if opt[0] == "l":
|
if opt[0] == "l":
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ def do_substitute(ctx, x: UOp):
|
||||||
subs = {}
|
subs = {}
|
||||||
for k,v in ctx.items():
|
for k,v in ctx.items():
|
||||||
if v is not None:
|
if v is not None:
|
||||||
subs[k] = k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1]))
|
subs[k] = k.replace(src=(k.src[0]//v,), arg=(k.arg[0]+"_0", k.arg[-1]))*v + k.replace(src=(v,), arg=(k.arg[0]+"_1", k.arg[-1]))
|
||||||
if not len(subs): return None
|
if not len(subs): return None
|
||||||
ret = x.substitute(subs).simplify()
|
ret = x.substitute(subs).simplify()
|
||||||
ctx.clear()
|
ctx.clear()
|
||||||
|
|
@ -152,7 +152,7 @@ def cut_store_range(ctx, store:UOp, r:UOp):
|
||||||
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
|
if r.src[0].op is not Ops.CONST or ctx!="CPU": return None
|
||||||
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
|
if not (cuts:=[c.src[1].arg for c in store.get_consumer_map()[r] if c.op is Ops.CMPLT and r is c.src[0] and c.src[1].op is Ops.CONST]): return None
|
||||||
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
|
cuts = sorted(dedup([0] + cuts + [r.src[0].arg]))
|
||||||
ranges = [UOp.range((end-start), *(r.arg[0:-1]+(i,r.arg[-1]))) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
|
ranges = [UOp.range((end-start), r.arg[0]+f"_{i}", r.arg[-1]) for i,(start,end) in enumerate(zip(cuts[:-1], cuts[1:]))]
|
||||||
|
|
||||||
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
|
return UOp.group(*[store.substitute({r: new_r+start}).end(new_r) for new_r, start in zip(ranges, cuts[:-1])])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ def split_reduceop(reduce:UOp, x:UOp):
|
||||||
# get expanded by rangeifying the UOp x
|
# get expanded by rangeifying the UOp x
|
||||||
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
|
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
|
||||||
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
|
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
|
||||||
is_expanded = [i not in range_nums for i in range(len(x.shape))]
|
is_expanded = [str(i) not in range_nums for i in range(len(x.shape))]
|
||||||
|
|
||||||
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
||||||
if x.shape[i]%d==0 and not is_expanded[i]]): return None
|
if x.shape[i]%d==0 and not is_expanded[i]]): return None
|
||||||
|
|
@ -289,7 +289,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
|
||||||
for s in root.src:
|
for s in root.src:
|
||||||
if s.op in GroupOp.Elementwise:
|
if s.op in GroupOp.Elementwise:
|
||||||
# Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP
|
# Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP
|
||||||
orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges]
|
orig_ranges, end_ranges = s.ranges, [x.replace(arg=(str(next(ctx.range_idx)), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges]
|
||||||
s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=s.device)).index(*orig_ranges)
|
s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=s.device)).index(*orig_ranges)
|
||||||
srcs.append(s)
|
srcs.append(s)
|
||||||
return root.replace(src=tuple(srcs))
|
return root.replace(src=tuple(srcs))
|
||||||
|
|
@ -412,7 +412,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||||
if r.arg[-1] == AxisType.OUTER:
|
if r.arg[-1] == AxisType.OUTER:
|
||||||
# for outer range, we replace with a bound variable
|
# for outer range, we replace with a bound variable
|
||||||
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
|
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
|
||||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
ret = r.replace(arg=(str(ctx.range),)+r.arg[1:], tag=None)
|
||||||
ctx.range += 1
|
ctx.range += 1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,8 @@ def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||||
|
|
||||||
def range_str(u:UOp, color=False) -> str:
|
def range_str(u:UOp, color=False) -> str:
|
||||||
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
|
assert len(u.arg) == 2
|
||||||
return colored(ret, axis_colors[u.arg[-1]]) if color else ret
|
return colored(u.arg[0], axis_colors[u.arg[-1]]) if color else u.arg[0]
|
||||||
|
|
||||||
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
|
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
|
||||||
|
|
@ -86,6 +86,7 @@ class UOpMetaClass(type):
|
||||||
if _buffer is not None:
|
if _buffer is not None:
|
||||||
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
||||||
buffers[created] = _buffer
|
buffers[created] = _buffer
|
||||||
|
if op is Ops.RANGE: assert isinstance(arg[0], str)
|
||||||
if SPEC > 1:
|
if SPEC > 1:
|
||||||
from tinygrad.uop.spec import full_spec, test_pyrender
|
from tinygrad.uop.spec import full_spec, test_pyrender
|
||||||
if SPEC > 2: test_pyrender(created)
|
if SPEC > 2: test_pyrender(created)
|
||||||
|
|
@ -425,8 +426,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||||
return ret
|
return ret
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
|
def range(end:sint, axis_id:str|int, axis_type=AxisType.LOOP, dtype=dtypes.index, src=(), **kwargs):
|
||||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
assert isinstance(axis_type, AxisType), f"{axis_type} must be an AxisType"
|
||||||
|
str_axis_id = ("m"+str(-axis_id)) if isinstance(axis_id, int) and axis_id < 0 else str(axis_id)
|
||||||
|
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(str_axis_id, axis_type), **kwargs)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||||
|
|
@ -1350,7 +1353,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||||
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
|
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
|
||||||
# NOTE: range has srcs sometimes after control flow
|
# NOTE: range has srcs sometimes after control flow
|
||||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
|
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
|
||||||
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
|
"UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+
|
||||||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
||||||
# TODO: index shouldn't mismatch dtype
|
# TODO: index shouldn't mismatch dtype
|
||||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,7 @@ shared_spec = PatternMatcher([
|
||||||
|
|
||||||
# RANGE can be in the big graph now
|
# RANGE can be in the big graph now
|
||||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) == 2 and isinstance(rng.arg[0], str) and isinstance(rng.arg[-1], AxisType)),
|
||||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
|
||||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||||
|
|
||||||
# RANGE/SPECIAL define loops, END closes them
|
# RANGE/SPECIAL define loops, END closes them
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue