mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
start torch.compile support (#9279)
This commit is contained in:
parent
4342300eff
commit
b6a14911c8
2 changed files with 42 additions and 1 deletions
|
|
@ -15,6 +15,7 @@ def unwrap(x:torch.Tensor) -> Tensor:
|
|||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
class TinyBackend:
|
||||
def is_initialized(self): return True
|
||||
def is_available(self): return True
|
||||
def current_device(self): return 0
|
||||
torch.utils.rename_privateuse1_backend("tiny")
|
||||
|
|
@ -230,6 +231,9 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
|||
"aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__,
|
||||
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
|
||||
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
|
||||
# TODO: support this in tinygrad
|
||||
"aten.bitwise_left_shift.Tensor_out": lambda self, other: Tensor(self << other.numpy()),
|
||||
"aten.bitwise_right_shift.Tensor_out": lambda self, other: Tensor(self >> other.numpy()),
|
||||
# not in tinygrad. are there decomps for these?
|
||||
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
|
||||
"aten.log1p.out": lambda self: (self+1).log(),
|
||||
|
|
@ -244,8 +248,9 @@ def wrap_out(f):
|
|||
def _wrap_out(*args, **kwargs):
|
||||
out = kwargs.pop('out')
|
||||
assigned = f(*args, **kwargs)
|
||||
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert getenv("ALLOW_DTYPE_MISMATCH") or out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
return out.replace(assigned)
|
||||
return _wrap_out
|
||||
|
||||
|
|
|
|||
36
extra/torch_backend/test_compile.py
Normal file
36
extra/torch_backend/test_compile.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from extra.torch_backend.backend import unwrap, wrap
|
||||
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
|
||||
from tinygrad import Tensor, TinyJit
|
||||
|
||||
@register_backend
|
||||
def tiny(gm:torch.fx.GraphModule, sample_inputs):
|
||||
def my_compiler(gm:torch.fx.GraphModule, sample_inputs):
|
||||
# TODO: the jit should capture the graph directly, not need three runs. this is a planned tinygrad refactor after becomes_map
|
||||
@TinyJit
|
||||
def tiny_function(*args:Tensor):
|
||||
outs = gm(*[wrap(x) for x in args])
|
||||
for x in outs: unwrap(x).realize()
|
||||
return outs
|
||||
# TODO: this should be able to pass in .tiny() Tensors, not need to convert them. it tries to access Storage if you pass in.
|
||||
def torch_function(*args:torch.Tensor): return tiny_function(*[unwrap(x.tiny()) for x in args])
|
||||
return torch_function
|
||||
return aot_module_simplified(gm, sample_inputs, decompositions={}, fw_compiler=my_compiler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def foo(x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return a + b
|
||||
|
||||
print("calling compile")
|
||||
opt_foo1 = torch.compile(foo, backend="tiny")
|
||||
print("compiled")
|
||||
for i in range(5):
|
||||
out = opt_foo1(torch.randn(10, 10), torch.randn(10, 10))
|
||||
print(out.device)
|
||||
Loading…
Add table
Add a link
Reference in a new issue