mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assert for elementwise dtypes in lazy (#2888)
* assert for elementwise dtypes in lazy * no image hack * check dtype of scalar for IMAGE=2
This commit is contained in:
parent
41b2a25be6
commit
2d2c4980fe
4 changed files with 14 additions and 15 deletions
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import sys, math
|
||||
import numpy as np
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict
|
||||
from tinygrad.helpers import prod, dtypes, DType, merge_dicts, flatten, getenv, dedup, ImageDType, DEBUG, all_int
|
||||
from tinygrad.helpers import prod, dtypes, DType, merge_dicts, flatten, getenv, dedup, ImageDType, DEBUG, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps
|
||||
from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
|
|
@ -59,9 +59,7 @@ class LazyBuffer:
|
|||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
|
||||
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
# NOTE: we force the image dtype const to be a float32
|
||||
const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():
|
||||
|
|
@ -96,16 +94,17 @@ class LazyBuffer:
|
|||
out = self.contiguous()
|
||||
return create_lazybuffer(device, out.st, out.dtype, LoadOps.COPY, srcs=(out,))
|
||||
|
||||
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
srcs = (self,)+srcs
|
||||
new_srcs = []
|
||||
for s in srcs:
|
||||
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
srcs: List[LazyBuffer] = []
|
||||
for s in (self,)+in_srcs:
|
||||
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
||||
new_srcs.append(root._view(s.base.contiguous_child[1]))
|
||||
srcs.append(root._view(s.base.contiguous_child[1]))
|
||||
else:
|
||||
new_srcs.append(s)
|
||||
srcs = tuple(new_srcs)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), max(x.dtype for x in srcs), op, arg, srcs)
|
||||
srcs.append(s)
|
||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
|
||||
assert op != TernaryOps.WHERE or self.device != "WEBGPU" or srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
output_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else (dtypes.bool if self.device != "WEBGPU" else dtypes.float32)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), output_dtype, op, arg, tuple(srcs))
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class Max(Function):
|
|||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)).cast(self.x.dtype))
|
||||
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
|||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: np.bitwise_xor,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
|||
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
||||
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
|
||||
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul,
|
||||
BinaryOps.DIV: lambda x,y: torch.div(x, y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.XOR: torch.bitwise_xor,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue