mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
spend lines on const_arg for tensor and scheduler [pr] (#8132)
* spend lines on const_arg for tensor and scheduler [pr] * simple test_const_arg * base on lazy
This commit is contained in:
parent
917deb88a4
commit
7436ebef2f
5 changed files with 23 additions and 3 deletions
|
|
@ -405,6 +405,14 @@ class TestUOpMethod(unittest.TestCase):
|
|||
self.assertEqual(const._device, None)
|
||||
with self.assertRaises(AssertionError): const.device
|
||||
|
||||
def test_const_arg(self):
|
||||
var = UOp.variable("a", 1, 10)
|
||||
with self.assertRaises(AssertionError): UOp.const(dtypes.int, var).const_arg
|
||||
const = UOp.const(dtypes.int, 1)
|
||||
self.assertEqual(const.const_arg, 1)
|
||||
tensor_const = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), const), ShapeTracker.from_shape(()))
|
||||
self.assertEqual(tensor_const.const_arg, 1)
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
|
|
|
|||
|
|
@ -116,6 +116,10 @@ class LazyBuffer(MathTrait):
|
|||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
@property
|
||||
def const_arg(self) -> ConstType:
|
||||
assert self.base.op is Ops.CONST and isinstance(self.base.arg, get_args(ConstType)), f"const_arg called on {self}"
|
||||
return self.base.arg
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:
|
|||
# remove reduce on unmasked const
|
||||
if all_int(x.shape) and x.is_unrealized_unmasked_const():
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.base.src[1].arg
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal, get_args
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -292,6 +292,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
@property
|
||||
def const_arg(self) -> ConstType:
|
||||
match self.base.op:
|
||||
case Ops.CONST: ret = self.base.arg
|
||||
case Ops.VIEW: ret = self.base.src[1].const_arg
|
||||
case op: raise AssertionError(f"const_arg called on {op}")
|
||||
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
|
||||
return ret
|
||||
@property
|
||||
def axis_arg(self) -> Tuple[int, ...]:
|
||||
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
|
|
|
|||
|
|
@ -2995,7 +2995,7 @@ class Tensor(SimpleMathTrait):
|
|||
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
||||
|
||||
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
||||
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
||||
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue