This commit is contained in:
George Hotz 2026-06-23 17:26:11 -07:00
commit f2c93fd0d4

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2
while step:
active = tid < step
other = lds[active.where(tid + step, UOp.invalid())].load()
lds = lds.after(lds[active.where(tid, UOp.invalid())].store(lds[tid].maximum(other)).barrier())
other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])