mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
[Feature] Added BinaryOps.AND/BinaryOps.OR (#5223)
* [Feature] Added BinaryOps.AND/BinaryOps.OR * Add: __rand__, __ror__
This commit is contained in:
parent
50b05dd3f4
commit
ad1ca7da64
9 changed files with 68 additions and 3 deletions
|
|
@ -26,7 +26,8 @@ if Device.DEFAULT == "LLVM":
|
|||
binary_operations.remove(operator.lt)
|
||||
binary_operations.remove(operator.eq)
|
||||
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)]
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
|
||||
(Tensor.bitwise_or, np.bitwise_or)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
|
||||
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
|
||||
|
|
|
|||
|
|
@ -467,6 +467,20 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_and(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor&tor, lambda: ten&ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor&0x1337, lambda: ten&0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337&tor, lambda: 0x1337&ten, forward_only=True)
|
||||
|
||||
def test_or(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor|tor, lambda: ten|ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor|0x1337, lambda: ten|0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337|tor, lambda: 0x1337|ten, forward_only=True)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
|
|
|
|||
|
|
@ -124,6 +124,8 @@ class TestNonFloatUOps(TestUOps):
|
|||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
|
|
@ -157,6 +159,8 @@ class TestBoolUOps(TestUOps):
|
|||
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
|
||||
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
|
||||
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
|
||||
def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
|
||||
def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
|
||||
def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
|
||||
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
|
||||
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
|
||||
|
|
|
|||
|
|
@ -106,6 +106,12 @@ class Neq(Function):
|
|||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
||||
|
||||
class BitwiseAnd(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
|
||||
|
||||
class BitwiseOr(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class UnaryOps(Enum):
|
|||
class BinaryOps(Enum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHR = auto(); SHL = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto() # noqa: E702
|
||||
class TernaryOps(Enum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
|
@ -125,6 +125,7 @@ python_alu = {
|
|||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class PTXRenderer(Renderer):
|
|||
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.AND: lambda d, a, b, dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.OR: lambda d, a, b, dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class CStyleLanguage(Renderer):
|
|||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
||||
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
||||
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
||||
|
||||
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
||||
|
|
|
|||
|
|
@ -2462,6 +2462,36 @@ class Tensor:
|
|||
"""
|
||||
return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
Compute the bit-wise AND of `self` and `x`.
|
||||
Equivalent to `self & x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
||||
```
|
||||
"""
|
||||
assert dtypes.is_int(self.dtype)
|
||||
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
Compute the bit-wise OR of `self` and `x`.
|
||||
Equivalent to `self | x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
||||
```
|
||||
"""
|
||||
assert dtypes.is_int(self.dtype)
|
||||
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def lshift(self, x:int):
|
||||
"""
|
||||
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
||||
|
|
@ -2586,6 +2616,8 @@ class Tensor:
|
|||
def __pow__(self, x) -> Tensor: return self.pow(x)
|
||||
def __truediv__(self, x) -> Tensor: return self.div(x)
|
||||
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
||||
def __and__(self, x) -> Tensor: return self.bitwise_and(x)
|
||||
def __or__(self, x) -> Tensor: return self.bitwise_or(x)
|
||||
def __xor__(self, x) -> Tensor: return self.xor(x)
|
||||
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
||||
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
||||
|
|
@ -2596,6 +2628,8 @@ class Tensor:
|
|||
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
||||
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
||||
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
||||
def __rand__(self, x) -> Tensor: return self.bitwise_and(x, True)
|
||||
def __ror__(self, x) -> Tensor: return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
||||
|
||||
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
||||
|
|
@ -2604,6 +2638,8 @@ class Tensor:
|
|||
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
||||
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue