mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into dsp_search
This commit is contained in:
commit
31ffa1607e
8 changed files with 31 additions and 10 deletions
|
|
@ -504,7 +504,7 @@ def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip
|
|||
obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
|
||||
weight_g, weight_v, parent, skip = None, None, None, False
|
||||
if not skip and obj.shape == v.shape:
|
||||
if "feature_extractor" in key and (isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm)): # cast
|
||||
if "feature_extractor" in key and (isinstance(parent, (nn.GroupNorm, nn.LayerNorm))): # cast
|
||||
obj.assign(v.to(obj.device).float())
|
||||
else:
|
||||
obj.assign(v.to(obj.device))
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
|||
@TinyJit
|
||||
def run(*x):
|
||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
assert isinstance(out, (tuple, list, Tensor)), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
out = [out] if isinstance(out, Tensor) else out
|
||||
return [o.realize() for o in out]
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ CHUNK_CLASSES = {
|
|||
}
|
||||
|
||||
def pretty(val, pad=0) -> str:
|
||||
if isinstance(val, ctypes.Structure) or isinstance(val, ctypes.Union):
|
||||
if isinstance(val, (ctypes.Structure, ctypes.Union)):
|
||||
nl = '\n' # old python versions don't support \ in f-strings
|
||||
return f"{val.__class__.__name__}({nl}{' '*(pad+2)}{(f', {nl}'+' '*(pad+2)).join([f'{field[0]}={pretty(getattr(val, field[0]), pad=pad+2)}' for field in val._fields_])}{nl}{' '*pad})"
|
||||
if isinstance(val, ctypes.Array):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class DispatchLog(TorchDispatchMode):
|
|||
should_call_tiny = kwargs.get('device') is not None and kwargs['device'].type == "cuda"
|
||||
|
||||
def can_print_arg(arg):
|
||||
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
|
||||
return args is None or isinstance(arg, (str, int, float, bool))
|
||||
|
||||
def create_tiny_mapping(arg):
|
||||
if WRAP_TINY:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,10 @@ def helper_alloc_rawbuffer(device, fill=False):
|
|||
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
|
||||
return rawbuf
|
||||
|
||||
def helper_create_offset_rawbuffer(base, offset=0):
|
||||
x = Buffer(base.device, base.size-offset, base.dtype, base=base, offset=offset)
|
||||
return x.ensure_allocated()
|
||||
|
||||
def helper_run_jit(jis, bufs, out_buffers):
|
||||
for rawbuf in out_buffers:
|
||||
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
|
||||
|
|
@ -229,5 +233,18 @@ class TestGraph(unittest.TestCase):
|
|||
|
||||
helper_test_graphs(Device[d0].graph, graphs)
|
||||
|
||||
def test_graph_offset_bufs(self):
|
||||
d0 = Device.DEFAULT
|
||||
if not hasattr(Device[d0].allocator, "_offset"): self.skipTest("device does not support _offset")
|
||||
|
||||
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(1)]
|
||||
b0 += [helper_create_offset_rawbuffer(b0[0]), helper_create_offset_rawbuffer(b0[0])]
|
||||
|
||||
graphs = [
|
||||
[helper_copy_op(d0, b0[0], b0[2]), helper_exec_op(d0, b0[1], [b0[0], b0[2]])],
|
||||
]
|
||||
|
||||
helper_test_graphs(Device[d0].graph, graphs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -15,14 +15,16 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
# first, extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]),
|
||||
UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)),
|
||||
symbolic, name=f"index_buf_{buf.arg}").src
|
||||
for i in range(vec.dtype.count):
|
||||
idx = vec.gep(i).simplify()
|
||||
#graph_rewrite(idx, PatternMatcher([]))
|
||||
idx: Any = midx.src[i]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
|
||||
if mask is not None: root_src = (mmask.src[i], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# the buf.dtype is always a pointer
|
||||
|
|
|
|||
|
|
@ -120,7 +120,9 @@ class GraphRunner(Runner):
|
|||
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
||||
if i in write:
|
||||
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
||||
self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||
|
||||
for i,rawbuf in enumerate(rawbufs):
|
||||
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
||||
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
|
|
|||
|
|
@ -714,7 +714,7 @@ class UPat(MathTrait):
|
|||
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
|
||||
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue