a few more UPat.var -> UPat.cvar in the scheduler [pr] (#8391)

* a few more UPat.var -> UPat.cvar in the scheduler [pr]

* keep it assert

* minimal diff
This commit is contained in:
qazal 2024-12-24 14:36:24 +02:00 committed by GitHub
commit 5c2fe04bb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 12 deletions

View file

@ -403,17 +403,16 @@ class UPatScheduled(UPat):
# ** this is schedule level const folding
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
# remove reduce on unmasked const
if all_int(x.shape) and x.is_unrealized_unmasked_const():
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return UOp.const(reduce.dtype, ret)
return None
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return UOp.const(reduce.dtype, ret)
def simplify_alu(alu:UOp):
if not all(x.is_unrealized_unmasked_const() for x in alu.src): return None
@ -456,7 +455,7 @@ ops_folding = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_reduceop),
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# CONST doesn't need COPY
(UPat(Ops.COPY, src=(UPat.cvar("x"),)), lambda x: x),
# no double COPY

View file

@ -325,7 +325,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def const_arg(self) -> ConstType:
match self.base.op:
case Ops.CONST: ret = self.base.arg
case Ops.VIEW: ret = self.base.src[1].const_arg
case op: raise AssertionError(f"const_arg called on {op}")
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
return ret