correct but slower

This commit is contained in:
George Hotz 2025-04-01 16:11:47 +08:00
commit 910cddbbca
4 changed files with 38 additions and 18 deletions

View file

@ -1,10 +1,10 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context, Tensor
from tinygrad import Device, Context, Tensor, GlobalCounters
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
import numpy as np
@ -48,12 +48,29 @@ if __name__ == "__main__":
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if getenv("VALIDATE"):
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
correct = ei.bufs[0].numpy()
GlobalCounters.kernel_count -= 1
#if knum != 1 and not getenv("NOOPT"): k.hand_coded_optimizations()
if not getenv("NOOPT"): k.hand_coded_optimizations()
#if knum == 6:
# k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
# k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
# k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei.run()
new_jit.append(new_ei)
test = ei.bufs[0].numpy()
if getenv("VALIDATE"):
import numpy as np
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
knum += 1
if getenv("RUN_JIT", 0):

View file

@ -2,7 +2,7 @@ from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, Context
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@ -178,7 +178,8 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3)
else:

View file

@ -402,7 +402,7 @@ if CAPTURE_PROCESS_REPLAY:
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
metadata: tuple[Metadata, ...] = ()
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:

View file

@ -53,13 +53,13 @@ def multi_mul(a0, a1, b0, b1, c0, c1, d0=None, d1=None, acc=None):
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_128B({0}, {1})")
# __builtin_HEXAGON_A2_vraddub_acc
@ -96,11 +96,12 @@ def multi_add_int32(**aa):
m0 = UOp(Ops.CAT, dtypes.uchar.vec(128), src=tuple(aa[k].src[0] for k in sorted(aa.keys()))).gep(swizzle)
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, UOp.const(dtypes.uint, 0x01010101)),
"__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
"__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, 0x01010101)), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, 0x01010101)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
def multi_add_int2(**aa):
return None
if 'acc' in aa:
acc = aa['acc']
del aa['acc']
@ -122,8 +123,9 @@ conv_pm = PatternMatcher([
UPat(name="c0")*UPat(name="c1"), multi_mul),
# __builtin_HEXAGON_V6_vrmpybus x3
(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
# this is wrong
#(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
#(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
])
dsp_pm = PatternMatcher([
@ -176,10 +178,10 @@ dsp_pm = PatternMatcher([
])+gep_pushing
def add_to_mul(c:UOp, x:UOp):
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpybus_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpybusv_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpyub_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpyubv_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
else:
@ -311,7 +313,7 @@ dsp_pm_late = PatternMatcher([
(UPat(Ops.VECTORIZE, dtypes.uchar.vec(128), name="vec"), vectorize_shuffle),
# __builtin_HEXAGON_V6_vrmpybus_acc_128B
# __builtin_HEXAGON_V6_vrmpyub_acc_128B
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)