Merge branch 'master' into dsp_search

This commit is contained in:
George Hotz 2025-03-26 21:43:22 +08:00 committed by GitHub
commit 31ffa1607e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 31 additions and 10 deletions

View file

@ -504,7 +504,7 @@ def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip
obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
weight_g, weight_v, parent, skip = None, None, None, False
if not skip and obj.shape == v.shape:
if "feature_extractor" in key and (isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm)): # cast
if "feature_extractor" in key and (isinstance(parent, (nn.GroupNorm, nn.LayerNorm))): # cast
obj.assign(v.to(obj.device).float())
else:
obj.assign(v.to(obj.device))

View file

@ -38,7 +38,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
@TinyJit
def run(*x):
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
assert isinstance(out, (tuple, list, Tensor)), "model output must be a Tensor, tuple, or a list of Tensors for export"
out = [out] if isinstance(out, Tensor) else out
return [o.realize() for o in out]

View file

@ -22,7 +22,7 @@ CHUNK_CLASSES = {
}
def pretty(val, pad=0) -> str:
if isinstance(val, ctypes.Structure) or isinstance(val, ctypes.Union):
if isinstance(val, (ctypes.Structure, ctypes.Union)):
nl = '\n' # old python versions don't support \ in f-strings
return f"{val.__class__.__name__}({nl}{' '*(pad+2)}{(f', {nl}'+' '*(pad+2)).join([f'{field[0]}={pretty(getattr(val, field[0]), pad=pad+2)}' for field in val._fields_])}{nl}{' '*pad})"
if isinstance(val, ctypes.Array):

View file

@ -39,7 +39,7 @@ class DispatchLog(TorchDispatchMode):
should_call_tiny = kwargs.get('device') is not None and kwargs['device'].type == "cuda"
def can_print_arg(arg):
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
return args is None or isinstance(arg, (str, int, float, bool))
def create_tiny_mapping(arg):
if WRAP_TINY:

View file

@ -38,6 +38,10 @@ def helper_alloc_rawbuffer(device, fill=False):
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
return rawbuf
def helper_create_offset_rawbuffer(base, offset=0):
x = Buffer(base.device, base.size-offset, base.dtype, base=base, offset=offset)
return x.ensure_allocated()
def helper_run_jit(jis, bufs, out_buffers):
for rawbuf in out_buffers:
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
@ -229,5 +233,18 @@ class TestGraph(unittest.TestCase):
helper_test_graphs(Device[d0].graph, graphs)
def test_graph_offset_bufs(self):
d0 = Device.DEFAULT
if not hasattr(Device[d0].allocator, "_offset"): self.skipTest("device does not support _offset")
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(1)]
b0 += [helper_create_offset_rawbuffer(b0[0]), helper_create_offset_rawbuffer(b0[0])]
graphs = [
[helper_copy_op(d0, b0[0], b0[2]), helper_exec_op(d0, b0[1], [b0[0], b0[2]])],
]
helper_test_graphs(Device[d0].graph, graphs)
if __name__ == '__main__':
unittest.main()

View file

@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
from tinygrad.ops import graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@ -15,14 +15,16 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# first, extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]),
UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)),
symbolic, name=f"index_buf_{buf.arg}").src
for i in range(vec.dtype.count):
idx = vec.gep(i).simplify()
#graph_rewrite(idx, PatternMatcher([]))
idx: Any = midx.src[i]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
if mask is not None: root_src = (mmask.src[i], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer

View file

@ -120,7 +120,9 @@ class GraphRunner(Runner):
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
if i in write:
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
for i,rawbuf in enumerate(rawbufs):
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
return list({id(x):x for x in wait_nodes}.values())

View file

@ -714,7 +714,7 @@ class UPat(MathTrait):
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject