This commit is contained in:
George Hotz 2025-08-01 16:56:47 -07:00
commit b4ab6de416
2 changed files with 31 additions and 11 deletions

View file

@ -7,7 +7,7 @@ from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.rangeify import pm_rangeify, pm_name
from tinygrad.codegen.rangeify import pm_rangeify, pm_name, RangeifyContext
from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims
@ -45,7 +45,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
ret: list[RewriteStep] = []
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
#ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
ret.append(RewriteStep(pm_rangeify, lambda _: [0], name="rangeify", bottom_up=True))
ret.append(RewriteStep(pm_rangeify, lambda _: RangeifyContext(), name="rangeify", bottom_up=True))
ret.append(RewriteStep(pm_name, lambda _: [0], name="name"))
# ** expander (expand_rewrite) **

View file

@ -1,30 +1,44 @@
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, KernelInfo, GroupOp, AxisType
from tinygrad.opt.kernel import axis_colors
from tinygrad.opt.kernel import axis_colors, Opt, OptOps
from dataclasses import dataclass
from tinygrad.dtype import dtypes
from tinygrad.helpers import argsort, colored
from tinygrad.helpers import argsort, colored, prod
def rangify_store(ctx:list[int], x:UOp):
@dataclass
class RangeifyContext:
idx: int = 0
opts: tuple[Opt, ...] = ()
def rangify_store(ctx:RangeifyContext, x:UOp):
if x.tag == 1: return None
ranges = []
for s in x.shape:
for i,s in enumerate(x.shape):
upcast_amount = prod([o.arg if o.arg != 0 else s for o in ctx.opts if o.axis == i and o.op == OptOps.UPCAST])
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, (ctx[0], AxisType.LOOP)))
ctx[0] += 1
if upcast_amount != 1:
print(x.shape, upcast_amount)
assert s%upcast_amount == 0
rng = UOp.range(dtypes.int, s//upcast_amount, (ctx.idx, AxisType.LOOP)) * upcast_amount
rng = rng + UOp.range(dtypes.int, upcast_amount, (ctx.idx, AxisType.UPCAST))
ranges.append(rng)
else:
ranges.append(UOp.range(dtypes.int, s, (ctx.idx, AxisType.LOOP)))
ctx.idx += 1
else:
ranges.append(UOp.const(dtypes.int, 0))
mm = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
mm2 = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[1],)+tuple(ranges))
return UOp(Ops.STORE, src=(mm, mm2)+tuple([x for x in ranges if x.op is not Ops.CONST]), tag=1)
def map_reduce(ctx:list[int], x:UOp, r:UOp):
def map_reduce(ctx:RangeifyContext, x:UOp, r:UOp):
rngs = list(x.src[1:])
new_ranges = []
for i,s in enumerate(r.src[0].shape):
if i in r.arg[1]:
assert rngs[i].op == Ops.CONST
rngs[i] = UOp.range(dtypes.int, s, (ctx[0], AxisType.REDUCE))
rngs[i] = UOp.range(dtypes.int, s, (ctx.idx, AxisType.REDUCE))
new_ranges.append(rngs[i])
ctx[0] += 1
ctx.idx += 1
mm = UOp(Ops.INDEX, r.src[0].dtype, src=(r.src[0],)+tuple(rngs))
return UOp(Ops.REDUCE, r.dtype, src=(mm,)+tuple(new_ranges), arg=r.arg[0])
@ -62,8 +76,14 @@ def map_pad(x:UOp, r:UOp):
# PAD is with 0
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
def capture_sink(ctx:RangeifyContext, x: UOp):
if x.arg is None: return None
if x.arg.opts_to_apply is not None: ctx.opts = x.arg.opts_to_apply
return x.replace(arg=None)
pm_rangeify = PatternMatcher([
(UPat(Ops.SINK, name="x"), capture_sink),
# TODO: handle INDEX on STORE
(UPat(Ops.STORE, name="x"), rangify_store),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="r"),), allow_any_len=True, name="x"), map_reduce),