diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py deleted file mode 100644 index 41a431e16e..0000000000 --- a/dimos/hardware/sensors/fake_zed_module.py +++ /dev/null @@ -1,292 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -FakeZEDModule - Replays recorded ZED data for testing without hardware. -""" - -import functools -import logging -from typing import Any - -from dimos_lcm.sensor_msgs import CameraInfo -import numpy as np - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import Out -from dimos.memory.timeseries.legacy import LegacyPickleStore -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.msgs.std_msgs.Header import Header -from dimos.protocol.tf.tf import TF -from dimos.utils.logging_config import setup_logger - -logger = setup_logger(level=logging.INFO) - - -class FakeZEDModuleConfig(ModuleConfig): - recording_path: str - frame_id: str = "zed_camera" - - -class FakeZEDModule(Module): - """ - Fake ZED module that replays recorded data instead of real camera. - """ - - config: FakeZEDModuleConfig - - # Define LCM outputs (same as ZEDModule) - color_image: Out[Image] - depth_image: Out[Image] - camera_info: Out[CameraInfo] - pose: Out[PoseStamped] - - def __init__(self, **kwargs: Any) -> None: - """ - Initialize FakeZEDModule with recording path. - - Args: - recording_path: Path to recorded data directory - """ - super().__init__(**kwargs) - - self.recording_path = self.config.recording_path - self._running = False - - # Initialize TF publisher - self.tf = TF() - - logger.info(f"FakeZEDModule initialized with recording: {self.recording_path}") - - @functools.cache - def _get_color_stream(self): # type: ignore[no-untyped-def] - """Get cached color image stream.""" - logger.info(f"Loading color image stream from {self.recording_path}/color") - - def image_autocast(x): # type: ignore[no-untyped-def] - """Convert raw numpy array to Image.""" - if isinstance(x, np.ndarray): - return Image(data=x, format=ImageFormat.RGB) - elif isinstance(x, Image): - return x - return x - - color_replay = LegacyPickleStore(f"{self.recording_path}/color", autocast=image_autocast) - return color_replay.stream() - - @functools.cache - def _get_depth_stream(self): # type: ignore[no-untyped-def] - """Get cached depth image stream.""" - logger.info(f"Loading depth image stream from {self.recording_path}/depth") - - def depth_autocast(x): # type: ignore[no-untyped-def] - """Convert raw numpy array to depth Image.""" - if isinstance(x, np.ndarray): - # Depth images are float32 - return Image(data=x, format=ImageFormat.DEPTH) - elif isinstance(x, Image): - return x - return x - - depth_replay = LegacyPickleStore(f"{self.recording_path}/depth", autocast=depth_autocast) - return depth_replay.stream() - - @functools.cache - def _get_pose_stream(self): # type: ignore[no-untyped-def] - """Get cached pose stream.""" - logger.info(f"Loading pose stream from {self.recording_path}/pose") - - def pose_autocast(x): # type: ignore[no-untyped-def] - """Convert raw pose dict to PoseStamped.""" - if isinstance(x, dict): - import time - - return PoseStamped( - position=x.get("position", [0, 0, 0]), - orientation=x.get("rotation", [0, 0, 0, 1]), - ts=time.time(), - ) - elif isinstance(x, PoseStamped): - return x - return x - - pose_replay = LegacyPickleStore(f"{self.recording_path}/pose", autocast=pose_autocast) - return pose_replay.stream() - - @functools.cache - def _get_camera_info_stream(self): # type: ignore[no-untyped-def] - """Get cached camera info stream.""" - logger.info(f"Loading camera info stream from {self.recording_path}/camera_info") - - def camera_info_autocast(x): # type: ignore[no-untyped-def] - """Convert raw camera info dict to CameraInfo message.""" - if isinstance(x, dict): - # Extract calibration parameters - left_cam = x.get("left_cam", {}) - resolution = x.get("resolution", {}) - - # Create CameraInfo message - header = Header(self.frame_id) - - # Create camera matrix K (3x3) - K = [ - left_cam.get("fx", 0), - 0, - left_cam.get("cx", 0), - 0, - left_cam.get("fy", 0), - left_cam.get("cy", 0), - 0, - 0, - 1, - ] - - # Distortion coefficients - D = [ - left_cam.get("k1", 0), - left_cam.get("k2", 0), - left_cam.get("p1", 0), - left_cam.get("p2", 0), - left_cam.get("k3", 0), - ] - - # Identity rotation matrix - R = [1, 0, 0, 0, 1, 0, 0, 0, 1] - - # Projection matrix P (3x4) - P = [ - left_cam.get("fx", 0), - 0, - left_cam.get("cx", 0), - 0, - 0, - left_cam.get("fy", 0), - left_cam.get("cy", 0), - 0, - 0, - 0, - 1, - 0, - ] - - return CameraInfo( - D_length=len(D), - header=header, - height=resolution.get("height", 0), - width=resolution.get("width", 0), - distortion_model="plumb_bob", - D=D, - K=K, - R=R, - P=P, - binning_x=0, - binning_y=0, - ) - elif isinstance(x, CameraInfo): - return x - return x - - info_replay = LegacyPickleStore( - f"{self.recording_path}/camera_info", autocast=camera_info_autocast - ) - return info_replay.stream() - - @rpc - def start(self) -> None: - """Start replaying recorded data.""" - super().start() - - if self._running: - logger.warning("FakeZEDModule already running") - return - - logger.info("Starting FakeZEDModule replay...") - - self._running = True - - # Subscribe to all streams and publish - try: - # Color image stream - unsub = self._get_color_stream().subscribe( - lambda msg: self.color_image.publish(msg) if self._running else None - ) - self.register_disposable(unsub) - logger.info("Started color image replay stream") - except Exception as e: - logger.warning(f"Color image stream not available: {e}") - - try: - # Depth image stream - unsub = self._get_depth_stream().subscribe( - lambda msg: self.depth_image.publish(msg) if self._running else None - ) - self.register_disposable(unsub) - logger.info("Started depth image replay stream") - except Exception as e: - logger.warning(f"Depth image stream not available: {e}") - - try: - # Pose stream - unsub = self._get_pose_stream().subscribe( - lambda msg: self._publish_pose(msg) if self._running else None - ) - self.register_disposable(unsub) - logger.info("Started pose replay stream") - except Exception as e: - logger.warning(f"Pose stream not available: {e}") - - try: - # Camera info stream - unsub = self._get_camera_info_stream().subscribe( - lambda msg: self.camera_info.publish(msg) if self._running else None - ) - self.register_disposable(unsub) - logger.info("Started camera info replay stream") - except Exception as e: - logger.warning(f"Camera info stream not available: {e}") - - logger.info("FakeZEDModule replay started") - - @rpc - def stop(self) -> None: - if not self._running: - return - - self._running = False - - super().stop() - - def _publish_pose(self, msg) -> None: # type: ignore[no-untyped-def] - """Publish pose and TF transform.""" - if msg: - self.pose.publish(msg) - - # Publish TF transform from world to camera - import time - - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - from dimos.msgs.geometry_msgs.Transform import Transform - from dimos.msgs.geometry_msgs.Vector3 import Vector3 - - transform = Transform( - translation=Vector3(*msg.position), - rotation=Quaternion(*msg.orientation), - frame_id="world", - child_frame_id=self.frame_id, - ts=time.time(), - ) - self.tf.publish(transform) diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index 411da4ecb8..32e66721de 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -20,9 +20,9 @@ from dimos.core.transport import LCMTransport from dimos.mapping.voxels import VoxelGrid -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data +from dimos.utils.testing.legacy_pickle import LegacyPickleStore from dimos.utils.testing.moment import OutputMoment from dimos.utils.testing.test_moment import Go2Moment diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py deleted file mode 100644 index 13409a5b11..0000000000 --- a/dimos/memory/embedding.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Callable -from dataclasses import dataclass - -from pydantic import Field -import reactivex as rx -from reactivex import operators as ops -from reactivex.observable import Observable - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In -from dimos.models.embedding.base import Embedding, EmbeddingModel -from dimos.models.embedding.clip import CLIPModel -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid -from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier -from dimos.utils.reactive import getter_hot - - -class Config(ModuleConfig): - embedding_model: EmbeddingModel = Field(default_factory=CLIPModel) - - -@dataclass -class SpatialEntry: - image: Image - pose: PoseStamped - - -@dataclass -class SpatialEmbedding(SpatialEntry): - embedding: Embedding - - -class EmbeddingMemory(Module): - config: Config - color_image: In[Image] - global_costmap: In[OccupancyGrid] - - _costmap_getter: Callable[[], OccupancyGrid] | None = None - - def get_costmap(self) -> OccupancyGrid: - if self._costmap_getter is None: - self._costmap_getter = getter_hot(self.global_costmap.pure_observable()) - self.register_disposable(self._costmap_getter) - return self._costmap_getter() - - @rpc - def query_costmap(self, text: str) -> OccupancyGrid: - costmap = self.get_costmap() - # overlay costmap with embedding heat - return costmap - - @rpc - def start(self) -> None: - # would be cool if this sharpness_barrier was somehow self-calibrating - # - # we need a Governor system, sharpness_barrier frequency shouldn't - # be a fixed float but an observable that adjusts based on downstream load - # - # (also voxel size for mapper for example would benefit from this) - self.color_image.pure_observable().pipe( - sharpness_barrier(0.5), - ops.flat_map(self._try_create_spatial_entry), - ops.map(self._embed_spatial_entry), - ops.map(self._store_spatial_entry), - ).subscribe(print) - - def _try_create_spatial_entry(self, img: Image) -> Observable[SpatialEntry]: - pose = self.tf.get_pose("world", "base_link") - if not pose: - return rx.empty() - return rx.of(SpatialEntry(image=img, pose=pose)) - - def _embed_spatial_entry(self, spatial_entry: SpatialEntry) -> SpatialEmbedding: - embedding = self.config.embedding_model.embed(spatial_entry.image) - return SpatialEmbedding( - image=spatial_entry.image, - pose=spatial_entry.pose, - embedding=embedding, - ) - - @rpc - def stop(self) -> None: - super().stop() - - def _store_spatial_entry(self, spatial_embedding: SpatialEmbedding) -> SpatialEmbedding: - return spatial_embedding - - def query_text(self, query: str) -> list[SpatialEmbedding]: - self.config.embedding_model.embed_text(query) - results: list[SpatialEmbedding] = [] - return results diff --git a/dimos/memory/test_embedding.py b/dimos/memory/test_embedding.py deleted file mode 100644 index 01c76b93cf..0000000000 --- a/dimos/memory/test_embedding.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from dimos.memory.embedding import EmbeddingMemory, SpatialEntry -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.utils.testing.replay import TimedSensorReplay - - -@pytest.mark.skip -def test_embed_frame() -> None: - """Test embedding a single frame.""" - # Load a frame from recorded data - from dimos.msgs.sensor_msgs.Image import Image - - video: TimedSensorReplay[Image] = TimedSensorReplay("go2_bigoffice/color_image") - frame = video.find_closest_seek(10) - - # Create memory and embed - memory = EmbeddingMemory() - - try: - # Create a spatial entry with dummy pose (no TF needed for this test) - dummy_pose = PoseStamped( - position=[0, 0, 0], - orientation=[0, 0, 0, 1], # identity quaternion - ) - spatial_entry = SpatialEntry(image=frame, pose=dummy_pose) - - # Embed the frame - result = memory._embed_spatial_entry(spatial_entry) - - # Verify - assert result is not None - assert result.embedding is not None - assert result.embedding.vector is not None - print(f"Embedding shape: {result.embedding.vector.shape}") - print(f"Embedding vector (first 5): {result.embedding.vector[:5]}") - finally: - memory.stop() diff --git a/dimos/memory/timeseries/pickledir.py b/dimos/memory/timeseries/pickledir.py deleted file mode 100644 index 9e8cd5a249..0000000000 --- a/dimos/memory/timeseries/pickledir.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Pickle directory backend for TimeSeriesStore.""" - -import bisect -from collections.abc import Iterator -import glob -import os -from pathlib import Path -import pickle - -from dimos.memory.timeseries.base import T, TimeSeriesStore -from dimos.utils.data import get_data, get_data_dir - - -class PickleDirStore(TimeSeriesStore[T]): - """Pickle directory backend. Files named by timestamp. - - Directory structure: - {name}/ - 1704067200.123.pickle - 1704067200.456.pickle - ... - - Usage: - # Load existing recording (auto-downloads from LFS if needed) - store = PickleDirStore("unitree_go2_bigoffice/lidar") - data = store.find_closest_seek(10.0) - - # Create new recording (directory created on first save) - store = PickleDirStore("my_recording/images") - store.save(image) # saves using image.ts - """ - - def __init__(self, name: str) -> None: - """ - Args: - name: Data directory name (e.g. "unitree_go2_bigoffice/lidar") - """ - self._name = name - self._root_dir: Path | None = None - - # Cached sorted timestamps for find_closest - self._timestamps: list[float] | None = None - - def _get_root_dir(self, for_write: bool = False) -> Path: - """Get root directory, creating on first write if needed.""" - if self._root_dir is not None: - return self._root_dir - - # If absolute path, use directly - if Path(self._name).is_absolute(): - self._root_dir = Path(self._name) - if for_write: - self._root_dir.mkdir(parents=True, exist_ok=True) - elif for_write: - # For writing: use get_data_dir and create if needed - self._root_dir = get_data_dir(self._name) - self._root_dir.mkdir(parents=True, exist_ok=True) - else: - # For reading: use get_data (handles LFS download) - self._root_dir = get_data(self._name) - - return self._root_dir - - def _save(self, timestamp: float, data: T) -> None: - root_dir = self._get_root_dir(for_write=True) - full_path = root_dir / f"{timestamp}.pickle" - - if full_path.exists(): - raise RuntimeError(f"File {full_path} already exists") - - with open(full_path, "wb") as f: - pickle.dump(data, f) - - self._timestamps = None # Invalidate cache - - def _load(self, timestamp: float) -> T | None: - filepath = self._get_root_dir() / f"{timestamp}.pickle" - if filepath.exists(): - return self._load_file(filepath) - return None - - def _delete(self, timestamp: float) -> T | None: - filepath = self._get_root_dir() / f"{timestamp}.pickle" - if filepath.exists(): - data = self._load_file(filepath) - filepath.unlink() - self._timestamps = None # Invalidate cache - return data - return None - - def _iter_items( - self, start: float | None = None, end: float | None = None - ) -> Iterator[tuple[float, T]]: - for ts in self._get_timestamps(): - if start is not None and ts < start: - continue - if end is not None and ts >= end: - break - data = self._load(ts) - if data is not None: - yield (ts, data) - - def _find_closest_timestamp( - self, timestamp: float, tolerance: float | None = None - ) -> float | None: - timestamps = self._get_timestamps() - if not timestamps: - return None - - pos = bisect.bisect_left(timestamps, timestamp) - - # Check neighbors - candidates = [] - if pos > 0: - candidates.append(timestamps[pos - 1]) - if pos < len(timestamps): - candidates.append(timestamps[pos]) - - if not candidates: - return None - - closest = min(candidates, key=lambda ts: abs(ts - timestamp)) - - if tolerance is not None and abs(closest - timestamp) > tolerance: - return None - - return closest - - def _get_timestamps(self) -> list[float]: - """Get sorted list of all timestamps.""" - if self._timestamps is not None: - return self._timestamps - - timestamps: list[float] = [] - root_dir = self._get_root_dir() - for filepath in glob.glob(os.path.join(root_dir, "*.pickle")): - try: - ts = float(Path(filepath).stem) - timestamps.append(ts) - except ValueError: - continue - - timestamps.sort() - self._timestamps = timestamps - return timestamps - - def _count(self) -> int: - return len(self._get_timestamps()) - - def _last_timestamp(self) -> float | None: - timestamps = self._get_timestamps() - return timestamps[-1] if timestamps else None - - def _find_before(self, timestamp: float) -> tuple[float, T] | None: - timestamps = self._get_timestamps() - if not timestamps: - return None - pos = bisect.bisect_left(timestamps, timestamp) - if pos > 0: - ts = timestamps[pos - 1] - data = self._load(ts) - if data is not None: - return (ts, data) - return None - - def _find_after(self, timestamp: float) -> tuple[float, T] | None: - timestamps = self._get_timestamps() - if not timestamps: - return None - pos = bisect.bisect_right(timestamps, timestamp) - if pos < len(timestamps): - ts = timestamps[pos] - data = self._load(ts) - if data is not None: - return (ts, data) - return None - - def _load_file(self, filepath: Path) -> T | None: - """Load data from a pickle file (LRU cached).""" - try: - with open(filepath, "rb") as f: - data: T = pickle.load(f) - return data - except Exception: - return None diff --git a/dimos/memory/timeseries/postgres.py b/dimos/memory/timeseries/postgres.py deleted file mode 100644 index cf31bcdc4a..0000000000 --- a/dimos/memory/timeseries/postgres.py +++ /dev/null @@ -1,312 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PostgreSQL backend for TimeSeriesStore.""" - -from collections.abc import Iterator -import pickle -import re - -import psycopg2 -import psycopg2.extensions - -from dimos.core.resource import Resource -from dimos.memory.timeseries.base import T, TimeSeriesStore - -# Valid SQL identifier: alphanumeric and underscores, not starting with digit -_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - - -def _validate_identifier(name: str) -> str: - """Validate SQL identifier to prevent injection.""" - if not _VALID_IDENTIFIER.match(name): - raise ValueError( - f"Invalid identifier '{name}': must be alphanumeric/underscore, not start with digit" - ) - if len(name) > 128: - raise ValueError(f"Identifier too long: {len(name)} > 128") - return name - - -class PostgresStore(TimeSeriesStore[T], Resource): - """PostgreSQL backend for sensor data. - - Multiple stores can share the same database with different tables. - Implements Resource for lifecycle management (start/stop/dispose). - - Usage: - # Create store - store = PostgresStore("lidar") - store.start() # open connection - - # Use store - store.save(data) # saves using data.ts - data = store.find_closest_seek(10.0) - - # Cleanup - store.stop() # close connection - - # Multiple sensors in same db - lidar = PostgresStore("lidar") - images = PostgresStore("images") - - # Manual run management via table naming - run1_lidar = PostgresStore("run1_lidar") - """ - - def __init__( - self, - table: str, - db: str = "dimensional", - host: str = "localhost", - port: int = 5432, - user: str | None = None, - ) -> None: - """ - Args: - table: Table name for this sensor's data (alphanumeric/underscore only). - db: Database name (alphanumeric/underscore only). - host: PostgreSQL host. - port: PostgreSQL port. - user: PostgreSQL user. Defaults to current system user. - """ - self._table = _validate_identifier(table) - self._db = _validate_identifier(db) - self._host = host - self._port = port - self._user = user - self._conn: psycopg2.extensions.connection | None = None - self._table_created = False - - def start(self) -> None: - """Open database connection.""" - if self._conn is not None: - return - self._conn = psycopg2.connect( - dbname=self._db, - host=self._host, - port=self._port, - user=self._user, - ) - - def stop(self) -> None: - """Close database connection.""" - if self._conn is not None: - self._conn.close() - self._conn = None - - def _get_conn(self) -> psycopg2.extensions.connection: - """Get connection, starting if needed.""" - if self._conn is None: - self.start() - assert self._conn is not None - return self._conn - - def _ensure_table(self) -> None: - """Create table if it doesn't exist.""" - if self._table_created: - return - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute(f""" - CREATE TABLE IF NOT EXISTS {self._table} ( - timestamp DOUBLE PRECISION PRIMARY KEY, - data BYTEA NOT NULL - ) - """) - cur.execute(f""" - CREATE INDEX IF NOT EXISTS idx_{self._table}_ts - ON {self._table}(timestamp) - """) - conn.commit() - self._table_created = True - - def _save(self, timestamp: float, data: T) -> None: - self._ensure_table() - conn = self._get_conn() - blob = pickle.dumps(data) - with conn.cursor() as cur: - cur.execute( - f""" - INSERT INTO {self._table} (timestamp, data) VALUES (%s, %s) - ON CONFLICT (timestamp) DO UPDATE SET data = EXCLUDED.data - """, - (timestamp, psycopg2.Binary(blob)), - ) - conn.commit() - - def _load(self, timestamp: float) -> T | None: - self._ensure_table() - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute(f"SELECT data FROM {self._table} WHERE timestamp = %s", (timestamp,)) - row = cur.fetchone() - if row is None: - return None - data: T = pickle.loads(row[0]) - return data - - def _delete(self, timestamp: float) -> T | None: - data = self._load(timestamp) - if data is not None: - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute(f"DELETE FROM {self._table} WHERE timestamp = %s", (timestamp,)) - conn.commit() - return data - - def _iter_items( - self, start: float | None = None, end: float | None = None - ) -> Iterator[tuple[float, T]]: - self._ensure_table() - conn = self._get_conn() - - query = f"SELECT timestamp, data FROM {self._table}" - params: list[float] = [] - conditions = [] - - if start is not None: - conditions.append("timestamp >= %s") - params.append(start) - if end is not None: - conditions.append("timestamp < %s") - params.append(end) - - if conditions: - query += " WHERE " + " AND ".join(conditions) - query += " ORDER BY timestamp" - - with conn.cursor() as cur: - cur.execute(query, params) - for row in cur: - ts: float = row[0] - data: T = pickle.loads(row[1]) - yield (ts, data) - - def _find_closest_timestamp( - self, timestamp: float, tolerance: float | None = None - ) -> float | None: - self._ensure_table() - conn = self._get_conn() - - with conn.cursor() as cur: - # Get closest timestamp <= target - cur.execute( - f""" - SELECT timestamp FROM {self._table} - WHERE timestamp <= %s - ORDER BY timestamp DESC LIMIT 1 - """, - (timestamp,), - ) - before = cur.fetchone() - - # Get closest timestamp >= target - cur.execute( - f""" - SELECT timestamp FROM {self._table} - WHERE timestamp >= %s - ORDER BY timestamp ASC LIMIT 1 - """, - (timestamp,), - ) - after = cur.fetchone() - - candidates: list[float] = [] - if before: - candidates.append(before[0]) - if after: - candidates.append(after[0]) - - if not candidates: - return None - - closest = min(candidates, key=lambda ts: abs(ts - timestamp)) - - if tolerance is not None and abs(closest - timestamp) > tolerance: - return None - - return closest - - def _count(self) -> int: - self._ensure_table() - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute(f"SELECT COUNT(*) FROM {self._table}") - row = cur.fetchone() - return row[0] if row else 0 - - def _last_timestamp(self) -> float | None: - self._ensure_table() - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute(f"SELECT MAX(timestamp) FROM {self._table}") - row = cur.fetchone() - if row is None or row[0] is None: - return None - return row[0] # type: ignore[no-any-return] - - def _find_before(self, timestamp: float) -> tuple[float, T] | None: - self._ensure_table() - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute( - f"SELECT timestamp, data FROM {self._table} WHERE timestamp < %s ORDER BY timestamp DESC LIMIT 1", - (timestamp,), - ) - row = cur.fetchone() - if row is None: - return None - return (row[0], pickle.loads(row[1])) - - def _find_after(self, timestamp: float) -> tuple[float, T] | None: - self._ensure_table() - conn = self._get_conn() - with conn.cursor() as cur: - cur.execute( - f"SELECT timestamp, data FROM {self._table} WHERE timestamp > %s ORDER BY timestamp ASC LIMIT 1", - (timestamp,), - ) - row = cur.fetchone() - if row is None: - return None - return (row[0], pickle.loads(row[1])) - - -def reset_db(db: str = "dimensional", host: str = "localhost", port: int = 5432) -> None: - """Drop and recreate database. Simple migration strategy. - - WARNING: This deletes all data in the database! - - Args: - db: Database name to reset (alphanumeric/underscore only). - host: PostgreSQL host. - port: PostgreSQL port. - """ - db = _validate_identifier(db) - # Connect to 'postgres' database to drop/create - conn = psycopg2.connect(dbname="postgres", host=host, port=port) - conn.autocommit = True - with conn.cursor() as cur: - # Terminate existing connections - cur.execute( - """ - SELECT pg_terminate_backend(pid) - FROM pg_stat_activity - WHERE datname = %s AND pid <> pg_backend_pid() - """, - (db,), - ) - cur.execute(f"DROP DATABASE IF EXISTS {db}") - cur.execute(f"CREATE DATABASE {db}") - conn.close() diff --git a/dimos/memory/timeseries/sqlite.py b/dimos/memory/timeseries/sqlite.py deleted file mode 100644 index 6e2ac7a7f5..0000000000 --- a/dimos/memory/timeseries/sqlite.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SQLite backend for TimeSeriesStore.""" - -from collections.abc import Iterator -from pathlib import Path -import pickle -import re -import sqlite3 - -from dimos.memory.timeseries.base import T, TimeSeriesStore -from dimos.utils.data import get_data, get_data_dir - -# Valid SQL identifier: alphanumeric and underscores, not starting with digit -_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - - -def _validate_identifier(name: str) -> str: - """Validate SQL identifier to prevent injection.""" - if not _VALID_IDENTIFIER.match(name): - raise ValueError( - f"Invalid identifier '{name}': must be alphanumeric/underscore, not start with digit" - ) - if len(name) > 128: - raise ValueError(f"Identifier too long: {len(name)} > 128") - return name - - -class SqliteStore(TimeSeriesStore[T]): - """SQLite backend for sensor data. Good for indexed queries and single-file storage. - - Data is stored as pickled BLOBs with timestamp as indexed column. - - Usage: - # Named store (uses data/ directory, auto-downloads from LFS if needed) - store = SqliteStore("recordings/lidar") # -> data/recordings/lidar.db - store.save(data) # saves using data.ts - - # Absolute path - store = SqliteStore("/path/to/sensors.db") - - # In-memory (for testing) - store = SqliteStore(":memory:") - - # Multiple tables in one DB - store = SqliteStore("recordings/sensors", table="lidar") - """ - - def __init__(self, name: str | Path, table: str = "sensor_data") -> None: - """ - Args: - name: Data name (e.g. "recordings/lidar") resolved via get_data, - absolute path, or ":memory:" for in-memory. - table: Table name for this sensor's data (alphanumeric/underscore only). - """ - self._name = str(name) - self._table = _validate_identifier(table) - self._db_path: str | None = None - self._conn: sqlite3.Connection | None = None - - def _get_db_path(self, for_write: bool = False) -> str: - """Get database path, resolving via get_data if needed.""" - if self._db_path is not None: - return self._db_path - - # Special case for in-memory - if self._name == ":memory:": - self._db_path = ":memory:" - return self._db_path - - # If absolute path, use directly - if Path(self._name).is_absolute(): - self._db_path = self._name - elif for_write: - # For writing: use get_data_dir - db_file = get_data_dir(self._name + ".db") - db_file.parent.mkdir(parents=True, exist_ok=True) - self._db_path = str(db_file) - else: - # For reading: use get_data (handles LFS download) - # Try with .db extension first - try: - db_file = get_data(self._name + ".db") - self._db_path = str(db_file) - except FileNotFoundError: - # Fall back to get_data_dir for new databases - db_file = get_data_dir(self._name + ".db") - db_file.parent.mkdir(parents=True, exist_ok=True) - self._db_path = str(db_file) - - return self._db_path - - def _get_conn(self) -> sqlite3.Connection: - """Get or create database connection.""" - if self._conn is None: - db_path = self._get_db_path(for_write=True) - self._conn = sqlite3.connect(db_path, check_same_thread=False) - self._create_table() - return self._conn - - def _create_table(self) -> None: - """Create table if it doesn't exist.""" - conn = self._conn - assert conn is not None - conn.execute(f""" - CREATE TABLE IF NOT EXISTS {self._table} ( - timestamp REAL PRIMARY KEY, - data BLOB NOT NULL - ) - """) - conn.execute(f""" - CREATE INDEX IF NOT EXISTS idx_{self._table}_timestamp - ON {self._table}(timestamp) - """) - conn.commit() - - def _save(self, timestamp: float, data: T) -> None: - conn = self._get_conn() - blob = pickle.dumps(data) - conn.execute( - f"INSERT OR REPLACE INTO {self._table} (timestamp, data) VALUES (?, ?)", - (timestamp, blob), - ) - conn.commit() - - def _load(self, timestamp: float) -> T | None: - conn = self._get_conn() - cursor = conn.execute(f"SELECT data FROM {self._table} WHERE timestamp = ?", (timestamp,)) - row = cursor.fetchone() - if row is None: - return None - data: T = pickle.loads(row[0]) - return data - - def _delete(self, timestamp: float) -> T | None: - data = self._load(timestamp) - if data is not None: - conn = self._get_conn() - conn.execute(f"DELETE FROM {self._table} WHERE timestamp = ?", (timestamp,)) - conn.commit() - return data - - def _iter_items( - self, start: float | None = None, end: float | None = None - ) -> Iterator[tuple[float, T]]: - conn = self._get_conn() - - # Build query with optional range filters - query = f"SELECT timestamp, data FROM {self._table}" - params: list[float] = [] - conditions = [] - - if start is not None: - conditions.append("timestamp >= ?") - params.append(start) - if end is not None: - conditions.append("timestamp < ?") - params.append(end) - - if conditions: - query += " WHERE " + " AND ".join(conditions) - query += " ORDER BY timestamp" - - cursor = conn.execute(query, params) - for row in cursor: - ts: float = row[0] - data: T = pickle.loads(row[1]) - yield (ts, data) - - def _find_closest_timestamp( - self, timestamp: float, tolerance: float | None = None - ) -> float | None: - conn = self._get_conn() - - # Find closest timestamp using SQL - # Get the closest timestamp <= target - cursor = conn.execute( - f""" - SELECT timestamp FROM {self._table} - WHERE timestamp <= ? - ORDER BY timestamp DESC LIMIT 1 - """, - (timestamp,), - ) - before = cursor.fetchone() - - # Get the closest timestamp >= target - cursor = conn.execute( - f""" - SELECT timestamp FROM {self._table} - WHERE timestamp >= ? - ORDER BY timestamp ASC LIMIT 1 - """, - (timestamp,), - ) - after = cursor.fetchone() - - # Find the closest of the two - candidates: list[float] = [] - if before: - candidates.append(before[0]) - if after: - candidates.append(after[0]) - - if not candidates: - return None - - closest = min(candidates, key=lambda ts: abs(ts - timestamp)) - - if tolerance is not None and abs(closest - timestamp) > tolerance: - return None - - return closest - - def _count(self) -> int: - conn = self._get_conn() - cursor = conn.execute(f"SELECT COUNT(*) FROM {self._table}") - return cursor.fetchone()[0] # type: ignore[no-any-return] - - def _last_timestamp(self) -> float | None: - conn = self._get_conn() - cursor = conn.execute(f"SELECT MAX(timestamp) FROM {self._table}") - row = cursor.fetchone() - if row is None or row[0] is None: - return None - return row[0] # type: ignore[no-any-return] - - def _find_before(self, timestamp: float) -> tuple[float, T] | None: - conn = self._get_conn() - cursor = conn.execute( - f"SELECT timestamp, data FROM {self._table} WHERE timestamp < ? ORDER BY timestamp DESC LIMIT 1", - (timestamp,), - ) - row = cursor.fetchone() - if row is None: - return None - return (row[0], pickle.loads(row[1])) - - def _find_after(self, timestamp: float) -> tuple[float, T] | None: - conn = self._get_conn() - cursor = conn.execute( - f"SELECT timestamp, data FROM {self._table} WHERE timestamp > ? ORDER BY timestamp ASC LIMIT 1", - (timestamp,), - ) - row = cursor.fetchone() - if row is None: - return None - return (row[0], pickle.loads(row[1])) - - def close(self) -> None: - """Close the database connection.""" - if self._conn is not None: - self._conn.close() - self._conn = None - - def __del__(self) -> None: - self.close() diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py index fa0e67517c..9ae9788630 100644 --- a/dimos/memory2/test_e2e.py +++ b/dimos/memory2/test_e2e.py @@ -21,7 +21,6 @@ import pytest -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.memory2.embed import EmbedImages from dimos.memory2.store.sqlite import SqliteStore from dimos.memory2.transform import QualityWindow @@ -30,6 +29,7 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data_dir +from dimos.utils.testing.legacy_pickle import LegacyPickleStore if TYPE_CHECKING: from collections.abc import Iterator diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index d03bf2faea..f310fee02b 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -219,7 +219,7 @@ def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: @pytest.mark.tool def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: """Test query_batch optimization - multiple images, same query.""" - from dimos.memory.timeseries.legacy import LegacyPickleStore + from dimos.utils.testing.legacy_pickle import LegacyPickleStore # Load 5 frames at 1-second intervals using LegacyPickleStore replay = LegacyPickleStore[Image]("unitree_go2_office_walk2/video") @@ -276,7 +276,7 @@ def test_vlm_resize( sizes: list[tuple[int, int] | None], ) -> None: """Test VLM auto_resize effect on performance.""" - from dimos.memory.timeseries.legacy import LegacyPickleStore + from dimos.utils.testing.legacy_pickle import LegacyPickleStore replay = LegacyPickleStore[Image]("unitree_go2_office_walk2/video") image = replay.find_closest_seek(0).to_rgb() diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 502161755f..d679c6cb69 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -20,9 +20,9 @@ _IS_MACOS = sys.platform == "darwin" -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier from dimos.utils.data import get_data +from dimos.utils.testing.legacy_pickle import LegacyPickleStore @pytest.fixture diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 5b991a1806..aa326804c0 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -21,7 +21,6 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image @@ -38,6 +37,7 @@ from dimos.robot.unitree.go2 import connection from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data +from dimos.utils.testing.legacy_pickle import LegacyPickleStore class Moment(TypedDict, total=False): diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index 035262f857..8f8ba4786d 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -25,13 +25,13 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.core.transport import LCMTransport -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger +from dimos.utils.testing.legacy_pickle import LegacyPickleStore logger = setup_logger() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 05c600d733..eb3d72b470 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -21,7 +21,6 @@ import threading import time -from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.tf2_msgs.TFMessage import TFMessage @@ -30,6 +29,7 @@ from dimos.protocol.service.spec import BaseConfig, Service from dimos.types.timestamped import to_human_readable from dimos.utils.logging_config import setup_logger +from dimos.utils.timeseries.inmemory import InMemoryStore logger = setup_logger() diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0c312a5491..51b91ee2eb 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -143,7 +143,6 @@ "drone-camera-module": "dimos.robot.drone.camera_module.DroneCameraModule", "drone-connection-module": "dimos.robot.drone.connection_module.DroneConnectionModule", "drone-tracking-module": "dimos.robot.drone.drone_tracking_module.DroneTrackingModule", - "embedding-memory": "dimos.memory.embedding.EmbeddingMemory", "emitter-module": "dimos.utils.demo_image_encoding.EmitterModule", "evaluator": "dimos.navigation.nav_stack.evaluator.evaluator.Evaluator", "far-planner": "dimos.navigation.nav_stack.modules.far_planner.far_planner.FarPlanner", diff --git a/dimos/robot/drone/dji_video_stream.py b/dimos/robot/drone/dji_video_stream.py index df153192e3..92afc1ce50 100644 --- a/dimos/robot/drone/dji_video_stream.py +++ b/dimos/robot/drone/dji_video_stream.py @@ -214,7 +214,7 @@ def get_stream(self) -> Observable[Image]: # type: ignore[override] """ from reactivex import operators as ops - from dimos.memory.timeseries.legacy import LegacyPickleStore + from dimos.utils.testing.legacy_pickle import LegacyPickleStore def _fix_format(img: Image) -> Image: if img.format == ImageFormat.BGR: diff --git a/dimos/robot/drone/mavlink_connection.py b/dimos/robot/drone/mavlink_connection.py index 6b833fcbd3..e908be4ab0 100644 --- a/dimos/robot/drone/mavlink_connection.py +++ b/dimos/robot/drone/mavlink_connection.py @@ -1030,8 +1030,8 @@ def __init__(self, connection_string: str) -> None: # Create fake mavlink object class FakeMavlink: def __init__(self) -> None: - from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.utils.data import get_data + from dimos.utils.testing.legacy_pickle import LegacyPickleStore get_data("drone") diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index df7617bf34..8d415e518d 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -196,7 +196,7 @@ class TestReplayMode(unittest.TestCase): def test_fake_mavlink_connection(self) -> None: """Test FakeMavlinkConnection replays messages correctly.""" - with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: + with patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") as mock_replay: # Mock the replay stream MagicMock() mock_messages = [ @@ -222,7 +222,7 @@ def test_fake_mavlink_connection(self) -> None: def test_fake_video_stream_no_throttling(self) -> None: """Test FakeDJIVideoStream returns replay stream with format fix.""" - with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: + with patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") as mock_replay: mock_stream = MagicMock() mock_replay.return_value.stream.return_value = mock_stream @@ -284,7 +284,7 @@ def test_connection_module_replay_with_messages(self) -> None: os.environ["DRONE_CONNECTION"] = "replay" - with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: + with patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") as mock_replay: # Set up MAVLink replay stream mavlink_messages = [ {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, @@ -432,7 +432,7 @@ def tearDown(self) -> None: self.pubsub_patch.stop() @patch("dimos.robot.drone.drone.ModuleCoordinator") - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") def test_full_system_with_replay(self, mock_replay, mock_coordinator_class) -> None: """Test full drone system initialization and operation with replay mode.""" # Set up mock replay data @@ -566,7 +566,7 @@ def deploy_side_effect(module_class, **kwargs): class TestDroneControlCommands(unittest.TestCase): """Test drone control commands with FakeMavlinkConnection.""" - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: """Test arm and disarm commands work with fake connection.""" @@ -585,7 +585,7 @@ def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: result = conn.disarm() self.assertIsInstance(result, bool) # Should return bool without crashing - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: """Test takeoff and land commands with fake connection.""" @@ -604,7 +604,7 @@ def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: result = conn.land() self.assertIsNotNone(result) - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_set_mode_command(self, mock_get_data, mock_replay) -> None: """Test flight mode setting with fake connection.""" @@ -625,7 +625,7 @@ def test_set_mode_command(self, mock_get_data, mock_replay) -> None: class TestDronePerception(unittest.TestCase): """Test drone perception capabilities.""" - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_video_stream_replay(self, mock_get_data, mock_replay) -> None: """Test video stream works with replay data.""" @@ -695,7 +695,7 @@ def piped_subscribe(callback): class TestDroneMovementAndOdometry(unittest.TestCase): """Test drone movement commands and odometry.""" - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: """Test movement commands are properly converted from ROS to NED.""" @@ -715,7 +715,7 @@ def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: # Movement should be converted to NED internally # The fake connection doesn't actually send commands, but it should not crash - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_odometry_from_replay(self, mock_get_data, mock_replay) -> None: """Test odometry is properly generated from replay messages.""" @@ -762,7 +762,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIsNotNone(odom.orientation) self.assertEqual(odom.frame_id, "world") - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_position_integration_indoor(self, mock_get_data, mock_replay) -> None: """Test position integration for indoor flight without GPS.""" @@ -807,7 +807,7 @@ def replay_stream_subscribe(callback) -> None: class TestDroneStatusAndTelemetry(unittest.TestCase): """Test drone status and telemetry reporting.""" - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_status_extraction(self, mock_get_data, mock_replay) -> None: """Test status is properly extracted from MAVLink messages.""" @@ -852,7 +852,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIn("altitude", status) self.assertIn("heading", status) - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_telemetry_json_publishing(self, mock_get_data, mock_replay) -> None: """Test full telemetry is published as JSON.""" @@ -906,7 +906,7 @@ def replay_stream_subscribe(callback) -> None: class TestFlyToErrorHandling(unittest.TestCase): """Test fly_to() error handling paths.""" - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: """flying_to_target=True rejects concurrent fly_to() calls.""" @@ -920,7 +920,7 @@ def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: result = conn.fly_to(37.0, -122.0, 10.0) self.assertIn("Already flying to target", result) - @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") + @patch("dimos.utils.testing.legacy_pickle.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_error_when_not_connected(self, mock_get_data, mock_replay) -> None: """connected=False returns error immediately.""" diff --git a/dimos/robot/unitree/modular/detect.py b/dimos/robot/unitree/modular/detect.py deleted file mode 100644 index d446f87668..0000000000 --- a/dimos/robot/unitree/modular/detect.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle - -from dimos_lcm.sensor_msgs import CameraInfo - -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.msgs.std_msgs.Header import Header -from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar -from dimos.robot.unitree.type.odometry import Odometry - -image_resize_factor = 1 -originalwidth, originalheight = (1280, 720) - - -def camera_info() -> CameraInfo: - fx, fy, cx, cy = list( - map( - lambda x: int(x / image_resize_factor), - [819.553492, 820.646595, 625.284099, 336.808987], - ) - ) - width, height = tuple( - map( - lambda x: int(x / image_resize_factor), - [originalwidth, originalheight], - ) - ) - - # Camera matrix K (3x3) - K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] - - # No distortion coefficients for now - D = [0.0, 0.0, 0.0, 0.0, 0.0] - - # Identity rotation matrix - R = [1, 0, 0, 0, 1, 0, 0, 0, 1] - - # Projection matrix P (3x4) - P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] - - base_msg = { - "D_length": len(D), - "height": height, - "width": width, - "distortion_model": "plumb_bob", - "D": D, - "K": K, - "R": R, - "P": P, - "binning_x": 0, - "binning_y": 0, - } - - return CameraInfo( - **base_msg, - header=Header("camera_optical"), - ) - - -def transform_chain(odom_frame: Odometry) -> list: # type: ignore[type-arg] - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - from dimos.msgs.geometry_msgs.Transform import Transform - from dimos.msgs.geometry_msgs.Vector3 import Vector3 - from dimos.protocol.tf.tf import TF - - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=odom_frame.ts, - ) - - camera_optical = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", - ts=camera_link.ts, - ) - - tf = TF() - tf.publish( - Transform.from_pose("base_link", odom_frame), - camera_link, - camera_optical, - ) - - return tf # type: ignore[return-value] - - -def broadcast( # type: ignore[no-untyped-def] - timestamp: float, - lidar_frame: PointCloud2, - video_frame: Image, - odom_frame: Odometry, - detections, -) -> None: - from dimos.core.transport import LCMTransport - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - - lidar_transport = LCMTransport("/lidar", PointCloud2) # type: ignore[var-annotated] - odom_transport = LCMTransport("/odom", PoseStamped) # type: ignore[var-annotated] - video_transport = LCMTransport("/image", Image) # type: ignore[var-annotated] - camera_info_transport = LCMTransport("/camera_info", CameraInfo) # type: ignore[var-annotated] - - lidar_transport.broadcast(None, lidar_frame) - video_transport.broadcast(None, video_frame) - odom_transport.broadcast(None, odom_frame) - camera_info_transport.broadcast(None, camera_info()) - - transform_chain(odom_frame) - - print(lidar_frame) - print(video_frame) - print(odom_frame) - video_transport = LCMTransport("/image", Image) - - -def process_data(): # type: ignore[no-untyped-def] - from dimos.memory.timeseries.legacy import LegacyPickleStore - from dimos.msgs.sensor_msgs.Image import Image - from dimos.perception.detection.module2D import ( # type: ignore[attr-defined] - Detection2DModule, - ) - from dimos.robot.unitree.type.odometry import Odometry - from dimos.utils.data import get_data - - get_data("unitree_office_walk") - target = 1751591272.9654856 - lidar_store = LegacyPickleStore( - "unitree_office_walk/lidar", autocast=pointcloud2_from_webrtc_lidar - ) - video_store = LegacyPickleStore("unitree_office_walk/video", autocast=Image.from_numpy) - odom_store = LegacyPickleStore("unitree_office_walk/odom", autocast=Odometry.from_msg) - - def attach_frame_id(image: Image) -> Image: - image.frame_id = "camera_optical" - return image - - lidar_frame = lidar_store.find_closest(target, tolerance=1) - video_frame = attach_frame_id(video_store.find_closest(target, tolerance=1)) # type: ignore[arg-type] - odom_frame = odom_store.find_closest(target, tolerance=1) - - detector = Detection2DModule() - detections = detector.detect(video_frame) # type: ignore[attr-defined] - - data = (target, lidar_frame, video_frame, odom_frame, detections) - - with open("filename.pkl", "wb") as file: - pickle.dump(data, file) - - return data - - -def main() -> None: - try: - with open("filename.pkl", "rb") as file: - data = pickle.load(file) - except FileNotFoundError: - print("Processing data and creating pickle file...") - data = process_data() # type: ignore[no-untyped-call] - broadcast(*data) - - -main() diff --git a/dimos/robot/unitree/testing/test_tooling.py b/dimos/robot/unitree/testing/test_tooling.py index c4f64c054f..cd47a8e903 100644 --- a/dimos/robot/unitree/testing/test_tooling.py +++ b/dimos/robot/unitree/testing/test_tooling.py @@ -16,10 +16,10 @@ import pytest -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.reactive import backpressure +from dimos.utils.testing.legacy_pickle import LegacyPickleStore @pytest.mark.tool diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 6e9cb8be09..d61773eaf0 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -19,8 +19,6 @@ from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import ( Timestamped, @@ -31,6 +29,8 @@ ) from dimos.utils.data import get_data from dimos.utils.reactive import backpressure +from dimos.utils.testing.legacy_pickle import LegacyPickleStore +from dimos.utils.timeseries.inmemory import InMemoryStore def test_timestamped_dt_method() -> None: diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index b229a2478e..284c56c1ea 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -22,9 +22,9 @@ # from dimos_lcm.std_msgs import Time as ROSTime from reactivex.observable import Observable -from dimos.memory.timeseries.inmemory import InMemoryStore from dimos.types.weaklist import WeakList from dimos.utils.logging_config import setup_logger +from dimos.utils.timeseries.inmemory import InMemoryStore logger = setup_logger() diff --git a/dimos/memory/timeseries/legacy.py b/dimos/utils/testing/legacy_pickle.py similarity index 99% rename from dimos/memory/timeseries/legacy.py rename to dimos/utils/testing/legacy_pickle.py index 27194462f7..9d4f92e4ef 100644 --- a/dimos/memory/timeseries/legacy.py +++ b/dimos/utils/testing/legacy_pickle.py @@ -26,8 +26,8 @@ from reactivex.observable import Observable -from dimos.memory.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir +from dimos.utils.timeseries.base import T, TimeSeriesStore class LegacyPickleStore(TimeSeriesStore[T]): diff --git a/dimos/utils/testing/moment.py b/dimos/utils/testing/moment.py index 0130a5b0a3..6dc8dc3be6 100644 --- a/dimos/utils/testing/moment.py +++ b/dimos/utils/testing/moment.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.core.resource import Resource -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.types.timestamped import Timestamped +from dimos.utils.testing.legacy_pickle import LegacyPickleStore if TYPE_CHECKING: from dimos.core.stream import Transport diff --git a/dimos/utils/testing/replay.py b/dimos/utils/testing/replay.py index 4346cf3cbb..d35818d735 100644 --- a/dimos/utils/testing/replay.py +++ b/dimos/utils/testing/replay.py @@ -19,7 +19,7 @@ to be ``"/"``. Callers that still need to read from legacy pickle dirs should import -``LegacyPickleStore`` directly from ``dimos.memory.timeseries.legacy``. The +``LegacyPickleStore`` directly from ``dimos.utils.testing.legacy_pickle``. The write-side (``TimedSensorStorage``/``SensorStorage``) still points at ``LegacyPickleStore`` — out of scope for the memory2 migration. """ @@ -37,9 +37,9 @@ from reactivex.observable import Observable from reactivex.scheduler import TimeoutScheduler -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.memory2.store.sqlite import SqliteStore from dimos.utils.data import get_data +from dimos.utils.testing.legacy_pickle import LegacyPickleStore T = TypeVar("T") diff --git a/dimos/memory/timeseries/test_legacy.py b/dimos/utils/testing/test_legacy.py similarity index 96% rename from dimos/memory/timeseries/test_legacy.py rename to dimos/utils/testing/test_legacy.py index 00ca8357a6..23ac280428 100644 --- a/dimos/memory/timeseries/test_legacy.py +++ b/dimos/utils/testing/test_legacy.py @@ -15,7 +15,7 @@ import pytest -from dimos.memory.timeseries.legacy import LegacyPickleStore +from dimos.utils.testing.legacy_pickle import LegacyPickleStore class TestLegacyPickleStoreRealData: diff --git a/dimos/utils/testing/test_replay.py b/dimos/utils/testing/test_replay.py index 79c0da0404..bd333906f6 100644 --- a/dimos/utils/testing/test_replay.py +++ b/dimos/utils/testing/test_replay.py @@ -17,11 +17,11 @@ import pytest from reactivex import operators as ops -from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data +from dimos.utils.testing.legacy_pickle import LegacyPickleStore pytestmark = pytest.mark.self_hosted diff --git a/dimos/memory/timeseries/base.py b/dimos/utils/timeseries/base.py similarity index 100% rename from dimos/memory/timeseries/base.py rename to dimos/utils/timeseries/base.py diff --git a/dimos/memory/timeseries/inmemory.py b/dimos/utils/timeseries/inmemory.py similarity index 98% rename from dimos/memory/timeseries/inmemory.py rename to dimos/utils/timeseries/inmemory.py index b67faca644..beac31725a 100644 --- a/dimos/memory/timeseries/inmemory.py +++ b/dimos/utils/timeseries/inmemory.py @@ -17,7 +17,7 @@ from sortedcontainers import SortedKeyList # type: ignore[import-untyped] -from dimos.memory.timeseries.base import T, TimeSeriesStore +from dimos.utils.timeseries.base import T, TimeSeriesStore class InMemoryStore(TimeSeriesStore[T]): diff --git a/dimos/memory/timeseries/test_base.py b/dimos/utils/timeseries/test_base.py similarity index 69% rename from dimos/memory/timeseries/test_base.py rename to dimos/utils/timeseries/test_base.py index 61fbdefdfa..4be6d03350 100644 --- a/dimos/memory/timeseries/test_base.py +++ b/dimos/utils/timeseries/test_base.py @@ -14,19 +14,13 @@ """Tests for TimeSeriesStore implementations.""" from dataclasses import dataclass -from pathlib import Path -import tempfile -import uuid import pytest from reactivex import operators as ops -from dimos.memory.timeseries.base import TimeSeriesStore -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.legacy import LegacyPickleStore -from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteStore from dimos.types.timestamped import Timestamped +from dimos.utils.timeseries.base import TimeSeriesStore +from dimos.utils.timeseries.inmemory import InMemoryStore @dataclass @@ -45,85 +39,21 @@ def __eq__(self, other: object) -> bool: return False -@pytest.fixture -def temp_dir(): - """Create a temporary directory for file-based store tests.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - def make_in_memory_store() -> TimeSeriesStore[SampleData]: return InMemoryStore[SampleData]() -def make_pickle_dir_store(tmpdir: str) -> TimeSeriesStore[SampleData]: - return PickleDirStore[SampleData](tmpdir) - - -def make_sqlite_store(tmpdir: str) -> TimeSeriesStore[SampleData]: - return SqliteStore[SampleData](Path(tmpdir) / "test.db") - - -def make_legacy_pickle_store(tmpdir: str) -> TimeSeriesStore[SampleData]: - return LegacyPickleStore[SampleData](Path(tmpdir) / "legacy") - - -# Base test data (always available) testdata: list[tuple[object, str]] = [ - (lambda _: make_in_memory_store(), "InMemoryStore"), - (lambda tmpdir: make_pickle_dir_store(tmpdir), "PickleDirStore"), - (lambda tmpdir: make_sqlite_store(tmpdir), "SqliteStore"), - (lambda tmpdir: make_legacy_pickle_store(tmpdir), "LegacyPickleStore"), + (lambda: make_in_memory_store(), "InMemoryStore"), ] -# Track postgres tables to clean up -_postgres_tables: list[str] = [] - -try: - import psycopg2 - - from dimos.memory.timeseries.postgres import PostgresStore - - # Test connection - _test_conn = psycopg2.connect(dbname="dimensional") - _test_conn.close() - - def make_postgres_store(_tmpdir: str) -> TimeSeriesStore[SampleData]: - """Create PostgresStore with unique table name.""" - table = f"test_{uuid.uuid4().hex[:8]}" - _postgres_tables.append(table) - store = PostgresStore[SampleData](table) - store.start() - return store - - testdata.append((lambda tmpdir: make_postgres_store(tmpdir), "PostgresStore")) - - @pytest.fixture(autouse=True) - def cleanup_postgres_tables(): - """Clean up postgres test tables after each test.""" - yield - if _postgres_tables: - try: - conn = psycopg2.connect(dbname="dimensional") - conn.autocommit = True - with conn.cursor() as cur: - for table in _postgres_tables: - cur.execute(f"DROP TABLE IF EXISTS {table}") - conn.close() - except Exception: - pass # Ignore cleanup errors - _postgres_tables.clear() - -except Exception: - print("PostgreSQL not available") - @pytest.mark.parametrize("store_factory,store_name", testdata) class TestTimeSeriesStore: """Parametrized tests for all TimeSeriesStore implementations.""" - def test_save_and_load(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_save_and_load(self, store_factory, store_name): + store = store_factory() store.save(SampleData("data_at_1", 1.0)) store.save(SampleData("data_at_2", 2.0)) @@ -131,8 +61,8 @@ def test_save_and_load(self, store_factory, store_name, temp_dir): assert store.load(2.0) == SampleData("data_at_2", 2.0) assert store.load(3.0) is None - def test_find_closest_timestamp(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_find_closest_timestamp(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) # Exact match @@ -148,8 +78,8 @@ def test_find_closest_timestamp(self, store_factory, store_name, temp_dir): assert store._find_closest_timestamp(1.4, tolerance=0.5) == 1.0 assert store._find_closest_timestamp(1.4, tolerance=0.3) is None - def test_iter_items(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iter_items(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) # Should iterate in timestamp order @@ -160,8 +90,8 @@ def test_iter_items(self, store_factory, store_name, temp_dir): (3.0, SampleData("c", 3.0)), ] - def test_iter_items_with_range(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iter_items_with_range(self, store_factory, store_name): + store = store_factory() store.save( SampleData("a", 1.0), SampleData("b", 2.0), @@ -185,15 +115,15 @@ def test_iter_items_with_range(self, store_factory, store_name, temp_dir): items = list(store._iter_items(start=2.0, end=4.0)) assert items == [(2.0, SampleData("b", 2.0)), (3.0, SampleData("c", 3.0))] - def test_empty_store(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_empty_store(self, store_factory, store_name): + store = store_factory() assert store.load(1.0) is None assert store._find_closest_timestamp(1.0) is None assert list(store._iter_items()) == [] - def test_first_and_first_timestamp(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_first_and_first_timestamp(self, store_factory, store_name): + store = store_factory() # Empty store assert store.first() is None @@ -206,8 +136,8 @@ def test_first_and_first_timestamp(self, store_factory, store_name, temp_dir): assert store.first_timestamp() == 1.0 assert store.first() == SampleData("a", 1.0) - def test_find_closest(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_find_closest(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) # Exact match @@ -223,8 +153,8 @@ def test_find_closest(self, store_factory, store_name, temp_dir): assert store.find_closest(1.4, tolerance=0.5) == SampleData("a", 1.0) assert store.find_closest(1.4, tolerance=0.3) is None - def test_find_closest_seek(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_find_closest_seek(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 10.0), SampleData("b", 11.0), SampleData("c", 12.0)) # Seek 0 = first item (10.0) @@ -243,8 +173,8 @@ def test_find_closest_seek(self, store_factory, store_name, temp_dir): assert store.find_closest_seek(1.4, tolerance=0.5) == SampleData("b", 11.0) assert store.find_closest_seek(1.4, tolerance=0.3) is None - def test_iterate(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iterate(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) # Should iterate in timestamp order, returning data only (not tuples) @@ -255,8 +185,8 @@ def test_iterate(self, store_factory, store_name, temp_dir): SampleData("c", 3.0), ] - def test_iterate_with_seek_and_duration(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iterate_with_seek_and_duration(self, store_factory, store_name): + store = store_factory() store.save( SampleData("a", 10.0), SampleData("b", 11.0), @@ -284,8 +214,8 @@ def test_iterate_with_seek_and_duration(self, store_factory, store_name, temp_di items = list(store.iterate(from_timestamp=12.0)) assert items == [SampleData("c", 12.0), SampleData("d", 13.0)] - def test_variadic_save(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_variadic_save(self, store_factory, store_name): + store = store_factory() # Save multiple items at once store.save( @@ -298,10 +228,10 @@ def test_variadic_save(self, store_factory, store_name, temp_dir): assert store.load(2.0) == SampleData("b", 2.0) assert store.load(3.0) == SampleData("c", 3.0) - def test_pipe_save(self, store_factory, store_name, temp_dir): + def test_pipe_save(self, store_factory, store_name): import reactivex as rx - store = store_factory(temp_dir) + store = store_factory() # Create observable with test data source = rx.of( @@ -326,10 +256,10 @@ def test_pipe_save(self, store_factory, store_name, temp_dir): SampleData("c", 3.0), ] - def test_consume_stream(self, store_factory, store_name, temp_dir): + def test_consume_stream(self, store_factory, store_name): import reactivex as rx - store = store_factory(temp_dir) + store = store_factory() # Create observable with test data source = rx.of( @@ -348,8 +278,8 @@ def test_consume_stream(self, store_factory, store_name, temp_dir): disposable.dispose() - def test_iterate_items(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iterate_items(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) items = list(store.iterate_items()) @@ -364,8 +294,8 @@ def test_iterate_items(self, store_factory, store_name, temp_dir): assert len(items) == 2 assert items[0] == (2.0, SampleData("b", 2.0)) - async def test_stream_basic(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + async def test_stream_basic(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) # Stream at high speed (essentially instant) @@ -382,52 +312,52 @@ async def test_stream_basic(self, store_factory, store_name, temp_dir): class TestCollectionAPI: """Test new collection API methods on all backends.""" - def test_len(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_len(self, store_factory, store_name): + store = store_factory() assert len(store) == 0 store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert len(store) == 3 - def test_iter(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_iter(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0)) items = list(store) assert items == [SampleData("a", 1.0), SampleData("b", 2.0)] - def test_last_timestamp(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_last_timestamp(self, store_factory, store_name): + store = store_factory() assert store.last_timestamp() is None store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert store.last_timestamp() == 3.0 - def test_last(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_last(self, store_factory, store_name): + store = store_factory() assert store.last() is None store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert store.last() == SampleData("c", 3.0) - def test_start_end_ts(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_start_end_ts(self, store_factory, store_name): + store = store_factory() assert store.start_ts is None assert store.end_ts is None store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert store.start_ts == 1.0 assert store.end_ts == 3.0 - def test_time_range(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_time_range(self, store_factory, store_name): + store = store_factory() assert store.time_range() is None store.save(SampleData("a", 1.0), SampleData("b", 5.0)) assert store.time_range() == (1.0, 5.0) - def test_duration(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_duration(self, store_factory, store_name): + store = store_factory() assert store.duration() == 0.0 store.save(SampleData("a", 1.0), SampleData("b", 5.0)) assert store.duration() == 4.0 - def test_find_before(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_find_before(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert store.find_before(0.5) is None @@ -436,8 +366,8 @@ def test_find_before(self, store_factory, store_name, temp_dir): assert store.find_before(2.5) == SampleData("b", 2.0) assert store.find_before(10.0) == SampleData("c", 3.0) - def test_find_after(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_find_after(self, store_factory, store_name): + store = store_factory() store.save(SampleData("a", 1.0), SampleData("b", 2.0), SampleData("c", 3.0)) assert store.find_after(0.5) == SampleData("a", 1.0) @@ -446,8 +376,8 @@ def test_find_after(self, store_factory, store_name, temp_dir): assert store.find_after(3.0) is None # strictly after assert store.find_after(10.0) is None - def test_slice_by_time(self, store_factory, store_name, temp_dir): - store = store_factory(temp_dir) + def test_slice_by_time(self, store_factory, store_name): + store = store_factory() store.save( SampleData("a", 1.0), SampleData("b", 2.0),