Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
name: CI

on:
push:
pull_request:

# Cancel superseded runs on the same ref so rapid pushes don't pile up.
concurrency:
group: ci-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
runs-on: ubuntu-latest
# `push` fires on every branch. To avoid a duplicate run in the base repo for same-repo branches, only
# let the `pull_request` event add a run when the PR comes from a fork.
if: >-
github.event_name != 'pull_request' ||
github.event.pull_request.head.repo.full_name != github.repository
steps:
- uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1

- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
cache-suffix: lint
python-version: "3.12"

- run: uv sync

# The full repository has never been gated on complete ruff/black/mypy runs, so these report on
# pre-existing debt without blocking. Flip to blocking once clean.
- name: ruff (informational)
run: uv run ruff check --output-format=concise . | tee ruff.txt
continue-on-error: true

- name: black (informational)
run: uv run black --check . 2>&1 | tee black.txt
continue-on-error: true

- name: mypy (informational)
run: uv run mypy . | tee mypy.txt
continue-on-error: true

- name: Lint summary
if: always()
shell: bash
run: |
{
echo "## Lint (informational)"
echo ""
for tool in ruff black mypy; do
echo "<details><summary>$tool</summary>"
echo ""
echo '```'
cat "$tool.txt" 2>/dev/null || echo "(no output)"
echo '```'
echo ""
echo "</details>"
echo ""
done
} >> "$GITHUB_STEP_SUMMARY"

test:
runs-on: ubuntu-latest
timeout-minutes: 60
if: >-
github.event_name != 'pull_request' ||
github.event.pull_request.head.repo.full_name != github.repository
steps:
- uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1

- name: Compute pinned schema commit
id: schema
run: echo "sha=$(git ls-tree HEAD domain/netex/schema --object-only)" >> "$GITHUB_OUTPUT"

- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
cache-suffix: test
python-version: "3.12"

- run: uv sync

- name: Cache generated NeTEx model
id: model-cache
uses: actions/cache@caa296126883cff596d87d8935842f9db880ef25 # v5.1.0
with:
path: domain/netex/model
key: netex-model-${{ steps.schema.outputs.sha }}-${{ hashFiles('domain/netex/conf/xsdata.conf', 'uv.lock') }}-py3.12

# scripts/generate-schema.sh is deliberately not used here: its
# `git submodule update --remote` floats the submodule to the branch tip,
# which would make the build non-reproducible.
- name: Generate NeTEx model
if: steps.model-cache.outputs.cache-hit != 'true'
run: |
git submodule update --init domain/netex/schema
uv run xsdata generate -c domain/netex/conf/xsdata.conf domain/netex/schema/xsd/NeTEx_publication.xsd
uv run python -m compileall -q domain/netex/model

- name: Run tests
run: uv run python scripts/run_tests.py
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ domain/netex/model/**
*.txt

*.jar
uv.lock
local_configuration.py
tools/results.xlsx
tools/~$results.xlsx
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ sh scripts/setup.sh
```
*For Microsoft Windows users, see the commands in the [shell file](scripts/setup.sh).*

### Testing

```sh
uv run python -m unittest discover -s tests -t .
```

This runs the test suite in `tests/`; it only uses temporary databases created on the fly.

The NeTEx model must have been generated before the tests can import `domain.netex.model`
(`sh scripts/generate-schema.sh`).

### Update schemas
```sh
sh scripts/generate-schema.sh
Expand Down
20 changes: 6 additions & 14 deletions domain/netex/services/recursive_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import inspect

from utils.mro_attributes import list_attributes
from utils import netex_monkeypatching
from utils.mro_attributes import resolve_class, unembed
from utils import netex_monkeypatching # noqa: F401

from dataclasses import fields, MISSING

Expand Down Expand Up @@ -43,18 +43,10 @@ def _all_subclasses(cls: type[Any]) -> set[type[Any]]:

def get_all_geo_elements() -> Generator[Any, None, None]:
for clazz_parent in get_boring_classes():
attrs = list_attributes(clazz_parent)
for attr in attrs:
clazz = attr[3].type
if clazz is not None and hasattr(clazz, '_name'):
if (clazz._name == 'Optional' or clazz._name == 'Union') and not isinstance(clazz, str):
clazz_resolved = [x for x in clazz.__args__ if x is not None][0]
else:
clazz_resolved = clazz

if clazz_resolved in GEO_CLASSES:
yield clazz_parent
break
for _name, _field, field_type in unembed(clazz_parent):
if resolve_class(field_type) in GEO_CLASSES:
yield clazz_parent
break


netex.set_geo_types = frozenset(get_all_geo_elements()) # type: ignore[attr-defined]
Expand Down
File renamed without changes.
104 changes: 104 additions & 0 deletions fix/trenitalia/add_train_numbers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Add TrainNumbers to the Trenitalia feed using ServiceJourney/Name

This fix gives every Trenitalia journey a real TrainNumber: it reads each ServiceJourney's
number from its name, creates one first-class TrainNumber per distinct (deduplicated) number,
and points the journey at it via `trainNumbers`. This aligns with the Austrian and Swiss datasets.

Usage:
uv run python -m fix.trenitalia.add_train_numbers path/to/it.lmdb
"""

import logging
import re
from pathlib import Path
from typing import Optional

from domain.netex.model import (
ServiceJourney,
TextType,
TrainNumber,
TrainNumberRef,
TrainNumberRefsRelStructure,
)
from storage.mdbx.core.implementation import MdbxStorage
from storage.mdbx.core.references import resolve_embeddings
from utils.aux_logging import log_all, prepare_logger

_DIGITS = re.compile(r"\d+")


def normalize_train_number(num: Optional[str]) -> Optional[str]:
if num is None:
return None
n = num.strip()
m = _DIGITS.search(n)
return str(int(m.group())) if m else None


def train_number_for_journey(sj: ServiceJourney) -> Optional[str]:
if sj.name is not None and sj.name.content:
first = sj.name.content[0]
text = first.value if isinstance(first, TextType) else str(first)
return normalize_train_number(text)
return None


def _tn_id(num: str) -> str:
return f"IT:TrainNumber:{num}"


def add_train_numbers(db: MdbxStorage) -> None:
# Read pass: journeys that don't yet carry a train number
with db.env.ro_transaction() as rtx:
todo: list[tuple[ServiceJourney, str]] = []
for sj in db.iter_only_objects(rtx, ServiceJourney):
if sj.train_numbers is not None:
continue
num = train_number_for_journey(sj)
if not num:
continue
todo.append((sj, num))

# One first-class TrainNumber per distinct number (many journeys share a number).
train_numbers: dict[str, TrainNumber] = {}
for sj, num in todo:
if num not in train_numbers:
train_numbers[num] = TrainNumber(id=_tn_id(num), version=sj.version, for_advertisement=num)

log_all(logging.INFO, f"[add_train_numbers] {len(todo)} Trenitalia journeys -> " f"{len(train_numbers)} distinct train numbers")

with db.env.rw_transaction() as wtx:
db.insert_any_object_on_queue(wtx, train_numbers.values())
for sj, num in todo:
train_number = train_numbers[num]
sj.train_numbers = TrainNumberRefsRelStructure(train_number_ref=[TrainNumberRef(ref=train_number.id, version=train_number.version)])
db.insert_any_object_on_queue(wtx, [sj for sj, _ in todo])
wtx.commit()

resolve_embeddings(db)
log_all(logging.INFO, f"[add_train_numbers] done: added {len(train_numbers)} TrainNumbers, " f"linked {len(todo)} Trenitalia journeys")


def main(source_database_file: str) -> None:
with MdbxStorage(Path(source_database_file), readonly=False) as db:
return add_train_numbers(db)


if __name__ == "__main__":
import argparse
import traceback

parser = argparse.ArgumentParser(
description="Add NeTEx TrainNumber objects to the Trenitalia journeys (read each "
"journey's number from its name) so it matches the Austrian/Swiss feeds."
)
parser.add_argument("source", type=str, help="Trenitalia NeTEx mdbx file")
parser.add_argument("--log_file", type=str, required=False, help="the logfile")
args = parser.parse_args()
prepare_logger(logging.INFO, args.log_file)
try:
main(args.source)
except Exception as e:
log_all(logging.ERROR, f"{e} {traceback.format_exc()}")
raise e
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ dependencies = [
]
requires-python = ">=3.12"

[dependency-groups]
dev = [
"mypy",
"black",
"ruff",
]

[tool.black]
extend-exclude = "domain/netex/model/*.py"
preview = true
Expand All @@ -45,4 +52,4 @@ exclude = ["domain/netex/model/*.py"]

[[tool.mypy.overrides]]
module = ["domain.netex.model.*"]
ignore_errors = true
ignore_errors = true
100 changes: 100 additions & 0 deletions scripts/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python
"""Run the unittest suite and emit a GitHub-flavored Markdown results table.

Run from the repository root: ``uv run python scripts/run_tests.py``.
"""

import os
import time
import unittest
from collections import OrderedDict
from typing import Any


class _Tally:
__slots__ = ("tests", "passed", "failed", "skipped", "seconds")

def __init__(self) -> None:
self.tests = 0
self.passed = 0
self.failed = 0 # failures + errors
self.skipped = 0
self.seconds = 0.0


class MarkdownResult(unittest.TextTestResult):
"""A TextTestResult that additionally tallies outcomes per TestCase class."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._starts: dict[unittest.TestCase, float] = {}
self.suites: "OrderedDict[str, _Tally]" = OrderedDict()

def _bucket(self, test: unittest.TestCase) -> _Tally:
key = f"{type(test).__module__}.{type(test).__qualname__}"
return self.suites.setdefault(key, _Tally())

def startTest(self, test: unittest.TestCase) -> None:
super().startTest(test)
self._starts[test] = time.perf_counter()

def stopTest(self, test: unittest.TestCase) -> None:
super().stopTest(test)
bucket = self._bucket(test)
bucket.tests += 1
bucket.seconds += time.perf_counter() - self._starts.pop(test, time.perf_counter())

def addSuccess(self, test: unittest.TestCase) -> None:
super().addSuccess(test)
self._bucket(test).passed += 1

def addFailure(self, test: unittest.TestCase, err: Any) -> None:
super().addFailure(test, err)
self._bucket(test).failed += 1

def addError(self, test: unittest.TestCase, err: Any) -> None:
super().addError(test, err)
self._bucket(test).failed += 1

def addSkip(self, test: unittest.TestCase, reason: str) -> None:
super().addSkip(test, reason)
self._bucket(test).skipped += 1


def _render(result: MarkdownResult) -> str:
lines = [
"## Test results",
"",
"| Suite | Tests | ✅ | ❌ | ⎭ | Time |",
"|-------|------:|--:|--:|--:|-----:|",
]
total = _Tally()
for name, bucket in result.suites.items():
total.tests += bucket.tests
total.passed += bucket.passed
total.failed += bucket.failed
total.skipped += bucket.skipped
total.seconds += bucket.seconds
lines.append(f"| {name} | {bucket.tests} | {bucket.passed} | {bucket.failed} | {bucket.skipped} | {bucket.seconds:.2f}s |")
lines.append(f"| **TOTAL** | **{total.tests}** | **{total.passed}** | **{total.failed}** | **{total.skipped}** | **{total.seconds:.2f}s** |")
return "\n".join(lines) + "\n"


def main() -> int:
suite = unittest.TestLoader().discover(start_dir="tests", pattern="test*.py", top_level_dir=".")
result = unittest.TextTestRunner(resultclass=MarkdownResult, verbosity=2).run(suite)
assert isinstance(result, MarkdownResult)

table = _render(result)
summary_path = os.environ.get("GITHUB_STEP_SUMMARY")
if summary_path:
with open(summary_path, "a", encoding="utf-8") as handle:
handle.write(table)
else:
print("\n" + table)

return 0 if result.wasSuccessful() else 1


if __name__ == "__main__":
raise SystemExit(main())
Loading