mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rename to callify + fix mypy (#15727)
* rename to callify + fix mypy * update test
This commit is contained in:
parent
528faa18ec
commit
2450c8cba8
5 changed files with 23 additions and 19 deletions
|
|
@ -20,7 +20,7 @@ def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
|
|||
|
||||
def decode_dpp16(dpp: int) -> tuple[str, int | tuple[int, int, int, int]]:
|
||||
"""Decode a DPP16 control word into a symbolic operation and argument."""
|
||||
if dpp < 0x100: return "quad_perm", tuple((dpp >> shift) & 0x3 for shift in range(0, 8, 2))
|
||||
if dpp < 0x100: return "quad_perm", ((dpp >> 0) & 0x3, (dpp >> 2) & 0x3, (dpp >> 4) & 0x3, (dpp >> 6) & 0x3)
|
||||
if dpp in _DPP16_EXACT_OPS: return _DPP16_EXACT_OPS[dpp]
|
||||
if (base := dpp & 0x1f0) in _DPP16_RANGE_OPS: return _DPP16_RANGE_OPS[base], dpp & 0xf
|
||||
return "dpp", dpp
|
||||
|
|
|
|||
|
|
@ -940,18 +940,22 @@ def _dpp16_ctrl(lane: UOp, dpp: int, row_mask: int, bank_mask: int, wave_size: i
|
|||
op, arg = decode_dpp16(dpp)
|
||||
src_lane, valid = lane_i, UOp.const(dtypes.bool, True)
|
||||
|
||||
if op == 'quad_perm': src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
|
||||
elif op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
|
||||
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
|
||||
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
|
||||
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
|
||||
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
|
||||
elif op == 'row_bcast': src_lane = row_base
|
||||
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
|
||||
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
|
||||
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
|
||||
if op == 'quad_perm':
|
||||
assert isinstance(arg, tuple)
|
||||
src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
|
||||
else:
|
||||
assert isinstance(arg, int)
|
||||
if op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
|
||||
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
|
||||
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
|
||||
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
|
||||
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
|
||||
elif op == 'row_bcast': src_lane = row_base
|
||||
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
|
||||
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
|
||||
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
|
||||
return src_lane, enabled, valid
|
||||
|
||||
def _load_dpp16_src0(ctx: _Ctx, inst, lane: UOp, fallback: UOp) -> UOp:
|
||||
|
|
@ -1055,8 +1059,8 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 |
|
|||
if 'ACCVGPR_MOV' in op_name:
|
||||
lane, exec_mask = ctx.range(), ctx.rexec()
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst) # VGPRField: raw ACCVGPR index (0-255)
|
||||
src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index
|
||||
val = ctx.raccvgpr_dyn(src0_off - _c(256), lane)
|
||||
acc_src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index
|
||||
val = ctx.raccvgpr_dyn(acc_src0_off - _c(256), lane)
|
||||
return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc())
|
||||
lane, exec_mask, bits = ctx.range(), ctx.rexec(), inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
|
|
@ -1067,7 +1071,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 |
|
|||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
elif write_hi_half: vdst_reg -= 128
|
||||
src0_off = None
|
||||
src0_off: UOp | None = None
|
||||
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
|
||||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
|
||||
|
|
|
|||
|
|
@ -335,7 +335,7 @@ class TestVizIntegration(unittest.TestCase):
|
|||
prg = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
lst = viz.list_items()
|
||||
self.assertEqual(len(lst), 3)
|
||||
self.assertEqual(lst[0]["name"], "Process 1 Buffer n1")
|
||||
self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1")
|
||||
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
|
||||
self.assertEqual(lst[2]["name"], prg.name)
|
||||
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ pm_replace_buf = PatternMatcher([
|
|||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
|
||||
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")
|
||||
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
|
|
@ -16,7 +16,7 @@ from tinygrad.uop.ops import _broadcast_shape
|
|||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Buffer, canonicalize_device
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.allocations import transform_to_call
|
||||
from tinygrad.engine.callify import transform_to_call
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue