Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
260da2017c fix unit test dtypes 2025-08-13 12:15:12 -07:00
George Hotz
65dcd6dd45 render ranges in viz, name gbufs with sizes. changes from rangeify 2025-08-13 12:06:34 -07:00
8 changed files with 35 additions and 17 deletions

View file

@ -303,8 +303,8 @@ class TestRecurse(unittest.TestCase):
def test_inf_loop(self): def test_inf_loop(self):
a = UOp.variable('a', 0, 10) a = UOp.variable('a', 0, 10)
pm = PatternMatcher([ pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
]) ])
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
graph_rewrite(a, pm) graph_rewrite(a, pm)
@ -312,8 +312,8 @@ class TestRecurse(unittest.TestCase):
def test_inf_loop_bottom_up(self): def test_inf_loop_bottom_up(self):
a = UOp.variable('a', 0, 10) a = UOp.variable('a', 0, 10)
pm = PatternMatcher([ pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
]) ])
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True) graph_rewrite(a, pm, bottom_up=True)

View file

@ -124,10 +124,10 @@ class TestViz(BaseTestViz):
def test_inf_loop(self): def test_inf_loop(self):
a = UOp.variable('a', 0, 10) a = UOp.variable('a', 0, 10)
b = a.replace(op=Ops.DEFINE_REG) b = a.replace(op=Ops.CONST)
pm = PatternMatcher([ pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
]) ])
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm]) with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0])) graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))

View file

@ -285,7 +285,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
topo = inp.toposort() topo = inp.toposort()
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE]) stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges]) input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
identity = red.const_like(identity_element(red.arg, red.dtype.scalar())) identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0)) acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity) do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element

View file

@ -146,7 +146,7 @@ class CStyleLanguage(Renderer):
if u.arg is not None: name = u.arg.function_name if u.arg is not None: name = u.arg.function_name
continue continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0] r[u] = (f"data{u.arg}_{sz}" if (sz:=cast(PtrDType, u.dtype).size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
bufs[u] = (r[u], (u.dtype, False)) bufs[u] = (r[u], (u.dtype, False))
continue continue

View file

@ -2933,11 +2933,11 @@ class Tensor(MathTrait):
""" """
return self*-1 if self.dtype != dtypes.bool else self.logical_not() return self*-1 if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self) -> Tensor: def contiguous(self, **kwargs) -> Tensor:
""" """
Returns a contiguous tensor. Returns a contiguous tensor.
""" """
return self._apply_uop(UOp.contiguous) return self._apply_uop(UOp.contiguous, **kwargs)
def fuse(self) -> Tensor: def fuse(self) -> Tensor:
""" """

View file

@ -86,6 +86,9 @@ class GroupOp:
Ternary = {Ops.WHERE, Ops.MULACC} Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary) ALU = set.union(Unary, Binary, Ternary)
# TODO: is BITCAST always Elementwise if it's shape changing?
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG} Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}

View file

@ -182,6 +182,19 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property @property
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
# determine what ranges this is in
@functools.cached_property
def ranges(self) -> dict[UOp, None]:
if self.op is Ops.RANGE: return {self:None}
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
ret = self.src[0].ranges.copy()
for s in self.src[1:]:
if s in ret: del ret[s]
else:
ret = {}
for s in self.src: ret.update(s.ranges)
return ret
# *** uop evaluation *** # *** uop evaluation ***
def simplify(self): def simplify(self):
@ -219,7 +232,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret return ret
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
def __getitem__(self, idx): return self.index(idx) def __getitem__(self, idx): return self.index(idx)
def const_like(self, b:ConstLike): def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source # constants can optionally have a DEVICE source
@ -275,7 +288,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis)) ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def fuse(self): return self.alu(Ops.FUSE) def fuse(self): return self.alu(Ops.FUSE)
def allreduce(self, op, device:str|tuple[str, ...]|UOp): def allreduce(self, op, device:str|tuple[str, ...]|UOp):

View file

@ -75,11 +75,13 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if x in excluded: if x in excluded:
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
else: label += f"\n{x.op.name}{idx} {x.arg}" else: label += f"\n{x.op.name}{idx} {x.arg}"
try: if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: try:
label += f"\n{shape_to_str(u.shape)}" label += f"\n{shape_to_str(u.shape)}"
except Exception: except Exception:
label += "\n<ISSUE GETTING SHAPE>" label += "\n<ISSUE GETTING SHAPE>"
elif len(rngs:=u.ranges):
label += f"\n{str(sorted([x.arg for x in rngs]))}"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg # NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata) if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)