Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions decent_array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,73 @@
from decent_array import interoperability, types
from decent_array._array import Array
from decent_array._constants import e, inf, nan, pi
from decent_array.types._dtypes import (
bfloat16,
bool_,
bytes_,
complex64,
complex128,
complex256,
float16,
float32,
float64,
float128,
int8,
int16,
int32,
int64,
object_,
qint8,
qint16,
qint32,
quint8,
quint16,
uint8,
uint16,
uint32,
uint64,
unicode_,
void,
)

__all_docs__ = [
"Array",
"interoperability",
"types",
]

__all__ = [
"Array",
"bfloat16",
"bool_",
"bytes_",
"complex64",
"complex128",
"complex256",
"e",
"float16",
"float32",
"float64",
"float128",
"inf",
"int8",
"int16",
"int32",
"int64",
"interoperability",
"nan",
"object_",
"pi",
"qint8",
"qint16",
"qint32",
"quint8",
"quint16",
"types",
"uint8",
"uint16",
"uint32",
"uint64",
"unicode_",
"void",
]
27 changes: 9 additions & 18 deletions decent_array/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
from typing import TYPE_CHECKING, Any, Self

from decent_array.interoperability._backend_manager import register_backend_listener
from decent_array.types import _STRING_TO_DTYPE
from decent_array.types._dtypes import _BACKEND_DTYPE_TO_DTYPE

if TYPE_CHECKING:
from numpy.typing import NDArray

from decent_array.interoperability._abstracts import Backend
from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices
from decent_array.types import ArrayKey, ArrayTypes, Devices, dtype


_BACKEND_INSTANCE: Backend | None = None
Expand All @@ -49,7 +49,7 @@ def _update_backend(backend: Backend | None) -> None:
register_backend_listener(_update_backend)


class Array: # noqa: PLR0904
class Array:
"""
Wrapper around a single backend-native array.

Expand All @@ -60,7 +60,7 @@ class Array: # noqa: PLR0904

__slots__ = ("_backend", "value")

def __init__(self, value: SupportedArrayTypes) -> None:
def __init__(self, value: ArrayTypes) -> None:
"""
Wrap ``value`` in an :class:`Array`.

Expand Down Expand Up @@ -401,21 +401,12 @@ def ndim(self) -> int:
return self._backend.ndim(self)

@property
def dtype(self) -> DTypes:
"""
Return dtype of the Array as item of DTypes enum.

Raises:
ValueError: for dtypes that are not supported by all decent-array functions

"""
# get framework-native dtype as string
# split takes care of types with names like "torch.float32"
dtype_name = str(self.value.dtype).split(".")[-1]
def dtype(self) -> dtype:
"""Return dtype of the Array."""
dtype = _BACKEND_DTYPE_TO_DTYPE.get(self.value.dtype)

dtype = _STRING_TO_DTYPE.get(dtype_name)
if dtype is None:
raise ValueError(f"dtype {self.value.dtype} is not supported by all decent-array functions.")
raise ValueError(f"dtype {self.value.dtype} is not supported.")

return dtype

Expand Down Expand Up @@ -445,7 +436,7 @@ def all(self) -> bool:
return self._backend.all(self)

@property
def device(self) -> SupportedDevices:
def device(self) -> Devices:
"""Return the device of the array."""
return self._backend.device_of(self)

Expand Down
10 changes: 10 additions & 0 deletions decent_array/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Numerical constants."""

import math

_CONSTANTS = ["e", "inf", "nan", "pi"]

e = math.e
inf = math.inf
nan = math.nan
pi = math.pi
14 changes: 14 additions & 0 deletions decent_array/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any

from decent_array._array import Array


def unwrap(x: Any) -> Any: # noqa: ANN401
"""
Return the underlying value of an :class:`Array`, or pass ``x`` through.

Typed as ``Any`` because operator dunders may pass either an :class:`Array` or a
Python scalar; the strict abstract signature would force a ``cast`` at every call
site without runtime benefit.
"""
return x.value if type(x) is Array else x
2 changes: 1 addition & 1 deletion decent_array/interoperability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
bitwise_right_shift,
bitwise_xor,
)
from ._iop.comparasion import equal, greater, greater_equal, less, less_equal, not_equal
from ._iop.comparison import equal, greater, greater_equal, less, less_equal, not_equal
from ._iop.creation import eye, ones, ones_like, zeros, zeros_like
from ._iop.linalg import dot, matmul, norm, vecdot, vector_norm
from ._iop.manipulations import (
Expand Down
Loading
Loading