pow not working

This commit is contained in:
Steven Anderson 2023-07-30 00:08:39 -04:00
commit 373f5831ec
2 changed files with 28 additions and 39 deletions

View file

@ -14,10 +14,11 @@ def compute_offsets(total):
#NOTE: Darwin needs lm functions to start with a "_"
def get_op(op): return f"bl {'_' if system() == 'Darwin' else ''}{op}"
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'x',dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
class ARM64Codegen(AssemblyCodegen):
def specialize(self, asm):
rtor:Dict[Register, str] = {}
var_size = 0
prev_uop = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (9,10,11,12,13,14,15,16,17,18,19,20)]
@ -26,8 +27,6 @@ class ARM64Codegen(AssemblyCodegen):
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", BinaryOps.CMPEQ: "subs",
UnaryOps.SIN: get_op('sinf'), UnaryOps.LOG2: get_op("log2f"), UnaryOps.EXP2: get_op("exp2f"), UnaryOps.SQRT: get_op("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcmp"}
reg_map = {}
var_size = 0
def mov_imm(value, to):
# Manually move value into reg if vin[1] can't fit
if value > 65535:
@ -35,6 +34,8 @@ class ARM64Codegen(AssemblyCodegen):
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {to}, w15")
elif to[0] == 's':
#NOTE: value comes as int when it should be float
#value = float(value) if value.__class__ is int else value
ins.append(f"movz x15, {'0x' + float_to_hex(value)[:4]}, lsl #16")
ins.append(f"movk x15, {'0x' + float_to_hex(value)[4:]}")
ins.append(f"scvtf {to}, w15")
@ -47,9 +48,8 @@ class ARM64Codegen(AssemblyCodegen):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
mem_vars = {}
temp_floats = ['s0', 's1', 's0']
temp_ints = ['x13', 'x14', 'x13']
temp_ints = ['x13', 'x12', 'x13']
def load_var(vin):
prev_reg = None
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
@ -62,13 +62,15 @@ class ARM64Codegen(AssemblyCodegen):
if out.nm in mem_vars:
ins.append(f"str {rtor[out.nm]}, {mem_vars[out.nm]}")
mem_vars = {}
def allocate_regs(vars):
nonlocal var_size
for i,v in enumerate([v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]):
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if len(available_regs) == 0:
var_size += 16
reg ='s1' if dtypes.is_float(out[1]) else 'x14'
reg ='s1' if dtypes.is_float(out[1]) else 'x12'
available_regs.append(reg)
mem_vars[v.nm] = f"[sp, #{var_size}]"
rtor[v.nm] = available_regs.pop()
@ -78,27 +80,13 @@ class ARM64Codegen(AssemblyCodegen):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] != 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Then we assign a register to the variable produced by the instruction.
# Assign a registers to the variables using live ranges.
allocate_regs(vin)
allocate_regs([out])
# for i,v in enumerate([v for v in vin if v is not None and v.__class__ is not int and v.nm not in rtor]):
# available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
# if len(available_regs) == 0:
# var_size += 16
# available_regs.append(temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i])
# mem_vars[v.nm] = f"[sp, #{var_size}]"
# rtor[v.nm] = available_regs.pop()
# if out is not None and out.nm not in rtor:
# available_regs = s_regs if dtypes.is_float(out[1]) else x_regs
# #print(available_regs)
# if len(available_regs) == 0:
# var_size += 16
# reg ='s1' if dtypes.is_float(out[1]) else 'x14'
# available_regs.append(reg)
# mem_vars[out.nm] = f"[sp, #{var_size}]"
# rtor[out.nm] = available_regs.pop()
if uop == UOps.DEFINE_GLOBAL:
if arg.startswith('data'):
# args 8 onward goes into the stack, so we move them into regs
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
@ -122,24 +110,22 @@ class ARM64Codegen(AssemblyCodegen):
ins.append(f"{alu[arg]} {rtor[vin[0].nm]}, s0")
ins.append(f"fcsel {rtor[out.nm]},{rtor[vin[2].nm]}, {rtor[vin[1].nm]}, eq")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#ins.append(f"sub sp, sp, #{len(rtor)*16}")
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{len(save_regs)*16}")
for i,k in enumerate(save_regs,1):
ins.append(f"mov x15, #{(16*i)}")
ins.append(f"str {rtor[k]}, [sp, x15]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
for i,k in enumerate([k for k in rtor.keys() if k != out.nm and k not in mem_vars]):
var_size += 16
ins.append(f"str {rtor[k]}, [sp, #{(var_size)}]")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
var_size_local = var_size
for i,k in enumerate([k for k in reversed(rtor.keys()) if k != out.nm and k not in mem_vars]):
if k != out.nm and k not in mem_vars:
ins.append(f"ldr {rtor[k]}, [sp, #{(var_size_local)}]")
var_size_local += -16
#ins.append(f"add sp, sp, #{(len(rtor) - len(mem_vars))*16}")
for i,k in enumerate(save_regs,1):
ins.append(f"mov x15, #{(16*i)}")
ins.append(f"ldr {rtor[k]}, [sp, x15]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[0].nm]}, {'x15' if vin[1].__class__ is int else rtor[vin[1].nm]}" if reg == 'x' else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
@ -153,13 +139,19 @@ class ARM64Codegen(AssemblyCodegen):
mov_imm(arg, rtor[out.nm])
else:
load_var(vin)
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr {rtor[out.nm]}, [x15]")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
store_var(out)
elif uop == UOps.STORE:
load_var(vin)
ins.append(f"str {rtor[vin[1].nm]}, [{rtor[vin[0].nm]}, #{arg[0]}]")
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2"}
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:

View file

@ -190,7 +190,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True)
@unittest.skipIf(getenv("ARM64") >0, "working on it")
def test_ceil(self):
helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
@ -315,7 +314,6 @@ class TestOps(unittest.TestCase):
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
@unittest.skipIf(getenv("ARM64"), "fix later")
def test_celu(self):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
@ -366,7 +364,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
@unittest.skipIf(getenv("ARM64"), "fix later")
def test_mish(self):
def _mish_pytorch(x):
return x*torch.tanh(torch.nn.functional.softplus(x))