mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
split_deco
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e740b115f | ||
|
|
2ef5255b09 | ||
|
|
d319a044a6 | ||
|
|
27396b8eed |
2 changed files with 6 additions and 1 deletions
|
|
@ -82,11 +82,15 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||||
supported_ops = tuple(opts.code_for_op.keys())
|
supported_ops = tuple(opts.code_for_op.keys())
|
||||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||||
|
|
||||||
|
# decompositions
|
||||||
|
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
|
||||||
|
ret.append(RewriteStep(pm_decomp, name="decompositions"))
|
||||||
|
|
||||||
# optional pre matcher
|
# optional pre matcher
|
||||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||||
|
|
||||||
# final rules for the renderer (without sym)
|
# final rules for the renderer (without sym)
|
||||||
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)+pm_render+extra_matcher
|
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
||||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
||||||
|
|
||||||
# return the list (with optional linearizer)
|
# return the list (with optional linearizer)
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if o
|
||||||
ptx_matcher = PatternMatcher([
|
ptx_matcher = PatternMatcher([
|
||||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||||
|
(UPat.var('x', dtype=dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
|
||||||
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
||||||
# upcast to float32 all the ops that don't support half
|
# upcast to float32 all the ops that don't support half
|
||||||
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue