Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 150 additions & 4 deletions decent_array/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

Hot-path notes:

* ``__add__``/``__sub__``/``__mul__``/``__truediv__``/``__matmul__``, the unary
``__neg__``/``__abs__``/``__pow__``, the comparisons ``__eq__``/``__ne__``/``__lt__``/
* ``__add__``/``__sub__``/``__mul__``/``__truediv__``/``__matmul__``, ``__floordiv__``, ``__mod__``,
the unary ``__neg__``/``__abs__``/``__pow__``, the comparisons ``__eq__``/``__ne__``/``__lt__``/
``__le__``/``__gt__``/``__ge__`` and the bitwise ``__and__``/``__rand__`` are
inlined: every supported framework's tensor implements the equivalent operator
natively with numpy-equivalent semantics, so routing through the interoperability layer
Expand All @@ -29,12 +29,13 @@
from typing import TYPE_CHECKING, Any, Self

from decent_array.interoperability._backend_manager import register_backend_listener
from decent_array.types import _STRING_TO_DTYPE

if TYPE_CHECKING:
from numpy.typing import NDArray

from decent_array.interoperability._abstracts import Backend
from decent_array.types import ArrayKey, SupportedArrayTypes, SupportedDevices
from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices


_BACKEND_INSTANCE: Backend | None = None
Expand Down Expand Up @@ -119,6 +120,22 @@ def __rtruediv__(self, other: int | float | complex | Array, /) -> Array:
"""Return the true division of ``other`` by the array."""
return Array(other / self.value)

def __floordiv__(self, other: int | float | Array, /) -> Array:
"""Return the floor division of the array by ``other``."""
return Array(self.value // (other.value if type(other) is Array else other))

def __rfloordiv__(self, other: int | float | Array, /) -> Array:
"""Return the floor division of ``other`` by the array."""
return Array(other // self.value)

def __mod__(self, other: int | float | Array, /) -> Array:
"""Return the remainder after floor division of the array by ``other``."""
return Array(self.value % (other.value if type(other) is Array else other))

def __rmod__(self, other: int | float | Array, /) -> Array:
"""Return the remainder after floor division of ``other`` by the array."""
return Array(other % self.value)

def __matmul__(self, other: Array, /) -> Array:
"""Return the matrix multiplication of the array with ``other``."""
return Array(self.value @ other.value)
Expand All @@ -128,12 +145,16 @@ def __rmatmul__(self, other: Array, /) -> Array:
return Array(other.value @ self.value)

def __pow__(self, other: int | float | complex | Array, /) -> Array:
"""Exponentiate the array by a scalar power."""
"""Exponentiate the array element-wise."""
# numpy/torch/jax/tf all implement ``tensor ** p`` with semantics matching the
# backend's ``pow``; routing through the backend would cost an extra method
# call for no behavioral difference.
return Array(self.value ** (other.value if type(other) is Array else other))

def __rpow__(self, other: int | float | complex | Array, /) -> Array:
"""Exponentiate other element-wise by array."""
return Array(other**self.value)

# Comparisons ----------------------------------------------------------
#
# Element-wise comparisons return an :class:`Array` of bools. The ``__eq__`` and
Expand Down Expand Up @@ -189,6 +210,42 @@ def __rand__(self, other: bool | int | Array, /) -> Array:
"""Element-wise bitwise/logical AND with the array on the right."""
return Array((other.value if type(other) is Array else other) & self.value)

def __or__(self, other: bool | int | Array, /) -> Array:
"""Element-wise bitwise/logical OR."""
return Array(self.value | (other.value if type(other) is Array else other))

def __ror__(self, other: bool | int | Array, /) -> Array:
"""Element-wise bitwise/logical OR with the array on the right."""
return Array((other.value if type(other) is Array else other) | self.value)

def __xor__(self, other: bool | int | Array, /) -> Array:
"""Element-wise bitwise/logical XOR."""
return Array(self.value ^ (other.value if type(other) is Array else other))

def __rxor__(self, other: bool | int | Array, /) -> Array:
"""Element-wise bitwise/logical XOR with the array on the right."""
return Array((other.value if type(other) is Array else other) ^ self.value)

def __lshift__(self, other: int | Array, /) -> Array:
"""Element-wise bitwise left shift as specified by int/int array."""
return self._backend.bitwise_left_shift(self, other)

def __rlshift__(self, other: int | Array, /) -> Array:
"""Element-wise bitwise left shift as specified by int/int array on the right."""
return self._backend.bitwise_left_shift(other, self)

def __rshift__(self, other: int | Array, /) -> Array:
"""Element-wise bitwise right shift as specified by int/int array."""
return self._backend.bitwise_right_shift(self, other)

def __rrshift__(self, other: int | Array, /) -> Array:
"""Element-wise bitwise right shift as specified by int/int array on the right."""
return self._backend.bitwise_right_shift(other, self)

def __invert__(self) -> Array:
"""Element-wise bitwise/logical NOT."""
return Array(~self.value)

# In-place arithmetic --------------------------------------------------
#
# The backend handles the framework's mutability semantics: numpy/pytorch mutate
Expand All @@ -215,6 +272,51 @@ def __itruediv__(self, other: int | float | complex | Array, /) -> Self:
self._backend.idivide(self, other)
return self

def __ifloordiv__(self, other: int | float | Array) -> Self:
"""In-place floor division."""
self._backend.ifloordiv(self, other)
return self

def __imod__(self, other: int | float | Array) -> Self:
"""In-place remainder after floor division."""
self._backend.imod(self, other)
return self

def __ipow__(self, other: int | float | complex | Array) -> Self:
"""In-place raise array to power ``other``."""
self._backend.ipow(self, other)
return self

def __imatmul__(self, other: Array) -> Self:
"""In-place matrix multiplication."""
self._backend.imatmul(self, other)
return self

def __iand__(self, other: bool | int | Array) -> Self:
"""In-place bitwise/logical AND."""
self._backend.iand(self, other)
return self

def __ior__(self, other: bool | int | Array) -> Self:
"""In-place bitwise/logical OR."""
self._backend.ior(self, other)
return self

def __ixor__(self, other: bool | int | Array) -> Self:
"""In-place bitwise/logical XOR."""
self._backend.ixor(self, other)
return self

def __ilshift__(self, other: int | Array) -> Self:
"""In-place bitwise left shift."""
self._backend.ilshift(self, other)
return self

def __irshift__(self, other: int | Array) -> Self:
"""In-place bitwise right shift."""
self._backend.irshift(self, other)
return self

# Unary ----------------------------------------------------------------

def __neg__(self) -> Array:
Expand All @@ -223,6 +325,10 @@ def __neg__(self) -> Array:
# supported frameworks, so the indirection is not needed.
return Array(-self.value)

def __pos__(self) -> Array:
"""Return the array itself."""
return self
Comment on lines +328 to +330

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this needed for?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's to handle the case of writing +x. not very common but it can happen

@Simpag Simpag Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, that would probably be more of a syntax error for this class but fine to keep


def __abs__(self) -> Array:
"""Return the absolute value of the array."""
# Same rationale as ``__neg__`` — native ``abs(tensor)`` matches each
Expand Down Expand Up @@ -251,6 +357,22 @@ def __float__(self) -> float:
"""Coerce a scalar array to a Python float."""
return float(self._backend.squeeze(self).value)

def __bool__(self) -> bool:
"""Coerce a scalar array to a Python bool."""
return bool(self._backend.squeeze(self).value)

def __int__(self) -> int:
"""Coerce a scalar array to a Python int."""
return int(self._backend.squeeze(self).value)

def __complex__(self) -> complex:
"""Coerce a scalar array to a Python complex."""
return complex(self._backend.squeeze(self).value)

def __index__(self) -> int:
"""Coerce a scalar array to a Python int."""
return int(self._backend.squeeze(self).value)

# Repr -----------------------------------------------------------------

def __repr__(self) -> str:
Expand Down Expand Up @@ -278,6 +400,25 @@ def ndim(self) -> int:
"""Return the number of dimensions of the array."""
return self._backend.ndim(self)

@property
def dtype(self) -> DTypes:
"""
Return dtype of the Array as item of DTypes enum.

Raises:
ValueError: for dtypes that are not supported by all decent-array functions

"""
# get framework-native dtype as string
# split takes care of types with names like "torch.float32"
dtype_name = str(self.value.dtype).split(".")[-1]

dtype = _STRING_TO_DTYPE.get(dtype_name)
Comment on lines +414 to +416

@Simpag Simpag Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky and hard to extend. It would be cleaner if this is added to the iop functions. Maybe "dtype_of" or just "dtype" but the former is a bit clearer. Doing it in the iop layer would just require a reversed map of the one thats already in each backend, it would basically be something like:

_INV_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}

def dtype_of(x):
    v = _INV_DTYPE_MAP.get(x.dtype)
    if not v:
        # Should never reach here since we control dtype but can be if user does something hacky
        raise SOME_ERROR
    return v

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is definitely really hacky. I'm starting to work on a much more solid version now based on the ideas in #8. but that is beyond the scope of this PR, so I thought to leave this hacky placeholder for now

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, feel free to merge when you feel ready

if dtype is None:
raise ValueError(f"dtype {self.value.dtype} is not supported by all decent-array functions.")

return dtype

@property
def transpose(self) -> Array:
"""Return a transposed view of the array."""
Expand All @@ -288,6 +429,11 @@ def T(self) -> Array: # noqa: N802
"""Return a transposed view of the array."""
return self.transpose

@property
def mT(self) -> Array: # noqa: N802
"""Return the matrix transpose (last two dimensions swapped)."""
return self._backend.matrix_transpose(self)

@property
def any(self) -> bool:
"""Return True if any element of the array is truthy."""
Expand Down
34 changes: 32 additions & 2 deletions decent_array/interoperability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"""

from ._backend_manager import default_device, set_backend
from ._iop.bit_operators import bitwise_and
from ._iop.bit_operators import (
bitwise_and,
bitwise_invert,
bitwise_left_shift,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
)
from ._iop.comparasion import equal, greater, greater_equal, less, less_equal, not_equal
from ._iop.creation import eye, ones, ones_like, zeros, zeros_like
from ._iop.linalg import dot, matmul, norm, vecdot, vector_norm
Expand All @@ -26,6 +33,7 @@
expand_dims,
from_numpy,
from_numpy_like,
matrix_transpose,
ndim,
reshape,
shape,
Expand All @@ -36,7 +44,20 @@
transpose,
unsqueeze,
)
from ._iop.math import abs, absolute, add, divide, multiply, negative, pow, sqrt, subtract # noqa: A004
from ._iop.math import (
abs, # noqa: A004
absolute,
add,
divide,
floor_divide,
multiply,
negative,
positive,
pow, # noqa: A004
remainder,
sqrt,
subtract,
)
from ._iop.operators import argmax, argmin, maximum, sign
from ._iop.reductions import all, any, max, mean, min, sum # noqa: A004
from ._iop.rng import (
Expand Down Expand Up @@ -65,6 +86,11 @@
"asarray",
"astype",
"bitwise_and",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"choice",
"copy",
"default_device",
Expand All @@ -77,6 +103,7 @@
"equal",
"expand_dims",
"eye",
"floor_divide",
"from_numpy",
"from_numpy_like",
"get_numpy_rng",
Expand All @@ -87,6 +114,7 @@
"less",
"less_equal",
"matmul",
"matrix_transpose",
"max",
"maximum",
"mean",
Expand All @@ -100,7 +128,9 @@
"not_equal",
"ones",
"ones_like",
"positive",
"pow",
"remainder",
"reshape",
"set_backend",
"set_rng_state",
Expand Down
Loading
Loading