mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix: pass input device to ONNX helper internal tensors (#16242)
* fix: pass input device to onnx methods internal tensors * test: onnx helper internal tensors use input device
This commit is contained in:
parent
a178301dbe
commit
74567c1958
2 changed files with 43 additions and 7 deletions
38
test/external/external_test_onnx_ops.py
vendored
38
test/external/external_test_onnx_ops.py
vendored
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from typing import Any
|
||||
import unittest, onnx, tempfile
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad import dtypes, Tensor, Context
|
||||
from tinygrad.nn.onnx import OnnxRunner
|
||||
import numpy as np
|
||||
from extra.onnx_helpers import validate
|
||||
|
|
@ -364,6 +364,19 @@ class TestMainOnnxOps(TestOnnxOps):
|
|||
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
|
||||
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
|
||||
|
||||
def test_same_device_as_input(self):
|
||||
from tinygrad.nn.onnx import onnx_ops
|
||||
EyeLike = onnx_ops["EyeLike"]
|
||||
Shape = onnx_ops["Shape"]
|
||||
Compress = onnx_ops["Compress"]
|
||||
with Context(DEV="CPU"):
|
||||
x = Tensor.arange(4, device="PYTHON").reshape(2,2)
|
||||
self.assertEqual(EyeLike(x).device, x.device)
|
||||
self.assertEqual(Shape(x).device, x.device)
|
||||
out = Compress(x, [True, False, True, False])
|
||||
self.assertEqual(out.device, x.device)
|
||||
self.assertEqual(out.tolist(), [0, 2])
|
||||
|
||||
class TestTrainingOnnxOps(TestOnnxOps):
|
||||
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
|
||||
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
|
||||
|
|
@ -581,5 +594,28 @@ class TestContribOnnxOps(TestOnnxOps):
|
|||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
|
||||
|
||||
def test_same_device_as_input(self):
|
||||
from tinygrad.nn.onnx import onnx_ops, OpSetId, Domain
|
||||
EmbedLayerNormalization = onnx_ops["EmbedLayerNormalization"]
|
||||
Attention = onnx_ops["Attention"]
|
||||
with Context(DEV="CPU"):
|
||||
input_ids = Tensor([[1, 2]], device="PYTHON", dtype=dtypes.int32)
|
||||
segment_ids = Tensor([[0, 0]], device="PYTHON", dtype=dtypes.int32)
|
||||
word = Tensor.ones(4, 3, device="PYTHON")
|
||||
pos = Tensor.ones(5, 3, device="PYTHON")
|
||||
seg = Tensor.ones(1, 3, device="PYTHON")
|
||||
gamma, beta = Tensor.ones(3, device="PYTHON"), Tensor.zeros(3, device="PYTHON")
|
||||
out, _, _ = EmbedLayerNormalization(input_ids, segment_ids, word, pos, seg, gamma, beta)
|
||||
self.assertEqual(out.device, input_ids.device)
|
||||
out.realize()
|
||||
|
||||
attn = Attention[OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1)]
|
||||
x = Tensor.ones(1, 2, 4, device="PYTHON")
|
||||
w = Tensor.ones(4, 12, device="PYTHON")
|
||||
mask = Tensor([2, 0], device="PYTHON", dtype=dtypes.int32)
|
||||
out, _ = attn(x, w, mask_index=mask, num_heads=1, unidirectional=1)
|
||||
self.assertEqual(out.device, x.device)
|
||||
out.realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -586,7 +586,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
||||
|
||||
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype)
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype, device=x.device)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
||||
|
||||
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
||||
|
|
@ -597,7 +597,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
return value.expand(shape)
|
||||
|
||||
def Size(data:Tensor): return data.numel()
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64, device=data.device)
|
||||
|
||||
# ***** Unary Ops (math) *****
|
||||
def Not(x:Tensor): return x.logical_not()
|
||||
|
|
@ -934,7 +934,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length).unsqueeze(0).expand(*input_shape)
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
|
||||
|
|
@ -1036,14 +1036,14 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
elif mask_index.shape[0] == 2*batch_size:
|
||||
end_positions = mask_index[:batch_size]
|
||||
start_positions = mask_index[batch_size:]
|
||||
arange = Tensor.arange(seq_len).unsqueeze(0)
|
||||
arange = Tensor.arange(seq_len, device=mask_index.device).unsqueeze(0)
|
||||
mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1))
|
||||
else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented")
|
||||
while mask.ndim < 4: mask = mask.unsqueeze(1)
|
||||
attn_scores = mask.where(attn_scores, mask_filter_value)
|
||||
|
||||
if unidirectional:
|
||||
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril()
|
||||
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril()
|
||||
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
|
||||
|
||||
output = attn_scores.softmax(-1) @ v
|
||||
|
|
@ -1199,7 +1199,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
inp = inp.flatten()
|
||||
axis = 0
|
||||
axis = inp._resolve_dim(axis)
|
||||
con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
|
||||
con = Tensor([i for i,cond in enumerate(condition) if cond], device=inp.device) # compress in python
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
# ***** Quantization Ops *****
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue