mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix examples
This commit is contained in:
parent
50ac2872b3
commit
1c18e1bae8
2 changed files with 5 additions and 4 deletions
|
|
@ -8,7 +8,7 @@ from tinygrad.uop.render import pyrender
|
|||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -36,9 +36,10 @@ pm_index_is_shrink = PatternMatcher([
|
|||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
# rewrite PARAM to non pointer
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) if isinstance(buf.dtype, PtrDType) else None),
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
|
||||
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
|
||||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class CStyleLanguage(Renderer):
|
|||
def render_dtype_with_shape(self, u:UOp) -> DType: return dtype_with_shape(u.dtype, u.shape)
|
||||
def render_access(self, bidx:UOp, dtype:DType) -> str:
|
||||
if bidx.addrspace == AddrSpace.REG: return self[bidx]
|
||||
return f"(*(({self.render_dtype(dtype)}*)({self[bidx]})))" if dtype.count > 1 else f"(*{self[bidx]})"
|
||||
return f"(*(({self.render_dtype(dtype.ptr(addrspace=bidx.addrspace))})({self[bidx]})))" if dtype.count > 1 else f"(*{self[bidx]})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
if isinstance(dt, PtrDType):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue