DM-54408: Add support for sources withdrawal to Apdb#137
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #137 +/- ##
==========================================
+ Coverage 84.46% 84.79% +0.33%
==========================================
Files 73 73
Lines 7304 7601 +297
Branches 844 894 +50
==========================================
+ Hits 6169 6445 +276
- Misses 899 905 +6
- Partials 236 251 +15 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Pull request overview
This PR adds “withdrawal” support for DIA sources/forced sources in APDB, including schema updates (new withdrawal timestamp columns) and backend implementations (SQL + Cassandra), along with new unit tests to validate the behavior.
Changes:
- Add
time_withdrawn/timeWithdrawnMjdTaicolumns to DiaSource and DiaForcedSource schemas. - Introduce new APDB API methods
withdrawDiaSources()andwithdrawDiaForcedSources()and implement them in SQL and Cassandra backends (including replication update records andnDiaSourceshandling). - Add test coverage for withdrawing sources and forced sources, and update schema column-count assertions.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_apdbSqlSchema.py |
Updates expected column counts for updated SQL schemas. |
tests/config/schema-datetime.yaml |
Adds nullable time_withdrawn timestamp columns for DiaSource/DiaForcedSource. |
tests/config/schema-apdb.yaml |
Adds nullable timeWithdrawnMjdTai columns for DiaSource/DiaForcedSource. |
tests/config/schema-apdb+sso.yaml |
Adds nullable timeWithdrawnMjdTai columns for DiaSource/DiaForcedSource in the SSO-combined schema. |
python/lsst/dax/apdb/tests/_apdb.py |
Adds unit tests for withdrawing sources/forced sources; updates schema column-count expectations. |
python/lsst/dax/apdb/sql/apdbSql.py |
Implements withdrawal methods; refactors setValidityEnd to share a connection; updates replication storage call signature. |
python/lsst/dax/apdb/cassandra/apdbCassandra.py |
Implements withdrawal methods; adds helper query method for forced sources; adds replication update record emission. |
python/lsst/dax/apdb/apdb.py |
Extends the abstract APDB interface with withdrawal methods and docstrings. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| config = context.config | ||
|
|
||
| if timeWithdrawn is None: | ||
| timeWithdrawn = astropy.time.Time.now() |
| statements.append(context.stmt_factory.with_params(query.where(clause), prepare=True)) | ||
|
|
||
| with self._timer( | ||
| "select_time", tags={"table": "DiaForcedSource", "method": "_get_diasource_data"} |
| # Check that there are 10 update records in replica tables. | ||
| assert replica_chunks is not None | ||
|
|
||
| # There could be one or two chunks. | ||
| self.assertTrue(1 <= len(replica_chunks) <= 2) | ||
|
|
||
| update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks]) | ||
| self.assertEqual(len(update_records), 12) | ||
|
|
| # Check that there are 10 update records in replica tables. | ||
| assert replica_chunks is not None | ||
|
|
||
| # There could be one or two chunks. | ||
| self.assertTrue(1 <= len(replica_chunks) <= 2) | ||
|
|
||
| update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks]) | ||
| self.assertEqual(len(update_records), 2) |
| # Set time_withdrawn for sources. | ||
| table = self._schema.get_table(ApdbTables.DiaSource) | ||
| where = table.columns["diaSourceId"].in_(sorted(source_ids)) | ||
| update = table.update().where(where).values({column_name: time_value}) | ||
| conn.execute(update) |
| source_ids = {source.diaSourceId for source in diaSourceIds} | ||
|
|
||
| if timeWithdrawn is None: | ||
| timeWithdrawn = astropy.time.Time.now() |
| Parameters | ||
| ---------- | ||
| diaForcedSourceIds : `~collections.abc.Iterable` [`DiaForcedSourceId`] | ||
| Identifiers of DiaSources to withdraw. | ||
| timeWithdrawn : `astropy.time.Time`, optional | ||
| Set the value of ``time_withdrawn`` column to this time, current | ||
| time by default. | ||
|
|
| config = context.config | ||
|
|
||
| if timeWithdrawn is None: | ||
| timeWithdrawn = astropy.time.Time.now() |
| # Find all DiaSources. | ||
| found_sources = self._get_diasource_data( | ||
| diaSourceIds, "apdb_part", "diaObjectId", "ra", "dec", "midpointMjdTai" | ||
| ) | ||
|
|
||
| if missing_ids := (source_ids - {row.diaSourceId for row in found_sources}): | ||
| raise LookupError(f"Some source IDs were not found in DiaSource table: {missing_ids}") | ||
|
|
||
| found_sources_by_id = {row.diaSourceId: row for row in found_sources} | ||
| original_object_ids = { | ||
| row.diaSourceId: row.diaObjectId for row in found_sources if row.diaObjectId is not None | ||
| } | ||
|
|
||
| update_records: list[ApdbUpdateRecord] = [] | ||
| update_order = 0 | ||
| current_time = self._current_time() | ||
| current_time_ns = int(current_time.unix_tai * 1e9) | ||
|
|
||
| # Update DiaSources. | ||
| statements: list[tuple] = [] | ||
| for source_id in diaSourceIds: | ||
| source_row = found_sources_by_id[source_id.diaSourceId] | ||
| apdb_part = source_row.apdb_part | ||
| time_part = context.partitioner.time_partition(source_row.midpointMjdTai) | ||
|
|
||
| if config.partitioning.time_partition_tables: | ||
| table_name = context.schema.tableName(ApdbTables.DiaSource, time_part) | ||
| update = ( | ||
| Update(self._keyspace, table_name) | ||
| .values(C(column_name).update(time_value)) | ||
| .where(C("apdb_part") == apdb_part) | ||
| .where(C("diaSourceId") == source_id.diaSourceId) | ||
| ) | ||
| else: | ||
| table_name = context.schema.tableName(ApdbTables.DiaSource) | ||
| update = ( | ||
| Update(self._keyspace, table_name) | ||
| .values(C(column_name).update(time_value)) | ||
| .where(C("apdb_part") == apdb_part) | ||
| .where(C("apdb_time_part") == time_part) | ||
| .where(C("diaSourceId") == source_id.diaSourceId) | ||
| ) | ||
| statements.append(context.stmt_factory.with_params(update, prepare=True)) | ||
|
|
||
| if context.schema.replication_enabled: | ||
| update_records.append( | ||
| ApdbWithdrawDiaSourceRecord( | ||
| diaSourceId=source_id.diaSourceId, | ||
| ra=source_id.ra, | ||
| dec=source_id.dec, | ||
| midpointMjdTai=source_id.midpointMjdTai, | ||
| update_time_ns=current_time_ns, | ||
| update_order=update_order, | ||
| timeWithdrawnMjdTai=float(timeWithdrawn.tai.mjd), | ||
| ) | ||
| ) | ||
| update_order += 1 | ||
|
|
||
| with self._timer("update_time", tags={"table": "DiaSource", "method": "withdrawDiaSources"}) as timer: | ||
| execute_concurrent(context.session, statements, execution_profile="write") | ||
| timer.add_values(num_queries=len(statements)) | ||
|
|
||
| if update_records: | ||
| replica_chunk = ReplicaChunk.make_replica_chunk(current_time, config.replica_chunk_seconds) | ||
| self._storeUpdateRecords(update_records, replica_chunk, store_chunk=True) | ||
|
|
||
| if decrement_nDiaSources: |
| found_fsources = self._get_diaforcedsource_data( | ||
| diaForcedSourceIds, "apdb_part", "ra", "dec", "midpointMjdTai" | ||
| ) | ||
|
|
||
| found_keys = {(row.diaObjectId, row.visit, row.detector) for row in found_fsources} | ||
| if missing_ids := (fsource_keys - found_keys): | ||
| raise LookupError(f"Some source IDs were not found in DiaForcedSource table: {missing_ids}") | ||
|
|
||
| statements: list[tuple] = [] | ||
| update_records = [] | ||
| update_order = 0 | ||
| current_time = self._current_time() | ||
| current_time_ns = int(current_time.unix_tai * 1e9) | ||
|
|
||
| for source_row in found_fsources: | ||
| apdb_part = source_row.apdb_part | ||
| time_part = context.partitioner.time_partition(source_row.midpointMjdTai) | ||
|
|
||
| if config.partitioning.time_partition_tables: | ||
| table_name = context.schema.tableName(ApdbTables.DiaForcedSource, time_part) | ||
| update = ( | ||
| Update(self._keyspace, table_name) | ||
| .values(C(column_name).update(time_value)) | ||
| .where(C("apdb_part") == apdb_part) | ||
| .where(C("diaObjectId") == source_row.diaObjectId) | ||
| .where(C("visit") == source_row.visit) | ||
| .where(C("detector") == source_row.detector) | ||
| ) | ||
| else: | ||
| table_name = context.schema.tableName(ApdbTables.DiaForcedSource) | ||
| update = ( | ||
| Update(self._keyspace, table_name) | ||
| .values(C(column_name).update(time_value)) | ||
| .where(C("apdb_part") == apdb_part) | ||
| .where(C("apdb_time_part") == time_part) | ||
| .where(C("diaObjectId") == source_row.diaObjectId) | ||
| .where(C("visit") == source_row.visit) | ||
| .where(C("detector") == source_row.detector) | ||
| ) | ||
| statements.append(context.stmt_factory.with_params(update, prepare=True)) | ||
|
|
||
| if context.schema.replication_enabled: | ||
| update_records.append( | ||
| ApdbWithdrawDiaForcedSourceRecord( | ||
| diaObjectId=source_row.diaObjectId, | ||
| visit=source_row.visit, | ||
| detector=source_row.detector, | ||
| ra=source_row.ra, | ||
| dec=source_row.dec, | ||
| midpointMjdTai=source_row.midpointMjdTai, | ||
| update_time_ns=current_time_ns, | ||
| update_order=update_order, | ||
| timeWithdrawnMjdTai=float(timeWithdrawn.tai.mjd), | ||
| ) | ||
| ) | ||
| update_order += 1 | ||
|
|
No description provided.