mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
lil cleanups from uop branch [pr] (#11197)
This commit is contained in:
parent
5625e1904b
commit
770a558585
3 changed files with 7 additions and 7 deletions
|
|
@ -490,11 +490,11 @@ class Kernel:
|
|||
st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:])
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), ret)))
|
||||
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), grouped_reduce)))
|
||||
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
|
||||
|
||||
return ret
|
||||
self.finalized = True
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class ASM24Controller:
|
|||
if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return []
|
||||
|
||||
assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}"
|
||||
if DEBUG >= 3: print("pcie_request", hex(fmt_type), hex(address), value, size)
|
||||
if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size)
|
||||
|
||||
masked_address, offset = address & 0xFFFFFFFC, address & 0x3
|
||||
assert size + offset <= 4 and (value is None or value >> (8 * size) == 0)
|
||||
|
|
|
|||
|
|
@ -254,14 +254,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
# move any non reduce axis before the first reduce axis
|
||||
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
||||
if move_early:
|
||||
permute = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
||||
ret = self.permute(permute)
|
||||
if move_early and permute:
|
||||
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
||||
ret = self.permute(permaxis)
|
||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
||||
assert len(axis) == len(new_axis)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue