lil cleanups from uop branch [pr] (#11197)

This commit is contained in:
George Hotz 2025-07-12 09:46:28 -07:00 committed by GitHub
commit 770a558585
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 7 additions and 7 deletions

View file

@ -490,11 +490,11 @@ class Kernel:
st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:])
local_size = st.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), ret)))
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), grouped_reduce)))
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
return ret
self.finalized = True

View file

@ -171,7 +171,7 @@ class ASM24Controller:
if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return []
assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}"
if DEBUG >= 3: print("pcie_request", hex(fmt_type), hex(address), value, size)
if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size)
masked_address, offset = address & 0xFFFFFFFC, address & 0x3
assert size + offset <= 4 and (value is None or value >> (8 * size) == 0)

View file

@ -254,14 +254,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
def r(self, op:Ops, axis:tuple[int, ...]):
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
# move any non reduce axis before the first reduce axis
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
if move_early:
permute = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
ret = self.permute(permute)
if move_early and permute:
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
ret = self.permute(permaxis)
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
assert len(axis) == len(new_axis)
else: