mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into shrink_in_render
This commit is contained in:
commit
7da2c151be
40 changed files with 260 additions and 263 deletions
101
.github/workflows/test.yml
vendored
101
.github/workflows/test.yml
vendored
|
|
@ -585,8 +585,19 @@ jobs:
|
|||
- name: Test quantize onnx
|
||||
run: DEBUG=2 DEV=DSP python3 test/backend/test_quantize_onnx.py
|
||||
|
||||
testwebgpu:
|
||||
name: Linux (WebGPU)
|
||||
testlinux:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
dev:
|
||||
- 'CPU:CLANGJIT'
|
||||
- 'CPU:LLVM'
|
||||
- 'CPU:LVP'
|
||||
- 'CPU:X86'
|
||||
- 'CL'
|
||||
- 'WEBGPU'
|
||||
|
||||
name: Linux (DEV=${{ matrix.dev }})
|
||||
runs-on: *linux
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
|
|
@ -595,17 +606,21 @@ jobs:
|
|||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: webgpu-minimal
|
||||
key: linux-${{ matrix.dev }}
|
||||
deps: testing_unit
|
||||
python-version: '3.12'
|
||||
webgpu: 'true'
|
||||
- name: Check Device.DEFAULT (WEBGPU) and print some source
|
||||
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANGJIT') }}
|
||||
mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }}
|
||||
webgpu: ${{ matrix.dev == 'WEBGPU' }}
|
||||
opencl: ${{ matrix.dev == 'CL' }}
|
||||
- name: Set env
|
||||
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
DEV=WEBGPU python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
DEV=WEBGPU DEBUG=4 FORWARD_ONLY=1 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run selected webgpu tests
|
||||
run: |
|
||||
DEV=WEBGPU WEBGPU_BACKEND="WGPUBackendType_Vulkan" python3 -m pytest -n=auto test/backend --durations=20
|
||||
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
|
||||
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run backend tests
|
||||
run: python -m pytest -n=auto test/backend --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
@ -757,39 +772,6 @@ jobs:
|
|||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testcpuopencl:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu, opencl, lvp, x86]
|
||||
|
||||
name: Linux (${{ matrix.backend }})
|
||||
runs-on: *linux
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: ${{ matrix.backend }}-minimal
|
||||
deps: testing_unit
|
||||
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
|
||||
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'cpu' || matrix.backend == 'lvp' }}
|
||||
mesa: ${{ matrix.backend == 'lvp' && 'cpu' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
|
||||
DEBUG=5 FORWARD_ONLY=1 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: python -m pytest -n=auto test/backend --durations=20
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/backend/test_ops.py::TestOps::test_sin test/backend/test_ops.py::TestOps::test_cos test/backend/test_ops.py::TestOps::test_tan test/backend/test_ops.py::TestOps::test_exp test/backend/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
# ****** OSX Tests ******
|
||||
|
||||
testmetal:
|
||||
|
|
@ -920,13 +902,17 @@ jobs:
|
|||
|
||||
# ****** Windows Tests ******
|
||||
|
||||
wintests:
|
||||
testwindows:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu, webgpu, x86]
|
||||
dev:
|
||||
- 'CPU:CLANGJIT'
|
||||
- 'CPU:LLVM'
|
||||
- 'CPU:X86'
|
||||
- 'WEBGPU'
|
||||
|
||||
name: Windows (${{ matrix.backend }})
|
||||
name: Windows (DEV=${{ matrix.dev }})
|
||||
runs-on: windows-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
|
@ -935,25 +921,20 @@ jobs:
|
|||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: windows-${{ matrix.backend }}-minimal
|
||||
key: windows-${{ matrix.dev }}-minimal
|
||||
deps: testing_unit
|
||||
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
|
||||
pydeps: ${{ matrix.dev == 'WEBGPU' && 'dawn-python' || '' }}
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
# test_newton_schulz hits RecursionError
|
||||
run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_tar.py --ignore=test/unit/test_linalg.py --durations=20
|
||||
- name: Run NULL backend tests
|
||||
if: matrix.backend=='llvm'
|
||||
shell: bash
|
||||
run: DEV=NULL python -m pytest -n=auto test/null/ --ignore=test/null/test_elf.py --durations=20
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
shell: bash
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU', 'X86':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
python -m pytest -n=auto test/test_tiny.py test/backend/test_ops.py --durations=20
|
||||
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
|
||||
DEBUG=4 python test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run test_tiny
|
||||
shell: bash
|
||||
run: python -m pytest -n=auto test/test_tiny.py --durations=20
|
||||
|
||||
# ****** Compile-only Tests ******
|
||||
|
||||
|
|
|
|||
|
|
@ -1438,13 +1438,14 @@ def train_llama3():
|
|||
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
|
||||
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts] if hasattr(model, "_fp8_grad_amax") else []
|
||||
fp8_inv_scales = list(model._fp8_inv_scale.values())
|
||||
fp8_inv_scales = list(model._fp8_inv_scale.values()) + list(model._fp8_next_inv_scale.values())
|
||||
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
model_state = get_state_dict(model)
|
||||
for wname in model._fp8_inv_scale:
|
||||
w = model_state[wname]
|
||||
w._inv_scale = model._fp8_inv_scale[wname]
|
||||
w._next_inv_scale = model._fp8_next_inv_scale[wname]
|
||||
if optim.master_params:
|
||||
idx = next(j for j, p in enumerate(optim.params) if p is w)
|
||||
master = optim.master_params[idx]
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ class FlatTransformer:
|
|||
w_scales = [("wqkv", s_qkv), ("wo", s_o), ("w2", s_2)]
|
||||
w_scales += [("w1", s_1), ("w3", s_3)] if SPLIT_W13 else [("w13", s_13)]
|
||||
self._fp8_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
|
|
@ -227,7 +228,7 @@ class FlatTransformer:
|
|||
self.w1.shard_(device, axis=1).realize()
|
||||
self.w3.shard_(device, axis=1).realize()
|
||||
else:
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
|
|
@ -241,6 +242,8 @@ class FlatTransformer:
|
|||
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_next_inv_scale:
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor, save:bool=True):
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,8 @@ class GradAccClipAdamW(Optimizer):
|
|||
for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i], self.master_params[i] if self.master_params else None))
|
||||
# collect inv_scale tensors attached to fp8 params (set by _apply_update)
|
||||
fp8_inv_scales = [tt._inv_scale for tt in self.params if hasattr(tt, '_inv_scale')]
|
||||
to_realize = extra+self.params+self.buffers+(self.master_params or [])+fp8_inv_scales
|
||||
fp8_next_inv_scales = [tt._next_inv_scale for tt in self.params if hasattr(tt, '_next_inv_scale')]
|
||||
to_realize = extra+self.params+self.buffers+(self.master_params or [])+fp8_inv_scales+fp8_next_inv_scales
|
||||
|
||||
Tensor.realize(*to_realize)
|
||||
return extra[-1]
|
||||
|
|
@ -89,13 +90,14 @@ class GradAccClipAdamW(Optimizer):
|
|||
if t.dtype in dtypes.fp8s:
|
||||
from examples.mlperf.models.flat_llama import FP8_MAX
|
||||
# delayed scaling: reuse previous step's inv_scale
|
||||
t._inv_scale.assign(t._next_inv_scale)
|
||||
scale = t._inv_scale.reciprocal().reshape(-1, *([1]*(new_w.ndim-1)))
|
||||
scaled = (new_w * scale).clamp(-FP8_MAX, FP8_MAX)
|
||||
ret = scaled.cast(t.dtype)
|
||||
# update inv_scale for next step from quantized result
|
||||
new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale).detach()
|
||||
inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
|
||||
t._inv_scale.assign(inv.shard_like(t._inv_scale) if offloaded else inv)
|
||||
new_inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
|
||||
t._next_inv_scale.assign(new_inv.shard_like(t._next_inv_scale) if offloaded else new_inv)
|
||||
return ret.shard_like(t) if offloaded else ret
|
||||
out = new_w.cast(t.dtype)
|
||||
return out.shard_like(t) if offloaded else out
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
|
|||
|
||||
def name_of(bu:UOp, is_out:bool) -> str:
|
||||
nonlocal n
|
||||
if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg), f"input{bu.arg}", prod(bu.shape)*bu.dtype.itemsize
|
||||
if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg.slot), f"input{bu.arg.slot}", prod(bu.shape)*bu.dtype.itemsize
|
||||
else:
|
||||
b = bu.buffer
|
||||
key, size = (id(b.base), b.offset, b.size, b.dtype), b.size*b.dtype.itemsize
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
|
||||
def test_global_atomic_add_f32_parsing(self):
|
||||
"""Test GLOBAL_ATOMIC_ADD_F32 keeps memory values in float dtype."""
|
||||
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1024), arg=2)
|
||||
vmem = UOp.param(2, dtypes.uint32.ptr(1024))
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint64, 0),
|
||||
'DATA': UOp.const(dtypes.uint32, 0x3f800000),
|
||||
|
|
@ -198,7 +198,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
def test_mem_read_parsing(self):
|
||||
"""Test MEM[addr].type read expression parsing."""
|
||||
# Create a mock LDS buffer
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
lds = UOp.param(3, dtypes.uint32.ptr(16384))
|
||||
addr = UOp.const(dtypes.uint32, 0)
|
||||
vrs = {'_lds': lds, 'ADDR': addr, 'OFFSET': UOp.const(dtypes.uint32, 0)}
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ class TestDSPcodePatterns(unittest.TestCase):
|
|||
pcode = PCODE.get(DSOp.DS_LOAD_2ADDR_B32)
|
||||
self.assertIsNotNone(pcode)
|
||||
assert pcode is not None
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
lds = UOp.param(3, dtypes.uint32.ptr(16384))
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint32, 0),
|
||||
'OFFSET0': UOp.const(dtypes.uint32, 0),
|
||||
|
|
@ -314,7 +314,7 @@ class TestConcatWidthParsing(unittest.TestCase):
|
|||
self.assertEqual(parsed.simplify().arg, expected)
|
||||
|
||||
def test_permlane64_wave64_pcode_indices(self):
|
||||
vgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(256), arg=0)
|
||||
vgpr = UOp.param(0, dtypes.uint32.ptr(256))
|
||||
srcs = {
|
||||
'SRC0': UOp.const(dtypes.uint32, 0),
|
||||
'VDST': UOp.const(dtypes.uint32, 1),
|
||||
|
|
@ -347,7 +347,7 @@ class TestAllPcode(unittest.TestCase):
|
|||
def _make_srcs(self):
|
||||
"""Create dummy source variables for pcode parsing."""
|
||||
u32, u64 = lambda v=0: UOp.const(dtypes.uint32, v), lambda v=0: UOp.const(dtypes.uint64, v)
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
lds = UOp.param(3, dtypes.uint32.ptr(16384))
|
||||
return {'laneId': u32(), 'laneID': u32(), 'S0': u32(), 'S1': u32(), 'S2': u32(), 'S3': u32(), 'SRC0': u32(),
|
||||
'D0': u32(), 'D1': u32(), 'DST': u32(), 'VDST': u32(), 'SDST': u32(),
|
||||
'VCC': u64(), 'VCCZ': u32(), 'EXEC': u64(), 'EXEC_LO': u32(), 'EXECZ': u32(), 'SCC': u32(),
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class TestIselX86(unittest.TestCase):
|
|||
# need to move src from gpr to xmm before broadcasting
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
|
||||
# if we can fuse a load we can skip the move and access memory directly
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
n = self.isel_rewrite(load.broadcast(4))
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
|
||||
|
||||
|
|
@ -122,20 +122,20 @@ class TestIselX86(unittest.TestCase):
|
|||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
|
||||
n = self.isel_rewrite(load)
|
||||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
|
||||
|
||||
def test_fold_load(self):
|
||||
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
|
||||
# don't fold when used multiple times
|
||||
def test_dont_fold_load(self):
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
# used by multiple users
|
||||
n = self.isel_rewrite(load + 1 + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
|
@ -144,4 +144,4 @@ class TestIselX86(unittest.TestCase):
|
|||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from tinygrad.codegen import to_program
|
|||
class TestLinearizerFailure(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
|
||||
def test_failure_beam_mnist(self):
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(4014080), arg=0, src=())
|
||||
c0 = UOp.param(0, dtypes.uchar.ptr(4014080))
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.weakint, 10), 3, AxisType.GLOBAL)
|
||||
c4 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = UOp.param(1, dtypes.int.ptr(512))
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
|
||||
c6 = UOp.range(UOp.const(dtypes.weakint, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.weakint, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.weakint, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c9 = UOp(Ops.PARAM, dtypes.uchar.ptr(47040000), arg=2, src=())
|
||||
c9 = UOp.param(2, dtypes.uchar.ptr(47040000))
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.weakint, 4704000))+c2)+(c6*UOp.const(dtypes.weakint, 784))).valid(UOp.const(dtypes.bool, True)))
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.weakint, 6000))+c6)+((c7*UOp.const(dtypes.weakint, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.weakint, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.weakint, 7840))+(c2*UOp.const(dtypes.weakint, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
|
|||
|
||||
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
dtype = alu_src_uops[0].dtype
|
||||
a = UOp(Ops.PARAM, dtype.ptr(), (), 0)
|
||||
b = UOp(Ops.PARAM, dtype.ptr(), (), 1)
|
||||
a = UOp.param(0, dtype.ptr())
|
||||
b = UOp.param(1, dtype.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = b.index(idx)
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
|
|
@ -33,7 +33,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
|||
class TestRendererFailures(unittest.TestCase):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
|
||||
|
|
@ -42,7 +42,7 @@ class TestRendererFailures(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu_2d(self):
|
||||
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||
|
|
@ -87,7 +87,7 @@ class TestWGSLFailures(unittest.TestCase):
|
|||
class TestPTXFailures(unittest.TestCase):
|
||||
@unittest.skip("INDEX can only have a gate ALU parent, not an IF")
|
||||
def test_gated_store_with_if(self):
|
||||
a = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
val = UOp.const(dtypes.int, 1)
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ def run_uops(uops_list:list[UOp], bufs:list[Buffer]):
|
|||
|
||||
def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
if op is Ops.CONST: uops.append(UOp.const(dtype, arg))
|
||||
elif op is Ops.PARAM: uops.append(UOp.param(arg, dtype).replace(src=()))
|
||||
else: uops.append(UOp(op, dtype, tuple(src), arg))
|
||||
return uops[-1]
|
||||
|
||||
|
|
@ -220,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
|
|||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_bitshift_left(self):
|
||||
g1 = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 0)
|
||||
out = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 1)
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
out = UOp.param(1, dtypes.int32.ptr())
|
||||
c1 = UOp.const(dtypes.int, 2)
|
||||
c2 = UOp.const(dtypes.int, 3)
|
||||
l1 = g1.index(c1)
|
||||
|
|
@ -248,7 +249,7 @@ class TestAssembly(unittest.TestCase):
|
|||
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
|
||||
|
||||
def test_mulacc_shl(self):
|
||||
g1 = UOp(Ops.PARAM, dtypes.int32.ptr(), (), 0)
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
c1 = UOp.const(dtypes.int, 0)
|
||||
c2 = UOp.const(dtypes.int, 1)
|
||||
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
|
||||
|
|
@ -257,7 +258,7 @@ class TestAssembly(unittest.TestCase):
|
|||
self.assertIn(Ops.MULACC, [x.op for x in uops])
|
||||
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
|
|||
26
test/external/external_benchmark_op_conv.py
vendored
26
test/external/external_benchmark_op_conv.py
vendored
|
|
@ -12,7 +12,7 @@ from tinygrad.dtype import ImageDType, Invalid
|
|||
# PYTHONPATH="." DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
|
||||
def vision_conv_143():
|
||||
c0 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 0)
|
||||
c0 = UOp.param(0, dtypes.imageh((16, 1024, 4)))
|
||||
c2 = UOp.range(32, 3, AxisType.LOOP)
|
||||
c5 = UOp.range(128, 4, AxisType.LOOP)
|
||||
c8 = UOp.range(16, 2, AxisType.LOOP)
|
||||
|
|
@ -22,13 +22,13 @@ def vision_conv_143():
|
|||
c26 = UOp.range(7, 1, AxisType.REDUCE)
|
||||
c27 = c2*2+c26
|
||||
c32 = ((c27<3)!=True)&(c27<67)
|
||||
c34 = UOp(Ops.PARAM, dtypes.imageh((32, 1024, 4)), (), 1)
|
||||
c34 = UOp.param(1, dtypes.imageh((32, 1024, 4)))
|
||||
c38 = c5//2
|
||||
c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.weakint, Invalid))
|
||||
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
|
||||
c49 = UOp(Ops.PARAM, dtypes.imageh((64, 49, 4)), (), 2)
|
||||
c49 = UOp.param(2, dtypes.imageh((64, 49, 4)))
|
||||
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
|
||||
c63 = UOp(Ops.PARAM, dtypes.float.ptr(128), (), 3)
|
||||
c63 = UOp.param(3, dtypes.float.ptr(128))
|
||||
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
|
||||
c67 = c0.index((c2*128+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ def vision_conv_143():
|
|||
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
|
||||
|
||||
def vision_conv_153():
|
||||
c0 = UOp(Ops.PARAM, dtypes.imageh((8, 1024, 4)), (), 0)
|
||||
c0 = UOp.param(0, dtypes.imageh((8, 1024, 4)))
|
||||
c2 = UOp.range(16, 3, AxisType.LOOP)
|
||||
c5 = UOp.range(256, 4, AxisType.LOOP)
|
||||
c8 = UOp.range(8, 2, AxisType.LOOP)
|
||||
|
|
@ -48,13 +48,13 @@ def vision_conv_153():
|
|||
c26 = UOp.range(7, 1, AxisType.REDUCE)
|
||||
c27 = c2*2+c26
|
||||
c32 = ((c27<3)!=True)&(c27<35)
|
||||
c34 = UOp(Ops.PARAM, dtypes.imageh((16, 1024, 4)), (), 1)
|
||||
c34 = UOp.param(1, dtypes.imageh((16, 1024, 4)))
|
||||
c38 = c5//2
|
||||
c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.weakint, Invalid))
|
||||
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
|
||||
c49 = UOp(Ops.PARAM, dtypes.imageh((128, 49, 4)), (), 2)
|
||||
c49 = UOp.param(2, dtypes.imageh((128, 49, 4)))
|
||||
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
|
||||
c63 = UOp(Ops.PARAM, dtypes.float.ptr(256), (), 3)
|
||||
c63 = UOp.param(3, dtypes.float.ptr(256))
|
||||
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
|
||||
c67 = c0.index((c2*256+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
|
||||
|
||||
|
|
@ -64,16 +64,16 @@ def vision_conv_153():
|
|||
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
|
||||
|
||||
def dm_conv_172():
|
||||
c0 = UOp(Ops.PARAM, dtypes.imageh((1, 240, 4)), (), 0)
|
||||
c0 = UOp.param(0, dtypes.imageh((1, 240, 4)))
|
||||
c2 = UOp.range(960, 4, AxisType.LOOP)
|
||||
c5 = UOp(Ops.PARAM, dtypes.imageh((8, 384, 4)), (), 1)
|
||||
c5 = UOp.param(1, dtypes.imageh((8, 384, 4)))
|
||||
c7 = UOp.range(32, 0, AxisType.REDUCE)
|
||||
c10 = UOp.range(4, 1, AxisType.REDUCE)
|
||||
c13 = UOp.range(12, 3, AxisType.REDUCE)
|
||||
c18 = UOp.range(8, 2, AxisType.REDUCE)
|
||||
c23 = UOp(Ops.PARAM, dtypes.imageh((240, 128, 4)), (), 2)
|
||||
c23 = UOp.param(2, dtypes.imageh((240, 128, 4)))
|
||||
c35 = c5.index((c7*4+c10+c13*128+c18*1536))*c23.index((c10*4+c2%4+c7*16+c2//4*512))
|
||||
c37 = UOp(Ops.PARAM, dtypes.float.ptr(960), (), 3)
|
||||
c37 = UOp.param(3, dtypes.float.ptr(960))
|
||||
c39 = c35.reduce(c7, c10, arg=Ops.ADD)+c37.index(c2)
|
||||
c50 = (1.0+((c39+0.044708251953125*(c39*(c39*c39)))*-2.3021129851685216).exp2()).reciprocal()*c39
|
||||
c53 = c50.reduce(c18, c13, arg=Ops.ADD)*0.010416666666666666
|
||||
|
|
@ -99,4 +99,4 @@ bufs = [Buffer(ps.arg.device, g.size, g.dtype if isinstance(g.dtype, ImageDType)
|
|||
|
||||
gsize, lsize = ps.arg.launch_dims({})
|
||||
t = rt(*[b._buf for b in bufs], global_size=gsize, local_size=lsize, vals=ps.arg.vals({}), wait=True)
|
||||
print(f"{t*1e6:.2f} us")
|
||||
print(f"{t*1e6:.2f} us")
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
|
|||
for buf_dt, data in inputs or []:
|
||||
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
|
||||
g = UOp(Ops.PARAM, uop.dtype.ptr(), arg=0, src=())
|
||||
g = UOp.param(0, uop.dtype.ptr())
|
||||
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
|
||||
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
|
||||
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)
|
||||
|
|
|
|||
|
|
@ -423,10 +423,10 @@ def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode
|
|||
class _Ctx:
|
||||
"""Context for instruction compilation - holds buffers and helpers."""
|
||||
__slots__ = ('inst_size', 'dyn_fields', '_axis_id', 'wave_size', 'vgpr', 'accvgpr')
|
||||
sgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(SGPR_COUNT), arg=0)
|
||||
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1 << 46), arg=2)
|
||||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||||
scratch = UOp(Ops.PARAM, dtypes.uint8.ptr(1 << 30), arg=4)
|
||||
sgpr = UOp.param(0, dtypes.uint32.ptr(SGPR_COUNT))
|
||||
vmem = UOp.param(2, dtypes.uint32.ptr(1 << 46))
|
||||
lds = UOp.param(3, dtypes.uint32.ptr(16384))
|
||||
scratch = UOp.param(4, dtypes.uint8.ptr(1 << 30))
|
||||
# Cache PARAM UOps by wave_size so all _Ctx instances with same wave_size share identical UOp references
|
||||
_vgpr_cache: dict[int, UOp] = {}
|
||||
_accvgpr_cache: dict[int, UOp] = {}
|
||||
|
|
@ -434,10 +434,10 @@ class _Ctx:
|
|||
def __init__(self, inst_size: int, wave_size: int = 32):
|
||||
self.inst_size, self._axis_id, self.wave_size = inst_size, 0, wave_size
|
||||
self.dyn_fields: list[tuple[int, int]] = [] # (lo, hi) of fields read dynamically
|
||||
if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=1)
|
||||
if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp.param(1, dtypes.uint32.ptr(256 * wave_size))
|
||||
self.vgpr = _Ctx._vgpr_cache[wave_size]
|
||||
if wave_size == 64:
|
||||
if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=5)
|
||||
if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp.param(5, dtypes.uint32.ptr(256 * wave_size))
|
||||
self.accvgpr = _Ctx._accvgpr_cache[wave_size]
|
||||
else:
|
||||
self.accvgpr = self.vgpr
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TestGroupedDims(unittest.TestCase):
|
|||
|
||||
def test_global_prod_max(self):
|
||||
g, l = UOp.range(256, 0, AxisType.GLOBAL), UOp.range(256, 1, AxisType.LOCAL)
|
||||
sink = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
|
||||
sink = UOp.param(0, dtypes.float.ptr()).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
|
||||
class R(Renderer): global_max, local_max, global_prod_max = (256, 256, 256), (128, 128, 128), (128, 128, 128)
|
||||
specials = [u for u in add_gpudims(R(Target()), sink).toposort() if u.op is Ops.SPECIAL]
|
||||
self.assertGreater(len([s for s in specials if "lidx" in s.arg]), 1)
|
||||
|
|
|
|||
|
|
@ -7,14 +7,14 @@ from tinygrad.codegen import to_program
|
|||
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_fail_1(self):
|
||||
c0 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=0, src=())
|
||||
c0 = UOp.param(0, dtypes.float.ptr(64))
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 2), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 32), 2, AxisType.LOOP)
|
||||
c3 = ((c1*UOp.const(dtypes.weakint, 32))+c2)
|
||||
c4 = UOp(Ops.PARAM, dtypes.float.ptr(163840), arg=1, src=())
|
||||
c4 = UOp.param(1, dtypes.float.ptr(163840))
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.weakint, 8))%UOp.const(dtypes.weakint, 8))*UOp.const(dtypes.weakint, 8))+(c5%UOp.const(dtypes.weakint, 8)))+(((c2*UOp.const(dtypes.weakint, 40))+(c5//UOp.const(dtypes.weakint, 64)))*UOp.const(dtypes.weakint, 64)))+(c1*UOp.const(dtypes.weakint, 81920))))
|
||||
c7 = UOp(Ops.PARAM, dtypes.float.ptr(64), arg=2, src=())
|
||||
c7 = UOp.param(2, dtypes.float.ptr(64))
|
||||
c8 = c7.index(c3)
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
c10 = c0.index(c3).store(c9).end(c1, c2)
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
|
|||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
|
||||
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
||||
UOp.const(dtypes.float, 0.0)
|
||||
))
|
||||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
||||
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
||||
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
|
|
@ -513,7 +513,7 @@ class TestDropTrueGate(unittest.TestCase):
|
|||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
buf = UOp.param(0, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.weakint, 0)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate)))
|
||||
|
|
@ -557,7 +557,7 @@ class TestRangeShrink(unittest.TestCase):
|
|||
# one load guards r < 4, but another load uses r without a gate -> no shrink
|
||||
r = Range(0, 204)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
|
||||
load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),))
|
||||
load2 = UOp(Ops.LOAD, dtypes.float, (UOp.param(1, dtypes.float.ptr()).index(r, ptr=True),))
|
||||
ranges = self.get_ranges(UOp.sink(load1, load2))
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 204)
|
||||
|
|
@ -583,7 +583,7 @@ class TestRangeShrink(unittest.TestCase):
|
|||
from tinygrad.dtype import Invalid
|
||||
r = Range(0, 204)
|
||||
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
|
||||
ranges = self.get_ranges(UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(r).store((r < 4).where(x, 0)).sink())
|
||||
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(x, 0)).sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 4)
|
||||
|
||||
|
|
@ -592,7 +592,7 @@ class TestRangeShrink(unittest.TestCase):
|
|||
from tinygrad.dtype import Invalid
|
||||
r = Range(0, 204)
|
||||
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
|
||||
ranges = self.get_ranges(UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(r).store((r < 4).where(0, x)).sink())
|
||||
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(0, x)).sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 4)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.decompositions import TRANSCENDENTAL_DTYPES, payne_hanek_reduction, cody_waite_reduction
|
||||
from tinygrad.uop.decompositions import frexp, rintk, xpow, xexp2, xlog2, trig_poly, pow2if
|
||||
from test.helpers import eval_uop
|
||||
|
|
@ -10,7 +10,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
|||
def test_payne_hanek_reduction(self):
|
||||
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
|
||||
# Load input value from a buffer to prevent constant folding
|
||||
input_buf = UOp(Ops.PARAM, dtypes.double.ptr(), arg=1, src=())
|
||||
input_buf = UOp.param(1, dtypes.double.ptr())
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
|
||||
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
|
||||
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
@unittest.skip("this test isn't valid uops")
|
||||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), arg=0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
vec = UOp(Ops.STACK, dtypes.float.vec(2), (ld,))
|
||||
|
|
@ -272,9 +272,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
@unittest.skip("this test isn't valid uops")
|
||||
def test_gep_vec_fold(self):
|
||||
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
d1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
|
||||
d2 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 2)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d1 = UOp.param(1, dtypes.float.ptr())
|
||||
d2 = UOp.param(2, dtypes.float.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
def _test_vec(geps, count=4):
|
||||
vec = UOp(Ops.STACK, dtypes.float.vec(count), geps)
|
||||
|
|
@ -380,8 +380,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(Ops.PARAM, dtypes.bool.ptr(), arg=0)
|
||||
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), arg=1)
|
||||
d0 = UOp.param(0, dtypes.bool.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
|
|
@ -390,8 +390,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), arg=0)
|
||||
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), arg=1)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
|
|
@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_bitcast_to_same_dtype_fold(self):
|
||||
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
|
||||
d0 = UOp(Ops.PARAM, dt.ptr(), arg=0)
|
||||
d0 = UOp.param(0, dt.ptr())
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
|
||||
|
|
@ -427,10 +427,10 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
out = UOp(Ops.PARAM, dtypes.long.ptr(), (), 1)
|
||||
out = UOp.param(1, dtypes.long.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
|
@ -438,7 +438,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
|
|
@ -448,11 +448,11 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_gated_load_with_cast(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.int.ptr())
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
out = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
|
||||
out = UOp.param(1, dtypes.float.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
|
@ -460,27 +460,27 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_where_on_casted_gated_load_extra_cond(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
|
||||
out = UOp(Ops.PARAM, dtypes.half.ptr(), (), 1)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond_swapped(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
|
||||
out = UOp(Ops.PARAM, dtypes.half.ptr(), (), 1)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_in_store_becomes_gate(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
idx = d0.index(ridx0)
|
||||
ld = idx.load()
|
||||
val = (ridx0<50).where(5, ld)
|
||||
|
|
@ -493,14 +493,14 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_load_idx_becomes_int(self):
|
||||
# mnist indexing with split reduceop
|
||||
# Make sure we are not doign math on the loaded index, which would promote it to long
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
|
||||
c0 = UOp.param(0, dtypes.uchar.ptr(128000))
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c3 = UOp.param(1, dtypes.int.ptr(512))
|
||||
c4 = c3.index(c1)
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
|
||||
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c7 = UOp.param(2, dtypes.uchar.ptr(60000))
|
||||
c8 = c7.index(c6)
|
||||
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.weakint, 250))+c2)).store(c9).end(c1, c2)
|
||||
|
|
@ -510,14 +510,14 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_load_idx_no_math_on_loaded(self):
|
||||
# test the (x+y)<c pattern where x has loads - we shouldn't do math on loaded indices
|
||||
c0 = UOp(Ops.PARAM, dtypes.uchar.ptr(128000), arg=0, src=())
|
||||
c0 = UOp.param(0, dtypes.uchar.ptr(128000))
|
||||
c1 = UOp.range(UOp.const(dtypes.weakint, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.weakint, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.PARAM, dtypes.int.ptr(512), arg=1, src=())
|
||||
c3 = UOp.param(1, dtypes.int.ptr(512))
|
||||
c4 = c3.index(c1) # c4 is a load
|
||||
c5 = UOp.range(UOp.const(dtypes.weakint, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.weakint, 240))+c5)
|
||||
c7 = UOp(Ops.PARAM, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c7 = UOp.param(2, dtypes.uchar.ptr(60000))
|
||||
c8 = c7.index(c6)
|
||||
# (loaded + range) < const pattern - loaded value shouldn't be promoted to long
|
||||
loaded_idx = c4.cast(dtypes.weakint)
|
||||
|
|
@ -529,9 +529,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
glbl1 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 2)
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
glbl1 = UOp.param(1, dtypes.int.ptr())
|
||||
glbl2 = UOp.param(2, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
|
|
@ -541,7 +541,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
|
|
@ -555,7 +555,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
|
|
@ -563,12 +563,12 @@ class TestUOpGraph(unittest.TestCase):
|
|||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 6) # +1 for SINK
|
||||
self.assertEqual(len(uops), 7) # +1 for SINK, +1 for PARAM shape sentinel
|
||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
|
|
@ -779,7 +779,7 @@ class TestLoadStoreFolding(unittest.TestCase):
|
|||
def test_gated_load_gep_preserves_alt(self):
|
||||
"""Test that LOAD(GEP, alt) preserves alt value after rewrite"""
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding
|
||||
buf = UOp(Ops.PARAM, dtypes.float.vec(4).ptr(), (), 0)
|
||||
buf = UOp.param(0, dtypes.float.vec(4).ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
gate = UOp.const(dtypes.bool, True)
|
||||
gated_index = buf.index(idx.valid(gate))
|
||||
|
|
@ -797,8 +797,8 @@ class TestLoadStoreFolding(unittest.TestCase):
|
|||
def test_gated_load_ptrcat_preserves_alt(self):
|
||||
"""Test that LOAD(PTRCAT, alt) preserves alt value after rewrite"""
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding
|
||||
buf1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
buf2 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
|
||||
buf1 = UOp.param(0, dtypes.float.ptr())
|
||||
buf2 = UOp.param(1, dtypes.float.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
idx1 = buf1.index(idx)
|
||||
idx2 = buf2.index(idx)
|
||||
|
|
|
|||
|
|
@ -951,7 +951,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
expr = cond.where(a, b).cast(dtypes.half)
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
|
|
@ -1270,7 +1270,7 @@ class TestStoreLoadFolding(unittest.TestCase):
|
|||
"""Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification."""
|
||||
def test_store_load_folding(self):
|
||||
# store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0)
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
buf = UOp.param(0, dtypes.int.ptr())
|
||||
index = buf.index(UOp.const(dtypes.weakint, 0))
|
||||
# Direct: store(idx, load(idx)) -> NOOP
|
||||
self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP)
|
||||
|
|
@ -1340,7 +1340,7 @@ class TestRangeSplitting(unittest.TestCase):
|
|||
from tinygrad.codegen.simplify import pm_split_ranges, pm_flatten_range
|
||||
r0 = UOp.range(uconst(8), 0)
|
||||
# create a simple expression using the range with mod: store range%2 to a buffer
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
buf = UOp.param(0, dtypes.int.ptr())
|
||||
val = (r0 % uconst(2)).cast(dtypes.int)
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf.index(uconst(0)), val))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.END, dtypes.void, (store, r0)),))
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestVminVmaxProperties(unittest.TestCase):
|
|||
def test_vmin_vmax_multiplication_0_inf(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.const(dtypes.float, 0.0)
|
||||
y = UOp.load(UOp(Ops.PARAM, dtypes.float.ptr(1), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
|
||||
y = UOp.load(UOp.param(0, dtypes.float.ptr(1)), UOp.const(dtypes.int, 0), dtype=dtypes.float)
|
||||
uop = x * y
|
||||
# TODO: these should be 0, but definitely should not be nan
|
||||
self.assertEqual(uop.vmin, -math.inf)
|
||||
|
|
@ -316,7 +316,7 @@ class TestVminVmaxVConst(unittest.TestCase):
|
|||
|
||||
def test_vmin_vmax_vector_with_gep(self):
|
||||
# vmin and vmax for a vector constant of bool values
|
||||
d1 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 1)
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
val = UOp(Ops.LOAD, dtypes.int.vec(2), (d1.index(idx).cast(dtypes.int.vec(2).ptr()),))
|
||||
uop = (val // 32).gep(0)
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class TestExecALU(unittest.TestCase):
|
|||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
gmem = UOp.param(0, dtypes.float.ptr())
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
|
|
@ -126,8 +126,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
self.assertEqual(len(gated_uops[-1].src), 2)
|
||||
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
|
|
@ -146,8 +146,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||
@unittest.skip("we don't merge ifs anymore")
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.PARAM, dtypes.float.ptr(), (), 1)
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
|
|
@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
class TestFastIdiv(unittest.TestCase):
|
||||
def test_division_power_of_two(self):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
|
||||
g = UOp.param(0, dt.ptr())
|
||||
c = UOp.const(dt, 2)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dt, (l, c))
|
||||
|
|
@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
def test_floormod_power_of_two(self):
|
||||
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
|
||||
g = UOp.param(0, dt.ptr())
|
||||
c = UOp.const(dt, 8)
|
||||
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
def test_floordiv_power_of_two_uint(self):
|
||||
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
|
||||
for dt in (dtypes.uint32, dtypes.uint64):
|
||||
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
|
||||
g = UOp.param(0, dt.ptr())
|
||||
c = UOp.const(dt, 2)
|
||||
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
|
|
@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
@Context(DISABLE_FAST_IDIV=0)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
@unittest.expectedFailure
|
||||
def test_fast_idiv_overflow(self):
|
||||
# This will be possible with a slightly different method for fast_idiv
|
||||
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
|
|||
self.assertNotIn(Ops.CDIV, ops)
|
||||
|
||||
def test_disable_fast_idiv(self):
|
||||
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
|
@ -290,8 +290,8 @@ class TestUOpMethod(unittest.TestCase):
|
|||
self.assertEqual((gidx0*3+1).const_factor(), 1)
|
||||
|
||||
def test_replace(self):
|
||||
x = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
self.assertIs(x.replace(arg=None).arg, None)
|
||||
x = UOp.param(0, dtypes.int.ptr())
|
||||
self.assertEqual(x.replace(arg=UOp.param(1, dtypes.int.ptr()).arg).arg.slot, 1)
|
||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||
|
||||
def test_const_zero_neg_zero_different(self):
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
globl = UOp(Ops.PARAM, dtypes.int.ptr(), tuple())
|
||||
globl = UOp.param(0, dtypes.int.ptr())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = globl.index(o1)
|
||||
|
|
@ -149,7 +149,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
|
||||
uops = tuple(u5.toposort())
|
||||
|
||||
globl = UOp(Ops.PARAM, dtypes.int.ptr(), tuple())
|
||||
globl = UOp.param(0, dtypes.int.ptr())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = globl.index(o1)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
# basic index patterns
|
||||
def test_const_index(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -21,7 +21,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_variable_index(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||
|
|
@ -30,7 +30,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_range_with_mask(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -38,7 +38,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_variable_with_mask(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
v = Variable("v", -5, 80)
|
||||
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -46,7 +46,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_gated_store(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
v = Variable("v", 0, 20)
|
||||
to_uops_list([buf.index(v.valid(v < 16), ptr=True).store(0)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -55,14 +55,14 @@ class TestValidateOOB(unittest.TestCase):
|
|||
# ALU ops in index
|
||||
def test_floordiv(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
|
||||
|
||||
def test_mod(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -70,14 +70,14 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_shr(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
||||
|
||||
def test_shl(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(64))
|
||||
r = UOp.range(8, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -85,7 +85,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_and(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -93,14 +93,14 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_max(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
|
||||
|
||||
def test_xor_in_mask(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
r = UOp.range(32, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -109,22 +109,22 @@ class TestValidateOOB(unittest.TestCase):
|
|||
# cast patterns
|
||||
def test_float_cast_in_index(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(16))
|
||||
r = UOp.range(20, 0)
|
||||
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
|
||||
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
|
||||
|
||||
def test_bool_cast_in_mask(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(1), (), 0)
|
||||
buf = UOp.param(0, dtypes.int.ptr(1))
|
||||
r = UOp.range(20, 0)
|
||||
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
|
||||
|
||||
# load result as index/mask
|
||||
def test_load_as_index(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
buf1 = UOp(Ops.PARAM, dtypes.int.ptr(64), (), 1)
|
||||
buf0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
buf1 = UOp.param(1, dtypes.int.ptr(64))
|
||||
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.weakint)
|
||||
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
|
|
@ -133,8 +133,8 @@ class TestValidateOOB(unittest.TestCase):
|
|||
|
||||
def test_load_bool_as_mask(self):
|
||||
with Context(CHECK_OOB=1, SPEC=2):
|
||||
buf_bool = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0)
|
||||
buf_int = UOp(Ops.PARAM, dtypes.int.ptr(8), (), 1)
|
||||
buf_bool = UOp.param(0, dtypes.bool.ptr(16))
|
||||
buf_int = UOp.param(1, dtypes.int.ptr(8))
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, 16),), "gidx0")
|
||||
ld_bool = buf_bool.index(gidx, ptr=True).load()
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -145,7 +145,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
def test_in_bounds_access_gated_local(self):
|
||||
with Context(CHECK_OOB=1):
|
||||
# Define buffers
|
||||
gbuf = UOp(Ops.PARAM, dtypes.uint.ptr(400), (), 0)
|
||||
gbuf = UOp.param(0, dtypes.uint.ptr(400))
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||
|
||||
# Define indices, valids and barrier
|
||||
|
|
@ -169,8 +169,8 @@ class TestValidateOOB(unittest.TestCase):
|
|||
@unittest.skip("Bool load is not supported yet")
|
||||
def test_load_mask(self):
|
||||
with Context(CHECK_OOB=1):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.PARAM, dtypes.bool.ptr(16), (), 0)
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
mask = UOp.param(0, dtypes.bool.ptr(16))
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||
to_uops_list([ld0])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import VIZ, pluralize, all_int
|
||||
|
||||
|
|
@ -176,7 +176,8 @@ def finalize_after(ctx:AllocCtx, x:UOp):
|
|||
def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
ctx.replacements.append(b)
|
||||
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
|
||||
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.AFTER, name="x"), finalize_after),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
extra = None
|
||||
match u.op:
|
||||
# the order and placement of these defines is important
|
||||
case Ops.PARAM: priority, extra = -20, u.arg
|
||||
case Ops.PARAM: priority, extra = -20, u.arg.slot
|
||||
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
|
||||
case Ops.DEFINE_REG: priority = -18
|
||||
case Ops.DEFINE_LOCAL: priority = -17
|
||||
|
|
@ -93,4 +93,4 @@ def do_split_ends(e:UOp):
|
|||
pm_split_ends = PatternMatcher([
|
||||
# split the ends
|
||||
(UPat(Ops.END, name="e"), do_split_ends),
|
||||
])
|
||||
])
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ class Scheduler:
|
|||
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
||||
|
||||
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
||||
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg)
|
||||
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg.slot)
|
||||
return [Buffer(dname, x.max_numel(), x.dtype.base) for x in glbls]
|
||||
|
||||
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class GraphRunner:
|
|||
self.runtimes: list[Any|None] = []
|
||||
self.uop_replace: list[list[tuple[int, int]]] = []
|
||||
for call in self.linear.src:
|
||||
replace = [(p, b.arg) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM]
|
||||
replace = [(p, b.arg.slot) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM]
|
||||
for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
|
||||
self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
|
||||
self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
|
||||
|
|
|
|||
|
|
@ -137,8 +137,8 @@ class ExecContext:
|
|||
cache: bool = True
|
||||
|
||||
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
|
||||
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
|
||||
return inputs[b.arg] if b.op is Ops.PARAM else b
|
||||
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg.slot], *b.src[1:]))
|
||||
return inputs[b.arg.slot] if b.op is Ops.PARAM else b
|
||||
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)]
|
||||
|
||||
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
|||
|
||||
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
|
||||
"""Remove unused PARAMs from body and return compacted (body, args)."""
|
||||
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
|
||||
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
|
||||
used = sorted({p.arg.slot: p for p in body.toposort() if p.op is Ops.PARAM}.items())
|
||||
body = body.substitute({p: p.replace(arg=dataclasses.replace(p.arg, slot=j)) for j,(_, p) in enumerate(used)}, walk=True)
|
||||
return body, tuple(all_args[i] for i,_ in used)
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
|
|
@ -29,7 +30,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|||
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
|
||||
return (None,) + k.arg.grad_fxn(on_dev(ctx, 0), k)
|
||||
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
|
||||
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||
|
|
|
|||
|
|
@ -179,8 +179,8 @@ class CStyleLanguage(Renderer):
|
|||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
if u.op is not Ops.PARAM: r[u] = u.arg[0]
|
||||
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
else: r[u] = f"data{u.arg}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg}"
|
||||
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
|
||||
bufs[u] = (r[u], (u.dtype, u in writable_params))
|
||||
continue
|
||||
|
||||
|
|
@ -321,8 +321,8 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
def aux(self, uops:list[UOp]):
|
||||
arg_dtypes:list[list[tuple[int, DType]]] = []
|
||||
for i,u in enumerate(u for u in uops if u.op is Ops.PARAM):
|
||||
if len(arg_dtypes) >= u.arg: arg_dtypes.append([])
|
||||
arg_dtypes[u.arg].append((i, u.dtype))
|
||||
while len(arg_dtypes) <= u.arg.slot: arg_dtypes.append([])
|
||||
arg_dtypes[u.arg.slot].append((i, u.dtype))
|
||||
return tuple(tuple(a) for a in arg_dtypes),
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ class LLVMRenderer(Renderer):
|
|||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
args.append((r[u], u.dtype))
|
||||
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
||||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def mem_type(x:UOp) -> str:
|
|||
match x.op:
|
||||
case Ops.AFTER: return mem_type(x.src[0])
|
||||
case Ops.DEFINE_LOCAL: return 'shared'
|
||||
case Ops.PARAM: return 'global'
|
||||
case Ops.PARAM: return 'shared' if x.addrspace == AddrSpace.LOCAL else 'global'
|
||||
case _: raise RuntimeError(f"{x.op} needs to be memory")
|
||||
|
||||
def render_wmma(ctx: "PTXRenderer", wmma: UOp):
|
||||
|
|
@ -90,7 +90,7 @@ string_rewrite = PatternMatcher([
|
|||
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
|
||||
(UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg}, %{'ctaid' if x.arg[0] == 'g' else 'tid'}.{chr(120+int(x.arg[-1]))};"),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg.slot}+0];"),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
|
||||
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
||||
|
|
@ -219,7 +219,7 @@ class PTXRenderer(Renderer):
|
|||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
|
||||
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg.slot}", u.dtype))
|
||||
elif u.op is Ops.WMMA:
|
||||
# registers for packing/unpacking input and acc
|
||||
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
|
|||
def dcache_flush():
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.codegen import to_program
|
||||
buf, n = UOp(Ops.PARAM, dtypes.uint8.ptr(), arg=0), UOp(Ops.PARAM, dtypes.uint8.ptr(), arg=1)
|
||||
buf, n = UOp.param(0, dtypes.uint8.ptr()), UOp.param(1, dtypes.uint8.ptr())
|
||||
i = UOp.range(n.cast(dtypes.int), 0, dtype=dtypes.int)
|
||||
flush = UOp(Ops.CUSTOM, dtypes.void, (buf.cast(dtypes.ulong) + i.cast(dtypes.ulong) * UOp.const(dtypes.ulong, 64),),
|
||||
arg='__asm__ volatile("dc cvac, %0" :: "r"({0}) : "memory");')
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
|||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg.slot]),
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ def rewrite_into_function(call:UOp):
|
|||
|
||||
def param_to_multi(p:UOp):
|
||||
if p.axis is None: return None
|
||||
return UOp.param(p.arg, p.dtype, p.shard_shape, p.device).multi(p.axis)
|
||||
return UOp.param(p.arg.slot, p.dtype, p.shard_shape, p.device, p.arg.vmin_vmax, p.arg.name, p.arg.addrspace).multi(p.axis)
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, ParamArg
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
|
|
@ -130,15 +130,15 @@ def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
|||
if c.arg.precompile: return None
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
params = sorted(params, key=lambda x: x.arg.slot)
|
||||
args = c.src[1:]
|
||||
|
||||
# NOTE: this isn't really needed. it's okay if there's unused args in the function
|
||||
if not allow_param_mismatch:
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
if [x.arg.slot for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg.slot for x in params]}")
|
||||
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
|
||||
|
||||
dict_map = {x:args[x.arg] for x in params}
|
||||
dict_map = {x:args[x.arg.slot] for x in params}
|
||||
for i, (p, a) in enumerate(dict_map.items()):
|
||||
if p.axis != a.axis: raise TypeError(f"arg {i} axis mismatch: expected {p.axis}, got {a.axis}")
|
||||
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
|
|
@ -477,7 +477,7 @@ class LocalAddBufferContext:
|
|||
opts:tuple|None = None
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, buf:UOp):
|
||||
ret = UOp(Ops.PARAM, buf.dtype.ptr(prod(buf.max_shape)), arg=ctx.dg).reshape(buf.max_shape)
|
||||
ret = UOp(Ops.PARAM, buf.dtype.ptr(prod(buf.max_shape), buf.addrspace), arg=ParamArg(ctx.dg, addrspace=buf.addrspace)).reshape(buf.max_shape)
|
||||
# if the buffer has symbolic shape, shrink the max-sized view to the actual shape
|
||||
if buf.max_shape != buf.shape: ret = ret.shrink(tuple((0, s) for s in buf.shape))
|
||||
if buf not in ctx.map: ctx.map[buf] = buf
|
||||
|
|
@ -512,9 +512,11 @@ def find_bufs(x:UOp):
|
|||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), find_bufs),
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="buf"), debuf),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat.cvar('vmin'), UPat.cvar('vmax'), UPat.var("nm")), name="v"),
|
||||
lambda v, vmin, vmax, nm: UOp.variable(nm.arg, vmin.arg, vmax.arg, v.dtype)),
|
||||
(UPat(Ops.PARAM, name="v"), lambda v:
|
||||
UOp.variable(v.arg.name, v.arg.vmin_vmax[0], v.arg.vmin_vmax[1], v.dtype)
|
||||
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
|
||||
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
|
||||
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
|
|
|
|||
|
|
@ -203,7 +203,7 @@ class Tensor(OpMixin):
|
|||
|
||||
def as_param(self, slot:int):
|
||||
if self.uop.axis is not None:
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,19 @@ class AxisType(Enum):
|
|||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class ParamArg:
|
||||
slot: int
|
||||
vmin_vmax: tuple[PyConst, PyConst]|None = None
|
||||
name: str|None = None
|
||||
addrspace: AddrSpace = AddrSpace.GLOBAL
|
||||
axis: int|None = None
|
||||
device: str|tuple[str, ...]|None = None
|
||||
def __repr__(self):
|
||||
fields = (("vmin_vmax", None), ("name", None), ("addrspace", AddrSpace.GLOBAL), ("axis", None), ("device", None))
|
||||
args = [str(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
|
||||
return f"ParamArg({', '.join(args)})"
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
|
|
@ -281,9 +294,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, ImageDType): return self.dtype.shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
# NOTE: copied from marg
|
||||
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
return None
|
||||
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) if len(self.src) >= 1 else None
|
||||
|
||||
# wmma output shape = accumulator shape (src[2])
|
||||
case Ops.WMMA | Ops.SHAPED_WMMA: return self.src[2]._shape
|
||||
|
|
@ -305,7 +316,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) if len(ps) > 0 else ps
|
||||
return ps
|
||||
|
||||
# MULTI marker (axis info in PARAM sources) has no shape
|
||||
# MULTI marker has no shape
|
||||
case Ops.MULTI if len(self.src) == 0: return None
|
||||
|
||||
# movement ops change the shape
|
||||
|
|
@ -611,11 +622,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.GETTUPLE:
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
|
||||
return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None
|
||||
# PARAM: axis is stored as a MULTI source
|
||||
if self.op is Ops.PARAM:
|
||||
for s in self.src:
|
||||
if s.op is Ops.MULTI: return s.arg
|
||||
return None
|
||||
if self.op is Ops.PARAM: return self.arg.axis
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
if len(self.src) == 0: return None
|
||||
|
|
@ -736,6 +743,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret.after(ret.store(src))
|
||||
@recursive_property
|
||||
def device(self) -> str|tuple[str, ...]|None:
|
||||
if self.op is Ops.PARAM: return self.arg.device
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
if self.op is Ops.STAGE: return self.arg.device
|
||||
if self.op is Ops.AFTER: return self.src[0].device
|
||||
|
|
@ -749,7 +757,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return None
|
||||
@property
|
||||
def addrspace(self) -> AddrSpace:
|
||||
if self.op in {Ops.PARAM, Ops.BUFFER}: return AddrSpace.GLOBAL
|
||||
if self.op is Ops.PARAM: return self.arg.addrspace
|
||||
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op is Ops.STACK:
|
||||
|
|
@ -965,7 +974,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is Ops.PARAM and len(self.src) >= 4: return self.src[2].arg, self.src[3].arg
|
||||
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return self.arg.vmin_vmax
|
||||
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
|
||||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
|
|
@ -1010,7 +1019,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
|
||||
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
|
|
@ -1023,18 +1033,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# TODO: this should replace placeholder
|
||||
@staticmethod
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None):
|
||||
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
|
||||
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
|
||||
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))
|
||||
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None,
|
||||
addrspace=AddrSpace.GLOBAL, axis:int|None=None):
|
||||
if shape is not None and axis is not None and isinstance(device, tuple):
|
||||
shape = tuple(s*len(device) if i == axis else s for i,s in enumerate(shape))
|
||||
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),)
|
||||
return UOp(Ops.PARAM, dtype, src, arg=ParamArg(slot, vmin_vmax, name, addrspace, axis, device))
|
||||
def param_like(self, slot:int):
|
||||
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.BIND:
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, self._min_max, self.src[0].arg[0])
|
||||
p = UOp.param(slot, self.dtype, self._shape, self.device)
|
||||
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
|
||||
return p
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
|
||||
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
|
||||
|
||||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION}
|
||||
|
|
@ -1098,10 +1107,10 @@ class ProgramInfo:
|
|||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
|
|
@ -1646,7 +1655,7 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
|
||||
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg.slot])])
|
||||
remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
|
|||
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.NOOP, name="x"))), lambda x: x.arg),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: x.arg.name if x.arg.name is not None else f"p{x.arg.slot}"),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: str(x.arg)),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from typing import cast, Any
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
|
||||
from tinygrad.uop.render import print_uops, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
|
|
@ -71,7 +71,8 @@ spec_shared = PatternMatcher([
|
|||
(UPat(Ops.END, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(u.op is Ops.RANGE for u in x.src[1:])),
|
||||
|
||||
# PARAM (that's really a DEFINE_GLOBAL)
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg) and isinstance(x.dtype, (PtrDType, ImageDType)) and
|
||||
x.addrspace == x.dtype.addrspace),
|
||||
|
||||
# GROUP of stores (or groups, or NOOPs)
|
||||
# TODO: remove UNROLL here, it's for SPEC=2
|
||||
|
|
@ -130,9 +131,6 @@ spec_tensor = PatternMatcher([
|
|||
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
||||
# PARAM (that's really a variable)
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat()), name="x"), lambda x: True),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
|
||||
|
|
@ -148,9 +146,7 @@ spec_tensor = PatternMatcher([
|
|||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||
|
||||
# PARAM
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.NOOP)), name="x"), lambda x: True), # TODO: why does this have NOOP?
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="x"), lambda x: True),
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.MULTI)), name="x"), lambda x: True),
|
||||
(UPat(Ops.PARAM, src=(UPat(),), name="x"), lambda x: isinstance(x.arg, ParamArg)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat(Ops.STACK), lambda: True),
|
||||
|
|
@ -264,7 +260,7 @@ from tinygrad.schedule.rangeify import BufferizeOpts
|
|||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
|
||||
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
|
||||
"ConstFloat": ConstFloat}
|
||||
"ConstFloat": ConstFloat, "ParamArg": ParamArg}
|
||||
def eval_pyrender(code:str) -> UOp:
|
||||
lcls:dict[str, Any] = {}
|
||||
exec(code, glbls, lcls)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue