Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
42c1f4d5b6 fixes of load alt value 2026-05-12 20:44:22 -07:00
George Hotz
e36288d047
Merge branch 'master' into earlier_gater 2026-05-12 20:27:59 -07:00
George Hotz
41c14d3558
Merge branch 'master' into earlier_gater 2026-05-12 19:36:50 -07:00
George Hotz
eaa362f49e even earlier 2026-05-12 19:32:14 -07:00
George Hotz
ebf45b4f34 move gater earlier 2026-05-12 19:26:36 -07:00
2 changed files with 10 additions and 6 deletions

View file

@ -80,6 +80,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# final symbolic
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher
@ -93,9 +96,6 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends

View file

@ -419,8 +419,10 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp:
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
if (n:=x.dtype.count) == 1:
return f2f(x.replace(dtype=f2f_dt[fr], src=(x.src[0], x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) > 1 else x.src), fr, to)
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)+
((x.src[1].cast(f2f_dt[fr]), x.src[2]) if len(x.src) > 1 else ())), fr, to) for i in range(n)))
def f2f_store(st, idx, val, fr:DType, to:DType):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
@ -522,7 +524,9 @@ pm_long_decomp = PatternMatcher([
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), allow_any_len=True, name='x'),
lambda x,idx: x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)+
(x.src[1].cast(l2i_dt[x.dtype]), x.src[2]) if len(x.src) > 1 else ())),
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
])