upcasted warp experiments

This commit is contained in:
George Hotz 2025-06-29 09:04:37 -07:00
commit 0df4355cd8
2 changed files with 41 additions and 0 deletions

24
extra/upcasted_warps.py Normal file
View file

@ -0,0 +1,24 @@
# play with upcasted warps
from tinygrad import Tensor, Device
from tinygrad.uop.ops import KernelInfo
from tinygrad.opt import get_optimized_ast
from tinygrad.opt.kernel import OptOps, Opt
from tinygrad.engine.realize import get_program
if __name__ == "__main__":
renderer = Device.default.renderer
N = 64
a = Tensor.empty(N,N)
b = Tensor.empty(N,N)
# metal TC
#opts = (Opt(OptOps.UPCAST, 0, 2), # not the warp
# Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2), Opt(OptOps.UPCAST, 1, 2),
# Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2))
# new TC should just be able to extract from this and swizzle as needed
opts = (Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UPCAST, 1, 8), Opt(OptOps.UNROLL, 0, 8))
c = (a@b)
ast = c.schedule()[-1].ast
ast = ast.replace(arg=KernelInfo(opts_to_apply=opts))
ast = get_optimized_ast(ast, renderer)
prg = get_program(ast, renderer)
print(prg.src)

View file

@ -37,6 +37,17 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
# tensor cores
from tinygrad.uop.ops import PatternMatcher, UPat, UOp
def tensor_cores(a:UOp, b:UOp, r:UOp):
print("use tensor cores")
pm_tensor_cores = PatternMatcher([
((UPat.var().gep(name='a') * UPat.var().gep(name='b')).reduce(name='r', allow_any_len=True), tensor_cores),
])
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
@ -50,10 +61,16 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# expand
ret.append(RewriteStep(sym+expander, name="expander"))
# use tensor cores
ret.append(RewriteStep(pm_tensor_cores, name="tensor cores"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
# factorize warp (before gpu dims)
#ret.append(RewriteStep(pm_warp, name="warpcast"))
# add gpu dims (late)
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))