mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
pow not working
This commit is contained in:
parent
a898517b39
commit
373f5831ec
2 changed files with 28 additions and 39 deletions
|
|
@ -14,10 +14,11 @@ def compute_offsets(total):
|
|||
|
||||
#NOTE: Darwin needs lm functions to start with a "_"
|
||||
def get_op(op): return f"bl {'_' if system() == 'Darwin' else ''}{op}"
|
||||
|
||||
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'x',dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
class ARM64Codegen(AssemblyCodegen):
|
||||
def specialize(self, asm):
|
||||
rtor:Dict[Register, str] = {}
|
||||
var_size = 0
|
||||
prev_uop = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (9,10,11,12,13,14,15,16,17,18,19,20)]
|
||||
|
|
@ -26,8 +27,6 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", BinaryOps.CMPEQ: "subs",
|
||||
UnaryOps.SIN: get_op('sinf'), UnaryOps.LOG2: get_op("log2f"), UnaryOps.EXP2: get_op("exp2f"), UnaryOps.SQRT: get_op("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcmp"}
|
||||
reg_map = {}
|
||||
var_size = 0
|
||||
def mov_imm(value, to):
|
||||
# Manually move value into reg if vin[1] can't fit
|
||||
if value > 65535:
|
||||
|
|
@ -35,6 +34,8 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {to}, w15")
|
||||
elif to[0] == 's':
|
||||
#NOTE: value comes as int when it should be float
|
||||
#value = float(value) if value.__class__ is int else value
|
||||
ins.append(f"movz x15, {'0x' + float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append(f"movk x15, {'0x' + float_to_hex(value)[4:]}")
|
||||
ins.append(f"scvtf {to}, w15")
|
||||
|
|
@ -47,9 +48,8 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
|
||||
mem_vars = {}
|
||||
temp_floats = ['s0', 's1', 's0']
|
||||
temp_ints = ['x13', 'x14', 'x13']
|
||||
temp_ints = ['x13', 'x12', 'x13']
|
||||
def load_var(vin):
|
||||
prev_reg = None
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
|
|
@ -62,13 +62,15 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
if out.nm in mem_vars:
|
||||
ins.append(f"str {rtor[out.nm]}, {mem_vars[out.nm]}")
|
||||
|
||||
mem_vars = {}
|
||||
def allocate_regs(vars):
|
||||
nonlocal var_size
|
||||
for i,v in enumerate([v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]):
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if len(available_regs) == 0:
|
||||
var_size += 16
|
||||
reg ='s1' if dtypes.is_float(out[1]) else 'x14'
|
||||
reg ='s1' if dtypes.is_float(out[1]) else 'x12'
|
||||
available_regs.append(reg)
|
||||
mem_vars[v.nm] = f"[sp, #{var_size}]"
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
|
@ -78,27 +80,13 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] != 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Then we assign a register to the variable produced by the instruction.
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs(vin)
|
||||
allocate_regs([out])
|
||||
# for i,v in enumerate([v for v in vin if v is not None and v.__class__ is not int and v.nm not in rtor]):
|
||||
# available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
# if len(available_regs) == 0:
|
||||
# var_size += 16
|
||||
# available_regs.append(temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i])
|
||||
# mem_vars[v.nm] = f"[sp, #{var_size}]"
|
||||
# rtor[v.nm] = available_regs.pop()
|
||||
# if out is not None and out.nm not in rtor:
|
||||
# available_regs = s_regs if dtypes.is_float(out[1]) else x_regs
|
||||
# #print(available_regs)
|
||||
# if len(available_regs) == 0:
|
||||
# var_size += 16
|
||||
# reg ='s1' if dtypes.is_float(out[1]) else 'x14'
|
||||
# available_regs.append(reg)
|
||||
# mem_vars[out.nm] = f"[sp, #{var_size}]"
|
||||
# rtor[out.nm] = available_regs.pop()
|
||||
|
||||
if uop == UOps.DEFINE_GLOBAL:
|
||||
if arg.startswith('data'):
|
||||
# args 8 onward goes into the stack, so we move them into regs
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
|
|
@ -122,24 +110,22 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
ins.append(f"{alu[arg]} {rtor[vin[0].nm]}, s0")
|
||||
ins.append(f"fcsel {rtor[out.nm]},{rtor[vin[2].nm]}, {rtor[vin[1].nm]}, eq")
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
#ins.append(f"sub sp, sp, #{len(rtor)*16}")
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
ins.append(f"sub sp, sp, #{len(save_regs)*16}")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"mov x15, #{(16*i)}")
|
||||
ins.append(f"str {rtor[k]}, [sp, x15]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
for i,k in enumerate([k for k in rtor.keys() if k != out.nm and k not in mem_vars]):
|
||||
var_size += 16
|
||||
ins.append(f"str {rtor[k]}, [sp, #{(var_size)}]")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
var_size_local = var_size
|
||||
for i,k in enumerate([k for k in reversed(rtor.keys()) if k != out.nm and k not in mem_vars]):
|
||||
if k != out.nm and k not in mem_vars:
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{(var_size_local)}]")
|
||||
var_size_local += -16
|
||||
#ins.append(f"add sp, sp, #{(len(rtor) - len(mem_vars))*16}")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"mov x15, #{(16*i)}")
|
||||
ins.append(f"ldr {rtor[k]}, [sp, x15]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[0].nm]}, {'x15' if vin[1].__class__ is int else rtor[vin[1].nm]}" if reg == 'x' else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
elif arg == BinaryOps.MOD:
|
||||
|
|
@ -153,13 +139,19 @@ class ARM64Codegen(AssemblyCodegen):
|
|||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
load_var(vin)
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr {rtor[out.nm]}, [x15]")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
store_var(out)
|
||||
elif uop == UOps.STORE:
|
||||
load_var(vin)
|
||||
ins.append(f"str {rtor[vin[1].nm]}, [{rtor[vin[0].nm]}, #{arg[0]}]")
|
||||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2"}
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
|
|
|
|||
|
|
@ -190,7 +190,6 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True)
|
||||
@unittest.skipIf(getenv("ARM64") >0, "working on it")
|
||||
def test_ceil(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
|
|
@ -315,7 +314,6 @@ class TestOps(unittest.TestCase):
|
|||
def test_leakyrelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
@unittest.skipIf(getenv("ARM64"), "fix later")
|
||||
def test_celu(self):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
|
|
@ -366,7 +364,6 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
|
||||
@unittest.skipIf(getenv("ARM64"), "fix later")
|
||||
def test_mish(self):
|
||||
def _mish_pytorch(x):
|
||||
return x*torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue