This commit is contained in:
George Hotz 2026-05-26 13:55:34 -07:00
commit ee8ea27637
2 changed files with 14 additions and 12 deletions

View file

@ -101,16 +101,18 @@ class PythonProgram:
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX:
elif uop is Ops.SLICE:
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
assert len(src_values) == 3, "image index must be 3 srcs"
for m,oy,ox in zip(*src_values):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else:
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
values[i] = ret
elif uop is Ops.INDEX:
assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
ret:list = []
assert len(src_values) == 3, "image index must be 3 srcs"
for m,oy,ox in zip(*src_values):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
values[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0]

View file

@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
(UPat(Ops.INS), lambda: True),
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
(UPat(Ops.SLICE, name="uidx").or_casted().load(), validate_index),
(UPat(Ops.SLICE, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.SLICE, name="uidx").or_casted().store(UPat()), validate_index),
(UPat(Ops.SLICE, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
# STORE in tensor graph: store a value into a target
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),