hotfix: touchups from presentation

This commit is contained in:
George Hotz 2024-06-04 16:31:03 +02:00
commit 052c928d06
3 changed files with 4 additions and 5 deletions

View file

@ -26,8 +26,6 @@ We'll use the model from [the Keras tutorial](https://keras.io/examples/vision/m
```python
from tinygrad import Tensor, nn
Tensor.manual_seed(42)
class Model:
def __init__(self):
self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))

View file

@ -19,11 +19,12 @@ class _Device:
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
f"can only open device {ix} from parent, not {cpn}"
x = ix.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
return ret
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore

View file

@ -38,7 +38,7 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
# TODO: check the correctness inline once compare_linearizer is in core
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
return k
# **************** Runners ****************