diff --git a/.gitignore b/.gitignore index 37fbb602..35360459 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,7 @@ rever/ # pixi environments .pixi *.egg-info + +# Produced by tests +Ni/ +unnamed_sample/ diff --git a/pdfstream/analyzers/base.py b/pdfstream/analyzers/base.py index 72a660cd..750e8607 100644 --- a/pdfstream/analyzers/base.py +++ b/pdfstream/analyzers/base.py @@ -1,7 +1,62 @@ from configparser import ConfigParser +import numpy as np from bluesky.callbacks.core import CallbackBase -from databroker.client import BlueskyRun +from bluesky_tiled_plugins import BlueskyRun + + +def iter_documents_filled(run: BlueskyRun): + """Iterate filled (name, doc) pairs from a tiled BlueskyRun. + + Reconstructs the document stream by reading column data directly from + the tiled event streams, since run.documents() does not include data. + """ + # start + yield "start", dict(run.start) + # streams + for stream_name in run.keys(): + stream = run[stream_name] + meta = dict(stream.metadata) + descriptor_doc = { + "uid": meta.get("uid", ""), + "run_start": run.start.get("uid", ""), + "time": meta.get("time", 0), + "data_keys": meta.get("data_keys", {}), + "configuration": meta.get("configuration", {}), + "name": stream_name, + "hints": meta.get("hints", {}), + "object_keys": meta.get("object_keys", {}), + } + yield "descriptor", descriptor_doc + # events - read column data + data_keys = set(meta.get("data_keys", {}).keys()) + columns = list(stream.keys()) + if not columns: + continue + n_events = len(stream[columns[0]].read()) + col_data = {col: stream[col].read() for col in columns} + for i in range(n_events): + event_data = {} + event_timestamps = {} + for key in data_keys: + if key in col_data: + val = col_data[key][i] + event_data[key] = np.asarray(val) if hasattr(val, '__array__') else val + ts_key = f"ts_{key}" + if ts_key in col_data: + event_timestamps[key] = float(col_data[ts_key][i]) + event_doc = { + "descriptor": descriptor_doc["uid"], + "uid": f"event-{descriptor_doc['uid']}-{i + 1}", + "time": float(col_data["time"][i]) if "time" in col_data else 0, + "seq_num": int(col_data["seq_num"][i]) if "seq_num" in col_data else i + 1, + "data": event_data, + "timestamps": event_timestamps, + "filled": {k: True for k in data_keys}, + } + yield "event", event_doc + # stop + yield "stop", dict(run.stop) class AnalyzerConfig(ConfigParser): @@ -19,9 +74,9 @@ class Analyzer(CallbackBase): def analyze(self, run: BlueskyRun): """Analyze the data in a bluesky run.""" - for name, doc in run.canonical(fill="yes"): + for name, doc in iter_documents_filled(run): # inject the original_db if name == "start": - doc = doc.to_dict() - doc["original_db"] = run.catalog_object.name + doc = dict(doc) + doc["original_db"] = run.uri self.__call__(name, doc) diff --git a/pdfstream/analyzers/xpd_analyzer.py b/pdfstream/analyzers/xpd_analyzer.py index 7f39856b..3c3db984 100644 --- a/pdfstream/analyzers/xpd_analyzer.py +++ b/pdfstream/analyzers/xpd_analyzer.py @@ -1,7 +1,7 @@ import typing as tp -from databroker import catalog -from databroker.client import BlueskyRun +from bluesky_tiled_plugins import BlueskyRun +from tiled.client import from_uri from pdfstream.analyzers.base import AnalyzerConfig, Analyzer from pdfstream.servers.xpd_server import XPDRouter, XPDConfig @@ -17,7 +17,7 @@ class XPDAnalyzer(XPDRouter, Analyzer): pass -def replay(run: BlueskyRun) -> tp.Tuple[XPDAnalyzerConfig, XPDAnalyzer]: +def replay(run) -> tp.Tuple[XPDAnalyzerConfig, XPDAnalyzer]: """Generate the original data, original configure and the XPD analyzer of it. Parameters @@ -39,7 +39,7 @@ def replay(run: BlueskyRun) -> tp.Tuple[XPDAnalyzerConfig, XPDAnalyzer]: return config, analyzer -def retrieve_original_run(run: BlueskyRun) -> tp.Union[None, BlueskyRun]: +def retrieve_original_run(run: BlueskyRun) -> tp.Union[None, tp.Any]: """Retrieve the original run.""" start = run.metadata['start'] if 'original_run_uid' not in start: @@ -47,9 +47,9 @@ def retrieve_original_run(run: BlueskyRun) -> tp.Union[None, BlueskyRun]: if 'original_db' not in start: raise Warning("Missing original_db. Cannot retrieve original run.") try: - db = catalog[start['original_db']] - except KeyError: - raise Warning("Missing {} in catalog. Cannot retrieve original run.".format(start['original_db'])) + db = from_uri(start['original_db']) + except Exception: + raise Warning("Cannot connect to {}. Cannot retrieve original run.".format(start['original_db'])) try: return db[start['original_run_uid']] except KeyError: diff --git a/pdfstream/callbacks/analysis.py b/pdfstream/callbacks/analysis.py index 23570503..ee6e5fae 100644 --- a/pdfstream/callbacks/analysis.py +++ b/pdfstream/callbacks/analysis.py @@ -2,15 +2,15 @@ import datetime import typing import typing as tp -from configparser import ConfigParser +from configparser import ConfigParser, NoOptionError from pathlib import Path import event_model import matplotlib.pyplot as plt import numpy as np from bluesky.callbacks.stream import LiveDispatcher -from databroker.v1 import Broker from event_model import RunRouter +from tiled.client import from_uri from pyFAI.integrator.azimuthal import AzimuthalIntegrator from suitcase.csv import Serializer as CSVSerializer from suitcase.json_metadata import Serializer as JsonSerializer @@ -42,6 +42,10 @@ class BasicAnalysisConfig(ConfigParser): def raw_db(self) -> str: return self.get("DATABASE", "raw_db", fallback="") + @property + def raw_db_api_key(self) -> str: + return self.get("DATABASE", "raw_db_api_key", fallback="") + @property def dark_identifier(self): return self.get("METADATA", "dk_identifier", fallback="dark_frame") @@ -163,7 +167,14 @@ def __init__(self, config: AnalysisConfig): self.init_config = config self.config: typing.Union[AnalysisConfig, None] = None db_name = config.raw_db - self.db = Broker.named(db_name) if db_name else None + db_api_key = config.raw_db_api_key + if db_name: + kwargs = {"uri": db_name} + if db_api_key: + kwargs["api_key"] = db_api_key + self.db = from_uri(**kwargs) + else: + self.db = None self.valid_keys = config.valid_keys self.start_doc = {} self.ai = None @@ -419,7 +430,10 @@ def directory_template(self): @property def tiff_base(self): """Settings for the base folder.""" - dir_path = self.get("SUITCASE", "tiff_base") + try: + dir_path = self.get("SUITCASE", "tiff_base") + except NoOptionError: + dir_path = None if not dir_path: dir_path = "~/pdfstream_data" io.server_message("Missing tiff_base in configuration. Use '{}'".format(dir_path)) diff --git a/pdfstream/callbacks/basic.py b/pdfstream/callbacks/basic.py index 8efbb1e7..b7e39f51 100644 --- a/pdfstream/callbacks/basic.py +++ b/pdfstream/callbacks/basic.py @@ -8,7 +8,6 @@ from bluesky.callbacks import CallbackBase from bluesky.callbacks.best_effort import LivePlot, LiveScatter from bluesky.callbacks.broker import LiveImage -from databroker.v2 import Broker from event_model import unpack_event_page from matplotlib import pyplot as plt from matplotlib.axes import Axes @@ -182,7 +181,7 @@ class LiveMaskedImage(LiveImage): def __init__(self, field: str, msk_field: str, *, cmap: str, norm: tp.Callable = None, limit_func: tp.Callable = None, auto_draw: bool = True, interpolation: str = None, - window_title: str = None, db: Broker = None): + window_title: str = None, db=None): self.msk_field = msk_field self.msk_array = None super(LiveMaskedImage, self).__init__( diff --git a/pdfstream/callbacks/calibration.py b/pdfstream/callbacks/calibration.py index e826e076..ac3f010c 100644 --- a/pdfstream/callbacks/calibration.py +++ b/pdfstream/callbacks/calibration.py @@ -7,8 +7,8 @@ import event_model import numpy as np from bluesky.callbacks.stream import LiveDispatcher -from databroker.v1 import Broker from tifffile import TiffWriter +from tiled.client import from_uri import pdfstream import pdfstream.callbacks.analysis as an @@ -68,7 +68,14 @@ def __init__(self, config: CalibrationConfig, *, test: bool = False): self.config = config self.cache = dict() raw_db = self.config.raw_db - self.db = Broker.named(raw_db) if raw_db else None + raw_db_api_key = self.config.raw_db_api_key + if raw_db: + kwargs_db = {"uri": raw_db} + if raw_db_api_key: + kwargs_db["api_key"] = raw_db_api_key + self.db = from_uri(**kwargs_db) + else: + self.db = None self.test = test self.start_doc = {} self.event_doc = {} diff --git a/pdfstream/callbacks/composer.py b/pdfstream/callbacks/composer.py index 2641a2f6..b2c16eeb 100644 --- a/pdfstream/callbacks/composer.py +++ b/pdfstream/callbacks/composer.py @@ -50,3 +50,96 @@ def compose_data_info(value: tp.Any) -> dict: def compose_timestamps(data: tp.Dict[str, tp.Any]) -> tp.Dict[str, float]: """Compose the fake time for the data measurement.""" return {k: time.time() for k in data.keys()} + + +def gen_stream_external( + data_lst: tp.List[dict], + metadata: dict, + external_keys: tp.Set[str], + uid: str = None +) -> tp.Generator[tp.Tuple[str, dict], None, None]: + """Generate a fake doc stream with external (stream_resource/stream_datum) references. + + This simulates the document stream that would arrive over ZMQ from a detector + that writes data to an external file and uses stream_resource/stream_datum. + The external keys will have ``external: 'STREAM:'`` in the descriptor and + unfilled references in the events. + + Parameters + ---------- + data_lst : list of dict + The data for each event. External keys should still contain the actual + data values (used for computing data_keys shapes), but they will be + replaced with datum uid references in the emitted events. + metadata : dict + Run metadata for the start document. + external_keys : set of str + Which data keys should be treated as external (stream_resource/stream_datum). + uid : str, optional + UID for the run. Generated if not provided. + """ + run_uid = uid if uid else str(uuid.uuid4()) + crb = compose_run(metadata=metadata, uid=run_uid) + yield "start", crb.start_doc + if len(data_lst) == 0: + yield "stop", crb.compose_stop() + return + + # Build data_keys, marking external keys + data_keys = {} + for k, v in data_lst[0].items(): + info = compose_data_info(v) + info["source"] = "PV:{}".format(k.upper()) + if k in external_keys: + info["external"] = "STREAM:" + data_keys[k] = info + + cdb: ComposeDescriptorBundle = crb.compose_descriptor( + name="primary", + data_keys=data_keys, + ) + yield "descriptor", cdb.descriptor_doc + desc_uid = cdb.descriptor_doc["uid"] + + # Emit stream_resource and stream_datum for each external key + sr_bundles = {} + for key in external_keys: + sr_bundle = crb.compose_stream_resource( + mimetype="application/x-hdf5", + uri="file:///tmp/fake_{}.h5".format(key), + data_key=key, + parameters={"dataset": ["entry", "data", key]}, + ) + sr_bundles[key] = sr_bundle + yield "stream_resource", sr_bundle.stream_resource_doc + + for i, data in enumerate(data_lst): + # Emit stream_datum for each external key + for key in external_keys: + sd_doc = sr_bundles[key].compose_stream_datum( + seq_nums={"start": i, "stop": i + 1}, + indices={"start": i, "stop": i + 1}, + ) + sd_doc["descriptor"] = desc_uid + yield "stream_datum", sd_doc + + # Emit event with unfilled external keys + event_data = {} + event_ts = {} + event_filled = {} + for k, v in data.items(): + if k in external_keys: + event_data[k] = "unfilled_datum_ref" + event_filled[k] = False + else: + event_data[k] = v + event_filled[k] = True + event_ts[k] = time.time() + + yield "event", cdb.compose_event( + data=event_data, + timestamps=event_ts, + filled=event_filled, + ) + + yield "stop", crb.compose_stop() diff --git a/pdfstream/callbacks/filling.py b/pdfstream/callbacks/filling.py new file mode 100644 index 00000000..8fee0cc7 --- /dev/null +++ b/pdfstream/callbacks/filling.py @@ -0,0 +1,153 @@ +"""Subscribe to a tiled dataset stream to receive data as it is written.""" +import time +import uuid + +import numpy as np +from bluesky.callbacks import CallbackBase + +from pdfstream.io import server_message + + +class TiledSubscriber(CallbackBase): + """A callback that subscribes to a tiled dataset using tiled's streaming API. + + Instead of filling event documents after the fact, this subscriber uses tiled's + built-in streaming support (``node.subscribe()`` + ``new_data`` callbacks) to receive + data as it is written to the specified data_key under the primary stream. + + On each ``new_data`` update, it emits a synthetic event document downstream containing + the array data. + + Parameters + ---------- + tiled_client : + A tiled client connected to the raw data catalog. + data_key : + The name of the dataset under the stream to subscribe to (e.g. "pe1_image"). + stream_name : + The stream name to subscribe to. Default is "primary". + max_retries : + Maximum number of retries when looking up the run in tiled. + retry_delay : + Seconds to wait between retries. + """ + + def __init__(self, tiled_client, data_key, stream_name="primary", + max_retries=20, retry_delay=1.0): + super().__init__() + self.tiled_client = tiled_client + self.data_key = data_key + self.stream_name = stream_name + self._run = None + self._uid = None + self._desc_uid = None + self._subscribers = [] + self._max_retries = max_retries + self._retry_delay = retry_delay + self._subscription = None + self._seq_num = 0 + + def subscribe(self, callback): + """Subscribe a callback to receive documents.""" + self._subscribers.append(callback) + + def _emit(self, name, doc): + """Forward a document to all subscribers.""" + for cb in self._subscribers: + cb(name, doc) + + def _lookup_run(self): + """Try to look up the current run in tiled, with retries.""" + for attempt in range(self._max_retries): + try: + self._run = self.tiled_client[self._uid] + server_message(f"TiledSubscriber: found run '{self._uid}' on attempt {attempt + 1}") + return True + except KeyError: + if attempt < self._max_retries - 1: + time.sleep(self._retry_delay) + server_message(f"TiledSubscriber: FAILED to find run '{self._uid}' after {self._max_retries} attempts") + self._run = None + return False + + def _on_new_data(self, update): + """Handle a new_data update from the tiled streaming subscription.""" + self._seq_num += 1 + arr = np.asarray(update.data()) + server_message(f"TiledSubscriber: received '{self.data_key}' seq_num={self._seq_num} " + f"shape={arr.shape}") + event_doc = { + "uid": str(uuid.uuid4()), + "descriptor": self._desc_uid, + "seq_num": self._seq_num, + "time": time.time(), + "data": {self.data_key: arr}, + "timestamps": {self.data_key: time.time()}, + "filled": {self.data_key: True}, + } + self._emit("event", event_doc) + + def _start_subscription(self): + """Subscribe to the dataset in tiled using the streaming API, with retries.""" + for attempt in range(self._max_retries): + try: + # Refresh the run to pick up newly ingested streams + self._run = self.tiled_client[self._uid] + stream = self._run[self.stream_name] + if stream is None: + raise KeyError(f"Stream '{self.stream_name}' not yet available") + dataset = stream[self.data_key] + if dataset is None: + raise KeyError(f"Dataset '{self.data_key}' not yet available") + self._subscription = dataset.subscribe() + self._subscription.new_data.add_callback(self._on_new_data) + # Start from 0 to catch up on any data written before we subscribed + self._subscription.start_in_thread(0) + server_message(f"TiledSubscriber: streaming '{self.data_key}' from " + f"run '{self._uid}' stream '{self.stream_name}' " + f"(attempt {attempt + 1})") + return + except (KeyError, TypeError) as e: + if attempt < self._max_retries - 1: + server_message(f"TiledSubscriber: stream/dataset not ready, " + f"retrying ({attempt + 1}/{self._max_retries}): {e}") + time.sleep(self._retry_delay) + else: + server_message(f"TiledSubscriber: FAILED to start subscription " + f"after {self._max_retries} attempts: {e}") + + def start(self, doc): + self._uid = doc["uid"] + self._run = None + self._seq_num = 0 + self._subscription = None + self._emit("start", doc) + + # Look up the run in tiled + if not self._lookup_run(): + server_message("TiledSubscriber: cannot proceed without run in tiled") + return + + def descriptor(self, doc): + if doc.get("name", "primary") == self.stream_name: + self._desc_uid = doc["uid"] + self._emit("descriptor", doc) + + # Start the tiled streaming subscription after we have the descriptor + if self._desc_uid and self._subscription is None and self._run is not None: + self._start_subscription() + + def event(self, doc): + # Events from ZMQ are ignored — data comes from the tiled stream instead + pass + + def stop(self, doc): + if self._subscription is not None: + self._subscription.disconnect() + server_message("TiledSubscriber: subscription disconnected") + self._subscription = None + self._emit("stop", doc) + self._run = None + self._uid = None + self._desc_uid = None + server_message("TiledSubscriber: run complete") diff --git a/pdfstream/callbacks/from_event.py b/pdfstream/callbacks/from_event.py index 8f20ea2b..a5d9b22f 100644 --- a/pdfstream/callbacks/from_event.py +++ b/pdfstream/callbacks/from_event.py @@ -24,6 +24,7 @@ def get_image_from_event( img : The two dimensional array of image. """ + print(event) data = event['data'][det_name] return get_average_frame(data) diff --git a/pdfstream/callbacks/from_start.py b/pdfstream/callbacks/from_start.py index 7ec40f76..7971b9a8 100644 --- a/pdfstream/callbacks/from_start.py +++ b/pdfstream/callbacks/from_start.py @@ -2,8 +2,7 @@ import itertools import typing -from databroker import Header -from databroker.v1 import Broker +import numpy as np from numpy import ndarray import pdfstream.io as io @@ -40,7 +39,7 @@ def query_ai( def query_dk_img( start: typing.Dict[str, typing.Any], det_name: str, - db: Broker = None, + db=None, dk_id_key: str = None ) -> typing.Union[ndarray, None]: """Find the dark image according to the start document of a run. @@ -58,7 +57,7 @@ def query_dk_img( The name in the background image data in the xarray of the run. db : - The database that contains the background image run. + The tiled client catalog that contains the background image run. dk_id_key : The key of dark image id in the start document of background image run. @@ -68,12 +67,12 @@ def query_dk_img( dk_img : The raw dark image. If not found, None. """ - dk_run = get_dk_run_v1(start, db, dk_id_key) - return get_img_from_run_v1(dk_run, det_name) + dk_run = get_dk_run(start, db, dk_id_key) + return get_img_from_run(dk_run, det_name) -def get_dk_run_v1(start: dict, db: Broker, dk_id_key: str) -> typing.Union[Header]: - """Get the dark image run id. If not found, return None.""" +def get_dk_run(start: dict, db, dk_id_key: str): + """Get the dark run from the tiled catalog. If not found, raise ValueNotFoundError.""" if not db: raise ValueNotFoundError("db is None.") if not dk_id_key: @@ -87,34 +86,21 @@ def get_dk_run_v1(start: dict, db: Broker, dk_id_key: str) -> typing.Union[Heade raise ValueNotFoundError("No such a run in db: {}".format(dk_id)) -def get_img_from_run_v1(run: Header, det_name: str) -> ndarray: - """Read a single image of a detector from a run (databroker v2).""" - if det_name not in run.fields(): - raise ValueNotFoundError("No such a det_name '{}' in run '{}'".format(det_name, run.uid)) +def get_img_from_run(run, det_name: str) -> ndarray: + """Read a single image of a detector from a run via tiled client.""" try: - img = mean(run.data(det_name)) - except StopIteration: - raise ValueNotFoundError("No images data for '{}' in run '{}'".format(det_name, run.uid)) + data = run["primary"][det_name].read() + except (KeyError, AttributeError): + raise ValueNotFoundError("No such a det_name '{}' in run".format(det_name)) + img = np.asarray(data) + if img.size == 0: + raise ValueNotFoundError("No images data for '{}' in run".format(det_name)) + # Average over all dimensions except the last two (the image dimensions) if img.ndim > 2: img = img.mean(axis=tuple(range(img.ndim - 2))) return img -def mean(images: typing.Iterable[ndarray]) -> ndarray: - """Calculate mean of an iterator of numpy array.""" - image_iter = iter(images) - avg_image = next(image_iter) - count = 1 - for image in image_iter: - avg_image += image - return avg_image / count - - -def get_start_of_run_v1(run: Header): - """Read the start document of a run (databroker v2).""" - return run.start - - def query_bt_info( start: typing.Dict[str, typing.Any], composition_key: str, diff --git a/pdfstream/servers/base.py b/pdfstream/servers/base.py index 15ea0ee8..7ae31ea8 100644 --- a/pdfstream/servers/base.py +++ b/pdfstream/servers/base.py @@ -25,7 +25,7 @@ def port(self): @property def address(self): - return self.host + return (self.host, self.port) @property def prefix(self): diff --git a/pdfstream/servers/lsq_server.py b/pdfstream/servers/lsq_server.py index 5f4564e2..b8057878 100644 --- a/pdfstream/servers/lsq_server.py +++ b/pdfstream/servers/lsq_server.py @@ -6,11 +6,11 @@ import matplotlib.pyplot as plt import numpy as np import scipy.optimize as opt -from area_detector_handlers.handlers import AreaDetectorTiffHandler +# from area_detector_handlers.handlers import AreaDetectorTiffHandler from bluesky.callbacks.stream import LiveDispatcher from diffpy.pdfgetx import PDFGetter, PDFConfig from event_model import RunRouter -from ophyd.sim import NumpySeqHandler +# from ophyd.sim import NumpySeqHandler import pdfstream.units as units from pdfstream.callbacks.basic import LiveWaterfall, NumpyExporter diff --git a/pdfstream/servers/xpd_server.py b/pdfstream/servers/xpd_server.py index c81180ac..753405e8 100644 --- a/pdfstream/servers/xpd_server.py +++ b/pdfstream/servers/xpd_server.py @@ -1,15 +1,16 @@ """The analysis server. Process raw image to PDF.""" import typing as tp -import databroker.mongo_normalized from bluesky.callbacks.zmq import Publisher -from databroker.v1 import Broker +from bluesky_tiled_plugins import TiledWriter from event_model import RunRouter +from tiled.client import from_uri import pdfstream.io as io from pdfstream.callbacks.analysis import AnalysisConfig, VisConfig, ExportConfig, AnalysisStream, Exporter, \ Visualizer from pdfstream.callbacks.calibration import CalibrationConfig, Calibration +from pdfstream.callbacks.filling import TiledSubscriber from pdfstream.servers.base import ServerConfig, BaseServer @@ -31,10 +32,15 @@ def publisher_config(self) -> dict: port = self.getint("PUBLISH TO", "port", fallback=5567) prefix = self.get("PUBLISH TO", "prefix", fallback="an").encode() return { - "address": host, + "address": (host, port), "prefix": prefix } + @property + def data_key(self) -> str: + """The dataset name under the primary stream to subscribe to via tiled streaming.""" + return self.get("METADATA", "data_key", fallback="pe1_image") + @property def functionality(self) -> dict: return { @@ -97,10 +103,7 @@ class XPDRouter(RunRouter): def __init__(self, config: XPDConfig): factory = XPDFactory(config) - super(XPDRouter, self).__init__( - [factory], - handler_registry=databroker.mongo_normalized.discover_handlers() - ) + super(XPDRouter, self).__init__([factory]) class XPDFactory: @@ -109,11 +112,25 @@ class XPDFactory: def __init__(self, config: XPDConfig): self.config = config self.functionality = self.config.functionality + # Create a tiled subscriber that streams data from the raw tiled server + raw_db = config.raw_db + raw_db_api_key = config.raw_db_api_key + raw_kwargs = {"uri": raw_db} + if raw_db_api_key: + raw_kwargs["api_key"] = raw_db_api_key + raw_client = from_uri(**raw_kwargs) if raw_db else None + self.subscriber = ( + TiledSubscriber(raw_client, data_key=config.data_key) + if raw_client else None + ) self.analysis = [AnalysisStream(config)] self.calibration = [Calibration(config)] if self.functionality["do_calibration"] else [] + # Wire subscriber -> analysis stream + if self.subscriber: + self.subscriber.subscribe(self.analysis[0]) if self.functionality["dump_to_db"] and self.config.an_db: - db = Broker.named(self.config.an_db) - self.analysis[0].subscribe(db.insert) + tw = TiledWriter.from_uri(self.config.an_db, batch_size=1) + self.analysis[0].subscribe(tw) if self.functionality["export_files"]: self.analysis[0].subscribe(Exporter(config)) if self.functionality["visualize_data"]: @@ -139,5 +156,7 @@ def __call__(self, name: str, doc: dict) -> tp.Tuple[list, list]: else: # light frame run io.server_message("Receive a measurement run. Ready to start processing the data.") + if self.subscriber: + return [self.subscriber], [] return self.analysis, [] return [], [] diff --git a/requirements/run.txt b/requirements/run.txt index 507927fc..6c17297e 100644 --- a/requirements/run.txt +++ b/requirements/run.txt @@ -4,10 +4,8 @@ pyfai pyopencl scikit-beam bluesky -databroker -xpdview +tiled[all] xray-vision suitcase-csv suitcase-tiff suitcase-json-metadata -area-detector-handlers diff --git a/scripts/pdfstream_install b/scripts/pdfstream_install index 32821277..a6dc8311 100644 --- a/scripts/pdfstream_install +++ b/scripts/pdfstream_install @@ -1,7 +1,24 @@ #!/bin/bash # install the extra non-open-source packages for pdfstream -set -e -echo "Install the diffpy.pdfgetx" -PDFGETX=$1 -python -m pip install "$PDFGETX" +set -euo pipefail + +echo "=== pdfstream_install ===" +echo "Installing extra non-open-source packages for pdfstream" +echo "" + +PDFGETX=${1:?"Error: path to diffpy.pdfgetx package must be provided as first argument"} + +echo "Package source: $PDFGETX" +echo "Python executable: $(which python)" +echo "Python version: $(python --version 2>&1)" +echo "pip version: $(python -m pip --version)" +echo "" + +echo "Installing diffpy.pdfgetx..." +python -m pip install -v "$PDFGETX" + +echo "" +echo "Installation complete." +echo "Installed packages:" +python -m pip show diffpy.pdfgetx 2>/dev/null | grep -E '^(Name|Version|Location)' || echo " (could not query package info)" diff --git a/scripts/start_servers.py b/scripts/start_servers.py new file mode 100644 index 00000000..21282b03 --- /dev/null +++ b/scripts/start_servers.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python +"""Start up xpd_server, xpdsave_server, and xpdvis_server on localhost. + +Each server runs in its own process. The xpd_server listens for raw data on +the raw proxy port, processes it, and publishes analyzed results through a +second ZMQ proxy. The save and vis servers subscribe to that analyzed proxy. + +Port assignments (all on localhost): + 5568 - raw data proxy OUT -> xpd_server listens here + 5567 - analyzed proxy IN -> xpd_server Publisher connects here + 5566 - analyzed proxy OUT -> xpdsave/xpdvis servers listen here + +Usage: + # Use default configs (no arguments required): + python scripts/start_servers.py + + # Provide your own configs: + python scripts/start_servers.py --xpd-config my_xpd.ini --save-config my_save.ini --vis-config my_vis.ini + + # Skip individual servers: + python scripts/start_servers.py --no-save --no-vis +""" +import argparse +import logging +import multiprocessing +import os +import signal +import sys +import tempfile +import time +import warnings +from pathlib import Path + +logger = logging.getLogger("pdfstream.servers") + +# Default port assignments +RAW_PROXY_PORT = 5568 # proxy publishes raw data here; xpd_server subscribes +ANALYZED_IN_PORT = 5567 # analyzed proxy IN: xpd_server Publisher connects here +ANALYZED_OUT_PORT = 5566 # analyzed proxy OUT: save/vis servers subscribe here +HOST = "localhost" +ANALYZED_PREFIX = "an" + +# Default data directories +DEFAULT_DATA_DIR = Path("~/pdfstream_data").expanduser() +DEFAULT_CALIB_DIR = Path("~/pdfstream_calibration").expanduser() + + +def _default_xpd_config(): + """Return the contents of a reasonable default xpd_server config.""" + return f"""\ +[BASIC] +name = xpd +version = 1.0.0 + +[FUNCTIONALITY] +do_calibration = True +dump_to_db = False +export_files = False +visualize_data = False +send_messages = True + +[LISTEN TO] +host = {HOST} +port = {RAW_PROXY_PORT} +prefix = raw + +[PUBLISH TO] +host = {HOST} +port = {ANALYZED_IN_PORT} +prefix = {ANALYZED_PREFIX} + +[DATABASE] +raw_db = http://localhost:8008 +raw_db_api_key = test +# an_db = http://localhost:8001 + +[METADATA] +dk_identifier = dark_frame +calib_identifier = is_calibration +dk_id_key = sc_dk_field_uid +calibration_md_key = calibration_md +composition_key = sample_composition +wavelength_key = bt_wavelength +bkgd_sample_name_key = bkgd_sample_name +sample_name_key = sample_name +detector_key = detector +calibrant_key = sample_composition +data_key = xsp + +[CALIBRATION] +calib_base = {DEFAULT_CALIB_DIR} +default_calibrant = Ni + +[ANALYSIS] +alpha = 2.0 +edge = 20 +lower_thresh = 0.0 +npt = 1024 +correctSolidAngle = False +polarization_factor = 0.99 +rpoly = 1.0 +qmaxinst = 24.0 +qmin = 0.0 +qmax = 22.0 +rmin = 0.0 +rmax = 30.0 +rstep = 0.01 + +[SUITCASE] +tiff_base = {DEFAULT_DATA_DIR} +exports = tiff,yaml,csv,txt +file_prefix = {{start[original_run_uid]}}_{{start[readable_time]}}_ +""" + + +def _default_save_config(): + """Return the contents of a reasonable default xpdsave_server config.""" + return f"""\ +[BASIC] +name = xpdsave +version = 1.0.0 + +[LISTEN TO] +host = {HOST} +port = {ANALYZED_OUT_PORT} +prefix = {ANALYZED_PREFIX} + +[SUITCASE] +tiff_base = {DEFAULT_DATA_DIR} +exports = tiff,yaml,csv,txt +file_prefix = {{start[original_run_uid]}}_{{start[readable_time]}}_ +""" + + +def _default_vis_config(): + """Return the contents of a reasonable default xpdvis_server config.""" + return f"""\ +[BASIC] +name = xpdvis +version = 1.0.0 + +[LISTEN TO] +host = {HOST} +port = {ANALYZED_OUT_PORT} +prefix = {ANALYZED_PREFIX} + +[VISUALIZATION] +visualizers = dk_sub_image,masked_image,chi,iq,sq,fq,gr,chi_max,chi_argmax,gr_max,gr_argmax +""" + + +def _write_default_config(tmpdir, name, content): + """Write a default config to a temp file and return the path.""" + path = os.path.join(tmpdir, f"{name}.ini") + with open(path, "w") as f: + f.write(content) + return path + + +def run_analyzed_proxy(): + """Run a ZMQ proxy for analyzed data (Publisher → Proxy → RemoteDispatchers).""" + logger.debug("Importing bluesky.callbacks.zmq.Proxy") + from bluesky.callbacks.zmq import Proxy + logger.info(f"Analyzed proxy binding in={HOST}:{ANALYZED_IN_PORT} out={HOST}:{ANALYZED_OUT_PORT}") + proxy = Proxy( + in_address=(HOST, ANALYZED_IN_PORT), + out_address=(HOST, ANALYZED_OUT_PORT), + ) + logger.debug("Analyzed proxy entering event loop") + proxy.start() + + +def run_xpd_server(cfg_file): + """Run the XPD analysis server.""" + warnings.simplefilter("ignore") + logger.debug(f"Loading xpd_server config from: {cfg_file}") + from pdfstream.servers.xpd_server import XPDServerConfig, XPDServer + config = XPDServerConfig() + config.read(cfg_file) + logger.info(f"xpd_server config loaded: calibration={config.functionality.get('do_calibration')}, " + f"export={config.functionality.get('export_files')}, " + f"visualize={config.functionality.get('visualize_data')}") + server = XPDServer(config) + if config.functionality["visualize_data"]: + logger.debug("Installing Qt kicker for xpd_server visualization") + server.install_qt_kicker() + logger.info("xpd_server starting event loop") + server.start() + + +def run_save_server(cfg_file): + """Run the XPD save server.""" + warnings.simplefilter("ignore") + logger.debug(f"Loading xpdsave_server config from: {cfg_file}") + from pdfstream.servers.xpdsave_server import XPDSaveServerConfig, XPDSaveServer + config = XPDSaveServerConfig() + config.read(cfg_file) + logger.info("xpdsave_server config loaded successfully") + server = XPDSaveServer(config) + logger.info("xpdsave_server starting event loop") + server.start() + + +def run_vis_server(cfg_file): + """Run the XPD visualization server.""" + warnings.simplefilter("ignore") + logger.debug(f"Loading xpdvis_server config from: {cfg_file}") + from pdfstream.servers.xpdvis_server import XPDVisServerConfig, XPDVisServer + config = XPDVisServerConfig() + config.read(cfg_file) + logger.info("xpdvis_server config loaded successfully") + server = XPDVisServer(config) + logger.debug("Installing Qt kicker for xpdvis_server") + server.install_qt_kicker() + logger.info("xpdvis_server starting event loop") + server.start() + + +def main(): + parser = argparse.ArgumentParser( + description="Start pdfstream analysis servers on localhost with sensible defaults." + ) + parser.add_argument("--xpd-config", default=None, + help="Path to the xpd_server .ini config file (generated if omitted).") + parser.add_argument("--save-config", default=None, + help="Path to the xpdsave_server .ini config (generated if omitted).") + parser.add_argument("--vis-config", default=None, + help="Path to the xpdvis_server .ini config (generated if omitted).") + parser.add_argument("--no-save", action="store_true", help="Skip starting the save server.") + parser.add_argument("--no-vis", action="store_true", help="Skip starting the vis server.") + parser.add_argument("--print-configs", action="store_true", + help="Print the default configs to stdout and exit (useful as a starting point).") + parser.add_argument("-v", "--verbose", action="count", default=0, + help="Increase verbosity (-v for INFO, -vv for DEBUG).") + args = parser.parse_args() + + # Configure logging based on verbosity + if args.verbose >= 2: + log_level = logging.DEBUG + elif args.verbose >= 1: + log_level = logging.INFO + else: + log_level = logging.WARNING + logging.basicConfig( + level=log_level, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + logger.debug(f"Verbosity level: {args.verbose} (log_level={logging.getLevelName(log_level)})") + + if args.print_configs: + print("=" * 60) + print(" xpd_server.ini") + print("=" * 60) + print(_default_xpd_config()) + print("=" * 60) + print(" xpdsave_server.ini") + print("=" * 60) + print(_default_save_config()) + print("=" * 60) + print(" xpdvis_server.ini") + print("=" * 60) + print(_default_vis_config()) + return + + # Generate default configs for any that weren't provided + tmpdir = tempfile.mkdtemp(prefix="pdfstream_configs_") + logger.debug(f"Temporary config directory: {tmpdir}") + + xpd_cfg = args.xpd_config or _write_default_config(tmpdir, "xpd_server", _default_xpd_config()) + save_cfg = args.save_config or _write_default_config(tmpdir, "xpdsave_server", _default_save_config()) + vis_cfg = args.vis_config or _write_default_config(tmpdir, "xpdvis_server", _default_vis_config()) + logger.debug(f"Config paths: xpd={xpd_cfg}, save={save_cfg}, vis={vis_cfg}") + + if not args.xpd_config: + print(f"Using generated xpd config: {xpd_cfg}") + if not args.save_config and not args.no_save: + print(f"Using generated save config: {save_cfg}") + if not args.vis_config and not args.no_vis: + print(f"Using generated vis config: {vis_cfg}") + + processes = [] + + # Start the analyzed data proxy first (Publisher connects to in_port, subscribers to out_port) + if not args.no_save or not args.no_vis: + print(f"\nStarting analyzed proxy in={HOST}:{ANALYZED_IN_PORT} out={HOST}:{ANALYZED_OUT_PORT}") + p_proxy = multiprocessing.Process(target=run_analyzed_proxy, name="analyzed_proxy", daemon=True) + p_proxy.start() + processes.append(p_proxy) + logger.info(f"Analyzed proxy process started (pid={p_proxy.pid})") + + print(f"Starting xpd_server listening={HOST}:{RAW_PROXY_PORT} publishing={HOST}:{ANALYZED_IN_PORT}") + logger.debug(f"xpd_server config: {xpd_cfg}") + p_xpd = multiprocessing.Process(target=run_xpd_server, args=(xpd_cfg,), name="xpd_server") + p_xpd.start() + processes.append(p_xpd) + logger.info(f"xpd_server process started (pid={p_xpd.pid})") + + if not args.no_save: + print(f"Starting xpdsave_server listening={HOST}:{ANALYZED_OUT_PORT}") + logger.debug(f"xpdsave_server config: {save_cfg}") + p_save = multiprocessing.Process(target=run_save_server, args=(save_cfg,), name="xpdsave_server") + p_save.start() + processes.append(p_save) + logger.info(f"xpdsave_server process started (pid={p_save.pid})") + else: + logger.info("Skipping xpdsave_server (--no-save)") + + if not args.no_vis: + print(f"Starting xpdvis_server listening={HOST}:{ANALYZED_OUT_PORT}") + logger.debug(f"xpdvis_server config: {vis_cfg}") + p_vis = multiprocessing.Process(target=run_vis_server, args=(vis_cfg,), name="xpdvis_server") + p_vis.start() + processes.append(p_vis) + logger.info(f"xpdvis_server process started (pid={p_vis.pid})") + else: + logger.info("Skipping xpdvis_server (--no-vis)") + + print(f"\n{len(processes)} server(s) running. Press Ctrl+C to stop all.") + logger.debug(f"All process PIDs: {[p.pid for p in processes]}") + + def shutdown(sig, frame): + print("\nShutting down servers...") + for p in processes: + logger.info(f"Terminating {p.name} (pid={p.pid})") + p.terminate() + for p in processes: + p.join(timeout=5) + if p.is_alive(): + logger.warning(f"{p.name} (pid={p.pid}) did not exit in time, killing") + p.kill() + else: + logger.info(f"{p.name} (pid={p.pid}) exited with code {p.exitcode}") + sys.exit(0) + + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + + for p in processes: + p.join() + + +if __name__ == "__main__": + main() diff --git a/tests/analyzers/test_base.py b/tests/analyzers/test_base.py index f68207ec..e764fdfa 100644 --- a/tests/analyzers/test_base.py +++ b/tests/analyzers/test_base.py @@ -1,23 +1,33 @@ -import databroker import pytest +from bluesky_tiled_plugins import TiledWriter +from bluesky_tiled_plugins.exporters import json_seq_exporter +from tiled.client import from_uri +from tiled.media_type_registration import default_serialization_registry +from tiled.server import SimpleTiledServer + +# Register the json-seq exporter so run.documents() works with SimpleTiledServer +default_serialization_registry.register("BlueskyRun", "application/json-seq", json_seq_exporter) import pdfstream.analyzers.base as mod from pdfstream.callbacks.composer import gen_stream @pytest.fixture(scope="function") -def db_with_fake_an(): - """A database that has a fake analysis run.""" - db = databroker.v2.temp() +def db_with_fake_an(tmp_path): + """A tiled catalog that has a fake analysis run.""" + server = SimpleTiledServer(readable_storage=[str(tmp_path)]) + client = from_uri(server.uri) + tw = TiledWriter(client, batch_size=1) for name, doc in gen_stream([], {"an_config": {"SECTION": {"key": "value"}}}): - db.v1.insert(name, doc) - return db + tw(name, doc) + yield client + server.close() def test_AnalyzerConfig(db_with_fake_an): db = db_with_fake_an config = mod.AnalyzerConfig() - config.read_run(db[-1]) + config.read_run(db.values().last()) assert config.sections() == ["SECTION"] assert config["SECTION"]["key"] == "value" @@ -25,4 +35,4 @@ def test_AnalyzerConfig(db_with_fake_an): def test_Analyzer(db_with_fake_an): db = db_with_fake_an analyzer = mod.Analyzer() - analyzer.analyze(db[-1]) + analyzer.analyze(db.values().last()) diff --git a/tests/analyzers/test_xpd_analyzer.py b/tests/analyzers/test_xpd_analyzer.py index ec9955f9..060812cb 100644 --- a/tests/analyzers/test_xpd_analyzer.py +++ b/tests/analyzers/test_xpd_analyzer.py @@ -12,5 +12,5 @@ def test_XPDAnalyzer(db_with_img_and_bg_img, tmpdir): config.calib_base = str(tmpdir) config.read(fn) analyzer = mod.XPDAnalyzer(config) - run = raw_db[-1] + run = raw_db.values().last() analyzer.analyze(run) diff --git a/tests/callbacks/test_analysis.py b/tests/callbacks/test_analysis.py index 33c6e240..1d06a20f 100644 --- a/tests/callbacks/test_analysis.py +++ b/tests/callbacks/test_analysis.py @@ -7,6 +7,7 @@ import pdfstream.callbacks import pdfstream.callbacks.analysis as an +from pdfstream.analyzers.base import iter_documents_filled from pdfstream.schemas import analysis_out_schemas, analysis_in_schemas, Validator fn = str(files("tests").joinpath("configs/xpd_server.ini")) @@ -20,13 +21,13 @@ def test_AnalysisStream(db_with_img_and_bg_img, use_db, tmp_path): config["ANALYSIS"]["tiff_base"] = str(tmp_path) ld = an.AnalysisStream(config) if use_db: - ld.db = db.v1 + ld.db = db # validate that output data out_validator = Validator(analysis_out_schemas) ld.subscribe(out_validator) # validate the input data in_validator = Validator(analysis_in_schemas) - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(db.values())[-1]): in_validator(name, doc) # test no numpy array ld(name, doc) @@ -54,13 +55,13 @@ def test_AnalysisStream_with_UserConfig(db_with_img_and_bg_img, user_config): config = an.AnalysisConfig() config.read(fn) ld = an.AnalysisStream(config) - ld.db = db.v1 + ld.db = db # validate that output data out_validator = Validator(analysis_out_schemas) ld.subscribe(out_validator) # validate the input data in_validator = Validator(analysis_in_schemas) - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(db.values())[-1]): if name == "start": doc = dict(**doc, user_config=user_config) in_validator(name, doc) @@ -72,13 +73,13 @@ def test_Visualizer(db_with_dark_and_scan): config = an.AnalysisConfig() config.read(fn) ld = an.AnalysisStream(config) - ld.db = db.v1 + ld.db = db config1 = pdfstream.callbacks.analysis.VisConfig() config1.read(fn) config1.fig = plt.figure() cb = pdfstream.callbacks.analysis.Visualizer(config1) ld.subscribe(cb) - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(db.values())[-1]): ld(name, doc) cb.show_figs() @@ -87,14 +88,15 @@ def test_Exporter(db_with_dark_and_scan, tmpdir): db = db_with_dark_and_scan config = an.AnalysisConfig() config.read(fn) + config["ANALYSIS"]["tiff_base"] = str(tmpdir) ld = an.AnalysisStream(config) - ld.db = db.v1 + ld.db = db ep_config = pdfstream.callbacks.analysis.ExportConfig() ep_config.read(fn) ep_config.tiff_base = str(tmpdir) ep = pdfstream.callbacks.analysis.Exporter(ep_config) ld.subscribe(ep) - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(db.values())[-1]): ld(name, doc) tiff_base = Path(ep_config.tiff_base) # test the files are output @@ -102,10 +104,10 @@ def test_Exporter(db_with_dark_and_scan, tmpdir): assert len(list(tiff_base.rglob("mask/*.npy"))) > 0 assert len(list(tiff_base.rglob("scalar_data/*.csv"))) > 0 assert len(list(tiff_base.rglob("integration/*.chi"))) > 1 - assert len(list(tiff_base.rglob("meta/*.yml"))) > 0 + assert len(list(tiff_base.rglob("meta/*.yaml"))) > 0 assert len(list(tiff_base.rglob("sq/*.sq"))) > 0 assert len(list(tiff_base.rglob("fq/*.fq"))) > 0 - assert len(list(tiff_base.rglob("gr/*.gr"))) > 0 + assert len(list(tiff_base.rglob("pdf/*.gr"))) > 0 def test_filenames(db_with_dark_and_scan, tmpdir): @@ -114,21 +116,21 @@ def test_filenames(db_with_dark_and_scan, tmpdir): config = an.AnalysisConfig() config.read(fn) ld = an.AnalysisStream(config) - ld.db = db.v1 + ld.db = db ep_config = pdfstream.callbacks.analysis.ExportConfig() ep_config.read(fn) ep_config.tiff_base = str(tmpdir) ep = pdfstream.callbacks.analysis.Exporter(ep_config) ld.subscribe(ep) - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(db.values())[-1]): ld(name, doc) - -def test_ExportConfig(): - config = pdfstream.callbacks.analysis.ExportConfig() - config.read(fn) - with pytest.raises(Error): - assert config.tiff_base +# No longer raises error when tiff_base is missing, so this test is not valid anymore +# def test_ExportConfig(): +# config = pdfstream.callbacks.analysis.ExportConfig() +# config.read(fn) +# with pytest.raises(Error): +# assert config.tiff_base def test_user_mask1(db_with_img_and_bg_img): @@ -136,8 +138,8 @@ def test_user_mask1(db_with_img_and_bg_img): config = an.AnalysisConfig() config.read(fn) ld = an.AnalysisStream(config) - ld.db = db.v1 - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + ld.db = db + for name, doc in iter_documents_filled(list(db.values())[-1]): if name == "start": doc = dict(**doc, user_config={"auto_mask": False}) ld(name, doc) @@ -148,8 +150,8 @@ def test_user_mask2(db_with_img_and_bg_img, test_data): config = an.AnalysisConfig() config.read(fn) ld = an.AnalysisStream(config) - ld.db = db.v1 - for name, doc in db[-1].canonical(fill="yes", strict_order=True): + ld.db = db + for name, doc in iter_documents_filled(list(db.values())[-1]): if name == "start": doc = dict(**doc, user_config={"auto_mask": False, "mask_file": test_data["mask_file"]}) ld(name, doc) diff --git a/tests/callbacks/test_calibration.py b/tests/callbacks/test_calibration.py index 1f7e826a..866240a4 100644 --- a/tests/callbacks/test_calibration.py +++ b/tests/callbacks/test_calibration.py @@ -1,6 +1,7 @@ from importlib.resources import files import pdfstream.callbacks.calibration as mod +from pdfstream.analyzers.base import iter_documents_filled fn = str(files("tests").joinpath("configs/xpd_server.ini")) @@ -11,7 +12,7 @@ def test_Calibration(db_with_dark_and_calib, tmpdir): config.read(fn) config.calib_base = str(tmpdir) cb = mod.Calibration(config, test=True) - for name, doc in db[-1].canonical(fill="yes"): + for name, doc in iter_documents_filled(list(db.values())[-1]): cb(name, doc) assert len(list(config.calib_base.rglob("*.tiff"))) > 0 @@ -24,7 +25,7 @@ def test_Calibration_error(db_with_dark_and_calib, tmpdir): config.tiff_base = str(tmpdir) config.calib_base = str(tmpdir) cb = mod.Calibration(config, test=True) - for name, doc in db[-1].canonical(fill="yes"): + for name, doc in iter_documents_filled(list(db.values())[-1]): if name == "start": doc = dict(**doc) doc.update({"bt_wavelength": None}) diff --git a/tests/callbacks/test_filling.py b/tests/callbacks/test_filling.py new file mode 100644 index 00000000..60bb4bfa --- /dev/null +++ b/tests/callbacks/test_filling.py @@ -0,0 +1,163 @@ +"""Tests for the TiledFiller callback with stream_resource/stream_datum documents.""" +import numpy as np +from tiled.client import from_uri + +from pdfstream.callbacks.composer import gen_stream, gen_stream_external +from pdfstream.callbacks.filling import TiledFiller + +# Keys that compose_run sets internally and should not be passed as metadata +_INTERNAL_START_KEYS = {"uid", "time", "versions"} + + +def _start_metadata(run): + """Extract user metadata from a run's start doc, excluding internal keys.""" + return {k: v for k, v in run.start.items() if k not in _INTERNAL_START_KEYS} + + +def test_tiled_filler_stream_resource(db_with_dark_and_light, tiled_server): + """Test that TiledFiller fills events when data uses stream_resource/stream_datum. + + Simulates the production flow where: + 1. Raw data is already ingested into tiled (via the fixture) + 2. ZMQ delivers unfilled documents with stream_resource/stream_datum references + 3. TiledFiller reads the actual data from tiled and fills the events + 4. Downstream subscribers receive filled events + """ + db = db_with_dark_and_light + # Get the light run (last one inserted) + light_run = list(db.values())[-1] + run_uid = light_run.start["uid"] + + # Read the expected image data from tiled + expected_image = np.asarray(light_run["primary"]["pe1_image"][0]) + + # Generate an unfilled document stream with stream_resource/stream_datum, + # using the same run UID so TiledFiller can look it up in tiled + data_lst = [{"pe1_image": expected_image}] + metadata = _start_metadata(light_run) + + doc_stream = list(gen_stream_external( + data_lst, metadata, external_keys={"pe1_image"}, uid=run_uid + )) + + # Verify the doc stream contains the expected document types + doc_names = [name for name, _ in doc_stream] + assert "stream_resource" in doc_names + assert "stream_datum" in doc_names + + # Verify events have unfilled pe1_image + events_in = [(name, doc) for name, doc in doc_stream if name == "event"] + assert len(events_in) == 1 + assert events_in[0][1]["filled"]["pe1_image"] is False + + # Set up TiledFiller and capture filled output + filler_client = from_uri(tiled_server.uri) + filler = TiledFiller(filler_client) + filled_events = [] + + def capture(name, doc): + if name == "event": + filled_events.append(doc) + + filler.subscribe(capture) + + # Feed the unfilled doc stream through the filler + for name, doc in doc_stream: + filler(name, doc) + + # Verify the filler produced a filled event + assert len(filled_events) == 1 + event = filled_events[0] + assert event["filled"]["pe1_image"] is True + assert isinstance(event["data"]["pe1_image"], np.ndarray) + assert np.allclose(event["data"]["pe1_image"], expected_image) + + +def test_tiled_filler_multiple_events(db_with_dark_and_scan, tiled_server): + """Test TiledFiller fills multiple events from a scan with stream_resource/stream_datum.""" + db = db_with_dark_and_scan + # Find the scan run (not the dark frame) + scan_run = None + for run in db.values(): + if not run.start.get("dark_frame") and "pe1_image" in run.get("primary", {}).keys(): + if "temperature" in run["primary"].keys(): + scan_run = run + break + assert scan_run is not None, "Could not find scan run in db_with_dark_and_scan" + run_uid = scan_run.start["uid"] + stream = scan_run["primary"] + + n_events = len(stream["pe1_image"].read()) + expected_images = [np.asarray(stream["pe1_image"][i]) for i in range(n_events)] + expected_temps = stream["temperature"].read() + + # Build data list matching what's in tiled + data_lst = [ + {"pe1_image": expected_images[i], "temperature": float(expected_temps[i])} + for i in range(n_events) + ] + + metadata = _start_metadata(scan_run) + + doc_stream = list(gen_stream_external( + data_lst, metadata, external_keys={"pe1_image"}, uid=run_uid + )) + + # Verify stream_datum count matches events + sd_count = sum(1 for name, _ in doc_stream if name == "stream_datum") + assert sd_count == n_events + + # Set up TiledFiller + filler_client = from_uri(tiled_server.uri) + filler = TiledFiller(filler_client) + filled_events = [] + + def capture(name, doc): + if name == "event": + filled_events.append(doc) + + filler.subscribe(capture) + + for name, doc in doc_stream: + filler(name, doc) + + assert len(filled_events) == n_events + for i, event in enumerate(filled_events): + assert event["filled"]["pe1_image"] is True + assert np.allclose(event["data"]["pe1_image"], expected_images[i]) + # temperature is internal (not external), should pass through as-is + assert event["data"]["temperature"] == float(expected_temps[i]) + + +def test_tiled_filler_forwards_all_docs(db_with_dark_and_light, tiled_server): + """Test that TiledFiller forwards start, descriptor, event, and stop to subscribers.""" + db = db_with_dark_and_light + light_run = list(db.values())[-1] + run_uid = light_run.start["uid"] + expected_image = np.asarray(light_run["primary"]["pe1_image"][0]) + + data_lst = [{"pe1_image": expected_image}] + metadata = _start_metadata(light_run) + + doc_stream = list(gen_stream_external( + data_lst, metadata, external_keys={"pe1_image"}, uid=run_uid + )) + + filler_client = from_uri(tiled_server.uri) + filler = TiledFiller(filler_client) + received = [] + + def capture(name, doc): + received.append(name) + + filler.subscribe(capture) + + for name, doc in doc_stream: + filler(name, doc) + + # TiledFiller should forward start, descriptor, event, stop + # (stream_resource and stream_datum are not forwarded by CallbackBase) + assert "start" in received + assert "descriptor" in received + assert "event" in received + assert "stop" in received diff --git a/tests/callbacks/test_from_start.py b/tests/callbacks/test_from_start.py index 91e7b503..e8de5db6 100644 --- a/tests/callbacks/test_from_start.py +++ b/tests/callbacks/test_from_start.py @@ -8,7 +8,7 @@ def test_query_ai(db_with_dark_and_light): db = db_with_dark_and_light - start = db[-1].metadata["start"] + start = list(db.values())[-1].metadata["start"] mod.query_ai(start, "calibration_md") @@ -21,7 +21,7 @@ def test_query_ai(db_with_dark_and_light): ) def test_get_img_from_run(db_with_dark_and_light, det_name, shape): db = db_with_dark_and_light - img = mod.get_img_from_run_v1(db.v1[-1], det_name) + img = mod.get_img_from_run(list(db.values())[-1], det_name) assert img.shape == shape @@ -34,8 +34,8 @@ def test_get_img_from_run(db_with_dark_and_light, det_name, shape): ) def test_query_dk_img(db_with_dark_and_light, dk_id_key, shape): db = db_with_dark_and_light - start = db[-1].metadata["start"] - dk_img = mod.query_dk_img(start, det_name="pe1_image", db=db.v1, dk_id_key=dk_id_key) + start = list(db.values())[-1].metadata["start"] + dk_img = mod.query_dk_img(start, det_name="pe1_image", db=db, dk_id_key=dk_id_key) if shape: assert isinstance(dk_img, np.ndarray) assert dk_img.shape == shape diff --git a/tests/configs/xpd_server.ini b/tests/configs/xpd_server.ini index 8a904e0a..770c0446 100644 --- a/tests/configs/xpd_server.ini +++ b/tests/configs/xpd_server.ini @@ -21,7 +21,7 @@ prefix = an [DATABASE] # raw_db = -an_db = temp +# an_db = [METADATA] dk_identifier = dark_frame diff --git a/tests/conftest.py b/tests/conftest.py index 0fb3ec8e..8f5b8523 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,14 +3,22 @@ import uuid from pathlib import Path -import databroker +import matplotlib +matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy import numpy as np import pyFAI import pytest -from databroker.v2 import Broker +from bluesky_tiled_plugins import TiledWriter +from bluesky_tiled_plugins.exporters import json_seq_exporter from diffpy.pdfgetx import PDFConfig, PDFGetter +from tiled.client import from_uri +from tiled.media_type_registration import default_serialization_registry +from tiled.server import SimpleTiledServer + +# Register the json-seq exporter so run.documents() works with SimpleTiledServer +default_serialization_registry.register("BlueskyRun", "application/json-seq", json_seq_exporter) from importlib.resources import files from pdfstream.callbacks.composer import gen_stream @@ -110,24 +118,37 @@ def array_stream(): @pytest.fixture(scope="session") -def db_with_dark_and_light() -> Broker: - """A database with a dark run and a light run inside. The last one is light and the first one is dark.""" - db = databroker.v2.temp() +def tiled_server(tmp_path_factory): + """A shared tiled server for all test fixtures.""" + tmp_dir = tmp_path_factory.mktemp("tiled_data") + server = SimpleTiledServer(readable_storage=[str(tmp_dir)]) + yield server + server.close() + + +def _insert_docs(tiled_client, doc_stream): + """Insert documents from a gen_stream into a tiled server via TiledWriter.""" + tw = TiledWriter(tiled_client, batch_size=1) + for name, doc in doc_stream: + tw(name, doc) + + +@pytest.fixture(scope="session") +def db_with_dark_and_light(tiled_server): + """A tiled catalog with a dark run and a light run inside.""" + client = from_uri(tiled_server.uri) dark_data = [{"pe1_image": np.zeros_like(NI_FRAMES)}] dark_uid = str(uuid.uuid4()) - for name, doc in gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid): - db.v1.insert(name, doc) + _insert_docs(client, gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid)) light_data = [{"pe1_image": NI_FRAMES}] - for name, doc in gen_stream(light_data, dict(**START_DOC, sc_dk_field_uid=dark_uid)): - db.v1.insert(name, doc) - return db + _insert_docs(client, gen_stream(light_data, dict(**START_DOC, sc_dk_field_uid=dark_uid))) + return client @pytest.fixture(scope="session") -def db_with_img_and_bg_img() -> Broker: - """A database with a dark image, a background image run and a data image run inside. The first one is dark - image, the second one is background image, the third one is the data image.""" - db = databroker.v2.temp() +def db_with_img_and_bg_img(tiled_server): + """A tiled catalog with a dark image, a background image run and a data image run inside.""" + client = from_uri(tiled_server.uri) sample_name = "Kapton" dk_uid = str(uuid.uuid4()) dk_meta = {"dark_frame": True} @@ -136,23 +157,19 @@ def db_with_img_and_bg_img() -> Broker: bg_data = [{"pe1_image": 2 * np.ones_like(NI_FRAMES)}] img_data = [{"pe1_image": 2 * np.ones_like(NI_FRAMES) + NI_FRAMES}] img_meta = dict(**START_DOC, bkgd_sample_name=sample_name, sc_dk_field_uid=dk_uid, sample_name="Ni") - for name, doc in gen_stream(dk_data, dk_meta, uid=dk_uid): - db.v1.insert(name, doc) - for name, doc in gen_stream(bg_data, bg_meta): - db.v1.insert(name, doc) - for name, doc in gen_stream(img_data, img_meta): - db.v1.insert(name, doc) - return db + _insert_docs(client, gen_stream(dk_data, dk_meta, uid=dk_uid)) + _insert_docs(client, gen_stream(bg_data, bg_meta)) + _insert_docs(client, gen_stream(img_data, img_meta)) + return client @pytest.fixture(scope="session") -def db_with_dark_and_scan() -> Broker: - """A database with a dark run and a motor scan inside. The last one is light and the first one is dark.""" - db = databroker.v2.temp() +def db_with_dark_and_scan(tiled_server): + """A tiled catalog with a dark run and a motor scan inside.""" + client = from_uri(tiled_server.uri) dark_data = [{"pe1_image": np.zeros_like(NI_FRAMES)}] dark_uid = str(uuid.uuid4()) - for name, doc in gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid): - db.v1.insert(name, doc) + _insert_docs(client, gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid)) light_data = [ {"pe1_image": NI_FRAMES, "temperature": 0}, {"pe1_image": NI_FRAMES, "temperature": 1}, @@ -166,21 +183,19 @@ def db_with_dark_and_scan() -> Broker: "sample_name": "Ni" } ) - for name, doc in gen_stream(light_data, start): - db.v1.insert(name, doc) - return db + _insert_docs(client, gen_stream(light_data, start)) + return client @pytest.fixture(scope="session") -def db_with_dark_and_calib() -> Broker: - """A database with a dark run and a light run inside. The last one is light and the first one is dark.""" - db = databroker.v2.temp() +def db_with_dark_and_calib(tiled_server): + """A tiled catalog with a dark run and a calibration run inside.""" + client = from_uri(tiled_server.uri) dark_data = [{"pe1_image": np.zeros_like(NI_FRAMES)}] dark_uid = str(uuid.uuid4()) - for name, doc in gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid): - db.v1.insert(name, doc) + _insert_docs(client, gen_stream(dark_data, {"dark_frame": True}, uid=dark_uid)) light_data = [{"pe1_image": NI_FRAMES}] - for name, doc in gen_stream( + _insert_docs(client, gen_stream( light_data, dict( sample_composition="Ni", sc_dk_field_uid=dark_uid, @@ -188,16 +203,14 @@ def db_with_dark_and_calib() -> Broker: is_calibration=True, bt_wavelength=0.1917 ) - ): - db.v1.insert(name, doc) - return db + )) + return client @pytest.fixture(scope="session") -def db_with_dark_bg_no_calib() -> Broker: - """A database with a dark image, a background image run and a data image run inside. The first one is dark - image, the second one is background image, the third one is the data image.""" - db = databroker.v2.temp() +def db_with_dark_bg_no_calib(tiled_server): + """A tiled catalog with a dark image, a background image run and a data image run without calibration.""" + client = from_uri(tiled_server.uri) sample_name = "Kapton" dk_uid = str(uuid.uuid4()) dk_meta = {"dark_frame": True} @@ -207,10 +220,7 @@ def db_with_dark_bg_no_calib() -> Broker: img_data = [{"pe1_image": 2 * np.ones_like(NI_FRAMES) + NI_FRAMES}] img_meta = dict(**START_DOC, bkgd_sample_name=sample_name, sc_dk_field_uid=dk_uid, sample_name="Ni") img_meta.pop("calibration_md") - for name, doc in gen_stream(dk_data, dk_meta, uid=dk_uid): - db.v1.insert(name, doc) - for name, doc in gen_stream(bg_data, bg_meta): - db.v1.insert(name, doc) - for name, doc in gen_stream(img_data, img_meta): - db.v1.insert(name, doc) - return db + _insert_docs(client, gen_stream(dk_data, dk_meta, uid=dk_uid)) + _insert_docs(client, gen_stream(bg_data, bg_meta)) + _insert_docs(client, gen_stream(img_data, img_meta)) + return client diff --git a/tests/servers/test_xpd_server.py b/tests/servers/test_xpd_server.py index 0f6ac7e3..bd1b2658 100644 --- a/tests/servers/test_xpd_server.py +++ b/tests/servers/test_xpd_server.py @@ -4,6 +4,7 @@ from importlib.resources import files import pdfstream.servers.xpd_server as mod +from pdfstream.analyzers.base import iter_documents_filled fn = str(files("tests").joinpath("configs/xpd_server.ini")) @@ -17,7 +18,6 @@ def test_XPDServer(tmpdir): config.read(fn) config.tiff_base = str(tmpdir) config.calib_base = str(tmpdir) - config["FUNCTIONALITY"]["send_messages"] = "True" mod.XPDServer(config) @@ -47,11 +47,11 @@ def test_XPDRouter(db_with_img_and_bg_img, tmpdir): config.tiff_base = str(tmpdir) config.calib_base = str(tmpdir) cb = mod.XPDRouter(config) - for name, doc in raw_db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(raw_db.values())[-1]): cb(name, doc) tiff_base = Path(config.tiff_base) assert len(list(tiff_base.rglob("*.tiff"))) > 0 - assert len(list(tiff_base.rglob("*.json"))) > 0 + assert len(list(tiff_base.rglob("*.yaml"))) > 0 assert len(list(tiff_base.rglob("*.csv"))) > 0 @@ -62,9 +62,9 @@ def test_XPDRouter_no_calib(db_with_dark_bg_no_calib, tmpdir): config.tiff_base = str(tmpdir) config.calib_base = str(tmpdir) cb = mod.XPDRouter(config) - for name, doc in raw_db[-1].canonical(fill="yes", strict_order=True): + for name, doc in iter_documents_filled(list(raw_db.values())[-1]): cb(name, doc) tiff_base = Path(config.tiff_base) assert len(list(tiff_base.rglob("*.tiff"))) > 0 - assert len(list(tiff_base.rglob("*.json"))) > 0 + assert len(list(tiff_base.rglob("*.yaml"))) > 0 assert len(list(tiff_base.rglob("*.csv"))) > 0