diff --git a/fix/rewrite_sta_ssp_ids.py b/fix/rewrite_sta_ssp_ids.py new file mode 100644 index 0000000..c3d0a5e --- /dev/null +++ b/fix/rewrite_sta_ssp_ids.py @@ -0,0 +1,106 @@ +""" +Fix ScheduledStopPoint IDs to match SIRI feed format. + +Transforms IDs from: IT:ITH1:ScheduledStopPoint:it-22021-7010-51-32073: +to: IT:ITH10:ScheduledStopPoint:7010:51:32073 + +This is needed so that NeTEx and SIRI feeds reference the same stops. + +It is a mystery to me why this cannot be fixed at the source. +""" + +import dataclasses +import logging +import re +from collections.abc import Generator +from pathlib import Path +from typing import Any + +from domain.netex.model import ( + PassengerStopAssignment, + Route, + RoutePoint, + RoutePointRef, + ScheduledStopPoint, + ScheduledStopPointRef, + ServiceJourneyPattern, + ServiceLink, + TimingLink, +) +from domain.netex.services.recursive_attributes import recursive_attributes +from storage.mdbx.core.implementation import MdbxStorage +from utils.aux_logging import log_all, prepare_logger + +_PATTERN = re.compile(r'^.*:ScheduledStopPoint:it-22021-(.+):$') + + +def _new_id(old_id: str) -> str | None: + m = _PATTERN.match(old_id) + if m: + return 'IT:ITH10:ScheduledStopPoint:' + m.group(1).replace('-', ':') + return None + + +# Object types that may transitively contain ScheduledStopPointRef or RoutePointRef. +_REF_BEARING_TYPES = [ + ServiceJourneyPattern, + ServiceLink, + TimingLink, + PassengerStopAssignment, + Route, + RoutePoint, +] + + +def _iter_updated_objects( + db: MdbxStorage, txn: Any, +) -> Generator[Any, None, None]: + for cls in _REF_BEARING_TYPES: + for obj in db.iter_only_objects(txn, cls): + changed = False + for ref, _path in recursive_attributes(obj, []): + if isinstance(ref, (ScheduledStopPointRef, RoutePointRef)): + new_ref = _new_id(ref.ref) + if new_ref is not None: + ref.ref = new_ref + changed = True + if changed: + yield obj + + +def _iter_renamed_ssps( + db: MdbxStorage, txn: Any, +) -> Generator[ScheduledStopPoint, None, None]: + for ssp in db.iter_only_objects(txn, ScheduledStopPoint): + new_id = _new_id(ssp.id) + if new_id is not None: + yield dataclasses.replace(ssp, id=new_id) + + +def fix_ssp_ids(database: Path) -> None: + with MdbxStorage(database, readonly=False) as db: + with db.env.rw_transaction() as txn: + # TODO: delete the old ScheduledStopPoint objects (no delete API available yet) + db.insert_any_object_on_queue(txn, _iter_updated_objects(db, txn)) + db.insert_any_object_on_queue(txn, _iter_renamed_ssps(db, txn)) + txn.commit() + + +def main(source_database_file: str) -> None: + fix_ssp_ids(Path(source_database_file)) + + +if __name__ == "__main__": + import argparse + import traceback + + parser = argparse.ArgumentParser(description="Fix ScheduledStopPoint IDs to SIRI format") + parser.add_argument("source", type=str, help="mdbx file to fix in-place") + parser.add_argument("--log_file", type=str, required=False, help="log file path") + args = parser.parse_args() + prepare_logger(logging.INFO, args.log_file) + try: + main(args.source) + except Exception as e: + log_all(logging.ERROR, f"{e}") + raise e diff --git a/test/rewrite_sta_ssp_ids.py b/test/rewrite_sta_ssp_ids.py new file mode 100644 index 0000000..b482790 --- /dev/null +++ b/test/rewrite_sta_ssp_ids.py @@ -0,0 +1,10 @@ +import unittest + +from fix.rewrite_sta_ssp_ids import main + +class FixSSPTestCase(unittest.TestCase): + def test(self): + main("sta.lmdb") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file