mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
py
This commit is contained in:
parent
ab71e51ae1
commit
ee8ea27637
2 changed files with 14 additions and 12 deletions
|
|
@ -101,16 +101,18 @@ class PythonProgram:
|
|||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
elif uop is Ops.SLICE:
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
values[i] = ret
|
||||
elif uop is Ops.INDEX:
|
||||
assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
|
||||
ret:list = []
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
|
|
|
|||
|
|
@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
|
|||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
||||
(UPat(Ops.SLICE, name="uidx").or_casted().load(), validate_index),
|
||||
(UPat(Ops.SLICE, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.SLICE, name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat(Ops.SLICE, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
|
||||
# STORE in tensor graph: store a value into a target
|
||||
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue