mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hevc: decoder as iterator (#14091)
This commit is contained in:
parent
35c9701df0
commit
3e2c05ee9f
1 changed files with 42 additions and 28 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import argparse, os, hashlib
|
||||
import argparse, os, hashlib, functools
|
||||
from typing import Iterator, Callable
|
||||
from tinygrad.helpers import getenv, DEBUG, round_up, Timing, tqdm, fetch, ceildiv
|
||||
from extra.hevc.hevc import parse_hevc_file_headers, untile_nv12, to_bgr, nv_gpu
|
||||
from tinygrad import Tensor, dtypes, Device, Variable, TinyJit
|
||||
|
|
@ -6,6 +7,36 @@ from tinygrad import Tensor, dtypes, Device, Variable, TinyJit
|
|||
# rounds up hevc input data to 32 bytes, so more optimal kernels can be generated
|
||||
HEVC_ROUNDUP = getenv("DATA_ROUNDUP", 32)
|
||||
|
||||
@functools.cache
|
||||
def _hevc_jitted_decoder(out_image_size:tuple[int, int], max_hist:int, inplace:bool):
|
||||
def hevc_decode_frame(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque:Tensor, i:Variable, *hist:Tensor, outbuf:Tensor|None=None):
|
||||
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist)
|
||||
if outbuf is not None: outbuf.assign(x).realize()
|
||||
return x.realize()
|
||||
return TinyJit(hevc_decode_frame)
|
||||
|
||||
def hevc_decode(hevc_tensor:Tensor, opaque:Tensor, frame_info:list, luma_h:int, luma_w:int,
|
||||
history:list[Tensor]|None=None, preallocated_outputs:list[Tensor]|None=None, warmup=False) -> Iterator[Tensor]:
|
||||
out_image_size = luma_h + (luma_h + 1) // 2, round_up(luma_w, 64)
|
||||
max_hist = max((hs for _, _, _, hs, _ in frame_info), default=0)
|
||||
|
||||
v_pos = Variable("pos", 0, max_hist + 1)
|
||||
v_offset = Variable("offset", 0, hevc_tensor.numel()-1)
|
||||
v_sz = Variable("sz", 1, ceildiv(hevc_tensor.numel(), HEVC_ROUNDUP))
|
||||
v_i = Variable("i", 0, len(frame_info)-1)
|
||||
|
||||
decode_jit = _hevc_jitted_decoder(out_image_size, max_hist, preallocated_outputs is not None)
|
||||
history = history or [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV").contiguous().realize() for _ in range(max_hist)]
|
||||
assert len(history) == max_hist, f"history length {len(history)} does not match max_hist {max_hist}"
|
||||
|
||||
for i, (offset, sz, frame_pos, _, is_hist) in enumerate(frame_info):
|
||||
history = history[-max_hist:] if max_hist > 0 else []
|
||||
img = decode_jit(v_pos.bind(frame_pos), hevc_tensor, v_offset.bind(offset), v_sz.bind(ceildiv(sz, HEVC_ROUNDUP)),
|
||||
opaque, v_i.bind(i), *history, outbuf=preallocated_outputs[i] if preallocated_outputs else None)
|
||||
res = preallocated_outputs[i] if preallocated_outputs else img.clone().realize()
|
||||
if is_hist: history.append(res)
|
||||
yield res
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, default="")
|
||||
|
|
@ -22,7 +53,6 @@ if __name__ == "__main__":
|
|||
dat_hash = hashlib.md5(dat).hexdigest()
|
||||
|
||||
with Timing("prep infos: "):
|
||||
dat_nv = hevc_tensor.to("NV")
|
||||
opaque, frame_info, w, h, luma_w, luma_h, chroma_off = parse_hevc_file_headers(dat)
|
||||
|
||||
frame_info = frame_info[:getenv("MAX_FRAMES", len(frame_info))]
|
||||
|
|
@ -33,38 +63,22 @@ if __name__ == "__main__":
|
|||
hevc_tensor = hevc_tensor.to("NV")
|
||||
|
||||
out_image_size = luma_h + (luma_h + 1) // 2, round_up(luma_w, 64)
|
||||
max_hist = max(history_sz for _, _, _, history_sz, _ in frame_info)
|
||||
|
||||
# define variables
|
||||
v_pos = Variable("pos", 0, max_hist + 1)
|
||||
v_offset = Variable("offset", 0, hevc_tensor.numel()-1)
|
||||
v_sz = Variable("sz", 1, ceildiv(hevc_tensor.numel(), HEVC_ROUNDUP))
|
||||
v_i = Variable("i", 0, len(frame_info)-1)
|
||||
|
||||
@TinyJit
|
||||
def decode_jit(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque_nv:Tensor, i:Variable, outbuf:Tensor, *hist:Tensor):
|
||||
x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque_nv[i], hist)
|
||||
outbuf.assign(x).realize()
|
||||
return x
|
||||
|
||||
# preallocate output buffers
|
||||
# preallocate output/hist buffers
|
||||
max_hist = max((hs for _, _, _, hs, _ in frame_info), default=0)
|
||||
hist = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV").contiguous().realize() for _ in range(max_hist)]
|
||||
out_images = [Tensor.zeros(*out_image_size, dtype=dtypes.uint8, device="NV").contiguous().realize() for _ in range(len(frame_info))]
|
||||
|
||||
# warm up
|
||||
history = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV") for _ in range(max_hist)]
|
||||
for i in range(3):
|
||||
decode_jit(v_pos.bind(0), hevc_tensor, v_offset.bind(frame_info[0][0]), v_sz.bind(ceildiv(frame_info[0][1], HEVC_ROUNDUP)), opaque_nv,
|
||||
v_i.bind(0), out_images[i], *history)
|
||||
# warmup decode
|
||||
_ = list(hevc_decode(hevc_tensor, opaque_nv, frame_info[:3], luma_h, luma_w, history=hist, preallocated_outputs=out_images))
|
||||
Device.default.synchronize()
|
||||
|
||||
# decode all frames using the iterator
|
||||
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
|
||||
for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info):
|
||||
history = history[-max_hist:] if max_hist > 0 else []
|
||||
decode_jit(v_pos.bind(frame_pos), hevc_tensor, v_offset.bind(offset), v_sz.bind(ceildiv(sz, HEVC_ROUNDUP)), opaque_nv,
|
||||
v_i.bind(i), out_images[i], *history)
|
||||
if is_hist: history.append(out_images[i])
|
||||
|
||||
images = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w, history=hist, preallocated_outputs=out_images))
|
||||
Device.default.synchronize()
|
||||
|
||||
# validation
|
||||
if getenv("VALIDATE", 0):
|
||||
import pickle
|
||||
if dat_hash == "b813bfdbec194fd17fdf0e3ceb8cea1c":
|
||||
|
|
@ -73,7 +87,7 @@ if __name__ == "__main__":
|
|||
else: decoded_frames = pickle.load(open(f"extra/hevc/decoded_frames_{dat_hash}.pkl", "rb"))
|
||||
else: import cv2
|
||||
|
||||
for i, img in tqdm(enumerate(out_images)):
|
||||
for i, img in tqdm(enumerate(images)):
|
||||
if getenv("VALIDATE", 0):
|
||||
if i < len(decoded_frames) and len(decoded_frames[i]) > 0:
|
||||
img = untile_nv12(img, h, w, luma_w, chroma_off).realize()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue