mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2379d5e607 |
||
|
|
a1e7bd5c09 |
3 changed files with 36 additions and 3 deletions
|
|
@ -41,6 +41,9 @@ if __name__ == "__main__":
|
|||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
test_acc = get_test_acc().item()
|
||||
|
||||
# verify eval acc
|
||||
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
||||
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -18,6 +18,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
|||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt import pm_optimize
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
|
|
@ -44,10 +45,10 @@ rewrites_for_linearizer = [
|
|||
|
||||
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value)
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL, _RANGEIFY) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
|
|
@ -57,6 +58,10 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
||||
# this is the new optimizer
|
||||
if _RANGEIFY:
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="new optimize ast"))
|
||||
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
|
|
|
|||
25
tinygrad/codegen/opt/postrange.py
Normal file
25
tinygrad/codegen/opt/postrange.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
|
||||
from tinygrad.helpers import partition, flatten, prod
|
||||
|
||||
def unroll_range(r:UOp):
|
||||
# all ranges sub 5 can be UNROLLS
|
||||
if r.src[0].op is Ops.CONST and r.vmax < 5:
|
||||
i = r.arg
|
||||
s = r.vmax+1
|
||||
return UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),), tag=1)
|
||||
|
||||
def fix_reduce(r:UOp):
|
||||
reduce_range, reduce_expand = partition(r.src[1:], lambda y: y.op is Ops.RANGE)
|
||||
if len(reduce_expand) == 0: return None
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
|
||||
ret = r.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, r.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return UOp(Ops.REDUCE, r.dtype, (ret,)+tuple(reduce_range), r.arg)
|
||||
|
||||
pm_postrange_opt = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r"), unroll_range),
|
||||
(UPat(Ops.REDUCE, name="r"), fix_reduce),
|
||||
])
|
||||
Loading…
Add table
Add a link
Reference in a new issue