Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
1e740b115f restore that 2025-08-12 11:24:01 -07:00
George Hotz
2ef5255b09 pack load store early 2025-08-12 11:00:14 -07:00
George Hotz
d319a044a6 fix ptx 2025-08-12 10:55:37 -07:00
George Hotz
27396b8eed split decompositions pass 2025-08-12 10:42:20 -07:00
2 changed files with 6 additions and 1 deletions

View file

@ -82,11 +82,15 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# decompositions
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
ret.append(RewriteStep(pm_decomp, name="decompositions"))
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
# final rules for the renderer (without sym)
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)+pm_render+extra_matcher
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
# return the list (with optional linearizer)

View file

@ -38,6 +38,7 @@ doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if o
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
# upcast to float32 all the ops that don't support half
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),