add inline flag for call

This commit is contained in:
George Hotz 2026-02-10 12:19:51 +08:00
commit e89221e9aa
4 changed files with 11 additions and 10 deletions

View file

@ -360,7 +360,8 @@ class Embedding:
def __call__(self, idx:Tensor) -> Tensor:
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
if USE_ATOMICS:
return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd, inline=True)
return _embedding_fwd(self.weight, idx)
class LSTMCell:

View file

@ -79,9 +79,8 @@ mop_cleanup = PatternMatcher([
])
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
# we only resolve here if the call is inlined
if not c.arg.inline: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here

View file

@ -239,8 +239,8 @@ class Tensor(OpMixin):
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, **kwargs) -> Tensor:
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], **kwargs), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""

View file

@ -823,10 +823,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), inline=False) -> UOp:
# TODO: reenable this after ENCDEC is fixed
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, inline))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
@ -848,9 +848,10 @@ class KernelInfo:
class CallInfo:
grad_fxn: Callable|None = None
metadata: tuple[Metadata, ...] = ()
inline: bool = False
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
def __reduce__(self): return (CallInfo, (None, self.metadata, self.inline))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {self.inline})"
# ******** ops in python ********