mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
tests pass
This commit is contained in:
parent
95d04048b0
commit
8ad8249e06
1 changed files with 4 additions and 2 deletions
|
|
@ -272,6 +272,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return self.src[0]._shape
|
||||
# REDUCE with empty axis is passthrough (lowered form)
|
||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
||||
# these can mismatch if there's a horizonal reduce
|
||||
if self.dtype.count > 1:
|
||||
assert len(self.src[0]._shape) == 1
|
||||
return (self.dtype.count,)
|
||||
return self.src[0]._shape
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
|
|
@ -289,8 +293,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE}):
|
||||
ps = self.src[0]._shape
|
||||
# TODO: WMMA is used for both axis WMMA and op WMMA. fix this and remove this hack. tested by BERT on AMD LLVM
|
||||
if ps is None and self.op is Ops.WMMA: return None
|
||||
if ps is None: raise RuntimeError(f"movement op {self.op} requires shape")
|
||||
match self.op:
|
||||
case Ops.RESHAPE:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue