mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Ops.BEFORE finally fixes the DEFINE_REG hacks
This commit is contained in:
parent
c0c4bc9d7c
commit
386bbf311c
4 changed files with 8 additions and 5 deletions
|
|
@ -284,9 +284,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
|
||||
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.load()] + lst # put acc as the first element
|
||||
input_ranges = tuple([x for x in inp.toposort() if x.op is Ops.RANGE and x not in reduce_range])
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), src=(identity,), arg=(ctx.acc_num,))
|
||||
acc = acc.index(UOp.const(dtypes.int, 0)).store(identity, UOp(Ops.BEFORE, src=input_ranges))
|
||||
lst = [acc.load(*reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ class CStyleLanguage(Renderer):
|
|||
c: defaultdict[str, int] = defaultdict(int)
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.BEFORE: continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class FastEnum(IntEnum):
|
|||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto() # noqa: E702
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); BEFORE = auto() # noqa: E702
|
||||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D"}
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BEFORE: "#f8a0e0"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue