-
Notifications
You must be signed in to change notification settings - Fork 2
enh(array, iop): add new Array attributes and dunders, and corresponding iop functions #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
23887cd
31d75ae
e3c106a
742717f
bfc4f85
a15466d
dbf063c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,8 +11,8 @@ | |
|
|
||
| Hot-path notes: | ||
|
|
||
| * ``__add__``/``__sub__``/``__mul__``/``__truediv__``/``__matmul__``, the unary | ||
| ``__neg__``/``__abs__``/``__pow__``, the comparisons ``__eq__``/``__ne__``/``__lt__``/ | ||
| * ``__add__``/``__sub__``/``__mul__``/``__truediv__``/``__matmul__``, ``__floordiv__``, ``__mod__``, | ||
| the unary ``__neg__``/``__abs__``/``__pow__``, the comparisons ``__eq__``/``__ne__``/``__lt__``/ | ||
| ``__le__``/``__gt__``/``__ge__`` and the bitwise ``__and__``/``__rand__`` are | ||
| inlined: every supported framework's tensor implements the equivalent operator | ||
| natively with numpy-equivalent semantics, so routing through the interoperability layer | ||
|
|
@@ -29,12 +29,13 @@ | |
| from typing import TYPE_CHECKING, Any, Self | ||
|
|
||
| from decent_array.interoperability._backend_manager import register_backend_listener | ||
| from decent_array.types import _STRING_TO_DTYPE | ||
|
|
||
| if TYPE_CHECKING: | ||
| from numpy.typing import NDArray | ||
|
|
||
| from decent_array.interoperability._abstracts import Backend | ||
| from decent_array.types import ArrayKey, SupportedArrayTypes, SupportedDevices | ||
| from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices | ||
|
|
||
|
|
||
| _BACKEND_INSTANCE: Backend | None = None | ||
|
|
@@ -119,6 +120,22 @@ def __rtruediv__(self, other: int | float | complex | Array, /) -> Array: | |
| """Return the true division of ``other`` by the array.""" | ||
| return Array(other / self.value) | ||
|
|
||
| def __floordiv__(self, other: int | float | Array, /) -> Array: | ||
| """Return the floor division of the array by ``other``.""" | ||
| return Array(self.value // (other.value if type(other) is Array else other)) | ||
|
|
||
| def __rfloordiv__(self, other: int | float | Array, /) -> Array: | ||
| """Return the floor division of ``other`` by the array.""" | ||
| return Array(other // self.value) | ||
|
|
||
| def __mod__(self, other: int | float | Array, /) -> Array: | ||
| """Return the remainder after floor division of the array by ``other``.""" | ||
| return Array(self.value % (other.value if type(other) is Array else other)) | ||
|
|
||
| def __rmod__(self, other: int | float | Array, /) -> Array: | ||
| """Return the remainder after floor division of ``other`` by the array.""" | ||
| return Array(other % self.value) | ||
|
|
||
| def __matmul__(self, other: Array, /) -> Array: | ||
| """Return the matrix multiplication of the array with ``other``.""" | ||
| return Array(self.value @ other.value) | ||
|
|
@@ -128,12 +145,16 @@ def __rmatmul__(self, other: Array, /) -> Array: | |
| return Array(other.value @ self.value) | ||
|
|
||
| def __pow__(self, other: int | float | complex | Array, /) -> Array: | ||
| """Exponentiate the array by a scalar power.""" | ||
| """Exponentiate the array element-wise.""" | ||
| # numpy/torch/jax/tf all implement ``tensor ** p`` with semantics matching the | ||
| # backend's ``pow``; routing through the backend would cost an extra method | ||
| # call for no behavioral difference. | ||
| return Array(self.value ** (other.value if type(other) is Array else other)) | ||
|
|
||
| def __rpow__(self, other: int | float | complex | Array, /) -> Array: | ||
| """Exponentiate other element-wise by array.""" | ||
| return Array(other**self.value) | ||
|
|
||
| # Comparisons ---------------------------------------------------------- | ||
| # | ||
| # Element-wise comparisons return an :class:`Array` of bools. The ``__eq__`` and | ||
|
|
@@ -189,6 +210,42 @@ def __rand__(self, other: bool | int | Array, /) -> Array: | |
| """Element-wise bitwise/logical AND with the array on the right.""" | ||
| return Array((other.value if type(other) is Array else other) & self.value) | ||
|
|
||
| def __or__(self, other: bool | int | Array, /) -> Array: | ||
| """Element-wise bitwise/logical OR.""" | ||
| return Array(self.value | (other.value if type(other) is Array else other)) | ||
|
|
||
| def __ror__(self, other: bool | int | Array, /) -> Array: | ||
| """Element-wise bitwise/logical OR with the array on the right.""" | ||
| return Array((other.value if type(other) is Array else other) | self.value) | ||
|
|
||
| def __xor__(self, other: bool | int | Array, /) -> Array: | ||
| """Element-wise bitwise/logical XOR.""" | ||
| return Array(self.value ^ (other.value if type(other) is Array else other)) | ||
|
|
||
| def __rxor__(self, other: bool | int | Array, /) -> Array: | ||
| """Element-wise bitwise/logical XOR with the array on the right.""" | ||
| return Array((other.value if type(other) is Array else other) ^ self.value) | ||
|
|
||
| def __lshift__(self, other: int | Array, /) -> Array: | ||
| """Element-wise bitwise left shift as specified by int/int array.""" | ||
| return self._backend.bitwise_left_shift(self, other) | ||
|
|
||
| def __rlshift__(self, other: int | Array, /) -> Array: | ||
| """Element-wise bitwise left shift as specified by int/int array on the right.""" | ||
| return self._backend.bitwise_left_shift(other, self) | ||
|
|
||
| def __rshift__(self, other: int | Array, /) -> Array: | ||
| """Element-wise bitwise right shift as specified by int/int array.""" | ||
| return self._backend.bitwise_right_shift(self, other) | ||
|
|
||
| def __rrshift__(self, other: int | Array, /) -> Array: | ||
| """Element-wise bitwise right shift as specified by int/int array on the right.""" | ||
| return self._backend.bitwise_right_shift(other, self) | ||
|
|
||
| def __invert__(self) -> Array: | ||
| """Element-wise bitwise/logical NOT.""" | ||
| return Array(~self.value) | ||
|
|
||
| # In-place arithmetic -------------------------------------------------- | ||
| # | ||
| # The backend handles the framework's mutability semantics: numpy/pytorch mutate | ||
|
|
@@ -215,6 +272,51 @@ def __itruediv__(self, other: int | float | complex | Array, /) -> Self: | |
| self._backend.idivide(self, other) | ||
| return self | ||
|
|
||
| def __ifloordiv__(self, other: int | float | Array) -> Self: | ||
| """In-place floor division.""" | ||
| self._backend.ifloordiv(self, other) | ||
| return self | ||
|
|
||
| def __imod__(self, other: int | float | Array) -> Self: | ||
| """In-place remainder after floor division.""" | ||
| self._backend.imod(self, other) | ||
| return self | ||
|
|
||
| def __ipow__(self, other: int | float | complex | Array) -> Self: | ||
| """In-place raise array to power ``other``.""" | ||
| self._backend.ipow(self, other) | ||
| return self | ||
|
|
||
| def __imatmul__(self, other: Array) -> Self: | ||
| """In-place matrix multiplication.""" | ||
| self._backend.imatmul(self, other) | ||
| return self | ||
|
|
||
| def __iand__(self, other: bool | int | Array) -> Self: | ||
| """In-place bitwise/logical AND.""" | ||
| self._backend.iand(self, other) | ||
| return self | ||
|
|
||
| def __ior__(self, other: bool | int | Array) -> Self: | ||
| """In-place bitwise/logical OR.""" | ||
| self._backend.ior(self, other) | ||
| return self | ||
|
|
||
| def __ixor__(self, other: bool | int | Array) -> Self: | ||
| """In-place bitwise/logical XOR.""" | ||
| self._backend.ixor(self, other) | ||
| return self | ||
|
|
||
| def __ilshift__(self, other: int | Array) -> Self: | ||
| """In-place bitwise left shift.""" | ||
| self._backend.ilshift(self, other) | ||
| return self | ||
|
|
||
| def __irshift__(self, other: int | Array) -> Self: | ||
| """In-place bitwise right shift.""" | ||
| self._backend.irshift(self, other) | ||
| return self | ||
|
|
||
| # Unary ---------------------------------------------------------------- | ||
|
|
||
| def __neg__(self) -> Array: | ||
|
|
@@ -223,6 +325,10 @@ def __neg__(self) -> Array: | |
| # supported frameworks, so the indirection is not needed. | ||
| return Array(-self.value) | ||
|
|
||
| def __pos__(self) -> Array: | ||
| """Return the array itself.""" | ||
| return self | ||
|
|
||
| def __abs__(self) -> Array: | ||
| """Return the absolute value of the array.""" | ||
| # Same rationale as ``__neg__`` — native ``abs(tensor)`` matches each | ||
|
|
@@ -251,6 +357,22 @@ def __float__(self) -> float: | |
| """Coerce a scalar array to a Python float.""" | ||
| return float(self._backend.squeeze(self).value) | ||
|
|
||
| def __bool__(self) -> bool: | ||
| """Coerce a scalar array to a Python bool.""" | ||
| return bool(self._backend.squeeze(self).value) | ||
|
|
||
| def __int__(self) -> int: | ||
| """Coerce a scalar array to a Python int.""" | ||
| return int(self._backend.squeeze(self).value) | ||
|
|
||
| def __complex__(self) -> complex: | ||
| """Coerce a scalar array to a Python complex.""" | ||
| return complex(self._backend.squeeze(self).value) | ||
|
|
||
| def __index__(self) -> int: | ||
| """Coerce a scalar array to a Python int.""" | ||
| return int(self._backend.squeeze(self).value) | ||
|
|
||
| # Repr ----------------------------------------------------------------- | ||
|
|
||
| def __repr__(self) -> str: | ||
|
|
@@ -278,6 +400,25 @@ def ndim(self) -> int: | |
| """Return the number of dimensions of the array.""" | ||
| return self._backend.ndim(self) | ||
|
|
||
| @property | ||
| def dtype(self) -> DTypes: | ||
| """ | ||
| Return dtype of the Array as item of DTypes enum. | ||
|
|
||
| Raises: | ||
| ValueError: for dtypes that are not supported by all decent-array functions | ||
|
|
||
| """ | ||
| # get framework-native dtype as string | ||
| # split takes care of types with names like "torch.float32" | ||
| dtype_name = str(self.value.dtype).split(".")[-1] | ||
|
|
||
| dtype = _STRING_TO_DTYPE.get(dtype_name) | ||
|
Comment on lines
+414
to
+416
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit hacky and hard to extend. It would be cleaner if this is added to the iop functions. Maybe "dtype_of" or just "dtype" but the former is a bit clearer. Doing it in the iop layer would just require a reversed map of the one thats already in each backend, it would basically be something like: _INV_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}
def dtype_of(x):
v = _INV_DTYPE_MAP.get(x.dtype)
if not v:
# Should never reach here since we control dtype but can be if user does something hacky
raise SOME_ERROR
return v
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes this is definitely really hacky. I'm starting to work on a much more solid version now based on the ideas in #8. but that is beyond the scope of this PR, so I thought to leave this hacky placeholder for now
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough, feel free to merge when you feel ready |
||
| if dtype is None: | ||
| raise ValueError(f"dtype {self.value.dtype} is not supported by all decent-array functions.") | ||
|
|
||
| return dtype | ||
|
|
||
| @property | ||
| def transpose(self) -> Array: | ||
| """Return a transposed view of the array.""" | ||
|
|
@@ -288,6 +429,11 @@ def T(self) -> Array: # noqa: N802 | |
| """Return a transposed view of the array.""" | ||
| return self.transpose | ||
|
|
||
| @property | ||
| def mT(self) -> Array: # noqa: N802 | ||
| """Return the matrix transpose (last two dimensions swapped).""" | ||
| return self._backend.matrix_transpose(self) | ||
|
|
||
| @property | ||
| def any(self) -> bool: | ||
| """Return True if any element of the array is truthy.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this needed for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's to handle the case of writing
+x. not very common but it can happenUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair enough, that would probably be more of a syntax error for this class but fine to keep