Compare commits

...

57 commits

Author SHA1 Message Date
George Hotz
e76de41110 fixes 2026-06-02 12:53:00 -07:00
George Hotz
5768042e3f fix after merge 2026-06-02 12:45:51 -07:00
George Hotz
ef9c60238e
Merge branch 'master' into shrink_in_render 2026-06-02 11:20:24 -07:00
George Hotz
5825d4d833
Merge branch 'master' into shrink_in_render 2026-06-01 22:08:00 -07:00
George Hotz
394afe40c5 fix for nir 2026-06-01 21:46:45 -07:00
George Hotz
774847a54d is github broken 2026-06-01 19:25:48 -07:00
George Hotz
a35c1c9c38
Merge branch 'master' into shrink_in_render 2026-06-01 19:25:15 -07:00
George Hotz
2bcd46d946
Merge branch 'master' into shrink_in_render 2026-06-01 19:10:41 -07:00
George Hotz
1ffe08bab9 webgpu 2026-06-01 18:58:06 -07:00
George Hotz
894f1221c2 webgpu fixes 2026-06-01 18:44:09 -07:00
George Hotz
f3ecb4f8e8 amd emu fix 2026-06-01 18:16:52 -07:00
George Hotz
3dbaa526fa update for ANON 2026-06-01 17:49:12 -07:00
George Hotz
31a87addca renderer fixes 2026-06-01 16:58:01 -07:00
George Hotz
c64a37fa7d
Merge branch 'master' into shrink_in_render 2026-06-01 16:48:22 -07:00
George Hotz
d6f1aadeb7 test fixes 2026-06-01 16:47:26 -07:00
George Hotz
beb6c3ab3e fix llvm test_tiny 2026-06-01 16:28:28 -07:00
George Hotz
08fe658f74 fix tests 2026-06-01 16:22:16 -07:00
George Hotz
f532e1b2d0 test updates 2026-06-01 16:12:37 -07:00
George Hotz
1c18e1bae8 fix examples 2026-06-01 16:06:52 -07:00
George Hotz
50ac2872b3 fixes 2026-06-01 15:50:40 -07:00
George Hotz
8d327d4877 index gate 2026-06-01 15:19:00 -07:00
George Hotz
bdbee57f34 fix dl/dr shape 2026-06-01 14:53:41 -07:00
George Hotz
7b00120d92 oops, remove that 2026-06-01 14:45:29 -07:00
George Hotz
46541d70f4
Merge branch 'master' into shrink_in_render 2026-06-01 14:43:17 -07:00
George Hotz
8850ce9380 reg/local 2026-06-01 13:50:40 -07:00
George Hotz
4571b0d98a llvm work 2026-06-01 13:22:54 -07:00
George Hotz
9f78877d14 some llvm fixes 2026-06-01 11:47:20 -07:00
George Hotz
6f506dc55e fix python 2026-06-01 11:38:40 -07:00
George Hotz
12752b8a44 fix wmma 2026-05-31 17:59:48 -07:00
George Hotz
e808f698bc fix uops stats 2026-05-31 14:58:03 -07:00
George Hotz
27835b5a31 fix CHECK_OOB 2026-05-31 14:21:35 -07:00
George Hotz
9ccee6aae7
Merge branch 'master' into shrink_in_render 2026-05-31 09:49:24 -07:00
George Hotz
fdc7d4c0af work 2026-05-31 09:29:58 -07:00
George Hotz
ea70715344
Merge branch 'master' into shrink_in_render 2026-05-31 09:29:52 -07:00
George Hotz
8ba3ee138e
Merge branch 'master' into shrink_in_render 2026-05-31 09:27:40 -07:00
George Hotz
604b35aa67
Merge branch 'master' into shrink_in_render 2026-05-29 19:17:11 -07:00
George Hotz
7754025f2a just global/local 2026-05-29 18:20:48 -07:00
George Hotz
5ac25ba991 that 2026-05-29 18:18:36 -07:00
George Hotz
507c68dbc8 more passing 2026-05-29 18:02:25 -07:00
George Hotz
3e0335a4a0 work 2026-05-29 17:57:18 -07:00
George Hotz
cb755bded6 test tiny passes 2026-05-29 15:28:53 -07:00
George Hotz
10c2a50e79 no GEP in program 2026-05-29 15:03:01 -07:00
George Hotz
7b951e691e both 2026-05-29 14:51:56 -07:00
George Hotz
90b2c7e115 do it as renderer hacks for now 2026-05-29 13:42:52 -07:00
George Hotz
14394eb97d stack spec 2026-05-29 13:24:16 -07:00
George Hotz
213cc5b6b0 test plus works 2026-05-29 13:18:12 -07:00
George Hotz
1670dbfacd
Merge branch 'master' into shrink_in_render 2026-05-29 13:06:31 -07:00
George Hotz
927e16fbdb
Merge branch 'master' into shrink_in_render 2026-05-29 12:58:31 -07:00
George Hotz
f165839386
Merge branch 'master' into shrink_in_render 2026-05-29 11:40:00 -07:00
George Hotz
3baba3f23f
Merge branch 'master' into shrink_in_render 2026-05-29 01:37:25 -07:00
George Hotz
a75ad9fbaa no vec dtype 2026-05-28 20:34:40 -07:00
George Hotz
4115d330ab
Merge branch 'master' into shrink_in_render 2026-05-28 19:27:45 -07:00
George Hotz
7da2c151be
Merge branch 'master' into shrink_in_render 2026-05-28 19:18:31 -07:00
George Hotz
19246184d3
Merge branch 'master' into shrink_in_render 2026-05-28 15:06:43 -07:00
George Hotz
f17eb03634
Merge branch 'master' into shrink_in_render 2026-05-28 14:40:06 -07:00
George Hotz
1c10882bf0 something 2026-05-28 13:31:48 -07:00
George Hotz
6881b32e84 use shrink in renderers 2026-05-28 11:58:57 -07:00
9 changed files with 222 additions and 116 deletions

View file

@ -3,16 +3,16 @@ from dataclasses import replace
import itertools import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.render import pyrender from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
@ -24,6 +24,25 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
# NOTE: this is temporary until we fix the devectorizer
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
UOp(Ops.SHRINK, dtype=buf.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count))) if isinstance(buf.dtype, PtrDType) else None),
# rewrite GEP to INDEX
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
])
pm_remove_vec_dtypes = PatternMatcher([
# rewrite PARAM to non pointer
(UPat((Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
# remove all vec dtypes
(UPat(GroupOp.All-{Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
@ -100,6 +119,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="remove vec dtypes")
# this was the linearizer # this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)

View file

@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.helpers import prod, Target, EMULATED_DTYPES from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes from tinygrad.dtype import AddrSpace, DType, dtypes
from tinygrad.codegen.opt.tc import TensorCore from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler from tinygrad.device import Compiler
@ -41,7 +41,7 @@ class Estimates:
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0] while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
if buf.op is Ops.PARAM: if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul) # u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize) mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
if u.op is Ops.RANGE: if u.op is Ops.RANGE:
mult_stack.append(mults) mult_stack.append(mults)
@ -51,10 +51,10 @@ class Estimates:
elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1 elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.dtype.itemsize * mults lds += u.max_numel() * u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.src[1].dtype.itemsize * mults lds += u.max_numel() * u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values())) return Estimates(flops, lds, sum(mem.values()))

View file

@ -8,9 +8,17 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu from tinygrad.codegen.late.devectorizer import no_vectorized_alu
def render_index(ctx,buf,idx):
base = buf
while base.op is Ops.AFTER: base = base.src[0]
if base.addrspace == AddrSpace.ANON:
assert idx.op is Ops.CONST, f"{idx.op} must be CONST"
return f"{ctx[buf]}[{idx.arg}]"
else:
return f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"
base_rewrite = PatternMatcher([ base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"), (UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
@ -18,14 +26,14 @@ base_rewrite = PatternMatcher([
(UPat(Ops.RANGE, name="x"), (UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"), lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
(UPat(Ops.STACK, name="x"), (UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(ctx.render_dtype_with_shape(x)))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"), f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: (UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None), f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: (UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"), f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
# const # const
@ -43,24 +51,26 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"), (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
# default const render # default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# SHRINK/INDEX
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), render_index),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), render_index),
# new load/store # new load/store
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'),), name="x"), lambda ctx,x,bidx: ctx.render_access(bidx, ctx.render_dtype_with_shape(x))),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"), (UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")), name="x"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), lambda ctx,x,bidx,var,gate: f"({ctx[gate]}?{ctx.render_access(bidx, ctx.render_dtype_with_shape(x))}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))),
lambda ctx,bidx,var: f"{ctx.render_access(bidx, ctx.render_dtype_with_shape(var))} = {ctx[var]};"),
# alu/gep # alu/gep
# TODO: look for left-associative # TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op]( (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)), *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format # custom passes through with format
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])), (UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
]) ])
extra_pm = PatternMatcher([ extra_pm = PatternMatcher([
# devectorize any bools # devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu), (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.SHRINK), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST (from bool) can't be vectorized # CAST (from bool) can't be vectorized
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu), (UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
# WHERE can't be vectorized # WHERE can't be vectorized
@ -95,7 +105,11 @@ pm_manual_bf16_cast = PatternMatcher([
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16), (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16),
]) ])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) def dtype_with_shape(dtype:DType, shape:tuple) -> DType:
return dtype.scalar().vec(prod(shape)) if dtype.count == 1 and len(shape) == 1 and isinstance(shape[0], int) and shape[0] > 1 else dtype
def uops_to_dtypes(uops:list[UOp]) -> list[DType]:
return dedup(dtype_with_shape(u.dtype, u._shape or ()) if u.addrspace not in {AddrSpace.GLOBAL, AddrSpace.LOCAL} else u.dtype
for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes) # (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]): def wmma_args(uops:list[UOp]):
@ -133,10 +147,12 @@ class CStyleLanguage(Renderer):
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str: def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
tmp = "" tmp = ""
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs): def arg_dtype(u:UOp) -> DType:
return u.dtype if isinstance(u.dtype, (ImageDType, PtrDType)) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
if any(isinstance(arg_dtype(u), ImageDType) for _,(u,_) in bufs):
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else buftypes = [(name, self.render_dtype(dt, mutable)+self.buffer_suffix if isinstance(dt, (ImageDType, PtrDType)) else
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs] self.arg_int_prefix if dt == dtypes.int else None) for name,(u,mutable) in bufs for dt in (arg_dtype(u),)]
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
launch_bounds = prod([d.vmax for d in local_dims]) launch_bounds = prod([d.vmax for d in local_dims])
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] + prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
@ -145,6 +161,9 @@ class CStyleLanguage(Renderer):
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})" def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
def render_dtype_with_shape(self, u:UOp) -> DType: return dtype_with_shape(u.dtype, u.shape)
def render_access(self, bidx:UOp, dtype:DType) -> str:
return f"(*(({self.render_dtype(dtype.ptr(addrspace=bidx.addrspace))})({self[bidx]})))" if dtype.count > 1 else f"(*{self[bidx]})"
def render_dtype(self, dt:DType, mutable=True) -> str: def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t" if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType): if isinstance(dt, PtrDType):
@ -190,21 +209,24 @@ class CStyleLanguage(Renderer):
else: else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu") Ops.INDEX: "bidx", Ops.SHRINK: "bidx",
Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}" r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SHRINK, Ops.INDEX, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \ (u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): (u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l r[u] = l
else: else:
if u.op is Ops.SHRINK or u._shape is None: u_dtype = u.src[0].dtype
else: u_dtype = self.render_dtype_with_shape(u)
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void: if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "") l = f"{self.render_dtype(u_dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l) kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {Ops.IF, Ops.RANGE}: depth += 1 if u.op in {Ops.IF, Ops.RANGE}: depth += 1

View file

@ -5,12 +5,14 @@ from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats, pm_manual_bf16_cast from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats, pm_manual_bf16_cast
from tinygrad.uop.decompositions import xexp2, xlog2 from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate, AddrSpace
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
def ldt(dt:DType): def ldt(dt:DType, count=1, ptr=False):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>" #if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*" #if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
if ptr: return ldt(dt, count) + "*"
if count > 1: return f"<{count} x {ldt(dt, 1, ptr)}>"
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8", dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt] dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
@ -54,12 +56,13 @@ def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
N,M,K = wmma.arg[1] N,M,K = wmma.arg[1]
if cdna: if cdna:
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"}) if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \ return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + \
", i32 0, i32 0, i32 0)"
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101) # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \ return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \ f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
if wmma.dtype.scalar() != dtypes.float else ")") if wmma.dtype.scalar() != dtypes.float else ")")
# llvm ops, lop[<dtype>][<op>] # llvm ops, lop[<dtype>][<op>]
@ -75,25 +78,31 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
base_rewrite = PatternMatcher([ base_rewrite = PatternMatcher([
# memory load/store # memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x: (UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"), f" {ctx[x]} = extractelement {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, i32 {x.src[1].arg}" \
if x.addrspace == AddrSpace.ANON else None),
(UPat((Ops.INDEX, Ops.SHRINK), name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype)}, {ldt(x.dtype, ptr=True)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"), (UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask: lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n" f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n" f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n" f" {ctx[x]}_yes = load {ldt(x.dtype, idx.max_numel())}, {ldt(idx.dtype, x.max_numel(), True)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n" f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"), f" {ctx[x]} = phi {ldt(x.dtype, idx.max_numel())} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat.var('idx').load(name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"), (UPat.var('idx').load(name="x"),
(UPat.var('idx').store(UPat.var("var")), lambda ctx,idx,var: f" store {ldt(var.dtype)} {ctx[var]}, {ldt(idx.dtype)} {ctx[idx]}"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(idx.dtype, idx.max_numel())}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
(UPat.var('idx').store(UPat.var("var")),
lambda ctx,idx,var:
f" store {ldt(var.dtype, idx.max_numel())} {ctx[var]}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
# GEP/VECTORIZE/CAST for float4 support # GEP/VECTORIZE/CAST for float4 support
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"), #(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y: (UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n" f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"), f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.max_numel()} x i32> zeroinitializer"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+ (UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+ f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])), f" = insertelement {ldt(x.dtype, x.max_numel())} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
# unary/binary/ternary ops # unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
@ -137,7 +146,8 @@ class LLVMRenderer(Renderer):
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str: def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str:
# NOTE: CPUAllocator promises 0x20 alignment # NOTE: CPUAllocator promises 0x20 alignment
sargs = ", ".join([f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {name}" for name,u in args]) sargs = ", ".join([f"{ldt(u.dtype, ptr=u.addrspace == AddrSpace.GLOBAL)}{' noalias align 32' if u.addrspace == AddrSpace.GLOBAL else ''} {name}"
for name,u in args])
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"]) return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]: def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
r: dict[UOp, str] = {} r: dict[UOp, str] = {}

View file

@ -1,17 +1,17 @@
from typing import Callable, cast, Any from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
from tinygrad.runtime.autogen import mesa from tinygrad.runtime.autogen import mesa
from tinygrad.runtime.support.c import POINTER from tinygrad.runtime.support.c import POINTER
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools import base64, ctypes, ctypes.util, struct, functools, inspect, itertools, os, warnings
def g(s:str): return getattr(mesa, s) def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d)) def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else { def glsl_type(t:DType): return {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]}, **{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b")) def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def: def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src) return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable): def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
@ -86,9 +85,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)}, nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])( num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}"))) lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count, nload = nir_instr(nc=lambda count:count, bs=lambda dtype:dtype.bitsize, num_components=lambda count:count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])( intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}"))) lambda b, space, addr, dtype, count=1: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id)) ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id)) nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
@ -104,16 +103,31 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def: def _load_count(x:UOp) -> int: return x.max_numel() if 1 < x.max_numel() <= 4 else 1
def _pad_count(b:mesa.nir_builder, dtype:DType, count:int, val):
return val if val.num_components == count else nalu(b, f"vec{count}", val, *[nundef(b, dtype) for _ in range(count-1)])
def nidx(b:mesa.nir_builder, buf, off, dtype, addrspace, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type)) @nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf): def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array) deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off) deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else f = (functools.partial(reg, b, buf) if addrspace == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long)))) lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f() return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
def ngated_load_index(ctx, x, buf, off, alt, gate):
cnt = _load_count(x)
return if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace, ctx.r[gate]), x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
def ngated_load_shrink(ctx, x, idx, alt, gate):
cnt = _load_count(idx)
return if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
class NIRRenderer(Renderer): class NIRRenderer(Renderer):
suffix = "NIR" suffix = "NIR"
nir_options: bytes nir_options: bytes
@ -137,7 +151,7 @@ class NIRRenderer(Renderer):
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)), (UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D # load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace( (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None), src=(buf,off.cast(dtypes.long))) if buf.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
# images need index to be int for nir # images need index to be int for nir
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")), (UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))), lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
@ -149,18 +163,28 @@ class NIRRenderer(Renderer):
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)), (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))), (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)), lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), ctx.r[val], val.dtype)),
(UPat(Ops.STORE, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("val"))),
lambda ctx,idx,val: nstore(ctx.b, idx.addrspace, ctx.r[idx], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"), (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate], ngated_load_index),
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])), (UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("alt"), UPat.var("gate")), name="x"),
ngated_load_shrink),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"), (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)), lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), x.dtype, _load_count(x))),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])), (UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"),), name="x"),
lambda ctx,x,idx: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, _load_count(idx))),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("off"), UPat.cvar()), name="x"),
lambda ctx,x,buf,off: nidx(ctx.b, ctx.r[buf], ctx.r[off], x.dtype, x.addrspace)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: ctx.r[x.src[0]] if len(x.src) == 1 else
nalu(ctx.b, f"vec{len(x.src)}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])), (UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)), (UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]), (UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])), (UPat(Ops.INDEX, src=(UPat.var("a"), UPat.cvar("idx"))),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents), lambda ctx,a,idx: nchannel(ctx.b, ctx.r[a], idx.arg) if a.addrspace == AddrSpace.ANON else None),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
mesa.glsl_array_type(glsl_type(x.dtype), x.src[0].arg, 0), f"acc{x.arg}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)), (UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])), (UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]]))) (UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
@ -189,8 +213,7 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, [] self.param_idx, ranges = 0, []
for u in uops: for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.INDEX and u.src[0].addrspace != AddrSpace.ANON): pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
elif u.op is Ops.AFTER: elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]] self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK: elif u.op == Ops.SINK:
@ -198,7 +221,7 @@ class NIRRenderer(Renderer):
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char]) self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
elif u.op == Ops.DEFINE_LOCAL: elif u.op == Ops.DEFINE_LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long) self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes() self.b.shader.contents.info.shared_size += u.src[0].arg * u.dtype.itemsize
elif u.op == Ops.RANGE: elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents)) ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
@ -217,7 +240,17 @@ class NIRRenderer(Renderer):
self.r[u] = cast(mesa.nir_def, d) self.r[u] = cast(mesa.nir_def, d)
self.postrender(uops) self.postrender(uops)
mesa.nir_validate_shader(self.b.shader, b"after render") if DEBUG >= 2 and hasattr(os, "fork"):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
pid = os.fork()
if pid == 0:
mesa.nir_validate_shader(self.b.shader, b"after render")
os._exit(0)
_, status = os.waitpid(pid, 0)
if os.WIFSIGNALED(status): raise RuntimeError(f"NIR validation failed after render with signal {os.WTERMSIG(status)}")
if os.WEXITSTATUS(status) != 0: raise RuntimeError(f"NIR validation failed after render with exit code {os.WEXITSTATUS(status)}")
else: mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')), if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
"__stdoutp" if OSX else "stdout")) "__stdoutp" if OSX else "stdout"))
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False) mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
@ -257,9 +290,10 @@ class LVPRenderer(NIRRenderer):
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int)) def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32 def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components, nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=4,
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)}, intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
srcs=lambda b,img,idx_y,idx_x,val:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])( srcs=lambda b,img,idx_y,idx_x,val,dtype:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int),
val if val.num_components == 4 else nalu(b, "vec4", val, nundef(b, dtype), nundef(b, dtype), nundef(b, dtype)), nimm(b, 0, dtypes.int)]])(
lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store"))) lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)}, _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
@ -277,8 +311,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def_rewrite = PatternMatcher([ def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True), (UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True),
lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)), lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate"))), (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_y, idx_x), lambda: ctx.r[alt])), lambda ctx,x,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: ctx.nload_img(img, idx_y, idx_x) if len(x.shape) > 0 and x.shape[-1] == 4 else nchannel(ctx.b, ctx.nload_img(img, idx_y, idx_x), 0),
lambda: ctx.r[alt] if len(x.shape) == 0 or x.shape[-1] != 4 or ctx.r[alt].num_components == 4 else
nalu(ctx.b, "vec4", ctx.r[alt], nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype)))),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img), (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img),
]) + NIRRenderer.def_rewrite ]) + NIRRenderer.def_rewrite

View file

@ -27,12 +27,19 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
val = (load.cast(dtypes.uint32) >> shift_am) & mask val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype) return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
def is_packed(dt:DType, odt:DType|None = None) -> bool: def is_packed(x:UOp) -> bool:
if odt is None: odt = dt if x.op is Ops.LOAD: dt, addrspace = x.dtype, x.src[0].addrspace
# registers aren't packed elif x.op is Ops.STORE: dt, addrspace = x.src[1].dtype, x.src[0].addrspace
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False else: dt, addrspace = x.dtype.base, x.dtype.addrspace if isinstance(x.dtype, PtrDType) else x.addrspace
return dt.itemsize < 4 and dt.base != dtypes.half return dt.itemsize < 4 and dt.base != dtypes.half and addrspace != AddrSpace.REG
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
def _packed_size(ctx, x:UOp):
size = ctx[x.src[0]]
if not is_packed(x): return size
elems = 4 // x.dtype.base.itemsize
return str((x.src[0].arg + elems - 1) // elems) if x.src[0].op is Ops.CONST else f"(({size}+{elems-1})/{elems})"
def _buf_map(ctx, x:UOp): return ctx.type_map[x.dtype.base] if x.addrspace == AddrSpace.REG else ctx.buf_map(x.dtype.base)
def is_nan(a): def is_nan(a):
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype) bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
@ -43,12 +50,12 @@ wgsl_matcher = PatternMatcher([
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)), lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const # TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"), (UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None), lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None), (UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")), (UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"), name="s"),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None), lambda s,bidx,var,gate: packed_store(bidx,var,gate) if is_packed(s) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")), (UPat.store(UPat.var("bidx"), UPat.var("var"), name="s"),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None), lambda s,bidx,var: packed_store(bidx,var) if is_packed(s) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None), (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
# fix nan check: 'a != a -> is_nan()' # fix nan check: 'a != a -> is_nan()'
@ -73,8 +80,8 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"), (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"), (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)), (UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"), (UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
@ -85,12 +92,12 @@ class WGSLRenderer(CStyleLanguage):
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"), if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const # TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")), (UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate"), name="l"),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"), lambda ctx,l,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b)}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)), (UPat.load(UPat.var("b"), name="l"), lambda ctx,l,b: ctx.render_load(ctx[b], b)),
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\ (UPat.store(UPat.var("b"), UPat.var("v"), name="s"), lambda ctx,s,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \ f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b) and b.addrspace != AddrSpace.REG \
else f"{ctx[b]} = {ctx[v]};"), else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))), (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
@ -98,9 +105,11 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var" def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x def render_load(self, x:str, uop:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(uop) and uop.addrspace != AddrSpace.REG else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base] def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt != dtypes.half else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str: def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
def arg_dtype(u:UOp) -> DType:
return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)] local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1] if not local_size: local_size = [1]
bind_it = iter(range(len(bufs))) bind_it = iter(range(len(bufs)))
@ -110,8 +119,9 @@ class WGSLRenderer(CStyleLanguage):
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n" prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n" prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" + f"{'var<storage,read_write>' if isinstance(dt, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs]) f"{name}:{f'array<{self.buf_map(dt.base)}>' if isinstance(dt,PtrDType) else self.buf_map(dt)};"
for name,(u,_) in bufs for dt in (arg_dtype(u),)])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>," prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}" return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View file

@ -54,18 +54,20 @@ class DSPRenderer(ClangRenderer):
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops) 'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {', msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};', 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);'] 'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;'] msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)] msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)] msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
if isinstance(b[1][0].dtype, PtrDType)] if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"] msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"] params = [(f'buf_{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)]
msrc += [f"{function_name}({', '.join(params)});"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"] msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)] msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["return 0; }"] msrc += ["return 0; }"]
return '\n'.join(msrc) return '\n'.join(msrc)
@ -275,22 +277,23 @@ class MockDSPRenderer(DSPRenderer):
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True) def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops) def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html # https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it # control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = [mockdsp_boilerplate, 'void _start(void) {'] msrc = [mockdsp_boilerplate, 'void _start(void) {']
for i,b in enumerate(bufs): for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType): if isinstance(dt:=arg_dtype(b[1][0]), PtrDType):
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize sz = dt.size*dt.itemsize
# for loop for big reads # for loop for big reads
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));") msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
else: else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);") msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();") msrc.append("unsigned int st = inscount();")
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)] params = [(f'(void*)buf{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
msrc.append(f"{function_name}({', '.join(params)});") msrc.append(f"{function_name}({', '.join(params)});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));") msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs): for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});") if isinstance(dt:=arg_dtype(b[1][0]), PtrDType): msrc.append(f"write(1, buf{i}, {dt.size*dt.itemsize});")
msrc.append('exit(0); }') msrc.append('exit(0); }')
return '\n'.join(msrc) return '\n'.join(msrc)

View file

@ -5,8 +5,8 @@
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import pickle, base64, itertools, time, sys, functools import pickle, base64, itertools, time, sys, functools
from dataclasses import replace from dataclasses import replace
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar from tinygrad.dtype import DType, dtypes, ImageDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, AddrSpace
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target, IMAGE from tinygrad.helpers import all_same, getenv, flatten, Target, IMAGE
from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp, bitcast from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp, bitcast
@ -101,19 +101,21 @@ class PythonProgram:
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp] elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
elif u.op is Ops.INDEX: elif u.op is Ops.SHRINK or (u.op is Ops.INDEX and len(src_values) == 2):
ret:list = [] if u.addrspace == AddrSpace.ANON:
if isinstance(src_dtypes[0], ImageDType): # old GEP
assert len(src_values) == 3, "image index must be 3 srcs" assert all_same(src_values[1]), "all index must be the same"
for m,oy,ox in zip(*src_values): values[u] = src_values[0][src_values[1][0]]
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else: else:
assert len(src_values) == 2, "non-image index must be 2 srcs" # normal index
for m,o in zip(*src_values): ret.append((m,o)) values[u] = [(m,o) for m,o in zip(src_values[0], src_values[1])]
elif u.op is Ops.INDEX and len(src_values) == 3:
assert isinstance(src_dtypes[0], ImageDType), "3 src index is only for Image"
ret:list = []
for m,oy,ox in zip(*src_values):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
values[u] = ret values[u] = ret
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
values[u] = src_values[0]
elif u.op is Ops.RANGE: elif u.op is Ops.RANGE:
if u not in values: values[u] = [0] * warp_size if u not in values: values[u] = [0] * warp_size
else: else:
@ -134,7 +136,6 @@ class PythonProgram:
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)] for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
else: else:
values[u] = load(src_values, 0, u.dtype) values[u] = load(src_values, 0, u.dtype)
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
elif u.op is Ops.WMMA: elif u.op is Ops.WMMA:
first_src_dtype = u.src[0].dtype first_src_dtype = u.src[0].dtype
assert isinstance(first_src_dtype, DType) # mypy assert isinstance(first_src_dtype, DType) # mypy

View file

@ -114,6 +114,8 @@ class DLL(ctypes.CDLL):
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs): def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?" self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?"
if nm == 'llvm' and (ver:=getenv("LLVM_VERSION", "")):
paths = ([f"/opt/homebrew/opt/llvm@{ver}/lib/libLLVM.dylib"] if OSX else [f"LLVM-{ver}"]) + (paths if isinstance(paths, list) else [paths])
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])): if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
if DEBUG >= 3: print(f"loading {nm} from {path}") if DEBUG >= 3: print(f"loading {nm} from {path}")
try: try: