Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
48dd1d6543 no long 2026-01-21 21:53:23 +09:00
George Hotz
a06023cd36 viz slowness 2026-01-21 21:50:09 +09:00
George Hotz
ac232bceb5 remove the device when we render 2026-01-21 21:41:29 +09:00
George Hotz
f19fbadce4 regression test 2026-01-21 21:30:04 +09:00
George Hotz
8eb762d6fa add device to local, fix PCONTIG=2 2026-01-21 19:11:49 +09:00
4 changed files with 18 additions and 3 deletions

View file

@ -76,6 +76,18 @@ class TestRangeifyEdgeCase(unittest.TestCase):
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
def test_pcontig_multi_gather(self):
# regression test: local bufferize must have device set for const_like to work
with Context(PCONTIG=2):
# NOTE: with uint type, this will become a long and fail on WEBGPU
forest = Tensor(list(range(8)), dtype='int')
idx = Tensor([0, 0], dtype='int')
node_val = forest.gather(0, idx)
idx2 = idx * 2 + 1
node_val2 = forest.gather(0, idx2)
result = (node_val + node_val2).numpy()
self.assertEqual(result.tolist(), [1, 1])
if getenv("BIG") > 2:
# llama 8B (8192)
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128

View file

@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(None, AddrSpace.LOCAL, removable=removable)
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)

View file

@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
@ -429,6 +429,9 @@ to_define_global = PatternMatcher([
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
# remove device from local BUFFERIZE
(UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),

View file

@ -120,7 +120,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
if len(u.toposort()) < 30: label += f"\n{u.render()}"
ranges: list[UOp] = []
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])