mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
13ea4979d5
commit
3e8225299c
4 changed files with 27 additions and 20 deletions
|
|
@ -491,6 +491,17 @@ reducer = PatternMatcher([
|
|||
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
idx = x.src[0].index(x.src[1])
|
||||
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
|
||||
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
|
||||
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
|
||||
|
||||
indexing = PatternMatcher([
|
||||
# use indexing for LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
linearize_cnt = 0
|
||||
|
|
@ -510,7 +521,8 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
sink = graph_rewrite(sink, sym+just_reduce)
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
sink = graph_rewrite(sink, sym+(indexing if opts is not None and opts.indexing else PatternMatcher([]))+\
|
||||
get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
|
||||
return sink
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class Renderer:
|
|||
tensor_cores: List[TensorCore] = []
|
||||
extra_matcher: Any = None
|
||||
code_for_op: Dict[Op, Callable] = {}
|
||||
indexing: bool = False
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
|
|
|
|||
|
|
@ -47,12 +47,6 @@ base_rewrite = PatternMatcher([
|
|||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
idx = x.src[0].index(x.src[1])
|
||||
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
|
||||
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
|
||||
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
|
||||
|
|
@ -61,8 +55,6 @@ extra_pm = PatternMatcher([
|
|||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(UOps.BITCAST, name="x"),
|
||||
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
|
||||
# use indexing for LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))),
|
||||
|
|
@ -85,6 +77,7 @@ class CStyleLanguage(Renderer):
|
|||
type_map: Dict[DType, str] = {}
|
||||
infinity: str = "INFINITY"
|
||||
nan: str = "NAN"
|
||||
indexing: bool = True
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class LLVMRenderer(Renderer):
|
|||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
indexing = True
|
||||
code_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
|
|
@ -89,12 +90,13 @@ class LLVMRenderer(Renderer):
|
|||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.STORE:
|
||||
idx = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
if len(src) > 3:
|
||||
with bb[-1].if_then(lvars[src[3]]): bb[-1].store(lvars[src[2]], idx)
|
||||
if uop is UOps.INDEX:
|
||||
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
elif uop is UOps.STORE:
|
||||
if len(src) > 2:
|
||||
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
else:
|
||||
bb[-1].store(lvars[src[2]], idx)
|
||||
bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
elif uop is UOps.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
|
|
@ -121,18 +123,17 @@ class LLVMRenderer(Renderer):
|
|||
lvars[u] = const(src[0].arg, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
idx = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
if len(src) > 2:
|
||||
with bb[-1].if_else(lvars[src[3]]) as (then, otherwise):
|
||||
if len(src) > 1:
|
||||
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
|
||||
with then:
|
||||
val1 = bb[-1].load(idx)
|
||||
val1 = bb[-1].load(lvars[src[0]])
|
||||
then_blk = bb[-1].block
|
||||
with otherwise: otherwise_blk = bb[-1].block
|
||||
val = bb[-1].phi(val1.type)
|
||||
val.add_incoming(val1, then_blk)
|
||||
val.add_incoming(lvars[src[2]], otherwise_blk)
|
||||
val.add_incoming(lvars[src[1]], otherwise_blk)
|
||||
else:
|
||||
val = bb[-1].load(idx)
|
||||
val = bb[-1].load(lvars[src[0]])
|
||||
lvars[u] = val
|
||||
elif uop is UOps.ASSIGN:
|
||||
lvars[u] = lvars[src[1]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue