use pathlib

This commit is contained in:
George Hotz 2024-10-22 14:27:02 +08:00
commit 7c38489820
3 changed files with 12 additions and 13 deletions

View file

@ -21,7 +21,7 @@ repos:
pass_filenames: false
- id: devicetests
name: select GPU tests
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_search.py
entry: env GPU=1 PYTHONPATH="." python3 -m pytest test/test_uops.py test/test_search.py
language: system
always_run: true
pass_filenames: false
@ -39,7 +39,7 @@ repos:
pass_filenames: false
- id: pylint
name: pylint
entry: env PYTHONPATH="." python3 -m pylint tinygrad/
entry: python3 -m pylint tinygrad/
language: system
always_run: true
pass_filenames: false

View file

@ -1,5 +1,5 @@
from typing import Optional, Tuple, Any, List
import unittest, math
import unittest, math, pathlib
import numpy as np
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import Tensor, _to_np_dtype
@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
self.assertEqual(append_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
self.assertEqual(sym.patterns[-1][0].location[0].name, "uopgraph.py")
self.assertEqual(append_bufs.patterns[0][0].location[0].name, "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].name, "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1])
self.assertEqual(test_upat.location[0].name, pathlib.Path(__file__).name)
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from weakref import WeakValueDictionary
@ -477,15 +477,14 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
# ***** pattern matcher *****
def get_location() -> Tuple[str, int]:
def get_location() -> Tuple[pathlib.Path, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
return pathlib.Path(frm.f_code.co_filename), frm.f_lineno
@functools.lru_cache(None)
def lines(fn) -> List[str]:
with open(fn) as f: return f.readlines()
def lines(fn:pathlib.Path) -> List[str]: return fn.read_text().splitlines()
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src", "_any"]