mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
simple_pri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6809ff8fe1 |
2 changed files with 29 additions and 27 deletions
|
|
@ -20,16 +20,8 @@ def linearize(u:UOp) -> list[UOp]:
|
||||||
# this will cause ranges to be placed late and ends to be placed early
|
# this will cause ranges to be placed late and ends to be placed early
|
||||||
run_count = prod([int(r.vmax)+1 for r in u.ranges])
|
run_count = prod([int(r.vmax)+1 for r in u.ranges])
|
||||||
|
|
||||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
# simple priority
|
||||||
priority = [0] + [priorities[x][1] for x in consumers[u]]
|
priorities[u] = (run_count, 0)
|
||||||
if u.op is Ops.LOAD: priority.append(-1000)
|
|
||||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
|
||||||
# ranges are scheduled as late as possible so anything that can be outside is
|
|
||||||
# if u.op is Ops.RANGE: priority = [2000]
|
|
||||||
if u.op is Ops.END: priority = [-1000]
|
|
||||||
# move defines and consts to the top
|
|
||||||
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
|
|
||||||
priorities[u] = (run_count, min(priority))
|
|
||||||
|
|
||||||
# number the uops in "ideal" order
|
# number the uops in "ideal" order
|
||||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,22 @@ class FastEnum(IntEnum):
|
||||||
|
|
||||||
# the order of these Ops controls the order of the toposort
|
# the order of these Ops controls the order of the toposort
|
||||||
class Ops(FastEnum):
|
class Ops(FastEnum):
|
||||||
|
# ** 1 -- defines/consts **
|
||||||
|
|
||||||
|
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
|
||||||
|
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
|
||||||
|
|
||||||
|
# this is for symbolic shapes
|
||||||
|
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
|
||||||
|
|
||||||
|
# consts. VCONST is a vectorized const
|
||||||
|
VCONST = auto(); CONST = auto() # noqa: E702
|
||||||
|
|
||||||
|
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||||
|
SPECIAL = auto()
|
||||||
|
|
||||||
|
# ** 2 -- non op uops **
|
||||||
|
|
||||||
# uops that aren't rendered
|
# uops that aren't rendered
|
||||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
||||||
SENTINEL = auto()
|
SENTINEL = auto()
|
||||||
|
|
@ -32,33 +48,28 @@ class Ops(FastEnum):
|
||||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||||
MULTI = auto() # MULTI is really a movement op
|
MULTI = auto() # MULTI is really a movement op
|
||||||
|
|
||||||
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
|
# reduce (movement)
|
||||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
|
|
||||||
|
|
||||||
# this is for symbolic shapes
|
|
||||||
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
|
|
||||||
|
|
||||||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
|
||||||
SPECIAL = auto()
|
|
||||||
|
|
||||||
# reduce
|
|
||||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||||
|
|
||||||
# optimization helper ops
|
# optimization helper ops
|
||||||
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||||
|
|
||||||
# UnaryOps
|
# ** 3 -- load/store **
|
||||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
|
||||||
|
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
||||||
|
INDEX = auto()
|
||||||
|
|
||||||
# load/store before math
|
# load/store before math
|
||||||
LOAD = auto(); STORE = auto() # noqa: E702
|
LOAD = auto(); STORE = auto() # noqa: E702
|
||||||
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
|
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
|
||||||
|
|
||||||
|
# ** 4 -- math **
|
||||||
|
|
||||||
# tensor core math op, not elementwise
|
# tensor core math op, not elementwise
|
||||||
WMMA = auto()
|
WMMA = auto()
|
||||||
|
|
||||||
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
# UnaryOps
|
||||||
INDEX = auto()
|
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
||||||
|
|
||||||
# BinaryOps
|
# BinaryOps
|
||||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
||||||
|
|
@ -69,12 +80,11 @@ class Ops(FastEnum):
|
||||||
# TernaryOps
|
# TernaryOps
|
||||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||||
|
|
||||||
|
# ** 5 -- control flow / other **
|
||||||
|
|
||||||
# control flow ops
|
# control flow ops
|
||||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
|
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
|
||||||
|
|
||||||
# consts. VCONST is a vectorized const
|
|
||||||
VCONST = auto(); CONST = auto() # noqa: E702
|
|
||||||
|
|
||||||
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
||||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue