start torch.compile support (#9279)

This commit is contained in:
George Hotz 2025-02-27 10:29:51 +08:00 committed by GitHub
commit b6a14911c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 1 deletions

View file

@ -15,6 +15,7 @@ def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
class TinyBackend:
def is_initialized(self): return True
def is_available(self): return True
def current_device(self): return 0
torch.utils.rename_privateuse1_backend("tiny")
@ -230,6 +231,9 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__,
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
# TODO: support this in tinygrad
"aten.bitwise_left_shift.Tensor_out": lambda self, other: Tensor(self << other.numpy()),
"aten.bitwise_right_shift.Tensor_out": lambda self, other: Tensor(self >> other.numpy()),
# not in tinygrad. are there decomps for these?
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
"aten.log1p.out": lambda self: (self+1).log(),
@ -244,8 +248,9 @@ def wrap_out(f):
def _wrap_out(*args, **kwargs):
out = kwargs.pop('out')
assigned = f(*args, **kwargs)
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert getenv("ALLOW_DTYPE_MISMATCH") or out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
return out.replace(assigned)
return _wrap_out

View file

@ -0,0 +1,36 @@
# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
import torch
import torch._dynamo
from extra.torch_backend.backend import unwrap, wrap
from torch._dynamo.backends.registry import register_backend
from torch._functorch.aot_autograd import aot_module_simplified
from tinygrad import Tensor, TinyJit
@register_backend
def tiny(gm:torch.fx.GraphModule, sample_inputs):
def my_compiler(gm:torch.fx.GraphModule, sample_inputs):
# TODO: the jit should capture the graph directly, not need three runs. this is a planned tinygrad refactor after becomes_map
@TinyJit
def tiny_function(*args:Tensor):
outs = gm(*[wrap(x) for x in args])
for x in outs: unwrap(x).realize()
return outs
# TODO: this should be able to pass in .tiny() Tensors, not need to convert them. it tries to access Storage if you pass in.
def torch_function(*args:torch.Tensor): return tiny_function(*[unwrap(x.tiny()) for x in args])
return torch_function
return aot_module_simplified(gm, sample_inputs, decompositions={}, fw_compiler=my_compiler)
if __name__ == "__main__":
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print("calling compile")
opt_foo1 = torch.compile(foo, backend="tiny")
print("compiled")
for i in range(5):
out = opt_foo1(torch.randn(10, 10), torch.randn(10, 10))
print(out.device)