diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index a72bf313a..7c2d24201 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -4,12 +4,13 @@ from collections import UserDict from collections.abc import Iterable, KeysView, ValuesView -from typing import Any +from typing import TypeVar from warnings import warn from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from xarray import DataArray, DataTree from spatialdata._core.validation import check_key_is_case_insensitively_unique, check_valid_name from spatialdata._types import Raster_T @@ -25,8 +26,10 @@ get_model, ) +T = TypeVar("T") -class Elements(UserDict[str, Any]): + +class Elements(UserDict[str, T]): def __init__(self, shared_keys: set[str | None]) -> None: self._shared_keys = shared_keys super().__init__() @@ -49,7 +52,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non # Validation raises ValueError, but inappropriate mapping key must raise KeyError. raise KeyError(*e.args) from e - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: str, value: T) -> None: self._add_shared_key(key) super().__setitem__(key, value) @@ -61,12 +64,12 @@ def keys(self) -> KeysView[str]: """Return the keys of the Elements.""" return self.data.keys() - def values(self) -> ValuesView[Any]: + def values(self) -> ValuesView[T]: """Return the values of the Elements.""" return self.data.values() -class Images(Elements): +class Images(Elements[DataArray | DataTree]): def __setitem__(self, key: str, value: Raster_T) -> None: self._check_key(key, self.keys(), self._shared_keys) schema = get_model(value) @@ -83,7 +86,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None: NotImplementedError("TODO: implement for ndim > 4.") -class Labels(Elements): +class Labels(Elements[DataArray | DataTree]): def __setitem__(self, key: str, value: Raster_T) -> None: self._check_key(key, self.keys(), self._shared_keys) schema = get_model(value) @@ -100,7 +103,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None: NotImplementedError("TODO: implement for ndim > 3.") -class Shapes(Elements): +class Shapes(Elements[GeoDataFrame]): def __setitem__(self, key: str, value: GeoDataFrame) -> None: self._check_key(key, self.keys(), self._shared_keys) schema = get_model(value) @@ -110,7 +113,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None: super().__setitem__(key, value) -class Points(Elements): +class Points(Elements[DaskDataFrame]): def __setitem__(self, key: str, value: DaskDataFrame) -> None: self._check_key(key, self.keys(), self._shared_keys) schema = get_model(value) @@ -120,7 +123,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None: super().__setitem__(key, value) -class Tables(Elements): +class Tables(Elements[AnnData]): def __setitem__(self, key: str, value: AnnData) -> None: self._check_key(key, self.keys(), self._shared_keys) schema = get_model(value) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f011d08f8..3cb9f8d52 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2336,7 +2336,7 @@ def subset( ) return SpatialData(**elements_dict, tables=tables, attrs=self.attrs) - def __getitem__(self, item: str) -> SpatialElement: + def __getitem__(self, item: str) -> SpatialElement | AnnData: """ Return the element with the given name. diff --git a/src/spatialdata/testing.py b/src/spatialdata/testing.py index 18199a237..1e5946aa1 100644 --- a/src/spatialdata/testing.py +++ b/src/spatialdata/testing.py @@ -10,14 +10,19 @@ from xarray.testing import assert_equal from spatialdata import SpatialData -from spatialdata._core._elements import Elements +from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata.models import PointsModel from spatialdata.models._utils import SpatialElement from spatialdata.transformations.operations import get_transformation +_Elements = Images | Labels | Shapes | Points | Tables + def assert_elements_dict_are_identical( - elements0: Elements, elements1: Elements, check_transformations: bool = True, check_metadata: bool = True + elements0: _Elements, + elements1: _Elements, + check_transformations: bool = True, + check_metadata: bool = True, ) -> None: """ Compare two dictionaries of elements and assert that they are identical (except for the order of the keys). @@ -55,7 +60,10 @@ def assert_elements_dict_are_identical( element0 = elements0[k] element1 = elements1[k] assert_elements_are_identical( - element0, element1, check_transformations=check_transformations, check_metadata=check_metadata + element0, + element1, + check_transformations=check_transformations, + check_metadata=check_metadata, ) @@ -125,7 +133,10 @@ def assert_elements_are_identical( def assert_spatial_data_objects_are_identical( - sdata0: SpatialData, sdata1: SpatialData, check_transformations: bool = True, check_metadata: bool = True + sdata0: SpatialData, + sdata1: SpatialData, + check_transformations: bool = True, + check_metadata: bool = True, ) -> None: """ Compare two SpatialData objects and assert that they are identical. @@ -169,5 +180,8 @@ def assert_spatial_data_objects_are_identical( element0 = sdata0[element_name] element1 = sdata1[element_name] assert_elements_are_identical( - element0, element1, check_transformations=check_transformations, check_metadata=check_metadata + element0, + element1, + check_transformations=check_transformations, + check_metadata=check_metadata, )