Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ jobs:
conda install -y -q pip wheel
pip install uv

# Need cassandra-driver from conda-forge where it has a patch for 3.13
- name: Install dependencies
shell: bash -l {0}
run: |
conda install -y -q cassandra-driver
uv pip install -r requirements.txt

# We have two cores so we can speed up the testing with xdist
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
uses: lsst/rubin_workflows/.github/workflows/mypy.yaml@main
with:
folders: "python tests"
mypy_package: "mypy<1.20"
mypy_package: "mypy<2.0"
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-yaml
args:
Expand All @@ -10,7 +10,7 @@ repos:
- id: check-toml
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.10
rev: v0.14.5
hooks:
- id: ruff
args: [--fix]
Expand Down
11 changes: 11 additions & 0 deletions python/lsst/dax/apdb/cassandra/apdbCassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def init_database(
replica_skips_diaobjects: bool = False,
port: int | None = None,
username: str | None = None,
dbauth_alias: str | None = None,
prefix: str | None = None,
part_pixelization: str | None = None,
part_pix_level: int | None = None,
Expand Down Expand Up @@ -220,6 +221,14 @@ def init_database(
Port number to use for Cassandra connections.
username : `str`, optional
User name for Cassandra connections.
dbauth_alias : `str`, optional
If specified then this string will be used to as a host name when
checking credentials in db-auth.yaml in addition to regular host
names in contact_points. For example if
dbauth_alias='pp_apdb_prod_cluster' then the entry
'cassandra://pp_apdb_prod_cluster/' will match. Port number should
not be used in that entry. Alias has higher priority than host
names.
Comment thread
andy-slac marked this conversation as resolved.
prefix : `str`, optional
Optional prefix for all table names.
part_pixelization : `str`, optional
Expand Down Expand Up @@ -289,6 +298,8 @@ def init_database(
config.connection_config.port = port
if username is not None:
config.connection_config.username = username
if dbauth_alias is not None:
config.connection_config.dbauth_alias = dbauth_alias
if prefix is not None:
config.prefix = prefix
if part_pixelization is not None:
Expand Down
10 changes: 10 additions & 0 deletions python/lsst/dax/apdb/cassandra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ class ApdbCassandraConnectionConfig(BaseModel):
),
)

dbauth_alias: str = Field(
default="",
description=(
"If specified then this string will be used to as a host name when checking credentials in "
"db-auth.yaml in addition to regular host names in contact_points. For example if "
"dbauth_alias='pp_apdb_prod_cluster' then the entry 'cassandra://pp_apdb_prod_cluster/' will "
"match. Port number should not be used in that entry. Alias has higher priority than host names."
),
)

read_consistency: str = Field(
default="QUORUM",
description="Name for consistency level of read operations, default: QUORUM, can be ONE.",
Expand Down
11 changes: 9 additions & 2 deletions python/lsst/dax/apdb/cassandra/sessionFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,22 @@ def _make_auth_provider(self) -> AuthProvider | None:
# Credentials file doesn't exist, use anonymous login.
return None

# If dbauth_alias is defined then try it too without port number.
hosts: list[tuple[str, int | None]] = [
(hostname, self._config.connection_config.port) for hostname in self._config.contact_points
]
if self._config.connection_config.dbauth_alias:
hosts = [(self._config.connection_config.dbauth_alias, None)] + hosts

empty_username = True
# Try every contact point in turn.
for hostname in self._config.contact_points:
for hostname, port in hosts:
try:
username, password = dbauth.getAuth(
"cassandra",
self._config.connection_config.username,
hostname,
self._config.connection_config.port,
port,
self._config.keyspace,
)
if not username:
Expand Down
1 change: 1 addition & 0 deletions python/lsst/dax/apdb/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def cassandra_config_options(parser: argparse.ArgumentParser) -> None:
_option_from_pydantic_field(group, ApdbCassandraConfig, "replica_skips_diaobjects", action="store_true")
_option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "port", metavar="PORT")
_option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "username", metavar="USER")
_option_from_pydantic_field(group, ApdbCassandraConnectionConfig, "dbauth_alias", metavar="NAME")
_option_from_pydantic_field(group, ApdbCassandraConfig, "prefix")
group.add_argument(
"--replication-factor", help="Replication factor used when creating new keyspace.", type=int
Expand Down
24 changes: 24 additions & 0 deletions python/lsst/dax/apdb/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

from __future__ import annotations

__all__ = ["TestCaseMixin", "modified_environment"]

import contextlib
import os
import unittest
from collections.abc import Iterator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -33,3 +38,22 @@ class TestCaseMixin(unittest.TestCase):

class TestCaseMixin:
"""Do-nothing definition of mixin base class for regular execution."""


# Stolen from daf_butler
@contextlib.contextmanager
def modified_environment(**environ: str) -> Iterator[None]:
"""Temporarily set environment variables.

Parameters
----------
**environ : `dict`
Key value pairs of environment variables to temporarily set.
"""
old_environ = dict(os.environ)
os.environ.update(environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(old_environ)
79 changes: 79 additions & 0 deletions tests/test_apdbCassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@
)
from lsst.dax.apdb.cassandra import ApdbCassandra, ApdbCassandraConfig
from lsst.dax.apdb.cassandra.connectionContext import ConnectionContext
from lsst.dax.apdb.cassandra.sessionFactory import SessionFactory
from lsst.dax.apdb.pixelization import Pixelization
from lsst.dax.apdb.tests import ApdbSchemaUpdateTest, ApdbTest, cassandra_mixin
from lsst.dax.apdb.tests.data_factory import makeObjectCatalog
from lsst.dax.apdb.tests.utils import modified_environment

TEST_SCHEMA = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema-apdb.yaml")
TEST_SCHEMA_SSO = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config/schema-sso.yaml")
Expand Down Expand Up @@ -217,6 +219,83 @@ def test_version_check(self) -> None:
Apdb.from_config(self.config).metadata.items()


_DB_AUTH_JSON = """\
[{
"url": "cassandra://user1000@node1.slac.stanford.edu:9042/",
"username": "user01",
"password": "pass01"
}, {
"url": "cassandra://node2.slac.stanford.edu:9042/",
"username": "user02",
"password": "pass02"
}, {
"url": "cassandra://node1.slac.stanford.edu:9042/apdb_dev",
"username": "user03",
"password": "pass03"
}, {
"url": "cassandra://user2000@test_cluster/",
"username": "user04",
"password": "pass04"
}, {
"url": "cassandra://test_cluster/",
"username": "user05",
"password": "pass05"
}]
"""


class ApdbCassandraDbAuthTest(unittest.TestCase):
"""A test case for extracting credentials from db-auth.yaml."""

def _make_config(self) -> ApdbCassandraConfig:
config = ApdbCassandraConfig(
contact_points=("node1.slac.stanford.edu", "node2.slac.stanford.edu"),
keyspace="apdb",
)
return config

@unittest.skipIf(not cassandra_mixin.CASSANDRA_IMPORTED, "cassandra_driver cannot be imported")
def test_dbauth(self) -> None:
"""Check credentials access."""
with modified_environment(LSST_DB_AUTH_CREDENTIALS=_DB_AUTH_JSON):
config = self._make_config()

factory = SessionFactory(config)

# Should match second entry.
auth = factory._make_auth_provider()
assert auth is not None
self.assertEqual(auth.username, "user02")

config.keyspace = "apdb_dev"
# Should match third entry.
auth = factory._make_auth_provider()
assert auth is not None
self.assertEqual(auth.username, "user03")

config.connection_config.username = "user1000"
# Should match first entry, returns original user name.
auth = factory._make_auth_provider()
assert auth is not None
self.assertEqual(auth.username, "user1000")
self.assertEqual(auth.password, "pass01")

config.connection_config.username = ""
config.connection_config.dbauth_alias = "test_cluster"
# Should match fifth entry.
auth = factory._make_auth_provider()
assert auth is not None
self.assertEqual(auth.username, "user05")

config.connection_config.username = "user2000"
config.connection_config.dbauth_alias = "test_cluster"
# Should match fourth entry.
auth = factory._make_auth_provider()
assert auth is not None
self.assertEqual(auth.username, "user2000")
self.assertEqual(auth.password, "pass04")


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
"""Run file leak tests."""

Expand Down
Loading