mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
earlier_ga
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42c1f4d5b6 | ||
|
|
e36288d047 |
||
|
|
41c14d3558 |
||
|
|
eaa362f49e | ||
|
|
ebf45b4f34 |
2 changed files with 10 additions and 6 deletions
|
|
@ -80,6 +80,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
# move gates from unrenderable INVALID where
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
# final symbolic
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# optional pre matcher
|
||||
|
|
@ -93,9 +96,6 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
|
||||
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||
|
||||
# move gates from unrenderable INVALID where
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||
|
|
|
|||
|
|
@ -419,8 +419,10 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
|||
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
|
||||
|
||||
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
if (n:=x.dtype.count) == 1:
|
||||
return f2f(x.replace(dtype=f2f_dt[fr], src=(x.src[0], x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) > 1 else x.src), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)+
|
||||
((x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) > 1 else ())), fr, to) for i in range(n)))
|
||||
|
||||
def f2f_store(st, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
|
||||
|
|
@ -522,7 +524,9 @@ pm_long_decomp = PatternMatcher([
|
|||
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), allow_any_len=True, name='x'),
|
||||
lambda x,idx: x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)+
|
||||
(x.src[1].cast(l2i_dt[x.dtype]), x.src[2]) if len(x.src) > 1 else ())),
|
||||
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
|
||||
])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue