Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/spatialdata/_core/_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from collections import UserDict
from collections.abc import Iterable, KeysView, ValuesView
from typing import Any
from typing import TypeVar
from warnings import warn

from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from xarray import DataArray, DataTree

from spatialdata._core.validation import check_key_is_case_insensitively_unique, check_valid_name
from spatialdata._types import Raster_T
Expand All @@ -25,8 +26,10 @@
get_model,
)

T = TypeVar("T")

class Elements(UserDict[str, Any]):

class Elements(UserDict[str, T]):
def __init__(self, shared_keys: set[str | None]) -> None:
self._shared_keys = shared_keys
super().__init__()
Expand All @@ -49,7 +52,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non
# Validation raises ValueError, but inappropriate mapping key must raise KeyError.
raise KeyError(*e.args) from e

def __setitem__(self, key: str, value: Any) -> None:
def __setitem__(self, key: str, value: T) -> None:
self._add_shared_key(key)
super().__setitem__(key, value)

Expand All @@ -61,12 +64,12 @@ def keys(self) -> KeysView[str]:
"""Return the keys of the Elements."""
return self.data.keys()

def values(self) -> ValuesView[Any]:
def values(self) -> ValuesView[T]:
"""Return the values of the Elements."""
return self.data.values()


class Images(Elements):
class Images(Elements[DataArray | DataTree]):
def __setitem__(self, key: str, value: Raster_T) -> None:
self._check_key(key, self.keys(), self._shared_keys)
schema = get_model(value)
Expand All @@ -83,7 +86,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
NotImplementedError("TODO: implement for ndim > 4.")


class Labels(Elements):
class Labels(Elements[DataArray | DataTree]):
def __setitem__(self, key: str, value: Raster_T) -> None:
self._check_key(key, self.keys(), self._shared_keys)
schema = get_model(value)
Expand All @@ -100,7 +103,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
NotImplementedError("TODO: implement for ndim > 3.")


class Shapes(Elements):
class Shapes(Elements[GeoDataFrame]):
def __setitem__(self, key: str, value: GeoDataFrame) -> None:
self._check_key(key, self.keys(), self._shared_keys)
schema = get_model(value)
Expand All @@ -110,7 +113,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
super().__setitem__(key, value)


class Points(Elements):
class Points(Elements[DaskDataFrame]):
def __setitem__(self, key: str, value: DaskDataFrame) -> None:
self._check_key(key, self.keys(), self._shared_keys)
schema = get_model(value)
Expand All @@ -120,7 +123,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
super().__setitem__(key, value)


class Tables(Elements):
class Tables(Elements[AnnData]):
def __setitem__(self, key: str, value: AnnData) -> None:
self._check_key(key, self.keys(), self._shared_keys)
schema = get_model(value)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,7 +2336,7 @@ def subset(
)
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)

def __getitem__(self, item: str) -> SpatialElement:
def __getitem__(self, item: str) -> SpatialElement | AnnData:
"""
Return the element with the given name.

Expand Down
24 changes: 19 additions & 5 deletions src/spatialdata/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
from xarray.testing import assert_equal

from spatialdata import SpatialData
from spatialdata._core._elements import Elements
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
from spatialdata.models import PointsModel
from spatialdata.models._utils import SpatialElement
from spatialdata.transformations.operations import get_transformation

_Elements = Images | Labels | Shapes | Points | Tables


def assert_elements_dict_are_identical(
elements0: Elements, elements1: Elements, check_transformations: bool = True, check_metadata: bool = True
elements0: _Elements,
elements1: _Elements,
check_transformations: bool = True,
check_metadata: bool = True,
) -> None:
"""
Compare two dictionaries of elements and assert that they are identical (except for the order of the keys).
Expand Down Expand Up @@ -55,7 +60,10 @@ def assert_elements_dict_are_identical(
element0 = elements0[k]
element1 = elements1[k]
assert_elements_are_identical(
element0, element1, check_transformations=check_transformations, check_metadata=check_metadata
element0,
element1,
check_transformations=check_transformations,
check_metadata=check_metadata,
)


Expand Down Expand Up @@ -125,7 +133,10 @@ def assert_elements_are_identical(


def assert_spatial_data_objects_are_identical(
sdata0: SpatialData, sdata1: SpatialData, check_transformations: bool = True, check_metadata: bool = True
sdata0: SpatialData,
sdata1: SpatialData,
check_transformations: bool = True,
check_metadata: bool = True,
) -> None:
"""
Compare two SpatialData objects and assert that they are identical.
Expand Down Expand Up @@ -169,5 +180,8 @@ def assert_spatial_data_objects_are_identical(
element0 = sdata0[element_name]
element1 = sdata1[element_name]
assert_elements_are_identical(
element0, element1, check_transformations=check_transformations, check_metadata=check_metadata
element0,
element1,
check_transformations=check_transformations,
check_metadata=check_metadata,
)
Loading