mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move up migrate + new gated fold (#7403)
* move up migrate + new gated fold [pr] * vcount for const ptr * move those rules there * fix openpilot
This commit is contained in:
parent
16e60d25b9
commit
adccfade7f
3 changed files with 56 additions and 69 deletions
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG
|
|||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
|
|
@ -389,12 +389,10 @@ class TestUOpGraph(unittest.TestCase):
|
|||
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[-1].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
|
||||
|
||||
|
|
@ -404,12 +402,10 @@ class TestUOpGraph(unittest.TestCase):
|
|||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[-1].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.index(lidx+2))
|
||||
|
||||
|
|
@ -449,7 +445,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def expander_rewrite(sink):
|
||||
sink = graph_rewrite(sink, sym + expander)
|
||||
return graph_rewrite(sink, sym + reducer)
|
||||
sink = graph_rewrite(sink, sym + reducer)
|
||||
return graph_rewrite(sink, sym + finalize)
|
||||
def float4_rewrite(sink): return graph_rewrite(sink, sym + expander + float4_folding)
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
|
|
@ -660,8 +657,6 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
print(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
|
||||
|
||||
def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer)
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
|
|
@ -675,12 +670,12 @@ class TestIFUOps(unittest.TestCase):
|
|||
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
|
||||
store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (store,))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
|
|
@ -693,12 +688,12 @@ class TestIFUOps(unittest.TestCase):
|
|||
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
# this will be fixed with the merge gated stores bounty
|
||||
@unittest.expectedFailure
|
||||
|
|
@ -709,12 +704,12 @@ class TestIFUOps(unittest.TestCase):
|
|||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
|
|
@ -61,13 +61,13 @@ def fold_expanded(ex, buf):
|
|||
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
||||
|
||||
def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
if not isinstance(buf.dtype, ImageDType) or load.src[1].dtype.count == 2: return None
|
||||
id4 = load.src[1] % 4
|
||||
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
|
||||
id4 = oidx % 4
|
||||
new_src = list(load.src)
|
||||
# TODO: copied logic from above
|
||||
new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((load.src[1] // 4) % buf.dtype.shape[1], (load.src[1] // (4 * buf.dtype.shape[1]))))
|
||||
if len(new_src) >= 4:
|
||||
new_src[2] = UOp(UOps.VECTORIZE, new_src[2].dtype.vec(4), tuple(new_src[2] for _ in range(4)))
|
||||
new_src[0] = load.src[0].src[0].index(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
||||
vec_load = UOp(UOps.LOAD, load.dtype.vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
||||
|
||||
|
|
@ -78,16 +78,12 @@ float4_folding = PatternMatcher([
|
|||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
def simplify_image_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
# wait for it to be image indexed before running simplification
|
||||
if start_idx.dtype.count != 2: return None
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
|
|
@ -105,7 +101,7 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
|
|||
# if X <= c, check if it's out of bound when X = c+1
|
||||
# if X >= c, check if it's out of bound when X = c-1
|
||||
test_value = c + 1 if is_upper_bound else c - 1
|
||||
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
|
||||
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
||||
if is_increasing(i):
|
||||
rw = graph_rewrite(i.substitute({X:X.const_like(test_value)}), sym)
|
||||
if rw.vmin >= b or rw.vmax < 0:
|
||||
|
|
@ -114,7 +110,8 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
|
|||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
|
||||
# ***** optional patterns *****
|
||||
|
||||
|
|
@ -290,12 +287,10 @@ sym = symbolic_flat+PatternMatcher([
|
|||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat.load(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), name="ld"), lambda ld: ld.replace(src=ld.src[:2])),
|
||||
(UPat.load(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), UPat.var("bar"), name="ld"), lambda ld,bar: ld.replace(src=ld.src[:2]+(bar,))),
|
||||
(UPat.load(UPat(), UPat(), UPat.var("var"), UPat.const(dtypes.bool, False)), lambda var: var),
|
||||
(UPat.load(UPat(), UPat(), UPat.var("var"), UPat.const(dtypes.bool, False), UPat()), lambda var: var),
|
||||
(UPat.store(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), name="store"), lambda store: store.replace(src=store.src[:3])),
|
||||
(UPat.store(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
|
||||
(UPat(UOps.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
||||
(UPat(UOps.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(UOps.NOOP)), # NULL pointer store does nothing
|
||||
# remove NOOPs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
|
||||
|
|
@ -450,14 +445,6 @@ def no_vectorized_acc(acc:UOp):
|
|||
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, acc.dtype, alus)
|
||||
|
||||
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
if x.op is UOps.IF: return x
|
||||
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
|
||||
if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
|
|
@ -472,20 +459,12 @@ devectorize = PatternMatcher([
|
|||
])
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
# delete_redundant_gates (after expand, is this still needed?)
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(UOps.LOAD, name="load"), simplify_image_load),
|
||||
# buffer load valid idx simplification
|
||||
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
|
|
@ -506,14 +485,25 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
|
|||
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
||||
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
||||
|
||||
masked_index = UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask")))
|
||||
move_masks = PatternMatcher([
|
||||
# NOTE: this shouldn't be here
|
||||
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
if x.op is UOps.IF: return x
|
||||
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
|
||||
if len(root.src) == 2 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[2]: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
|
||||
|
||||
finalize = PatternMatcher([
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
|
||||
# fix up loads/stores
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
# move masks of loads/stores
|
||||
# TODO: this should be an IF instead of a masked STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index, masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
|
@ -532,13 +522,15 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
# convert REDUCE to DEFINE_ACC + ASSIGN
|
||||
sink = graph_rewrite(sink, sym+just_reduce)
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
|
||||
# temp for indexing migration
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
# cleanups
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, sym+move_masks+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
|
||||
return sink
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ class UPat(MathTrait):
|
|||
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def index(self, idx:UPat): return UPat(UOps.INDEX, self.dtype, (self,idx))
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue