mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
time with real global buffers in search (#4621)
* filter fake buffers in search * test that * update test
This commit is contained in:
parent
e5d4e6a8aa
commit
c86adabe15
2 changed files with 20 additions and 3 deletions
|
|
@ -3,12 +3,15 @@ import unittest
|
|||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.ops import LoadOps, BufferOps
|
||||
from tinygrad.ops import LazyOp, LoadOps, BufferOps, ReduceOps, BinaryOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
|
|
@ -66,5 +69,18 @@ class TestBEAM(unittest.TestCase):
|
|||
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
|
||||
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV"}, "Tries to open CUDA. #4607")
|
||||
def test_filter_global_buffer(self):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
lin = Linearizer(ast)
|
||||
|
||||
bufs = bufs_from_lin(lin)
|
||||
best_lin = beam_search(lin, bufs, 3)
|
||||
assert best_lin
|
||||
# need disable_cache to trigger.
|
||||
tm = time_linearizer(best_lin, bufs, allow_test_size=False, cnt=2, disable_cache=True)
|
||||
assert tm
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -42,10 +42,11 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
|||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize()
|
||||
tms.append(cast(float, car(rawbufs, var_vals, wait=True))*factor)
|
||||
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue