Compare commits

...

57 commits

Author SHA1 Message Date
George Hotz
e76de41110 fixes 2026-06-02 12:53:00 -07:00
George Hotz
5768042e3f fix after merge 2026-06-02 12:45:51 -07:00
George Hotz
ef9c60238e
Merge branch 'master' into shrink_in_render 2026-06-02 11:20:24 -07:00
George Hotz
5825d4d833
Merge branch 'master' into shrink_in_render 2026-06-01 22:08:00 -07:00
George Hotz
394afe40c5 fix for nir 2026-06-01 21:46:45 -07:00
George Hotz
774847a54d is github broken 2026-06-01 19:25:48 -07:00
George Hotz
a35c1c9c38
Merge branch 'master' into shrink_in_render 2026-06-01 19:25:15 -07:00
George Hotz
2bcd46d946
Merge branch 'master' into shrink_in_render 2026-06-01 19:10:41 -07:00
George Hotz
1ffe08bab9 webgpu 2026-06-01 18:58:06 -07:00
George Hotz
894f1221c2 webgpu fixes 2026-06-01 18:44:09 -07:00
George Hotz
f3ecb4f8e8 amd emu fix 2026-06-01 18:16:52 -07:00
George Hotz
3dbaa526fa update for ANON 2026-06-01 17:49:12 -07:00
George Hotz
31a87addca renderer fixes 2026-06-01 16:58:01 -07:00
George Hotz
c64a37fa7d
Merge branch 'master' into shrink_in_render 2026-06-01 16:48:22 -07:00
George Hotz
d6f1aadeb7 test fixes 2026-06-01 16:47:26 -07:00
George Hotz
beb6c3ab3e fix llvm test_tiny 2026-06-01 16:28:28 -07:00
George Hotz
08fe658f74 fix tests 2026-06-01 16:22:16 -07:00
George Hotz
f532e1b2d0 test updates 2026-06-01 16:12:37 -07:00
George Hotz
1c18e1bae8 fix examples 2026-06-01 16:06:52 -07:00
George Hotz
50ac2872b3 fixes 2026-06-01 15:50:40 -07:00
George Hotz
8d327d4877 index gate 2026-06-01 15:19:00 -07:00
George Hotz
bdbee57f34 fix dl/dr shape 2026-06-01 14:53:41 -07:00
George Hotz
7b00120d92 oops, remove that 2026-06-01 14:45:29 -07:00
George Hotz
46541d70f4
Merge branch 'master' into shrink_in_render 2026-06-01 14:43:17 -07:00
George Hotz
8850ce9380 reg/local 2026-06-01 13:50:40 -07:00
George Hotz
4571b0d98a llvm work 2026-06-01 13:22:54 -07:00
George Hotz
9f78877d14 some llvm fixes 2026-06-01 11:47:20 -07:00
George Hotz
6f506dc55e fix python 2026-06-01 11:38:40 -07:00
George Hotz
12752b8a44 fix wmma 2026-05-31 17:59:48 -07:00
George Hotz
e808f698bc fix uops stats 2026-05-31 14:58:03 -07:00
George Hotz
27835b5a31 fix CHECK_OOB 2026-05-31 14:21:35 -07:00
George Hotz
9ccee6aae7
Merge branch 'master' into shrink_in_render 2026-05-31 09:49:24 -07:00
George Hotz
fdc7d4c0af work 2026-05-31 09:29:58 -07:00
George Hotz
ea70715344
Merge branch 'master' into shrink_in_render 2026-05-31 09:29:52 -07:00
George Hotz
8ba3ee138e
Merge branch 'master' into shrink_in_render 2026-05-31 09:27:40 -07:00
George Hotz
604b35aa67
Merge branch 'master' into shrink_in_render 2026-05-29 19:17:11 -07:00
George Hotz
7754025f2a just global/local 2026-05-29 18:20:48 -07:00
George Hotz
5ac25ba991 that 2026-05-29 18:18:36 -07:00
George Hotz
507c68dbc8 more passing 2026-05-29 18:02:25 -07:00
George Hotz
3e0335a4a0 work 2026-05-29 17:57:18 -07:00
George Hotz
cb755bded6 test tiny passes 2026-05-29 15:28:53 -07:00
George Hotz
10c2a50e79 no GEP in program 2026-05-29 15:03:01 -07:00
George Hotz
7b951e691e both 2026-05-29 14:51:56 -07:00
George Hotz
90b2c7e115 do it as renderer hacks for now 2026-05-29 13:42:52 -07:00
George Hotz
14394eb97d stack spec 2026-05-29 13:24:16 -07:00
George Hotz
213cc5b6b0 test plus works 2026-05-29 13:18:12 -07:00
George Hotz
1670dbfacd
Merge branch 'master' into shrink_in_render 2026-05-29 13:06:31 -07:00
George Hotz
927e16fbdb
Merge branch 'master' into shrink_in_render 2026-05-29 12:58:31 -07:00
George Hotz
f165839386
Merge branch 'master' into shrink_in_render 2026-05-29 11:40:00 -07:00
George Hotz
3baba3f23f
Merge branch 'master' into shrink_in_render 2026-05-29 01:37:25 -07:00
George Hotz
a75ad9fbaa no vec dtype 2026-05-28 20:34:40 -07:00
George Hotz
4115d330ab
Merge branch 'master' into shrink_in_render 2026-05-28 19:27:45 -07:00
George Hotz
7da2c151be
Merge branch 'master' into shrink_in_render 2026-05-28 19:18:31 -07:00
George Hotz
19246184d3
Merge branch 'master' into shrink_in_render 2026-05-28 15:06:43 -07:00
George Hotz
f17eb03634
Merge branch 'master' into shrink_in_render 2026-05-28 14:40:06 -07:00
George Hotz
1c10882bf0 something 2026-05-28 13:31:48 -07:00
George Hotz
6881b32e84 use shrink in renderers 2026-05-28 11:58:57 -07:00
9 changed files with 222 additions and 116 deletions

View file

@ -3,16 +3,16 @@ from dataclasses import replace
import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
@ -24,6 +24,25 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
# NOTE: this is temporary until we fix the devectorizer
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
UOp(Ops.SHRINK, dtype=buf.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count))) if isinstance(buf.dtype, PtrDType) else None),
# rewrite GEP to INDEX
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
])
pm_remove_vec_dtypes = PatternMatcher([
# rewrite PARAM to non pointer
(UPat((Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
# remove all vec dtypes
(UPat(GroupOp.All-{Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
@ -100,6 +119,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="remove vec dtypes")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)

View file

@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
from tinygrad.dtype import AddrSpace, DType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@ -41,7 +41,7 @@ class Estimates:
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
if u.op is Ops.RANGE:
mult_stack.append(mults)
@ -51,10 +51,10 @@ class Estimates:
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))

View file

@ -8,9 +8,17 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
def render_index(ctx,buf,idx):
base = buf
while base.op is Ops.AFTER: base = base.src[0]
if base.addrspace == AddrSpace.ANON:
assert idx.op is Ops.CONST, f"{idx.op} must be CONST"
return f"{ctx[buf]}[{idx.arg}]"
else:
return f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
@ -18,14 +26,14 @@ base_rewrite = PatternMatcher([
(UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
(UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(ctx.render_dtype_with_shape(x)))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
# const
@ -43,24 +51,26 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# SHRINK/INDEX
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), render_index),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), render_index),
# new load/store
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), name="x"), lambda ctx,x,bidx: ctx.render_access(bidx, ctx.render_dtype_with_shape(x))),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")), name="x"),
lambda ctx,x,bidx,var,gate: f"({ctx[gate]}?{ctx.render_access(bidx, ctx.render_dtype_with_shape(x))}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))),
lambda ctx,bidx,var: f"{ctx.render_access(bidx, ctx.render_dtype_with_shape(var))} = {ctx[var]};"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
])
extra_pm = PatternMatcher([
# devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.SHRINK), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST (from bool) can't be vectorized
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
# WHERE can't be vectorized
@ -95,7 +105,11 @@ pm_manual_bf16_cast = PatternMatcher([
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16),
])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
def dtype_with_shape(dtype:DType, shape:tuple) -> DType:
return dtype.scalar().vec(prod(shape)) if dtype.count == 1 and len(shape) == 1 and isinstance(shape[0], int) and shape[0] > 1 else dtype
def uops_to_dtypes(uops:list[UOp]) -> list[DType]:
return dedup(dtype_with_shape(u.dtype, u._shape or ()) if u.addrspace not in {AddrSpace.GLOBAL, AddrSpace.LOCAL} else u.dtype
for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]):
@ -133,10 +147,12 @@ class CStyleLanguage(Renderer):
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
tmp = ""
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
def arg_dtype(u:UOp) -> DType:
return u.dtype if isinstance(u.dtype, (ImageDType, PtrDType)) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
if any(isinstance(arg_dtype(u), ImageDType) for _,(u,_) in bufs):
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
buftypes = [(name, self.render_dtype(dt, mutable)+self.buffer_suffix if isinstance(dt, (ImageDType, PtrDType)) else
self.arg_int_prefix if dt == dtypes.int else None) for name,(u,mutable) in bufs for dt in (arg_dtype(u),)]
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
launch_bounds = prod([d.vmax for d in local_dims])
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
@ -145,6 +161,9 @@ class CStyleLanguage(Renderer):
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
def render_dtype_with_shape(self, u:UOp) -> DType: return dtype_with_shape(u.dtype, u.shape)
def render_access(self, bidx:UOp, dtype:DType) -> str:
return f"(*(({self.render_dtype(dtype.ptr(addrspace=bidx.addrspace))})({self[bidx]})))" if dtype.count > 1 else f"(*{self[bidx]})"
def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
@ -190,21 +209,24 @@ class CStyleLanguage(Renderer):
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
Ops.INDEX: "bidx", Ops.SHRINK: "bidx",
Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SHRINK, Ops.INDEX, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l
else:
if u.op is Ops.SHRINK or u._shape is None: u_dtype = u.src[0].dtype
else: u_dtype = self.render_dtype_with_shape(u)
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
l = f"{self.render_dtype(u_dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {Ops.IF, Ops.RANGE}: depth += 1

View file

@ -5,12 +5,14 @@ from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats, pm_manual_bf16_cast
from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate, AddrSpace
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
def ldt(dt:DType, count=1, ptr=False):
#if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
#if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
if ptr: return ldt(dt, count) + "*"
if count > 1: return f"<{count} x {ldt(dt, 1, ptr)}>"
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
@ -54,12 +56,13 @@ def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
N,M,K = wmma.arg[1]
if cdna:
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + \
", i32 0, i32 0, i32 0)"
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
if wmma.dtype.scalar() != dtypes.float else ")")
# llvm ops, lop[<dtype>][<op>]
@ -75,25 +78,31 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
f" {ctx[x]} = extractelement {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, i32 {x.src[1].arg}" \
if x.addrspace == AddrSpace.ANON else None),
(UPat((Ops.INDEX, Ops.SHRINK), name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype)}, {ldt(x.dtype, ptr=True)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" {ctx[x]}_yes = load {ldt(x.dtype, idx.max_numel())}, {ldt(idx.dtype, x.max_numel(), True)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat.var('idx').load(name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat.var('idx').store(UPat.var("var")), lambda ctx,idx,var: f" store {ldt(var.dtype)} {ctx[var]}, {ldt(idx.dtype)} {ctx[idx]}"),
f" {ctx[x]} = phi {ldt(x.dtype, idx.max_numel())} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat.var('idx').load(name="x"),
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(idx.dtype, idx.max_numel())}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
(UPat.var('idx').store(UPat.var("var")),
lambda ctx,idx,var:
f" store {ldt(var.dtype, idx.max_numel())} {ctx[var]}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
# GEP/VECTORIZE/CAST for float4 support
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
#(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.max_numel()} x i32> zeroinitializer"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(
f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f" = insertelement {ldt(x.dtype, x.max_numel())} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
# unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
@ -137,7 +146,8 @@ class LLVMRenderer(Renderer):
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str:
# NOTE: CPUAllocator promises 0x20 alignment
sargs = ", ".join([f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {name}" for name,u in args])
sargs = ", ".join([f"{ldt(u.dtype, ptr=u.addrspace == AddrSpace.GLOBAL)}{' noalias align 32' if u.addrspace == AddrSpace.GLOBAL else ''} {name}"
for name,u in args])
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
r: dict[UOp, str] = {}

View file

@ -1,17 +1,17 @@
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
from tinygrad.runtime.autogen import mesa
from tinygrad.runtime.support.c import POINTER
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools, os, warnings
def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
def glsl_type(t:DType): return {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
@ -86,9 +85,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda count:count, bs=lambda dtype:dtype.bitsize, num_components=lambda count:count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
lambda b, space, addr, dtype, count=1: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
@ -104,16 +103,31 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
def _load_count(x:UOp) -> int: return x.max_numel() if 1 < x.max_numel() <= 4 else 1
def _pad_count(b:mesa.nir_builder, dtype:DType, count:int, val):
return val if val.num_components == count else nalu(b, f"vec{count}", val, *[nundef(b, dtype) for _ in range(count-1)])
def nidx(b:mesa.nir_builder, buf, off, dtype, addrspace, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
f = (functools.partial(reg, b, buf) if addrspace == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
def ngated_load_index(ctx, x, buf, off, alt, gate):
cnt = _load_count(x)
return if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace, ctx.r[gate]), x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
def ngated_load_shrink(ctx, x, idx, alt, gate):
cnt = _load_count(idx)
return if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
class NIRRenderer(Renderer):
suffix = "NIR"
nir_options: bytes
@ -137,7 +151,7 @@ class NIRRenderer(Renderer):
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
src=(buf,off.cast(dtypes.long))) if buf.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
# images need index to be int for nir
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
@ -149,18 +163,28 @@ class NIRRenderer(Renderer):
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), ctx.r[val], val.dtype)),
(UPat(Ops.STORE, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("val"))),
lambda ctx,idx,val: nstore(ctx.b, idx.addrspace, ctx.r[idx], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
ngated_load_index),
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("alt"), UPat.var("gate")), name="x"),
ngated_load_shrink),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), x.dtype, _load_count(x))),
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"),), name="x"),
lambda ctx,x,idx: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, _load_count(idx))),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("off"), UPat.cvar()), name="x"),
lambda ctx,x,buf,off: nidx(ctx.b, ctx.r[buf], ctx.r[off], x.dtype, x.addrspace)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: ctx.r[x.src[0]] if len(x.src) == 1 else
nalu(ctx.b, f"vec{len(x.src)}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat.cvar("idx"))),
lambda ctx,a,idx: nchannel(ctx.b, ctx.r[a], idx.arg) if a.addrspace == AddrSpace.ANON else None),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
mesa.glsl_array_type(glsl_type(x.dtype), x.src[0].arg, 0), f"acc{x.arg}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
@ -189,8 +213,7 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, []
for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.INDEX and u.src[0].addrspace != AddrSpace.ANON): pass
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK:
@ -198,7 +221,7 @@ class NIRRenderer(Renderer):
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
elif u.op == Ops.DEFINE_LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
self.b.shader.contents.info.shared_size += u.src[0].arg * u.dtype.itemsize
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
@ -217,7 +240,17 @@ class NIRRenderer(Renderer):
self.r[u] = cast(mesa.nir_def, d)
self.postrender(uops)
mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 2 and hasattr(os, "fork"):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
pid = os.fork()
if pid == 0:
mesa.nir_validate_shader(self.b.shader, b"after render")
os._exit(0)
_, status = os.waitpid(pid, 0)
if os.WIFSIGNALED(status): raise RuntimeError(f"NIR validation failed after render with signal {os.WTERMSIG(status)}")
if os.WEXITSTATUS(status) != 0: raise RuntimeError(f"NIR validation failed after render with exit code {os.WEXITSTATUS(status)}")
else: mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
"__stdoutp" if OSX else "stdout"))
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
@ -257,9 +290,10 @@ class LVPRenderer(NIRRenderer):
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=4,
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
srcs=lambda b,img,idx_y,idx_x,val:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
srcs=lambda b,img,idx_y,idx_x,val,dtype:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int),
val if val.num_components == 4 else nalu(b, "vec4", val, nundef(b, dtype), nundef(b, dtype), nundef(b, dtype)), nimm(b, 0, dtypes.int)]])(
lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
@ -277,8 +311,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True),
lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate"))),
lambda ctx,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_y, idx_x), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: ctx.nload_img(img, idx_y, idx_x) if len(x.shape) > 0 and x.shape[-1] == 4 else nchannel(ctx.b, ctx.nload_img(img, idx_y, idx_x), 0),
lambda: ctx.r[alt] if len(x.shape) == 0 or x.shape[-1] != 4 or ctx.r[alt].num_components == 4 else
nalu(ctx.b, "vec4", ctx.r[alt], nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype)))),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img),
]) + NIRRenderer.def_rewrite

View file

@ -27,12 +27,19 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
def is_packed(dt:DType, odt:DType|None = None) -> bool:
if odt is None: odt = dt
# registers aren't packed
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
return dt.itemsize < 4 and dt.base != dtypes.half
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
def is_packed(x:UOp) -> bool:
if x.op is Ops.LOAD: dt, addrspace = x.dtype, x.src[0].addrspace
elif x.op is Ops.STORE: dt, addrspace = x.src[1].dtype, x.src[0].addrspace
else: dt, addrspace = x.dtype.base, x.dtype.addrspace if isinstance(x.dtype, PtrDType) else x.addrspace
return dt.itemsize < 4 and dt.base != dtypes.half and addrspace != AddrSpace.REG
def _packed_size(ctx, x:UOp):
size = ctx[x.src[0]]
if not is_packed(x): return size
elems = 4 // x.dtype.base.itemsize
return str((x.src[0].arg + elems - 1) // elems) if x.src[0].op is Ops.CONST else f"(({size}+{elems-1})/{elems})"
def _buf_map(ctx, x:UOp): return ctx.type_map[x.dtype.base] if x.addrspace == AddrSpace.REG else ctx.buf_map(x.dtype.base)
def is_nan(a):
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
@ -43,12 +50,12 @@ wgsl_matcher = PatternMatcher([
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"), name="s"),
lambda s,bidx,var,gate: packed_store(bidx,var,gate) if is_packed(s) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), name="s"),
lambda s,bidx,var: packed_store(bidx,var) if is_packed(s) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
# fix nan check: 'a != a -> is_nan()'
@ -73,8 +80,8 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
@ -85,12 +92,12 @@ class WGSLRenderer(CStyleLanguage):
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate"), name="l"),
lambda ctx,l,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b)}, {ctx[gate]})"),
(UPat.load(UPat.var("b"), name="l"), lambda ctx,l,b: ctx.render_load(ctx[b], b)),
(UPat.store(UPat.var("b"), UPat.var("v"), name="s"), lambda ctx,s,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b) and b.addrspace != AddrSpace.REG \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
@ -98,9 +105,11 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
def render_load(self, x:str, uop:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(uop) and uop.addrspace != AddrSpace.REG else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt != dtypes.half else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
def arg_dtype(u:UOp) -> DType:
return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1]
bind_it = iter(range(len(bufs)))
@ -110,8 +119,9 @@ class WGSLRenderer(CStyleLanguage):
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
f"{'var<storage,read_write>' if isinstance(dt, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(dt.base)}>' if isinstance(dt,PtrDType) else self.buf_map(dt)};"
for name,(u,_) in bufs for dt in (arg_dtype(u),)])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View file

@ -54,18 +54,20 @@ class DSPRenderer(ClangRenderer):
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
if isinstance(b[1][0].dtype, PtrDType)]
if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
params = [(f'buf_{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)]
msrc += [f"{function_name}({', '.join(params)});"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["return 0; }"]
return '\n'.join(msrc)
@ -275,22 +277,23 @@ class MockDSPRenderer(DSPRenderer):
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = [mockdsp_boilerplate, 'void _start(void) {']
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType):
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType):
sz = dt.size*dt.itemsize
# for loop for big reads
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
params = [(f'(void*)buf{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
msrc.append(f"{function_name}({', '.join(params)});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType): msrc.append(f"write(1, buf{i}, {dt.size*dt.itemsize});")
msrc.append('exit(0); }')
return '\n'.join(msrc)

View file

@ -5,8 +5,8 @@
from typing import Any, TYPE_CHECKING
import pickle, base64, itertools, time, sys, functools
from dataclasses import replace
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target, IMAGE
from tinygrad.dtype import DType, dtypes, ImageDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, AddrSpace
from tinygrad.helpers import all_same, getenv, flatten, Target, IMAGE
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp, bitcast
@ -101,19 +101,21 @@ class PythonProgram:
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
elif u.op is Ops.INDEX:
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
assert len(src_values) == 3, "image index must be 3 srcs"
for m,oy,ox in zip(*src_values):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
elif u.op is Ops.SHRINK or (u.op is Ops.INDEX and len(src_values) == 2):
if u.addrspace == AddrSpace.ANON:
# old GEP
assert all_same(src_values[1]), "all index must be the same"
values[u] = src_values[0][src_values[1][0]]
else:
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
# normal index
values[u] = [(m,o) for m,o in zip(src_values[0], src_values[1])]
elif u.op is Ops.INDEX and len(src_values) == 3:
assert isinstance(src_dtypes[0], ImageDType), "3 src index is only for Image"
ret:list = []
for m,oy,ox in zip(*src_values):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
values[u] = ret
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
values[u] = src_values[0]
elif u.op is Ops.RANGE:
if u not in values: values[u] = [0] * warp_size
else:
@ -134,7 +136,6 @@ class PythonProgram:
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
else:
values[u] = load(src_values, 0, u.dtype)
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
elif u.op is Ops.WMMA:
first_src_dtype = u.src[0].dtype
assert isinstance(first_src_dtype, DType) # mypy

View file

@ -114,6 +114,8 @@ class DLL(ctypes.CDLL):
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?"
if nm == 'llvm' and (ver:=getenv("LLVM_VERSION", "")):
paths = ([f"/opt/homebrew/opt/llvm@{ver}/lib/libLLVM.dylib"] if OSX else [f"LLVM-{ver}"]) + (paths if isinstance(paths, list) else [paths])
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
if DEBUG >= 3: print(f"loading {nm} from {path}")
try: