refactoring a bit

This commit is contained in:
Steven Anderson 2023-07-25 17:22:16 -04:00
commit 247db8d408
3 changed files with 16 additions and 26 deletions

View file

@ -165,20 +165,14 @@ class AssemblyCodegen(Linearizer):
pred = args.valid.render(render_ops)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
if buf_to_dtype[args.name] != dtypes.float:
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype))) #if args.i != -1 else 'shared')
else:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global'))) # if args.i != -1 else 'shared'
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None))) #if args.i != -1 else 'shared')
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = addr_w_offset(args)
if buf_to_dtype['data0'] != dtypes.float:
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype))) #if args.i != -1 else 'shared')
else:
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global'))) #if args.i != -1 else 'shared'
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None))) #if args.i != -1 else 'shared')
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins

View file

@ -29,12 +29,13 @@ class ARM64Codegen(AssemblyCodegen):
ins.append(f"movk w2, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {to}, w2")
else:
ins.append(f"mov {to}, #{value}")
ins.append(f"mov {to}, {'#' + str(value) if value.__class__ is int else '0x' + float_to_hex(arg)}")
for i, (uop, out, vin, arg) in enumerate(asm):
if uop == UOps.DEFINE_REGISTER:
for i in range(arg[2]):
var_size += 16
#TODO: Find a way to use less memory lookups. Graph coloring?
reg_map[f"%{arg[1]}{i}"] = f"[sp, #{var_size}]"
elif uop == UOps.DEFINE_GLOBAL:
if arg.startswith('data'):
@ -78,35 +79,29 @@ class ARM64Codegen(AssemblyCodegen):
ins.append(f"{'f' if reg == 's' else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {reg}0, {reg}0, {reg}1")
ins.append(f"str {reg}{'2' if arg == BinaryOps.MOD else '0'}, {reg_map[out.nm]}")
elif uop == UOps.LOAD:
if isinstance(arg, float):
ins.append(f"mov x0, 0x{float_to_hex(arg)}")
ins.append("fmov s0, w0")
ins.append(f"str s0, {reg_map[out.nm]}")
elif arg.__class__ is int:
mov_imm(arg, f"x0")
if arg.__class__ in (int, float):
mov_imm(arg,"x0")
ins.append(f"str x0, {reg_map[out.nm]}")
else:
need_cast = len(arg) == 3
reg_out = 's0' if dtypes.is_float(out[1]) else 'x0'
reg_in = type_to_reg[arg[2] if need_cast else out[1]] + '0'
reg_in = type_to_reg[arg[2] if arg[2] is not None else out[1]] + '0'
ins.append(f"ldr x1, {reg_map[vin[0].nm]}")
# Manually offset in case it can't fix in imm
mov_imm(abs(arg[0]), "x2")
ins.append(f"{'sub' if arg[0] < 0 else 'add'} x1, x1, x2")
ins.append(f"ldr{'sb' if need_cast and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x1]")
if need_cast: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} s0, {reg_in}")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x1]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} s0, {reg_in}")
ins.append(f"str {reg_out}, {reg_map[out.nm]}")
elif uop == UOps.STORE:
need_cast = len(arg) == 3
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2"}
ins.append(f"ldr s0, {reg_map[vin[1].nm]}")
reg_out = (type_to_reg[arg[2]] if need_cast else "s") + '0'
if need_cast: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, s0")
reg_out = (type_to_reg[arg[2]] if arg[2] is not None else "s") + '0'
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, s0")
ins.append(f"mov x3, #{arg[0]}")
ins.append(f"ldr x2, {reg_map[vin[0].nm]}")
ins.append(f"str {reg_out}, [x2, x3, lsl {shifts[arg[2]] if need_cast and arg[2] in shifts else '#0'}]")
ins.append(f"str {reg_out}, [x2, x3, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it should always be a cmp before a cond branch?
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:
ins.append(f"ldr x1, {reg_map[vin[0].nm]}")
ins.append(f"cmp x1, #0")

View file

@ -21,8 +21,9 @@ class ClangProgram:
os.rename(fn+'.tmp', fn)
else:
if DEBUG >= 5: print(prg)
subprocess.check_output(["as","-arch", "arm64", "-o", f"{fn}.o"], input=prg.encode('utf-8'))
subprocess.check_output(["clang", "-lm", "-shared", f"{fn}.o", "-o", fn])
if getenv('ARM64'):
subprocess.check_output(args=('as -arch arm64 -o '+fn+'.o').split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -O2 -Wall -shared '+fn+'.o -o'+fn).split())
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]