mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
view_after
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e59c5c3a1 |
2 changed files with 5 additions and 8 deletions
|
|
@ -15,9 +15,6 @@ merge_views = PatternMatcher([
|
|||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
|
|
@ -124,6 +121,9 @@ fix_kernel_ops = view_left_through_load+PatternMatcher([
|
|||
# add view to LOAD and STORE
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(UOp(Ops.VIEW, arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
|
|
|
|||
|
|
@ -255,11 +255,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
if shape is not None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
if device is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue