mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
correct but slower
This commit is contained in:
parent
e6e0c0ec86
commit
910cddbbca
4 changed files with 38 additions and 18 deletions
|
|
@ -1,10 +1,10 @@
|
|||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context, Tensor
|
||||
from tinygrad import Device, Context, Tensor, GlobalCounters
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
import numpy as np
|
||||
|
|
@ -48,12 +48,29 @@ if __name__ == "__main__":
|
|||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
|
||||
correct = ei.bufs[0].numpy()
|
||||
GlobalCounters.kernel_count -= 1
|
||||
|
||||
#if knum != 1 and not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
if not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
#if knum == 6:
|
||||
# k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
# k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
# k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei.run()
|
||||
new_jit.append(new_ei)
|
||||
test = ei.bufs[0].numpy()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
|
||||
knum += 1
|
||||
|
||||
if getenv("RUN_JIT", 0):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Optional, cast, Generator
|
|||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, Context
|
||||
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
|
|
@ -178,7 +178,8 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i
|
|||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -402,7 +402,7 @@ if CAPTURE_PROCESS_REPLAY:
|
|||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
|
|
|
|||
|
|
@ -53,13 +53,13 @@ def multi_mul(a0, a1, b0, b1, c0, c1, d0=None, d1=None, acc=None):
|
|||
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
|
||||
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_128B({0}, {1})")
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_128B({0}, {1})")
|
||||
|
||||
# __builtin_HEXAGON_A2_vraddub_acc
|
||||
|
||||
|
|
@ -96,11 +96,12 @@ def multi_add_int32(**aa):
|
|||
m0 = UOp(Ops.CAT, dtypes.uchar.vec(128), src=tuple(aa[k].src[0] for k in sorted(aa.keys()))).gep(swizzle)
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, UOp.const(dtypes.uint, 0x01010101)),
|
||||
"__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
|
||||
"__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, 0x01010101)), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, 0x01010101)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
|
||||
def multi_add_int2(**aa):
|
||||
return None
|
||||
if 'acc' in aa:
|
||||
acc = aa['acc']
|
||||
del aa['acc']
|
||||
|
|
@ -122,8 +123,9 @@ conv_pm = PatternMatcher([
|
|||
UPat(name="c0")*UPat(name="c1"), multi_mul),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
# this is wrong
|
||||
#(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
#(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
])
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
|
|
@ -176,10 +178,10 @@ dsp_pm = PatternMatcher([
|
|||
])+gep_pushing
|
||||
|
||||
def add_to_mul(c:UOp, x:UOp):
|
||||
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpybus_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
|
||||
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpybusv_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
|
||||
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpyub_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpyubv_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
|
||||
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
|
||||
else:
|
||||
|
|
@ -311,7 +313,7 @@ dsp_pm_late = PatternMatcher([
|
|||
|
||||
(UPat(Ops.VECTORIZE, dtypes.uchar.vec(128), name="vec"), vectorize_shuffle),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus_acc_128B
|
||||
# __builtin_HEXAGON_V6_vrmpyub_acc_128B
|
||||
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
|
||||
|
||||
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue