mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
image_no_v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b910f1d5c0 | ||
|
|
e14b2b41c6 | ||
|
|
bf05a2762e |
||
|
|
08747264cf | ||
|
|
f68c224b71 |
10 changed files with 159 additions and 65 deletions
|
|
@ -21,7 +21,7 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
|
|||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.STACK, dtypes.weakint.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[0].valid(valid), idx[1].valid(valid), ptr=True),
|
||||
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
|
|
@ -222,17 +222,15 @@ class TestValidIdxSimplification(unittest.TestCase):
|
|||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = simplify_image_idx(load.sink()).src[0]
|
||||
off = load.src[0].src[1]
|
||||
idx = off.get_idx()
|
||||
self.assertEqual(idx.op, Ops.STACK)
|
||||
self.assertEqual(len(idx.src), 2)
|
||||
idx0, idx1 = idx.src[0], idx.src[1]
|
||||
off = load.src[0]
|
||||
idx0, idx1 = off.src[1].get_idx(), off.src[2].get_idx()
|
||||
check_uop_against_string(self, idx0, sidx0)
|
||||
check_uop_against_string(self, idx1, sidx1)
|
||||
self.assertEqual(off.src[1].get_valid(), off.src[2].get_valid())
|
||||
if svalid is not None:
|
||||
check_uop_against_string(self, off.get_valid(), svalid)
|
||||
check_uop_against_string(self, off.src[1].get_valid(), svalid)
|
||||
else:
|
||||
self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
|
||||
self.assertEqual(off.src[1].get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
|
||||
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f
|
|||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||
from tinygrad.codegen.late.gater import pm_image_index, pm_move_gates_from_index
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
|
@ -77,6 +77,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
|
||||
# convert image linear offsets to image coordinates before symbolic/index dtype cleanup
|
||||
sink = graph_rewrite(sink, pm_image_index, name="image indexing")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
|
|
|||
|
|
@ -38,19 +38,24 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
|||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
idx = uop_given_valid(valid, start_idx)
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
|
||||
return None if isinstance(buf.dtype, ImageDType) or idx is start_idx else buf.index(idx.valid(valid), ptr=True)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
if start_idx.dtype.count != 2: return None
|
||||
|
||||
drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1])
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
def simplify_valid_image_load(buf:UOp, start_x:UOp, start_y:UOp, valid:UOp) -> UOp|None:
|
||||
if not isinstance(buf.dtype, ImageDType) or start_x.dtype.scalar() is not dtypes.weakint or \
|
||||
start_y.dtype.scalar() is not dtypes.weakint: return None
|
||||
x, y = uop_given_valid(valid, start_x), uop_given_valid(valid, start_y)
|
||||
drop_stmt = _drop_valid_stmts(valid, UOp.vectorize(x, y), buf.dtype.shape[0], buf.dtype.shape[1])
|
||||
if not drop_stmt and x is start_x and y is start_y: return None
|
||||
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
|
||||
return buf.index(x.valid(new_valid) if new_valid is not None else x, y.valid(new_valid) if new_valid is not None else y, ptr=True)
|
||||
|
||||
|
||||
image_invalid_gate_x = UPat.var("cond").where(UPat.var("x"), UPat(Ops.CONST, arg=Invalid))
|
||||
image_invalid_gate_y = UPat.var("cond").where(UPat.var("y"), UPat(Ops.CONST, arg=Invalid))
|
||||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification with scalar x/y coordinates
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), image_invalid_gate_x, image_invalid_gate_y)),
|
||||
lambda buf,x,y,cond: simplify_valid_image_load(buf, x, y, cond)),
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
|
||||
])
|
||||
|
|
@ -192,27 +197,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def get_image_idx(idx:UOp, width:int):
|
||||
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
idx_x, idx_y = (x // 4) % width, x // (4*width)
|
||||
return idx.replace(src=(idx.src[0], UOp.vectorize(idx_x, idx_y).valid(valid)))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
if isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) and ls.src[0].op is Ops.CAST:
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
return ls.replace(src=(get_image_idx(ls.src[0].src[0], dt.shape[1]),)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, we should just make it a buffer
|
||||
if isinstance(dt, ImageDType) and (off:=ls.src[0].src[1]).get_idx().dtype != dtypes.weakint.vec(2):
|
||||
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(off)
|
||||
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
|
||||
|
||||
correct_load_store = PatternMatcher([
|
||||
# split LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
|
||||
# image indexing, including unfoldable images
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
|
|
@ -231,7 +218,7 @@ def no_vectorized_wmma(wmma:UOp):
|
|||
|
||||
def no_vectorized_alu(alu:UOp):
|
||||
if alu.dtype.vcount == 1: return None
|
||||
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index
|
||||
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # gated indexes use cond.where(idx, Invalid)
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.STACK, alu.dtype, alus)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,48 @@
|
|||
# this is a temporary intermediate step while we remove this index style
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
|
||||
from tinygrad.dtype import Invalid, dtypes
|
||||
from tinygrad.dtype import Invalid, dtypes, ImageDType
|
||||
|
||||
def move_image_load_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, l:UOp):
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
return buf.index(x, y, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)
|
||||
|
||||
def move_image_store_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, data:UOp):
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
return buf.index(x, y, ptr=True).cast(cast.dtype).store(data, gate)
|
||||
|
||||
def image_coords_to_int(idx:UOp, buf:UOp, x:UOp, y:UOp):
|
||||
if not isinstance(buf.dtype, ImageDType) or (x.dtype != dtypes.long and y.dtype != dtypes.long): return None
|
||||
return idx.replace(src=(buf, x.cast(dtypes.int) if x.dtype == dtypes.long else x, y.cast(dtypes.int) if y.dtype == dtypes.long else y))
|
||||
|
||||
def index_and_valid(idx:UOp) -> tuple[UOp, UOp]:
|
||||
if idx.dtype.scalar() is dtypes.weakint: return idx.get_idx(), idx.get_valid()
|
||||
if idx.op is Ops.WHERE and idx.src[2].arg is Invalid: return idx.src[1], idx.src[0]
|
||||
return idx, UOp.const(dtypes.bool, idx.arg is not Invalid)
|
||||
|
||||
def valid_idx(idx:UOp, valid:UOp) -> UOp:
|
||||
return idx if valid.op is Ops.CONST and valid.arg is True else valid.where(idx, idx.const_like(Invalid))
|
||||
|
||||
def get_image_idx(idx:UOp, height:int, width:int) -> UOp:
|
||||
x, valid = index_and_valid(idx.src[1])
|
||||
px = x // 4
|
||||
idx_x, idx_y = (px, px.const_like(0)) if height == 1 else (px % width, px // width)
|
||||
return idx.replace(src=(idx.src[0], valid_idx(idx_x, valid), valid_idx(idx_y, valid)))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load/store from split_load_store: casted linear offset -> image x/y coordinates
|
||||
if ls.src[0].op is Ops.CAST and (cast_idx:=ls.src[0].src[0]).op is Ops.INDEX and isinstance(dt:=cast_idx.src[0].dtype, ImageDType):
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[0], dt.shape[1]),)+ls.src[1:])
|
||||
|
||||
if ls.src[0].op is not Ops.INDEX or not isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) or len(ls.src[0].src) == 3: return None
|
||||
|
||||
# this is an unprocessed image without a cast, we should just make it a buffer
|
||||
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(ls.src[0].src[1])
|
||||
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
|
||||
|
||||
pm_image_index = PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
||||
])
|
||||
|
||||
pm_move_gates_from_index = PatternMatcher([
|
||||
# here we create the alt value for load to be 0s and remove the where Invalid
|
||||
|
|
@ -8,6 +50,12 @@ pm_move_gates_from_index = PatternMatcher([
|
|||
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
|
||||
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
|
||||
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
|
||||
move_image_load_gate),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
|
||||
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
|
||||
move_image_store_gate),
|
||||
|
||||
# Where after gated load becomes alt value
|
||||
(UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate"), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
|
||||
|
|
@ -15,7 +63,8 @@ pm_move_gates_from_index = PatternMatcher([
|
|||
(UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a:
|
||||
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
|
||||
|
||||
# vectorized indexes (ie. images) must be int
|
||||
# vectorized indexes must be int
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.var("y")), name="idx"), image_coords_to_int),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -97,6 +97,14 @@ pm_manual_bf16_cast = PatternMatcher([
|
|||
])
|
||||
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
def image_coord(ctx, x:UOp, y:UOp) -> str: return f"(int2)({ctx[x]}, {ctx[y]})"
|
||||
def render_image_load(ctx, buf:UOp, x:UOp, y:UOp, var:UOp|None=None, gate:UOp|None=None) -> str|None:
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
load = f"read_imagef({ctx[buf]}, smp, {image_coord(ctx, x, y)})"
|
||||
return f"({ctx[gate]}?{load}:{ctx[var]})" if gate is not None and var is not None else load
|
||||
def render_image_store(ctx, buf:UOp, x:UOp, y:UOp, var:UOp) -> str|None:
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
return f"write_imagef({ctx[buf]}, {image_coord(ctx, x, y)}, {ctx[var]});"
|
||||
|
||||
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
||||
def wmma_args(uops:list[UOp]):
|
||||
|
|
@ -301,13 +309,15 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var('buf'), UPat.var('x'), UPat.var('y')), name="idx"),
|
||||
lambda ctx,buf,x,y,idx: image_coord(ctx, x, y) if isinstance(buf.dtype, ImageDType) else None),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,x,y,var,gate: render_image_load(ctx, buf, x, y, var, gate)),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),)),
|
||||
lambda ctx,buf,x,y: render_image_load(ctx, buf, x, y)),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,x,y,var: render_image_store(ctx, buf, x, y, var)),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
|
|
|
|||
|
|
@ -114,6 +114,11 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
|||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
def cast_global_index(x:UOp, buf:UOp, off:UOp):
|
||||
if isinstance(buf.dtype, ImageDType) or not isinstance(buf.dtype, PtrDType) or buf.dtype.addrspace == AddrSpace.REG or \
|
||||
off.op in (Ops.CAST, Ops.STACK): return None
|
||||
return x.replace(src=(buf, off.cast(dtypes.long))+x.src[2:])
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
|
|
@ -136,8 +141,7 @@ class NIRRenderer(Renderer):
|
|||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), cast_global_index),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
|
||||
])
|
||||
|
||||
|
|
@ -249,30 +253,39 @@ class LVPRenderer(NIRRenderer):
|
|||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
||||
|
||||
# FIXME: this should be a rewrite rule
|
||||
def tovec(b, coord): return nalu(b, "vec4", nchannel(b, coord, 0), nchannel(b, coord, 1), nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def tovec(b, x, y): return nalu(b, "vec4", x, y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
|
||||
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
|
||||
srcs=lambda b,img,coord,val:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
srcs=lambda b,img,x,y,val:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,x,y,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
|
||||
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
|
||||
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
nc=4, bs=32, num_components=4, srcs=lambda b,img,x,y:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,x,y,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
|
||||
def nstore_img_checked(ctx, img:UOp, x:UOp, y:UOp, val:UOp):
|
||||
if not isinstance(img.dtype, ImageDType): return None
|
||||
return nstore_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], ctx.r[val], val.dtype)
|
||||
|
||||
def nload_img_gated(ctx, img:UOp, x:UOp, y:UOp, alt:UOp, gate:UOp):
|
||||
if not isinstance(img.dtype, ImageDType): return None
|
||||
return if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, x, y), lambda: ctx.r[alt])
|
||||
|
||||
class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||
has_aux = True
|
||||
|
||||
def nload_img(ctx,img,coord):
|
||||
def nload_img(ctx,img,x,y):
|
||||
if not isinstance(img.dtype, ImageDType): return None
|
||||
ctx.texs.add(img)
|
||||
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
|
||||
return _nload_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], img.dtype)
|
||||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("val")), allow_any_len=True),
|
||||
nstore_img_checked),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("alt"), UPat.var("gate"))),
|
||||
nload_img_gated),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
||||
_param = LVPRenderer.param
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
from typing import Any, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, sys, functools
|
||||
from dataclasses import replace
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, Invalid
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.opt import tc
|
||||
|
|
@ -92,11 +92,13 @@ class PythonProgram:
|
|||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
if len(src_values) != 2: raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
|
||||
if len(src_values) != 2 and not isinstance(src_dtypes[0], ImageDType): raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
xs, ys = (src_values[1][0], src_values[1][1]) if len(src_values) == 2 else (src_values[1], src_values[2])
|
||||
for m,ox,oy in zip(src_values[0], xs, ys):
|
||||
invalid = ox is Invalid or oy is Invalid
|
||||
if invalid or ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
|
||||
|
|
|
|||
|
|
@ -1531,6 +1531,19 @@ def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x)
|
|||
def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
def lower_index_casts(idx:UOp) -> UOp|None:
|
||||
new_src, changed = [idx.src[0]], False
|
||||
for s in idx.src[1:]:
|
||||
ns = None
|
||||
if s.op is Ops.CAST and s.dtype == dtypes.weakint and s.src[0].dtype.scalar() in dtypes.ints:
|
||||
ns = s.src[0]
|
||||
elif s.op is Ops.WHERE and s.dtype.scalar() is dtypes.weakint and s.src[2].op is Ops.CONST and s.src[2].arg is Invalid:
|
||||
val = s.src[1]
|
||||
if val.op is Ops.CAST and val.dtype == dtypes.weakint and val.src[0].dtype.scalar() in dtypes.ints:
|
||||
ns = s.src[0].where(val.src[0], val.src[0].const_like(Invalid))
|
||||
new_src.append(ns if ns is not None else s)
|
||||
changed = changed or ns is not None
|
||||
return idx.replace(src=tuple(new_src)) if changed else None
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
|
||||
|
|
@ -1552,6 +1565,7 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
||||
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
|
||||
(UPat(Ops.INDEX, name="idx"), lower_index_casts),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
|||
from tinygrad.uop.validate import validate_index_with_z3
|
||||
return validate_index_with_z3(sz, idx, gate)
|
||||
|
||||
def validate_image_index(buf:UOp, idx0:UOp, idx1:UOp, gate:UOp|None=None):
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
return validate_index(buf, idx0, gate) and validate_index(buf, idx1, gate)
|
||||
|
||||
# four specs:
|
||||
# shared_spec -- usable anywhere
|
||||
# tensor_spec -- usable in tensor graph
|
||||
|
|
@ -180,6 +184,13 @@ shared_codegen_spec = PatternMatcher([
|
|||
lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(), validate_image_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(
|
||||
UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
lambda buf,idx0,idx1,gate,alt,load: validate_image_index(buf, idx0, idx1, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(UPat()), validate_image_index),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(
|
||||
UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_image_index),
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
|
@ -189,6 +200,8 @@ shared_codegen_spec = PatternMatcher([
|
|||
|
||||
# INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available.
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat(), UPat())),
|
||||
lambda buf: True if isinstance(buf.dtype, ImageDType) else None),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
|
|
|||
|
|
@ -447,8 +447,13 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
|||
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
|
||||
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid), UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())),
|
||||
lambda: UOp(Ops.NOOP)),
|
||||
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
|
||||
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid),
|
||||
UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)),
|
||||
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
|
||||
# store of where with invalid -> gated store
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue