mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix commavq benchmark (#4712)
* fix _slice and assert explicit device * with _slice
This commit is contained in:
parent
84255069e7
commit
c170ddceaf
2 changed files with 8 additions and 7 deletions
|
|
@ -650,7 +650,7 @@ def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional
|
|||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = x.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
xq, xk, xv = [xqkv._slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
|
|
|
|||
13
test/external/external_model_benchmark.py
vendored
13
test/external/external_model_benchmark.py
vendored
|
|
@ -1,5 +1,6 @@
|
|||
import csv, pathlib, time, numpy as np
|
||||
from os import getenv
|
||||
from tinygrad.device import CompileError
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
import onnx
|
||||
|
|
@ -60,8 +61,8 @@ def benchmark_model(m, devices, validate_outs=False):
|
|||
|
||||
# print input names
|
||||
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
|
||||
try:
|
||||
for device in devices:
|
||||
for device in devices:
|
||||
try:
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
|
|
@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
|
|||
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
except Exception as e:
|
||||
# model crashed
|
||||
print(f"{m} crashed on {device} with: {e}")
|
||||
return
|
||||
except CompileError as e:
|
||||
# METAL fails with buffer count limit
|
||||
if m == "dm" and device == "METAL": return
|
||||
raise e
|
||||
|
||||
# convert model to torch
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue