mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move pow folding tests to test_schedule [pr] (#8955)
not really belongs to test_const_folding
This commit is contained in:
parent
c2b4c43edb
commit
cfd28517df
2 changed files with 26 additions and 35 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import Ops, GroupOp
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
|
@ -97,34 +97,6 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
|||
def test_tensor_one_pow(self):
|
||||
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_2_pow_is_exp2(self):
|
||||
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
self.assertEqual(alu, [Ops.EXP2])
|
||||
|
||||
def test_pow_05_is_sqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
self.assertEqual(alu, [Ops.SQRT])
|
||||
|
||||
def test_pow_neg_05_is_rsqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
self.assertEqual(alu, [Ops.RECIP, Ops.SQRT])
|
||||
|
||||
def test_pow_8_has_3_muls(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 8
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
self.assertEqual(alu, [Ops.MUL, Ops.MUL, Ops.MUL])
|
||||
|
||||
# folds advance indexing into basic indexing
|
||||
class TestIndexingConstFolding(unittest.TestCase):
|
||||
def test_scalar_index(self):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
|
|||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
|
|
@ -559,11 +559,30 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x.to('python')
|
||||
check_schedule(out, 0, filter_sink=False)
|
||||
|
||||
def test_pow_const_tensor_simplified(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
|
||||
out = x ** Tensor(2.0)
|
||||
check_schedule(out, 1)
|
||||
def _alu_from_tensor(self, t:Tensor):
|
||||
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(s), 1)
|
||||
return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
|
||||
|
||||
def test_2_pow_is_exp2(self):
|
||||
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
|
||||
|
||||
def test_pow_05_is_sqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
|
||||
|
||||
def test_pow_neg_05_is_rsqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
|
||||
|
||||
def test_pow_2_has_1_mul(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
|
||||
|
||||
def test_pow_8_has_3_muls(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** 8
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
|
||||
|
||||
def test_pow_const_tensor_to_zero(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue