fix type in fold_bitcast [pr] (#11853)

This commit is contained in:
chenyu 2025-08-26 13:22:30 -04:00 committed by GitHub
commit aabe7756be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,5 +1,5 @@
# all of symbolic lives here now
from typing import Any, cast
from typing import cast
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
@ -19,7 +19,7 @@ def simplify_pow(x:UOp, c:UOp) -> UOp|None:
def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
if c.dtype.itemsize != root.dtype.itemsize: return None
def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
symbolic_simple = PatternMatcher([