mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move type verify to codegen [pr] (#10816)
This commit is contained in:
parent
27cf836958
commit
cc5e4e54b8
2 changed files with 6 additions and 5 deletions
|
|
@ -3,6 +3,7 @@ import functools
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
|
|
@ -72,4 +73,8 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: return list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
||||
if __debug__: type_verify(lst)
|
||||
return lst
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from collections import defaultdict
|
|||
from dataclasses import dataclass, replace
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.helpers import dedup, partition, all_same, flatten, getenv
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
||||
|
|
@ -237,9 +236,6 @@ def finalize(sink:UOp) -> UOp:
|
|||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
|
||||
if __debug__: type_verify(lst)
|
||||
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue