mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
This reverts commit 28897be9a2.
This commit is contained in:
parent
28897be9a2
commit
caee42e8a6
4 changed files with 5 additions and 6 deletions
|
|
@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
|||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(UOp,x.lazydata).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
ei = CompiledRunner(ProgramSpec(src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ def _uops_to_prg(uops_list):
|
|||
uops = linearize_uop(full_graph_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(ProgramSpec(src, Device.DEFAULT, ast, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
|
||||
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(src), arg))
|
||||
|
|
|
|||
|
|
@ -697,5 +697,5 @@ class Kernel:
|
|||
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return ProgramSpec(src, self.opts.device, self.ast, self.uops, mem_estimate=mem_bytes,
|
||||
return ProgramSpec(self.uops[0].arg, src, self.opts.device, self.ast, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ class Estimates:
|
|||
|
||||
@dataclass
|
||||
class ProgramSpec:
|
||||
name:str
|
||||
src:str
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
|
|
@ -74,7 +75,6 @@ class ProgramSpec:
|
|||
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
name:str="test"
|
||||
global_size:Optional[list[int]]=None
|
||||
local_size:Optional[list[int]]=None
|
||||
vars:list[Variable]=field(default_factory=list)
|
||||
|
|
@ -87,7 +87,6 @@ class ProgramSpec:
|
|||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is Ops.NAME: self.name = u.arg
|
||||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue