Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
2379d5e607
Merge branch 'master' into r_new_opt 2025-08-21 16:54:51 -07:00
George Hotz
a1e7bd5c09 new opt with rangeify 2025-08-21 15:37:07 -07:00
3 changed files with 36 additions and 3 deletions

View file

@ -41,6 +41,9 @@ if __name__ == "__main__":
if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
test_acc = get_test_acc().item()
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))

View file

@ -1,7 +1,7 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
@ -18,6 +18,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt import pm_optimize
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt
@dataclass
class RewriteStep:
@ -44,10 +45,10 @@ rewrites_for_linearizer = [
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL, _RANGEIFY) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
@ -57,6 +58,10 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
# this is the new optimizer
if _RANGEIFY:
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="new optimize ast"))
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))

View file

@ -0,0 +1,25 @@
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.helpers import partition, flatten, prod
def unroll_range(r:UOp):
# all ranges sub 5 can be UNROLLS
if r.src[0].op is Ops.CONST and r.vmax < 5:
i = r.arg
s = r.vmax+1
return UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),), tag=1)
def fix_reduce(r:UOp):
reduce_range, reduce_expand = partition(r.src[1:], lambda y: y.op is Ops.RANGE)
if len(reduce_expand) == 0: return None
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
ret = r.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, r.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return UOp(Ops.REDUCE, r.dtype, (ret,)+tuple(reduce_range), r.arg)
pm_postrange_opt = PatternMatcher([
(UPat(Ops.RANGE, name="r"), unroll_range),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
])