mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add inline flag for call
This commit is contained in:
parent
69574542ab
commit
e89221e9aa
4 changed files with 11 additions and 10 deletions
|
|
@ -360,7 +360,8 @@ class Embedding:
|
|||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
|
||||
if USE_ATOMICS:
|
||||
return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd, inline=True)
|
||||
return _embedding_fwd(self.weight, idx)
|
||||
|
||||
class LSTMCell:
|
||||
|
|
|
|||
|
|
@ -79,9 +79,8 @@ mop_cleanup = PatternMatcher([
|
|||
])
|
||||
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
# we only resolve here if the call is inlined
|
||||
if not c.arg.inline: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
|
|
|||
|
|
@ -239,8 +239,8 @@ class Tensor(OpMixin):
|
|||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, **kwargs) -> Tensor:
|
||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], **kwargs), device=self.device)
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -823,10 +823,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), inline=False) -> UOp:
|
||||
# TODO: reenable this after ENCDEC is fixed
|
||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, inline))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
|
|
@ -848,9 +848,10 @@ class KernelInfo:
|
|||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
inline: bool = False
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.inline))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {self.inline})"
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue