Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz 2026-05-28 19:18:31 -07:00 committed by GitHub
commit 7da2c151be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 260 additions and 263 deletions

View file

@ -585,8 +585,19 @@ jobs:
- name: Test quantize onnx
run: DEBUG=2 DEV=DSP python3 test/backend/test_quantize_onnx.py
testwebgpu:
name: Linux (WebGPU)
testlinux:
strategy:
fail-fast: false
matrix:
dev:
- 'CPU:CLANGJIT'
- 'CPU:LLVM'
- 'CPU:LVP'
- 'CPU:X86'
- 'CL'
- 'WEBGPU'
name: Linux (DEV=${{ matrix.dev }})
runs-on: *linux
timeout-minutes: 20
steps:
@ -595,17 +606,21 @@ jobs:
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: webgpu-minimal
key: linux-${{ matrix.dev }}
deps: testing_unit
python-version: '3.12'
webgpu: 'true'
- name: Check Device.DEFAULT (WEBGPU) and print some source
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANGJIT') }}
mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }}
webgpu: ${{ matrix.dev == 'WEBGPU' }}
opencl: ${{ matrix.dev == 'CL' }}
- name: Set env
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
DEV=WEBGPU python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
DEV=WEBGPU DEBUG=4 FORWARD_ONLY=1 python3 test/test_tiny.py TestTiny.test_plus
- name: Run selected webgpu tests
run: |
DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Vulkan" python3 -m pytest -n=auto test/backend --durations=20
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
- name: Run backend tests
run: python -m pytest -n=auto test/backend --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
@ -757,39 +772,6 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testcpuopencl:
strategy:
fail-fast: false
matrix:
backend: [llvm, cpu, opencl, lvp, x86]
name: Linux (${{ matrix.backend }})
runs-on: *linux
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: ${{ matrix.backend }}-minimal
deps: testing_unit
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'cpu' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'cpu' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
DEBUG=5 FORWARD_ONLY=1 python3 test/test_tiny.py TestTiny.test_plus
- name: Run pytest (${{ matrix.backend }})
run: python -m pytest -n=auto test/backend --durations=20
- name: Run TRANSCENDENTAL math
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/backend/test_ops.py::TestOps::test_sin test/backend/test_ops.py::TestOps::test_cos test/backend/test_ops.py::TestOps::test_tan test/backend/test_ops.py::TestOps::test_exp test/backend/test_ops.py::TestOps::test_log --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
# ****** OSX Tests ******
testmetal:
@ -920,13 +902,17 @@ jobs:
# ****** Windows Tests ******
wintests:
testwindows:
strategy:
fail-fast: false
matrix:
backend: [llvm, cpu, webgpu, x86]
dev:
- 'CPU:CLANGJIT'
- 'CPU:LLVM'
- 'CPU:X86'
- 'WEBGPU'
name: Windows (${{ matrix.backend }})
name: Windows (DEV=${{ matrix.dev }})
runs-on: windows-latest
timeout-minutes: 15
steps:
@ -935,25 +921,20 @@ jobs:
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: windows-${{ matrix.backend }}-minimal
key: windows-${{ matrix.dev }}-minimal
deps: testing_unit
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
pydeps: ${{ matrix.dev == 'WEBGPU' && 'dawn-python' || '' }}
- name: Set env
shell: bash
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
- name: Run unit tests
if: matrix.backend=='llvm'
# test_newton_schulz hits RecursionError
run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_tar.py --ignore=test/unit/test_linalg.py --durations=20
- name: Run NULL backend tests
if: matrix.backend=='llvm'
shell: bash
run: DEV=NULL python -m pytest -n=auto test/null/ --ignore=test/null/test_elf.py --durations=20
- name: Run pytest (${{ matrix.backend }})
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
shell: bash
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU', 'X86':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
python -m pytest -n=auto test/test_tiny.py test/backend/test_ops.py --durations=20
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
- name: Run test_tiny
shell: bash
run: python -m pytest -n=auto test/test_tiny.py --durations=20
# ****** Compile-only Tests ******

View file

@ -1438,13 +1438,14 @@ def train_llama3():
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts] if hasattr(model, "_fp8_grad_amax") else []
fp8_inv_scales = list(model._fp8_inv_scale.values())
fp8_inv_scales = list(model._fp8_inv_scale.values()) + list(model._fp8_next_inv_scale.values())
from tinygrad.nn.state import get_state_dict
model_state = get_state_dict(model)
for wname in model._fp8_inv_scale:
w = model_state[wname]
w._inv_scale = model._fp8_inv_scale[wname]
w._next_inv_scale = model._fp8_next_inv_scale[wname]
if optim.master_params:
idx = next(j for j, p in enumerate(optim.params) if p is w)
master = optim.master_params[idx]

View file

@ -136,6 +136,7 @@ class FlatTransformer:
w_scales = [("wqkv", s_qkv), ("wo", s_o), ("w2", s_2)]
w_scales += [("w1", s_1), ("w3", s_3)] if SPLIT_W13 else [("w13", s_13)]
self._fp8_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
self._fp8_next_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
@ -227,7 +228,7 @@ class FlatTransformer:
self.w1.shard_(device, axis=1).realize()
self.w3.shard_(device, axis=1).realize()
else:
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
@ -241,6 +242,8 @@ class FlatTransformer:
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
for name in self._fp8_inv_scale:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
for name in self._fp8_next_inv_scale:
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
def __call__(self, tokens:Tensor, save:bool=True):
h = self.tok_embeddings(tokens)

View file

@ -39,7 +39,8 @@ class GradAccClipAdamW(Optimizer):
for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i], self.master_params[i] if self.master_params else None))
# collect inv_scale tensors attached to fp8 params (set by _apply_update)
fp8_inv_scales = [tt._inv_scale for tt in self.params if hasattr(tt, '_inv_scale')]
to_realize = extra+self.params+self.buffers+(self.master_params or [])+fp8_inv_scales
fp8_next_inv_scales = [tt._next_inv_scale for tt in self.params if hasattr(tt, '_next_inv_scale')]
to_realize = extra+self.params+self.buffers+(self.master_params or [])+fp8_inv_scales+fp8_next_inv_scales
Tensor.realize(*to_realize)
return extra[-1]
@ -89,13 +90,14 @@ class GradAccClipAdamW(Optimizer):
if t.dtype in dtypes.fp8s:
from examples.mlperf.models.flat_llama import FP8_MAX
# delayed scaling: reuse previous step's inv_scale
t._inv_scale.assign(t._next_inv_scale)
scale = t._inv_scale.reciprocal().reshape(-1, *([1]*(new_w.ndim-1)))
scaled = (new_w * scale).clamp(-FP8_MAX, FP8_MAX)
ret = scaled.cast(t.dtype)
# update inv_scale for next step from quantized result
new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale).detach()
inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
t._inv_scale.assign(inv.shard_like(t._inv_scale) if offloaded else inv)
new_inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
t._next_inv_scale.assign(new_inv.shard_like(t._next_inv_scale) if offloaded else new_inv)
return ret.shard_like(t) if offloaded else ret
out = new_w.cast(t.dtype)
return out.shard_like(t) if offloaded else out

View file

@ -23,7 +23,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
def name_of(bu:UOp, is_out:bool) -> str:
nonlocal n
if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg), f"input{bu.arg}", prod(bu.shape)*bu.dtype.itemsize
if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg.slot), f"input{bu.arg.slot}", prod(bu.shape)*bu.dtype.itemsize
else:
b = bu.buffer
key, size = (id(b.base), b.offset, b.size, b.dtype), b.size*b.dtype.itemsize

View file

@ -167,7 +167,7 @@ class TestDSPcodePatterns(unittest.TestCase):
def test_global_atomic_add_f32_parsing(self):
"""Test GLOBAL_ATOMIC_ADD_F32 keeps memory values in float dtype."""
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1024), arg=2)
vmem = UOp.param(2, dtypes.uint32.ptr(1024))
srcs = {
'ADDR': UOp.const(dtypes.uint64, 0),
'DATA': UOp.const(dtypes.uint32, 0x3f800000),
@ -198,7 +198,7 @@ class TestDSPcodePatterns(unittest.TestCase):
def test_mem_read_parsing(self):
"""Test MEM[addr].type read expression parsing."""
# Create a mock LDS buffer
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
lds = UOp.param(3, dtypes.uint32.ptr(16384))
addr = UOp.const(dtypes.uint32, 0)
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
@ -233,7 +233,7 @@ class TestDSPcodePatterns(unittest.TestCase):
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
self.assertIsNotNone(pcode)
assert pcode is not None
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
lds = UOp.param(3, dtypes.uint32.ptr(16384))
srcs = {
'ADDR': UOp.const(dtypes.uint32, 0),
'OFFSET0': UOp.const(dtypes.uint32, 0),
@ -314,7 +314,7 @@ class TestConcatWidthParsing(unittest.TestCase):
self.assertEqual(parsed.simplify().arg, expected)
def test_permlane64_wave64_pcode_indices(self):
vgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(256), arg=0)
vgpr = UOp.param(0, dtypes.uint32.ptr(256))
srcs = {
'SRC0': UOp.const(dtypes.uint32, 0),
'VDST': UOp.const(dtypes.uint32, 1),
@ -347,7 +347,7 @@ class TestAllPcode(unittest.TestCase):
def _make_srcs(self):
"""Create dummy source variables for pcode parsing."""
u32, u64 = lambda v=0: UOp.const(dtypes.uint32, v), lambda v=0: UOp.const(dtypes.uint64, v)
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
lds = UOp.param(3, dtypes.uint32.ptr(16384))
return {'laneId': u32(), 'laneID': u32(), 'S0': u32(), 'S1': u32(), 'S2': u32(), 'S3': u32(), 'SRC0': u32(),
'D0': u32(), 'D1': u32(), 'DST': u32(), 'VDST': u32(), 'SDST': u32(),
'VCC': u64(), 'VCCZ': u32(), 'EXEC': u64(), 'EXEC_LO': u32(), 'EXECZ': u32(), 'SCC': u32(),

View file

@ -57,7 +57,7 @@ class TestIselX86(unittest.TestCase):
# need to move src from gpr to xmm before broadcasting
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
# if we can fuse a load we can skip the move and access memory directly
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
n = self.isel_rewrite(load.broadcast(4))
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
@ -122,20 +122,20 @@ class TestIselX86(unittest.TestCase):
# complex address is [base + index*scale + displacement]
def test_complex_address(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load()
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
n = self.isel_rewrite(load)
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
def test_fold_load(self):
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
n = self.isel_rewrite(load1 + load2)
self.assertTrue(len(n.src) == 4)
# don't fold when used multiple times
def test_dont_fold_load(self):
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
# used by multiple users
n = self.isel_rewrite(load + 1 + load)
self.assertTrue(len(n.src) == 2)
@ -144,4 +144,4 @@ class TestIselX86(unittest.TestCase):
self.assertTrue(len(n.src) == 2)
if __name__ == "__main__":
unittest.main()
unittest.main()

View file

@ -11,16 +11,16 @@ from tinygrad.codegen import to_program
class TestLinearizerFailure(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
def test_failure_beam_mnist(self):
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(4014080), arg=0, src=())
c0 = UOp.param(0, dtypes.uchar.ptr(4014080))
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 0, AxisType.GLOBAL)
c2 = UOp.range(UOp.const(dtypes.weakint, 784), 1, AxisType.GLOBAL)
c3 = UOp.range(UOp.const(dtypes.weakint, 10), 3, AxisType.GLOBAL)
c4 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
c4 = UOp.param(1, dtypes.int.ptr(512))
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
c6 = UOp.range(UOp.const(dtypes.weakint, 6000), 1004, AxisType.REDUCE)
c7 = UOp.range(UOp.const(dtypes.weakint, 3750), 2006, AxisType.REDUCE)
c8 = UOp.range(UOp.const(dtypes.weakint, 16), 2007, AxisType.GROUP_REDUCE)
c9 = UOp(Ops.PARAM, dtypes.uchar.ptr(47040000), arg=2, src=())
c9 = UOp.param(2, dtypes.uchar.ptr(47040000))
c10 = c9.index((((c3*UOp.const(dtypes.weakint, 4704000))+c2)+(c6*UOp.const(dtypes.weakint, 784))).valid(UOp.const(dtypes.bool, True)))
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.weakint, 6000))+c6)+((c7*UOp.const(dtypes.weakint, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.weakint, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
c12 = c0.index((((c1*UOp.const(dtypes.weakint, 7840))+(c2*UOp.const(dtypes.weakint, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)

View file

@ -22,8 +22,8 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
dtype = alu_src_uops[0].dtype
a = UOp(Ops.PARAM, dtype.ptr(), (), 0)
b = UOp(Ops.PARAM, dtype.ptr(), (), 1)
a = UOp.param(0, dtype.ptr())
b = UOp.param(1, dtype.ptr())
idx = UOp.const(dtypes.int, 0)
ld = b.index(idx)
alu = ld.alu(alu_op, *alu_src_uops)
@ -33,7 +33,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu(self):
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
a = UOp.param(0, dtypes.int.ptr())
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
@ -42,7 +42,7 @@ class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu_2d(self):
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
a = UOp.param(0, dtypes.int.ptr())
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
@ -87,7 +87,7 @@ class TestWGSLFailures(unittest.TestCase):
class TestPTXFailures(unittest.TestCase):
@unittest.skip("INDEX can only have a gate ALU parent, not an IF")
def test_gated_store_with_if(self):
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
a = UOp.param(0, dtypes.int.ptr())
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
val = UOp.const(dtypes.int, 1)
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))

View file

@ -20,6 +20,7 @@ def run_uops(uops_list:list[UOp], bufs:list[Buffer]):
def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
if op is Ops.CONST: uops.append(UOp.const(dtype, arg))
elif op is Ops.PARAM: uops.append(UOp.param(arg, dtype).replace(src=()))
else: uops.append(UOp(op, dtype, tuple(src), arg))
return uops[-1]
@ -220,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 0)
out = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 1)
g1 = UOp.param(0, dtypes.int32.ptr())
out = UOp.param(1, dtypes.int32.ptr())
c1 = UOp.const(dtypes.int, 2)
c2 = UOp.const(dtypes.int, 3)
l1 = g1.index(c1)
@ -248,7 +249,7 @@ class TestAssembly(unittest.TestCase):
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
def test_mulacc_shl(self):
g1 = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 0)
g1 = UOp.param(0, dtypes.int32.ptr())
c1 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 1)
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
@ -257,7 +258,7 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.MULACC, [x.op for x in uops])
def test_use_cmpeq(self):
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
g = UOp.param(0, dtypes.uint32.ptr())
c = UOp.const(dtypes.uint, 7)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)

View file

@ -12,7 +12,7 @@ from tinygrad.dtype import ImageDType, Invalid
# PYTHONPATH="." DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
def vision_conv_143():
c0 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 0)
c0 = UOp.param(0, dtypes.imageh((16, 1024, 4)))
c2 = UOp.range(32, 3, AxisType.LOOP)
c5 = UOp.range(128, 4, AxisType.LOOP)
c8 = UOp.range(16, 2, AxisType.LOOP)
@ -22,13 +22,13 @@ def vision_conv_143():
c26 = UOp.range(7, 1, AxisType.REDUCE)
c27 = c2*2+c26
c32 = ((c27<3)!=True)&(c27<67)
c34 = UOp(Ops.PARAM, dtypes.imageh((32, 1024, 4)), (), 1)
c34 = UOp.param(1, dtypes.imageh((32, 1024, 4)))
c38 = c5//2
c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.weakint, Invalid))
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
c49 = UOp(Ops.PARAM, dtypes.imageh((64, 49, 4)), (), 2)
c49 = UOp.param(2, dtypes.imageh((64, 49, 4)))
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
c63 = UOp(Ops.PARAM, dtypes.float.ptr(128), (), 3)
c63 = UOp.param(3, dtypes.float.ptr(128))
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
c67 = c0.index((c2*128+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
@ -38,7 +38,7 @@ def vision_conv_143():
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
def vision_conv_153():
c0 = UOp(Ops.PARAM, dtypes.imageh((8, 1024, 4)), (), 0)
c0 = UOp.param(0, dtypes.imageh((8, 1024, 4)))
c2 = UOp.range(16, 3, AxisType.LOOP)
c5 = UOp.range(256, 4, AxisType.LOOP)
c8 = UOp.range(8, 2, AxisType.LOOP)
@ -48,13 +48,13 @@ def vision_conv_153():
c26 = UOp.range(7, 1, AxisType.REDUCE)
c27 = c2*2+c26
c32 = ((c27<3)!=True)&(c27<35)
c34 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 1)
c34 = UOp.param(1, dtypes.imageh((16, 1024, 4)))
c38 = c5//2
c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.weakint, Invalid))
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
c49 = UOp(Ops.PARAM, dtypes.imageh((128, 49, 4)), (), 2)
c49 = UOp.param(2, dtypes.imageh((128, 49, 4)))
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
c63 = UOp(Ops.PARAM, dtypes.float.ptr(256), (), 3)
c63 = UOp.param(3, dtypes.float.ptr(256))
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
c67 = c0.index((c2*256+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
@ -64,16 +64,16 @@ def vision_conv_153():
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
def dm_conv_172():
c0 = UOp(Ops.PARAM, dtypes.imageh((1, 240, 4)), (), 0)
c0 = UOp.param(0, dtypes.imageh((1, 240, 4)))
c2 = UOp.range(960, 4, AxisType.LOOP)
c5 = UOp(Ops.PARAM, dtypes.imageh((8, 384, 4)), (), 1)
c5 = UOp.param(1, dtypes.imageh((8, 384, 4)))
c7 = UOp.range(32, 0, AxisType.REDUCE)
c10 = UOp.range(4, 1, AxisType.REDUCE)
c13 = UOp.range(12, 3, AxisType.REDUCE)
c18 = UOp.range(8, 2, AxisType.REDUCE)
c23 = UOp(Ops.PARAM, dtypes.imageh((240, 128, 4)), (), 2)
c23 = UOp.param(2, dtypes.imageh((240, 128, 4)))
c35 = c5.index((c7*4+c10+c13*128+c18*1536))*c23.index((c10*4+c2%4+c7*16+c2//4*512))
c37 = UOp(Ops.PARAM, dtypes.float.ptr(960), (), 3)
c37 = UOp.param(3, dtypes.float.ptr(960))
c39 = c35.reduce(c7, c10, arg=Ops.ADD)+c37.index(c2)
c50 = (1.0+((c39+0.044708251953125*(c39*(c39*c39)))*-2.3021129851685216).exp2()).reciprocal()*c39
c53 = c50.reduce(c18, c13, arg=Ops.ADD)*0.010416666666666666
@ -99,4 +99,4 @@ bufs = [Buffer(ps.arg.device, g.size, g.dtype if isinstance(g.dtype, ImageDType)
gsize, lsize = ps.arg.launch_dims({})
t = rt(*[b._buf for b in bufs], global_size=gsize, local_size=lsize, vals=ps.arg.vals({}), wait=True)
print(f"{t*1e6:.2f} us")
print(f"{t*1e6:.2f} us")

View file

@ -82,7 +82,7 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp(Ops.PARAM, uop.dtype.ptr(), arg=0, src=())
g = UOp.param(0, uop.dtype.ptr())
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)

View file

@ -423,10 +423,10 @@ def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode
class _Ctx:
"""Context for instruction compilation - holds buffers and helpers."""
__slots__ = ('inst_size', 'dyn_fields', '_axis_id', 'wave_size', 'vgpr', 'accvgpr')
sgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(SGPR_COUNT), arg=0)
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1 << 46), arg=2)
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
scratch = UOp(Ops.PARAM, dtypes.uint8.ptr(1 << 30), arg=4)
sgpr = UOp.param(0, dtypes.uint32.ptr(SGPR_COUNT))
vmem = UOp.param(2, dtypes.uint32.ptr(1 << 46))
lds = UOp.param(3, dtypes.uint32.ptr(16384))
scratch = UOp.param(4, dtypes.uint8.ptr(1 << 30))
# Cache PARAM UOps by wave_size so all _Ctx instances with same wave_size share identical UOp references
_vgpr_cache: dict[int, UOp] = {}
_accvgpr_cache: dict[int, UOp] = {}
@ -434,10 +434,10 @@ class _Ctx:
def __init__(self, inst_size: int, wave_size: int = 32):
self.inst_size, self._axis_id, self.wave_size = inst_size, 0, wave_size
self.dyn_fields: list[tuple[int, int]] = [] # (lo, hi) of fields read dynamically
if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=1)
if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp.param(1, dtypes.uint32.ptr(256 * wave_size))
self.vgpr = _Ctx._vgpr_cache[wave_size]
if wave_size == 64:
if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=5)
if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp.param(5, dtypes.uint32.ptr(256 * wave_size))
self.accvgpr = _Ctx._accvgpr_cache[wave_size]
else:
self.accvgpr = self.vgpr

View file

@ -96,7 +96,7 @@ class TestGroupedDims(unittest.TestCase):
def test_global_prod_max(self):
g, l = UOp.range(256, 0, AxisType.GLOBAL), UOp.range(256, 1, AxisType.LOCAL)
sink = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
sink = UOp.param(0, dtypes.float.ptr()).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
class R(Renderer): global_max, local_max, global_prod_max = (256, 256, 256), (128, 128, 128), (128, 128, 128)
specials = [u for u in add_gpudims(R(Target()), sink).toposort() if u.op is Ops.SPECIAL]
self.assertGreater(len([s for s in specials if "lidx" in s.arg]), 1)

View file

@ -7,14 +7,14 @@ from tinygrad.codegen import to_program
class TestLinearizerFailures(unittest.TestCase):
def test_fail_1(self):
c0 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=0, src=())
c0 = UOp.param(0, dtypes.float.ptr(64))
c1 = UOp.range(UOp.const(dtypes.weakint, 2), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.weakint, 32), 2, AxisType.LOOP)
c3 = ((c1*UOp.const(dtypes.weakint, 32))+c2)
c4 = UOp(Ops.PARAM, dtypes.float.ptr(163840), arg=1, src=())
c4 = UOp.param(1, dtypes.float.ptr(163840))
c5 = UOp.range(UOp.const(dtypes.weakint, 2560), 0, AxisType.REDUCE)
c6 = c4.index(((((((c5//UOp.const(dtypes.weakint, 8))%UOp.const(dtypes.weakint, 8))*UOp.const(dtypes.weakint, 8))+(c5%UOp.const(dtypes.weakint, 8)))+(((c2*UOp.const(dtypes.weakint, 40))+(c5//UOp.const(dtypes.weakint, 64)))*UOp.const(dtypes.weakint, 64)))+(c1*UOp.const(dtypes.weakint, 81920))))
c7 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=2, src=())
c7 = UOp.param(2, dtypes.float.ptr(64))
c8 = c7.index(c3)
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
c10 = c0.index(c3).store(c9).end(c1, c2)

View file

@ -15,13 +15,13 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
@ -513,7 +513,7 @@ class TestDropTrueGate(unittest.TestCase):
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import sym
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
buf = UOp.param(0, dtypes.int.ptr())
idx = UOp.const(dtypes.weakint, 0)
true_gate = UOp.const(dtypes.bool, True)
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate)))
@ -557,7 +557,7 @@ class TestRangeShrink(unittest.TestCase):
# one load guards r < 4, but another load uses r without a gate -> no shrink
r = Range(0, 204)
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),))
load2 = UOp(Ops.LOAD, dtypes.float, (UOp.param(1, dtypes.float.ptr()).index(r, ptr=True),))
ranges = self.get_ranges(UOp.sink(load1, load2))
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
@ -583,7 +583,7 @@ class TestRangeShrink(unittest.TestCase):
from tinygrad.dtype import Invalid
r = Range(0, 204)
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
ranges = self.get_ranges(UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(r).store((r < 4).where(x, 0)).sink())
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(x, 0)).sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)
@ -592,7 +592,7 @@ class TestRangeShrink(unittest.TestCase):
from tinygrad.dtype import Invalid
r = Range(0, 204)
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
ranges = self.get_ranges(UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(r).store((r < 4).where(0, x)).sink())
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(0, x)).sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)

View file

@ -1,7 +1,7 @@
import unittest, math
import numpy as np
from tinygrad import dtypes
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.ops import UOp
from tinygrad.uop.decompositions import TRANSCENDENTAL_DTYPES, payne_hanek_reduction, cody_waite_reduction
from tinygrad.uop.decompositions import frexp, rintk, xpow, xexp2, xlog2, trig_poly, pow2if
from test.helpers import eval_uop
@ -10,7 +10,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
def test_payne_hanek_reduction(self):
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
# Load input value from a buffer to prevent constant folding
input_buf = UOp(Ops.PARAM, dtypes.double.ptr(), arg=1, src=())
input_buf = UOp.param(1, dtypes.double.ptr())
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))

View file

@ -260,7 +260,7 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("this test isn't valid uops")
def test_noop_vectorize_fold(self):
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), arg=0)
d0 = UOp.param(0, dtypes.float.ptr())
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(Ops.STACK, dtypes.float.vec(2), (ld,))
@ -272,9 +272,9 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("this test isn't valid uops")
def test_gep_vec_fold(self):
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
d1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
d2 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 2)
d0 = UOp.param(0, dtypes.float.ptr())
d1 = UOp.param(1, dtypes.float.ptr())
d2 = UOp.param(2, dtypes.float.ptr())
idx = UOp.const(dtypes.int, 0)
def _test_vec(geps, count=4):
vec = UOp(Ops.STACK, dtypes.float.vec(count), geps)
@ -380,8 +380,8 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
def test_cast_alu_fold(self):
d0 = UOp(Ops.PARAM, dtypes.bool.ptr(), arg=0)
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), arg=1)
d0 = UOp.param(0, dtypes.bool.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = (ld<1).cast(dtypes.bool)
@ -390,8 +390,8 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), arg=0)
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), arg=1)
d0 = UOp.param(0, dtypes.float.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = ld.cast(dtypes.float).cast(dtypes.float)
@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp(Ops.PARAM, dt.ptr(), arg=0)
d0 = UOp.param(0, dt.ptr())
v = d0.index(UOp.const(dtypes.int, 0))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
@ -427,10 +427,10 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
d0 = UOp.param(0, dtypes.long.ptr())
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
out = UOp(Ops.PARAM, dtypes.long.ptr(), (), 1)
out = UOp.param(1, dtypes.long.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
@ -438,7 +438,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
d0 = UOp.param(0, dtypes.long.ptr())
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
@ -448,11 +448,11 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_with_cast(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
d0 = UOp.param(0, dtypes.int.ptr())
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
out = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
out = UOp.param(1, dtypes.float.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
@ -460,27 +460,27 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_casted_gated_load_extra_cond(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
d0 = UOp.param(0, dtypes.float.ptr())
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
out = UOp(Ops.PARAM, dtypes.half.ptr(), (), 1)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_on_casted_gated_load_extra_cond_swapped(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
d0 = UOp.param(0, dtypes.float.ptr())
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
out = UOp(Ops.PARAM, dtypes.half.ptr(), (), 1)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_in_store_becomes_gate(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
d0 = UOp.param(0, dtypes.long.ptr())
idx = d0.index(ridx0)
ld = idx.load()
val = (ridx0<50).where(5, ld)
@ -493,14 +493,14 @@ class TestUOpGraph(unittest.TestCase):
def test_load_idx_becomes_int(self):
# mnist indexing with split reduceop
# Make sure we are not doign math on the loaded index, which would promote it to long
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
c0 = UOp.param(0, dtypes.uchar.ptr(128000))
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
c3 = UOp.param(1, dtypes.int.ptr(512))
c4 = c3.index(c1)
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
c7 = UOp.param(2, dtypes.uchar.ptr(60000))
c8 = c7.index(c6)
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
c10 = c0.index(((c1*UOp.const(dtypes.weakint, 250))+c2)).store(c9).end(c1, c2)
@ -510,14 +510,14 @@ class TestUOpGraph(unittest.TestCase):
def test_load_idx_no_math_on_loaded(self):
# test the (x+y)<c pattern where x has loads - we shouldn't do math on loaded indices
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
c0 = UOp.param(0, dtypes.uchar.ptr(128000))
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
c3 = UOp.param(1, dtypes.int.ptr(512))
c4 = c3.index(c1) # c4 is a load
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
c7 = UOp.param(2, dtypes.uchar.ptr(60000))
c8 = c7.index(c6)
# (loaded + range) < const pattern - loaded value shouldn't be promoted to long
loaded_idx = c4.cast(dtypes.weakint)
@ -529,9 +529,9 @@ class TestUOpGraph(unittest.TestCase):
self.assertNotEqual(u.dtype, dtypes.long)
def test_fold_gated_load(self):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 2)
glbl0 = UOp.param(0, dtypes.int.ptr())
glbl1 = UOp.param(1, dtypes.int.ptr())
glbl2 = UOp.param(2, dtypes.int.ptr())
idx = UOp.const(dtypes.int, 0)
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
@ -541,7 +541,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
glbl0 = UOp.param(0, dtypes.int.ptr())
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
@ -555,7 +555,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
def test_fold_gated_store(self):
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
glbl = UOp.param(0, dtypes.int.ptr())
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
@ -563,12 +563,12 @@ class TestUOpGraph(unittest.TestCase):
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 6) # +1 for SINK
self.assertEqual(len(uops), 7) # +1 for SINK, +1 for PARAM shape sentinel
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
glbl0 = UOp.param(0, dtypes.int.ptr())
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
@ -779,7 +779,7 @@ class TestLoadStoreFolding(unittest.TestCase):
def test_gated_load_gep_preserves_alt(self):
"""Test that LOAD(GEP, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf = UOp(Ops.PARAM, dtypes.float.vec(4).ptr(), (), 0)
buf = UOp.param(0, dtypes.float.vec(4).ptr())
idx = UOp.const(dtypes.int, 0)
gate = UOp.const(dtypes.bool, True)
gated_index = buf.index(idx.valid(gate))
@ -797,8 +797,8 @@ class TestLoadStoreFolding(unittest.TestCase):
def test_gated_load_ptrcat_preserves_alt(self):
"""Test that LOAD(PTRCAT, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
buf2 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
buf1 = UOp.param(0, dtypes.float.ptr())
buf2 = UOp.param(1, dtypes.float.ptr())
idx = UOp.const(dtypes.int, 0)
idx1 = buf1.index(idx)
idx2 = buf2.index(idx)

View file

@ -951,7 +951,7 @@ class TestSymbolic(unittest.TestCase):
expr = cond.where(a, b).cast(dtypes.half)
# TODO: copied from render, render does not support cast
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
glbl = UOp.param(0, dtypes.int.ptr())
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
@ -1270,7 +1270,7 @@ class TestStoreLoadFolding(unittest.TestCase):
"""Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification."""
def test_store_load_folding(self):
# store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0)
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
buf = UOp.param(0, dtypes.int.ptr())
index = buf.index(UOp.const(dtypes.weakint, 0))
# Direct: store(idx, load(idx)) -> NOOP
self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP)
@ -1340,7 +1340,7 @@ class TestRangeSplitting(unittest.TestCase):
from tinygrad.codegen.simplify import pm_split_ranges, pm_flatten_range
r0 = UOp.range(uconst(8), 0)
# create a simple expression using the range with mod: store range%2 to a buffer
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
buf = UOp.param(0, dtypes.int.ptr())
val = (r0 % uconst(2)).cast(dtypes.int)
store = UOp(Ops.STORE, dtypes.void, (buf.index(uconst(0)), val))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.END, dtypes.void, (store, r0)),))

View file

@ -82,7 +82,7 @@ class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_multiplication_0_inf(self):
# vmin and vmax for multiplication with a variable
x = UOp.const(dtypes.float, 0.0)
y = UOp.load(UOp(Ops.PARAM, dtypes.float.ptr(1), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
y = UOp.load(UOp.param(0, dtypes.float.ptr(1)), UOp.const(dtypes.int, 0), dtype=dtypes.float)
uop = x * y
# TODO: these should be 0, but definitely should not be nan
self.assertEqual(uop.vmin, -math.inf)
@ -316,7 +316,7 @@ class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vector_with_gep(self):
# vmin and vmax for a vector constant of bool values
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 1)
d1 = UOp.param(1, dtypes.int.ptr())
idx = UOp.const(dtypes.int, 0)
val = UOp(Ops.LOAD, dtypes.int.vec(2), (d1.index(idx).cast(dtypes.int.vec(2).ptr()),))
uop = (val // 32).gep(0)

View file

@ -110,7 +110,7 @@ class TestExecALU(unittest.TestCase):
class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
gmem = UOp.param(0, dtypes.float.ptr())
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
gate = gidx0<UOp.const(dtypes.int, 1)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
@ -126,8 +126,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
self.assertEqual(len(gated_uops[-1].src), 2)
def test_gate_some_stores(self):
gmem0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
@ -146,8 +146,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.skip("we don't merge ifs anymore")
def test_merge_ifs_alt(self):
gmem0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0*UOp.const(dtypes.int, 2)
gate = gidx0<UOp.const(dtypes.int, 1)
@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
class TestFastIdiv(unittest.TestCase):
def test_division_power_of_two(self):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
g = UOp.param(0, dt.ptr())
c = UOp.const(dt, 2)
l = g.index(c)
a = UOp(Ops.CDIV, dt, (l, c))
@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floormod_power_of_two(self):
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
g = UOp.param(0, dt.ptr())
c = UOp.const(dt, 8)
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floordiv_power_of_two_uint(self):
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
for dt in (dtypes.uint32, dtypes.uint64):
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
g = UOp.param(0, dt.ptr())
c = UOp.const(dt, 2)
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
@Context(DISABLE_FAST_IDIV=0)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
def test_fast_idiv_and_mod(self):
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
g = UOp.param(0, dtypes.uint32.ptr())
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
@unittest.expectedFailure
def test_fast_idiv_overflow(self):
# This will be possible with a slightly different method for fast_idiv
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
g = UOp.param(0, dtypes.uint32.ptr())
c = UOp.const(dtypes.uint, 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
self.assertNotIn(Ops.CDIV, ops)
def test_disable_fast_idiv(self):
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
g = UOp.param(0, dtypes.uint32.ptr())
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@ -290,8 +290,8 @@ class TestUOpMethod(unittest.TestCase):
self.assertEqual((gidx0*3+1).const_factor(), 1)
def test_replace(self):
x = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
self.assertIs(x.replace(arg=None).arg, None)
x = UOp.param(0, dtypes.int.ptr())
self.assertEqual(x.replace(arg=UOp.param(1, dtypes.int.ptr()).arg).arg.slot, 1)
with self.assertRaises(AssertionError): x.replace(field="a")
def test_const_zero_neg_zero_different(self):

View file

@ -139,7 +139,7 @@ class TestUOpsStats(unittest.TestCase):
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(Ops.PARAM, dtypes.int.ptr(), tuple())
globl = UOp.param(0, dtypes.int.ptr())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = globl.index(o1)
@ -149,7 +149,7 @@ class TestUOpsStats(unittest.TestCase):
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = tuple(u5.toposort())
globl = UOp(Ops.PARAM, dtypes.int.ptr(), tuple())
globl = UOp.param(0, dtypes.int.ptr())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = globl.index(o1)

View file

@ -11,7 +11,7 @@ class TestValidateOOB(unittest.TestCase):
# basic index patterns
def test_const_index(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
with self.assertRaises(RuntimeError):
@ -21,7 +21,7 @@ class TestValidateOOB(unittest.TestCase):
def test_variable_index(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
@ -30,7 +30,7 @@ class TestValidateOOB(unittest.TestCase):
def test_range_with_mask(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
r = UOp.range(42, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
@ -38,7 +38,7 @@ class TestValidateOOB(unittest.TestCase):
def test_variable_with_mask(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
v = Variable("v", -5, 80)
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
@ -46,7 +46,7 @@ class TestValidateOOB(unittest.TestCase):
def test_gated_store(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
v = Variable("v", 0, 20)
to_uops_list([buf.index(v.valid(v < 16), ptr=True).store(0)]) # valid
with self.assertRaises(RuntimeError):
@ -55,14 +55,14 @@ class TestValidateOOB(unittest.TestCase):
# ALU ops in index
def test_floordiv(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
def test_mod(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
@ -70,14 +70,14 @@ class TestValidateOOB(unittest.TestCase):
def test_shr(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
def test_shl(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 0)
buf = UOp.param(0, dtypes.int.ptr(64))
r = UOp.range(8, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
with self.assertRaises(RuntimeError):
@ -85,7 +85,7 @@ class TestValidateOOB(unittest.TestCase):
def test_and(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
@ -93,14 +93,14 @@ class TestValidateOOB(unittest.TestCase):
def test_max(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
def test_xor_in_mask(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
r = UOp.range(32, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
@ -109,22 +109,22 @@ class TestValidateOOB(unittest.TestCase):
# cast patterns
def test_float_cast_in_index(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf = UOp.param(0, dtypes.int.ptr(16))
r = UOp.range(20, 0)
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
def test_bool_cast_in_mask(self):
with Context(CHECK_OOB=1, SPEC=2):
buf = UOp(Ops.PARAM, dtypes.int.ptr(1), (), 0)
buf = UOp.param(0, dtypes.int.ptr(1))
r = UOp.range(20, 0)
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
# load result as index/mask
def test_load_as_index(self):
with Context(CHECK_OOB=1, SPEC=2):
buf0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
buf1 = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 1)
buf0 = UOp.param(0, dtypes.int.ptr(16))
buf1 = UOp.param(1, dtypes.int.ptr(64))
r = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.weakint)
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
@ -133,8 +133,8 @@ class TestValidateOOB(unittest.TestCase):
def test_load_bool_as_mask(self):
with Context(CHECK_OOB=1, SPEC=2):
buf_bool = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0)
buf_int = UOp(Ops.PARAM, dtypes.int.ptr(8), (), 1)
buf_bool = UOp.param(0, dtypes.bool.ptr(16))
buf_int = UOp.param(1, dtypes.int.ptr(8))
gidx = UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, 16),), "gidx0")
ld_bool = buf_bool.index(gidx, ptr=True).load()
with self.assertRaises(RuntimeError):
@ -145,7 +145,7 @@ class TestValidateOOB(unittest.TestCase):
def test_in_bounds_access_gated_local(self):
with Context(CHECK_OOB=1):
# Define buffers
gbuf = UOp(Ops.PARAM, dtypes.uint.ptr(400), (), 0)
gbuf = UOp.param(0, dtypes.uint.ptr(400))
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
@ -169,8 +169,8 @@ class TestValidateOOB(unittest.TestCase):
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(CHECK_OOB=1):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0)
glbl0 = UOp.param(0, dtypes.int.ptr(16))
mask = UOp.param(0, dtypes.bool.ptr(16))
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
from tinygrad.helpers import VIZ, pluralize, all_int
@ -176,7 +176,8 @@ def finalize_after(ctx:AllocCtx, x:UOp):
def replace_input_buffer(ctx:AllocCtx, b:UOp):
ctx.replacements.append(b)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
pm_finalize_call = PatternMatcher([
(UPat(Ops.AFTER, name="x"), finalize_after),

View file

@ -22,7 +22,7 @@ def linearize(sink:UOp) -> list[UOp]:
extra = None
match u.op:
# the order and placement of these defines is important
case Ops.PARAM: priority, extra = -20, u.arg
case Ops.PARAM: priority, extra = -20, u.arg.slot
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
case Ops.DEFINE_REG: priority = -18
case Ops.DEFINE_LOCAL: priority = -17
@ -93,4 +93,4 @@ def do_split_ends(e:UOp):
pm_split_ends = PatternMatcher([
# split the ends
(UPat(Ops.END, name="e"), do_split_ends),
])
])

View file

@ -329,7 +329,7 @@ class Scheduler:
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg)
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg.slot)
return [Buffer(dname, x.max_numel(), x.dtype.base) for x in glbls]
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:

View file

@ -94,7 +94,7 @@ class GraphRunner:
self.runtimes: list[Any|None] = []
self.uop_replace: list[list[tuple[int, int]]] = []
for call in self.linear.src:
replace = [(p, b.arg) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM]
replace = [(p, b.arg.slot) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM]
for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)

View file

@ -137,8 +137,8 @@ class ExecContext:
cache: bool = True
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
return inputs[b.arg] if b.op is Ops.PARAM else b
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg.slot], *b.src[1:]))
return inputs[b.arg.slot] if b.op is Ops.PARAM else b
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)]
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:

View file

@ -16,8 +16,9 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
"""Remove unused PARAMs from body and return compacted (body, args)."""
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
used = sorted({p.arg.slot: p for p in body.toposort() if p.op is Ops.PARAM}.items())
body = body.substitute({p: p.replace(arg=dataclasses.replace(p.arg, slot=j)) for j,(_, p) in enumerate(used)}, walk=True)
return body, tuple(all_args[i] for i,_ in used)
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
fxn, args = k.src[0], k.src[1:]
@ -29,7 +30,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
return (None,) + k.arg.grad_fxn(on_dev(ctx, 0), k)
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))

View file

@ -179,8 +179,8 @@ class CStyleLanguage(Renderer):
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
if u.op is not Ops.PARAM: r[u] = u.arg[0]
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
else: r[u] = f"data{u.arg}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg}"
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
bufs[u] = (r[u], (u.dtype, u in writable_params))
continue
@ -321,8 +321,8 @@ class OpenCLRenderer(CStyleLanguage):
def aux(self, uops:list[UOp]):
arg_dtypes:list[list[tuple[int, DType]]] = []
for i,u in enumerate(u for u in uops if u.op is Ops.PARAM):
if len(arg_dtypes) >= u.arg: arg_dtypes.append([])
arg_dtypes[u.arg].append((i, u.dtype))
while len(arg_dtypes) <= u.arg.slot: arg_dtypes.append([])
arg_dtypes[u.arg.slot].append((i, u.dtype))
return tuple(tuple(a) for a in arg_dtypes),
def supported_dtypes(self): return {d for d in super().supported_dtypes()

View file

@ -164,7 +164,7 @@ class LLVMRenderer(Renderer):
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
args.append((r[u], u.dtype))
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"

View file

@ -63,7 +63,7 @@ def mem_type(x:UOp) -> str:
match x.op:
case Ops.AFTER: return mem_type(x.src[0])
case Ops.DEFINE_LOCAL: return 'shared'
case Ops.PARAM: return 'global'
case Ops.PARAM: return 'shared' if x.addrspace == AddrSpace.LOCAL else 'global'
case _: raise RuntimeError(f"{x.op} needs to be memory")
def render_wmma(ctx: "PTXRenderer", wmma: UOp):
@ -90,7 +90,7 @@ string_rewrite = PatternMatcher([
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
(UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg}, %{'ctaid' if x.arg[0] == 'g' else 'tid'}.{chr(120+int(x.arg[-1]))};"),
(UPat(Ops.PARAM, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
(UPat(Ops.PARAM, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg.slot}+0];"),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
@ -219,7 +219,7 @@ class PTXRenderer(Renderer):
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg.slot}", u.dtype))
elif u.op is Ops.WMMA:
# registers for packing/unpacking input and acc
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],

View file

@ -20,7 +20,7 @@ BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
def dcache_flush():
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.codegen import to_program
buf, n = UOp(Ops.PARAM, dtypes.uint8.ptr(), arg=0), UOp(Ops.PARAM, dtypes.uint8.ptr(), arg=1)
buf, n = UOp.param(0, dtypes.uint8.ptr()), UOp.param(1, dtypes.uint8.ptr())
i = UOp.range(n.cast(dtypes.int), 0, dtype=dtypes.int)
flush = UOp(Ops.CUSTOM, dtypes.void, (buf.cast(dtypes.ulong) + i.cast(dtypes.ulong) * UOp.const(dtypes.ulong, 64),),
arg='__asm__ volatile("dc cvac, %0" :: "r"({0}) : "memory");')

View file

@ -77,7 +77,7 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg.slot]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])

View file

@ -137,7 +137,7 @@ def rewrite_into_function(call:UOp):
def param_to_multi(p:UOp):
if p.axis is None: return None
return UOp.param(p.arg, p.dtype, p.shard_shape, p.device).multi(p.axis)
return UOp.param(p.arg.slot, p.dtype, p.shard_shape, p.device, p.arg.vmin_vmax, p.arg.name, p.arg.addrspace).multi(p.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([

View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, ParamArg
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
@ -130,15 +130,15 @@ def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if c.arg.precompile: return None
params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
params = sorted(params, key=lambda x: x.arg)
params = sorted(params, key=lambda x: x.arg.slot)
args = c.src[1:]
# NOTE: this isn't really needed. it's okay if there's unused args in the function
if not allow_param_mismatch:
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
if [x.arg.slot for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg.slot for x in params]}")
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
dict_map = {x:args[x.arg] for x in params}
dict_map = {x:args[x.arg.slot] for x in params}
for i, (p, a) in enumerate(dict_map.items()):
if p.axis != a.axis: raise TypeError(f"arg {i} axis mismatch: expected {p.axis}, got {a.axis}")
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
@ -477,7 +477,7 @@ class LocalAddBufferContext:
opts:tuple|None = None
def debuf(ctx:LocalAddBufferContext, buf:UOp):
ret = UOp(Ops.PARAM, buf.dtype.ptr(prod(buf.max_shape)), arg=ctx.dg).reshape(buf.max_shape)
ret = UOp(Ops.PARAM, buf.dtype.ptr(prod(buf.max_shape), buf.addrspace), arg=ParamArg(ctx.dg, addrspace=buf.addrspace)).reshape(buf.max_shape)
# if the buffer has symbolic shape, shrink the max-sized view to the actual shape
if buf.max_shape != buf.shape: ret = ret.shrink(tuple((0, s) for s in buf.shape))
if buf not in ctx.map: ctx.map[buf] = buf
@ -512,9 +512,11 @@ def find_bufs(x:UOp):
to_define_global = PatternMatcher([
(UPat(Ops.STORE, name="x"), find_bufs),
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="buf"), debuf),
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat.cvar('vmin'), UPat.cvar('vmax'), UPat.var("nm")), name="v"),
lambda v, vmin, vmax, nm: UOp.variable(nm.arg, vmin.arg, vmax.arg, v.dtype)),
(UPat(Ops.PARAM, name="v"), lambda v:
UOp.variable(v.arg.name, v.arg.vmin_vmax[0], v.arg.vmin_vmax[1], v.dtype)
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
(UPat(Ops.BIND, name="b"), unbind_kernel),

View file

@ -203,7 +203,7 @@ class Tensor(OpMixin):
def as_param(self, slot:int):
if self.uop.axis is not None:
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param)

View file

@ -18,6 +18,19 @@ class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
@dataclass(frozen=True, order=True)
class ParamArg:
slot: int
vmin_vmax: tuple[PyConst, PyConst]|None = None
name: str|None = None
addrspace: AddrSpace = AddrSpace.GLOBAL
axis: int|None = None
device: str|tuple[str, ...]|None = None
def __repr__(self):
fields = (("vmin_vmax", None), ("name", None), ("addrspace", AddrSpace.GLOBAL), ("axis", None), ("device", None))
args = [str(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
return f"ParamArg({', '.join(args)})"
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
@ -281,9 +294,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.PARAM:
if isinstance(self.dtype, ImageDType): return self.dtype.shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
# NOTE: copied from marg
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
return None
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) if len(self.src) >= 1 else None
# wmma output shape = accumulator shape (src[2])
case Ops.WMMA | Ops.SHAPED_WMMA: return self.src[2]._shape
@ -305,7 +316,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) if len(ps) > 0 else ps
return ps
# MULTI marker (axis info in PARAM sources) has no shape
# MULTI marker has no shape
case Ops.MULTI if len(self.src) == 0: return None
# movement ops change the shape
@ -611,11 +622,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.GETTUPLE:
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None
# PARAM: axis is stored as a MULTI source
if self.op is Ops.PARAM:
for s in self.src:
if s.op is Ops.MULTI: return s.arg
return None
if self.op is Ops.PARAM: return self.arg.axis
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
@ -736,6 +743,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret.after(ret.store(src))
@recursive_property
def device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.PARAM: return self.arg.device
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.STAGE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0].device
@ -749,7 +757,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return None
@property
def addrspace(self) -> AddrSpace:
if self.op in {Ops.PARAM, Ops.BUFFER}: return AddrSpace.GLOBAL
if self.op is Ops.PARAM: return self.arg.addrspace
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op is Ops.STACK:
@ -965,7 +974,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# float has NAN issue and we use explicit NAN in transcendental
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST
if self.op is Ops.PARAM and len(self.src) >= 4: return self.src[2].arg, self.src[3].arg
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return self.arg.vmin_vmax
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
@ -1010,7 +1019,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
if len(shape) > 1: ret = ret.reshape(shape)
return ret
def placeholder_like(self, slot:int):
@ -1023,18 +1033,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# TODO: this should replace placeholder
@staticmethod
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None):
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
return UOp(Ops.PARAM, dtype, src, arg=slot)
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None,
addrspace=AddrSpace.GLOBAL, axis:int|None=None):
if shape is not None and axis is not None and isinstance(device, tuple):
shape = tuple(s*len(device) if i == axis else s for i,s in enumerate(shape))
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),)
return UOp(Ops.PARAM, dtype, src, arg=ParamArg(slot, vmin_vmax, name, addrspace, axis, device))
def param_like(self, slot:int):
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self.device, self._min_max, self.src[0].arg[0])
p = UOp.param(slot, self.dtype, self._shape, self.device)
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
return p
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION}
@ -1098,10 +1107,10 @@ class ProgramInfo:
local_size: list[int]|None = [1, 1, 1]
for u in sink.toposort():
if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.PARAM: _globals.append(u.arg)
if u.op is Ops.PARAM: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
@ -1646,7 +1655,7 @@ pm_lower_index_dtype = PatternMatcher([
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg.slot])])
remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool:

View file

@ -34,7 +34,7 @@ def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.NOOP, name="x"))), lambda x: x.arg),
(UPat(Ops.PARAM, name="x"), lambda x: x.arg.name if x.arg.name is not None else f"p{x.arg.slot}"),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
(UPat(Ops.CONST, name="x"), lambda x: str(x.arg)),

View file

@ -1,6 +1,6 @@
import math
from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
from tinygrad.uop.render import print_uops, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
@ -71,7 +71,8 @@ spec_shared = PatternMatcher([
(UPat(Ops.END, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(u.op is Ops.RANGE for u in x.src[1:])),
# PARAM (that's really a DEFINE_GLOBAL)
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg) and isinstance(x.dtype, (PtrDType, ImageDType)) and
x.addrspace == x.dtype.addrspace),
# GROUP of stores (or groups, or NOOPs)
# TODO: remove UNROLL here, it's for SPEC=2
@ -130,9 +131,6 @@ spec_tensor = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
# PARAM (that's really a variable)
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat()), name="x"), lambda x: True),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
@ -148,9 +146,7 @@ spec_tensor = PatternMatcher([
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
# PARAM
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.NOOP)), name="x"), lambda x: True), # TODO: why does this have NOOP?
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="x"), lambda x: True),
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.MULTI)), name="x"), lambda x: True),
(UPat(Ops.PARAM, src=(UPat(),), name="x"), lambda x: isinstance(x.arg, ParamArg)),
# inputs to movement ops
(UPat(Ops.STACK), lambda: True),
@ -264,7 +260,7 @@ from tinygrad.schedule.rangeify import BufferizeOpts
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
"ConstFloat": ConstFloat}
"ConstFloat": ConstFloat, "ParamArg": ParamArg}
def eval_pyrender(code:str) -> UOp:
lcls:dict[str, Any] = {}
exec(code, glbls, lcls)