mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
changes_fr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
260da2017c | ||
|
|
65dcd6dd45 |
8 changed files with 35 additions and 17 deletions
|
|
@ -303,8 +303,8 @@ class TestRecurse(unittest.TestCase):
|
||||||
def test_inf_loop(self):
|
def test_inf_loop(self):
|
||||||
a = UOp.variable('a', 0, 10)
|
a = UOp.variable('a', 0, 10)
|
||||||
pm = PatternMatcher([
|
pm = PatternMatcher([
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||||
])
|
])
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
graph_rewrite(a, pm)
|
graph_rewrite(a, pm)
|
||||||
|
|
@ -312,8 +312,8 @@ class TestRecurse(unittest.TestCase):
|
||||||
def test_inf_loop_bottom_up(self):
|
def test_inf_loop_bottom_up(self):
|
||||||
a = UOp.variable('a', 0, 10)
|
a = UOp.variable('a', 0, 10)
|
||||||
pm = PatternMatcher([
|
pm = PatternMatcher([
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||||
])
|
])
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
graph_rewrite(a, pm, bottom_up=True)
|
graph_rewrite(a, pm, bottom_up=True)
|
||||||
|
|
|
||||||
|
|
@ -124,10 +124,10 @@ class TestViz(BaseTestViz):
|
||||||
|
|
||||||
def test_inf_loop(self):
|
def test_inf_loop(self):
|
||||||
a = UOp.variable('a', 0, 10)
|
a = UOp.variable('a', 0, 10)
|
||||||
b = a.replace(op=Ops.DEFINE_REG)
|
b = a.replace(op=Ops.CONST)
|
||||||
pm = PatternMatcher([
|
pm = PatternMatcher([
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||||
])
|
])
|
||||||
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
||||||
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))
|
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))
|
||||||
|
|
|
||||||
|
|
@ -285,7 +285,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||||
topo = inp.toposort()
|
topo = inp.toposort()
|
||||||
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
||||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
||||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ class CStyleLanguage(Renderer):
|
||||||
if u.arg is not None: name = u.arg.function_name
|
if u.arg is not None: name = u.arg.function_name
|
||||||
continue
|
continue
|
||||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||||
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
r[u] = (f"data{u.arg}_{sz}" if (sz:=cast(PtrDType, u.dtype).size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
||||||
bufs[u] = (r[u], (u.dtype, False))
|
bufs[u] = (r[u], (u.dtype, False))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2933,11 +2933,11 @@ class Tensor(MathTrait):
|
||||||
"""
|
"""
|
||||||
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
||||||
|
|
||||||
def contiguous(self) -> Tensor:
|
def contiguous(self, **kwargs) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a contiguous tensor.
|
Returns a contiguous tensor.
|
||||||
"""
|
"""
|
||||||
return self._apply_uop(UOp.contiguous)
|
return self._apply_uop(UOp.contiguous, **kwargs)
|
||||||
|
|
||||||
def fuse(self) -> Tensor:
|
def fuse(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,9 @@ class GroupOp:
|
||||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||||
ALU = set.union(Unary, Binary, Ternary)
|
ALU = set.union(Unary, Binary, Ternary)
|
||||||
|
|
||||||
|
# TODO: is BITCAST always Elementwise if it's shape changing?
|
||||||
|
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||||
|
|
||||||
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||||
|
|
||||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,19 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
@property
|
@property
|
||||||
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||||
|
|
||||||
|
# determine what ranges this is in
|
||||||
|
@functools.cached_property
|
||||||
|
def ranges(self) -> dict[UOp, None]:
|
||||||
|
if self.op is Ops.RANGE: return {self:None}
|
||||||
|
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
|
||||||
|
ret = self.src[0].ranges.copy()
|
||||||
|
for s in self.src[1:]:
|
||||||
|
if s in ret: del ret[s]
|
||||||
|
else:
|
||||||
|
ret = {}
|
||||||
|
for s in self.src: ret.update(s.ranges)
|
||||||
|
return ret
|
||||||
|
|
||||||
# *** uop evaluation ***
|
# *** uop evaluation ***
|
||||||
|
|
||||||
def simplify(self):
|
def simplify(self):
|
||||||
|
|
@ -219,7 +232,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
return ret
|
return ret
|
||||||
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
|
||||||
def __getitem__(self, idx): return self.index(idx)
|
def __getitem__(self, idx): return self.index(idx)
|
||||||
def const_like(self, b:ConstLike):
|
def const_like(self, b:ConstLike):
|
||||||
# constants can optionally have a DEVICE source
|
# constants can optionally have a DEVICE source
|
||||||
|
|
@ -275,7 +288,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||||
def fuse(self): return self.alu(Ops.FUSE)
|
def fuse(self): return self.alu(Ops.FUSE)
|
||||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||||
|
|
|
||||||
|
|
@ -75,11 +75,13 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||||
if x in excluded:
|
if x in excluded:
|
||||||
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
||||||
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
||||||
try:
|
|
||||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||||
|
try:
|
||||||
label += f"\n{shape_to_str(u.shape)}"
|
label += f"\n{shape_to_str(u.shape)}"
|
||||||
except Exception:
|
except Exception:
|
||||||
label += "\n<ISSUE GETTING SHAPE>"
|
label += "\n<ISSUE GETTING SHAPE>"
|
||||||
|
elif len(rngs:=u.ranges):
|
||||||
|
label += f"\n{str(sorted([x.arg for x in rngs]))}"
|
||||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||||
# NOTE: kernel already has metadata in arg
|
# NOTE: kernel already has metadata in arg
|
||||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
|
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue