mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
dtype_shap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35e13e08d5 |
1 changed files with 25 additions and 4 deletions
|
|
@ -208,7 +208,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
|
||||
Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
|
||||
|
|
@ -224,21 +224,38 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return inner_shape
|
||||
|
||||
case Ops.CAST:
|
||||
# if it has a vec dtype, set the shape
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
# when PTX casts from ptr to non ptr, remove the shape of the buffer
|
||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||
return ()
|
||||
|
||||
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
|
||||
case Ops.STACK: return (len(self.src),)
|
||||
case Ops.INDEX:
|
||||
shp:list[sint] = []
|
||||
# NOTE: the acc buffer can have a dtype with count, we need it back here
|
||||
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
|
||||
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# TODO: these should have the shape of the dtype.count
|
||||
case Ops.CONST | Ops.DEFINE_VAR: return ()
|
||||
case Ops.GEP | Ops.STACK | Ops.VCONST | Ops.VCAT: return ()
|
||||
case Ops.CONST | Ops.DEFINE_VAR:
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
return ()
|
||||
case Ops.VCONST: return (len(self.arg),)
|
||||
case Ops.VCAT: return ()
|
||||
|
||||
# TODO: contract and unroll should be deleted
|
||||
case Ops.CONTRACT:
|
||||
amt = prod([x[1] for x in self.arg])
|
||||
return (amt,) if amt > 1 else ()
|
||||
case Ops.UNROLL:
|
||||
amt = prod([x[1] for x in self.arg])
|
||||
return () if self.src[0]._shape[0] == amt else (amt,)
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW:
|
||||
# HACK: BUFFER_VIEW is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
|
|
@ -261,6 +278,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return self.src[0]._shape
|
||||
# REDUCE with empty axis is passthrough (lowered form)
|
||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
||||
# these can mismatch if there's a horizonal reduce
|
||||
if self.src[0].dtype.count > 1:
|
||||
assert self.src[0]._shape is not None and len(self.src[0]._shape) == 1, f"bad reduce shape on {self.src[0].op} {self.src[0]._shape}"
|
||||
return () if self.dtype.count == 1 else (self.dtype.count,)
|
||||
return self.src[0]._shape
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue