fix: pass input device to ONNX helper internal tensors (#16242)

* fix: pass input device to onnx methods internal tensors

* test: onnx helper internal tensors use input device
This commit is contained in:
Sachith Shetty 2026-05-19 11:16:33 -07:00 committed by GitHub
commit 74567c1958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 7 deletions

View file

@ -4,7 +4,7 @@
from typing import Any
import unittest, onnx, tempfile
from tinygrad import dtypes, Tensor
from tinygrad import dtypes, Tensor, Context
from tinygrad.nn.onnx import OnnxRunner
import numpy as np
from extra.onnx_helpers import validate
@ -364,6 +364,19 @@ class TestMainOnnxOps(TestOnnxOps):
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
def test_same_device_as_input(self):
from tinygrad.nn.onnx import onnx_ops
EyeLike = onnx_ops["EyeLike"]
Shape = onnx_ops["Shape"]
Compress = onnx_ops["Compress"]
with Context(DEV="CPU"):
x = Tensor.arange(4, device="PYTHON").reshape(2,2)
self.assertEqual(EyeLike(x).device, x.device)
self.assertEqual(Shape(x).device, x.device)
out = Compress(x, [True, False, True, False])
self.assertEqual(out.device, x.device)
self.assertEqual(out.tolist(), [0, 2])
class TestTrainingOnnxOps(TestOnnxOps):
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
@ -581,5 +594,28 @@ class TestContribOnnxOps(TestOnnxOps):
outputs = ["C"]
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
def test_same_device_as_input(self):
from tinygrad.nn.onnx import onnx_ops, OpSetId, Domain
EmbedLayerNormalization = onnx_ops["EmbedLayerNormalization"]
Attention = onnx_ops["Attention"]
with Context(DEV="CPU"):
input_ids = Tensor([[1, 2]], device="PYTHON", dtype=dtypes.int32)
segment_ids = Tensor([[0, 0]], device="PYTHON", dtype=dtypes.int32)
word = Tensor.ones(4, 3, device="PYTHON")
pos = Tensor.ones(5, 3, device="PYTHON")
seg = Tensor.ones(1, 3, device="PYTHON")
gamma, beta = Tensor.ones(3, device="PYTHON"), Tensor.zeros(3, device="PYTHON")
out, _, _ = EmbedLayerNormalization(input_ids, segment_ids, word, pos, seg, gamma, beta)
self.assertEqual(out.device, input_ids.device)
out.realize()
attn = Attention[OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1)]
x = Tensor.ones(1, 2, 4, device="PYTHON")
w = Tensor.ones(4, 12, device="PYTHON")
mask = Tensor([2, 0], device="PYTHON", dtype=dtypes.int32)
out, _ = attn(x, w, mask_index=mask, num_heads=1, unidirectional=1)
self.assertEqual(out.device, x.device)
out.realize()
if __name__ == "__main__":
unittest.main()

View file

@ -586,7 +586,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype)
ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype, device=x.device)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
@ -597,7 +597,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return value.expand(shape)
def Size(data:Tensor): return data.numel()
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64, device=data.device)
# ***** Unary Ops (math) *****
def Not(x:Tensor): return x.logical_not()
@ -934,7 +934,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
# bert embedding layer
if position_ids is None: position_ids = Tensor.arange(seq_length).unsqueeze(0).expand(*input_shape)
if position_ids is None: position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
@ -1036,14 +1036,14 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
elif mask_index.shape[0] == 2*batch_size:
end_positions = mask_index[:batch_size]
start_positions = mask_index[batch_size:]
arange = Tensor.arange(seq_len).unsqueeze(0)
arange = Tensor.arange(seq_len, device=mask_index.device).unsqueeze(0)
mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1))
else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented")
while mask.ndim < 4: mask = mask.unsqueeze(1)
attn_scores = mask.where(attn_scores, mask_filter_value)
if unidirectional:
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril()
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril()
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
output = attn_scores.softmax(-1) @ v
@ -1199,7 +1199,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
inp = inp.flatten()
axis = 0
axis = inp._resolve_dim(axis)
con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
con = Tensor([i for i,cond in enumerate(condition) if cond], device=inp.device) # compress in python
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
# ***** Quantization Ops *****