fix sd-example dtype for CLIP embeddings (#14397)

This commit is contained in:
Jakob Sachs 2026-01-28 15:07:19 +01:00 committed by GitHub
commit 2b7c00d3d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -93,7 +93,7 @@ if __name__ == "__main__":
forward: Any = None
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "textModel", input = [Tensor.randint(1, 77, low=0, high=49408, dtype=dtypes.int32)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)