ext gate indexing (#7349)

* ext gate indexing

* copy paste better
This commit is contained in:
George Hotz 2024-10-29 13:46:10 +07:00 committed by GitHub
commit 3e8225299c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 20 deletions

View file

@ -491,6 +491,17 @@ reducer = PatternMatcher([
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1])
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
indexing = PatternMatcher([
# use indexing for LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
])
# *** uop graph ***
linearize_cnt = 0
@ -510,7 +521,8 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
sink = graph_rewrite(sink, sym+just_reduce)
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, sym+reducer)
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
sink = graph_rewrite(sink, sym+(indexing if opts is not None and opts.indexing else PatternMatcher([]))+\
get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
return sink

View file

@ -84,6 +84,7 @@ class Renderer:
tensor_cores: List[TensorCore] = []
extra_matcher: Any = None
code_for_op: Dict[Op, Callable] = {}
indexing: bool = False
def __reduce__(self): return self.__class__, ()
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")

View file

@ -47,12 +47,6 @@ base_rewrite = PatternMatcher([
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1])
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
extra_pm = PatternMatcher([
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
@ -61,8 +55,6 @@ extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
# use indexing for LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
# gate any stores that aren't gated with ifs
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))),
@ -85,6 +77,7 @@ class CStyleLanguage(Renderer):
type_map: Dict[DType, str] = {}
infinity: str = "INFINITY"
nan: str = "NAN"
indexing: bool = True
code_for_op: Dict = {
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",

View file

@ -52,6 +52,7 @@ class LLVMRenderer(Renderer):
has_local = False
has_shared = False
global_max = None
indexing = True
code_for_op: Dict[Op, Callable] = {
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
@ -89,12 +90,13 @@ class LLVMRenderer(Renderer):
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.STORE:
idx = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
if len(src) > 3:
with bb[-1].if_then(lvars[src[3]]): bb[-1].store(lvars[src[2]], idx)
if uop is UOps.INDEX:
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
elif uop is UOps.STORE:
if len(src) > 2:
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
else:
bb[-1].store(lvars[src[2]], idx)
bb[-1].store(lvars[src[1]], lvars[src[0]])
elif uop is UOps.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
@ -121,18 +123,17 @@ class LLVMRenderer(Renderer):
lvars[u] = const(src[0].arg, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
idx = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
if len(src) > 2:
with bb[-1].if_else(lvars[src[3]]) as (then, otherwise):
if len(src) > 1:
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
with then:
val1 = bb[-1].load(idx)
val1 = bb[-1].load(lvars[src[0]])
then_blk = bb[-1].block
with otherwise: otherwise_blk = bb[-1].block
val = bb[-1].phi(val1.type)
val.add_incoming(val1, then_blk)
val.add_incoming(lvars[src[2]], otherwise_blk)
val.add_incoming(lvars[src[1]], otherwise_blk)
else:
val = bb[-1].load(idx)
val = bb[-1].load(lvars[src[0]])
lvars[u] = val
elif uop is UOps.ASSIGN:
lvars[u] = lvars[src[1]]