mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
884592f6c8
commit
6838b35cff
3 changed files with 56 additions and 7 deletions
|
|
@ -130,15 +130,18 @@ class NVDriver(VirtDriver):
|
|||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVContextShare(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVChannelGroup)
|
||||
parent = self.object_by_handle.get(struct.hObjectParent)
|
||||
assert parent is not None and isinstance(parent, (NVChannelGroup, NVGPU))
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS.from_address(params_ptr)
|
||||
gpu = self.object_by_handle[struct.hObjectParent].device
|
||||
gpu = parent.device if isinstance(parent, NVChannelGroup) else parent
|
||||
gpfifo_token = gpu.add_gpfifo(params.gpFifoOffset, params.gpFifoEntries)
|
||||
self.object_by_handle[struct.hObjectNew] = NVGPFIFO(gpu, gpfifo_token)
|
||||
elif struct.hClass == nv_gpu.AMPERE_DMA_COPY_B or struct.hClass == nv_gpu.ADA_COMPUTE_A:
|
||||
elif struct.hClass in (nv_gpu.AMPERE_DMA_COPY_B, nv_gpu.ADA_COMPUTE_A, nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER):
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPFIFO)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
gpfifo = self.object_by_handle[struct.hObjectParent]
|
||||
gpfifo.device.queues[gpfifo.token].bound_engines.add(struct.hClass)
|
||||
elif struct.hClass == nv_gpu.GT200_DEBUGGER:
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
elif struct.hClass == nv_gpu.MAXWELL_PROFILER_DEVICE:
|
||||
|
|
@ -194,7 +197,7 @@ class NVDriver(VirtDriver):
|
|||
params = nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS.from_address(params_ptr)
|
||||
gpu_fifo = self.object_by_handle[struct.hObject]
|
||||
params.workSubmitToken = gpu_fifo.token
|
||||
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
|
||||
elif struct.cmd in (nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE, nv_gpu.NVA06F_CTRL_CMD_BIND, nv_gpu.NVA06F_CTRL_CMD_GPFIFO_SCHEDULE): pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_FB_FLUSH_GPU_CACHE: pass
|
||||
elif struct.cmd == nv_gpu.NV83DE_CTRL_CMD_DEBUG_READ_ALL_SM_ERROR_STATES:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class GPFIFO:
|
|||
self.gpfifo = to_mv(self.base, self.entries_cnt * 8).cast("Q")
|
||||
self.ctrl = nv_gpu.AmpereAControlGPFifo.from_address(self.base + self.entries_cnt * 8)
|
||||
self.state = {}
|
||||
self.bound_engines: set[int] = set()
|
||||
|
||||
# Buf exec state
|
||||
self.buf = None
|
||||
|
|
@ -115,9 +116,11 @@ class GPFIFO:
|
|||
def execute_cmd(self, cmd) -> SchedResult:
|
||||
if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
|
||||
elif cmd == nv_gpu.NVC6C0_LAUNCH_DMA: return self._exec_nvc6c0_dma()
|
||||
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: return self._exec_nvc6b5_dma()
|
||||
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: # NOTE: NVC6B5_LAUNCH_DMA == NVC9B0_EXECUTE == 0x300
|
||||
return self._exec_vid_decode() if self.bound_engines & {nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER} else self._exec_nvc6b5_dma()
|
||||
elif cmd == nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B: return self._exec_pcas2()
|
||||
elif cmd == 0x0320: return self._exec_load_inline_qmd() # NVC6C0_LOAD_INLINE_QMD_DATA
|
||||
elif cmd == nv_gpu.NVC9B0_SEMAPHORE_D: return self._exec_vid_semaphore()
|
||||
else: self.state[cmd] = self._next_dword() # just state update
|
||||
return SchedResult.CONT
|
||||
|
||||
|
|
@ -136,6 +139,27 @@ class GPFIFO:
|
|||
else: raise RuntimeError(f"Unsupported type={typ} in exec wait/signal")
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_vid_decode(self) -> SchedResult:
|
||||
self._next_dword() # consume execute flags
|
||||
# validate that all required decode state was set up correctly
|
||||
assert self._state(nv_gpu.NVC9B0_SET_APPLICATION_ID) == nv_gpu.NVC9B0_SET_APPLICATION_ID_ID_HEVC
|
||||
pic_desc_addr = self._state(nv_gpu.NVC9B0_SET_DRV_PIC_SETUP_OFFSET) << 8
|
||||
pic = nv_gpu.nvdec_hevc_pic_s.from_address(pic_desc_addr)
|
||||
assert pic.stream_len > 0 and pic.pic_width_in_luma_samples > 0 and pic.pic_height_in_luma_samples > 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_IN_BUF_BASE_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_COLOC_DATA_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_NVDEC_STATUS_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_HEVC_SET_FILTER_BUFFER_OFFSET) != 0
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_vid_semaphore(self) -> SchedResult:
|
||||
signal = self._state64(nv_gpu.NVC9B0_SEMAPHORE_A)
|
||||
val = self._state(nv_gpu.NVC9B0_SEMAPHORE_C)
|
||||
self._next_dword() # flags
|
||||
to_mv(signal, 8).cast('Q')[0] = val
|
||||
to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_load_inline_qmd(self):
|
||||
qmd_addr = self._state64(nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A) << 8
|
||||
assert qmd_addr != 0x0, f"invalid qmd address {qmd_addr}"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import fetch, round_up
|
||||
from extra.hevc.hevc import parse_hevc_file_headers, nv_gpu
|
||||
from extra.hevc.decode import hevc_decode
|
||||
|
||||
class TestHevc(unittest.TestCase):
|
||||
def test_hevc_parser(self):
|
||||
|
|
@ -61,5 +62,26 @@ class TestHevc(unittest.TestCase):
|
|||
self.assertEqual(list(frame3.initreflistidxl0), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
self.assertEqual(list(frame3.initreflistidxl1), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
self.assertEqual(list(frame3.RefDiffPicOrderCnts), [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "NV only")
|
||||
def test_hevc_decode(self):
|
||||
url = "https://github.com/haraschax/filedump/raw/09a497959f7fa6fd8dba501a25f2cdb3a41ecb12/comma_video.hevc"
|
||||
dat = fetch(url, headers={"Range": f"bytes=0-{512<<10}"}).read_bytes()
|
||||
|
||||
opaque, frame_info, w, h, luma_w, luma_h, chroma_off = parse_hevc_file_headers(dat)
|
||||
frame_info = frame_info[:4]
|
||||
out_image_size = luma_h + (luma_h + 1) // 2, round_up(luma_w, 64)
|
||||
|
||||
hevc_tensor = Tensor(dat, device="NV")
|
||||
opaque_nv = opaque.to("NV").contiguous().realize()
|
||||
|
||||
frames = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w))
|
||||
Device.default.synchronize()
|
||||
self.assertEqual(len(frames), 4)
|
||||
for f in frames:
|
||||
self.assertEqual(f.shape, out_image_size)
|
||||
self.assertEqual(f.dtype, dtypes.uint8)
|
||||
self.assertEqual(f.device, "NV")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue