mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
a few more UPat.var -> UPat.cvar in the scheduler [pr] (#8391)
* a few more UPat.var -> UPat.cvar in the scheduler [pr] * keep it assert * minimal diff
This commit is contained in:
parent
3273972f44
commit
5c2fe04bb6
2 changed files with 10 additions and 12 deletions
|
|
@ -403,17 +403,16 @@ class UPatScheduled(UPat):
|
|||
# ** this is schedule level const folding
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
# remove reduce on unmasked const
|
||||
if all_int(x.shape) and x.is_unrealized_unmasked_const():
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return UOp.const(reduce.dtype, ret)
|
||||
return None
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return UOp.const(reduce.dtype, ret)
|
||||
|
||||
def simplify_alu(alu:UOp):
|
||||
if not all(x.is_unrealized_unmasked_const() for x in alu.src): return None
|
||||
|
|
@ -456,7 +455,7 @@ ops_folding = PatternMatcher([
|
|||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_reduceop),
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# CONST doesn't need COPY
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"),)), lambda x: x),
|
||||
# no double COPY
|
||||
|
|
|
|||
|
|
@ -325,7 +325,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def const_arg(self) -> ConstType:
|
||||
match self.base.op:
|
||||
case Ops.CONST: ret = self.base.arg
|
||||
case Ops.VIEW: ret = self.base.src[1].const_arg
|
||||
case op: raise AssertionError(f"const_arg called on {op}")
|
||||
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
|
||||
return ret
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue