mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
refactoring a bit
This commit is contained in:
parent
9ab399a67f
commit
247db8d408
3 changed files with 16 additions and 26 deletions
|
|
@ -165,20 +165,14 @@ class AssemblyCodegen(Linearizer):
|
|||
pred = args.valid.render(render_ops)
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
if buf_to_dtype[args.name] != dtypes.float:
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype))) #if args.i != -1 else 'shared')
|
||||
else:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global'))) # if args.i != -1 else 'shared'
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None))) #if args.i != -1 else 'shared')
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, treg, off = addr_w_offset(args)
|
||||
if buf_to_dtype['data0'] != dtypes.float:
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype))) #if args.i != -1 else 'shared')
|
||||
else:
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global'))) #if args.i != -1 else 'shared'
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None))) #if args.i != -1 else 'shared')
|
||||
|
||||
# define registers
|
||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
|
||||
|
|
|
|||
|
|
@ -29,12 +29,13 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
ins.append(f"movk w2, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {to}, w2")
|
||||
else:
|
||||
ins.append(f"mov {to}, #{value}")
|
||||
ins.append(f"mov {to}, {'#' + str(value) if value.__class__ is int else '0x' + float_to_hex(arg)}")
|
||||
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
for i in range(arg[2]):
|
||||
var_size += 16
|
||||
#TODO: Find a way to use less memory lookups. Graph coloring?
|
||||
reg_map[f"%{arg[1]}{i}"] = f"[sp, #{var_size}]"
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
if arg.startswith('data'):
|
||||
|
|
@ -78,35 +79,29 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
ins.append(f"{'f' if reg == 's' else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {reg}0, {reg}0, {reg}1")
|
||||
ins.append(f"str {reg}{'2' if arg == BinaryOps.MOD else '0'}, {reg_map[out.nm]}")
|
||||
elif uop == UOps.LOAD:
|
||||
if isinstance(arg, float):
|
||||
ins.append(f"mov x0, 0x{float_to_hex(arg)}")
|
||||
ins.append("fmov s0, w0")
|
||||
ins.append(f"str s0, {reg_map[out.nm]}")
|
||||
elif arg.__class__ is int:
|
||||
mov_imm(arg, f"x0")
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg,"x0")
|
||||
ins.append(f"str x0, {reg_map[out.nm]}")
|
||||
else:
|
||||
need_cast = len(arg) == 3
|
||||
reg_out = 's0' if dtypes.is_float(out[1]) else 'x0'
|
||||
reg_in = type_to_reg[arg[2] if need_cast else out[1]] + '0'
|
||||
reg_in = type_to_reg[arg[2] if arg[2] is not None else out[1]] + '0'
|
||||
ins.append(f"ldr x1, {reg_map[vin[0].nm]}")
|
||||
# Manually offset in case it can't fix in imm
|
||||
mov_imm(abs(arg[0]), "x2")
|
||||
ins.append(f"{'sub' if arg[0] < 0 else 'add'} x1, x1, x2")
|
||||
ins.append(f"ldr{'sb' if need_cast and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x1]")
|
||||
if need_cast: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} s0, {reg_in}")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x1]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} s0, {reg_in}")
|
||||
ins.append(f"str {reg_out}, {reg_map[out.nm]}")
|
||||
elif uop == UOps.STORE:
|
||||
need_cast = len(arg) == 3
|
||||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2"}
|
||||
ins.append(f"ldr s0, {reg_map[vin[1].nm]}")
|
||||
reg_out = (type_to_reg[arg[2]] if need_cast else "s") + '0'
|
||||
if need_cast: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, s0")
|
||||
reg_out = (type_to_reg[arg[2]] if arg[2] is not None else "s") + '0'
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, s0")
|
||||
ins.append(f"mov x3, #{arg[0]}")
|
||||
ins.append(f"ldr x2, {reg_map[vin[0].nm]}")
|
||||
ins.append(f"str {reg_out}, [x2, x3, lsl {shifts[arg[2]] if need_cast and arg[2] in shifts else '#0'}]")
|
||||
ins.append(f"str {reg_out}, [x2, x3, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
#TODO: this is a hack it should always be a cmp before a cond branch?
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
ins.append(f"ldr x1, {reg_map[vin[0].nm]}")
|
||||
ins.append(f"cmp x1, #0")
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ class ClangProgram:
|
|||
os.rename(fn+'.tmp', fn)
|
||||
else:
|
||||
if DEBUG >= 5: print(prg)
|
||||
subprocess.check_output(["as","-arch", "arm64", "-o", f"{fn}.o"], input=prg.encode('utf-8'))
|
||||
subprocess.check_output(["clang", "-lm", "-shared", f"{fn}.o", "-o", fn])
|
||||
if getenv('ARM64'):
|
||||
subprocess.check_output(args=('as -arch arm64 -o '+fn+'.o').split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -lm -O2 -Wall -shared '+fn+'.o -o'+fn).split())
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue