Compare commits

...

9 commits

Author SHA1 Message Date
George Hotz
29262a7543 fine for ptx 2025-08-05 18:26:06 -07:00
George Hotz
dcc6ddf0eb that hack broke things 2025-08-05 18:24:58 -07:00
George Hotz
d1d935242b Revert "fix tests"
This reverts commit a27019383d.
2025-08-05 18:06:27 -07:00
George Hotz
6b330f302d remote metal was flaky 2025-08-05 17:50:14 -07:00
George Hotz
a27019383d fix tests 2025-08-05 17:45:35 -07:00
George Hotz
8b285e193a move those to fix_kernel_ops 2025-08-05 17:07:12 -07:00
George Hotz
f0c9b11b9e early meta ops 2025-08-05 17:00:49 -07:00
George Hotz
1902a85ac1 early load buffer 2025-08-05 16:56:10 -07:00
George Hotz
b2fc111e3f cleanup fix_kernel 2025-08-05 16:46:13 -07:00
5 changed files with 47 additions and 42 deletions

View file

@ -870,29 +870,29 @@ jobs:
- name: Test ONNX Runner (WEBGPU)
run: WEBGPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_runner.py
osxremote:
name: MacOS (remote metal)
runs-on: macos-15
timeout-minutes: 10
env:
REMOTE: 1
REMOTEDEV: METAL
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: macos-remote
deps: testing_minimal
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
#osxremote:
# name: MacOS (remote metal)
# runs-on: macos-15
# timeout-minutes: 10
# env:
# REMOTE: 1
# REMOTEDEV: METAL
# steps:
# - name: Checkout Code
# uses: actions/checkout@v4
# - name: Setup Environment
# uses: ./.github/actions/setup-tinygrad
# with:
# key: macos-remote
# deps: testing_minimal
# - name: Check Device.DEFAULT and print some source
# run: |
# python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
# python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
# DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
# - name: Run REMOTE=1 Test
# run: |
# python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
amdremote:
name: Linux (remote)

View file

@ -95,7 +95,7 @@ view_right = merge_views+PatternMatcher([
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.LOAD, Ops.STORE}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),

View file

@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
# **** Grouper decides which of the UOps realize

View file

@ -150,18 +150,11 @@ create_kernels = PatternMatcher([
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
early_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
# STORE (except for meta ops)
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x), tag=1)),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
])
def check_load_st(glbl:UOp, view:UOp):
@ -175,6 +168,16 @@ def check_load_st(glbl:UOp, view:UOp):
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# add the LOAD
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
@ -193,10 +196,6 @@ replace_globals = PatternMatcher([
def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last
bufs = []
for s in k.src:
@ -204,9 +203,15 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
# replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
# TODO: move these to codegen
ast = graph_rewrite(ast, view_left, name="Main View Left")
ast = graph_rewrite(ast, view_right, name="Main View Right")
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])

View file

@ -154,8 +154,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
sz = cast(PtrDType, self.dtype).size
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
# hack for PTX, CASTing the ptr loses the shape
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
# hack for PTX, CASTing the ptr loses the shape. even worse hack with tag
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL and self.src[0].tag is None: return None
# otherwise we get the shape from sources
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None