mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
opt work
This commit is contained in:
parent
3d31b0b5f6
commit
b4ab6de416
2 changed files with 31 additions and 11 deletions
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.spec import type_verify
|
|||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.rangeify import pm_rangeify, pm_name
|
||||
from tinygrad.codegen.rangeify import pm_rangeify, pm_name, RangeifyContext
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -45,7 +45,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
ret: list[RewriteStep] = []
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
#ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_rangeify, lambda _: [0], name="rangeify", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_rangeify, lambda _: RangeifyContext(), name="rangeify", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_name, lambda _: [0], name="name"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
|
|
|
|||
|
|
@ -1,30 +1,44 @@
|
|||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, KernelInfo, GroupOp, AxisType
|
||||
from tinygrad.opt.kernel import axis_colors
|
||||
from tinygrad.opt.kernel import axis_colors, Opt, OptOps
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import argsort, colored
|
||||
from tinygrad.helpers import argsort, colored, prod
|
||||
|
||||
def rangify_store(ctx:list[int], x:UOp):
|
||||
@dataclass
|
||||
class RangeifyContext:
|
||||
idx: int = 0
|
||||
opts: tuple[Opt, ...] = ()
|
||||
|
||||
def rangify_store(ctx:RangeifyContext, x:UOp):
|
||||
if x.tag == 1: return None
|
||||
ranges = []
|
||||
for s in x.shape:
|
||||
for i,s in enumerate(x.shape):
|
||||
upcast_amount = prod([o.arg if o.arg != 0 else s for o in ctx.opts if o.axis == i and o.op == OptOps.UPCAST])
|
||||
if resolve(s!=1):
|
||||
ranges.append(UOp.range(dtypes.int, s, (ctx[0], AxisType.LOOP)))
|
||||
ctx[0] += 1
|
||||
if upcast_amount != 1:
|
||||
print(x.shape, upcast_amount)
|
||||
assert s%upcast_amount == 0
|
||||
rng = UOp.range(dtypes.int, s//upcast_amount, (ctx.idx, AxisType.LOOP)) * upcast_amount
|
||||
rng = rng + UOp.range(dtypes.int, upcast_amount, (ctx.idx, AxisType.UPCAST))
|
||||
ranges.append(rng)
|
||||
else:
|
||||
ranges.append(UOp.range(dtypes.int, s, (ctx.idx, AxisType.LOOP)))
|
||||
ctx.idx += 1
|
||||
else:
|
||||
ranges.append(UOp.const(dtypes.int, 0))
|
||||
mm = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
|
||||
mm2 = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[1],)+tuple(ranges))
|
||||
return UOp(Ops.STORE, src=(mm, mm2)+tuple([x for x in ranges if x.op is not Ops.CONST]), tag=1)
|
||||
|
||||
def map_reduce(ctx:list[int], x:UOp, r:UOp):
|
||||
def map_reduce(ctx:RangeifyContext, x:UOp, r:UOp):
|
||||
rngs = list(x.src[1:])
|
||||
new_ranges = []
|
||||
for i,s in enumerate(r.src[0].shape):
|
||||
if i in r.arg[1]:
|
||||
assert rngs[i].op == Ops.CONST
|
||||
rngs[i] = UOp.range(dtypes.int, s, (ctx[0], AxisType.REDUCE))
|
||||
rngs[i] = UOp.range(dtypes.int, s, (ctx.idx, AxisType.REDUCE))
|
||||
new_ranges.append(rngs[i])
|
||||
ctx[0] += 1
|
||||
ctx.idx += 1
|
||||
mm = UOp(Ops.INDEX, r.src[0].dtype, src=(r.src[0],)+tuple(rngs))
|
||||
return UOp(Ops.REDUCE, r.dtype, src=(mm,)+tuple(new_ranges), arg=r.arg[0])
|
||||
|
||||
|
|
@ -62,8 +76,14 @@ def map_pad(x:UOp, r:UOp):
|
|||
# PAD is with 0
|
||||
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
|
||||
|
||||
def capture_sink(ctx:RangeifyContext, x: UOp):
|
||||
if x.arg is None: return None
|
||||
if x.arg.opts_to_apply is not None: ctx.opts = x.arg.opts_to_apply
|
||||
return x.replace(arg=None)
|
||||
|
||||
pm_rangeify = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), capture_sink),
|
||||
|
||||
# TODO: handle INDEX on STORE
|
||||
(UPat(Ops.STORE, name="x"), rangify_store),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="r"),), allow_any_len=True, name="x"), map_reduce),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue