mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
verify_sin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c85dfe1500 |
1 changed files with 6 additions and 2 deletions
|
|
@ -135,8 +135,12 @@ kernel_spec = buffer_spec+PatternMatcher([
|
|||
# *** this is the UOp shape spec ***
|
||||
|
||||
def verify_sink_dims(sink:UOp):
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
|
||||
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
|
||||
# get all shapes from ShapeTrackers in the graph, grouped by dimension
|
||||
all_shapes = [x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None]
|
||||
assert all_same([len(x) for x in all_shapes]), f"shapes aren't all the same length {all_shapes}"
|
||||
shape_dims = [dedup(dims) for dims in zip(*all_shapes)]
|
||||
# confirm each dimension is either all the same, or has two values, a 1 (after the reduce), and n (before the reduce)
|
||||
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and (x[0] == 1 or x[1] == 1)) for x in shape_dims)
|
||||
|
||||
shape_spec = PatternMatcher([
|
||||
# shapes must have either 1 or n in each dimension
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue