diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0e6507e..6e1b395 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -46,11 +46,9 @@ jobs: conda install -y -q pip wheel pip install uv - # Need cassandra-driver from conda-forge where it has a patch for 3.13 - name: Install dependencies shell: bash -l {0} run: | - conda install -y -q cassandra-driver uv pip install -r requirements.txt # We have two cores so we can speed up the testing with xdist diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 3add4dc..f55df9d 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -11,4 +11,4 @@ jobs: uses: lsst/rubin_workflows/.github/workflows/mypy.yaml@main with: folders: "python tests" - mypy_package: "mypy<1.20" + mypy_package: "mypy<2.0" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 855447c..5836f7a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-yaml args: @@ -10,7 +10,7 @@ repos: - id: check-toml - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.9.10 + rev: v0.14.5 hooks: - id: ruff args: [--fix] diff --git a/python/lsst/dax/apdb/cassandra/apdbCassandra.py b/python/lsst/dax/apdb/cassandra/apdbCassandra.py index 0c0e0ba..7cc2555 100644 --- a/python/lsst/dax/apdb/cassandra/apdbCassandra.py +++ b/python/lsst/dax/apdb/cassandra/apdbCassandra.py @@ -178,6 +178,7 @@ def init_database( replica_skips_diaobjects: bool = False, port: int | None = None, username: str | None = None, + dbauth_alias: str | None = None, prefix: str | None = None, part_pixelization: str | None = None, part_pix_level: int | None = None, @@ -220,6 +221,14 @@ def init_database( Port number to use for Cassandra connections. username : `str`, optional User name for Cassandra connections. + dbauth_alias : `str`, optional + If specified then this string will be used to as a host name when + checking credentials in db-auth.yaml in addition to regular host + names in contact_points. For example if + dbauth_alias='pp_apdb_prod_cluster' then the entry + 'cassandra://pp_apdb_prod_cluster/' will match. Port number should + not be used in that entry. Alias has higher priority than host + names. prefix : `str`, optional Optional prefix for all table names. part_pixelization : `str`, optional @@ -289,6 +298,8 @@ def init_database( config.connection_config.port = port if username is not None: config.connection_config.username = username + if dbauth_alias is not None: + config.connection_config.dbauth_alias = dbauth_alias if prefix is not None: config.prefix = prefix if part_pixelization is not None: diff --git a/python/lsst/dax/apdb/cassandra/config.py b/python/lsst/dax/apdb/cassandra/config.py index 0df973d..9749a7a 100644 --- a/python/lsst/dax/apdb/cassandra/config.py +++ b/python/lsst/dax/apdb/cassandra/config.py @@ -69,6 +69,16 @@ class ApdbCassandraConnectionConfig(BaseModel): ), ) + dbauth_alias: str = Field( + default="", + description=( + "If specified then this string will be used to as a host name when checking credentials in " + "db-auth.yaml in addition to regular host names in contact_points. For example if " + "dbauth_alias='pp_apdb_prod_cluster' then the entry 'cassandra://pp_apdb_prod_cluster/' will " + "match. Port number should not be used in that entry. Alias has higher priority than host names." + ), + ) + read_consistency: str = Field( default="QUORUM", description="Name for consistency level of read operations, default: QUORUM, can be ONE.", diff --git a/python/lsst/dax/apdb/cassandra/sessionFactory.py b/python/lsst/dax/apdb/cassandra/sessionFactory.py index 6d9f3cd..df6e3fb 100644 --- a/python/lsst/dax/apdb/cassandra/sessionFactory.py +++ b/python/lsst/dax/apdb/cassandra/sessionFactory.py @@ -160,15 +160,22 @@ def _make_auth_provider(self) -> AuthProvider | None: # Credentials file doesn't exist, use anonymous login. return None + # If dbauth_alias is defined then try it too without port number. + hosts: list[tuple[str, int | None]] = [ + (hostname, self._config.connection_config.port) for hostname in self._config.contact_points + ] + if self._config.connection_config.dbauth_alias: + hosts = [(self._config.connection_config.dbauth_alias, None)] + hosts + empty_username = True # Try every contact point in turn. - for hostname in self._config.contact_points: + for hostname, port in hosts: try: username, password = dbauth.getAuth( "cassandra", self._config.connection_config.username, hostname, - self._config.connection_config.port, + port, self._config.keyspace, ) if not username: diff --git a/python/lsst/dax/apdb/cli/options.py b/python/lsst/dax/apdb/cli/options.py index e14df66..14e1400 100644 --- a/python/lsst/dax/apdb/cli/options.py +++ b/python/lsst/dax/apdb/cli/options.py @@ -102,6 +102,7 @@ def cassandra_config_options(parser: argparse.ArgumentParser) -> None: _option_from_pydantic_field(group, ApdbCassandraConfig, "replica_skips_diaobjects", action="store_true") _option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "port", metavar="PORT") _option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "username", metavar="USER") + _option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "dbauth_alias", metavar="NAME") _option_from_pydantic_field(group, ApdbCassandraConfig, "prefix") group.add_argument( "--replication-factor", help="Replication factor used when creating new keyspace.", type=int diff --git a/python/lsst/dax/apdb/tests/utils.py b/python/lsst/dax/apdb/tests/utils.py index c5de1fd..0b29541 100644 --- a/python/lsst/dax/apdb/tests/utils.py +++ b/python/lsst/dax/apdb/tests/utils.py @@ -21,7 +21,12 @@ from __future__ import annotations +__all__ = ["TestCaseMixin", "modified_environment"] + +import contextlib +import os import unittest +from collections.abc import Iterator from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -33,3 +38,22 @@ class TestCaseMixin(unittest.TestCase): class TestCaseMixin: """Do-nothing definition of mixin base class for regular execution.""" + + +# Stolen from daf_butler +@contextlib.contextmanager +def modified_environment(**environ: str) -> Iterator[None]: + """Temporarily set environment variables. + + Parameters + ---------- + **environ : `dict` + Key value pairs of environment variables to temporarily set. + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) diff --git a/tests/test_apdbCassandra.py b/tests/test_apdbCassandra.py index 0e4b69b..38000cf 100644 --- a/tests/test_apdbCassandra.py +++ b/tests/test_apdbCassandra.py @@ -51,9 +51,11 @@ ) from lsst.dax.apdb.cassandra import ApdbCassandra, ApdbCassandraConfig from lsst.dax.apdb.cassandra.connectionContext import ConnectionContext +from lsst.dax.apdb.cassandra.sessionFactory import SessionFactory from lsst.dax.apdb.pixelization import Pixelization from lsst.dax.apdb.tests import ApdbSchemaUpdateTest, ApdbTest, cassandra_mixin from lsst.dax.apdb.tests.data_factory import makeObjectCatalog +from lsst.dax.apdb.tests.utils import modified_environment TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema-apdb.yaml") TEST_SCHEMA_SSO = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema-sso.yaml") @@ -217,6 +219,83 @@ def test_version_check(self) -> None: Apdb.from_config(self.config).metadata.items() +_DB_AUTH_JSON = """\ +[{ + "url": "cassandra://user1000@node1.slac.stanford.edu:9042/", + "username": "user01", + "password": "pass01" +}, { + "url": "cassandra://node2.slac.stanford.edu:9042/", + "username": "user02", + "password": "pass02" +}, { + "url": "cassandra://node1.slac.stanford.edu:9042/apdb_dev", + "username": "user03", + "password": "pass03" +}, { + "url": "cassandra://user2000@test_cluster/", + "username": "user04", + "password": "pass04" +}, { + "url": "cassandra://test_cluster/", + "username": "user05", + "password": "pass05" +}] +""" + + +class ApdbCassandraDbAuthTest(unittest.TestCase): + """A test case for extracting credentials from db-auth.yaml.""" + + def _make_config(self) -> ApdbCassandraConfig: + config = ApdbCassandraConfig( + contact_points=("node1.slac.stanford.edu", "node2.slac.stanford.edu"), + keyspace="apdb", + ) + return config + + @unittest.skipIf(not cassandra_mixin.CASSANDRA_IMPORTED, "cassandra_driver cannot be imported") + def test_dbauth(self) -> None: + """Check credentials access.""" + with modified_environment(LSST_DB_AUTH_CREDENTIALS=_DB_AUTH_JSON): + config = self._make_config() + + factory = SessionFactory(config) + + # Should match second entry. + auth = factory._make_auth_provider() + assert auth is not None + self.assertEqual(auth.username, "user02") + + config.keyspace = "apdb_dev" + # Should match third entry. + auth = factory._make_auth_provider() + assert auth is not None + self.assertEqual(auth.username, "user03") + + config.connection_config.username = "user1000" + # Should match first entry, returns original user name. + auth = factory._make_auth_provider() + assert auth is not None + self.assertEqual(auth.username, "user1000") + self.assertEqual(auth.password, "pass01") + + config.connection_config.username = "" + config.connection_config.dbauth_alias = "test_cluster" + # Should match fifth entry. + auth = factory._make_auth_provider() + assert auth is not None + self.assertEqual(auth.username, "user05") + + config.connection_config.username = "user2000" + config.connection_config.dbauth_alias = "test_cluster" + # Should match fourth entry. + auth = factory._make_auth_provider() + assert auth is not None + self.assertEqual(auth.username, "user2000") + self.assertEqual(auth.password, "pass04") + + class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): """Run file leak tests."""