mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix test_xlm_roberta_large (#14564)
onnxruntime does not allow symlink that's outside model dir. update snapshot_download to use local_dir instead of cache_dir. some ad hoc migration step to copy the existing model too
This commit is contained in:
parent
aa9dc50577
commit
41a179f542
4 changed files with 21 additions and 10 deletions
12
.github/workflows/benchmark.yml
vendored
12
.github/workflows/benchmark.yml
vendored
|
|
@ -41,6 +41,18 @@ jobs:
|
|||
run: |
|
||||
echo "CACHEDB=/tmp/pytest-db-ci.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/pytest-db-ci*
|
||||
# TODO: remove this step once all old caches are migrated
|
||||
- name: Migrate old huggingface cache (symlinks break onnxruntime 1.24+)
|
||||
run: |
|
||||
cd ~/Library/Caches/tinygrad/downloads/models 2>/dev/null || exit 0
|
||||
for old_dir in models--*; do
|
||||
[ -d "$old_dir" ] || continue
|
||||
repo_id=$(echo "$old_dir" | sed 's/models--//; s/--/\//g')
|
||||
snapshot=$(ls -1 "$old_dir/snapshots" 2>/dev/null | head -1)
|
||||
[ -n "$snapshot" ] || continue
|
||||
mkdir -p "$repo_id"
|
||||
cp -RLn "$old_dir/snapshots/$snapshot/"* "$repo_id/" 2>/dev/null || true
|
||||
done
|
||||
- name: Run pytest -nauto
|
||||
run: |
|
||||
source /tmp/tinygrad_pytest_ci/bin/activate
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from tinygrad.helpers import _ensure_downloads_dir
|
|||
DOWNLOADS_DIR = _ensure_downloads_dir() / "models"
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, cache_dir: str|Path|None=None,
|
||||
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, local_dir: str|Path|None=None,
|
||||
tries: int=2, **kwargs) -> Path:
|
||||
for attempt in range(tries):
|
||||
try:
|
||||
return Path(snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=str(cache_dir) if cache_dir is not None else None,
|
||||
local_dir=str(local_dir) if local_dir is not None else None,
|
||||
**kwargs
|
||||
))
|
||||
except Exception as e:
|
||||
|
|
@ -144,14 +144,14 @@ class HuggingFaceONNXManager:
|
|||
root_path = snapshot_download_with_retry(
|
||||
repo_id=model_id,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=str(self.models_dir)
|
||||
local_dir=str(self.models_dir / model_id)
|
||||
)
|
||||
|
||||
# Download config files (usually small)
|
||||
snapshot_download_with_retry(
|
||||
repo_id=model_id,
|
||||
allow_patterns=["*config.json"],
|
||||
cache_dir=str(self.models_dir)
|
||||
local_dir=str(self.models_dir / model_id)
|
||||
)
|
||||
|
||||
model_data["download_path"] = str(root_path)
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@ if __name__ == "__main__":
|
|||
# repo id
|
||||
# validates all onnx models inside repo
|
||||
repo_id = "/".join(path)
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=DOWNLOADS_DIR)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], local_dir=DOWNLOADS_DIR / repo_id)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], local_dir=DOWNLOADS_DIR / repo_id)
|
||||
config = get_config(root_path)
|
||||
for onnx_model in root_path.rglob("*.onnx"):
|
||||
rtol, atol = get_tolerances(onnx_model.name)
|
||||
|
|
@ -101,8 +101,8 @@ if __name__ == "__main__":
|
|||
onnx_model = path[-1]
|
||||
assert path[-1].endswith(".onnx")
|
||||
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=DOWNLOADS_DIR)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], local_dir=DOWNLOADS_DIR / repo_id)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], local_dir=DOWNLOADS_DIR / repo_id)
|
||||
config = get_config(root_path)
|
||||
rtol, atol = get_tolerances(onnx_model)
|
||||
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
|
||||
|
|
|
|||
|
|
@ -73,14 +73,13 @@ class TestHuggingFaceOnnxModels(unittest.TestCase):
|
|||
onnx_model_path = snapshot_download_with_retry(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=["*.onnx", "*.onnx_data"],
|
||||
cache_dir=str(DOWNLOADS_DIR)
|
||||
local_dir=DOWNLOADS_DIR / repo_id
|
||||
)
|
||||
onnx_model_path = onnx_model_path / model_file
|
||||
file_size = onnx_model_path.stat().st_size
|
||||
print(f"Validating model: {repo_id}/{model_file} ({file_size/1e6:.2f}M)")
|
||||
validate(onnx_model_path, custom_inputs, rtol=rtol, atol=atol)
|
||||
|
||||
@unittest.skip("onnxruntime 1.24+ rejects huggingface_hub symlinks as path traversal")
|
||||
def test_xlm_roberta_large(self):
|
||||
repo_id = "FacebookAI/xlm-roberta-large"
|
||||
model_file = "onnx/model.onnx"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue