fix test_xlm_roberta_large (#14564)

onnxruntime does not allow symlink that's outside model dir. update snapshot_download to use local_dir instead of cache_dir. some ad hoc migration step to copy the existing model too
This commit is contained in:
chenyu 2026-02-05 14:56:06 -05:00 committed by GitHub
commit 41a179f542
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 21 additions and 10 deletions

View file

@ -41,6 +41,18 @@ jobs:
run: |
echo "CACHEDB=/tmp/pytest-db-ci.db" >> $GITHUB_ENV
rm -f /tmp/pytest-db-ci*
# TODO: remove this step once all old caches are migrated
- name: Migrate old huggingface cache (symlinks break onnxruntime 1.24+)
run: |
cd ~/Library/Caches/tinygrad/downloads/models 2>/dev/null || exit 0
for old_dir in models--*; do
[ -d "$old_dir" ] || continue
repo_id=$(echo "$old_dir" | sed 's/models--//; s/--/\//g')
snapshot=$(ls -1 "$old_dir/snapshots" 2>/dev/null | head -1)
[ -n "$snapshot" ] || continue
mkdir -p "$repo_id"
cp -RLn "$old_dir/snapshots/$snapshot/"* "$repo_id/" 2>/dev/null || true
done
- name: Run pytest -nauto
run: |
source /tmp/tinygrad_pytest_ci/bin/activate

View file

@ -8,14 +8,14 @@ from tinygrad.helpers import _ensure_downloads_dir
DOWNLOADS_DIR = _ensure_downloads_dir() / "models"
from tinygrad.helpers import tqdm
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, cache_dir: str|Path|None=None,
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, local_dir: str|Path|None=None,
tries: int=2, **kwargs) -> Path:
for attempt in range(tries):
try:
return Path(snapshot_download(
repo_id=repo_id,
allow_patterns=allow_patterns,
cache_dir=str(cache_dir) if cache_dir is not None else None,
local_dir=str(local_dir) if local_dir is not None else None,
**kwargs
))
except Exception as e:
@ -144,14 +144,14 @@ class HuggingFaceONNXManager:
root_path = snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=allow_patterns,
cache_dir=str(self.models_dir)
local_dir=str(self.models_dir / model_id)
)
# Download config files (usually small)
snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=["*config.json"],
cache_dir=str(self.models_dir)
local_dir=str(self.models_dir / model_id)
)
model_data["download_path"] = str(root_path)

View file

@ -88,8 +88,8 @@ if __name__ == "__main__":
# repo id
# validates all onnx models inside repo
repo_id = "/".join(path)
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=DOWNLOADS_DIR)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], local_dir=DOWNLOADS_DIR / repo_id)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], local_dir=DOWNLOADS_DIR / repo_id)
config = get_config(root_path)
for onnx_model in root_path.rglob("*.onnx"):
rtol, atol = get_tolerances(onnx_model.name)
@ -101,8 +101,8 @@ if __name__ == "__main__":
onnx_model = path[-1]
assert path[-1].endswith(".onnx")
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=DOWNLOADS_DIR)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], local_dir=DOWNLOADS_DIR / repo_id)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], local_dir=DOWNLOADS_DIR / repo_id)
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_model)
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")

View file

@ -73,14 +73,13 @@ class TestHuggingFaceOnnxModels(unittest.TestCase):
onnx_model_path = snapshot_download_with_retry(
repo_id=repo_id,
allow_patterns=["*.onnx", "*.onnx_data"],
cache_dir=str(DOWNLOADS_DIR)
local_dir=DOWNLOADS_DIR / repo_id
)
onnx_model_path = onnx_model_path / model_file
file_size = onnx_model_path.stat().st_size
print(f"Validating model: {repo_id}/{model_file} ({file_size/1e6:.2f}M)")
validate(onnx_model_path, custom_inputs, rtol=rtol, atol=atol)
@unittest.skip("onnxruntime 1.24+ rejects huggingface_hub symlinks as path traversal")
def test_xlm_roberta_large(self):
repo_id = "FacebookAI/xlm-roberta-large"
model_file = "onnx/model.onnx"