mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
openpilot warp (#13283)
* openpilot image warp test * 0.4 ms on metal, 1 ms on CPU * new inputs each time * reshape
This commit is contained in:
parent
7c110e1a57
commit
e5351699bd
1 changed files with 55 additions and 0 deletions
55
test/external/external_openpilot_image_warp.py
vendored
Normal file
55
test/external/external_openpilot_image_warp.py
vendored
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import time
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
||||
MODEL_WIDTH = 512
|
||||
MODEL_HEIGHT = 256
|
||||
MODEL_FRAME_SIZE = MODEL_WIDTH * MODEL_HEIGHT * 3 // 2
|
||||
IMG_INPUT_SHAPE = (1, 12, 128, 256)
|
||||
|
||||
def tensor_arange(end): return Tensor([float(i) for i in range(end)])
|
||||
def tensor_round(tensor:Tensor): return (tensor + 0.5).floor()
|
||||
|
||||
h_src, w_src = 1208, 1928
|
||||
h_dst, w_dst = MODEL_HEIGHT, MODEL_WIDTH
|
||||
x = tensor_arange(w_dst).reshape(1, w_dst).expand(h_dst, w_dst)
|
||||
y = tensor_arange(h_dst).reshape(h_dst, 1).expand(h_dst, w_dst)
|
||||
ones = Tensor.ones_like(x)
|
||||
dst_coords = x.reshape((1,-1)).cat(y.reshape((1,-1))).cat(ones.reshape((1,-1)))
|
||||
|
||||
def warp_perspective_tinygrad(src:Tensor, M_inv:Tensor) -> Tensor:
|
||||
src_coords = M_inv @ dst_coords
|
||||
src_coords = src_coords / src_coords[2:3, :]
|
||||
|
||||
x_src = src_coords[0].reshape(h_dst, w_dst)
|
||||
y_src = src_coords[1].reshape(h_dst, w_dst)
|
||||
|
||||
x_nearest = tensor_round(x_src).clip(0, w_src - 1).cast('int')
|
||||
y_nearest = tensor_round(y_src).clip(0, h_src - 1).cast('int')
|
||||
|
||||
# TODO: make 2d indexing fast
|
||||
idx = y_nearest*src.shape[1] + x_nearest
|
||||
dst = src.flatten()[idx]
|
||||
return dst.reshape(h_dst, w_dst)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
update_img_jit = TinyJit(warp_perspective_tinygrad, prune=True)
|
||||
|
||||
step_times = []
|
||||
for _ in range(10):
|
||||
# regenerate inputs
|
||||
inputs = [Tensor.randn(1928,1208), Tensor.randn(3,3)]
|
||||
Tensor.realize(*inputs)
|
||||
Device.default.synchronize()
|
||||
|
||||
# do the warp
|
||||
st = time.perf_counter()
|
||||
out = update_img_jit(*inputs)
|
||||
mt = time.perf_counter()
|
||||
val = out.contiguous().realize()
|
||||
Device.default.synchronize()
|
||||
et = time.perf_counter()
|
||||
|
||||
# measure the time
|
||||
step_times.append((et-st)*1e3)
|
||||
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {step_times[-1]:6.2f} ms")
|
||||
Loading…
Add table
Add a link
Reference in a new issue