UOps.VECTORIZE cleanups [run_process_replay] (#5314)

* still render_cast

* one extra line ok

* these are all just vectorize

* save space

* behavior change can go in a different diff
This commit is contained in:
qazal 2024-07-07 10:49:08 +03:00 committed by GitHub
commit ae10e936e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 21 additions and 22 deletions

View file

@ -100,20 +100,19 @@ class PythonProgram:
del ul[i]
i = loop_ends[i] + 1
continue
elif uop in (UOps.CAST, UOps.BITCAST, UOps.VECTORIZE):
if dtype.count > 1: ul[i] = inp
elif uop is UOps.VECTORIZE: ul[i] = inp
elif uop in {UOps.CAST, UOps.BITCAST}:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
if dtypes.is_int(dtype):
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
elif dtypes.is_float(dtype):
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
if dtypes.is_int(dtype):
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
elif dtypes.is_float(dtype):
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):
assert dtype.count == 4