mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
9 commits
master
...
fixup_fix_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29262a7543 | ||
|
|
dcc6ddf0eb | ||
|
|
d1d935242b | ||
|
|
6b330f302d | ||
|
|
a27019383d | ||
|
|
8b285e193a | ||
|
|
f0c9b11b9e | ||
|
|
1902a85ac1 | ||
|
|
b2fc111e3f |
5 changed files with 47 additions and 42 deletions
46
.github/workflows/test.yml
vendored
46
.github/workflows/test.yml
vendored
|
|
@ -870,29 +870,29 @@ jobs:
|
||||||
- name: Test ONNX Runner (WEBGPU)
|
- name: Test ONNX Runner (WEBGPU)
|
||||||
run: WEBGPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_runner.py
|
run: WEBGPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_runner.py
|
||||||
|
|
||||||
osxremote:
|
#osxremote:
|
||||||
name: MacOS (remote metal)
|
# name: MacOS (remote metal)
|
||||||
runs-on: macos-15
|
# runs-on: macos-15
|
||||||
timeout-minutes: 10
|
# timeout-minutes: 10
|
||||||
env:
|
# env:
|
||||||
REMOTE: 1
|
# REMOTE: 1
|
||||||
REMOTEDEV: METAL
|
# REMOTEDEV: METAL
|
||||||
steps:
|
# steps:
|
||||||
- name: Checkout Code
|
# - name: Checkout Code
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
- name: Setup Environment
|
# - name: Setup Environment
|
||||||
uses: ./.github/actions/setup-tinygrad
|
# uses: ./.github/actions/setup-tinygrad
|
||||||
with:
|
# with:
|
||||||
key: macos-remote
|
# key: macos-remote
|
||||||
deps: testing_minimal
|
# deps: testing_minimal
|
||||||
- name: Check Device.DEFAULT and print some source
|
# - name: Check Device.DEFAULT and print some source
|
||||||
run: |
|
# run: |
|
||||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
# python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
# python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
||||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
# DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||||
- name: Run REMOTE=1 Test
|
# - name: Run REMOTE=1 Test
|
||||||
run: |
|
# run: |
|
||||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
# python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
||||||
|
|
||||||
amdremote:
|
amdremote:
|
||||||
name: Linux (remote)
|
name: Linux (remote)
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ view_right = merge_views+PatternMatcher([
|
||||||
# apply view after reduceops
|
# apply view after reduceops
|
||||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||||
# apply view after elementwise ops
|
# apply view after elementwise ops
|
||||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.LOAD, Ops.STORE}, name="root"), elementwise_view_right),
|
||||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
|
||||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
|
||||||
|
|
||||||
# **** Grouper decides which of the UOps realize
|
# **** Grouper decides which of the UOps realize
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -150,18 +150,11 @@ create_kernels = PatternMatcher([
|
||||||
|
|
||||||
# **** fix kernel AST
|
# **** fix kernel AST
|
||||||
|
|
||||||
add_buffer_ops = PatternMatcher([
|
early_buffer_ops = PatternMatcher([
|
||||||
# LOAD
|
# LOAD
|
||||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x), tag=1)),
|
||||||
# STORE (except for meta ops)
|
# no SINK for meta ops
|
||||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
|
||||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
|
||||||
# passthrough ASSIGN
|
|
||||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
|
||||||
# VALID
|
|
||||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
|
||||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def check_load_st(glbl:UOp, view:UOp):
|
def check_load_st(glbl:UOp, view:UOp):
|
||||||
|
|
@ -175,6 +168,16 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||||
|
|
||||||
fix_kernel_ops = PatternMatcher([
|
fix_kernel_ops = PatternMatcher([
|
||||||
|
# add the LOAD
|
||||||
|
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
||||||
|
# STORE (except for meta ops)
|
||||||
|
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||||
|
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||||
|
# passthrough ASSIGN
|
||||||
|
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||||
|
# VALID
|
||||||
|
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||||
|
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||||
|
|
@ -193,10 +196,6 @@ replace_globals = PatternMatcher([
|
||||||
|
|
||||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
||||||
# replace global memory ops with the BUFFER they write to
|
|
||||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
|
||||||
# push views to edges
|
|
||||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
|
||||||
# replace buffer with define_global + add load/store last
|
# replace buffer with define_global + add load/store last
|
||||||
bufs = []
|
bufs = []
|
||||||
for s in k.src:
|
for s in k.src:
|
||||||
|
|
@ -204,9 +203,15 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||||
bufs.append(s)
|
bufs.append(s)
|
||||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
# replace global memory ops with the BUFFER they write to
|
||||||
|
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||||
|
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
|
||||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||||
|
# TODO: move these to codegen
|
||||||
|
ast = graph_rewrite(ast, view_left, name="Main View Left")
|
||||||
|
ast = graph_rewrite(ast, view_right, name="Main View Right")
|
||||||
|
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
|
||||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||||
|
|
||||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||||
|
|
|
||||||
|
|
@ -154,8 +154,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
sz = cast(PtrDType, self.dtype).size
|
sz = cast(PtrDType, self.dtype).size
|
||||||
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
|
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
|
||||||
|
|
||||||
# hack for PTX, CASTing the ptr loses the shape
|
# hack for PTX, CASTing the ptr loses the shape. even worse hack with tag
|
||||||
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL and self.src[0].tag is None: return None
|
||||||
|
|
||||||
# otherwise we get the shape from sources
|
# otherwise we get the shape from sources
|
||||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue