assert for elementwise dtypes in lazy (#2888)

* assert for elementwise dtypes in lazy

* no image hack

* check dtype of scalar for IMAGE=2
This commit is contained in:
chenyu 2023-12-21 01:42:32 -05:00 committed by GitHub
commit 2d2c4980fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 15 deletions

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import sys, math
import numpy as np
from typing import Union, Optional, Any, Tuple, List, Set, Dict
from tinygrad.helpers import prod, dtypes, DType, merge_dicts, flatten, getenv, dedup, ImageDType, DEBUG, all_int
from tinygrad.helpers import prod, dtypes, DType, merge_dicts, flatten, getenv, dedup, ImageDType, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps
from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast
from tinygrad.shape.symbolic import sint, Variable
@ -59,9 +59,7 @@ class LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
def const(self, val:Union[float, int]) -> LazyBuffer:
# NOTE: we force the image dtype const to be a float32
const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32
return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def contiguous(self):
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():
@ -96,16 +94,17 @@ class LazyBuffer:
out = self.contiguous()
return create_lazybuffer(device, out.st, out.dtype, LoadOps.COPY, srcs=(out,))
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
srcs = (self,)+srcs
new_srcs = []
for s in srcs:
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
srcs: List[LazyBuffer] = []
for s in (self,)+in_srcs:
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
new_srcs.append(root._view(s.base.contiguous_child[1]))
srcs.append(root._view(s.base.contiguous_child[1]))
else:
new_srcs.append(s)
srcs = tuple(new_srcs)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), max(x.dtype for x in srcs), op, arg, srcs)
srcs.append(s)
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
assert op != TernaryOps.WHERE or self.device != "WEBGPU" or srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
output_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else (dtypes.bool if self.device != "WEBGPU" else dtypes.float32)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), output_dtype, op, arg, tuple(srcs))
# *** reduce ops ***

View file

@ -159,7 +159,7 @@ class Max(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)).cast(self.x.dtype))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))

View file

@ -31,7 +31,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
BinaryOps.XOR: np.bitwise_xor,

View file

@ -33,7 +33,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: x<y,
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul,
BinaryOps.DIV: lambda x,y: torch.div(x, y).type(torch.promote_types(x.dtype, y.dtype)),
BinaryOps.XOR: torch.bitwise_xor,