mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
that hack broke things
This commit is contained in:
parent
d1d935242b
commit
dcc6ddf0eb
3 changed files with 5 additions and 3 deletions
|
|
@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
|
||||
|
||||
# **** Grouper decides which of the UOps realize
|
||||
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ create_kernels = PatternMatcher([
|
|||
|
||||
early_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x), tag=1)),
|
||||
# no SINK for meta ops
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
])
|
||||
|
|
@ -168,6 +168,8 @@ def check_load_st(glbl:UOp, view:UOp):
|
|||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# add the LOAD
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
|
||||
|
||||
# hack for PTX, CASTing the ptr loses the shape
|
||||
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
||||
#if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
||||
|
||||
# otherwise we get the shape from sources
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue