Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
c85dfe1500 improve verify_sink_dims 2025-03-18 19:38:22 +08:00

View file

@ -135,8 +135,12 @@ kernel_spec = buffer_spec+PatternMatcher([
# *** this is the UOp shape spec ***
def verify_sink_dims(sink:UOp):
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
# get all shapes from ShapeTrackers in the graph, grouped by dimension
all_shapes = [x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None]
assert all_same([len(x) for x in all_shapes]), f"shapes aren't all the same length {all_shapes}"
shape_dims = [dedup(dims) for dims in zip(*all_shapes)]
# confirm each dimension is either all the same, or has two values, a 1 (after the reduce), and n (before the reduce)
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and (x[0] == 1 or x[1] == 1)) for x in shape_dims)
shape_spec = PatternMatcher([
# shapes must have either 1 or n in each dimension