rename to callify + fix mypy (#15727)

* rename to callify + fix mypy

* update test
This commit is contained in:
George Hotz 2026-04-14 23:43:19 +08:00 committed by GitHub
commit 2450c8cba8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 23 additions and 19 deletions

View file

@ -20,7 +20,7 @@ def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
def decode_dpp16(dpp: int) -> tuple[str, int | tuple[int, int, int, int]]:
"""Decode a DPP16 control word into a symbolic operation and argument."""
if dpp < 0x100: return "quad_perm", tuple((dpp >> shift) & 0x3 for shift in range(0, 8, 2))
if dpp < 0x100: return "quad_perm", ((dpp >> 0) & 0x3, (dpp >> 2) & 0x3, (dpp >> 4) & 0x3, (dpp >> 6) & 0x3)
if dpp in _DPP16_EXACT_OPS: return _DPP16_EXACT_OPS[dpp]
if (base := dpp & 0x1f0) in _DPP16_RANGE_OPS: return _DPP16_RANGE_OPS[base], dpp & 0xf
return "dpp", dpp

View file

@ -940,18 +940,22 @@ def _dpp16_ctrl(lane: UOp, dpp: int, row_mask: int, bank_mask: int, wave_size: i
op, arg = decode_dpp16(dpp)
src_lane, valid = lane_i, UOp.const(dtypes.bool, True)
if op == 'quad_perm': src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
elif op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
elif op == 'row_bcast': src_lane = row_base
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
if op == 'quad_perm':
assert isinstance(arg, tuple)
src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
else:
assert isinstance(arg, int)
if op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
elif op == 'row_bcast': src_lane = row_base
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
return src_lane, enabled, valid
def _load_dpp16_src0(ctx: _Ctx, inst, lane: UOp, fallback: UOp) -> UOp:
@ -1055,8 +1059,8 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 |
if 'ACCVGPR_MOV' in op_name:
lane, exec_mask = ctx.range(), ctx.rexec()
vdst_reg = ctx.inst_field(type(inst).vdst) # VGPRField: raw ACCVGPR index (0-255)
src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index
val = ctx.raccvgpr_dyn(src0_off - _c(256), lane)
acc_src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index
val = ctx.raccvgpr_dyn(acc_src0_off - _c(256), lane)
return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc())
lane, exec_mask, bits = ctx.range(), ctx.rexec(), inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
@ -1067,7 +1071,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 |
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
elif write_hi_half: vdst_reg -= 128
src0_off = None
src0_off: UOp | None = None
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))

View file

@ -335,7 +335,7 @@ class TestVizIntegration(unittest.TestCase):
prg = get_program(ast, Device[Device.DEFAULT].renderer)
lst = viz.list_items()
self.assertEqual(len(lst), 3)
self.assertEqual(lst[0]["name"], "Process 1 Buffer n1")
self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1")
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[2]["name"], prg.name)

View file

@ -178,7 +178,7 @@ pm_replace_buf = PatternMatcher([
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
])
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
# uop list is a list in the original_sink graph and we can map to the tags later

View file

@ -16,7 +16,7 @@ from tinygrad.uop.ops import _broadcast_shape
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.allocations import transform_to_call
from tinygrad.engine.callify import transform_to_call
# *** all in scope Tensors are here. this gets relevant UOps ***