mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
upcasted warp experiments
This commit is contained in:
parent
be53ef4f0a
commit
0df4355cd8
2 changed files with 41 additions and 0 deletions
24
extra/upcasted_warps.py
Normal file
24
extra/upcasted_warps.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# play with upcasted warps
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
from tinygrad.opt import get_optimized_ast
|
||||
from tinygrad.opt.kernel import OptOps, Opt
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
if __name__ == "__main__":
|
||||
renderer = Device.default.renderer
|
||||
N = 64
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
# metal TC
|
||||
#opts = (Opt(OptOps.UPCAST, 0, 2), # not the warp
|
||||
# Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2), Opt(OptOps.UPCAST, 1, 2),
|
||||
# Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2))
|
||||
# new TC should just be able to extract from this and swizzle as needed
|
||||
opts = (Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UPCAST, 1, 8), Opt(OptOps.UNROLL, 0, 8))
|
||||
c = (a@b)
|
||||
ast = c.schedule()[-1].ast
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=opts))
|
||||
ast = get_optimized_ast(ast, renderer)
|
||||
prg = get_program(ast, renderer)
|
||||
print(prg.src)
|
||||
|
|
@ -37,6 +37,17 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
|
|||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
|
||||
# tensor cores
|
||||
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
def tensor_cores(a:UOp, b:UOp, r:UOp):
|
||||
print("use tensor cores")
|
||||
|
||||
pm_tensor_cores = PatternMatcher([
|
||||
((UPat.var().gep(name='a') * UPat.var().gep(name='b')).reduce(name='r', allow_any_len=True), tensor_cores),
|
||||
])
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
|
|
@ -50,10 +61,16 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
# expand
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
|
||||
# use tensor cores
|
||||
ret.append(RewriteStep(pm_tensor_cores, name="tensor cores"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
# factorize warp (before gpu dims)
|
||||
#ret.append(RewriteStep(pm_warp, name="warpcast"))
|
||||
|
||||
# add gpu dims (late)
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue