mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove schedule() from tests batch 2 (#15923)
* remove schedule() from tests batch 2 * batch 4
This commit is contained in:
parent
1fdcb13bfb
commit
3c8a2db870
10 changed files with 69 additions and 62 deletions
|
|
@ -24,7 +24,7 @@ def get_ast(device:str, num_inputs:int) -> UOp:
|
|||
fst = [Tensor.randn(BUF_SIZE, dtype=dtypes.int).realize() for _ in range(num_inputs)]
|
||||
s = fst[0]
|
||||
for i in range(1, num_inputs): s = s.bitwise_xor(fst[i])
|
||||
cached_asts[(device, num_inputs)] = s.schedule()[-1].ast
|
||||
cached_asts[(device, num_inputs)] = s.schedule_linear().src[-1].src[0]
|
||||
return cached_asts[(device, num_inputs)]
|
||||
|
||||
def make_buffer(device, size=BUF_SIZE, fill=False):
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ class TestLinearizer(unittest.TestCase):
|
|||
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
sched = c.schedule()
|
||||
for si in sched: si.run()
|
||||
rawbufs = sched[-1].bufs
|
||||
linear = c.schedule_linear()
|
||||
run_linear(linear)
|
||||
rawbufs = [s.buffer for s in linear.src[-1].src[1:] if s.op is not Ops.BIND]
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
|
@ -134,7 +134,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
uops = get_program(replace_opts(r.schedule()[-1].ast, [Opt(op=OptOps.UPCAST, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).uops
|
||||
uops = get_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
|
||||
renderer=Device[Device.DEFAULT].renderer).uops
|
||||
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
|
|
@ -146,7 +147,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
uops = get_program(replace_opts(r.schedule()[-1].ast, [Opt(op=OptOps.UPCAST, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).uops
|
||||
uops = get_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
|
||||
renderer=Device[Device.DEFAULT].renderer).uops
|
||||
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
|
|
@ -155,7 +157,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
uops = get_program(replace_opts(r.schedule()[-1].ast, [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]),
|
||||
uops = get_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]),
|
||||
renderer=Device[Device.DEFAULT].renderer).uops
|
||||
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
|
|
@ -168,7 +170,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
|
||||
def test_upcast_with_locals_cpu(self):
|
||||
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
|
||||
prg = get_program(replace_opts(out.schedule()[-1].ast, [Opt(OptOps.LOCAL, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
|
||||
prg = get_program(replace_opts(out.schedule_linear().src[-1].src[0], [Opt(OptOps.LOCAL, axis=0, arg=4)]),
|
||||
renderer=Device[Device.DEFAULT].renderer).uops
|
||||
self.assertEqual(len(prg.src.split("for")), 5)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
|
|
@ -179,7 +182,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(replace_opts(r.schedule()[-1].ast, opts_to_apply), renderer=Device[Device.DEFAULT].renderer)
|
||||
program = get_program(replace_opts(r.schedule_linear().src[-1].src[0], opts_to_apply), renderer=Device[Device.DEFAULT].renderer)
|
||||
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
|
|
@ -193,7 +196,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
uops = get_program(replace_opts(r.schedule()[-1].ast, [Opt(op=OptOps.UPCAST, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).uops
|
||||
uops = get_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
|
||||
renderer=Device[Device.DEFAULT].renderer).uops
|
||||
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
|
|
@ -202,14 +206,14 @@ class TestLinearizer(unittest.TestCase):
|
|||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule()[-1].ast
|
||||
realized_ast = a.schedule_linear().src[-1].src[0]
|
||||
program = get_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
assert local[0].dtype.base == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
realized_ast = c.schedule()[-1].ast
|
||||
realized_ast = c.schedule_linear().src[-1].src[0]
|
||||
program = get_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
self.assertEqual(local[0].dtype.base, expected_dtype)
|
||||
|
|
@ -267,10 +271,10 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
sched = [si for si in t.schedule() if si.ast.op is Ops.SINK]
|
||||
sched = [si for si in t.schedule_linear().src if si.src[0].op is Ops.SINK]
|
||||
# sum_collapse is a full collapse now
|
||||
assert len(sched) == 1
|
||||
assert not any(u.op is Ops.REDUCE_AXIS for u in sched[0].ast.toposort()), "found reduce in sum collapse"
|
||||
assert not any(u.op is Ops.REDUCE_AXIS for u in sched[0].src[0].toposort()), "found reduce in sum collapse"
|
||||
#lin = Kernel(sched[0].ast)
|
||||
#assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.schedule import ExecItem
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from test.helpers import replace_opts
|
||||
|
||||
class TestOptGemm(unittest.TestCase):
|
||||
|
|
@ -19,10 +19,10 @@ class TestOptGemm(unittest.TestCase):
|
|||
def _test_gemm_unrolled_permute_l(self, opts=[]):
|
||||
t = self.a.T @ self.b.T
|
||||
# TODO: this should be a generic test helper
|
||||
si = get_single_element(t.schedule())
|
||||
run = CompiledRunner(get_program(replace_opts(si.ast, opts), renderer=Device[Device.DEFAULT].renderer))
|
||||
ExecItem(si.ast, list(si.bufs), prg=run).run()
|
||||
test = si.bufs[0].numpy().reshape(self.res.shape)
|
||||
call = get_single_element(t.schedule_linear().src)
|
||||
new_call = call.replace(src=(replace_opts(call.src[0], opts), *call.src[1:]))
|
||||
run_linear(UOp(Ops.LINEAR, src=(new_call,)))
|
||||
test = call.src[1].buffer.numpy().reshape(self.res.shape)
|
||||
np.testing.assert_allclose(self.res, test, atol=1e-4)
|
||||
|
||||
def test_gemm_unrolled_permute_l_44(self):
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ class TestProfiler(unittest.TestCase):
|
|||
|
||||
TestProfiler.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestProfiler.b = self.a + 1
|
||||
si = self.b.schedule()[-1]
|
||||
si = self.b.schedule_linear().src[-1]
|
||||
|
||||
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast)
|
||||
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.src[0])
|
||||
TestProfiler.b.uop.buffer.allocate()
|
||||
|
||||
def test_profile_kernel_run(self):
|
||||
|
|
|
|||
|
|
@ -69,10 +69,9 @@ class TestCStyleFailures(unittest.TestCase):
|
|||
dtype = "bool" if op in (Ops.OR, Ops.XOR, Ops.AND) else None
|
||||
ret = Tensor.empty(1, dtype=dtype)
|
||||
for _ in range(5): ret = python_alu[op](ret, Tensor.empty(1, dtype=dtype))
|
||||
schedule = ret.schedule()
|
||||
assert len(schedule) == 1
|
||||
schedule[0].lower()
|
||||
src = schedule[0].prg.p.src
|
||||
linear = ret.schedule_linear()
|
||||
assert len(linear.src) == 1
|
||||
src = get_program(linear.src[0].src[0], Device[Device.DEFAULT].renderer).src
|
||||
self.assertEqual("("*5 not in src, should_strip_paren)
|
||||
|
||||
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, Context, Device
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.helpers import DEBUG, get_single_element
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
|
|
@ -26,7 +27,10 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te
|
|||
out = e.div(ss).reshape(x_in.shape)
|
||||
return out
|
||||
|
||||
def run_one_schedule_item(out): get_single_element(out.schedule()).run()
|
||||
def run_one_schedule_item(out):
|
||||
linear = out.schedule_linear()
|
||||
get_single_element(linear.src)
|
||||
run_linear(linear)
|
||||
|
||||
class TestFuse(unittest.TestCase):
|
||||
def _test_fuse(self, fxn, *args, atol=1e-6, allow_multiple=False, **kwargs):
|
||||
|
|
@ -100,8 +104,8 @@ class TestFuse(unittest.TestCase):
|
|||
k = (x @ wk).contiguous()
|
||||
v = (x @ wv).contiguous()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
s = attn.schedule()
|
||||
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
|
||||
s = attn.schedule_linear()
|
||||
self.assertEqual(len(s.src), 4) # 3 matmul and 1 attention
|
||||
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_flash_attention(self):
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class TestAssembly(unittest.TestCase):
|
|||
a = Tensor.empty(1024)
|
||||
b = Tensor.empty(1024)
|
||||
c = (a*b).sum()
|
||||
ast = c.schedule()[-1].ast
|
||||
ast = c.schedule_linear().src[-1].src[0]
|
||||
opts_to_apply = [Opt(OptOps.UNROLL, 0, 4)]
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
|
|
|
|||
|
|
@ -17,41 +17,41 @@ class TestTensorMetadata(unittest.TestCase):
|
|||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
k = a.schedule_linear().src[-1]
|
||||
self.assertEqual([m.name for m in k.arg.metadata], ["rand"])
|
||||
|
||||
@unittest.skip("metadata not reaching kernel schedule")
|
||||
def test_exclude_const_metadata(self):
|
||||
a = Tensor.arange(4)
|
||||
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
||||
sched = Tensor.schedule(a, b)
|
||||
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
||||
sched = a.schedule_linear(b)
|
||||
self.assertEqual([m.name for m in sched.src[0].arg.metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched.src[1].arg.metadata], ["contiguous"])
|
||||
|
||||
def test_matmul(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
si = out.schedule_linear().src[-1]
|
||||
self.assertEqual(len(si.arg.metadata), 1)
|
||||
self.assertEqual(si.arg.metadata[0].name, "matmul")
|
||||
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.uop.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
si = out.schedule_linear().src[-1]
|
||||
self.assertEqual(len(si.arg.metadata), 1)
|
||||
self.assertEqual(si.arg.metadata[0].name, "relu")
|
||||
|
||||
@unittest.skip("assign metadata no longer captured")
|
||||
def test_assign(self):
|
||||
x = Tensor.empty(10, 10).realize()
|
||||
x.assign(Tensor.ones(10, 10).contiguous())
|
||||
si = x.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "assign")
|
||||
si = x.schedule_linear().src[-1]
|
||||
self.assertEqual(len(si.arg.metadata), 1)
|
||||
self.assertEqual(si.arg.metadata[0].name, "assign")
|
||||
|
||||
def test_complex(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
|
|
@ -60,9 +60,9 @@ class TestTensorMetadata(unittest.TestCase):
|
|||
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
si = out.schedule_linear().src[-1]
|
||||
self.assertEqual(len(si.arg.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.arg.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_complex_backward(self):
|
||||
|
|
@ -75,10 +75,10 @@ class TestTensorMetadata(unittest.TestCase):
|
|||
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
si = out.schedule_linear(x.grad, y.grad).src[-1]
|
||||
#self.assertEqual(len(si.arg.metadata), 3, f"failed with {si.arg.metadata}")
|
||||
# skip numpy, this is schedule cache
|
||||
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
||||
self.assertSetEqual(set(m.name for m in si.arg.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
||||
#bw = [m for m in si.metadata if m.backward]
|
||||
#self.assertEqual(len(bw), 1)
|
||||
#self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
|
@ -90,8 +90,8 @@ class TestTensorMetadata(unittest.TestCase):
|
|||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertIsNone(out.uop.metadata)
|
||||
self.assertIsNone(out.uop.src[0].metadata)
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(si.metadata, ())
|
||||
si = out.schedule_linear().src[-1]
|
||||
self.assertEqual(si.arg.metadata, ())
|
||||
|
||||
def _has_metadata(self, h, name):
|
||||
linears = []
|
||||
|
|
|
|||
|
|
@ -320,13 +320,13 @@ class TestVizGC(unittest.TestCase):
|
|||
# VIZ integrates with other parts of tinygrad
|
||||
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.engine.realize import get_program, get_runner
|
||||
|
||||
class TestVizIntegration(unittest.TestCase):
|
||||
# codegen supports rendering of code blocks
|
||||
def test_codegen_tracing(self):
|
||||
with save_viz() as viz:
|
||||
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
|
||||
ast = (Tensor.empty(4)+Tensor.empty(4)).schedule_linear().src[0].src[0]
|
||||
prg = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
lst = viz.list_items()
|
||||
self.assertEqual(len(lst), 3)
|
||||
|
|
@ -339,8 +339,8 @@ class TestVizIntegration(unittest.TestCase):
|
|||
with save_viz() as viz:
|
||||
c1 = Tensor.empty(4).add(1)
|
||||
c2 = Tensor.empty(8).add(1)
|
||||
sched = Tensor.schedule(c1, c2)
|
||||
prgs = [get_program(si.ast, Device[Device.DEFAULT].renderer).name for si in sched]
|
||||
sched = c1.schedule_linear(c2)
|
||||
prgs = [get_program(si.src[0], Device[Device.DEFAULT].renderer).name for si in sched.src]
|
||||
lst = viz.list_items()
|
||||
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
|
||||
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
|
||||
|
|
@ -356,7 +356,7 @@ class TestVizIntegration(unittest.TestCase):
|
|||
a = Tensor.empty(1)
|
||||
b = Tensor.empty(1)
|
||||
metadata = (alu:=a+b).uop.metadata
|
||||
alu.schedule()
|
||||
alu.schedule_linear()
|
||||
graph = next(viz.get_details(0, 0))["graph"]
|
||||
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
|
||||
|
||||
|
|
@ -724,7 +724,7 @@ class TestCfg(unittest.TestCase):
|
|||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="NULL"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
with Context(DEV=f"NULL::{self.arch}"):
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
prg = out.schedule()[-1].lower().prg.p
|
||||
prg = get_runner(out.device, out.schedule_linear().src[-1].src[0]).p
|
||||
return amdgpu_cfg(prg.lib, self.arch)
|
||||
|
||||
def test_simple(self):
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@ class TestWinograd(unittest.TestCase):
|
|||
def test_forward_kernels(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
out = Tensor.conv2d(x,w)
|
||||
self.assertEqual(len(out.schedule()), 2)
|
||||
self.assertEqual(len(out.schedule_linear().src), 2)
|
||||
|
||||
def test_backward_kernels(self):
|
||||
x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = Tensor.schedule(x.grad, w.grad)
|
||||
self.assertEqual(len(backward_schedule), 4)
|
||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||
self.assertEqual(len(backward_schedule.src), 4)
|
||||
|
||||
def test_counters(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue