diff --git a/.dockerignore b/.dockerignore
index 574bedd..6925a64 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,36 +1,25 @@
-# Ignore everything (use whitelist approach rather to ensure only important files kept)
+# Ignore everything then whitelist what the image actually needs
*
-# Allow files and directories
+# Package source and build artefacts
+!/decider
!/dist
-!/spockflow
-!/requirements
!/pyproject.toml
+!/uv.lock
-# Ignore unnecessary files inside allowed directories
-# This should go after the allowed directories
-**/*~
-**/*.log
-**/.DS_Store
-**/Thumbs.db
+# Strip noise from allowed directories
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
-**/.Python
-**/env
-**/pip-log.txt
-**/pip-delete-this-directory.txt
-**/.tox
-**/.coverage
-**/.coverage.*
-**/.cache
-**/nosetests.xml
-**/coverage.xml
-**/*.cover
+**/*.egg-info
+**/.DS_Store
+**/Thumbs.db
**/*.log
**/.git
**/.mypy_cache
**/.pytest_cache
**/.hypothesis
-**/.DS_Store
\ No newline at end of file
+**/.coverage
+**/.coverage.*
+**/.tox
diff --git a/.gitattributes b/.gitattributes
index 3a2b045..77de7a6 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1 +1,5 @@
-spockflow/_version.py export-subst
+# Ensure consistent line endings
+* text=auto
+
+# Mark generated version file so it survives `git archive --export-subst`
+decider/_version.py export-subst
diff --git a/.github/workflows/feature.yml b/.github/workflows/feature.yml
deleted file mode 100644
index 3d4b2c5..0000000
--- a/.github/workflows/feature.yml
+++ /dev/null
@@ -1,44 +0,0 @@
-name: Build and Test
-
-on:
- pull_request:
- branches:
- - dev
- - main
- - release/*
-
-jobs:
- build-package:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
-
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: '3.12.4'
-
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install -r requirements/dev.txt
- pip install -r requirements/all.txt
-
- - name: Format code with Black
- run: black --check .
-
- - name: Run tests
- run: pytest
-
- - name: Build Python package
- run: python -m build
-
- - name: Upload Python package
- uses: actions/upload-artifact@v4
- with:
- name: python-package
- path: dist/*
- retention-days: 1
diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml
new file mode 100644
index 0000000..6d2a537
--- /dev/null
+++ b/.github/workflows/pr.yml
@@ -0,0 +1,31 @@
+name: PR Checks
+
+on:
+ pull_request:
+ branches:
+ - main
+
+permissions:
+ contents: read
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+
+ - name: Install dependencies
+ run: uv sync --group dev
+
+ - name: Run tests
+ run: uv run pytest
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 14e231f..713db3e 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -1,236 +1,220 @@
-name: Build and Release
+name: Release
on:
push:
branches:
- main
- - dev
- - release/*
env:
- DOCKER_IMAGE_NAME: sjnarmstrong/spockflow
+ DOCKER_IMAGE: capitecbankltd/decider
jobs:
- get-version:
+ # ── 1. Resolve version from git tags (hatch-vcs) ──────────────────────────
+ version:
runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
-
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install versioneer setuptools
-
- - name: Get version
- id: get-package-version
- run: |
- version=$(python setup.py --version)
- echo "PACKAGE_VERSION=$(python setup.py --version)" >> $GITHUB_OUTPUT
outputs:
- PACKAGE_VERSION: ${{ steps.get-package-version.outputs.PACKAGE_VERSION }}
-
- build-package:
- needs: ["get-version"]
- runs-on: ubuntu-latest
+ version: ${{ steps.get.outputs.version }}
steps:
- - name: Checkout code
- uses: actions/checkout@v4
+ - uses: actions/checkout@v4
with:
- fetch-depth: 0
+ fetch-depth: 0 # full history so hatch-vcs can read tags
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
- python-version: '3.12.4'
-
- - name: Tag Release
- env:
- PACKAGE_VERSION: ${{needs.get-version.outputs.PACKAGE_VERSION}}
- run: |
- git config --global user.email "githubbot@donotreply.com"
- git config --global user.name "Sholto Armstrong"
- git tag -a v${PACKAGE_VERSION} -m "Release v${PACKAGE_VERSION}"
+ python-version: "3.12"
- - name: Install dependencies
+ - name: Resolve version
+ id: get
run: |
- python -m pip install --upgrade pip
- pip install -r requirements/dev.txt
- pip install -r requirements/all.txt
-
- - name: Format code with Black
- run: black --check .
-
- - name: Run tests
- run: pytest
-
- - name: Build Python package
- run: python -m build
+ uv sync --group dev
+ VERSION=$(uv run python -c "from importlib.metadata import version; print(version('decider'))")
+ echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
- - name: Upload Python package
- uses: actions/upload-artifact@v4
- with:
- name: python-package
- path: dist/*
-
- build-docs:
- needs: ["get-version"]
+ # ── 2a. Build wheel → PyPI ─────────────────────────────────────────────────
+ publish-pypi:
+ needs: version
runs-on: ubuntu-latest
+ environment:
+ name: release
+ url: https://pypi.org/p/decider
+ permissions:
+ id-token: write # OIDC trusted publishing — no token secret needed
steps:
- - name: Checkout code
- uses: actions/checkout@v4
+ - uses: actions/checkout@v4
with:
fetch-depth: 0
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
+
- name: Set up Python
uses: actions/setup-python@v5
with:
- python-version: '3.12.4'
-
- - name: Tag Release
- env:
- PACKAGE_VERSION: ${{needs.get-version.outputs.PACKAGE_VERSION}}
- run: |
- git config --global user.email "githubbot@donotreply.com"
- git config --global user.name "Sholto Armstrong"
- git tag -a v${PACKAGE_VERSION} -m "Release v${PACKAGE_VERSION}"
+ python-version: "3.12"
- - name: Install documentation dependencies
+ - name: Run tests
run: |
- python -m pip install --upgrade pip
- pip install -r requirements/all.txt
- pip install .
+ uv sync --group dev
+ uv run pytest
- - name: Build documentation
- run: |
- cd docs
- make html
+ - name: Build wheel + sdist
+ run: uv build
- - name: Archive artifact
- shell: sh
- run: |
- echo ::group::Archive artifact
- tar \
- --dereference --hard-dereference \
- --directory "$INPUT_PATH" \
- -cvf "$RUNNER_TEMP/artifact.tar" \
- --exclude=.git \
- --exclude=.github \
- .
- echo ::endgroup::
- env:
- INPUT_PATH: docs/_build/html/
-
- - name: Upload docs
+ - name: Upload wheel artifact
uses: actions/upload-artifact@v4
with:
- name: 'github-pages'
- path: ${{ runner.temp }}/artifact.tar
- if-no-files-found: error
+ name: dist
+ path: dist/
+ retention-days: 1
- release:
- needs: ["build-docs", "build-package", "get-version"]
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ skip-existing: true
+
+ # ── 2b. Build and push Docker image ───────────────────────────────────────
+ publish-docker:
+ needs: [version, publish-pypi]
runs-on: ubuntu-latest
environment:
- name: pypi
- url: https://pypi.org/p/spockflow
+ name: release
permissions:
- id-token: write
- pages: write
- contents: write
- pull-requests: write
- repository-projects: write
+ contents: read
steps:
- - name: Checkout code
- uses: actions/checkout@v4
+ - uses: actions/checkout@v4
with:
fetch-depth: 0
-
-
- - name: Tag Release
- env:
- PACKAGE_VERSION: ${{needs.get-version.outputs.PACKAGE_VERSION}}
- run: |
- git config --global user.email "githubbot@donotreply.com"
- git config --global user.name "Github Bot"
- git tag -a v${PACKAGE_VERSION} -m "Release v${PACKAGE_VERSION}"
- - name: Set up Python
- uses: actions/setup-python@v5
+ - name: Download wheel artifact
+ uses: actions/download-artifact@v4
with:
- python-version: "3.12.4"
+ name: dist
+ path: dist/
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
-
+
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- - name: Login to Docker Hub
+ - name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install versioneer setuptools
-
- - name: Get version
- run: |
- version=$(python setup.py --version)
- echo "PACKAGE_VERSION=$(python setup.py --version)" >> $GITHUB_ENV
-
- - name: Download Python Artifact
- uses: actions/download-artifact@v4
- with:
- name: python-package
- path: dist
- - name: Build Docker image
- uses: docker/build-push-action@v3
+ - name: Build and push versioned image
+ uses: docker/build-push-action@v6
with:
context: .
- push: true
- tags: ${{ env.DOCKER_IMAGE_NAME }}:${{ env.PACKAGE_VERSION }}
file: docker/Dockerfile
+ push: true
+ tags: ${{ env.DOCKER_IMAGE }}:${{ needs.version.outputs.version }}
+ cache-from: type=gha
+ cache-to: type=gha,mode=max
- - name: Upload Latest Docker image
- uses: docker/build-push-action@v3
- if: github.ref == 'refs/heads/main'
+ - name: Re-tag as latest
+ uses: docker/build-push-action@v6
with:
context: .
- push: true
- tags: ${{ env.DOCKER_IMAGE_NAME }}:latest
file: docker/Dockerfile
+ push: true
+ tags: ${{ env.DOCKER_IMAGE }}:latest
+ cache-from: type=gha
- - name: Publish distribution 📦 to PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
+ # ── 2c. Build and deploy versioned docs to GitHub Pages ───────────────────
+ publish-docs:
+ needs: version
+ runs-on: ubuntu-latest
+ environment:
+ name: release
+ url: https://capitecbankltd.github.io/dsp_north-polrs
+ permissions:
+ pages: write
+ id-token: write
+ steps:
+ - uses: actions/checkout@v4
with:
- skip-existing: true
+ fetch-depth: 0
- - name: Deploy documentation to GitHub Pages
- if: github.ref == 'refs/heads/main'
- uses: actions/deploy-pages@v4
-
- - name: Push Tag
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+
+ - name: Install docs dependencies
+ run: uv sync --extra docs
+
+ - name: Restore previous gh-pages (for versioned history)
+ # Pull whatever is already deployed so we can layer this release on top.
+ # If the branch doesn't exist yet the step silently skips.
+ run: |
+ git fetch origin gh-pages:gh-pages 2>/dev/null || true
+ if git show-ref --quiet refs/heads/gh-pages; then
+ git worktree add _site gh-pages
+ else
+ mkdir -p _site
+ fi
+
+ - name: Build docs for this version
env:
- PACKAGE_VERSION: ${{needs.get-version.outputs.PACKAGE_VERSION}}
+ VERSION: ${{ needs.version.outputs.version }}
run: |
- git push origin tag v${PACKAGE_VERSION}
+ uv run python -m sphinx docs docs/_build/html -b html
+ # Copy into a versioned subdirectory so older releases stay accessible.
+ mkdir -p _site/$VERSION
+ cp -r docs/_build/html/. _site/$VERSION/
+ # Also update the root (latest) copy.
+ cp -r docs/_build/html/. _site/
+
+ - name: Upload Pages artifact
+ uses: actions/upload-pages-artifact@v3
+ with:
+ path: _site/
+
+ - name: Deploy to GitHub Pages
+ uses: actions/deploy-pages@v4
- - name: Create Release
- if: github.ref == 'refs/heads/main'
- id: release-snapshot
- uses: actions/create-release@latest
+ # ── 3. Tag and create GitHub Release (after all three deploys succeed) ─────
+ github-release:
+ needs: [version, publish-pypi, publish-docker, publish-docs]
+ runs-on: ubuntu-latest
+ environment:
+ name: release
+ permissions:
+ contents: write
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Download wheel artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist/
+
+ - name: Push version tag
env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ VERSION: ${{ needs.version.outputs.version }}
+ run: |
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+ git config user.name "github-actions[bot]"
+ # Skip if tag already exists (e.g. re-run on same commit).
+ git tag "v${VERSION}" || true
+ git push origin "v${VERSION}" || true
+
+ - name: Create GitHub Release
+ uses: softprops/action-gh-release@v2
with:
- tag_name: v${{needs.get-version.outputs.PACKAGE_VERSION}}
- release_name: Release ${{needs.get-version.outputs.PACKAGE_VERSION}}
- draft: false
- prerelease: false
-
\ No newline at end of file
+ tag_name: v${{ needs.version.outputs.version }}
+ name: Release ${{ needs.version.outputs.version }}
+ generate_release_notes: true
+ files: dist/*
diff --git a/.gitignore b/.gitignore
index 4abe45e..724fd88 100644
--- a/.gitignore
+++ b/.gitignore
@@ -71,6 +71,7 @@ instance/
# Sphinx documentation
docs/_build/
+docs/_autosummary/
# PyBuilder
.pybuilder/
@@ -166,3 +167,14 @@ cython_debug/
.DS_Store
.vscode
+.claude
+
+# uv
+# uv.lock is committed for reproducible installs — do not ignore it
+.python-version
+
+# hatch-vcs generated version file
+decider/_version.py
+
+# project-level extension configs (generated by decider_extensions)
+projects/*/configs/
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 1aa35ba..c2628f0 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -17,22 +17,18 @@ Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md).
Engagement starts with an Issue where conversations and debates can occur around [bugs](#bugs) and [feature requests](#feature-requests):
- ✅ **Do** search for a similar or existing Issue prior to submitting a new one.
-- ❌ **Do not** use Issues for personal support. Use [Discussions](https://github.com/your-repo/discussions) or [StackOverflow](https://stackoverflow.com/) instead.
+- ❌ **Do not** use Issues for personal support. Use [Discussions](https://github.com/capitecbankltd/dsp_north-polrs/discussions) or [StackOverflow](https://stackoverflow.com/) instead.
- ❌ **Do not** side-track or derail Issue threads. Stick to the topic, please.
- ❌ **Do not** post comments using just "+1", "++" or "👍". Use [Reactions](https://github.blog/2016-03-10-add-reactions-to-pull-requests-issues-and-comments/) instead.
👾 Bugs
-A bug is an error, flaw, or fault associated with *any part* of the project:
-
- ✅ **Do** search for a similar or existing Issue prior to submitting a new one.
- ✅ **Do** describe the bug concisely. **Avoid** adding extraneous code, logs, or screenshots.
- ✅ **Do** attach a minimal test or example to demonstrate the bug.
💡 Feature Requests
-A feature request is an improvement or new capability associated with *any part* of the project:
-
- ✅ **Do** search for a similar or existing Issue prior to submitting a new one.
- ✅ **Do** provide sufficient motivation and use case(s) for the feature.
- ❌ **Do not** submit multiple unrelated requests within one request.
@@ -41,94 +37,80 @@ A feature request is an improvement or new capability associated with *any part*
## 3. Vulnerabilities
-A vulnerability is a security-related risk associated with *any part* of the project or its dependencies:
-
-- ✅ **Do** refer to our [Security Policy](https://github.com/your-repo/security/policy) for more information.
-- ✅ **Do** report vulnerabilities via this [link](https://github.com/your-repo/security/advisories/new).
-- ❌ **Do not** report any Issues or mention vulnerabilities in public Discussions for discretionary purposes.
+- ✅ **Do** refer to our [Security Policy](https://github.com/capitecbankltd/dsp_north-polrs/security/policy) for more information.
+- ✅ **Do** report vulnerabilities via this [link](https://github.com/capitec/ml-decision-engine/security/advisories/new).
+- ❌ **Do not** open a public Issue or Discussion for security vulnerabilities.
## 4. Development
🌱 Branches
-- `develop` - Default branch for all feature development and Pull Requests.
-- `main` - Stable branch for all periodic releases.
+- `feature/*` — feature development; PR target is `main`.
+- `main` — stable branch; tagged releases are cut from here.
🔒 Dependencies
-* Python (>= 3.8)
-* `pip` for package management. Use `pip install -r requirements/all.txt` to install dependencies.
-* Optional: Set up your environment using `conda`, `virtualenv`, or another method. Refer to [Python virtual environments](https://docs.python.org/3/tutorial/venv.html) for guidance.
+- Python >= 3.10
+- [uv](https://docs.astral.sh/uv/) for environment and dependency management.
📦 Project Setup
-1. [Fork](https://github.com/your-repo/fork) the repository and create a branch from `develop`.
-2. Clone the forked repo, checkout your branch, and install the dependencies with `pip install -r requirements/all.txt`.
-3. Run tests using `pytest` to ensure everything is working correctly.
+```bash
+# 1. Clone and enter the repo
+git clone https://github.com/capitecbankltd/dsp_north-polrs.git
+cd dsp_north-polrs
+
+# 2. Install all dependencies (creates .venv automatically)
+uv sync --all-extras
+
+# 3. Run the tests
+uv run pytest
+```
+
+If you are behind a corporate proxy that uses a private CA, `uv` is already
+configured to use the system certificate store (`system-certs = true` in
+`pyproject.toml`).
📂 Directory Structure
-When contributing, please note the following key files and directories:
-├── docker
-│ ├── Dockerfile
-├── docs
-│ ├── index.rst
-│ ├── ...
-├── requirements
-│ ├── all.txt
-│ ├── ...
-├── spockflow
-│ ├── components
-│ │ ├── scorecard
-│ │ ├── tree
-│ │ ├── dtable
-│ ├── inference
-│ ├── ...
-├── core.py
-├── exceptions.py
-├── nodes.py
-├── tests
-│ ├── test_example.py
-│ ├── ...
-
-
-* `docker` - Contains all files related to Docker images.
-* `docs` - Documentation files.
-* `requirements` - Directory containing `.txt` files for different optional requirements.
-* `spockflow/components` - Contains all components for the Hamilton DAG, including:
- * `scorecard` - For scorecards.
- * `tree` - For decision trees.
- * `dtable` - For decision tables.
-* `spockflow/inference` - Files needed to serve the module as a live endpoint.
-* `core.py` - Contains code to inject components into the Hamilton DAG.
-* `exceptions.py` - Base for exceptions produced by various components.
-* `nodes.py` - Core module for all components.
-* `tests` - Contains unit tests for the project.
+```
+decider/ Core library
+ cli/ `decider` CLI (click)
+ config/ Versioned config management
+ modules/ Module primitives (expression, join, sequential, union)
+ serving/ HTTP servers (Starlette, Sanic)
+ templates/ Scaffolding templates and scaffold.py renderer
+ magics/ Jupyter %%module magic
+docs/examples/ Example notebooks
+projects/ End-to-end project examples
+tests/ pytest suite
+pyproject.toml Project metadata, dependencies, tool config
+uv.lock Locked dependency graph (committed)
+```
🏷 Naming Conventions
-- ✅ **Do** follow PEP 8 for naming conventions.
-- ✅ **Do** use descriptive names for files and modules.
-- ✅ **Do** name Python classes in `CamelCase` and functions in `snake_case`.
+- ✅ **Do** follow PEP 8.
+- ✅ **Do** name classes in `CamelCase` and functions/modules in `snake_case`.
🔍 Code Quality
- ✅ **Do** adhere to PEP 8 style guidelines.
-- ✅ **Do** use `black` for automatic code formatting.
+- ✅ **Do** use `ruff` or `black` for formatting before opening a PR.
🧪 Testing
-- ✅ **Do** write tests using `pytest`.
-- ✅ **Do** ensure all tests pass before submitting a Pull Request.
+- ✅ **Do** write tests under `tests/` using `pytest`.
+- ✅ **Do** ensure all tests pass before submitting a Pull Request (`uv run pytest`).
## 5. Pull Requests
-- ✅ **Do** ensure your branch is up to date with the `develop` branch.
-- ✅ **Do** ensure there are no conflicts with the `develop` branch.
-- ✅ **Do** make sure all tests pass and code is formatted using `black`.
-- ✅ **Do** provide a clear description of the changes and the purpose of the Pull Request.
+- ✅ **Do** ensure your branch is up to date with `main`.
+- ✅ **Do** ensure there are no merge conflicts.
+- ✅ **Do** make sure all tests pass.
+- ✅ **Do** provide a clear description of the changes and their purpose.
-> **TIP:** Make sure to review the existing codebase and follow the conventions used throughout the project.
+> **TIP:** Review the existing codebase and follow the conventions used throughout the project.
---
diff --git a/LICENSE b/LICENSE
index bf6b516..53f0701 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2024 Capitec
+Copyright (c) 2024-2026 Capitec
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.md b/README.md
index 5498e76..6bcc7aa 100644
--- a/README.md
+++ b/README.md
@@ -1,84 +1,212 @@
-# DSP Decision Engine
+# Decider
-[](https://www.python.org/downloads/)
+[](https://www.python.org/downloads/)
[](LICENSE)
-**SpockFlow** is a Python framework designed to create standalone micro-services that enrich data with actionable outputs. It supports both batch and live inference modes, and extends existing frameworks to simplify data flows, including policy rules and scoring. Leveraging Hamilton for traceability, SpockFlow provides a powerful, modular approach for data enrichment and model deployment.
+**Decider** is a Python framework for building, serving, and inspecting decision pipelines as versioned, deployable micro-services. Define pipelines from plain Python functions, compose them with `|` and `&`, save them as versioned JSON configs, and serve them over HTTP — all with a single consistent API.
## Table of Contents
-- [Introduction](docs/intro.md)
-- [Installation](docs/getting_started/install.md)
-- [Concepts](docs/concepts/index.md)
+- [Introduction](#introduction)
+- [Installation](#installation)
+- [Concepts](#concepts)
- [Usage Examples](#usage-examples)
+- [CLI](#cli)
- [Contributing](#contributing)
- [License](#license)
## Introduction
-SpockFlow is built to be extensible and modular, allowing the reuse of pipelines and configurations across multiple data flows. Its emphasis on runtime traceability and explainability is empowered by Hamilton, which helps track and visualize data lineage and identify process steps leading to specific outcomes.
+Decider is built around a few core ideas:
-
-
-For a more detailed introduction, see [Introduction](docs/main.rst).
+- **Functions as nodes** — plain Python functions that accept and return `polars.Expr` are wired into executable DAGs automatically, with no decorators or registries required.
+- **Composable pipelines** — modules chain with `|` (sequential) or merge with `&` (parallel union), making it easy to build complex pipelines from simple parts.
+- **Versioned configs** — every pipeline is serialisable to JSON. Configs are versioned, loadable by the server at startup, and hot-swappable without redeployment.
+- **Pluggable extensions** — new module types are registered into a discriminated union at runtime, so the server always knows how to reconstruct any module from its config.
## Installation
-To get started with SpockFlow, you need to install the required dependencies. Follow the instructions in the [Installation Guide](docs/getting_started/install.md) to set up your environment.
+```bash
+# Core library and CLI
+pip install decider
+
+# With serving dependencies
+pip install "decider[serve-starlette]" # uvicorn + starlette
+pip install "decider[serve-sanic]" # sanic
+
+# With the interactive graph visualiser
+pip install "decider[visualise]"
+
+# Everything
+pip install "decider[all]"
+```
+
+**With uv (recommended):**
```bash
-pip install spockflow[all]
+git clone https://github.com/capitecbankltd/dsp_north-polrs.git
+cd dsp_north-polrs
+uv sync --all-extras
```
## Concepts
-Explore the foundational principles and components of SpockFlow in the [Concepts](docs/concepts/index.md) section. This guide covers:
+### Modules
+
+A module is the basic unit of computation. The simplest way to create one is `generate_from_functions`, which turns plain functions into an executable module:
+
+- **Function name → output column** — `def dti_ratio(...)` produces a `dti_ratio` column.
+- **Parameter name → input column or sibling output** — parameters are resolved from the input DataFrame or from the output of another function in the same module.
+- **`config` parameter → Pydantic model injection** — declare `config: MyConfig` and the config fields are promoted onto the module itself.
+
+### Pipelines
+
+Modules compose with two operators:
+
+- `|` — **sequential**: `step_a | step_b | step_c` passes each step's output as the next step's input.
+- `&` — **union**: `module_a & module_b` merges both modules into a single compilation pass, computing all columns in one frame scan.
+
+### Versioned Configs
+
+Every module can be saved to a versioned JSON config and reconstructed from it:
-- **Decision Trees**: Automate decision-making processes based on defined conditions.
-- **Decision Tables**: Map input values to outputs based on conditions.
-- **Score Cards**: Assign scores to entities based on parameters.
-- **API Customization**: Customize and extend SpockFlow functionalities.
+```python
+await module.asave("main", config_manager)
+await config_manager.save_version(overwrite=True)
+```
+
+The server reads the latest version on startup and watches for updates.
+
+### Extensions
+
+Custom module types are registered with `register_graph_module` and auto-discovered by `initialize_decider` from an `extension_path` directory. This keeps the core library small while allowing domain-specific modules to be developed and shipped independently.
## Usage Examples
-Here are some examples of how to use SpockFlow:
+### Your first module
+
+```python
+import polars as pl
+from decider.modules.functional import generate_from_functions
+
+def dti_ratio(debt: pl.Expr, income: pl.Expr) -> pl.Expr:
+ return debt / income
+
+def credit_score(dti_ratio: pl.Expr) -> pl.Expr:
+ return pl.lit(800) - dti_ratio * 200
+
+Scorer = generate_from_functions("credit_scorer", dti_ratio, credit_score)
+scorer = Scorer(name="scorer")
+
+df = pl.DataFrame({"debt": [25_000.0], "income": [50_000.0]})
+result = scorer({"input": df})
+# shape: (1, 4) — debt, income, dti_ratio, credit_score
+```
+
+### Config injection
+
+```python
+from pydantic import BaseModel
+
+class ScorerConfig(BaseModel):
+ dti_weight: float = 200.0
+ score_base: float = 800.0
+
+def credit_score(dti_ratio: pl.Expr, config: ScorerConfig) -> pl.Expr:
+ return pl.lit(config.score_base) - dti_ratio * config.dti_weight
+
+Scorer = generate_from_functions("credit_scorer", dti_ratio, credit_score)
+scorer = Scorer(name="scorer", dti_weight=150.0) # config fields on the module
+```
+
+### Sequential pipeline
+
+```python
+features = FeatureModule(name="features")
+scorer = ScorerModule(name="scorer")
+flags = FlagModule(name="flags")
+
+pipeline = features | scorer | flags
+result = pipeline({"input": df})
+```
+
+### Join then score
+
+```python
+from decider.modules.primitives.join import JoinModule
+
+join = JoinModule(name="enrich", left="transactions", right="users", on="user_id", how="left")
+pipeline = join | scorer
-### Decision Trees
+result = pipeline({"transactions": txns_df, "users": users_df})
+```
-Create and use decision trees in SpockFlow:
+### Save and load
```python
-from spockflow.components.tree import Tree, Action
-from spockflow.core import initialize_spock_module
-import pandas as pd
-from typing_extensions import TypedDict
+import asyncio
+from decider.config.file import JsonFileConfigManager
+from decider.modules import GraphModule
+
+mgr = JsonFileConfigManager(basepath="./configs")
+asyncio.run(scorer.asave("main", mgr))
+asyncio.run(mgr.save_version(overwrite=True))
+
+# Reconstruct from disk
+fresh = JsonFileConfigManager(basepath="./configs")
+loaded = asyncio.run(fresh.get_latest())
+module = GraphModule.model_validate(loaded.config["main"]).root
+```
+
+## CLI
+
+```bash
+# Scaffold a new module
+decider template module CreditScorer
+
+# Scaffold a module into a shareable package
+decider template module CreditScorer --package mylib
+
+# Scaffold a new project
+decider template project fraud_detection
-class Reject(TypedDict):
- code: int
- description: str
+# Start a server
+decider serve # starlette on :8080
+decider serve --engine sanic --workers 4
+decider serve --port 9000 --reload
-RejectAction = Action[Reject]
+# Launch the interactive graph visualiser
+decider visualise --project-dir projects/loan_scoring
+```
+
+### Jupyter magic
+
+```python
+%load_ext decider.magics
+```
+
+```python
+%%module CreditScorer
-# Initialize Tree
-tree = Tree()
+class CreditScorerConfig(BaseModel):
+ weight: float = 200.0
-# Define conditions and actions
-@tree.condition(output=RejectAction(code=102, description="My first condition"))
-def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:
- return (d > 5) & (e > 5) & (f > 5)
+def dti_ratio(debt: pl.Expr, income: pl.Expr) -> pl.Expr:
+ return debt / income
-tree.visualize(get_value_name=lambda x: x["description"][0])
+def credit_score(dti_ratio: pl.Expr, config: CreditScorerConfig) -> pl.Expr:
+ return pl.lit(800) - dti_ratio * config.weight
```
-For more details and advanced usage, check out the [Concepts](docs/concepts/index.md) section.
+This writes `decider_extensions/credit_scorer/__init__.py`, reloads it, and injects `CreditScorer` into the notebook namespace. Add `--package mylib` to write into a proper uv src-layout package instead.
## Contributing
-We welcome contributions to SpockFlow! Please refer to our [Contributing Guide](CONTRIBUTING.md) for information on how to contribute.
+We welcome contributions to Decider! Please refer to our [Contributing Guide](CONTRIBUTING.md) for full details.
-- **Fork the repository** and create a branch from `develop`.
-- **Install dependencies** using `pip install -r requirements/all.txt`.
-- **Run tests** with `pytest` to ensure everything is working.
+- **Fork** the repository and create a branch from `main`.
+- **Install dependencies** with `uv sync --all-extras`.
+- **Run tests** with `uv run pytest`.
- **Submit a Pull Request** with a clear description of your changes.
## License
@@ -87,4 +215,4 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file
---
-Thank you for your interest in SpockFlow! We look forward to your contributions and feedback.
+Thank you for your interest in Decider! We look forward to your contributions and feedback.
diff --git a/SECURITY.md b/SECURITY.md
index d5d8c43..7479ed9 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -13,5 +13,7 @@ Older versions will not be supported if the vulnerability has been fixed in late
## Reporting a Vulnerability
-This repository makes use of GitHub's [private vulnerability reporting feature](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing/privately-reporting-a-security-vulnerability).
+This repository uses GitHub's [private vulnerability reporting](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing/privately-reporting-a-security-vulnerability).
Vulnerabilities can be reported using the [repository's security advisories](https://github.com/capitec/ml-decision-engine/security/advisories/new).
+
+Please do **not** open a public issue for security vulnerabilities.
diff --git a/Tasks.md b/Tasks.md
new file mode 100644
index 0000000..93963a9
--- /dev/null
+++ b/Tasks.md
@@ -0,0 +1,15 @@
+# Where to begin
+1. Create a hamilton dag with polars and execute
+2. Decorator to use both polars and pandas so can we create a function in hamilton that is a pandas function that executes
+
+``` module```
+def func1(df: pl.Polars) -> pl.Polars
+
+@convert_types
+def func2(func1: pd.DF) -> pd.DF
+
+@convert_types
+def func3(func2: pl.Polars) -> pl.Polars
+
+3. Config for components
+4. Higherlevel hamilton to build build dags
diff --git a/decider/__init__.py b/decider/__init__.py
new file mode 100644
index 0000000..b064796
--- /dev/null
+++ b/decider/__init__.py
@@ -0,0 +1,7 @@
+from .initialization import initialize_decider
+
+try:
+ from importlib.metadata import version, PackageNotFoundError
+ __version__ = version("decider")
+except PackageNotFoundError:
+ __version__ = "unknown"
diff --git a/decider/_ext.py b/decider/_ext.py
new file mode 100644
index 0000000..93f3693
--- /dev/null
+++ b/decider/_ext.py
@@ -0,0 +1,106 @@
+import inspect
+import typing as t
+from abc import ABC
+from pydantic import create_model, RootModel, BaseModel, Field, model_validator
+from warnings import warn
+
+
+class TypeDiscriminatedBaseModule(BaseModel, ABC):
+ type: str
+
+ _CLASS_TYPE_IDENTIFIER: t.ClassVar[str]
+
+ def __init_subclass__(cls, **kwargs):
+ """
+ We are basically using the below to ensure:
+ 1. the class implements a type: Literal['value'] so we can use that as a discriminator for the union of all implementations of this class
+ 2. We dont want there to be type: Literal['value'] = 'value' on the class because we making use of pydantic.model_dump(exclude_defaults=True) to exclude the type field when saving out modules, and if there is a default value then it will not be included in the dumdecider dict which breaks loading it back in.
+ 3. We want to store what the value of Literal is so we can automatically initialise it when we construct the model Model() rather than needing Model(type='value') every time
+ """
+ super().__init_subclass__(**kwargs)
+
+ # Skip abstract classes. as this will be used as a base for multiple implementations.
+ if inspect.isabstract(cls):
+ return
+
+ # Ensure `type` declared
+ if "type" not in cls.__annotations__:
+ raise TypeError(f"{cls.__name__} must define a 'type' annotation")
+
+ annotation = cls.__annotations__["type"]
+
+ if t.get_origin(annotation) is not t.Literal:
+ raise TypeError(
+ f"{cls.__name__}.type must be typing.Literal[...]"
+ )
+
+ literal_values = t.get_args(annotation)
+
+ if len(literal_values) != 1:
+ raise TypeError(
+ f"{cls.__name__}.type must be a single-value Literal"
+ )
+
+ if "type" in cls.__dict__:
+ raise TypeError(
+ f"{cls.__name__}.type must not define a default value"
+ )
+
+ cls._CLASS_TYPE_IDENTIFIER = literal_values[0]
+
+ @model_validator(mode="before")
+ @classmethod
+ def auto_set_type(cls, values):
+ if isinstance(values, dict) and not inspect.isabstract(cls):
+ values.setdefault("type", cls._CLASS_TYPE_IDENTIFIER)
+ return values
+
+
+_TExtenableRootType = t.TypeVar("_TExtenableRootType")
+
+class TExtendableModel(RootModel[_TExtenableRootType], t.Generic[_TExtenableRootType]):
+ root: _TExtenableRootType
+
+
+def create_extendable_model(
+ base_class: t.Type,
+ discriminator_field: str = "type",
+ model_name: str = "ExtendableModel"
+) -> t.Tuple[t.Type[TExtendableModel], t.Callable[[t.Type], None]]:
+ """
+ Creates an extendable model pattern that allows external packages to register
+ new types without creating hard dependencies.
+
+ Returns a tuple of (Model class, register function).
+ """
+
+ ExtendableModel = create_model(
+ model_name,
+ __base__=RootModel,
+ root=("RootType", ...)
+ )
+
+ _registered: t.Dict[str, t.Type] = {}
+
+ def _rebuild():
+ classes = list(_registered.values())
+ if not classes:
+ return
+ union = classes[0] if len(classes) == 1 else t.Union[tuple(classes)]
+ annotated = t.Annotated[union, Field(discriminator=discriminator_field)]
+ ExtendableModel.__annotations__["root"] = annotated
+ ExtendableModel.model_fields["root"].annotation = annotated
+ was_rebuilt = ExtendableModel.model_rebuild(
+ force=True,
+ _types_namespace={"RootType": annotated},
+ )
+ if was_rebuilt is not True:
+ warn(f"model_rebuild did not return True for {ExtendableModel.__name__}")
+
+ def register_provider(provider_class: t.Type):
+ assert issubclass(provider_class, base_class), f"Provider must be a subclass of {base_class.__name__}"
+ type_id = getattr(provider_class, "_CLASS_TYPE_IDENTIFIER", None) or provider_class.__name__
+ _registered[type_id] = provider_class
+ _rebuild()
+
+ return ExtendableModel, register_provider
diff --git a/decider/cli/__init__.py b/decider/cli/__init__.py
new file mode 100644
index 0000000..840c8c7
--- /dev/null
+++ b/decider/cli/__init__.py
@@ -0,0 +1,15 @@
+import click
+
+from .template import template
+from .serve import serve
+from .visualise import visualise
+
+
+@click.group()
+def cli():
+ """Decider — build, serve and inspect decision pipelines."""
+
+
+cli.add_command(template)
+cli.add_command(serve)
+cli.add_command(visualise)
diff --git a/decider/cli/_graph.py b/decider/cli/_graph.py
new file mode 100644
index 0000000..cf9dcde
--- /dev/null
+++ b/decider/cli/_graph.py
@@ -0,0 +1,310 @@
+"""
+Walk a BaseModule tree and produce graph structures for the visualiser.
+
+Two graph types:
+ build_graph(module) — module-level structural graph (pipeline view)
+ build_expression_graph(module) — expression node DAG inside one ExpressionModule
+
+Both return ModuleGraph. The module_ref on each GraphNode holds the live
+module object so the app can drill into it.
+"""
+
+import typing as t
+from dataclasses import dataclass, field
+
+
+@dataclass
+class GraphNode:
+ id: str
+ label: str
+ kind: str # "expression" | "sequential" | "join" | "union" | "col" | "config" | "unknown"
+ type_id: str
+ parent: t.Optional[str] = None
+ fields: t.Dict[str, t.Any] = field(default_factory=dict)
+ module_ref: t.Any = None # live BaseModule if drillable
+ drillable: bool = False
+
+
+@dataclass
+class GraphEdge:
+ source: str
+ target: str
+ label: str = ""
+
+
+@dataclass
+class ModuleGraph:
+ nodes: t.List[GraphNode] = field(default_factory=list)
+ edges: t.List[GraphEdge] = field(default_factory=list)
+
+ def to_graphviz(self) -> "graphviz.Digraph":
+ import graphviz
+ dot = graphviz.Digraph(graph_attr={"rankdir": "TB", "splines": "ortho"})
+
+ _KIND_COLOURS = {
+ "expression": "#4C9BE8",
+ "sequential": "#E8884C",
+ "join": "#4CE8A0",
+ "union": "#9B4CE8",
+ "col": "#888888",
+ "config": "#C8A850",
+ "unknown": "#AAAAAA",
+ }
+
+ for n in self.nodes:
+ colour = _KIND_COLOURS.get(n.kind, "#AAAAAA")
+ tooltip = "\n".join(f"{k}: {v}" for k, v in n.fields.items()) or n.type_id
+ shape = "ellipse" if n.kind in ("col", "config") else "box"
+ border = "bold" if n.drillable else ""
+ dot.node(
+ n.id,
+ label=n.label,
+ shape=shape,
+ style=f"filled,rounded,{border}".strip(","),
+ fillcolor=colour,
+ fontcolor="white",
+ tooltip=tooltip,
+ )
+ for e in self.edges:
+ dot.edge(e.source, e.target, label=e.label)
+ return dot
+
+
+# ── helpers ───────────────────────────────────────────────────────────────────
+
+def _kind(module) -> str:
+ type_id = getattr(module, "type", "")
+ if type_id == "sequential":
+ return "sequential"
+ if type_id == "join":
+ return "join"
+ if type_id == "union":
+ return "union"
+ if hasattr(module, "expand_nodes"):
+ return "expression"
+ return "unknown"
+
+
+def _config_fields(module) -> t.Dict[str, t.Any]:
+ _SKIP = {"type", "name", "steps", "modules", "left", "right", "on", "how"}
+ try:
+ raw = module.model_dump(exclude_defaults=False)
+ except Exception:
+ return {}
+ return {k: v for k, v in raw.items() if k not in _SKIP and not k.startswith("_")}
+
+
+# ── module-level graph ────────────────────────────────────────────────────────
+
+def _walk(
+ module,
+ graph: ModuleGraph,
+ parent_id: t.Optional[str] = None,
+ counter: t.Optional[t.List[int]] = None,
+) -> str:
+ if counter is None:
+ counter = [0]
+
+ counter[0] += 1
+ node_id = f"node_{counter[0]}"
+ type_id = getattr(module, "type", type(module).__name__)
+ name = getattr(module, "name", type_id)
+ kind = _kind(module)
+ drillable = kind in ("expression", "sequential", "join", "union")
+
+ graph.nodes.append(GraphNode(
+ id=node_id,
+ label=name,
+ kind=kind,
+ type_id=type_id,
+ parent=parent_id,
+ fields=_config_fields(module),
+ module_ref=module,
+ drillable=drillable,
+ ))
+
+ if parent_id is not None:
+ graph.edges.append(GraphEdge(source=parent_id, target=node_id))
+
+ if kind == "sequential":
+ prev = node_id
+ for step in module.steps:
+ child_id = _walk(step, graph, parent_id=node_id, counter=counter)
+ if graph.edges and graph.edges[-1].source == node_id:
+ graph.edges[-1] = GraphEdge(source=prev, target=child_id, label="then")
+ prev = child_id
+
+ elif kind == "join":
+ for side, ref in (("left", module.left), ("right", module.right)):
+ if hasattr(ref, "type"):
+ child_id = _walk(ref, graph, parent_id=node_id, counter=counter)
+ if graph.edges:
+ graph.edges[-1].label = side
+ else:
+ fid = f"frame_{ref}_{counter[0]}"
+ counter[0] += 1
+ graph.nodes.append(GraphNode(
+ id=fid, label=f'"{ref}"', kind="col", type_id="frame", parent=node_id,
+ ))
+ graph.edges.append(GraphEdge(source=node_id, target=fid, label=side))
+
+ elif kind == "union":
+ for child_mod in module.modules:
+ _walk(child_mod, graph, parent_id=node_id, counter=counter)
+
+ return node_id
+
+
+def build_graph(module) -> ModuleGraph:
+ """Module-level structural graph for any BaseModule tree."""
+ g = ModuleGraph()
+ _walk(module, g)
+ return g
+
+
+# ── expression node DAG ───────────────────────────────────────────────────────
+
+def build_expression_graph(module) -> ModuleGraph:
+ """
+ Return a computation DAG for an ExpressionModule showing individual
+ function nodes, their column inputs, and config injections.
+ """
+ from decider.modules.expression import ExternalInputNode, StaticValueNode, Node as ExprNode
+
+ g = ModuleGraph()
+ nodes = module.expand_nodes()
+
+ # add a function node for every expression node
+ for name, expr_node in nodes.items():
+ g.nodes.append(GraphNode(
+ id=f"fn_{name}",
+ label=name,
+ kind="expression",
+ type_id="function",
+ drillable=False,
+ ))
+
+ # add edges: inputs → function nodes
+ for name, expr_node in nodes.items():
+ for param, ref in expr_node.input_map.items():
+ if isinstance(ref, ExprNode):
+ g.edges.append(GraphEdge(source=f"fn_{ref.name}", target=f"fn_{name}", label=param))
+ elif isinstance(ref, ExternalInputNode):
+ col_id = f"col_{ref.input_name}"
+ if not any(n.id == col_id for n in g.nodes):
+ g.nodes.append(GraphNode(
+ id=col_id,
+ label=ref.input_name,
+ kind="col",
+ type_id="column",
+ drillable=False,
+ ))
+ g.edges.append(GraphEdge(source=col_id, target=f"fn_{name}", label=param))
+ elif isinstance(ref, StaticValueNode):
+ val = ref.value
+ cfg_id = f"cfg_{name}_{param}"
+ # show the config type name, not the full repr
+ cfg_label = type(val).__name__ if hasattr(val, "__class__") else str(val)
+ if not any(n.id == cfg_id for n in g.nodes):
+ g.nodes.append(GraphNode(
+ id=cfg_id,
+ label=cfg_label,
+ kind="config",
+ type_id="config",
+ drillable=False,
+ fields=val.model_dump() if hasattr(val, "model_dump") else {},
+ ))
+ g.edges.append(GraphEdge(source=cfg_id, target=f"fn_{name}", label=param))
+
+ return g
+
+
+# ── intermediate value extraction ─────────────────────────────────────────────
+
+def run_with_intermediates(
+ module,
+ inputs: t.Dict[str, "pl.DataFrame"],
+) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
+ """
+ Execute module and return a list of (label, DataFrame) pairs, one per
+ logical step, in execution order.
+
+ - ExpressionModule → one entry per compiled expression column, accumulated
+ - SequentialModule → one entry per step
+ - Others → single entry with final output
+ """
+ import polars as pl
+
+ kind = _kind(module)
+
+ if kind == "expression":
+ return _run_expression_intermediates(module, inputs)
+ elif kind == "sequential":
+ return _run_sequential_intermediates(module, inputs)
+ elif kind == "join":
+ return _run_join_intermediates(module, inputs)
+ else:
+ out = module(inputs)
+ if isinstance(out, pl.LazyFrame):
+ out = out.collect()
+ return [(getattr(module, "name", "output"), out)]
+
+
+def _run_expression_intermediates(
+ module,
+ inputs: t.Dict[str, "pl.DataFrame"],
+) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
+ import polars as pl
+
+ module.compile_expressions()
+ ce = module._compiled_expressions
+ frame = inputs.get(ce.input_frame)
+ if frame is None:
+ frame = next(iter(inputs.values()))
+ if isinstance(frame, pl.DataFrame):
+ frame = frame.lazy()
+
+ results = []
+ accumulated = frame
+ for col_name, expr in ce.expressions.items():
+ accumulated = accumulated.with_columns(expr.alias(col_name))
+ snapshot = accumulated.collect()
+ results.append((col_name, snapshot))
+
+ return results
+
+
+def _run_sequential_intermediates(
+ module,
+ inputs: t.Dict[str, "pl.DataFrame"],
+) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
+ import polars as pl
+
+ frames = {
+ k: v.lazy() if isinstance(v, pl.DataFrame) else v
+ for k, v in inputs.items()
+ }
+ current = frames.get("input") if "input" in frames else next(iter(frames.values()))
+
+ results = []
+ for step in module.steps:
+ frames["input"] = current
+ out = step(frames)
+ if isinstance(out, pl.LazyFrame):
+ out = out.collect()
+ current = out.lazy()
+ results.append((getattr(step, "name", step.type), out))
+
+ return results
+
+
+def _run_join_intermediates(
+ module,
+ inputs: t.Dict[str, "pl.DataFrame"],
+) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
+ import polars as pl
+
+ out = module(inputs)
+ if isinstance(out, pl.LazyFrame):
+ out = out.collect()
+ return [(getattr(module, "name", "join"), out)]
diff --git a/decider/cli/_visualise_app.py b/decider/cli/_visualise_app.py
new file mode 100644
index 0000000..301d991
--- /dev/null
+++ b/decider/cli/_visualise_app.py
@@ -0,0 +1,236 @@
+"""
+Streamlit app — launched by `decider visualise`.
+
+Env vars:
+ DECIDER_VISUALISE_PROJECT_DIR
+ DECIDER_VISUALISE_EXT_DIR (optional)
+ DECIDER_VISUALISE_CONFIG_DIR (optional)
+ DECIDER_VISUALISE_ROOT_MODULE (optional, default "main")
+"""
+
+import json
+import os
+import sys
+from pathlib import Path
+
+import polars as pl
+import streamlit as st
+
+# ── bootstrap ─────────────────────────────────────────────────────────────────
+
+_project_dir = Path(os.environ.get("DECIDER_VISUALISE_PROJECT_DIR", ".")).resolve()
+_repo_root = _project_dir.parent.parent
+if str(_repo_root) not in sys.path:
+ sys.path.insert(0, str(_repo_root))
+
+_ext_dir = os.environ.get("DECIDER_VISUALISE_EXT_DIR",
+ str(_project_dir / "decider_extensions"))
+_configs_dir = os.environ.get("DECIDER_VISUALISE_CONFIG_DIR",
+ str(_project_dir / "configs"))
+_root_module = os.environ.get("DECIDER_VISUALISE_ROOT_MODULE", "main")
+
+from decider.initialization import initialize_decider
+from decider.config.file import JsonFileConfigManager
+from decider.modules import GraphModule
+from decider.cli._graph import (
+ build_graph,
+ build_expression_graph,
+ run_with_intermediates,
+ _kind,
+)
+
+# ── page setup ────────────────────────────────────────────────────────────────
+
+st.set_page_config(
+ page_title="Decider Visualise",
+ layout="wide",
+ initial_sidebar_state="expanded",
+)
+
+# ── session state init ────────────────────────────────────────────────────────
+
+if "breadcrumb" not in st.session_state:
+ # Each entry: {"label": str, "module": BaseModule}
+ st.session_state.breadcrumb = []
+
+if "run_inputs" not in st.session_state:
+ st.session_state.run_inputs = None # Dict[str, pl.DataFrame] when set
+
+# ── load root module (cached) ─────────────────────────────────────────────────
+
+@st.cache_resource
+def _load_root(root_key: str):
+ initialize_decider(extension_path=_ext_dir)
+ import asyncio
+ mgr = JsonFileConfigManager(basepath=_configs_dir)
+ versioned = asyncio.run(mgr.get_latest())
+ module = GraphModule.model_validate(versioned.config[root_key]).root
+ return module, versioned
+
+
+# ── sidebar ───────────────────────────────────────────────────────────────────
+
+with st.sidebar:
+ st.header("Project")
+ st.caption(str(_project_dir))
+
+ root_key = st.text_input("Root module key", value=_root_module)
+ if st.button("↺ Reload config"):
+ st.cache_resource.clear()
+ st.session_state.breadcrumb = []
+ st.session_state.run_inputs = None
+ st.rerun()
+
+try:
+ root_module, versioned = _load_root(root_key)
+except Exception as e:
+ st.error(f"Could not load module: {e}")
+ st.stop()
+
+with st.sidebar:
+ st.divider()
+ st.caption(f"version {versioned.version}")
+ st.caption(f"type {root_module.type}")
+
+ # ── input data entry ──────────────────────────────────────────────────────
+ st.divider()
+ st.subheader("Run data")
+ st.caption("Paste JSON (column-oriented) to push data through the pipeline.")
+
+ default_cols = root_module.get_input_frame_keys()
+ json_placeholder = json.dumps(
+ {k: ["value1", "value2"] for k in
+ (root_module._compute_input_frame_keys() if hasattr(root_module, '_compute_input_frame_keys') else ["input"])},
+ indent=2,
+ )
+ raw_json = st.text_area("Input JSON", value="", height=180,
+ placeholder=json_placeholder)
+ if st.button("▶ Run"):
+ try:
+ parsed = json.loads(raw_json)
+ # support both {col: [...]} (single frame) and {"frame": {col: [...]}}
+ if parsed and isinstance(next(iter(parsed.values())), dict):
+ st.session_state.run_inputs = {
+ k: pl.DataFrame(v) for k, v in parsed.items()
+ }
+ else:
+ st.session_state.run_inputs = {"input": pl.DataFrame(parsed)}
+ except Exception as e:
+ st.error(f"Invalid JSON: {e}")
+
+ if st.session_state.run_inputs is not None:
+ if st.button("✕ Clear run"):
+ st.session_state.run_inputs = None
+ st.rerun()
+
+
+# ── breadcrumb navigation ─────────────────────────────────────────────────────
+
+# current module is root unless the user has drilled in
+crumb_stack = st.session_state.breadcrumb
+current_module = crumb_stack[-1]["module"] if crumb_stack else root_module
+
+# render breadcrumb bar
+crumb_parts = [{"label": root_key, "module": root_module}] + crumb_stack
+cols = st.columns([1] * len(crumb_parts) + [8])
+for i, crumb in enumerate(crumb_parts):
+ with cols[i]:
+ is_last = i == len(crumb_parts) - 1
+ if is_last:
+ st.markdown(f"**{crumb['label']}**")
+ else:
+ if st.button(crumb["label"], key=f"crumb_{i}"):
+ st.session_state.breadcrumb = crumb_stack[: i] # pop back to i
+ st.rerun()
+
+if crumb_stack:
+ st.caption(f"type: {current_module.type} · name: {current_module.name}")
+
+st.divider()
+
+# ── main content: tabs ────────────────────────────────────────────────────────
+
+tab_graph, tab_run, tab_config = st.tabs(["Graph", "Run output", "Config"])
+
+# ── TAB: Graph ────────────────────────────────────────────────────────────────
+
+with tab_graph:
+ kind = _kind(current_module)
+
+ if kind == "expression":
+ # show the intra-module expression DAG
+ st.caption("Expression node DAG — functions, column inputs and config injections")
+ eg = build_expression_graph(current_module)
+ dot = eg.to_graphviz()
+ st.graphviz_chart(dot.source, use_container_width=True)
+
+ # node table
+ rows = [{"node": n.label, "kind": n.kind,
+ **{f"cfg:{k}": v for k, v in n.fields.items()}}
+ for n in eg.nodes]
+ if rows:
+ st.dataframe(pl.DataFrame(rows, infer_schema_length=len(rows)),
+ use_container_width=True)
+
+ else:
+ # show the module-level structural graph
+ g = build_graph(current_module)
+ col_g, col_d = st.columns([2, 1])
+
+ with col_g:
+ dot = g.to_graphviz()
+ st.graphviz_chart(dot.source, use_container_width=True)
+
+ with col_d:
+ st.subheader("Modules")
+ for n in g.nodes:
+ if not n.drillable:
+ continue
+ c1, c2 = st.columns([4, 1])
+ with c1:
+ tag = f"`{n.type_id}`"
+ cfg = " · " + " ".join(f"{k}={v}" for k, v in n.fields.items()) if n.fields else ""
+ st.markdown(f"**{n.label}** {tag}{cfg}")
+ with c2:
+ if st.button("→", key=f"drill_{n.id}",
+ help=f"Drill into {n.label}"):
+ st.session_state.breadcrumb = crumb_stack + [
+ {"label": n.label, "module": n.module_ref}
+ ]
+ st.rerun()
+
+# ── TAB: Run output ───────────────────────────────────────────────────────────
+
+with tab_run:
+ if st.session_state.run_inputs is None:
+ st.info("Paste input data in the sidebar and click **▶ Run** to see intermediate outputs.")
+ else:
+ inputs = st.session_state.run_inputs
+
+ st.subheader("Input")
+ for frame_key, df in inputs.items():
+ st.caption(f"frame: `{frame_key}`")
+ st.dataframe(df, use_container_width=True)
+
+ st.subheader("Intermediates")
+ try:
+ intermediates = run_with_intermediates(current_module, inputs)
+ except Exception as e:
+ st.error(f"Execution error: {e}")
+ intermediates = []
+
+ for label, df in intermediates:
+ with st.expander(f"after **{label}**", expanded=True):
+ # highlight newly-added columns vs the input
+ input_cols = set(next(iter(inputs.values())).columns)
+ new_cols = [c for c in df.columns if c not in input_cols]
+ st.caption(f"new columns: {', '.join(new_cols) if new_cols else '(none)'}")
+ st.dataframe(df, use_container_width=True)
+
+# ── TAB: Config ───────────────────────────────────────────────────────────────
+
+with tab_config:
+ try:
+ st.json(current_module.model_dump())
+ except Exception:
+ st.json(versioned.config)
diff --git a/decider/cli/serve.py b/decider/cli/serve.py
new file mode 100644
index 0000000..78adbee
--- /dev/null
+++ b/decider/cli/serve.py
@@ -0,0 +1,92 @@
+import typing as t
+import click
+
+
+_ENGINES = ("starlette", "sanic")
+
+
+def _settings():
+ from decider.settings import settings
+ return settings.serve
+
+
+def _default_host() -> str:
+ return _settings().host
+
+
+def _default_port() -> int:
+ return _settings().port
+
+
+def _default_workers() -> int:
+ from decider.settings import _default_workers as _dw
+ return _settings().workers or _dw()
+
+
+@click.command()
+@click.option("--engine", "-e", default="starlette", type=click.Choice(_ENGINES), show_default=True)
+@click.option("--host", "-h", default=_default_host, show_default=True)
+@click.option("--port", "-p", default=_default_port, show_default=True, type=int)
+@click.option("--workers", "-w", default=_default_workers, show_default=True, type=int)
+@click.option("--reload", is_flag=True, help="Enable hot-reload (dev only).")
+def serve(engine: str, host: str, port: int, workers: int, reload: bool):
+ """Start a Decider inference server.
+
+ \b
+ host, port and workers read from DECIDER_SERVE__HOST / _PORT / _WORKERS
+ when not supplied on the command line. workers falls back to nproc*2+1.
+
+ Examples:
+ decider serve
+ decider serve --engine sanic --workers 4
+ decider serve --port 9000 --reload
+ """
+ if engine == "starlette":
+ _serve_starlette(host, port, workers, reload)
+ else:
+ _serve_sanic(host, port, workers, reload)
+
+
+def _serve_starlette(host: str, port: int, workers: int, reload: bool):
+ try:
+ import uvicorn
+ except ImportError:
+ raise click.ClickException(
+ "uvicorn is required for the starlette engine.\n"
+ "Install it with: pip install uvicorn"
+ )
+ click.echo(f"Starting Starlette server on {host}:{port} workers={workers}")
+ uvicorn.run(
+ "decider.serving.servers.starlette:get_app",
+ host=host,
+ port=port,
+ workers=workers,
+ reload=reload,
+ factory=True,
+ )
+
+
+def _serve_sanic(host: str, port: int, workers: int, reload: bool):
+ try:
+ from sanic import Sanic
+ from sanic.worker.loader import AppLoader
+ except ImportError:
+ raise click.ClickException(
+ "sanic is required for the sanic engine.\n"
+ "Install it with: pip install sanic"
+ )
+ click.echo(f"Starting Sanic server on {host}:{port} workers={workers}")
+ from decider.serving.servers.sanic import create_app
+ # AppLoader ensures the factory is called inside each worker process so
+ # Sanic's _app_registry is populated correctly in every worker — not just
+ # the main process.
+ loader = AppLoader(factory=create_app)
+ app = loader.load()
+ app.run(
+ host=host,
+ port=port,
+ workers=workers,
+ auto_reload=reload,
+ single_process=(workers == 1),
+ app_loader=loader,
+ )
diff --git a/decider/cli/template.py b/decider/cli/template.py
new file mode 100644
index 0000000..53c462d
--- /dev/null
+++ b/decider/cli/template.py
@@ -0,0 +1,86 @@
+import os
+import sys
+from pathlib import Path
+
+import click
+
+
+@click.group()
+def template():
+ """Scaffold modules and projects from templates."""
+
+
+@template.command("module")
+@click.argument("class_name")
+@click.option(
+ "--package", "-p",
+ default=None,
+ metavar="PKG",
+ help="Write into a uv src-layout package instead of a flat extension.",
+)
+@click.option(
+ "--ext-dir", "-e",
+ default=None,
+ metavar="DIR",
+ help="Extension directory (default: ./decider_extensions).",
+)
+def module_cmd(class_name: str, package: str, ext_dir: str):
+ """Scaffold a new module called CLASS_NAME.
+
+ \b
+ Flat (inline):
+ decider template module CreditScorer
+
+ As a shareable package:
+ decider template module CreditScorer --package mylib
+ """
+ from decider.templates.scaffold import write_inline_module, write_package_module
+
+ ext_path = Path(ext_dir).resolve() if ext_dir else Path.cwd() / "decider_extensions"
+
+ placeholder = (
+ "def my_function(column_name: pl.Expr) -> pl.Expr:\n"
+ " return column_name * 1.0\n"
+ )
+
+ if package:
+ mod_file, init_file = write_package_module(ext_path, class_name, package, placeholder)
+ click.echo(click.style("created", fg="green") + f" {mod_file}")
+ click.echo(click.style("updated", fg="cyan") + f" {init_file}")
+ pyproject = ext_path / package / "pyproject.toml"
+ if pyproject.exists():
+ click.echo(click.style("created", fg="green") + f" {pyproject}")
+ else:
+ init_file = write_inline_module(ext_path, class_name, placeholder)
+ click.echo(click.style("created", fg="green") + f" {init_file}")
+
+ click.echo(
+ "\nEdit the file above, then load it with:\n"
+ f" from decider.initialization import initialize_decider\n"
+ f" initialize_decider(extension_path={str(ext_path)!r})"
+ )
+
+
+@template.command("project")
+@click.argument("project_name")
+@click.option(
+ "--dir", "-d",
+ "projects_dir",
+ default=None,
+ metavar="DIR",
+ help="Parent directory for the new project (default: ./projects).",
+)
+def project_cmd(project_name: str, projects_dir: str):
+ """Scaffold a new project directory called PROJECT_NAME."""
+ from decider.templates.scaffold import write_project
+
+ parent = Path(projects_dir).resolve() if projects_dir else Path.cwd() / "projects"
+ try:
+ project_dir = write_project(parent, project_name)
+ except FileExistsError as e:
+ raise click.ClickException(str(e))
+
+ click.echo(click.style("created", fg="green") + f" {project_dir}/")
+ for f in sorted(project_dir.rglob("*")):
+ if f.is_file():
+ click.echo(f" {f.relative_to(project_dir.parent)}")
diff --git a/decider/cli/visualise.py b/decider/cli/visualise.py
new file mode 100644
index 0000000..ab6e242
--- /dev/null
+++ b/decider/cli/visualise.py
@@ -0,0 +1,92 @@
+import os
+import subprocess
+import sys
+from pathlib import Path
+
+import click
+
+
+@click.command()
+@click.option(
+ "--project-dir", "-d",
+ default=None,
+ metavar="DIR",
+ help="Project directory (default: current directory).",
+)
+@click.option(
+ "--ext-dir", "-e",
+ default=None,
+ metavar="DIR",
+ help="Extension directory (default: /decider_extensions).",
+)
+@click.option(
+ "--config-dir", "-c",
+ default=None,
+ metavar="DIR",
+ help="Config directory (default: /configs).",
+)
+@click.option(
+ "--root-module", "-r",
+ default="main",
+ show_default=True,
+ metavar="KEY",
+ help="Root module key inside the versioned config.",
+)
+@click.option(
+ "--port", "-p",
+ default=8501,
+ show_default=True,
+ type=int,
+ help="Port for the Streamlit server.",
+)
+def visualise(
+ project_dir: str,
+ ext_dir: str,
+ config_dir: str,
+ root_module: str,
+ port: int,
+):
+ """Launch the interactive module graph browser (Streamlit).
+
+ \b
+ Examples:
+ decider visualise
+ decider visualise --project-dir projects/loan_scoring
+ decider visualise --root-module main --port 8502
+ """
+ try:
+ import streamlit # noqa: F401
+ except ImportError:
+ raise click.ClickException(
+ "streamlit is required for this command.\n"
+ "Install it with: pip install streamlit"
+ )
+
+ app_file = Path(__file__).parent / "_visualise_app.py"
+ resolved_project = Path(project_dir).resolve() if project_dir else Path.cwd()
+
+ env = {
+ **os.environ,
+ "DECIDER_VISUALISE_PROJECT_DIR": str(resolved_project),
+ "DECIDER_VISUALISE_ROOT_MODULE": root_module,
+ }
+ if ext_dir:
+ env["DECIDER_VISUALISE_EXT_DIR"] = str(Path(ext_dir).resolve())
+ if config_dir:
+ env["DECIDER_VISUALISE_CONFIG_DIR"] = str(Path(config_dir).resolve())
+
+ click.echo(f"Launching visualiser for {resolved_project.name} on http://localhost:{port}")
+
+ cmd = [
+ sys.executable, "-m", "streamlit", "run",
+ str(app_file),
+ "--server.port", str(port),
+ "--server.headless", "true",
+ "--browser.gatherUsageStats", "false",
+ ]
+ try:
+ result = subprocess.run(cmd, env=env)
+ if result.returncode not in (0, -2): # -2 = SIGINT
+ sys.exit(result.returncode)
+ except KeyboardInterrupt:
+ pass
diff --git a/decider/config/__init__.py b/decider/config/__init__.py
new file mode 100644
index 0000000..f5cb1b2
--- /dev/null
+++ b/decider/config/__init__.py
@@ -0,0 +1,22 @@
+from .core import CoreConfigManager, VersionedConfig
+from .file import JsonFileConfigManager, YamlFileConfigManager, TomlFileConfigManager
+from ._ext import ConfigManager, register_config_manager
+
+
+def _register_builtins() -> None:
+ register_config_manager(JsonFileConfigManager)
+ register_config_manager(YamlFileConfigManager)
+ register_config_manager(TomlFileConfigManager)
+
+
+_register_builtins()
+
+__all__ = [
+ "CoreConfigManager",
+ "VersionedConfig",
+ "JsonFileConfigManager",
+ "YamlFileConfigManager",
+ "TomlFileConfigManager",
+ "ConfigManager",
+ "register_config_manager",
+]
diff --git a/decider/config/_ext.py b/decider/config/_ext.py
new file mode 100644
index 0000000..8bf9ff1
--- /dev/null
+++ b/decider/config/_ext.py
@@ -0,0 +1,17 @@
+"""
+Enables external packages to register custom ConfigManager implementations
+without creating hard dependencies. Follows the same _ext.py pattern as
+decider/modules/_ext.py — a pydantic discriminated union keyed on `type`.
+"""
+import typing as t
+from .core import CoreConfigManager
+from decider._ext import create_extendable_model
+
+
+ConfigManager, register_config_manager = create_extendable_model(
+ CoreConfigManager,
+ discriminator_field="type",
+ model_name="ConfigManagerModel",
+)
+
+__all__ = ["ConfigManager", "register_config_manager"]
diff --git a/decider/config/_ext.pyi b/decider/config/_ext.pyi
new file mode 100644
index 0000000..d880d46
--- /dev/null
+++ b/decider/config/_ext.pyi
@@ -0,0 +1,16 @@
+"""
+Stub file for static type checkers.
+ConfigManagerModel is generated dynamically; this gives pyright/mypy a
+concrete type.
+"""
+import typing as t
+from .core import CoreConfigManager
+from decider._ext import TExtendableModel
+
+ConfigManager = TExtendableModel[CoreConfigManager]
+
+
+def register_config_manager(provider_class: t.Type[CoreConfigManager]) -> None: ...
+
+
+__all__: list[str]
diff --git a/decider/config/base.py b/decider/config/base.py
new file mode 100644
index 0000000..cbf9c1a
--- /dev/null
+++ b/decider/config/base.py
@@ -0,0 +1,113 @@
+import typing as t
+from pydantic import BaseModel, ConfigDict, PrivateAttr, model_serializer, model_validator, SerializerFunctionWrapHandler, SerializationInfo
+from .versioned import Version, get_current_versioned_config
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor, FrameNode
+
+T = t.TypeVar("T", bound=BaseModel)
+
+DUMP_TRIGGER_KEY = "dump_to_versioned_config"
+
+
+class BaseConfig(BaseModel, t.Generic[T]):
+ model_config = ConfigDict(extra="allow")
+ config_key: str
+ _constructed_model: T
+ _loaded_version: Version
+ _MODEL_CLASS: t.ClassVar[t.Type[T]]
+
+ @model_validator(mode="after")
+ def _construct_model(self) -> "t.Self":
+ """Construct the internal model from the config dict."""
+ constructed_model = self.model_extra.pop("_constructed_model", None)
+ if constructed_model is not None:
+ assert isinstance(constructed_model, self._MODEL_CLASS), (
+ f"_constructed_model must be an instance of {self._MODEL_CLASS.__name__}, got {type(constructed_model).__name__}"
+ )
+ loaded_version = self.model_extra.pop("_loaded_version", Version(-float('inf'),0,0))
+ assert isinstance(loaded_version, Version), (
+ f"_loaded_version must be an instance of Version or None, got {type(loaded_version).__name__}"
+ )
+
+ self._constructed_model = constructed_model
+ self._loaded_version = loaded_version
+ return self
+
+ self.reload(force=True)
+ return self
+
+ @classmethod
+ def from_model(cls, model: T, config_key:str) -> "BaseConfig[T]":
+ """Create a BaseConfig instance from a constructed model. This is the preferred way to create a new config, as it ensures the internal state is consistent."""
+ return cls(
+ config_key=config_key,
+ _constructed_model=model,
+ )
+
+ def reload(self, force=False) -> None:
+ """Reload the internal model from the current versioned config. Call this if you know the underlying config has changed and you want to refresh the model."""
+ versioned_config = get_current_versioned_config()
+ if versioned_config is None:
+ raise RuntimeError("No versioned config found in context. Make sure to use current_version_context() when accessing the config.")
+ if not force and versioned_config.version == self._loaded_version:
+ return # No need to reload if the version hasn't changed
+ config_data = versioned_config.config.get(self.config_key)
+ if config_data is None:
+ raise RuntimeError(f"Config key {self.config_key!r} not found in versioned config.")
+ self._constructed_model = self._MODEL_CLASS.model_validate(config_data)
+ self._loaded_version = versioned_config.version
+
+ @model_serializer(mode="wrap")
+ def serialize(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict:
+ result = handler(self)
+ ctx = info.context
+ if ctx.get(DUMP_TRIGGER_KEY, False):
+ versioned_config = get_current_versioned_config()
+ if versioned_config is None:
+ raise RuntimeError("No versioned config found in context. Make sure to use current_version_context() when accessing the config.")
+ # Update the versioned config with the current model's data
+ versioned_config.config[self.config_key] = self._constructed_model.model_dump(context=ctx)
+ return result
+
+
+TModule = t.TypeVar("TModule", bound="BaseModuleT")
+# Forward ref — resolved at runtime to avoid circular imports
+BaseModuleT = t.Any
+
+
+class ConfigModule(BaseConfig[TModule]):
+ """A BaseConfig whose _constructed_model is a BaseModule.
+
+ Delegates get_frame_nodes (and therefore __call__ / compile) to the
+ underlying module so the config layer is transparent at execution time.
+ The type discriminator required by TypeDiscriminatedBaseModule is
+ inherited from the wrapped module.
+ """
+
+ _MODEL_CLASS: t.ClassVar[t.Type] # set by subclasses or inferred
+
+ @classmethod
+ def for_module_class(cls, module_class: t.Type[TModule]) -> t.Type["ConfigModule[TModule]"]:
+ """Factory that produces a ConfigModule subclass bound to a specific module class."""
+ return t.cast(
+ t.Type["ConfigModule[TModule]"],
+ type(
+ f"{module_class.__name__}Config",
+ (ConfigModule,),
+ {"_MODEL_CLASS": module_class, "__module__": module_class.__module__},
+ ),
+ )
+
+ # ------------------------------------------------------------------
+ # Delegate module interface to _constructed_model
+ # ------------------------------------------------------------------
+
+ def get_frame_nodes(self, executor: "Executor") -> t.List["FrameNode"]:
+ return self._constructed_model.get_frame_nodes(executor)
+
+ def compile(self, executor: "Executor"):
+ return self._constructed_model.compile(executor)
+
+ def __call__(self, inputs, executor=None):
+ return self._constructed_model(inputs, executor=executor)
\ No newline at end of file
diff --git a/decider/config/core.py b/decider/config/core.py
new file mode 100644
index 0000000..dcbdf50
--- /dev/null
+++ b/decider/config/core.py
@@ -0,0 +1,156 @@
+import asyncio
+import typing as t
+from abc import abstractmethod
+
+from pydantic import PrivateAttr
+from decider._ext import TypeDiscriminatedBaseModule
+from .versioned import VersionedConfig, Version, VersionPart, with_versioned_config
+
+
+class CoreConfigManager(TypeDiscriminatedBaseModule):
+ """Base pydantic model for versioned config managers.
+
+ Runtime state (_current, _lock) is stored in PrivateAttr so pydantic
+ doesn't include it in serialisation. Subclasses implement the four
+ storage primitives; the public API is fully implemented here.
+
+ _dirty state lives on VersionedConfig.is_dirty; _lock guards all reads
+ and writes of _current.
+ """
+
+ _current: t.Optional[VersionedConfig] = PrivateAttr(default=None)
+ _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
+
+ # ------------------------------------------------------------------
+ # Abstract storage primitives
+ # ------------------------------------------------------------------
+
+ @abstractmethod
+ async def _load_version(self, version: str) -> VersionedConfig: ...
+
+ @abstractmethod
+ async def _write_version(self, versioned_config: VersionedConfig) -> None: ...
+
+ @abstractmethod
+ async def _version_exists(self, version: str) -> bool: ...
+
+ @abstractmethod
+ async def _latest_version(self) -> t.Optional[Version]: ...
+
+ # ------------------------------------------------------------------
+ # Public API
+ # ------------------------------------------------------------------
+
+ def get(self) -> VersionedConfig:
+ if self._current is None:
+ raise RuntimeError("No versioned config loaded. Call pull_version() or create_version() first.")
+ return self._current
+
+ async def get_latest(self) -> VersionedConfig:
+ latest_version = await self._latest_version()
+ if latest_version is None:
+ raise RuntimeError("No versions available in the store.")
+ if self._current is not None and self._current.version >= latest_version:
+ return self._current
+ async with self._lock:
+ if self._current is not None and self._current.version >= latest_version:
+ return self._current
+ _current = await self._load_version(latest_version)
+ self._current = _current
+ return _current
+
+
+ async def latest_version_context(self) -> t.ContextManager[VersionedConfig]:
+ """Context manager to get the latest versioned config. Use this in any code that needs access to the config."""
+ config = await self.get_latest()
+ return with_versioned_config(config)
+
+ def current_version_context(self) -> t.ContextManager[VersionedConfig]:
+ """Context manager to get the current versioned config. Use this in any code that needs access to the config."""
+ config = self.get()
+ return with_versioned_config(config)
+
+ async def create_version(self, bump: VersionPart = VersionPart.MINOR, force: bool = False) -> VersionedConfig:
+ async with self._lock:
+ if self._current is None:
+ self._current = VersionedConfig(version=Version(0,0,0), config={})
+ return self._current
+ if not force:
+ latest_version = await self._latest_version()
+ if latest_version is not None and latest_version != self._current.version:
+ raise ValueError(
+ f"Current version {self._current.version!r} should exactly match latest version {latest_version!r} before creating a new version."
+ + (
+ "You are behind the latest version please run check_for_updates to fetch the latest version, or pass force=True to ignore this check."
+ if latest_version > self._current.version else
+ "You are ahead of the latest version, which likely means you have un-pushed changes. Please push them before creating a new version, or pass force=True to ignore this check."
+ )
+ )
+ self._current = VersionedConfig(
+ version=self._current.version.bump(bump),
+ config=self._current.config.copy()
+ )
+ return self._current
+
+ async def save_version(self, overwrite: bool = False) -> None:
+ async with self._lock:
+ if self._current is None:
+ raise RuntimeError("No current version to save. Call create_version() first.")
+
+ exists = await self._version_exists(self._current.version)
+ if exists and not overwrite:
+ raise FileExistsError(
+ f"Version {self._current.version!r} already exists. "
+ "Pass overwrite=True to overwrite."
+ )
+
+ await self._write_version(self._current)
+
+ async def check_for_updates(self) -> t.Tuple[t.Optional[Version], bool]:
+ latest = await self._latest_version()
+ if latest is None:
+ return None, False
+ async with self._lock:
+ current = self._current
+ if current is None:
+ return latest, True
+ has_update = latest > current.version
+ return latest, has_update
+
+ async def pull_version(
+ self,
+ version: t.Optional[str] = None,
+ force: bool = False,
+ ) -> VersionedConfig:
+ async with self._lock:
+
+ target = version or await self._latest_version()
+ if target is None:
+ raise RuntimeError("No versions available in the store.")
+ if not force and self._current is not None:
+ if self._current.version >= target.version:
+ raise ValueError(
+ f"Current version {self._current.version!r} is newer or equal to target version {target!r}. "
+ "Pass force=True to ignore this check."
+ )
+
+ self._current = await self._load_version(target)
+ return self._current
+
+ async def subscribe_version_updates(self, force: bool = True) -> None:
+ """Poll for new versions and auto-pull. Safe to run as a background task."""
+ from decider.settings import settings, SETTINGS_DEFAULT_CONFIG_POLL_DURATION_S
+
+ while True:
+ try:
+ poll_seconds: int = getattr(
+ settings, "config_poll_duration_s", SETTINGS_DEFAULT_CONFIG_POLL_DURATION_S
+ )
+ await asyncio.sleep(poll_seconds)
+ _, has_update = await self.check_for_updates()
+ if has_update:
+ await self.pull_version(force=force)
+ except asyncio.CancelledError:
+ raise
+ except Exception:
+ pass
diff --git a/decider/config/file.py b/decider/config/file.py
new file mode 100644
index 0000000..1a2197b
--- /dev/null
+++ b/decider/config/file.py
@@ -0,0 +1,162 @@
+import os
+import typing as t
+from abc import abstractmethod
+
+from pydantic import Field
+from .core import CoreConfigManager, VersionedConfig
+from .versioned import Version
+
+
+class BaseFileConfigManager(CoreConfigManager):
+ """Shared logic for all file-backed config managers.
+
+ Files are stored as:
+ {basepath}/{version}/{key}.{ext}
+
+ All keys within a version directory are merged into VersionedConfig.config.
+
+ Example:
+ configs/1.0.0/tree.json
+ configs/1.0.0/scorecard.json
+ → VersionedConfig(version="1.0.0", config={"tree": {...}, "scorecard": {...}})
+ """
+
+ basepath: str = Field(default="configs")
+
+ # Subclasses set this to the file extension (without leading dot).
+ _file_ext: t.ClassVar[str]
+
+ def _version_dir(self, version) -> str:
+ return os.path.join(self.basepath, str(version))
+
+ def _file_path(self, version, key: str) -> str:
+ return os.path.join(self._version_dir(version), f"{key}.{self._file_ext}")
+
+ def _list_versions(self) -> t.List[str]:
+ if not os.path.isdir(self.basepath):
+ return []
+ valid = []
+ for entry in os.listdir(self.basepath):
+ if not os.path.isdir(os.path.join(self.basepath, entry)):
+ continue
+ try:
+ parts = entry.split(".")
+ if len(parts) == 3:
+ int(parts[0]); int(parts[1]); int(parts[2])
+ valid.append(entry)
+ except (ValueError, AttributeError):
+ pass
+ return sorted(valid, key=lambda v: tuple(int(x) for x in v.split(".")))
+
+ # ------------------------------------------------------------------
+ # Format-specific primitives (override in subclasses)
+ # ------------------------------------------------------------------
+
+ @abstractmethod
+ def _deserialize(self, _raw: bytes) -> t.Any: ...
+
+ @abstractmethod
+ def _serialize(self, _data: t.Any) -> bytes: ...
+
+ # ------------------------------------------------------------------
+ # ConfigManager storage primitives
+ # ------------------------------------------------------------------
+
+ async def _load_version(self, version: str) -> VersionedConfig:
+ version_dir = self._version_dir(version)
+ if not os.path.isdir(version_dir):
+ raise FileNotFoundError(f"Version directory not found: {version_dir!r}")
+
+ ext = f".{self._file_ext}"
+ config: t.Dict[str, t.Any] = {}
+ for filename in sorted(os.listdir(version_dir)):
+ if filename.endswith(ext):
+ key = filename[: -len(ext)]
+ with open(os.path.join(version_dir, filename), "rb") as f:
+ config[key] = self._deserialize(f.read())
+
+ return VersionedConfig(version=version, config=config)
+
+ async def _write_version(self, versioned_config: VersionedConfig) -> None:
+ for key, value in versioned_config.config.items():
+ path = self._file_path(versioned_config.version, key)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, "wb") as f:
+ f.write(self._serialize(value))
+
+ async def _version_exists(self, version: str) -> bool:
+ return os.path.isdir(self._version_dir(version))
+
+ async def _latest_version(self) -> t.Optional[Version]:
+ versions = self._list_versions()
+ if not versions:
+ return None
+ v = versions[-1]
+ parts = v.split(".")
+ return Version(int(parts[0]), int(parts[1]), int(parts[2]))
+
+
+class JsonFileConfigManager(BaseFileConfigManager):
+ type: t.Literal["file:json"]
+ _file_ext: t.ClassVar[str] = "json"
+
+ def _deserialize(self, raw: bytes) -> t.Any:
+ import json
+ return json.loads(raw)
+
+ def _serialize(self, data: t.Any) -> bytes:
+ import json
+ return json.dumps(data, indent=2).encode()
+
+
+class YamlFileConfigManager(BaseFileConfigManager):
+ type: t.Literal["file:yaml"]
+ _file_ext: t.ClassVar[str] = "yaml"
+
+ def _deserialize(self, raw: bytes) -> t.Any:
+ try:
+ import yaml
+ except ImportError:
+ raise ImportError(
+ "PyYAML is required to use yaml format. "
+ "Install it with: pip install pyyaml"
+ )
+ return yaml.safe_load(raw)
+
+ def _serialize(self, data: t.Any) -> bytes:
+ try:
+ import yaml
+ except ImportError:
+ raise ImportError(
+ "PyYAML is required to use yaml format. "
+ "Install it with: pip install pyyaml"
+ )
+ return yaml.dump(data, default_flow_style=False, allow_unicode=True).encode()
+
+
+class TomlFileConfigManager(BaseFileConfigManager):
+ type: t.Literal["file:toml"]
+ _file_ext: t.ClassVar[str] = "toml"
+
+ def _deserialize(self, raw: bytes) -> t.Any:
+ try:
+ import tomllib # Python 3.11+
+ except ImportError:
+ try:
+ import tomli as tomllib # type: ignore
+ except ImportError:
+ raise ImportError(
+ "A TOML library is required to use toml format. "
+ "Install tomli with: pip install tomli (Python < 3.11)"
+ )
+ return tomllib.loads(raw.decode())
+
+ def _serialize(self, data: t.Any) -> bytes:
+ try:
+ import tomli_w # type: ignore
+ except ImportError:
+ raise ImportError(
+ "tomli_w is required to write toml files. "
+ "Install it with: pip install tomli-w"
+ )
+ return tomli_w.dumps(data).encode()
diff --git a/decider/config/versioned.py b/decider/config/versioned.py
new file mode 100644
index 0000000..7673135
--- /dev/null
+++ b/decider/config/versioned.py
@@ -0,0 +1,57 @@
+import enum
+import typing as t
+from contextlib import contextmanager
+from contextvars import ContextVar
+
+
+class VersionPart(enum.Enum):
+ MAJOR = "major"
+ MINOR = "minor"
+ PATCH = "patch"
+
+class Version(t.NamedTuple):
+ major: int
+ minor: int
+ patch: int
+
+ @classmethod
+ def parse(cls, version_str: str) -> "Version":
+ parts = version_str.split(".")
+ if len(parts) != 3:
+ raise ValueError(f"Invalid version string: {version_str!r}")
+ return cls(major=int(parts[0]), minor=int(parts[1]), patch=int(parts[2]))
+
+ def bump(self, part: VersionPart) -> "Version":
+ if part == VersionPart.MAJOR:
+ return Version(self.major + 1, 0, 0)
+ elif part == VersionPart.MINOR:
+ return Version(self.major, self.minor + 1, 0)
+ elif part == VersionPart.PATCH:
+ return Version(self.major, self.minor, self.patch + 1)
+ else:
+ raise ValueError(f"Invalid version part: {part!r}")
+
+
+ def __str__(self) -> str:
+ return f"{self.major}.{self.minor}.{self.patch}"
+
+
+class VersionedConfig(t.NamedTuple):
+ version: Version
+ config: t.Dict[str, t.Any]
+
+
+_current_versioned_config: ContextVar[t.Optional[VersionedConfig]] = ContextVar("_current_versioned_config", default=None)
+
+@contextmanager
+def with_versioned_config(versioned_config: VersionedConfig) -> t.Iterator[VersionedConfig]:
+ """Context manager to set the current versioned config for the duration of a block."""
+ token = _current_versioned_config.set(versioned_config)
+ try:
+ yield versioned_config
+ finally:
+ _current_versioned_config.reset(token)
+
+def get_current_versioned_config() -> t.Optional[VersionedConfig]:
+ """Get the current versioned config from the context variable."""
+ return _current_versioned_config.get()
diff --git a/decider/exceptions.py b/decider/exceptions.py
new file mode 100644
index 0000000..a76e718
--- /dev/null
+++ b/decider/exceptions.py
@@ -0,0 +1,113 @@
+from dataclasses import dataclass
+from contextlib import contextmanager
+from warnings import warn
+
+
+@dataclass
+class DeciderErrorResponse:
+ message: str
+ details: str | None = None
+ # Do we include a stack trace here? Maybe only in debug mode?
+
+class DeciderError(Exception):
+ """Base class for all Decider-related exceptions."""
+ _STATUS_CODE = 500 # Default to Internal Server Error, can be overridden by subclasses
+ _MESSAGE = "An error occurred in the Decider system."
+
+ def __init__(self, message=None, *args):
+ self.message = message or self._MESSAGE
+ super().__init__(self.message, *args)
+
+ def get_status_code(self) -> int:
+ """Return the HTTP status code associated with this error."""
+ return self._STATUS_CODE
+
+ def get_response_body(self) -> DeciderErrorResponse:
+ """Return the response body to be sent to the client."""
+ return DeciderErrorResponse(
+ message=self.message,
+ details=str(self) # Include the exception message as details
+ )
+
+
+class DeciderMissingDependencyError(DeciderError, ModuleNotFoundError):
+ """Raised when there is an error importing a module or source."""
+ _STATUS_CODE = 500
+ _MESSAGE = "Failed to import a required module or source. Please ensure the optional package is installed and available."
+ # TODO make this more dynamic by reading from the pyproject.toml
+ _KNOWN_MODULES = {
+ "pandera": "pandera>=0.29.0<1.0.0",
+ }
+
+ def __init__(self, package_name: str = None, optional_source: str = None, *args):
+ self.optional_source = optional_source
+
+ package_name = package_name or 'a package'
+ if self.optional_source:
+ self.message = (
+ f"Failed to import {package_name} provided in {self.optional_source}. "
+ f"Please ensure you install decider with pip install decider[{self.optional_source}] "
+ f"or install {package_name} directly with pip install '{self._KNOWN_MODULES.get(package_name, package_name)}'."
+ )
+ else:
+ self.message = (
+ f"Failed to import {package_name}. "
+ f"Please ensure you install {package_name} directly with pip install '{self._KNOWN_MODULES.get(package_name, package_name)}'."
+ )
+ super().__init__(self.message, *args)
+
+
+class BaseConfigurationError(DeciderError):
+ _STATUS_CODE = 500
+ _MESSAGE = "A configuration error occurred."
+
+
+class ModuleLoadError(DeciderError):
+ _STATUS_CODE = 500
+ _MESSAGE = "Failed to load the decider module."
+
+ @classmethod
+ def from_value_error(cls, e: ValueError) -> "ModuleLoadError":
+ return cls(str(e))
+
+
+class UnsupportedContentTypeError(DeciderError):
+ _STATUS_CODE = 415
+ _MESSAGE = "Unsupported content type."
+
+
+class InputParsingError(DeciderError):
+ _STATUS_CODE = 400
+ _MESSAGE = "Failed to parse request input."
+
+
+class UnsupportedAcceptError(DeciderError):
+ _STATUS_CODE = 406
+ _MESSAGE = "Unsupported Accept media type."
+
+
+class OutputFormattingError(DeciderError):
+ _STATUS_CODE = 500
+ _MESSAGE = "Failed to format the output."
+
+
+class DeciderRuntimeError(DeciderError):
+ _STATUS_CODE = 500
+ _MESSAGE = "A runtime error occurred."
+
+
+@contextmanager
+def wrap_import_errors(optional_source: str = None, raise_error=True):
+ try:
+ yield
+ # We only care about module not found error not import errors
+ # as if we importing pandora it will raise module not found if the dependency is missing
+ except ModuleNotFoundError as e:
+ error = DeciderMissingDependencyError(
+ optional_source=optional_source,
+ package_name=e.name,
+ )
+ if raise_error:
+ raise error from e
+ else:
+ warn(error.message + " Some functionality may not work properly.", ImportWarning)
\ No newline at end of file
diff --git a/decider/executor.py b/decider/executor.py
new file mode 100644
index 0000000..deae723
--- /dev/null
+++ b/decider/executor.py
@@ -0,0 +1,109 @@
+import typing as t
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+from dataclasses import dataclass
+
+import polars as pl
+
+from decider._ext import TypeDiscriminatedBaseModule
+from decider.types import TInputType, TOutputType
+from decider.graphutil import topological_sort
+
+if t.TYPE_CHECKING:
+ from decider.modules.expression import Node, CompiledExpressions
+ from decider.modules.core import BaseModule
+
+
+# ── FrameNode ─────────────────────────────────────────────────────────────────
+
+@dataclass
+class FrameNode:
+ """A named frame-level computation with declared dependencies.
+
+ Stores a bound execute method rather than the module itself, so
+ CompiledFrameGraph can call it directly without re-entering executor.execute.
+ """
+ name: str
+ callable: t.Callable[[TInputType, "Executor"], TOutputType]
+ depends_on: t.List[str]
+
+ def get_dependencies(self) -> t.List[str]:
+ return self.depends_on
+
+
+@dataclass
+class CompiledFrameGraph:
+ """Sorted list of FrameNodes ready to execute in dependency order."""
+ nodes: t.List[FrameNode]
+
+ def execute(
+ self,
+ inputs: TInputType,
+ executor: "Executor",
+ ) -> TOutputType:
+ frames = {**inputs}
+ for node in self.nodes:
+ result = node.callable(frames, executor)
+ frames[node.name] = result
+ return frames
+
+
+# ── Executor ABC ──────────────────────────────────────────────────────────────
+
+class Executor(TypeDiscriminatedBaseModule, ABC):
+ debug: bool = False
+ collect: bool = True
+
+ @abstractmethod
+ def compile_expression_graph(self, nodes: t.List["Node"]) -> "CompiledExpressions":
+ """Topologically sort expression nodes into a CompiledExpressions artifact."""
+ ...
+
+ def compile_frame_graph(self, frame_nodes: t.List[FrameNode]) -> CompiledFrameGraph:
+ """Topologically sort frame nodes into an executable graph."""
+ return CompiledFrameGraph(nodes=topological_sort(frame_nodes))
+
+ def execute(self, module: "BaseModule", inputs: TInputType) -> TOutputType:
+ graph = module.compile(self)
+ result = graph.execute(inputs, self)
+ if self.debug:
+ if self.collect:
+ result = {k: v.collect() if isinstance(v, pl.LazyFrame) else v for k, v in result.items()}
+ return result
+ res = result[graph.nodes[-1].name] # Return only the final output frame
+ if self.collect and isinstance(res, pl.LazyFrame):
+ res = res.collect()
+ return res
+
+
+# ── SimpleExecutor ────────────────────────────────────────────────────────────
+
+class SimpleExecutor(Executor):
+ type: t.Literal["simple"]
+
+ def compile_expression_graph(self, nodes: t.List["Node"]) -> "CompiledExpressions":
+ from decider.modules.expression import CompiledExpressions, ExternalInputNode
+
+ sorted_nodes = topological_sort(nodes)
+
+ expressions: OrderedDict[str, pl.Expr] = OrderedDict()
+ for node in sorted_nodes:
+ try:
+ expr = node.callable(**node.get_input_expressions())
+ except Exception as e:
+ missing = [
+ f"'{k}' (column '{v.input_name}')"
+ for k, v in node.input_map.items()
+ if isinstance(v, ExternalInputNode)
+ ]
+ hint = (
+ f"\nHint: node '{node.name}' expected columns: "
+ + ", ".join(missing)
+ ) if missing else ""
+ raise ValueError(
+ f"Error building expression for '{node.name}': {e}{hint}"
+ ) from e
+ expressions[node.name] = expr
+
+ return CompiledExpressions(expressions=expressions)
+
diff --git a/decider/graphutil.py b/decider/graphutil.py
new file mode 100644
index 0000000..3cf24ef
--- /dev/null
+++ b/decider/graphutil.py
@@ -0,0 +1,40 @@
+import typing as t
+from collections import deque
+
+
+class HasDependencies(t.Protocol):
+ name: str
+
+ def get_dependencies(self) -> t.List[str]: ...
+
+
+T = t.TypeVar("T", bound=HasDependencies)
+
+
+def topological_sort(nodes: t.List[T]) -> t.List[T]:
+ """Kahn's algorithm. Works on any node type that implements HasDependencies."""
+ node_map = {n.name: n for n in nodes}
+ in_degree = {n.name: 0 for n in nodes}
+ adjacency: t.Dict[str, t.List[str]] = {n.name: [] for n in nodes}
+
+ for node in nodes:
+ for dep in node.get_dependencies():
+ if dep in node_map:
+ adjacency[dep].append(node.name)
+ in_degree[node.name] += 1
+
+ queue = deque(name for name, deg in in_degree.items() if deg == 0)
+ result = []
+ while queue:
+ current = queue.popleft()
+ result.append(current)
+ for neighbour in adjacency[current]:
+ in_degree[neighbour] -= 1
+ if in_degree[neighbour] == 0:
+ queue.append(neighbour)
+
+ if len(result) != len(nodes):
+ cycle = [n for n, d in in_degree.items() if d > 0]
+ raise ValueError(f"Circular dependency detected: {cycle}")
+
+ return [node_map[name] for name in result]
diff --git a/decider/initialization.py b/decider/initialization.py
new file mode 100644
index 0000000..6a7461b
--- /dev/null
+++ b/decider/initialization.py
@@ -0,0 +1,52 @@
+"""
+This module handles loading any external extensions to decider
+"""
+import glob
+import importlib
+import logging
+import sys
+from pathlib import Path
+
+
+
+logger = logging.getLogger(__name__)
+
+
+def initialize_decider(extension_path: str = None) -> None:
+ """Initialise Decider by loading extensions from the configured extension path
+ and importing any explicitly listed extension modules.
+
+ Args:
+ extension_path: Override the extension path from settings. Useful when
+ calling from a script before the settings singleton is configured.
+ """
+ from .settings import settings
+ ext_settings = settings.ext
+ ext_path = Path(extension_path).resolve() if extension_path else Path(ext_settings.extension_path).resolve()
+
+ # Add the extension path to sys.path so packages inside it are importable
+ ext_path_str = str(ext_path)
+ if ext_path_str not in sys.path:
+ sys.path.insert(0, ext_path_str)
+ logger.debug("Added extension path to sys.path: %s", ext_path_str)
+
+ # Discover flat packages: ext_path//__init__.py
+ for init_file in glob.glob(str(ext_path / "*" / "__init__.py")):
+ module_name = Path(init_file).parent.name
+ logger.debug("Initialising flat extension: %s", module_name)
+ importlib.import_module(module_name)
+
+ # Discover src-layout packages: ext_path//src//__init__.py
+ for init_file in glob.glob(str(ext_path / "*" / "src" / "*" / "__init__.py")):
+ src_dir = str(Path(init_file).parent.parent)
+ module_name = Path(init_file).parent.name
+ if src_dir not in sys.path:
+ sys.path.insert(0, src_dir)
+ logger.debug("Added src-layout path to sys.path: %s", src_dir)
+ logger.debug("Initialising src-layout extension: %s", module_name)
+ importlib.import_module(module_name)
+
+ # Import any explicitly listed extension modules
+ for module_name in ext_settings.extension_imports:
+ logger.debug("Initialising extension import: %s", module_name)
+ importlib.import_module(module_name)
diff --git a/decider/magics/__init__.py b/decider/magics/__init__.py
new file mode 100644
index 0000000..e53d4f2
--- /dev/null
+++ b/decider/magics/__init__.py
@@ -0,0 +1,162 @@
+"""
+Jupyter magic for interactive module development.
+
+Load in a notebook with:
+ %load_ext decider.magics
+
+Inline module (writes decider_extensions//__init__.py):
+ %%module CreditScorer
+
+ def score(amount: pl.Expr) -> pl.Expr:
+ return amount * 2
+
+Shared package module (writes decider_extensions/mylib/src/mylib/credit_scorer.py):
+ %%module CreditScorer --package mylib
+
+ def score(amount: pl.Expr) -> pl.Expr:
+ return amount * 2
+
+Override extension directory for the session:
+ DECIDER_EXTENSIONS_DIR = "/path/to/decider_extensions"
+"""
+
+import argparse
+import importlib
+import os
+import sys
+import typing as t
+from pathlib import Path
+
+from decider.templates.scaffold import (
+ to_snake,
+ write_inline_module,
+ write_package_module,
+)
+
+
+# ── extension directory resolution ───────────────────────────────────────────
+
+def _find_ext_dir(ip) -> Path:
+ """
+ Priority:
+ 1. ip.user_ns['DECIDER_EXTENSIONS_DIR']
+ 2. decider_extensions/ next to the notebook
+ 3. decider_extensions/ in cwd
+ """
+ if ip is not None and "DECIDER_EXTENSIONS_DIR" in ip.user_ns:
+ return Path(ip.user_ns["DECIDER_EXTENSIONS_DIR"]).resolve()
+ nb_dir = Path(getattr(ip, "starting_dir", None) or os.getcwd()) if ip else Path(os.getcwd())
+ candidate = nb_dir / "decider_extensions"
+ if candidate.exists():
+ return candidate.resolve()
+ return (Path(os.getcwd()) / "decider_extensions").resolve()
+
+
+# ── import helpers ────────────────────────────────────────────────────────────
+
+def _ensure_on_path(directory: Path) -> None:
+ s = str(directory)
+ if s not in sys.path:
+ sys.path.insert(0, s)
+
+
+def _import_or_reload(module_name: str) -> t.Any:
+ if module_name in sys.modules:
+ return importlib.reload(sys.modules[module_name])
+ return importlib.import_module(module_name)
+
+
+def _load_class_from_inline(ext_dir: Path, class_name: str) -> type:
+ _ensure_on_path(ext_dir)
+ mod = _import_or_reload(to_snake(class_name))
+ cls = getattr(mod, class_name, None)
+ if cls is None:
+ raise ImportError(f"Could not find {class_name!r} in {mod.__file__}")
+ return cls
+
+
+def _load_class_from_package(ext_dir: Path, class_name: str, package_name: str) -> type:
+ # The importable root for a src-layout package is ext_dir//src/
+ src_root = ext_dir / package_name / "src"
+ _ensure_on_path(src_root)
+ snake = to_snake(class_name)
+ # import the submodule directly so we get the freshest version
+ full_name = f"{package_name}.{snake}"
+ if full_name in sys.modules:
+ mod = importlib.reload(sys.modules[full_name])
+ else:
+ mod = importlib.import_module(full_name)
+ cls = getattr(mod, class_name, None)
+ if cls is None:
+ raise ImportError(f"Could not find {class_name!r} in {mod.__file__}")
+ return cls
+
+
+def _register(cls: type) -> None:
+ from decider.modules import register_graph_module
+ register_graph_module(cls)
+
+
+# ── argument parsing ──────────────────────────────────────────────────────────
+
+def _parse_line(line: str) -> tuple[str, t.Optional[str]]:
+ """Return (class_name, package_name_or_None)."""
+ parts = line.split()
+ if not parts:
+ raise ValueError("Usage: %%module ClassName [--package pkg_name]")
+ class_name = parts[0]
+ package_name = None
+ rest = parts[1:]
+ i = 0
+ while i < len(rest):
+ if rest[i] in ("--package", "-p") and i + 1 < len(rest):
+ package_name = rest[i + 1]
+ i += 2
+ else:
+ i += 1
+ return class_name, package_name
+
+
+# ── magic implementation ──────────────────────────────────────────────────────
+
+def module_magic(line: str, cell: str = None):
+ """
+ %%module ClassName [--package pkg]
+
+ Writes/updates the extension file, reloads it, registers the class with
+ GraphModule, and injects it into the notebook namespace.
+ """
+ try:
+ from IPython import get_ipython
+ ip = get_ipython()
+ except ImportError:
+ ip = None
+
+ class_name, package_name = _parse_line(line)
+
+ if cell is None:
+ print(f"Usage:\n %%%%module {class_name} [--package pkg_name]\n ")
+ return
+
+ ext_dir = _find_ext_dir(ip)
+
+ if package_name:
+ module_file, init_file = write_package_module(ext_dir, class_name, package_name, cell)
+ print(f"[module] Written: {module_file}")
+ print(f"[module] Updated: {init_file}")
+ cls = _load_class_from_package(ext_dir, class_name, package_name)
+ else:
+ init_file = write_inline_module(ext_dir, class_name, cell)
+ print(f"[module] Written: {init_file}")
+ cls = _load_class_from_inline(ext_dir, class_name)
+
+ _register(cls)
+ print(f"[module] {class_name} registered type={cls._CLASS_TYPE_IDENTIFIER!r}")
+
+ if ip is not None:
+ ip.user_ns[class_name] = cls
+ print(f"[module] {class_name} injected into namespace")
+
+
+def load_ipython_extension(ip):
+ ip.register_magic_function(module_magic, magic_kind="line_cell", magic_name="module")
diff --git a/decider/modules/__init__.py b/decider/modules/__init__.py
new file mode 100644
index 0000000..d47e0c7
--- /dev/null
+++ b/decider/modules/__init__.py
@@ -0,0 +1,22 @@
+from ._ext import register_graph_module, GraphModule
+from .expression import Node
+
+def _load_core_modules():
+ from .primitives.sequential import SequentialModule
+ from .primitives.join import JoinModule, FrameRef, FrameModule
+ from .credit import register_credit_modules
+ from .rules import register_rule_modules
+
+ for cls in (SequentialModule, JoinModule, FrameRef):
+ register_graph_module(cls)
+
+ register_credit_modules()
+ register_rule_modules()
+
+_load_core_modules()
+
+__all__: list[str] = [
+ "register_graph_module",
+ "GraphModule",
+ "Node",
+]
\ No newline at end of file
diff --git a/decider/modules/_ext.py b/decider/modules/_ext.py
new file mode 100644
index 0000000..da80f5e
--- /dev/null
+++ b/decider/modules/_ext.py
@@ -0,0 +1,15 @@
+"""
+This module enables config sources to be extended by external packages without creating hard dependencies on those packages. It does this by maintaining a global union type of all registered sources, which can be extended by calling the `register_module` function with a new source type. The `GraphModule` model is then rebuilt to include the new source type in its union.
+"""
+import typing as t
+from .core import BaseModule
+from decider._ext import create_extendable_model
+
+
+GraphModule, register_graph_module = create_extendable_model(
+ BaseModule,
+ discriminator_field="type",
+ model_name="GraphModule"
+)
+
+__all__ = ['GraphModule', 'register_graph_module', ]
\ No newline at end of file
diff --git a/decider/modules/_ext.pyi b/decider/modules/_ext.pyi
new file mode 100644
index 0000000..0f05533
--- /dev/null
+++ b/decider/modules/_ext.pyi
@@ -0,0 +1,16 @@
+"""
+This file is needed because the static type checkers dont like the
+fact that ParameterSource is generated from a function. So we define a
+stub file that defines the type of ParameterSource and the register_graph_module function,
+which can be used by static type checkers to understand the types of these objects.
+"""
+import typing as t
+from .core import BaseModule
+from decider._ext import TExtendableModel
+
+GraphModule = TExtendableModel[BaseModule]
+
+
+def register_graph_module(provider_class: t.Type[BaseModule]) -> None: ...
+
+__all__: list[str]
\ No newline at end of file
diff --git a/decider/modules/core.py b/decider/modules/core.py
new file mode 100644
index 0000000..3cf1cbe
--- /dev/null
+++ b/decider/modules/core.py
@@ -0,0 +1,92 @@
+import typing as t
+from abc import ABC, abstractmethod
+from pydantic import PrivateAttr
+from decider.types import TInputType, TOutputType
+from decider._ext import TypeDiscriminatedBaseModule
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor, FrameNode, CompiledFrameGraph
+ from decider.config.base import BaseConfig
+ from decider.config.versioned import VersionedConfig
+
+
+class BaseModule(TypeDiscriminatedBaseModule, ABC):
+ name: str
+ _compiled_graph: t.Optional["CompiledFrameGraph"] = PrivateAttr(default=None)
+ _input_frame_keys: t.Optional[t.List[str]] = PrivateAttr(default=None)
+
+
+ def compile(self, executor: "Executor") -> None:
+ if self._compiled_graph is None:
+ frame_nodes = self.get_frame_nodes(executor)
+ self._compiled_graph = executor.compile_frame_graph(frame_nodes)
+ return self._compiled_graph
+
+ @abstractmethod
+ def get_frame_nodes(self, executor: "Executor") -> t.List["FrameNode"]:
+ ...
+
+ def __call__(
+ self,
+ inputs: TInputType,
+ executor: t.Optional["Executor"] = None,
+ ) -> TOutputType:
+ from decider.settings import get_default_executor
+ executor = executor or get_default_executor()
+ return executor.execute(self, inputs)
+
+ def __or__(self, other: "BaseModule") -> "BaseModule":
+ from decider.modules.primitives.sequential import SequentialModule
+ if isinstance(self, SequentialModule):
+ return SequentialModule(name=self.name, steps=self.steps + [other]) # type: ignore[attr-defined]
+ return SequentialModule(name=self.name, steps=[self, other])
+
+ def get_input_frame_keys(self) -> t.List[str]:
+ """Return the names of input frames this module expects, computed once and cached."""
+ if self._input_frame_keys is None:
+ self._input_frame_keys = self._compute_input_frame_keys()
+ return self._input_frame_keys
+
+ def _compute_input_frame_keys(self) -> t.List[str]:
+ """Override in subclasses that consume named frames other than 'input'."""
+ return ["input"]
+
+ def to_config(self, config_key: str) -> "BaseConfig[t.Self]":
+ from decider.config.base import ConfigModule
+ config_class = ConfigModule.for_module_class(type(self))
+ return config_class.from_model(model=self, config_key=config_key)
+
+ async def asave(self, root_key: str, config_manager=None) -> "VersionedConfig":
+ from decider.config.base import DUMP_TRIGGER_KEY
+ from decider.config.versioned import with_versioned_config
+ if config_manager is None:
+ from decider.settings import settings
+ config_manager = settings.config.get()
+ try:
+ versioned_conf = await config_manager.get_latest()
+ except RuntimeError:
+ versioned_conf = await config_manager.create_version()
+ with with_versioned_config(versioned_conf):
+ config_mod = self.to_config(config_key=root_key)
+ config_mod.model_dump(context={DUMP_TRIGGER_KEY: True})
+ return versioned_conf
+
+ def save(self, root_key: str, config_manager=None) -> "VersionedConfig":
+ import asyncio
+ return asyncio.run(self.asave(root_key, config_manager))
+
+
+class BaseExecuteModule(BaseModule, ABC):
+ name: str
+
+ @abstractmethod
+ def execute(self, inputs: TInputType, executor: "Executor") -> TOutputType:
+ ...
+
+ def get_frame_nodes(self, executor: "Executor") -> t.List["FrameNode"]:
+ from decider.executor import FrameNode
+ return [FrameNode(
+ name=self.name,
+ callable=self.execute,
+ depends_on=[],
+ )]
diff --git a/decider/modules/credit/__init__.py b/decider/modules/credit/__init__.py
new file mode 100644
index 0000000..7378f9b
--- /dev/null
+++ b/decider/modules/credit/__init__.py
@@ -0,0 +1,24 @@
+
+
+def register_credit_modules():
+ from .decision_table.module import DecisionTableModule
+ from .scorecard.module import (
+ ScoreCard,
+ ProbabilityDefault,
+ LogProbability,
+ ScoreFromPDO,
+ MergeScorecardValues,
+ )
+ from decider.modules import register_graph_module
+
+ MODULE_LIST = [
+ DecisionTableModule,
+ ScoreCard,
+ ProbabilityDefault,
+ LogProbability,
+ ScoreFromPDO,
+ MergeScorecardValues,
+ ]
+
+ for module_cls in MODULE_LIST:
+ register_graph_module(module_cls)
\ No newline at end of file
diff --git a/decider/modules/credit/decision_table/__init__.py b/decider/modules/credit/decision_table/__init__.py
new file mode 100644
index 0000000..9ee7f35
--- /dev/null
+++ b/decider/modules/credit/decision_table/__init__.py
@@ -0,0 +1,19 @@
+from .module import (
+ DecisionTableModule,
+ ParametersConfig,
+ Expression,
+)
+
+from .impl import (
+ default_form_output_struct_from_row,
+ calculate_decision_table_output,
+)
+
+__all__ = [
+ "DecisionTableModule",
+ "DecisionTableConfig",
+ "ParametersConfig",
+ "Expression",
+ "default_form_output_struct_from_row",
+ "calculate_decision_table_output",
+]
\ No newline at end of file
diff --git a/decider/modules/credit/decision_table/config.py b/decider/modules/credit/decision_table/config.py
new file mode 100644
index 0000000..1dc3c0e
--- /dev/null
+++ b/decider/modules/credit/decision_table/config.py
@@ -0,0 +1,387 @@
+import ast
+import typing as t
+import functools
+from dataclasses import dataclass
+from enum import Enum
+import polars as pl
+from pydantic import BaseModel, Field, model_validator, PrivateAttr
+from decider.exceptions import wrap_import_errors
+from decider._ext import TypeDiscriminatedBaseModule
+
+
+if t.TYPE_CHECKING:
+ import pandera.pandas as pa
+
+_ALLOWED_NAMES: t.Dict[str, t.Any] = {
+ "Optional": t.Optional,
+ "List": t.List,
+ "Dict": t.Dict,
+ "str": str,
+ "string": str, # pandera-style alias
+ "int": int,
+ "float": float,
+ "bool": bool,
+ "list": list,
+}
+
+
+class _TypeEvaluator(ast.NodeVisitor):
+ def visit_Name(self, node: ast.Name) -> t.Any:
+ if node.id not in _ALLOWED_NAMES:
+ raise ValueError(f"Disallowed name: {node.id!r}")
+ return _ALLOWED_NAMES[node.id]
+
+ def visit_Subscript(self, node: ast.Subscript) -> t.Any:
+ base = self.visit(node.value)
+ sub = self.visit(node.slice)
+ return base[sub]
+
+ def visit_Index(self, node: ast.Index) -> t.Any: # Python < 3.9
+ return self.visit(node.value)
+
+ def visit_Tuple(self, node: ast.Tuple) -> tuple:
+ return tuple(self.visit(elt) for elt in node.elts)
+
+ def visit_Constant(self, node: ast.Constant) -> t.Any:
+ return node.value
+
+ def generic_visit(self, node: ast.AST) -> t.Any:
+ raise ValueError(f"Unsupported syntax: {ast.dump(node)}")
+
+
+def _safe_eval_type(expr: str) -> t.Any:
+ """Parse a type-expression string into a real type using AST (no eval)."""
+ tree = ast.parse(expr.strip(), mode="eval")
+ return _TypeEvaluator().visit(tree.body)
+
+
+@dataclass(frozen=True)
+class ParsedDtype:
+ type: type
+ optional: bool
+ list_inner_type: t.Optional[type] = None # set when type is list
+
+ @property
+ def is_list(self) -> bool:
+ return self.type is list
+
+
+@functools.lru_cache(maxsize=None)
+def _parse_dtype(dtype_str: str) -> ParsedDtype:
+ """Parse a dtype string into a ParsedDtype.
+
+ Uses an AST walker with an allowlist of safe names — no eval().
+ Supports ``int``, ``float``, ``str`` / ``string``, ``bool``,
+ ``list`` / ``List[X]``, and ``Optional[X]`` variants of all the above.
+ Results are cached so repeated calls for the same string are free.
+ """
+ result = _safe_eval_type(dtype_str)
+ # Unwrap Optional[X] == Union[X, None]
+ optional = False
+ if t.get_origin(result) is t.Union and type(None) in t.get_args(result):
+ optional = True
+ result = next(a for a in t.get_args(result) if a is not type(None))
+ # Detect list / List[X]
+ if result is list or t.get_origin(result) is list:
+ args = t.get_args(result)
+ inner = args[0] if args else None
+ return ParsedDtype(type=list, optional=optional, list_inner_type=inner)
+ return ParsedDtype(type=result, optional=optional)
+
+class ParametersConfig(BaseModel):
+ columns: t.List[str]
+ values: t.List[t.List[t.Any]]
+ dtypes: t.Dict[str, str]
+
+ _parameters_df: pl.DataFrame = PrivateAttr()
+ _pandera_schema: "pa.DataFrameSchema" = PrivateAttr()
+ _parsed_dtypes: t.Dict[str, "ParsedDtype"] = PrivateAttr()
+
+ @model_validator(mode='after')
+ def validate_and_create_dataframe(self):
+ with wrap_import_errors(optional_source="dt"):
+ import pandera.pandas as pa
+ # TODO we could make the pandera validation optional?
+ # Validate that all dtypes correspond to existing columns
+ for col in self.dtypes:
+ if col not in self.columns:
+ raise ValueError(f"Dtype specified for column '{col}' but column not found in parameters columns")
+
+ # Validate that values matrix has correct dimensions
+ if self.values:
+ expected_cols = len(self.columns)
+ for i, row in enumerate(self.values):
+ if len(row) != expected_cols:
+ raise ValueError(f"Row {i} has {len(row)} values but {expected_cols} columns expected")
+
+ # Create DataFrame from values
+ df_dict = {}
+ for i, col in enumerate(self.columns):
+ df_dict[col] = [row[i] for row in self.values]
+
+ self._parameters_df = pl.DataFrame(df_dict)
+
+ # Parse dtypes once and cache them, then build the pandera schema
+ parsed_dtypes: t.Dict[str, ParsedDtype] = {}
+ schema_dict = {}
+ for col, dtype_str in self.dtypes.items():
+ try:
+ parsed = _parse_dtype(dtype_str)
+ except ValueError:
+ raise ValueError(f"Unsupported dtype '{dtype_str}' for column '{col}'")
+ parsed_dtypes[col] = parsed
+ pa_dtype = pa.PythonList(parsed.list_inner_type) if parsed.is_list else parsed.type
+ schema_dict[col] = pa.Column(pa_dtype, nullable=parsed.optional)
+
+ self._parsed_dtypes = parsed_dtypes
+ self._pandera_schema = pa.DataFrameSchema(schema_dict)
+
+ # Convert to pandas for pandera validation, then back to polars
+ pandas_df = self._parameters_df.to_pandas()
+ try:
+ validated_df = self._pandera_schema.validate(pandas_df)
+ self._parameters_df = pl.from_pandas(validated_df)
+ except pa.errors.SchemaError as e:
+ raise ValueError(f"Schema validation failed: {e}")
+
+ return self
+
+
+def _safe_list_get(lst: t.List, i: int) -> t.Optional[t.Any]:
+ """Return lst[i], or None when i is out of range."""
+ return lst[i] if 0 <= i < len(lst) else None
+
+
+class BoundMode(str, Enum):
+ """Interval convention for BetweenExpression.
+
+ lower_inclusive → [lower, upper) i.e. lower <= x < upper
+ upper_inclusive → (lower, upper] i.e. lower < x <= upper
+
+ ``both_inclusive`` / ``both_exclusive`` are intentionally omitted:
+ they create overlaps or gaps at shared boundaries, making contiguous
+ range tables ambiguous.
+ """
+ lower_inclusive = "lower_inclusive"
+ upper_inclusive = "upper_inclusive"
+
+
+class BaseExpression(BaseModel):
+ """Base class for all decision table expressions"""
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ """Return one boolean pl.Expr per row of *parameters*.
+
+ Each element is a scalar-broadcast expression that evaluates to True when
+ the corresponding parameter row matches the input variables in **kwargs.
+ """
+ raise NotImplementedError
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None:
+ """Validate that required variables exist in parameters with correct types"""
+ raise NotImplementedError
+
+ def get_variables(self) -> t.List[str]:
+ """Return the list of external input variable names this expression consumes."""
+ raise NotImplementedError
+
+
+class AndExpression(TypeDiscriminatedBaseModule):
+ type: t.Literal["and"]
+ expressions: t.List["Expression"]
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ # Each child returns List[pl.Expr] of length len(parameters).
+ # AND: zip-combine per row with all_horizontal.
+ per_child = [expr(parameters, **kwargs) for expr in self.expressions]
+ return [
+ pl.all_horizontal(*row_conds)
+ for row_conds in zip(*per_child)
+ ]
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None:
+ for expr in self.expressions:
+ expr.validate_parameters(parameters)
+
+ def get_variables(self) -> t.List[str]:
+ seen = set()
+ result = []
+ for expr in self.expressions:
+ for v in expr.get_variables():
+ if v not in seen:
+ seen.add(v)
+ result.append(v)
+ return result
+
+
+class OrExpression(TypeDiscriminatedBaseModule):
+ type: t.Literal["or"]
+ expressions: t.List["Expression"]
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ per_child = [expr(parameters, **kwargs) for expr in self.expressions]
+ return [
+ pl.any_horizontal(*row_conds)
+ for row_conds in zip(*per_child)
+ ]
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None:
+ for expr in self.expressions:
+ expr.validate_parameters(parameters)
+
+ def get_variables(self) -> t.List[str]:
+ seen = set()
+ result = []
+ for expr in self.expressions:
+ for v in expr.get_variables():
+ if v not in seen:
+ seen.add(v)
+ result.append(v)
+ return result
+
+
+class BetweenExpression(TypeDiscriminatedBaseModule):
+ type: t.Literal["between"]
+ variable: str
+ lower_bound_column: t.Optional[str] = None
+ upper_bound_column: t.Optional[str] = None
+ mode: BoundMode = BoundMode.lower_inclusive
+ allow_gaps: bool = False
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ if self.variable not in kwargs:
+ raise ValueError(f"Variable '{self.variable}' not found in expression arguments")
+
+ var_expr = kwargs[self.variable]
+ n = len(parameters)
+ lower_vals = parameters[self.lower_bound_column].to_list() if self.lower_bound_column else [None] * n
+ upper_vals = parameters[self.upper_bound_column].to_list() if self.upper_bound_column else [None] * n
+ row_conditions: t.List[pl.Expr] = []
+
+ for i, (lower, upper) in enumerate(zip(lower_vals, upper_vals)):
+ lower = lower if lower is not None else _safe_list_get(upper_vals, i - 1)
+ upper = upper if upper is not None else _safe_list_get(lower_vals, i + 1)
+ if lower is None and upper is None:
+ raise ValueError(f"Row {i} has no lower or upper bound after resolution")
+ conditions: t.List[pl.Expr] = []
+ if lower is not None:
+ lb = pl.lit(lower)
+ conditions.append(var_expr >= lb if self.mode == BoundMode.lower_inclusive else var_expr > lb)
+ if upper is not None:
+ ub = pl.lit(upper)
+ conditions.append(var_expr < ub if self.mode == BoundMode.lower_inclusive else var_expr <= ub)
+ row_conditions.append(
+ pl.all_horizontal(*conditions) if len(conditions) > 1 else conditions[0]
+ )
+
+ return row_conditions
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None:
+ if not self.lower_bound_column and not self.upper_bound_column:
+ raise ValueError("At least one of lower_bound_column or upper_bound_column must be specified")
+
+ for col_attr, label in (
+ (self.lower_bound_column, "Lower"),
+ (self.upper_bound_column, "Upper"),
+ ):
+ if col_attr is None:
+ continue
+ if col_attr not in parameters.columns:
+ raise ValueError(f"{label} bound column '{col_attr}' not found in parameters columns")
+ dtype = parameters._parsed_dtypes[col_attr]
+ if dtype.type not in (float, int):
+ raise ValueError(f"{label} bound column '{col_attr}' must be numeric, got {dtype.type}")
+
+ df = parameters._parameters_df
+ n = len(df)
+ lower_vals = df[self.lower_bound_column].to_list() if self.lower_bound_column else [None] * n
+ upper_vals = df[self.upper_bound_column].to_list() if self.upper_bound_column else [None] * n
+
+ for i, (lower, upper) in enumerate(zip(lower_vals, upper_vals)):
+ lower = lower if lower is not None else _safe_list_get(upper_vals, i - 1)
+ upper = upper if upper is not None else _safe_list_get(lower_vals, i + 1)
+ if lower is None and upper is None:
+ raise ValueError(
+ f"Row {i}: both bounds are unresolvable. "
+ f"Only row 0's lower and the last row's upper may be None (open edges)."
+ )
+ if lower is None and i > 0:
+ raise ValueError(f"Row {i}: lower bound unresolvable — only row 0 may have an open lower edge.")
+ if upper is None and i < n - 1:
+ raise ValueError(f"Row {i}: upper bound unresolvable — only row {n-1} may have an open upper edge.")
+ if not self.allow_gaps and i < n - 1:
+ next_lower = lower_vals[i + 1] if lower_vals[i + 1] is not None else upper
+ if upper is not None and next_lower is not None and upper != next_lower:
+ raise ValueError(
+ f"Row {i} upper ({upper}) != row {i+1} lower ({next_lower}): "
+ f"ranges are not contiguous. Set allow_gaps=True to permit this."
+ )
+
+ def get_variables(self) -> t.List[str]:
+ return [self.variable]
+
+
+class InExpression(TypeDiscriminatedBaseModule):
+ type: t.Literal["in"]
+ variable: str
+ values_column: str
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ if self.variable not in kwargs:
+ raise ValueError(f"Variable '{self.variable}' not found in expression arguments")
+
+ var_expr = kwargs[self.variable]
+ return [
+ var_expr.is_in(pl.lit(row[self.values_column], dtype=pl.List(pl.Utf8)))
+ for row in parameters.iter_rows(named=True)
+ ]
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None:
+ # Note: self.variable is an input variable, not a column in parameters
+ # We only need to validate that the values column exists
+
+ # Check that values column exists and is a list type
+ if self.values_column not in parameters.columns:
+ raise ValueError(f"Values column '{self.values_column}' not found in parameters columns")
+
+ values_dtype = parameters._parsed_dtypes[self.values_column]
+ if not values_dtype.is_list:
+ raise ValueError(f"Values column '{self.values_column}' must be a list type, got {values_dtype.type}")
+
+ def get_variables(self) -> t.List[str]:
+ return [self.variable]
+
+
+class IsTrueExpression(TypeDiscriminatedBaseModule):
+ type: t.Literal["is_true"]
+ variable: str
+
+ def __call__(self, parameters: pl.DataFrame, **kwargs: pl.Expr) -> t.List[pl.Expr]:
+ if self.variable not in kwargs:
+ raise ValueError(f"Variable '{self.variable}' not found in expression arguments")
+ expr = kwargs[self.variable]
+ return [expr] * len(parameters)
+
+ def validate_parameters(self, parameters: "ParametersConfig") -> None: # noqa: ARG002
+ # Note: self.variable is an input variable, not a column in parameters
+ # No validation needed for IsTrueExpression beyond basic structure
+ pass
+
+ def get_variables(self) -> t.List[str]:
+ return [self.variable]
+
+
+Expression = t.Annotated[
+ t.Union[
+ AndExpression,
+ OrExpression,
+ BetweenExpression,
+ InExpression,
+ IsTrueExpression,
+ ],
+ Field(discriminator="type")
+]
+
+# Update forward references
+AndExpression.model_rebuild()
+OrExpression.model_rebuild()
\ No newline at end of file
diff --git a/decider/modules/credit/decision_table/impl.py b/decider/modules/credit/decision_table/impl.py
new file mode 100644
index 0000000..9efadfd
--- /dev/null
+++ b/decider/modules/credit/decision_table/impl.py
@@ -0,0 +1,66 @@
+import typing as t
+import polars as pl
+from .config import Expression
+
+
+def default_form_output_struct_from_row(
+ row_values: t.Dict[str, t.Any],
+ output_columns: t.List[str]
+) -> pl.Expr:
+ """
+ Create a struct from row values for specified output columns.
+
+ Args:
+ row_values: Dictionary mapping column names to values for the current row
+ output_columns: List of column names to include in the output struct
+
+ Returns:
+ pl.Expr: A struct expression with the specified output columns
+ """
+ output_fields = []
+ for col in output_columns:
+ output_fields.append(pl.lit(row_values[col]).alias(col))
+ return pl.struct(*output_fields)
+
+
+def calculate_decision_table_output(
+ parameters: pl.DataFrame,
+ expression: Expression,
+ output_columns: t.List[str],
+ default: t.Optional[t.List[t.Any]] = None,
+ output_fn: t.Callable[[t.Dict[str, t.Any], t.List[str]], pl.Expr] = default_form_output_struct_from_row,
+ **kwargs: pl.Expr
+) -> pl.Expr:
+ """
+ Calculate decision table output by evaluating expression against each parameter row.
+
+ Args:
+ parameters: DataFrame containing decision table parameters
+ expression: Expression object that can be called with **kwargs
+ output_columns: List of columns to include in output
+ default: Default values to return if no rows match
+ output_fn: Function to create output struct from row values
+ **kwargs: Input variable expressions (e.g., v1=pl.col("age"), v2=pl.col("income"))
+
+ Returns:
+ pl.Expr: Expression that evaluates to output struct for first matching row
+ """
+ # expression() returns one boolean pl.Expr per parameter row.
+ conditions = expression(parameters, **kwargs)
+
+ output_expr = pl
+ for condition, row_values in zip(conditions, parameters.iter_rows(named=True)):
+ row_output = output_fn(row_values, output_columns)
+ output_expr = output_expr.when(condition).then(row_output)
+
+ # Handle default case
+ if default is not None:
+ default_values = {col: default[i] for i, col in enumerate(output_columns)}
+ default_output = output_fn(default_values, output_columns)
+ else:
+ default_fields = [pl.lit(None).alias(col) for col in output_columns]
+ default_output = pl.struct(*default_fields)
+
+ if output_expr is pl:
+ return default_output
+ return output_expr.otherwise(default_output)
diff --git a/decider/modules/credit/decision_table/module.py b/decider/modules/credit/decision_table/module.py
new file mode 100644
index 0000000..896a570
--- /dev/null
+++ b/decider/modules/credit/decision_table/module.py
@@ -0,0 +1,58 @@
+import typing as t
+from pydantic import model_validator
+from decider.modules.core import BaseModule
+from decider.modules.expression import Node, ExternalInputNode
+from .config import ParametersConfig, Expression
+
+
+
+class DecisionTableModule(BaseModule):
+ type: t.Literal["decision_table"]
+ parameters: ParametersConfig
+ expression: Expression
+ outputs: t.List[str]
+ default: t.Optional[t.List[t.Any]] = None
+
+ @model_validator(mode='after')
+ def validate_config(self):
+ # Validate that all output columns exist in parameters
+ for output in self.outputs:
+ if output not in self.parameters.columns:
+ raise ValueError(f"Output column '{output}' not found in parameters columns")
+
+ # Validate default has correct length if specified
+ if self.default is not None:
+ if len(self.default) != len(self.outputs):
+ raise ValueError(f"Default values length ({len(self.default)}) must match outputs length ({len(self.outputs)})")
+
+ # Validate expression parameters
+ self.expression.validate_parameters(self.parameters)
+
+ return self
+
+ def expand_nodes(self, config: t.Dict[str, t.Any] = None) -> t.List[Node]: # noqa: ARG002
+ """
+ Expand the decision table configuration into DeciderNodes.
+
+ Returns:
+ List of DeciderNodes
+ """
+ from .impl import calculate_decision_table_output
+
+ variables = self.expression.get_variables()
+ input_map = {v: v for v in variables}
+
+ # Create the main decision table evaluation node
+ return [
+ Node.from_callable(
+ calculate_decision_table_output,
+ name="output",
+ input_map=input_map,
+ static_kwargs={
+ "parameters": self.parameters._parameters_df,
+ "expression": self.expression,
+ "output_columns": self.outputs,
+ "default": self.default
+ }
+ )
+ ]
diff --git a/decider/modules/credit/scorecard/__init__.py b/decider/modules/credit/scorecard/__init__.py
new file mode 100644
index 0000000..e96f328
--- /dev/null
+++ b/decider/modules/credit/scorecard/__init__.py
@@ -0,0 +1,42 @@
+from .module import (
+ ScoreCard,
+ ScoredVariable,
+ AdjustedVariable,
+ ConstantScore,
+ ProbabilityDefault,
+ LogProbability,
+ ScoreFromPDO,
+ MergeScorecardValues,
+)
+
+from .impl import (
+ BoundBin,
+ ValuesBin,
+ DefaultBin,
+ score_variable,
+ adjust_score,
+ calculate_score,
+ calculate_probability_of_default,
+ log_odds_from_score,
+ calculate_credit_score,
+)
+
+__all__ = [
+ "ScoreCard",
+ "ScoredVariable",
+ "AdjustedVariable",
+ "ConstantScore",
+ "ProbabilityDefault",
+ "LogProbability",
+ "ScoreFromPDO",
+ "MergeScorecardValues",
+ "BoundBin",
+ "ValuesBin",
+ "DefaultBin",
+ "score_variable",
+ "adjust_score",
+ "calculate_score",
+ "calculate_probability_of_default",
+ "log_odds_from_score",
+ "calculate_credit_score",
+]
diff --git a/decider/modules/credit/scorecard/impl.py b/decider/modules/credit/scorecard/impl.py
new file mode 100644
index 0000000..58a51a8
--- /dev/null
+++ b/decider/modules/credit/scorecard/impl.py
@@ -0,0 +1,312 @@
+import typing as t
+from dataclasses import dataclass, field
+import polars as pl
+import math
+
+
+@dataclass
+class DefaultBin:
+ value: float
+ name: t.Optional[str] = None
+
+
+@dataclass
+class BoundBin(DefaultBin):
+ lower_bound: t.Optional[float] = None
+ upper_bound: t.Optional[float] = None
+
+ def with_bounds(self, lower_bound: t.Optional[float] = None, upper_bound: t.Optional[float] = None) -> "BoundBin":
+ return BoundBin(
+ value=self.value,
+ name=self.name,
+ lower_bound=lower_bound or self.lower_bound,
+ upper_bound=upper_bound or self.upper_bound,
+ )
+
+@dataclass
+class ValuesBin(DefaultBin):
+ items: t.List[int|float|str] = field(default_factory=list)
+
+
+def adjust_score(score: pl.Expr, offset: float = 0.0, scale: float = 1.0) -> pl.Expr:
+ """Apply a linear transformation to the score using the provided offset and scale.
+
+ Args:
+ score (pl.Expr): The input score expression.
+ offset (float): The value to add to the score.
+ scale (float): The value to multiply the score by.
+
+ Returns:
+ pl.Expr: The transformed score expression.
+ """
+ # Note we do the checks to avoid bringing in unnecessary compute into the pl.Expr.
+ if scale != 1.0:
+ score = score * scale
+ if offset != 0.0:
+ score = score + offset
+ return score
+
+def default_output_expr(
+ input: pl.Expr,
+ current_bin: t.Union[BoundBin, ValuesBin, DefaultBin],
+ input_name: t.Optional[str] = None,
+) -> pl.Expr:
+ common_args = [
+ pl.lit(current_bin.value).alias("value"),
+ input.alias("input"),
+ pl.lit(current_bin.name).alias("bin"),
+ ]
+ if input_name is not None:
+ common_args.append(pl.lit(input_name).alias("input_name"))
+ if isinstance(current_bin, BoundBin):
+ return pl.struct(
+ *common_args,
+ pl.lit("bound").alias("type"),
+ pl.lit(current_bin.upper_bound).alias("upper_bound"),
+ pl.lit(current_bin.lower_bound).alias("lower_bound"),
+ )
+ elif isinstance(current_bin, ValuesBin):
+ return pl.struct(
+ *common_args,
+ pl.lit("values").alias("type"),
+ pl.lit(current_bin.items).alias("values"),
+ )
+ elif isinstance(current_bin, DefaultBin):
+ return pl.struct(
+ *common_args,
+ pl.lit("default").alias("type"),
+ )
+ else:
+ raise ValueError(f"Unsupported bin type: {type(current_bin)}")
+
+
+def default_get_value_from_struct(struct: pl.Expr) -> pl.Expr:
+ return struct.struct.field("value")
+
+
+def _get_bound_bin(bound_bins: t.List[BoundBin], i: int) -> BoundBin:
+ if 0 <= i < len(bound_bins): return bound_bins[i]
+ return BoundBin(value=float('nan'), lower_bound=None, upper_bound=None)
+
+
+def score_variable(
+ input: pl.Expr,
+ bound_bins: t.List[BoundBin],
+ value_bins: t.List[ValuesBin],
+ default_bin: DefaultBin,
+ input_name: t.Optional[str] = None,
+ output_expr_fn: "type[default_output_expr] | None" = None,
+) -> pl.Expr:
+ # A little hack because when starting the chain it will look like pl.when
+ # and then the next loop will be chain.when
+ # so to avoid a check with if expression_chain is none we can just start expr_chain = pl
+ # that way the first loop will be pl.when
+ if output_expr_fn is None:
+ output_expr_fn = default_output_expr
+ expr_chain = pl
+
+ # We apply value bins first as they are higher priority than bound bins.
+ for b in value_bins:
+ condition = input.is_in(pl.lit(list(b.items)))
+ bin_expr = output_expr_fn(input, b, input_name)
+ expr_chain = expr_chain.when(condition).then(bin_expr)
+
+ for i, b in enumerate(bound_bins):
+ lower_bound = b.lower_bound or _get_bound_bin(bound_bins, i-1).upper_bound
+ upper_bound = b.upper_bound or _get_bound_bin(bound_bins, i+1).lower_bound
+
+ condition = pl.lit(True)
+ # This is the first condition so no need to do condition & ... (Optimize out the literal)
+ if lower_bound is not None: condition = (input > lower_bound)
+ if upper_bound is not None: condition = condition & (input <= upper_bound)
+
+ bin_expr = output_expr_fn(
+ input,
+ b.with_bounds(lower_bound=lower_bound, upper_bound=upper_bound),
+ input_name
+ )
+ expr_chain = expr_chain.when(condition).then(bin_expr)
+
+ default_expr = output_expr_fn(input, default_bin, input_name)
+
+ if expr_chain is pl:
+ # We know we have no conditions above so we can just return the default value
+ return default_expr
+ else:
+ return expr_chain.otherwise(default_expr)
+
+
+def calculate_score(**kwargs: pl.Expr) -> pl.Expr:
+ """Calculate the total score from a set of scored variable values
+
+ Args:
+ **kwargs: A variable number of keyword arguments where each key is the name of a scored variable and each value is a Polars expression representing the score for that variable. For example, you might have:
+ score_variable1=pl.col("score_variable1"),
+ score_variable2=pl.col("score_variable2"),
+ score_variable3=pl.col("score_variable3"),
+ Returns:
+ pl.Expr: the calculated score.
+ """
+ return pl.sum_horizontal(list(kwargs.values()))
+
+
+def log_odds_from_score(
+ score: pl.Expr,
+ anchor_score: float, # Or known as base_points
+ target_odds: float, # Or known as base_odds
+ points_to_double_the_odds: float # Or known as pdo
+) -> pl.Expr:
+ """
+ Calculate log odds from a credit score.
+
+ This function converts a credit score to log odds using a linear transformation.
+ The transformation is defined as:
+
+ log_odds = (score - anchor_score) / factor + log(target_odds)
+
+ where:
+ factor = points_to_double_the_odds / log(2)
+
+ Args:
+ score (pl.Expr): The credit score expression.
+ anchor_score (float): The reference score (also known as base_points).
+ target_odds (float): The odds at the anchor score (also known as base_odds).
+ points_to_double_the_odds (float): Points needed to double the odds (also known as pdo).
+
+ Returns:
+ pl.Expr: A Polars expression representing the log odds.
+ """
+ # I was considering writing like pl.lit(2).log() but it seems that its better practice to just use the literal from python side (source chatgpt)
+ factor = points_to_double_the_odds / math.log(2)
+ return (score - anchor_score) / factor + math.log(target_odds)
+
+
+def probability_of_default_from_log_odds(
+ credit_log_odds: pl.Expr,
+ log_odds_safety_cap: float = 40.0,
+ financial_rounding: bool = True,
+) -> pl.Expr:
+ """
+ Calculate probability of default (PD) from log odds using numerically stable transformation.
+
+ Most people know probability of default is:
+ 1 - (exp((score - offset) / factor) / (1 + exp((score - offset) / factor)))
+
+ However, we can show that:
+ log_odds = (score - offset) / factor
+ x = exp(log_odds)
+
+ And the mathematical transformation:
+ 1 - (x / (1 + x))
+ = (1 + x)/(1 + x) - x/(1 + x)
+ = (1 + x - x)/(1 + x)
+ = 1/(1 + x)
+
+ Hence we use the form 1/(1 + exp(log_odds)) as it's more numerically stable
+ and more efficient to compute.
+
+ Args:
+ credit_log_odds (pl.Expr): The log odds expression.
+ log_odds_safety_cap (float, optional): Maximum log odds value to prevent
+ numerical overflow. Defaults to 40.0.
+ pl.LazyFrame().with_columns(pl.lit(89).cast(pl.Float32).exp()).collect() -> inf
+ So limiting to 40 ensures the calculation will be stable and not lead to problems later.
+ (Note a lower value of closer to 8 may be needed for float16)
+ financial_rounding (bool, optional): Whether to apply financial rounding
+ (7 decimal places, Float32 cast). Defaults to True.
+
+ Returns:
+ pl.Expr: A Polars expression representing the probability of default.
+ """
+ exp_log_odds = pl.min_horizontal(pl.lit(log_odds_safety_cap), credit_log_odds).exp()
+ pd = 1 / (1 + exp_log_odds)
+ if financial_rounding:
+ pd = pd.round(7).cast(pl.Float32)
+ return pd
+
+
+def calculate_probability_of_default(
+ score: pl.Expr,
+ anchor_score: float = 660,
+ target_odds: float = 15,
+ points_to_double_the_odds: float = 20,
+ log_odds_safety_cap: float = 40.0,
+ financial_rounding: bool = True,
+) -> pl.Expr:
+ """
+ Calculate probability of default from credit score in a two-step process.
+
+ This function combines the log odds calculation and probability transformation:
+ 1. Convert score to log odds using linear transformation
+ 2. Convert log odds to probability using numerically stable logistic function
+
+ Args:
+ score (pl.Expr): The credit score expression.
+ anchor_score (float): The reference score (also known as base_points).
+ target_odds (float): The odds at the anchor score (also known as base_odds).
+ points_to_double_the_odds (float): Points needed to double the odds (also known as pdo).
+ log_odds_safety_cap (float, optional): Maximum log odds value to prevent
+ numerical overflow. Defaults to 40.0.
+ financial_rounding (bool, optional): Whether to apply financial rounding
+ (7 decimal places, Float32 cast). Defaults to True.
+
+ Returns:
+ pl.Expr: A Polars expression representing the probability of default.
+ """
+ log_odds = log_odds_from_score(score, anchor_score, target_odds, points_to_double_the_odds)
+ return probability_of_default_from_log_odds(log_odds, log_odds_safety_cap, financial_rounding)
+
+
+
+def calculate_credit_score(
+ probability_of_default: pl.Expr,
+ points_to_double_the_odds: float = 20,
+ anchor_score: float = 660,
+ target_odds: float = 15,
+ safety_factor: float = 1e-10,
+ return_integer: bool = True,
+) -> pl.Expr:
+ """
+ Calculate credit score from probability of default using inverse logistic transformation.
+
+ This function is the inverse of calculate_probability_of_default, transforming a
+ probability of default (PD) back into a credit score using the inverse logistic formula.
+
+ The transformation performs the inverse of the PD calculation:
+ 1. Convert PD to odds: odds = (1 - pd) / pd
+ 2. Take log odds: log_odds = log(odds)
+ 3. Convert to score: score = offset + factor * log_odds
+
+ Where:
+ factor = points_to_double_the_odds / log(2)
+ offset = anchor_score - log(target_odds) * factor
+
+ Args:
+ probability_of_default (pl.Expr): A Polars expression representing the
+ probability of default (between 0 and 1).
+ points_to_double_the_odds (float, optional): Number of score points required
+ to double the odds of default (also known as pdo). Defaults to 20.
+ anchor_score (float, optional): The reference score (also known as base_points).
+ Defaults to 660.
+ target_odds (float, optional): The odds at the anchor score (also known as
+ base_odds). Defaults to 15.
+ safety_factor (float, optional): Small value to clip PD away from 0 and 1
+ for numerical stability. Defaults to 1e-10.
+ return_integer (bool, optional): Whether to round the score to the nearest
+ integer and cast to Int32. Defaults to True.
+
+ Returns:
+ pl.Expr: A Polars expression representing the calculated credit score,
+ optionally rounded to the nearest integer and cast to Int32.
+ """
+ # Calculation Parameters
+ factor: pl.Expr = pl.lit(factor_value := points_to_double_the_odds / math.log(2))
+ offset: pl.Expr = pl.lit(anchor_score - (math.log(target_odds) * factor_value))
+ # Ensure PD is within (safety_factor, 1-safety_factor) for numerical stability
+ pd = probability_of_default.clip(safety_factor, 1 - safety_factor)
+ # Note original implementation wrapped this with if probability_of_default.is_null()
+ # However if you push null through the calculation it returns null anyway so i feel its best to leave it out to reduce complexity.
+ score = offset + factor * ((1- pd) / pd).log()
+ if return_integer:
+ score = score.round().cast(pl.Int32)
+ return score
diff --git a/decider/modules/credit/scorecard/module.py b/decider/modules/credit/scorecard/module.py
new file mode 100644
index 0000000..89cbd0c
--- /dev/null
+++ b/decider/modules/credit/scorecard/module.py
@@ -0,0 +1,357 @@
+import typing as t
+import polars as pl
+from pydantic import Discriminator, Tag, model_validator, PrivateAttr, field_validator
+from decider.modules.core import BaseModule
+from decider.modules.expression import Node
+from .impl import (
+ BoundBin,
+ ValuesBin,
+ DefaultBin,
+ score_variable,
+ default_get_value_from_struct,
+ adjust_score,
+ calculate_score,
+ calculate_probability_of_default,
+ log_odds_from_score,
+ calculate_credit_score,
+)
+from decider.serializable.function import DefinedFunction
+
+def get_bin_type(bin_obj: t.Union[BoundBin, ValuesBin, dict]) -> str:
+ if isinstance(bin_obj, dict):
+ return "values" if "items" in bin_obj else "bound"
+ return "values" if hasattr(bin_obj, 'items') else "bound"
+
+_TBin = t.Annotated[
+ t.Union[
+ t.Annotated[BoundBin, Tag("bound")],
+ t.Annotated[ValuesBin, Tag("values")]
+ ],
+ Discriminator(get_bin_type)
+]
+
+
+class ScoredVariable(BaseModule):
+ type: t.Literal["scored"]
+ variable_name: str
+ bins: t.List[_TBin]
+ default: DefaultBin
+ strict: bool = True
+ raw_output_name: t.Optional[str] = None
+ value_output_name: t.Optional[str] = "{variable_name}_score"
+
+ variable_struct_function: t.Optional[DefinedFunction] = None
+ struct_to_score_function: t.Optional[DefinedFunction] = None
+
+ _bound_bins: t.List[BoundBin] = PrivateAttr(default_factory=list)
+ _value_bins: t.List[ValuesBin] = PrivateAttr(default_factory=list)
+
+ @model_validator(mode='after')
+ def validate_bins(self):
+ """
+ Validates the bins configuration to ensure:
+ 1. No duplicate values across ValuesBins
+ 2. BoundBins are in ascending order without overlaps
+ 3. Proper boundary continuity between consecutive BoundBins
+ 4. Gap handling based on strict mode setting
+
+ In strict mode: consecutive BoundBins must have exact boundary matches (no gaps)
+ In non-strict mode: gaps are allowed and will use default values
+
+ Examples:
+ Valid configurations:
+ - [BoundBin(0, 1), BoundBin(1, 2)] (strict=True, continuous)
+ - [BoundBin(0, 0.5), BoundBin(1, 2)] (strict=False, gap 0.5-1 uses default)
+ - [BoundBin(None, 1), BoundBin(1, None)] (first=-inf to 1, second=1 to +inf)
+ - [ValuesBin({'A', 'B'}), BoundBin(0, 1)] (mixed types allowed)
+
+ Invalid configurations:
+ - [BoundBin(1, 2), BoundBin(0, 1)] (wrong order)
+ - [BoundBin(0, 1.5), BoundBin(1, 2)] (overlap)
+ - [ValuesBin({'A'}), ValuesBin({'A'})] (duplicate value 'A')
+ - [BoundBin(0, None), BoundBin(None, 2)] (unbounded gap)
+
+ Raises:
+ ValueError: If validation fails with detailed error messages
+ """
+ if not self.bins:
+ return self
+
+ # Track all values across ValuesBins to check for duplicates
+ all_values = set()
+
+ # Track bounds continuity and ordering for BoundBins
+ last_bin: t.Optional[BoundBin] = None
+ last_bin_idx = -1
+ highest_bound = float('-inf') # Tracks the highest bound seen so far
+ _bound_bins = []
+ _value_bins = []
+
+ for i, bin_obj in enumerate(self.bins):
+ if isinstance(bin_obj, ValuesBin):
+ # Check for duplicate values across all ValuesBins using set intersection
+ _value_bins.append(bin_obj)
+ intersection = set(bin_obj.items) & all_values
+ if intersection:
+ raise ValueError(f"Duplicate values {intersection} found in ValuesBin at index {i}")
+ all_values.update(bin_obj.items)
+
+ elif isinstance(bin_obj, BoundBin):
+ # Validate that lower_bound < upper_bound when both are defined
+ _bound_bins.append(bin_obj)
+ if bin_obj.lower_bound is not None and bin_obj.upper_bound is not None:
+ if bin_obj.lower_bound >= bin_obj.upper_bound:
+ raise ValueError(f"BoundBin at index {i} has lower_bound >= upper_bound")
+
+ # Check bin ordering: ensure bins are in ascending order
+ # First check lower_bound against highest_bound seen so far
+ if bin_obj.lower_bound is not None:
+ if bin_obj.lower_bound < highest_bound:
+ raise ValueError(f"BoundBin at index {i} has lower_bound {bin_obj.lower_bound} which is less than previous highest_bound {highest_bound}. Bound bins must be in order and non-overlapping.")
+ highest_bound = bin_obj.lower_bound # Update to current lower_bound
+
+ # Then check upper_bound - since lower < upper (validated above),
+ # this will effectively set highest_bound to max(lower, upper)
+ if bin_obj.upper_bound is not None:
+ if bin_obj.upper_bound < highest_bound:
+ raise ValueError(f"BoundBin at index {i} has upper_bound {bin_obj.upper_bound} which is less than previous highest_bound {highest_bound}. Bound bins must be in order and non-overlapping.")
+ highest_bound = bin_obj.upper_bound # Update to current upper_bound (the max)
+
+ # Check boundary continuity between consecutive BoundBins
+ if last_bin is not None:
+ # If both boundaries are defined, check for proper continuity/gaps
+ if last_bin.upper_bound is not None and bin_obj.lower_bound is not None:
+ # In strict mode: must be exactly continuous (no gaps)
+ # In non-strict mode: allow gaps but prevent overlaps
+ if (self.strict and bin_obj.lower_bound != last_bin.upper_bound) or bin_obj.lower_bound < last_bin.upper_bound:
+ raise ValueError(f"BoundBin at index {i} lower_bound ({bin_obj.lower_bound}) must be {'equal to (in strict mode)' if self.strict else 'greater than or equal to'} previous upper_bound ({last_bin.upper_bound}) defined in BoundBin at index {last_bin_idx}. Bound bins must be in order and non-overlapping.")
+
+ # Ensure at least one boundary is defined between consecutive bins
+ # This prevents unbounded gaps (e.g., both upper_bound=None and lower_bound=None)
+ if last_bin.upper_bound is None and bin_obj.lower_bound is None:
+ raise ValueError(f"Either BoundBin at index {last_bin_idx} must define a upper_bound or BoundBin at index {i} must have a lower_bound.")
+
+ # Update tracking variables for next iteration
+ last_bin = bin_obj
+ last_bin_idx = i
+ self._bound_bins = _bound_bins
+ self._value_bins = _value_bins
+ return self
+
+ def get_value_output_name(self) -> t.Optional[str]:
+ if self.value_output_name is None:
+ return None
+ return self.value_output_name.format(variable_name=self.variable_name)
+
+ def expand_nodes(self) -> t.List[Node]:
+ nodes = []
+
+ value_output_name = self.get_value_output_name()
+
+ # Determine the name of the intermediate struct node.
+ # If raw_output_name is explicitly set, use it (and always emit the node).
+ # If only value_output_name is set we still need the struct, so use a
+ # private internal name that won't be exposed as a graph output.
+ struct_node_name = self.raw_output_name
+ if struct_node_name is None and value_output_name is not None:
+ struct_node_name = f"_{self.variable_name}_raw"
+
+ if struct_node_name is not None:
+ nodes.append(Node.from_callable(
+ score_variable,
+ name=struct_node_name,
+ input_map={"input": self.variable_name},
+ static_kwargs={
+ "bound_bins": self._bound_bins,
+ "value_bins": self._value_bins,
+ "default_bin": self.default,
+ "input_name": self.variable_name,
+ "output_expr_fn": self.variable_struct_function.get_function() if self.variable_struct_function is not None else None
+ }
+ ))
+
+ if value_output_name is not None:
+ value_func = (self.struct_to_score_function.get_function()
+ if self.struct_to_score_function is not None
+ else default_get_value_from_struct)
+ nodes.append(Node.from_callable(
+ value_func,
+ name=value_output_name,
+ input_map={"struct": struct_node_name}
+ ))
+
+ return nodes
+
+
+class AdjustedVariable(BaseModule):
+ type: t.Literal["adjusted"]
+ variable: ScoredVariable
+ offset: float = 0.0
+ scale: float = 1.0
+ variable_output_name: str = "{variable_name}_adjusted_score"
+
+ @field_validator("variable", mode="after")
+ @classmethod
+ def validate_variable(cls, variable: ScoredVariable) -> ScoredVariable:
+ if variable.value_output_name is None:
+ raise ValueError("ScoredVariable used in AdjustedVariable must have a value_output_name defined to be referenced for adjustment.")
+ return variable
+
+ def get_value_output_name(self) -> t.Optional[str]:
+ return self.variable_output_name.format(
+ variable_name=self.variable.variable_name,
+ score_value_output_name=self.variable.get_value_output_name(),
+ )
+
+ def expand_nodes(self) -> t.List[Node]:
+ output_name = self.get_value_output_name()
+ return [
+ Node.from_callable(
+ adjust_score,
+ name=output_name,
+ input_map={"score": self.variable.get_value_output_name()},
+ static_kwargs={"offset": self.offset, "scale": self.scale}
+ )
+ ]
+
+class ConstantScore(BaseModule):
+ type: t.Literal["constant"]
+ score: float
+ output_name: str = "constant_score"
+
+ def get_value_output_name(self) -> str:
+ return self.output_name
+
+ def expand_nodes(self) -> t.List[Node]:
+ def constant_score_fn() -> pl.Expr:
+ return pl.lit(self.score)
+
+ return [Node.from_callable(constant_score_fn, name=self.output_name)]
+
+_TScoredVariable = t.Annotated[
+ t.Union[ScoredVariable, AdjustedVariable, ConstantScore],
+ Discriminator("type")
+]
+
+class ScoreCard(BaseModule):
+ type: t.Literal["scorecard"]
+ variables: t.List[_TScoredVariable]
+ output_name: str = "score"
+
+ @field_validator("variables", mode="after")
+ @classmethod
+ def validate_variables(cls, variables: t.List[_TScoredVariable]) -> t.List[_TScoredVariable]:
+ variable_names = set()
+ for var in variables:
+ if var.type == "constant": continue
+ if var.variable_name in variable_names:
+ raise ValueError(f"Duplicate variable_name '{var.variable_name}' found in ScoreCard variables. Each ScoredVariable must have a unique variable_name.")
+ variable_names.add(var.variable_name)
+ if var.get_value_output_name() is None:
+ raise ValueError(f"ScoredVariable with variable_name '{var.variable_name}' must have a value_output_name defined to be used in ScoreCard.")
+ return variables
+
+ def expand_nodes(self) -> t.List[Node]:
+ from decider.modules.expression import ExternalInputNode
+
+ nodes = []
+ for variable in self.variables:
+ nodes.extend(variable.expand_nodes())
+
+ # Get all value output names to create input mapping for calculate_score
+ # Need to convert to ExternalInputNode references
+ input_map = {
+ var_name: ExternalInputNode(var_name)
+ for var in self.variables
+ if (var_name := var.get_value_output_name())
+ }
+
+ nodes.append(Node(
+ name=self.output_name,
+ callable=calculate_score,
+ input_map=input_map,
+ ))
+
+ return nodes
+
+
+class ProbabilityDefault(BaseModule):
+ type: t.Literal["probability_default"]
+
+ def expand_nodes(self) -> t.List[Node]:
+ return [
+ Node.from_callable(
+ calculate_probability_of_default,
+ name=self.output_name,
+ input_map={"score": self.input_name}
+ )
+ ]
+
+class LogProbability(BaseModule):
+ type: t.Literal["log_probability"]
+
+ def expand_nodes(self) -> t.List[Node]:
+ return [
+ Node.from_callable(
+ log_odds_from_score,
+ name=self.output_name,
+ input_map={"score": self.input_name},
+ static_kwargs={"anchor_score": 660, "target_odds": 15, "points_to_double_the_odds": 20}
+ )
+ ]
+
+class ScoreFromPDO(BaseModule):
+ type: t.Literal["score_from_pdo"]
+
+ def expand_nodes(self) -> t.List[Node]:
+ return [
+ Node.from_callable(
+ calculate_credit_score,
+ name=self.output_name,
+ input_map={"probability_of_default": self.input_name}
+ )
+ ]
+
+
+class WeightedScore(t.NamedTuple):
+ score_name: str
+ weight: float
+
+class MergeScorecardValues(BaseModule):
+ type: t.Literal["merge_scorecard_values"]
+ weighted_scores: t.List[WeightedScore]
+ output_name: str = "merged_scorecard_values"
+
+ @field_validator("weighted_scores", mode="after")
+ @classmethod
+ def validate_weighted_scores(cls, weighted_scores: t.List[WeightedScore]) -> t.List[WeightedScore]:
+ total_weight = sum(ws.weight for ws in weighted_scores)
+ if total_weight != 1.0:
+ raise ValueError(f"The sum of weights in weighted_scores must equal 1.0. Current sum is {total_weight}.")
+ return weighted_scores
+
+ def expand_nodes(self) -> t.List[Node]:
+ from decider.modules.expression import ExternalInputNode
+
+ # TODO i just made this up i think its more complex than this as pd is involved here. @christiaan
+ weighted_scores = self.weighted_scores
+ def merge_scorecard_values(**kwargs):
+ # Create a weighted sum of the scores
+ result = 0.0
+ for weighted_score in weighted_scores:
+ score_value = kwargs[weighted_score.score_name]
+ result += score_value * weighted_score.weight
+ return result
+
+ input_map = {
+ ws.score_name: ExternalInputNode(ws.score_name)
+ for ws in self.weighted_scores
+ }
+
+ return [
+ Node(
+ name=self.output_name,
+ callable=merge_scorecard_values,
+ input_map=input_map,
+ )
+ ]
\ No newline at end of file
diff --git a/decider/modules/expression.py b/decider/modules/expression.py
new file mode 100644
index 0000000..4c6c346
--- /dev/null
+++ b/decider/modules/expression.py
@@ -0,0 +1,219 @@
+import typing as t
+import inspect
+import polars as pl
+from pydantic import PrivateAttr
+from dataclasses import dataclass, field
+from collections import OrderedDict
+from abc import abstractmethod
+
+from decider.types import TInputType, TOutputType
+from .core import BaseModule, BaseExecuteModule
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+ from decider.config.base import BaseConfig
+
+
+# ── Input ref types ───────────────────────────────────────────────────────────
+
+@dataclass(slots=True)
+class StaticValueNode:
+ value: t.Any
+
+ def get_expr(self) -> t.Any:
+ return self.value
+
+ def get_frame_value(self, _frames: t.Dict[str, t.Any]) -> t.Any:
+ return self.value
+
+ def to_config(self, config_key: str) -> "ConfigValueNode":
+ from pydantic import BaseModel
+ from decider.config.base import BaseConfig
+ if not isinstance(self.value, BaseModel):
+ raise TypeError(
+ f"StaticValueNode.to_config requires the value to be a Pydantic BaseModel, "
+ f"got {type(self.value).__name__}."
+ )
+ value_class = type(self.value)
+ config_class = type(
+ f"{value_class.__name__}Config",
+ (BaseConfig,),
+ {"_MODEL_CLASS": value_class, "__module__": value_class.__module__},
+ )
+ return ConfigValueNode(
+ config=config_class.from_model(model=self.value, config_key=config_key)
+ )
+
+
+@dataclass(slots=True)
+class ConfigValueNode:
+ """A node whose value is owned by a BaseConfig instance.
+
+ Reads config._constructed_model at expression-graph evaluation time.
+ Call node.config.reload() externally to pick up a new config version.
+ """
+ config: "BaseConfig[t.Any]"
+
+ def get_expr(self) -> t.Any:
+ return self.config._constructed_model
+
+ def get_frame_value(self, _frames: t.Dict[str, t.Any]) -> t.Any:
+ return self.config._constructed_model
+
+
+@dataclass(slots=True)
+class ExternalInputNode:
+ input_name: str
+
+ def get_expr(self) -> pl.Expr:
+ return pl.col(self.input_name)
+
+ def get_frame_value(self, frames: t.Dict[str, t.Any]) -> t.Any:
+ if self.input_name not in frames:
+ raise ValueError(
+ f"Frame '{self.input_name}' not found. Available: {list(frames.keys())}"
+ )
+ return frames[self.input_name]
+
+
+# ── Node ──────────────────────────────────────────────────────────────────────
+
+@dataclass(slots=True)
+class Node:
+ name: str
+ callable: t.Callable
+
+ input_map: t.Dict[str, t.Union["Node", StaticValueNode, ExternalInputNode, ConfigValueNode]] = field(
+ default_factory=dict
+ )
+
+ def get_expr(self) -> pl.Expr:
+ return pl.col(self.name)
+
+ def get_input_expressions(self) -> t.Dict[str, t.Any]:
+ return {k: ref.get_expr() for k, ref in self.input_map.items()}
+
+ def get_frame_kwargs(self, frames: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
+ return {k: ref.get_frame_value(frames) for k, ref in self.input_map.items()}
+
+ @property
+ def frame_dependencies(self) -> t.List[str]:
+ return [
+ ref.input_name for ref in self.input_map.values()
+ if isinstance(ref, ExternalInputNode)
+ ]
+
+ def get_dependencies(self) -> t.List[str]:
+ return [ref.name for ref in self.input_map.values() if isinstance(ref, Node)]
+
+ @classmethod
+ def from_callable(
+ cls,
+ func: t.Callable,
+ name: t.Optional[str] = None,
+ input_map: t.Optional[t.Dict[str, str]] = None,
+ static_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
+ ) -> "Node":
+ name = name or func.__name__
+ params = inspect.signature(func).parameters
+ static_kwargs = static_kwargs or {}
+
+ has_var_keyword = any(
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
+ )
+ named_params = {
+ k for k, p in params.items()
+ if p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
+ }
+ required = {
+ k for k in named_params
+ if params[k].default is inspect.Parameter.empty and k not in static_kwargs
+ }
+
+ input_map = input_map or {}
+ resolved = {k: input_map.get(k, k) for k in required}
+
+ extra = {k for k in input_map if k not in named_params and k not in static_kwargs}
+ if extra:
+ if not has_var_keyword:
+ raise ValueError(
+ f"Parameters {sorted(extra)} are in input_map but '{func.__name__}' "
+ "has no matching parameter and no **kwargs."
+ )
+ for k in extra:
+ resolved[k] = input_map[k]
+
+ return cls(
+ name=name,
+ callable=func,
+ input_map={
+ k: ExternalInputNode(input_name=resolved[k]) for k in resolved
+ } | {
+ k: StaticValueNode(value=v) for k, v in static_kwargs.items()
+ },
+ )
+
+
+# ── CompiledExpressions ───────────────────────────────────────────────────────
+
+@dataclass(slots=True)
+class CompiledExpressions:
+ expressions: OrderedDict # OrderedDict[str, pl.Expr]
+ input_frame: str = "input"
+ drop_inputs: bool = False
+
+ def execute(self, inputs: TInputType) -> TOutputType:
+ frame = inputs[self.input_frame]
+ if isinstance(frame, pl.DataFrame):
+ frame = frame.lazy()
+ for name, expr in self.expressions.items():
+ frame = frame.with_columns(expr.alias(name))
+ if self.drop_inputs:
+ return frame.select(list(self.expressions.keys()))
+ return frame
+
+
+# ── ExpressionModule ──────────────────────────────────────────────────────────
+
+class ExpressionModule(BaseExecuteModule):
+ name: str
+ _compiled_expressions: t.Optional[CompiledExpressions] = PrivateAttr(None)
+
+ @abstractmethod
+ def expand_nodes(self) -> t.Dict[str, Node]:
+ ...
+
+ def compile_expressions(self, executor: t.Optional["Executor"] = None) -> CompiledExpressions:
+ if self._compiled_expressions is None:
+ from decider.settings import get_default_executor
+ executor = executor or get_default_executor()
+ nodes = list(self.expand_nodes().values())
+ self._compiled_expressions = executor.compile_expression_graph(nodes)
+ return self._compiled_expressions
+
+ def execute(self, inputs: TInputType, _executor: "Executor") -> TOutputType:
+ if self._compiled_expressions is None:
+ raise RuntimeError(
+ f"Module '{self.name}' has not been compiled. "
+ "Call .compile_expressions() first, or use mod(inputs) which compiles automatically."
+ )
+ return self._compiled_expressions.execute(inputs)
+
+ def __call__(self, inputs: TInputType, executor: t.Optional["Executor"] = None) -> TOutputType:
+ from decider.settings import get_default_executor
+ executor = executor or get_default_executor()
+ self.compile_expressions(executor)
+ return executor.execute(self, inputs)
+
+ def __and__(self, other: "ExpressionModule") -> "ExpressionModule":
+ from decider.modules.primitives.union import UnionExpressionModule
+ left_modules = self.modules if isinstance(self, UnionExpressionModule) else [self] # type: ignore[attr-defined]
+ right_modules = other.modules if isinstance(other, UnionExpressionModule) else [other] # type: ignore[attr-defined]
+ return UnionExpressionModule(
+ name=f"{self.name}__{other.name}",
+ modules=left_modules + right_modules,
+ )
+
+ def __or__(self, other: "BaseModule") -> "SequentialModule":
+ from decider.modules.primitives.sequential import SequentialModule
+ return SequentialModule(name=self.name, steps=[self, other])
diff --git a/decider/modules/functional.py b/decider/modules/functional.py
new file mode 100644
index 0000000..7300298
--- /dev/null
+++ b/decider/modules/functional.py
@@ -0,0 +1,126 @@
+import typing as t
+import inspect
+import polars as pl
+from pydantic import BaseModel
+from types import ModuleType
+from .core import BaseModule
+from .expression import ExpressionModule, Node
+
+
+def generate_from_functions(module_name: str, *functions: t.Callable) -> t.Type[BaseModule]:
+ """Create a module class from plain Python functions.
+
+ Three conventions wire everything together automatically:
+
+ 1. **Function name → output column name**
+ Each function produces a new column whose name matches the function's
+ ``__name__``. ``def risk_score(...) → pl.Expr`` adds a ``risk_score``
+ column to the frame.
+
+ 2. **Parameter name → input column lookup**
+ Parameters are resolved in order:
+ a. If another function in this call has the same name, its output
+ expression is injected (dependency wiring).
+ b. Otherwise the parameter is read from the column of that name in the
+ input dataframe.
+
+ Example — ``amount_centered`` receives the output of ``amount_mean``:
+ ::
+
+ def amount_mean(amount: pl.Expr) -> pl.Expr:
+ return amount.mean()
+
+ def amount_centered(amount: pl.Expr, amount_mean: pl.Expr) -> pl.Expr:
+ return amount - amount_mean
+
+ 3. **Optional ``config`` parameter → module config injection**
+ If a function declares a ``config`` parameter annotated with a Pydantic
+ model, the module itself acts as that config (fields are defined on the
+ generated class) and the current instance is injected at call time.
+
+ Args:
+ module_name: Type discriminator string used for serialisation. Also
+ shows up in debug output and error messages. Must be lowercase and
+ unique within your project (e.g. ``"income_scorer"``).
+ *functions: One or more plain functions returning ``pl.Expr``. Order
+ only matters when two functions share the same dependency graph
+ level — prefer topological clarity over positional ordering.
+
+ Returns:
+ A new ``BaseModule`` subclass. Instantiate it with ``name=`` to get a
+ module you can call directly or compose with ``|`` and ``&``::
+
+ Scorer = generate_from_functions("scorer", risk_score, tier_flag)
+ scorer = Scorer(name="my_scorer")
+ result = scorer({"input": df})
+
+ Input frame convention:
+ Pass frames as a dict. The key ``"input"`` is the default frame every
+ expression targets. You may pass additional named frames; expression
+ functions always operate on columns, not whole frames.
+
+ Quickstart::
+
+ import polars as pl
+ from decider.modules.functional import generate_from_functions
+
+ def score(amount: pl.Expr) -> pl.Expr:
+ return amount * 100
+
+ Scorer = generate_from_functions("scorer", score)
+ result = Scorer(name="s")({"input": pl.DataFrame({"amount": [1, 2, 3]})})
+ # result is a LazyFrame; call .collect() to materialise it
+ """
+ # Create a dictionary of the functions to be used as the namespace for the new class
+
+ # 1. look in all functions for a keyword argument named config and ensure they are all the same
+ config_class = None
+ requires_injection: t.List[bool] = []
+ for func in functions:
+ sig = inspect.signature(func)
+ if "config" not in sig.parameters:
+ requires_injection.append(False)
+ continue
+ requires_injection.append(True)
+ param = sig.parameters["config"]
+ if param.annotation is inspect.Parameter.empty:
+ raise TypeError(f"Function {func.__name__} has a 'config' parameter but it is not annotated with a type")
+ if config_class is None:
+ config_class = param.annotation
+ elif config_class != param.annotation:
+ raise TypeError(f"All functions must have the same type for the 'config' parameter. Found {config_class} and {param.annotation}")
+
+ # check config is a pydantic model
+ if config_class is not None:
+ if not issubclass(config_class, BaseModel):
+ raise TypeError(f"The 'config' parameter must be a subclass of pydantic.BaseModel. Found {config_class}")
+
+ if config_class is None: config_class = BaseModel
+
+ module_name = module_name.lower()
+
+ class TModule(ExpressionModule, config_class):
+ type: t.Literal[module_name]
+
+ def expand_nodes(self) -> t.Dict[str, "Node"]:
+ nonlocal functions, requires_injection
+ internal_nodes: t.Dict[str, Node] = {}
+ for inject, func in zip(requires_injection, functions):
+ node = Node.from_callable(
+ func,
+ static_kwargs={"config": self} if inject else None
+ )
+ internal_nodes[node.name] = node
+
+ for node in internal_nodes.values():
+ for k in node.input_map.keys():
+ if k in internal_nodes:
+ node.input_map[k] = internal_nodes[k]
+
+ return internal_nodes
+ return TModule
+
+def generate_from_module(module_name: str, module: ModuleType) -> t.Type[BaseModule]:
+ """Dynamically generate a BaseModule subclass from all the functions in a given module."""
+ functions = [getattr(module, attr) for attr in dir(module) if callable(getattr(module, attr))]
+ return generate_from_functions(module_name, *functions)
\ No newline at end of file
diff --git a/decider/modules/primitives/__init__.py b/decider/modules/primitives/__init__.py
new file mode 100644
index 0000000..e2012fb
--- /dev/null
+++ b/decider/modules/primitives/__init__.py
@@ -0,0 +1,5 @@
+from .join import FrameModule, JoinModule, FrameRef
+from .union import UnionExpressionModule
+from .sequential import SequentialModule
+
+__all__ = ["FrameModule", "JoinModule", "FrameRef", "UnionExpressionModule", "SequentialModule"]
diff --git a/decider/modules/primitives/join.py b/decider/modules/primitives/join.py
new file mode 100644
index 0000000..4b0c0ab
--- /dev/null
+++ b/decider/modules/primitives/join.py
@@ -0,0 +1,90 @@
+import typing as t
+from abc import abstractmethod
+
+import polars as pl
+from pydantic import field_validator
+
+from decider.types import TInputType, TOutputType
+from decider.modules.core import BaseModule, BaseExecuteModule
+
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+
+
+FrameInput = t.Union[str, BaseModule]
+
+
+def _deserialise_frame_input(v: t.Any) -> FrameInput:
+ if isinstance(v, (str, BaseModule)):
+ return v
+ if isinstance(v, dict):
+ from decider.modules._ext import GraphModule
+ return GraphModule.model_validate(v).root
+ return v
+
+
+def _resolve_frame(
+ ref: FrameInput,
+ inputs: TInputType,
+ executor: "Executor",
+) -> pl.LazyFrame:
+ if isinstance(ref, str):
+ frame = inputs[ref]
+ return frame.lazy() if isinstance(frame, pl.DataFrame) else frame
+ result = ref(inputs, executor=executor)
+ return result.lazy() if isinstance(result, pl.DataFrame) else result
+
+
+class FrameRef(BaseExecuteModule):
+ """Extracts a named frame from inputs as 'input', enabling frame routing.
+
+ Used to feed a specific named input into a sub-pipeline:
+ FrameRef("input1") | scorer → routes input1 through scorer
+ """
+
+ type: t.Literal["frame_ref"]
+
+ def execute(self, inputs: TInputType, executor: "Executor") -> TOutputType:
+ frame = inputs[self.name]
+ return frame.lazy() if isinstance(frame, pl.DataFrame) else frame
+
+
+class FrameModule(BaseExecuteModule):
+ """Module that combines named input frames and returns a new LazyFrame."""
+
+ @abstractmethod
+ def execute(self, inputs: TInputType, executor: "Executor") -> TOutputType:
+ ...
+
+
+class JoinModule(FrameModule):
+ type: t.Literal["join"]
+ left: t.Any # FrameInput; Any allows discriminated deserialisation
+ right: t.Any
+ on: t.Union[str, t.List[str]]
+ how: str = "left"
+
+ model_config = {"arbitrary_types_allowed": True}
+
+ @field_validator("left", "right", mode="before")
+ @classmethod
+ def _deserialise_frame_input(cls, v: t.Any) -> FrameInput:
+ return _deserialise_frame_input(v)
+
+ def _compute_input_frame_keys(self) -> t.List[str]:
+ keys = []
+ if isinstance(self.left, str):
+ keys.append(self.left)
+ else:
+ keys.extend(self.left.get_input_frame_keys())
+ if isinstance(self.right, str):
+ keys.append(self.right)
+ else:
+ keys.extend(self.right.get_input_frame_keys())
+ return list(dict.fromkeys(keys)) # deduplicate, preserve order
+
+ def execute(self, inputs: TInputType, executor: "Executor") -> TOutputType:
+ left_frame = _resolve_frame(self.left, inputs, executor)
+ right_frame = _resolve_frame(self.right, inputs, executor)
+ return left_frame.join(right_frame, on=self.on, how=self.how)
diff --git a/decider/modules/primitives/sequential.py b/decider/modules/primitives/sequential.py
new file mode 100644
index 0000000..b473888
--- /dev/null
+++ b/decider/modules/primitives/sequential.py
@@ -0,0 +1,55 @@
+import typing as t
+
+import polars as pl
+from pydantic import field_validator
+
+from decider.types import TInputType, TOutputType
+from decider.modules.core import BaseModule, BaseExecuteModule
+
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+
+
+class SequentialModule(BaseExecuteModule):
+ """Chains modules so each step receives the previous step's output as 'input'.
+
+ Created via the | operator: mod_a | mod_b | mod_c
+ """
+
+ type: t.Literal["sequential"]
+ steps: t.List[t.Any] # BaseModule; use Any to allow discriminated deserialisation
+
+ @field_validator("steps", mode="before")
+ @classmethod
+ def _deserialise_steps(cls, v: t.Any) -> t.List[BaseModule]:
+ from decider.modules._ext import GraphModule
+ result = []
+ for item in v:
+ if isinstance(item, BaseModule):
+ result.append(item)
+ elif isinstance(item, dict):
+ result.append(GraphModule.model_validate(item).root)
+ else:
+ result.append(item)
+ return result
+
+ def _compute_input_frame_keys(self) -> t.List[str]:
+ return self.steps[0].get_input_frame_keys() if self.steps else ["input"]
+
+ def execute(self, inputs: TInputType, executor: "Executor") -> TOutputType:
+ frames: t.Dict[str, pl.LazyFrame] = {
+ k: v.lazy() if isinstance(v, pl.DataFrame) else v
+ for k, v in inputs.items()
+ }
+ _input = frames.get("input")
+ current = _input if _input is not None else next(iter(frames.values()))
+
+ for step in self.steps:
+ frames["input"] = current
+ current = step(frames, executor=executor)
+
+ return current
+
+ def __or__(self, other: BaseModule) -> "SequentialModule":
+ return SequentialModule(name=self.name, steps=self.steps + [other])
diff --git a/decider/modules/primitives/union.py b/decider/modules/primitives/union.py
new file mode 100644
index 0000000..7ce6ea4
--- /dev/null
+++ b/decider/modules/primitives/union.py
@@ -0,0 +1,22 @@
+import typing as t
+
+from decider.modules.expression import ExpressionModule, Node
+
+
+class UnionExpressionModule(ExpressionModule):
+ """Merges multiple ExpressionModules into a single compilation pass.
+
+ All children's nodes are flattened into one CompiledExpressions artifact,
+ so the combined module applies in a single frame pass.
+
+ Created via the & operator: mod_a & mod_b
+ """
+
+ type: t.Literal["union"]
+ modules: t.List[ExpressionModule]
+
+ def expand_nodes(self) -> t.Dict[str, Node]:
+ merged: t.Dict[str, Node] = {}
+ for mod in self.modules:
+ merged.update(mod.expand_nodes())
+ return merged
diff --git a/decider/modules/rules/__init__.py b/decider/modules/rules/__init__.py
new file mode 100644
index 0000000..c321149
--- /dev/null
+++ b/decider/modules/rules/__init__.py
@@ -0,0 +1,160 @@
+# Tree formats
+from .tree.v1.tree import Tree as V1Tree
+from .tree.v2.tree import Tree as V2Tree
+from .tree.v3.tree import Tree as V3Tree
+from .tree.tree import Tree
+
+# Flat rule module (the main execution unit)
+from .flat_rules.module import (
+ FlatRuleModule,
+ PrioritizedFlatRuleModule,
+ PrioritizationMode,
+ RunPolarsExpression,
+ OptimRunPolarsExpression,
+)
+
+# Rule node types
+from .flat_rules.nodes import (
+ LeafRule,
+ UnaryRule,
+ CompositeRule,
+ CasesRanges,
+ CasesStringMatch,
+ CasesIsIn,
+ CasesRule,
+ RuleRoot,
+ RuleMeta,
+ FlatRuleTree,
+ BuilderConfig,
+ RuleType,
+)
+
+# Condition/operator primitives
+from .common.nodes import (
+ TUnaryOp,
+ TCondition,
+ CompositeCondition,
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+ CasesBranch,
+ BaseUnaryNode,
+ BaseCasesRanges,
+ BaseCasesStringMatch,
+ BaseCasesIsIn,
+ BaseCompositeNode,
+ UnaryLessThanEqual,
+ UnaryLessThan,
+ UnaryEqual,
+ UnaryGreaterThan,
+ UnaryGreaterThanEqual,
+ UnaryNotEqual,
+ UnaryBetween,
+ UnaryIsIn,
+ UnaryStringMatch,
+ UnaryIsNull,
+ UnaryIsNotNull,
+ UnaryIsTrue,
+ UnaryIsFalse,
+)
+
+# Shared types
+from .common.shared import InputRef, TreeOutput, WithTreeOutput
+from .common.parameters import WithParameters, ParameterInfo
+from .common.feature import Feature
+from .common.nodetypes import (
+ BaseRule,
+ TLogicOp,
+ TNodeType,
+ TStringMatchType,
+ RangeEndLogic,
+ NodeMeta,
+ NodePosition,
+)
+
+# Serializable / schema types
+# Discriminated union for all module types
+from .modules import TModule, TTreeFormat, ModuleWrapper, module_discriminator
+
+
+def register_rule_modules():
+ """Register all rule module types into the global GraphModule union."""
+ from decider.modules import register_graph_module
+
+ for module_cls in [V1Tree, V2Tree, V3Tree, FlatRuleModule, PrioritizedFlatRuleModule]:
+ register_graph_module(module_cls)
+
+
+__all__ = [
+ # Trees
+ "V1Tree",
+ "V2Tree",
+ "V3Tree",
+ "Tree",
+ # Flat rules
+ "FlatRuleModule",
+ "PrioritizedFlatRuleModule",
+ "PrioritizationMode",
+ "RunPolarsExpression",
+ "OptimRunPolarsExpression",
+ # Rule nodes
+ "LeafRule",
+ "UnaryRule",
+ "CompositeRule",
+ "CasesRanges",
+ "CasesStringMatch",
+ "CasesIsIn",
+ "CasesRule",
+ "RuleRoot",
+ "RuleMeta",
+ "FlatRuleTree",
+ "BuilderConfig",
+ "RuleType",
+ # Conditions & operators
+ "TUnaryOp",
+ "TCondition",
+ "CompositeCondition",
+ "RangeCondition",
+ "StringMatchCondition",
+ "IsInCondition",
+ "CasesBranch",
+ "BaseUnaryNode",
+ "BaseCasesRanges",
+ "BaseCasesStringMatch",
+ "BaseCasesIsIn",
+ "BaseCompositeNode",
+ "UnaryLessThanEqual",
+ "UnaryLessThan",
+ "UnaryEqual",
+ "UnaryGreaterThan",
+ "UnaryGreaterThanEqual",
+ "UnaryNotEqual",
+ "UnaryBetween",
+ "UnaryIsIn",
+ "UnaryStringMatch",
+ "UnaryIsNull",
+ "UnaryIsNotNull",
+ "UnaryIsTrue",
+ "UnaryIsFalse",
+ # Shared types
+ "InputRef",
+ "TreeOutput",
+ "WithTreeOutput",
+ "WithParameters",
+ "ParameterInfo",
+ "Feature",
+ "BaseRule",
+ "TLogicOp",
+ "TNodeType",
+ "TStringMatchType",
+ "RangeEndLogic",
+ "NodeMeta",
+ "NodePosition",
+ # Module union
+ "TModule",
+ "TTreeFormat",
+ "ModuleWrapper",
+ "module_discriminator",
+ # Registration
+ "register_rule_modules",
+]
diff --git a/__init__.py b/decider/modules/rules/common/__init__.py
similarity index 100%
rename from __init__.py
rename to decider/modules/rules/common/__init__.py
diff --git a/decider/modules/rules/common/feature.py b/decider/modules/rules/common/feature.py
new file mode 100644
index 0000000..9d83e86
--- /dev/null
+++ b/decider/modules/rules/common/feature.py
@@ -0,0 +1,128 @@
+import ast
+import types
+import typing as t
+from pydantic import BaseModel, Field, PrivateAttr, model_validator, RootModel
+import polars as pl
+
+
+def extract_names_and_parameters(expression: str) -> t.Tuple[t.Set[str], t.Set[str]]:
+ """
+ Extract variable names and parameter accesses from a Python expression.
+
+ Returns:
+ (variable_names, parameter_names) where:
+ - variable_names: set of standalone names like 'b', 'age' (excludes function names)
+ - parameter_names: set of parameter accesses like 'p.asdf', 'params.income'
+ """
+ try:
+ tree = ast.parse(expression, mode="eval")
+ all_names = []
+ parameter_names = set()
+ exclude_name_ids = set()
+
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Name):
+ all_names.append(node)
+ elif (
+ isinstance(node, ast.Attribute)
+ and isinstance(node.value, ast.Name)
+ and node.value.id == "p"
+ ):
+ param_name = node.attr
+ if param_name:
+ parameter_names.add(param_name)
+ exclude_name_ids.add(id(node.value))
+ elif isinstance(node, ast.Call):
+ if isinstance(node.func, ast.Name):
+ exclude_name_ids.add(id(node.func))
+
+ variable_names = {
+ name.id for name in all_names if id(name) not in exclude_name_ids
+ }
+
+ return variable_names, parameter_names
+
+ except SyntaxError as e:
+ raise ValueError(f"Invalid expression syntax: {e}") from e
+
+
+ALLOWED_POLARS_FUNCTIONS = {
+ "duration": pl.duration,
+ "datetime": pl.datetime,
+ "date": pl.date,
+ "time": pl.time,
+ "lit": pl.lit,
+ # ... To Be Extended with more allowed functions as needed
+}
+
+
+class _ComputedFeature(BaseModel):
+ """Definition of a computed feature based on existing features."""
+
+ type: t.Literal["computed"] = "computed"
+ expression: str = Field(
+ description="Expression to compute the feature (e.g., 'feature1 + feature2')"
+ )
+ _features: t.Set[str] = PrivateAttr(default_factory=set)
+ _parameters: t.Set[str] = PrivateAttr(default_factory=set)
+
+ @model_validator(mode="after")
+ def extract_features_and_parameters(self):
+ features, parameters = extract_names_and_parameters(self.expression)
+ self._features = features
+ self._parameters = parameters
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ parameters: t.Dict[str, pl.Expr],
+ ) -> pl.Expr:
+ try:
+ from simpleeval import simple_eval
+ except ImportError:
+ raise ImportError(
+ "simpleeval is required to evaluate computed feature expressions. Please install it with 'pip install simpleeval'."
+ )
+
+ p = types.SimpleNamespace(
+ **{name: parameters.struct.field(name) for name in self._parameters}
+ ) if parameters is not None else None
+ res = simple_eval(
+ self.expression,
+ names={**inputs, "p": p},
+ functions=ALLOWED_POLARS_FUNCTIONS,
+ )
+ assert isinstance(
+ res, pl.Expr
+ ), f"Expression {self.expression} did not evaluate to a Polars expression: {res}.\nHint: if the intended output is as expected you may want to consider wrapping it in lit()"
+ return res
+
+
+class Feature(RootModel[t.Union[_ComputedFeature, str]]):
+ root: t.Union[_ComputedFeature, str] = Field(description="Feature name to test")
+
+ def __str__(self) -> str:
+ """String representation returns the feature name or expression."""
+ if isinstance(self.root, str):
+ return self.root
+ return self.root.expression
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ parameters: t.Dict[str, pl.Expr],
+ ) -> pl.Expr:
+ if isinstance(self.root, str):
+ return inputs[self.root]
+ return self.root.build_expression(inputs, parameters)
+
+ def get_required_features(self) -> t.Set[str]:
+ if isinstance(self.root, str):
+ return {self.root}
+ return self.root._features
+
+ def get_required_parameters(self) -> t.Set[str]:
+ if isinstance(self.root, str):
+ return set()
+ return self.root._parameters
diff --git a/decider/modules/rules/common/nodes/__init__.py b/decider/modules/rules/common/nodes/__init__.py
new file mode 100644
index 0000000..06a9258
--- /dev/null
+++ b/decider/modules/rules/common/nodes/__init__.py
@@ -0,0 +1,82 @@
+"""Common node definitions shared between flat_rules and tree v3."""
+
+from .operators import (
+ _BaseUnaryOp,
+ _ThresholdedUnaryOp,
+ UnaryLessThanEqual,
+ UnaryLessThan,
+ UnaryEqual,
+ UnaryGreaterThan,
+ UnaryGreaterThanEqual,
+ UnaryNotEqual,
+ UnaryBetween,
+ UnaryIsIn,
+ UnaryStringMatch,
+ UnaryIsNull,
+ UnaryIsNotNull,
+ UnaryIsTrue,
+ UnaryIsFalse,
+ TUnaryOp,
+)
+
+from .conditions import (
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+ CasesBranch,
+ TCaseCondition,
+ CompositeCondition,
+ TCondition,
+ _UnaryOpConditionWrapper,
+)
+
+from .unary import BaseUnaryNode
+from .cases import (
+ _CasesRangesCore,
+ _CasesStringMatchCore,
+ _CasesIsInCore,
+ BaseCasesRanges,
+ BaseCasesStringMatch,
+ BaseCasesIsIn,
+ validate_range_conditions,
+)
+from .composite import BaseCompositeNode
+
+__all__ = [
+ # Operators
+ "_BaseUnaryOp",
+ "_ThresholdedUnaryOp",
+ "UnaryLessThanEqual",
+ "UnaryLessThan",
+ "UnaryEqual",
+ "UnaryGreaterThan",
+ "UnaryGreaterThanEqual",
+ "UnaryNotEqual",
+ "UnaryBetween",
+ "UnaryIsIn",
+ "UnaryStringMatch",
+ "UnaryIsNull",
+ "UnaryIsNotNull",
+ "UnaryIsTrue",
+ "UnaryIsFalse",
+ "TUnaryOp",
+ # Conditions
+ "RangeCondition",
+ "StringMatchCondition",
+ "IsInCondition",
+ "CasesBranch",
+ "TCaseCondition",
+ "CompositeCondition",
+ "TCondition",
+ "_UnaryOpConditionWrapper",
+ # Base nodes
+ "BaseUnaryNode",
+ "_CasesRangesCore",
+ "_CasesStringMatchCore",
+ "_CasesIsInCore",
+ "BaseCasesRanges",
+ "BaseCasesStringMatch",
+ "BaseCasesIsIn",
+ "validate_range_conditions",
+ "BaseCompositeNode",
+]
diff --git a/decider/modules/rules/common/nodes/cases.py b/decider/modules/rules/common/nodes/cases.py
new file mode 100644
index 0000000..b54778e
--- /dev/null
+++ b/decider/modules/rules/common/nodes/cases.py
@@ -0,0 +1,208 @@
+"""Base cases nodes — feature + condition lists + required feature/param traversal.
+
+The strict range validation lives here once as a standalone function callable
+by both flat_rules (conditions: List[CasesBranch]) and nodes_ui (conditions: List[RangeCondition]).
+
+Structure:
+ _CasesRangesCore — feature/op/end_logic/strict, get_required_features (no conditions)
+ _CasesStringMatchCore — feature/op/match settings, get_required_features (no conditions)
+ _CasesIsInCore — feature/op, get_required_features (no conditions)
+ validate_range_conditions() — standalone, accepts List[RangeCondition]
+ BaseCasesRanges — extends _CasesRangesCore with conditions: List[RangeCondition] (for nodes_ui)
+ BaseCasesStringMatch — extends _CasesStringMatchCore with conditions: List[StringMatchCondition]
+ BaseCasesIsIn — extends _CasesIsInCore with conditions: List[IsInCondition]
+
+flat_rules extends the *Core classes and defines its own CasesBranch-typed conditions.
+nodes_ui extends the Base* classes directly.
+"""
+
+import typing as t
+from pydantic import BaseModel, Field, model_validator
+import typing_extensions as t_ext
+
+from ..nodetypes import BaseRule, TStringMatchType, RangeEndLogic
+from ..feature import Feature as _Feature
+from ..shared import InputRef
+from .conditions import (
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+)
+
+# =============================================================================
+# Shared validation
+# =============================================================================
+
+
+def validate_range_conditions(conditions: t.List[RangeCondition], strict: bool) -> None:
+ """Validate a list of RangeConditions for sorted/continuous order.
+
+ Raises ValueError if strict=True and conditions are not valid.
+ """
+ if not conditions or not strict:
+ return
+
+ all_min_none = all(
+ rc.min is None or isinstance(rc.min, InputRef) for rc in conditions
+ )
+
+ if all_min_none:
+ static_max_values = [
+ (i, rc.max)
+ for i, rc in enumerate(conditions)
+ if not isinstance(rc.max, InputRef) and rc.max is not None
+ ]
+ for j in range(len(static_max_values) - 1):
+ if static_max_values[j][1] >= static_max_values[j + 1][1]:
+ raise ValueError(
+ f"Ranges must be in sorted order. "
+ f"Range {static_max_values[j][0]} has max={static_max_values[j][1]}, "
+ f"but range {static_max_values[j+1][0]} has max={static_max_values[j+1][1]}."
+ )
+ return
+
+ static_ranges = [
+ (i, rc.min, rc.max)
+ for i, rc in enumerate(conditions)
+ if not isinstance(rc.min, InputRef) and not isinstance(rc.max, InputRef)
+ ]
+
+ if len(static_ranges) < 2:
+ return
+
+ for j in range(len(static_ranges) - 1):
+ current_idx, current_min, current_max = static_ranges[j]
+ next_idx, next_min, next_max = static_ranges[j + 1]
+
+ current_min_val = current_min if current_min is not None else float("-inf")
+ next_min_val = next_min if next_min is not None else float("-inf")
+
+ if (
+ current_min_val >= next_min_val
+ and current_min is not None
+ and next_min is not None
+ ):
+ raise ValueError(
+ f"Ranges must be in sorted order. "
+ f"Range {current_idx} has min={current_min}, "
+ f"but range {next_idx} has min={next_min}."
+ )
+
+ if current_max is not None and next_min is not None and current_max != next_min:
+ raise ValueError(
+ f"Ranges are not continuous in strict mode. "
+ f"Range {current_idx} ends at {current_max} but range {next_idx} starts at {next_min}."
+ )
+
+
+# =============================================================================
+# Core base classes (no conditions field — subclasses define conditions)
+# =============================================================================
+
+
+class _CasesRangesCore(BaseRule):
+ """Core fields for range-based cases nodes, shared between both systems."""
+
+ type: t.Literal["cases"] = "cases"
+ id: t.Optional[str] = None
+ feature: _Feature
+ op: t.Literal["ranges"] = "ranges"
+ end_logic: RangeEndLogic = Field(default=RangeEndLogic.lower_inclusive)
+ strict: bool = Field(default=True)
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.feature.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.feature.get_required_parameters()
+
+
+class _CasesStringMatchCore(BaseRule):
+ """Core fields for string match cases nodes, shared between both systems."""
+
+ type: t.Literal["cases"] = "cases"
+ id: t.Optional[str] = None
+ feature: _Feature
+ op: t.Literal["string_match"] = "string_match"
+ match_type: TStringMatchType = Field(default=TStringMatchType.exact)
+ case_sensitive: bool = True
+ trim_whitespace: bool = False
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.feature.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.feature.get_required_parameters()
+
+
+class _CasesIsInCore(BaseRule):
+ """Core fields for isin cases nodes, shared between both systems."""
+
+ type: t.Literal["cases"] = "cases"
+ id: t.Optional[str] = None
+ feature: _Feature
+ op: t.Literal["isin"] = "isin"
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.feature.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.feature.get_required_parameters()
+
+
+# =============================================================================
+# nodes_ui base classes (conditions as plain condition objects, no branch indices)
+# =============================================================================
+
+
+class BaseCasesRanges(_CasesRangesCore):
+ """Multi-way range branching base — used by nodes_ui where edges define routing."""
+
+ conditions: t.List[RangeCondition] = Field(
+ description="List of range conditions (order matches source index)"
+ )
+
+ @model_validator(mode="after")
+ def _validate(self) -> t_ext.Self:
+ validate_range_conditions(self.conditions, self.strict)
+ return self
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cond in self.conditions:
+ if isinstance(cond.min, InputRef):
+ params.add(cond.min.key)
+ if isinstance(cond.max, InputRef):
+ params.add(cond.max.key)
+ return params
+
+
+class BaseCasesStringMatch(_CasesStringMatchCore):
+ """Multi-way string matching base — used by nodes_ui."""
+
+ conditions: t.List[StringMatchCondition] = Field(
+ description="List of pattern conditions (order matches source index)"
+ )
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cond in self.conditions:
+ for pattern in cond.patterns:
+ if isinstance(pattern, InputRef):
+ params.add(pattern.key)
+ return params
+
+
+class BaseCasesIsIn(_CasesIsInCore):
+ """Multi-way categorical branching base — used by nodes_ui."""
+
+ conditions: t.List[IsInCondition] = Field(
+ description="List of value sets (order matches source index)"
+ )
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cond in self.conditions:
+ if isinstance(cond.values, InputRef):
+ params.add(cond.values.key)
+ return params
diff --git a/decider/modules/rules/common/nodes/composite.py b/decider/modules/rules/common/nodes/composite.py
new file mode 100644
index 0000000..79aafa5
--- /dev/null
+++ b/decider/modules/rules/common/nodes/composite.py
@@ -0,0 +1,40 @@
+"""Base composite node — AND/OR/NOT logic + required feature/param traversal.
+
+Both flat_rules.CompositeRule and nodes_ui.CompositeNode extend this.
+"""
+
+import typing as t
+from pydantic import Field, model_validator
+import typing_extensions as t_ext
+
+from ..nodetypes import BaseRule, TLogicOp
+from .conditions import TCondition
+
+
+class BaseCompositeNode(BaseRule):
+ """Composite AND/OR/NOT node — base shared between systems."""
+
+ type: t.Literal["composite"] = "composite"
+ id: t.Optional[str] = None
+ op: TLogicOp
+ conditions: t.List[TCondition] = Field(
+ description="Conditions to combine with AND/OR/NOT"
+ )
+
+ @model_validator(mode="after")
+ def validate_conditions(self) -> t_ext.Self:
+ if self.op == TLogicOp.NOT and len(self.conditions) != 1:
+ raise ValueError("NOT operator must have exactly 1 condition")
+ return self
+
+ def get_required_features(self) -> t.Set[str]:
+ features = set()
+ for cond in self.conditions:
+ features.update(cond.get_required_features())
+ return features
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = set()
+ for cond in self.conditions:
+ params.update(cond.get_required_parameters())
+ return params
diff --git a/decider/modules/rules/common/nodes/conditions.py b/decider/modules/rules/common/nodes/conditions.py
new file mode 100644
index 0000000..7ce3c3e
--- /dev/null
+++ b/decider/modules/rules/common/nodes/conditions.py
@@ -0,0 +1,279 @@
+"""Condition types shared between flat_rules and tree v3.
+
+RangeCondition, StringMatchCondition, IsInCondition, CasesBranch —
+all shared directly. CompositeCondition and TCondition are defined
+here so both systems use the exact same composite logic.
+"""
+
+import typing as t
+from pydantic import BaseModel, Field, model_validator, Discriminator, Tag
+import typing_extensions as t_ext
+import polars as pl
+
+from ..shared import InputRef
+from ..nodetypes import TStringMatchType, TLogicOp, RangeEndLogic
+from .operators import TUnaryOp
+
+# =============================================================================
+# Range
+# =============================================================================
+
+
+class RangeCondition(BaseModel):
+ """Range condition with optional min/max bounds."""
+
+ min: t.Optional[t.Union[float, int, InputRef]] = None
+ max: t.Optional[t.Union[float, int, InputRef]] = None
+
+ @model_validator(mode="after")
+ def validate_bounds(self) -> t_ext.Self:
+ if self.min is None and self.max is None:
+ raise ValueError("At least one of min or max must be specified")
+ return self
+
+ def _resolve_bound(
+ self,
+ bound: t.Optional[t.Union[float, int, InputRef]],
+ parameters: t.Optional[pl.Expr],
+ ) -> t.Optional[pl.Expr]:
+ if bound is None:
+ return None
+ if isinstance(bound, InputRef):
+ return bound.resolve(parameters)
+ return pl.lit(bound)
+
+ def build_range_condition(
+ self,
+ feature_expr: pl.Expr,
+ end_logic: RangeEndLogic,
+ parameters: t.Optional[pl.Expr],
+ ) -> pl.Expr:
+ min_expr = self._resolve_bound(self.min, parameters)
+ max_expr = self._resolve_bound(self.max, parameters)
+
+ if end_logic == RangeEndLogic.lower_inclusive:
+ min_cond = feature_expr >= min_expr if min_expr is not None else None
+ max_cond = feature_expr < max_expr if max_expr is not None else None
+ else:
+ min_cond = feature_expr > min_expr if min_expr is not None else None
+ max_cond = feature_expr <= max_expr if max_expr is not None else None
+
+ if min_cond is not None and max_cond is not None:
+ return min_cond & max_cond
+ elif min_cond is not None:
+ return min_cond
+ else:
+ return max_cond
+
+
+# =============================================================================
+# String match
+# =============================================================================
+
+
+class StringMatchCondition(BaseModel):
+ """String match condition — patterns are a mix of static strings and InputRefs."""
+
+ patterns: t.List[t.Union[str, InputRef]] = Field(
+ description="Patterns to match (OR logic) - can mix static strings and InputRefs"
+ )
+
+ @model_validator(mode="after")
+ def validate_patterns(self) -> t_ext.Self:
+ if not self.patterns:
+ raise ValueError("patterns list must contain at least one pattern")
+ return self
+
+ def _handle_static_patterns(
+ self,
+ feature_expr: pl.Expr,
+ static_patterns: t.List[str],
+ match_type: str,
+ case_sensitive: bool,
+ ) -> t.Optional[pl.Expr]:
+ if not static_patterns:
+ return None
+
+ if match_type == "exact":
+ return feature_expr.is_in(static_patterns)
+ elif match_type == "regex":
+ combined_pattern = "|".join(f"(?:{p})" for p in static_patterns)
+ return feature_expr.str.contains(combined_pattern, literal=False)
+ else:
+ if match_type == "contains":
+ exprs = [
+ feature_expr.str.contains(p, literal=True) for p in static_patterns
+ ]
+ elif match_type == "starts_with":
+ exprs = [feature_expr.str.starts_with(p) for p in static_patterns]
+ else: # ends_with
+ exprs = [feature_expr.str.ends_with(p) for p in static_patterns]
+
+ result = exprs[0]
+ for expr in exprs[1:]:
+ result = result | expr
+ return result
+
+ def _handle_dynamic_patterns(
+ self,
+ feature_expr: pl.Expr,
+ dynamic_refs: t.List[InputRef],
+ match_type: str,
+ case_sensitive: bool,
+ parameters: t.Optional[pl.Expr],
+ ) -> t.List[pl.Expr]:
+ conditions = []
+ for ref in dynamic_refs:
+ pat_expr = ref.resolve(parameters)
+ if not case_sensitive:
+ pat_expr = pat_expr.str.to_lowercase()
+
+ if match_type == "exact":
+ conditions.append(feature_expr == pat_expr)
+ elif match_type == "contains":
+ conditions.append(feature_expr.str.contains(pat_expr, literal=True))
+ elif match_type == "starts_with":
+ conditions.append(feature_expr.str.starts_with(pat_expr))
+ elif match_type == "ends_with":
+ conditions.append(feature_expr.str.ends_with(pat_expr))
+ elif match_type == "regex":
+ conditions.append(feature_expr.str.contains(pat_expr, literal=False))
+ return conditions
+
+ def build_match_condition(
+ self,
+ feature_expr: pl.Expr,
+ match_type: str,
+ case_sensitive: bool,
+ parameters: t.Optional[pl.Expr],
+ ) -> pl.Expr:
+ feat = feature_expr.str.to_lowercase() if not case_sensitive else feature_expr
+
+ static_patterns = [p for p in self.patterns if isinstance(p, str)]
+ dynamic_refs = [p for p in self.patterns if isinstance(p, InputRef)]
+
+ if not case_sensitive:
+ static_patterns = [p.lower() for p in static_patterns]
+
+ conditions = []
+ static_cond = self._handle_static_patterns(
+ feat, static_patterns, match_type, case_sensitive
+ )
+ if static_cond is not None:
+ conditions.append(static_cond)
+
+ conditions.extend(
+ self._handle_dynamic_patterns(
+ feat, dynamic_refs, match_type, case_sensitive, parameters
+ )
+ )
+
+ if not conditions:
+ return pl.lit(False)
+
+ result = conditions[0]
+ for cond in conditions[1:]:
+ result = result | cond
+ return result
+
+
+# =============================================================================
+# IsIn
+# =============================================================================
+
+
+class IsInCondition(BaseModel):
+ values: t.Union[t.List[t.Union[int, float]], InputRef]
+
+
+# =============================================================================
+# Cases branch wrapper
+# =============================================================================
+
+TCaseCondition = t.Union[RangeCondition, StringMatchCondition, IsInCondition]
+
+
+class CasesBranch(BaseModel):
+ """A single case: when condition → then branch_index."""
+
+ when: TCaseCondition
+ then: int = Field(description="Index into branches array")
+
+
+# =============================================================================
+# Composite condition (shared, recursive)
+# =============================================================================
+
+
+def _condition_discriminator(value: t.Any) -> str:
+ """Return tag based on the 'type' field of the incoming value."""
+ if isinstance(value, dict):
+ return value.get("type", "unary")
+ return getattr(value, "type", "unary")
+
+
+TCondition = t.Annotated[
+ t.Union[
+ t.Annotated[TUnaryOp, Tag("unary")],
+ t.Annotated["CompositeCondition", Tag("composite")],
+ ],
+ Discriminator(_condition_discriminator),
+]
+
+# Backward-compat alias — nothing actually uses this as a class now
+_UnaryOpConditionWrapper = None
+
+
+class CompositeCondition(BaseModel):
+ """Nested composite condition (AND/OR/NOT of other conditions)."""
+
+ type: t.Literal["composite"] = "composite"
+ id: t.Optional[str] = Field(default=None)
+ op: TLogicOp
+ conditions: t.List[TCondition] = Field(description="List of conditions to combine")
+
+ @model_validator(mode="after")
+ def validate_and_ensure_id(self) -> t_ext.Self:
+ if self.op == TLogicOp.NOT and len(self.conditions) != 1:
+ raise ValueError("NOT operator must have exactly 1 condition")
+ if self.id is None:
+ import uuid
+
+ self.id = str(uuid.uuid4())
+ return self
+
+ def get_required_features(self) -> t.Set[str]:
+ features = set()
+ for cond in self.conditions:
+ features.update(cond.get_required_features())
+ return features
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = set()
+ for cond in self.conditions:
+ params.update(cond.get_required_parameters())
+ return params
+
+ def build_condition(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ parameters: t.Optional[pl.Expr],
+ ) -> pl.Expr:
+ if not self.conditions:
+ return pl.lit(False)
+
+ cond_exprs = [
+ cond.build_condition(inputs, parameters) for cond in self.conditions
+ ]
+
+ if self.op == TLogicOp.NOT:
+ return ~cond_exprs[0]
+
+ result = cond_exprs[0]
+ for expr in cond_exprs[1:]:
+ result = result & expr if self.op == TLogicOp.AND else result | expr
+ return result
+
+
+# Rebuild for forward ref
+CompositeCondition.model_rebuild()
diff --git a/decider/modules/rules/common/nodes/operators.py b/decider/modules/rules/common/nodes/operators.py
new file mode 100644
index 0000000..0ae3856
--- /dev/null
+++ b/decider/modules/rules/common/nodes/operators.py
@@ -0,0 +1,312 @@
+"""Unary operator definitions shared between flat_rules and tree v3.
+
+All operators have:
+ - feature: the Feature to test
+ - op: literal discriminator
+ - build_condition(): polars expression (used by flat_rules)
+ - get_required_features/parameters() (used by both systems)
+
+flat_rules wraps these directly.
+nodes_ui imports TUnaryOp and uses them inside UnaryNode/CompositeNode.
+"""
+
+import typing as t
+from abc import ABC, abstractmethod
+from pydantic import BaseModel, Field, model_validator
+import typing_extensions as t_ext
+import polars as pl
+
+from ..shared import InputRef
+from ..feature import Feature as _Feature
+from ..nodetypes import TStringMatchType
+
+
+class _BaseUnaryOp(BaseModel, ABC):
+ """Base class for all unary operators."""
+
+ type: t.Literal["unary"] = "unary"
+ feature: _Feature = Field(description="Feature name to test")
+
+ @abstractmethod
+ def build_condition(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ parameters: t.Optional[pl.Expr],
+ ) -> pl.Expr: ...
+
+ @abstractmethod
+ def _get_required_params(self) -> t.Set[str]: ...
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.feature.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ params.update(self._get_required_params())
+ return params
+
+
+class _ThresholdedUnaryOp(_BaseUnaryOp, ABC):
+ """Base for numeric comparison operators."""
+
+ threshold: t.Union[float, int, InputRef] = Field(
+ description="Comparison value (number or InputRef for runtime variable)"
+ )
+
+ def _resolve_threshold(self, parameters: t.Optional[pl.Expr]) -> pl.Expr:
+ if isinstance(self.threshold, InputRef):
+ return self.threshold.resolve(parameters)
+ return pl.lit(self.threshold)
+
+ def _get_required_params(self) -> t.Set[str]:
+ if isinstance(self.threshold, InputRef):
+ return {self.threshold.key}
+ return set()
+
+
+class UnaryLessThanEqual(_ThresholdedUnaryOp):
+ op: t.Literal["<="] = "<="
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) <= self._resolve_threshold(parameters)
+
+
+class UnaryLessThan(_ThresholdedUnaryOp):
+ op: t.Literal["<"] = "<"
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) < self._resolve_threshold(parameters)
+
+
+class UnaryEqual(_ThresholdedUnaryOp):
+ op: t.Literal["=="] = "=="
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) == self._resolve_threshold(parameters)
+
+
+class UnaryGreaterThan(_ThresholdedUnaryOp):
+ op: t.Literal[">"] = ">"
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) > self._resolve_threshold(parameters)
+
+
+class UnaryGreaterThanEqual(_ThresholdedUnaryOp):
+ op: t.Literal[">="] = ">="
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) >= self._resolve_threshold(parameters)
+
+
+class UnaryNotEqual(_ThresholdedUnaryOp):
+ op: t.Literal["!="] = "!="
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(
+ inputs, parameters
+ ) != self._resolve_threshold(parameters)
+
+
+class UnaryBetween(_BaseUnaryOp):
+ op: t.Literal["between"] = "between"
+ min: t.Optional[t.Union[float, int, InputRef]] = Field(default=None)
+ max: t.Optional[t.Union[float, int, InputRef]] = Field(default=None)
+
+ @model_validator(mode="after")
+ def validate_bounds(self) -> t_ext.Self:
+ if self.min is None and self.max is None:
+ raise ValueError("At least one of min or max must be specified")
+ return self
+
+ def _get_required_params(self) -> t.Set[str]:
+ params = set()
+ if isinstance(self.min, InputRef):
+ params.add(self.min.key)
+ if isinstance(self.max, InputRef):
+ params.add(self.max.key)
+ return params
+
+ def _resolve_bound(
+ self,
+ bound: t.Optional[t.Union[float, int, InputRef]],
+ parameters: t.Optional[pl.Expr],
+ ) -> t.Optional[pl.Expr]:
+ if bound is None:
+ return None
+ if isinstance(bound, InputRef):
+ return bound.resolve(parameters)
+ return pl.lit(bound)
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ min_expr = self._resolve_bound(self.min, parameters)
+ max_expr = self._resolve_bound(self.max, parameters)
+
+ if min_expr is not None and max_expr is not None:
+ return (feature_expr >= min_expr) & (feature_expr <= max_expr)
+ elif min_expr is not None:
+ return feature_expr >= min_expr
+ else:
+ return feature_expr <= max_expr
+
+
+class UnaryIsIn(_BaseUnaryOp):
+ op: t.Literal["isin"] = "isin"
+ values: t.Union[t.List[t.Union[int, float]], InputRef] = Field(
+ description="List of acceptable values or InputRef for runtime variable"
+ )
+
+ @model_validator(mode="after")
+ def validate_values(self) -> t_ext.Self:
+ if isinstance(self.values, list) and not self.values:
+ raise ValueError("values list must contain at least one element")
+ return self
+
+ def _get_required_params(self) -> t.Set[str]:
+ if isinstance(self.values, InputRef):
+ return {self.values.key}
+ return set()
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ if isinstance(self.values, InputRef):
+ return feature_expr == self.values.resolve(parameters)
+ return feature_expr.is_in(self.values)
+
+
+class UnaryStringMatch(_BaseUnaryOp):
+ op: t.Literal["string_match"] = "string_match"
+ patterns: t.List[t.Union[str, InputRef]] = Field(
+ description="Patterns to match (OR logic) - can mix static strings and InputRefs"
+ )
+ match_type: TStringMatchType = Field(default=TStringMatchType.exact)
+ case_sensitive: bool = Field(default=True)
+ trim_whitespace: bool = Field(default=False)
+
+ @model_validator(mode="after")
+ def validate_patterns(self) -> t_ext.Self:
+ if not self.patterns:
+ raise ValueError("patterns list must contain at least one pattern")
+ return self
+
+ def _get_required_params(self) -> t.Set[str]:
+ params = set()
+ for pattern in self.patterns:
+ if isinstance(pattern, InputRef):
+ params.add(pattern.key)
+ return params
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ from .conditions import StringMatchCondition
+
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ if self.trim_whitespace:
+ feature_expr = feature_expr.str.strip_chars()
+ matcher = StringMatchCondition(patterns=self.patterns)
+ return matcher.build_match_condition(
+ feature_expr=feature_expr,
+ match_type=(
+ self.match_type.value
+ if isinstance(self.match_type, TStringMatchType)
+ else self.match_type
+ ),
+ case_sensitive=self.case_sensitive,
+ parameters=parameters,
+ )
+
+
+class UnaryIsNull(_BaseUnaryOp):
+ op: t.Literal["is_null"] = "is_null"
+
+ def _get_required_params(self) -> t.Set[str]:
+ return set()
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(inputs, parameters).is_null()
+
+
+class UnaryIsNotNull(_BaseUnaryOp):
+ op: t.Literal["is_not_null"] = "is_not_null"
+
+ def _get_required_params(self) -> t.Set[str]:
+ return set()
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(inputs, parameters).is_not_null()
+
+
+class UnaryIsTrue(_BaseUnaryOp):
+ op: t.Literal["is_true"] = "is_true"
+
+ def _get_required_params(self) -> t.Set[str]:
+ return set()
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(inputs, parameters) == pl.lit(True)
+
+
+class UnaryIsFalse(_BaseUnaryOp):
+ op: t.Literal["is_false"] = "is_false"
+
+ def _get_required_params(self) -> t.Set[str]:
+ return set()
+
+ def build_condition(
+ self, inputs: t.Dict[str, pl.Expr], parameters: t.Optional[pl.Expr]
+ ) -> pl.Expr:
+ return self.feature.build_expression(inputs, parameters) == pl.lit(False)
+
+
+TUnaryOp = t.Annotated[
+ t.Union[
+ UnaryLessThanEqual,
+ UnaryLessThan,
+ UnaryEqual,
+ UnaryGreaterThan,
+ UnaryGreaterThanEqual,
+ UnaryNotEqual,
+ UnaryBetween,
+ UnaryIsIn,
+ UnaryStringMatch,
+ UnaryIsNull,
+ UnaryIsNotNull,
+ UnaryIsTrue,
+ UnaryIsFalse,
+ ],
+ Field(discriminator="op"),
+]
diff --git a/decider/modules/rules/common/nodes/unary.py b/decider/modules/rules/common/nodes/unary.py
new file mode 100644
index 0000000..84ae0a0
--- /dev/null
+++ b/decider/modules/rules/common/nodes/unary.py
@@ -0,0 +1,25 @@
+"""Base unary node — condition + required feature/param traversal.
+
+Both flat_rules.UnaryRule and nodes_ui.UnaryNode extend this.
+They add their own child-resolution mechanism (embedded vs edge-based).
+"""
+
+import typing as t
+from pydantic import Field
+
+from ..nodetypes import BaseRule
+from .operators import TUnaryOp
+
+
+class BaseUnaryNode(BaseRule):
+ """Unary condition node — holds the condition, delegates child logic to subclasses."""
+
+ type: t.Literal["unary"] = "unary"
+ id: t.Optional[str] = Field(default=None)
+ condition: TUnaryOp = Field(description="The condition to evaluate")
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.condition.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.condition.get_required_parameters()
diff --git a/decider/modules/rules/common/nodetypes.py b/decider/modules/rules/common/nodetypes.py
new file mode 100644
index 0000000..5d93fda
--- /dev/null
+++ b/decider/modules/rules/common/nodetypes.py
@@ -0,0 +1,115 @@
+"""Common node types and operators shared between flat rules and tree v3."""
+
+import typing as t
+import enum
+from pydantic import BaseModel, Field
+
+# =============================================================================
+# Base Metadata Types
+# =============================================================================
+
+
+class NodePosition(BaseModel):
+ """UI position for a node."""
+
+ x: float = 0.0
+ y: float = 0.0
+
+
+class NodeMeta(BaseModel):
+ """Execution-agnostic metadata attached to a node.
+
+ Currently used to preserve UI layout positions through round-trip conversion.
+ """
+
+ position: t.Optional[NodePosition] = None
+
+
+class BaseRule(BaseModel):
+ """Base class for all rule/node types with optional metadata."""
+
+ meta: t.Optional[NodeMeta] = Field(
+ default=None,
+ description="Optional metadata (e.g., UI position)",
+ )
+
+
+# =============================================================================
+# Enums - Shared operators
+# =============================================================================
+
+
+class RangeEndLogic(str, enum.Enum):
+ """Logic for range boundary handling."""
+
+ lower_inclusive = "lower_inclusive"
+ upper_inclusive = "upper_inclusive"
+
+
+class TPrimitiveOperators(str, enum.Enum):
+ """Primitive comparison operators for unary conditions."""
+
+ LEQ = "<="
+ LT = "<"
+ EQ = "=="
+ GT = ">"
+ GEQ = ">="
+ NEQ = "!="
+
+
+class TStringMatchType(str, enum.Enum):
+ """String matching strategies."""
+
+ exact = "exact"
+ starts_with = "starts_with"
+ contains = "contains"
+ ends_with = "ends_with"
+ regex = "regex"
+
+
+class TLogicOp(str, enum.Enum):
+ """Logical operators for composite conditions."""
+
+ AND = "and"
+ OR = "or"
+ NOT = "not"
+
+
+class TNodeType(str, enum.Enum):
+ """Node type discriminators."""
+
+ LEAF = "leaf"
+ UNARY = "unary"
+ CASES = "cases"
+ COMPOSITE = "composite"
+
+
+# =============================================================================
+# Core node types - Base models shared between systems
+# =============================================================================
+
+
+class LeafNodeCore(BaseRule):
+ """Core leaf node - terminal node that returns a result."""
+
+ type: t.Literal["leaf"] = "leaf"
+ result_idx: int = Field(
+ default=-1,
+ description="Index into output table. -1 indicates default/no-match.",
+ )
+
+
+# =============================================================================
+# Core condition types - Used in both flat rules and tree v3
+# =============================================================================
+
+
+class MinMaxConditionCore(BaseModel):
+ """Core range condition with min/max bounds."""
+
+ min: t.Optional[t.Union[float, int]] = Field(
+ default=None, description="Minimum value (null for unbounded)"
+ )
+ max: t.Optional[t.Union[float, int]] = Field(
+ default=None, description="Maximum value (null for unbounded)"
+ )
diff --git a/decider/modules/rules/common/parameters.py b/decider/modules/rules/common/parameters.py
new file mode 100644
index 0000000..7032e51
--- /dev/null
+++ b/decider/modules/rules/common/parameters.py
@@ -0,0 +1,47 @@
+import typing as t
+import polars as pl
+from pydantic import BaseModel, Field, model_validator, PrivateAttr
+from ....serializable.schema import PrimitiveSchema
+
+
+class ParameterInfo(BaseModel):
+ type: PrimitiveSchema
+ default_value: t.Optional[t.Any] = None
+ _polars_literal: t.Optional[pl.Expr] = PrivateAttr(default=None)
+
+ @model_validator(mode="after")
+ def validate_default_value(self):
+ if self.default_value is not None:
+ polars_type = self.type.polars_type
+ try:
+ self._polars_literal = pl.lit(self.default_value, polars_type)
+ except pl.exceptions.InvalidOperationError:
+ raise ValueError(
+ f"Default value {self.default_value} is not compatible with type {self.type}"
+ )
+
+ return self
+
+ @property
+ def polars_literal(self) -> t.Optional[pl.Expr]:
+ return self._polars_literal
+
+
+class WithParameters:
+ parameters_col: str = "parameters"
+ parameters: t.Dict[str, ParameterInfo] = Field(default_factory=dict)
+
+ @model_validator(mode="after")
+ def validate_parameters(self):
+ required_parameters = self.get_required_parameters()
+ known_parameters = set(self.parameters.keys())
+ if not required_parameters.issubset(known_parameters):
+ missing = required_parameters - known_parameters
+ raise ValueError(f"Missing parameter definitions for: {missing}")
+ return self
+
+ @property
+ def parameter_schema(self) -> pl.Schema:
+ return pl.Schema(
+ {name: info.type.polars_type for name, info in self.parameters.items()}
+ )
diff --git a/decider/modules/rules/common/shared.py b/decider/modules/rules/common/shared.py
new file mode 100644
index 0000000..c047c6c
--- /dev/null
+++ b/decider/modules/rules/common/shared.py
@@ -0,0 +1,184 @@
+import typing as t
+import polars as pl
+from pydantic import BaseModel, PrivateAttr, model_validator, Field
+from ....serializable.dataframe import TDataFrameData, TDataFrameRow
+from ....serializable.dtypes import ContainsDtypes, TTypeDef
+from ....serializable.schema import TType, ExplicitType
+
+
+class InputRef(BaseModel):
+ """Reference to a runtime parameter or variable.
+
+ Used to dynamically get parameters from the payload or dataframe at execution time.
+ Represents variables in the UI (displayed as #key).
+ """
+
+ key: str = Field(description="Parameter key from the graph execution context")
+
+ def resolve(self, parameters: t.Optional[pl.Expr]):
+ """Resolve the reference to a Polars expression at runtime."""
+ return parameters.struct.field(self.key)
+
+ def __str__(self) -> str:
+ """Display as #key for UI rendering."""
+ return f"#{self.key}"
+
+
+def expand_struct_type(
+ struct_data: TDataFrameRow,
+ dtype_keys: t.Iterable[t.Tuple[str, TTypeDef]],
+ type_defs: t.Dict[str, TTypeDef],
+) -> TDataFrameRow:
+ out = {
+ k: expand_data_row(struct_data[k], v, type_defs)
+ for k, v in dtype_keys
+ if k in struct_data
+ }
+ if set(out.keys()) != set(struct_data.keys()):
+ raise ValueError(
+ f"Miss-match in schema between struct data and dtype definition:\nStruct data keys: {set(struct_data.keys())}\nDtype keys: {set(out.keys())}"
+ )
+ return out
+
+
+def expand_str(
+ struct_data: t.Any, dtype: str, type_defs: t.Dict[str, TTypeDef]
+) -> t.Any:
+ return struct_data
+
+
+def expand_explicit_type(
+ struct_data: t.Any, dtype: ExplicitType, type_defs: t.Dict[str, TTypeDef]
+) -> t.Any:
+ lower_type = dtype.type.lower()
+ if lower_type == "custom":
+ type_id = dtype.model_extra.get("type_id")
+ if not type_id:
+ raise ValueError("ExplicitType of type 'Custom' must have a 'type_id'.")
+ type_def = type_defs.get(type_id)
+ if not type_def:
+ raise ValueError(f"Type definition for {type_id} not found in type_defs.")
+
+ if type_def.type == "categorical":
+ return struct_data
+ if type_def.type == "struct":
+ # Handle case where struct_data is already an index (int/str) instead of a dict
+ if isinstance(struct_data, (int, str)):
+ # If it's already an index/key, use it directly - return the full record
+ return type_def.get_value_for_key(struct_data)
+ # If it's already the full expanded record dict, return it as-is
+ # (Check if it matches the struct schema fields)
+ if isinstance(struct_data, dict):
+ # If it has all the fields from the struct schema, it's already expanded
+ struct_fields = set(
+ type_def.definition.fields.keys()
+ if isinstance(type_def.definition.fields, dict)
+ else [f[0] for f in type_def.definition.fields]
+ )
+ if struct_fields == set(struct_data.keys()):
+ # Already expanded - return as-is
+ return struct_data
+ # Otherwise, check for $key field
+ key = struct_data.get("$key")
+ if key is not None:
+ return type_def.get_value_for_key(key)
+ raise ValueError(
+ f"Struct data must be int, str, dict with '$key' field, or already-expanded struct dict, got: {struct_data}"
+ )
+ raise ValueError(
+ f"Unsupported custom type {type_def.type} for ExplicitType with type_id {type_id}"
+ )
+ if lower_type in ("list", "set", "array"):
+ if not isinstance(struct_data, t.Iterable):
+ return struct_data
+ # raise ValueError(f"Expected a list for {dtype.type} type, got {type(struct_data)}")
+ inner_dtype = dtype.model_extra.get(
+ "inner", dtype.model_extra.pop("fields", None)
+ )
+ return [expand_data_row(item, inner_dtype, type_defs) for item in struct_data]
+ return expand_str(struct_data, dtype.type, type_defs)
+
+
+def expand_data_row(
+ tree_output: t.Any, dtype: TType, type_defs: t.Dict[str, TTypeDef]
+) -> TDataFrameRow:
+ if isinstance(dtype, dict):
+ return expand_struct_type(tree_output, dtype.items(), type_defs)
+ if isinstance(dtype, list):
+ return expand_struct_type(tree_output, dtype, type_defs)
+ if isinstance(dtype, str):
+ return expand_str(tree_output, dtype, type_defs)
+ if isinstance(dtype, ExplicitType):
+ return expand_explicit_type(tree_output, dtype, type_defs)
+ # This case should already be handled by the pydantic validator
+ raise ValueError(
+ f"Unexpected value {t}. Expected either a dict, list, string or explicit type"
+ )
+
+
+def expand_data(
+ data: TDataFrameData, dtype: TType, type_defs: t.Dict[str, TTypeDef]
+) -> t.List[TDataFrameRow]:
+ return [expand_data_row(row, dtype, type_defs) for row in data]
+
+
+class TreeOutput(ContainsDtypes):
+ data: TDataFrameData
+ default: t.Optional[TDataFrameRow] = None
+
+ _output_literals: t.List[pl.Expr] = PrivateAttr()
+ _default_literal: t.Optional[pl.Expr] = PrivateAttr(default=None)
+
+ @model_validator(mode="after")
+ def construct_literals(self):
+ # Build df purely to validate and resolve the struct dtype
+ # Use parent's schema (from ContainsDtypes) which already handles type_defs
+
+ data = expand_data(self.data, self.dtypes, self.type_defs)
+ default = (
+ expand_data_row(self.default, self.dtypes, self.type_defs)
+ if self.default
+ else None
+ )
+ try:
+ pl_df = pl.DataFrame(data, schema=self.schema)
+ except Exception as e:
+ # Below was for when we didnt have schemas
+ # if not len(data) and self.default is not None:
+ # dtypes = pl.DataFrame([self.default]).schema
+ # pl_df = pl.DataFrame(data, schema=dtypes)
+ # else:
+ raise ValueError(f"Could not load output data into schema: {e}") from e
+
+ struct_dtype = pl.Struct(pl_df.schema)
+ self._output_literals = [
+ pl.lit(row, dtype=struct_dtype) for row in pl_df.to_dicts()
+ ]
+ if default is not None:
+ # Validate default is compatible with schema
+ try:
+ pl.DataFrame([default], schema=self.schema)
+ except Exception as e:
+ raise ValueError(f"Default row is incompatible with schema: {e}") from e
+ self._default_literal = pl.lit(default, dtype=struct_dtype)
+ return self
+
+ @property
+ def output_literals(self) -> t.List[pl.Expr]:
+ return self._output_literals
+
+ @property
+ def default_literal(self) -> t.Optional[pl.Expr]:
+ return self._default_literal
+
+
+class WithTreeOutput(BaseModel):
+ output: TreeOutput
+
+ @property
+ def output_literals(self) -> t.List[pl.Expr]:
+ return self.output.output_literals
+
+ @property
+ def default_literal(self) -> t.Optional[pl.Expr]:
+ return self.output.default_literal
diff --git a/spockflow/components/__init__.py b/decider/modules/rules/flat_rules/__init__.py
similarity index 100%
rename from spockflow/components/__init__.py
rename to decider/modules/rules/flat_rules/__init__.py
diff --git a/decider/modules/rules/flat_rules/impl.py b/decider/modules/rules/flat_rules/impl.py
new file mode 100644
index 0000000..da678fe
--- /dev/null
+++ b/decider/modules/rules/flat_rules/impl.py
@@ -0,0 +1,542 @@
+"""Execution logic for flat rules system.
+
+This module provides functions to execute flat rules on Polars expressions and
+DataFrames, with support for parameters, prioritized multi-rule execution, and
+custom output formatting.
+"""
+
+import typing as t
+import functools
+import polars as pl
+from .nodes import RuleRoot, FlatRuleTree, TBranchStack, RuleType, BuilderConfig
+from dataclasses import dataclass, field
+
+# =============================================================================
+# Result Builder Functions
+# =============================================================================
+
+
+def default_result_builder(
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ result_idx: int,
+) -> pl.Expr:
+ """Default output function: returns the output literal at result_idx,
+ or the default literal when result_idx is -1 (the no-match leaf)."""
+ if result_idx == -1:
+ return config.default_expr
+ return config.output_literals[result_idx]
+
+
+# Single source of truth for OutputFn signature
+OutputFn = t.Type[default_result_builder]
+
+
+# =============================================================================
+# Parameter Building
+# =============================================================================
+
+
+def build_parameters_expr(
+ runtime_params: t.Optional[pl.Expr],
+ parameter_schema: pl.Schema,
+ default_literals: t.Dict[str, pl.Expr],
+) -> t.Optional[pl.Expr]:
+ """Build a parameters struct expression with smart fallback.
+
+ For each parameter in the schema:
+ - If runtime param exists and is not None, use it
+ - Otherwise, fallback to default_literals[key]
+
+ Runtime params are cast to the complete schema before coalescing
+ to handle missing fields in the runtime struct.
+
+ Args:
+ runtime_params: Optional expression for runtime parameters column
+ parameter_schema: Polars schema defining all parameter fields and types
+ default_literals: Dict mapping parameter names to default literal expressions
+
+ Returns:
+ Struct expression with merged runtime and default parameters, or None if no params
+ """
+ if not parameter_schema:
+ return None
+
+ if runtime_params is None:
+ # No runtime column, build struct from defaults only
+ if default_literals:
+ return pl.struct(
+ *[expr.alias(name) for name, expr in default_literals.items()]
+ )
+ return None
+
+ # Cast runtime params to complete schema (adds missing fields as None)
+ runtime_casted = runtime_params.cast(pl.Struct(parameter_schema))
+
+ # Build merged struct with per-field fallback
+ merged_fields = []
+ for name in parameter_schema:
+ default_expr = default_literals.get(name)
+ if default_expr is not None:
+ # Runtime[name] if not None, else default
+ merged_fields.append(
+ pl.coalesce([runtime_casted.struct.field(name), default_expr]).alias(
+ name
+ )
+ )
+ else:
+ # No default, use runtime value (may be None)
+ merged_fields.append(runtime_casted.struct.field(name).alias(name))
+
+ return pl.struct(*merged_fields)
+
+
+# =============================================================================
+# Single Rule Execution
+# =============================================================================
+
+
+def execute_rule(
+ rule: RuleType,
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ **inputs: pl.Expr,
+) -> pl.Expr:
+ """Execute a single rule and return the resulting Polars expression.
+
+ Args:
+ rule: The rule to execute (can be any RuleType)
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ **inputs: Feature columns as pl.Expr (e.g., age=pl.col('age'))
+
+ Returns:
+ Polars expression representing the rule's output
+ """
+ return rule.build_expression(
+ inputs=inputs,
+ branch_stack=(),
+ config=builder_config,
+ parameters=parameters,
+ )
+
+
+def execute_rule_root(
+ rule_root: RuleRoot,
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ **inputs: pl.Expr,
+) -> pl.Expr:
+ """Execute a RuleRoot (rule with metadata) and return the resulting expression.
+
+ Args:
+ rule_root: RuleRoot containing meta and rule
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ **inputs: Feature columns as pl.Expr
+
+ Returns:
+ Polars expression representing the rule's output
+ """
+ from dataclasses import replace
+
+ # Inject rule metadata into config
+ config_with_meta = replace(builder_config, root_meta=rule_root.meta)
+
+ return execute_rule(
+ rule=rule_root.rule,
+ builder_config=config_with_meta,
+ parameters=parameters,
+ **inputs,
+ )
+
+
+# =============================================================================
+# Multiple Rule Execution
+# =============================================================================
+
+
+def execute_rule_list(
+ rules: t.List[RuleRoot],
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ **inputs: pl.Expr,
+) -> t.List[pl.Expr]:
+ """Execute multiple rules and return one output expression per rule.
+
+ Args:
+ rules: List of RuleRoots to execute
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ **inputs: Feature columns as pl.Expr
+
+ Returns:
+ List of Polars expressions, one per rule
+ """
+ from dataclasses import replace
+
+ return [
+ execute_rule_root(
+ rule_root=rule_root,
+ builder_config=replace(builder_config, rule_idx=i),
+ parameters=parameters,
+ **inputs,
+ )
+ for i, rule_root in enumerate(rules)
+ ]
+
+
+def execute_flat_rule_tree(
+ tree: FlatRuleTree,
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ **inputs: pl.Expr,
+) -> t.List[pl.Expr]:
+ """Execute a FlatRuleTree and return one expression per rule.
+
+ Args:
+ tree: FlatRuleTree containing multiple rules
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ **inputs: Feature columns as pl.Expr
+
+ Returns:
+ List of Polars expressions, one per rule in the tree
+ """
+ return execute_rule_list(
+ rules=tree.rules,
+ builder_config=builder_config,
+ parameters=parameters,
+ **inputs,
+ )
+
+
+# =============================================================================
+# Prioritized Multi-Rule Execution
+# =============================================================================
+
+
+def _with_rule_idx(expr: pl.Expr, rule_idx: int) -> pl.Expr:
+ """Enrich an indexed result struct {idx, val} with the originating rule index.
+
+ Args:
+ expr: Expression returning {idx, val} struct
+ rule_idx: Index of the rule that produced this result
+
+ Returns:
+ Expression returning {idx, rule_idx, val} struct
+ """
+ return pl.struct(
+ expr.struct.field("idx").alias("idx"),
+ pl.lit(rule_idx).alias("rule_idx"),
+ expr.struct.field("val").alias("val"),
+ )
+
+
+def default_get_prioritized_result(
+ failed_rule_results: t.List[pl.Expr],
+ current_result: t.Optional[pl.Expr],
+ default_expr: pl.Expr,
+) -> pl.Expr:
+ """Default prioritization logic: return current result or last failed result.
+
+ Args:
+ failed_rule_results: List of results from rules that fell through (idx=-1)
+ current_result: Current rule's result (if it matched), or None (if we're in otherwise)
+ default_expr: Default fallback expression
+
+ Returns:
+ Expression to use as the prioritized result
+ """
+ if current_result is None:
+ if len(failed_rule_results) == 0:
+ return default_expr
+ return failed_rule_results[-1]
+ return current_result
+
+
+def prioritize_results(
+ results: t.List[pl.Expr],
+ default_expr: pl.Expr,
+ format_prioritized_fn: t.Optional[
+ t.Callable[[t.List[pl.Expr], t.Optional[pl.Expr], pl.Expr], pl.Expr]
+ ] = None,
+) -> pl.Expr:
+ """Return the first result (lowest rule index) where idx != -1.
+
+ Falls back to a sentinel default struct {idx=-1, rule_idx=-1, val=default_expr}
+ when every rule fell through.
+
+ Args:
+ results: List of result expressions (each returns {idx, val} struct)
+ default_expr: Default expression to use when all rules fall through
+ format_prioritized_fn: Optional custom prioritization logic
+
+ Returns:
+ Expression returning {idx, rule_idx, val} struct with the prioritized result
+ """
+ default_struct = pl.struct(
+ pl.lit(-1).alias("idx"),
+ pl.lit(-1).alias("rule_idx"),
+ default_expr.alias("val"),
+ )
+
+ if format_prioritized_fn is None:
+ format_prioritized_fn = default_get_prioritized_result
+
+ out_expr = pl
+ previous_results = []
+
+ for i, rule_res in enumerate(results):
+ indexed_res = _with_rule_idx(rule_res, i)
+ out_expr = out_expr.when(rule_res.struct.field("idx") != -1).then(
+ format_prioritized_fn(previous_results, indexed_res, default_struct)
+ )
+ previous_results.append(indexed_res)
+
+ if out_expr is pl:
+ return default_struct
+ return out_expr.otherwise(
+ format_prioritized_fn(previous_results, None, default_struct)
+ )
+
+
+def wrap_output_fn_for_index(fn: OutputFn) -> OutputFn:
+ """Wrap an OutputFn so its return value is embedded in a {idx, val} struct.
+
+ idx is the result_idx of the matched leaf (-1 for the default/no-match leaf).
+ Uses **kwargs so the wrapper never needs updating if OutputFn's signature changes.
+ """
+
+ @functools.wraps(fn)
+ def inner(**kwargs: t.Any) -> pl.Expr:
+ val = fn(**kwargs)
+ return pl.struct(
+ pl.lit(kwargs["result_idx"]).alias("idx"),
+ val.alias("val"),
+ )
+
+ return inner # type: ignore[return-value]
+
+
+def extract_value(prioritized: pl.Expr) -> pl.Expr:
+ """Default post-processor: unwrap the val field from a prioritized result struct.
+
+ Args:
+ prioritized: Expression returning {idx, rule_idx, val} struct
+
+ Returns:
+ Expression returning just the val field
+ """
+ return prioritized.struct.field("val")
+
+
+def execute_prioritized_rules(
+ rules: t.List[RuleRoot],
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ post_process_fn: t.Optional[t.Callable[[pl.Expr], pl.Expr]] = None,
+ format_prioritized_fn: t.Optional[
+ t.Callable[[t.List[pl.Expr], t.Optional[pl.Expr], pl.Expr], pl.Expr]
+ ] = None,
+ **inputs: pl.Expr,
+) -> pl.Expr:
+ """Execute multiple rules and return the result of the first one that matches.
+
+ Falls back to default_literal if none match.
+
+ Args:
+ rules: List of RuleRoots to execute in priority order
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ post_process_fn: Applied to the full prioritized struct {idx, rule_idx, val}
+ (defaults to extract_value which returns just val)
+ format_prioritized_fn: Optional custom prioritization logic
+ **inputs: Feature columns as pl.Expr
+
+ Returns:
+ Expression with the prioritized result (post-processed)
+ """
+ from dataclasses import replace
+
+ # Process the default result first
+ default_result_expr = builder_config.build_result_function(
+ inputs=inputs,
+ branch_stack=(),
+ config=builder_config,
+ result_idx=-1,
+ )
+
+ # Wrap the result builder to inject {idx, val} structs
+ wrapped_config = replace(
+ builder_config,
+ build_result_function=wrap_output_fn_for_index(
+ builder_config.build_result_function
+ ),
+ )
+
+ # Execute all rules
+ results = execute_rule_list(
+ rules=rules,
+ builder_config=wrapped_config,
+ parameters=parameters,
+ **inputs,
+ )
+
+ # Prioritize results
+ prioritized = prioritize_results(
+ results=results,
+ default_expr=default_result_expr,
+ format_prioritized_fn=format_prioritized_fn,
+ )
+
+ # Post-process (default: extract value)
+ post_process_fn = post_process_fn or extract_value
+ return post_process_fn(prioritized)
+
+
+def execute_prioritized_flat_rule_tree(
+ tree: FlatRuleTree,
+ builder_config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ post_process_fn: t.Optional[t.Callable[[pl.Expr], pl.Expr]] = None,
+ format_prioritized_fn: t.Optional[
+ t.Callable[[t.List[pl.Expr], t.Optional[pl.Expr], pl.Expr], pl.Expr]
+ ] = None,
+ **inputs: pl.Expr,
+) -> pl.Expr:
+ """Execute a FlatRuleTree with prioritization (first matching rule wins).
+
+ Args:
+ tree: FlatRuleTree to execute
+ builder_config: Configuration with result builder and output literals
+ parameters: Optional parameters struct expression
+ post_process_fn: Applied to the full prioritized struct
+ format_prioritized_fn: Optional custom prioritization logic
+ **inputs: Feature columns as pl.Expr
+
+ Returns:
+ Expression with the prioritized result
+ """
+ return execute_prioritized_rules(
+ rules=tree.rules,
+ builder_config=builder_config,
+ parameters=parameters,
+ post_process_fn=post_process_fn,
+ format_prioritized_fn=format_prioritized_fn,
+ **inputs,
+ )
+
+
+# =============================================================================
+# DataFrame Integration
+# =============================================================================
+
+
+def execute_rule_on_frame(
+ frame: pl.LazyFrame,
+ rule: RuleRoot,
+ builder_config: BuilderConfig,
+ result_col: str = "result",
+ parameters_col: str = "parameters",
+ default_parameters: t.Optional[t.Dict[str, t.Any]] = None,
+) -> pl.LazyFrame:
+ """Execute a single rule over a LazyFrame, appending the result as a new column.
+
+ Args:
+ frame: Input LazyFrame
+ rule: RuleRoot to execute
+ builder_config: Configuration with result builder and output literals
+ result_col: Name of the output column
+ parameters_col: Name of the parameters column in the DataFrame
+ default_parameters: Default parameter values (fallback when runtime params are None)
+
+ Returns:
+ LazyFrame with result_col added
+
+ Raises:
+ ValueError: If required features are missing from the frame
+ """
+ # Get required features from the rule
+ # TODO: Implement get_required_features() on RuleType classes
+ # For now, we'll need to pass all columns
+ schema_names = frame.collect_schema().names()
+ inputs: t.Dict[str, pl.Expr] = {col: pl.col(col) for col in schema_names}
+
+ # Handle parameters
+ parameters = None
+ if parameters_col in schema_names:
+ parameters = pl.col(parameters_col)
+ elif default_parameters:
+ # Build parameters from defaults
+ parameter_literals = {k: pl.lit(v) for k, v in default_parameters.items()}
+ parameters = pl.struct(
+ *[expr.alias(name) for name, expr in parameter_literals.items()]
+ )
+
+ result_expr = execute_rule_root(
+ rule_root=rule,
+ builder_config=builder_config,
+ parameters=parameters,
+ **inputs,
+ )
+
+ return frame.with_columns(result_expr.alias(result_col))
+
+
+def execute_flat_rule_tree_on_frame(
+ frame: pl.LazyFrame,
+ tree: FlatRuleTree,
+ builder_config: BuilderConfig,
+ result_col: str = "result",
+ use_prioritization: bool = True,
+ parameters_col: str = "parameters",
+ default_parameters: t.Optional[t.Dict[str, t.Any]] = None,
+) -> pl.LazyFrame:
+ """Execute a FlatRuleTree over a LazyFrame, appending the result as a new column.
+
+ Args:
+ frame: Input LazyFrame
+ tree: FlatRuleTree to execute
+ builder_config: Configuration with result builder and output literals
+ result_col: Name of the output column
+ use_prioritization: If True, use first-match prioritization; if False, execute
+ all rules and return a list of results
+ parameters_col: Name of the parameters column in the DataFrame
+ default_parameters: Default parameter values (fallback when runtime params are None)
+
+ Returns:
+ LazyFrame with result_col added
+ """
+ schema_names = frame.collect_schema().names()
+ inputs: t.Dict[str, pl.Expr] = {col: pl.col(col) for col in schema_names}
+
+ # Handle parameters
+ parameters = None
+ if parameters_col in schema_names:
+ parameters = pl.col(parameters_col)
+ elif default_parameters:
+ parameter_literals = {k: pl.lit(v) for k, v in default_parameters.items()}
+ parameters = pl.struct(
+ *[expr.alias(name) for name, expr in parameter_literals.items()]
+ )
+
+ if use_prioritization:
+ result_expr = execute_prioritized_flat_rule_tree(
+ tree=tree,
+ builder_config=builder_config,
+ parameters=parameters,
+ **inputs,
+ )
+ else:
+ # Execute all rules, return list
+ results = execute_flat_rule_tree(
+ tree=tree,
+ builder_config=builder_config,
+ parameters=parameters,
+ **inputs,
+ )
+ result_expr = pl.concat_list(results)
+
+ return frame.with_columns(result_expr.alias(result_col))
diff --git a/decider/modules/rules/flat_rules/module.py b/decider/modules/rules/flat_rules/module.py
new file mode 100644
index 0000000..8cf5dfd
--- /dev/null
+++ b/decider/modules/rules/flat_rules/module.py
@@ -0,0 +1,272 @@
+"""Flat rules module — compiles rules into executable Polars expressions."""
+
+import typing as t
+import enum
+import polars as pl
+from pydantic import Field
+from dataclasses import dataclass
+
+from ....serializable.function import DefinedFunction
+from ..common.shared import WithTreeOutput, InputRef
+from ..common.parameters import WithParameters
+from .nodes import BuilderConfig, RuleRoot, FlatRuleTree, RuleMeta
+from .impl import (
+ execute_rule_root,
+ execute_prioritized_rules,
+ execute_flat_rule_tree,
+ build_parameters_expr,
+ default_result_builder,
+ extract_value,
+)
+from ....serializable.schema import PolarsSchema
+from decider.modules.core import BaseExecuteModule
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+ from decider.types import TInputType, TOutputType
+
+
+class PrioritizationMode(str, enum.Enum):
+ """How to handle multiple rules in a FlatRuleTree."""
+
+ first_match = "first_match" # Return first rule that matches (prioritized)
+ all = "all" # Return all rule results as a struct
+
+
+@dataclass
+class OptimRunPolarsExpression:
+ """Optimized Polars expression executor (no feature extraction)."""
+
+ expr: pl.Expr
+
+ def get_output(
+ self,
+ input_frame: pl.DataFrame,
+ ) -> pl.DataFrame:
+ return input_frame.select(self.expr)
+
+
+@dataclass
+class RunPolarsExpression:
+ """Standard Polars expression executor with feature extraction."""
+
+ expr: pl.Expr
+ features: t.List[str]
+ parameters_expr: t.Optional[pl.Expr] = None
+
+ def execute(self, input_frame: pl.DataFrame) -> pl.DataFrame:
+ return input_frame.select(self.expr.struct.unnest())
+
+
+class FlatRuleModule(WithTreeOutput, BaseExecuteModule, WithParameters):
+ """Single rule compiled as a Polars expression."""
+
+ type: t.Literal["flat_rule"]
+ name: str = "output"
+ rule: RuleRoot
+ output_fn: t.Optional[DefinedFunction] = None
+ use_optimized_execution: bool = False
+
+ def get_required_parameters(self) -> t.Set[str]:
+ """Get all parameters required by this rule."""
+ return self.rule.rule.get_required_parameters()
+
+ def get_required_features(self) -> t.Set[str]:
+ """Get all features required by this rule."""
+ return self.rule.rule.get_required_features()
+
+ def build_expression(self) -> RunPolarsExpression:
+ """Compile rule into a RunPolarsExpression that can be executed as Hamilton nodes."""
+ output_fn = (
+ self.output_fn.get_function() if self.output_fn else default_result_builder
+ )
+ config = BuilderConfig(
+ build_result_function=output_fn,
+ output_literals=self.output_literals,
+ default_literal=self.default_literal,
+ )
+
+ # Build the main rule expression
+ required_features = self.get_required_features()
+ inputs = {col: pl.col(col) for col in required_features}
+
+ # Build parameters expression if needed
+ parameters_expr = None
+ if self.parameters:
+ param_schema = self.parameter_schema
+ default_literals = {
+ name: (
+ info._polars_literal
+ if info._polars_literal is not None
+ else pl.lit(None)
+ )
+ for name, info in self.parameters.items()
+ }
+ parameters_expr = build_parameters_expr(
+ runtime_params=(
+ pl.col(self.parameters_col) if self.parameters_col else None
+ ),
+ parameter_schema=param_schema,
+ default_literals=default_literals,
+ )
+
+ # Execute the rule to get the result expression
+ result_expr = execute_rule_root(
+ rule_root=self.rule,
+ builder_config=config,
+ parameters=parameters_expr,
+ **inputs,
+ )
+
+ extra_features = []
+ if self.parameters:
+ extra_features.append(self.parameters_col)
+
+ if self.use_optimized_execution:
+ return OptimRunPolarsExpression(expr=result_expr)
+
+ return RunPolarsExpression(
+ expr=result_expr,
+ features=list(required_features) + extra_features,
+ parameters_expr=parameters_expr,
+ )
+
+ def execute(self, inputs: "TInputType", _executor: "Executor") -> "TOutputType":
+ frame = inputs["input"]
+ if isinstance(frame, pl.LazyFrame):
+ frame = frame.collect()
+ if self.parameters and self.parameters_col not in frame.columns:
+ frame = frame.with_columns(pl.lit(None).alias(self.parameters_col))
+ compiled = self.build_expression()
+ return frame.select(compiled.expr.struct.unnest()).lazy()
+
+
+class PrioritizedFlatRuleModule(WithTreeOutput, BaseExecuteModule, WithParameters):
+ """Multiple flat rules evaluated in priority order; first match wins."""
+
+ type: t.Literal["prioritized_flat_rule"]
+ name: str = "output"
+
+ input_schema: t.Optional[PolarsSchema] = Field(
+ default=None, description="Input schema for casting inputs at runtime"
+ )
+ rules: t.List[RuleRoot] = Field(
+ description="List of rules to evaluate in priority order"
+ )
+ mode: PrioritizationMode = Field(
+ default=PrioritizationMode.first_match,
+ description="'first_match' returns the first rule that matches; 'all' returns all results.",
+ )
+ use_optimized_execution: bool = False
+ output_fn: t.Optional[DefinedFunction] = None
+ post_process_fn: t.Optional[DefinedFunction] = None
+ format_prioritized_fn: t.Optional[DefinedFunction] = None
+
+ def get_required_parameters(self) -> t.Set[str]:
+ """Get all parameters required by any rule."""
+ params = set()
+ for rule in self.rules:
+ params.update(rule.rule.get_required_parameters())
+ return params
+
+ def get_required_features(self) -> t.Set[str]:
+ """Get all features required by any rule."""
+ features = set()
+ for rule in self.rules:
+ features.update(rule.rule.get_required_features())
+ return features
+
+ def build_expression(self) -> RunPolarsExpression:
+ """Compile prioritized flat rules into a RunPolarsExpression."""
+ output_fn = (
+ self.output_fn.get_function() if self.output_fn else default_result_builder
+ )
+ post_process_fn = (
+ self.post_process_fn.get_function()
+ if self.post_process_fn
+ else extract_value
+ )
+ format_prioritized_fn = (
+ self.format_prioritized_fn.get_function()
+ if self.format_prioritized_fn
+ else None
+ )
+
+ config = BuilderConfig(
+ build_result_function=output_fn,
+ output_literals=self.output_literals,
+ default_literal=self.default_literal,
+ )
+
+ # Collect all required features and parameters from all rules
+ required_features = self.get_required_features()
+ inputs = {col: pl.col(col) for col in required_features}
+
+ # Build parameters expression if needed
+ parameters_expr = None
+ if self.parameters:
+ param_schema = self.parameter_schema
+ default_literals = {
+ name: (
+ info._polars_literal
+ if info._polars_literal is not None
+ else pl.lit(None)
+ )
+ for name, info in self.parameters.items()
+ }
+ parameters_expr = build_parameters_expr(
+ runtime_params=(
+ pl.col(self.parameters_col) if self.parameters_col else None
+ ),
+ parameter_schema=param_schema,
+ default_literals=default_literals,
+ )
+
+ # Execute prioritized rules
+ if self.mode == PrioritizationMode.first_match:
+ result_expr = execute_prioritized_rules(
+ rules=self.rules,
+ builder_config=config,
+ parameters=parameters_expr,
+ post_process_fn=post_process_fn,
+ format_prioritized_fn=format_prioritized_fn,
+ **inputs,
+ )
+ elif self.mode == PrioritizationMode.all:
+ # Execute all rules and return as struct
+ results = execute_flat_rule_tree(
+ tree=FlatRuleTree(rules=self.rules),
+ builder_config=config,
+ parameters=parameters_expr,
+ **inputs,
+ )
+ result_expr = pl.struct(
+ *[
+ e.alias(self.rules[i].meta.name or f"rule_{i}")
+ for i, e in enumerate(results)
+ ]
+ )
+ else:
+ raise ValueError(f"Unsupported prioritization mode: {self.mode}")
+
+ extra_features = []
+ if self.parameters:
+ extra_features.append(self.parameters_col)
+
+ if self.use_optimized_execution:
+ return OptimRunPolarsExpression(expr=result_expr)
+
+ return RunPolarsExpression(
+ expr=result_expr,
+ features=list(required_features) + extra_features,
+ parameters_expr=parameters_expr,
+ )
+
+ def execute(self, inputs: "TInputType", _executor: "Executor") -> "TOutputType":
+ frame = inputs["input"]
+ if isinstance(frame, pl.LazyFrame):
+ frame = frame.collect()
+ if self.parameters and self.parameters_col not in frame.columns:
+ frame = frame.with_columns(pl.lit(None).alias(self.parameters_col))
+ compiled = self.build_expression()
+ return frame.select(compiled.expr.struct.unnest()).lazy()
diff --git a/decider/modules/rules/flat_rules/nodes.py b/decider/modules/rules/flat_rules/nodes.py
new file mode 100644
index 0000000..3d42816
--- /dev/null
+++ b/decider/modules/rules/flat_rules/nodes.py
@@ -0,0 +1,582 @@
+"""Flat rules execution nodes.
+
+All structural definitions (operators, conditions, base nodes) live in
+dspd.components.common.nodes. This module adds only execution logic:
+ - build_expression() for all rule nodes
+ - WithUnaryBranches: then/otherwise embedded children
+ - WithCasesBranches: branches list + otherwise index
+
+Re-exports common types so existing imports keep working.
+"""
+
+import typing as t
+import uuid
+from pydantic import BaseModel, Field, model_validator
+import typing_extensions as t_ext
+import polars as pl
+from dataclasses import dataclass
+
+from ..common.shared import InputRef
+from ..common.feature import Feature as _Feature
+from ..common.nodetypes import (
+ BaseRule,
+ TNodeType,
+ TLogicOp,
+ TStringMatchType,
+ RangeEndLogic as CommonRangeEndLogic,
+)
+
+# Import everything from common.nodes
+from ..common.nodes import (
+ TUnaryOp,
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+ CasesBranch,
+ TCaseCondition,
+ CompositeCondition,
+ TCondition,
+ _UnaryOpConditionWrapper,
+ BaseUnaryNode,
+ _CasesRangesCore,
+ _CasesStringMatchCore,
+ _CasesIsInCore,
+ BaseCompositeNode,
+ validate_range_conditions,
+ UnaryLessThanEqual,
+ UnaryLessThan,
+ UnaryEqual,
+ UnaryGreaterThan,
+ UnaryGreaterThanEqual,
+ UnaryNotEqual,
+ UnaryBetween,
+ UnaryIsIn,
+ UnaryStringMatch,
+ UnaryIsNull,
+ UnaryIsNotNull,
+ UnaryIsTrue,
+ UnaryIsFalse,
+)
+
+# Re-export common enums for backward compatibility
+LogicOp = TLogicOp
+StringMatchType = TStringMatchType
+RangeEndLogic = CommonRangeEndLogic
+
+if t.TYPE_CHECKING:
+ from .tree import RuleMeta
+
+
+# =============================================================================
+# Branch Stack and Builder Config
+# =============================================================================
+
+
+class IndexedBranch(t.NamedTuple):
+ """Track which branch was taken in the decision tree.
+
+ For unary rules:
+ index=0 means 'then' branch (condition true)
+ index=None means 'otherwise' branch (condition false)
+
+ For cases rules:
+ index=0,1,2... means conditions[0], conditions[1], conditions[2]...
+ index=None means 'otherwise' branch (no conditions matched)
+ """
+
+ index: t.Optional[int]
+ rule: "RuleType"
+
+
+TBranchStack = t.Tuple[IndexedBranch, ...]
+
+
+@dataclass
+class BuilderConfig:
+ """Configuration for building rule expressions."""
+
+ build_result_function: t.Callable
+ output_literals: t.List[pl.Expr]
+ default_literal: t.Optional[pl.Expr] = None
+ rule_idx: int = 0
+ root_meta: "t.Optional[RuleMeta]" = None
+
+ @property
+ def default_expr(self):
+ return (
+ self.default_literal if self.default_literal is not None else pl.lit(None)
+ )
+
+
+# =============================================================================
+# Leaf Rule
+# =============================================================================
+
+
+class LeafRule(BaseRule):
+ """Terminal node that returns a result."""
+
+ type: t.Literal[TNodeType.LEAF] = TNodeType.LEAF
+ id: t.Optional[str] = Field(default=None)
+ result_idx: int = Field(default=-1)
+
+ @model_validator(mode="after")
+ def ensure_id(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: "BuilderConfig",
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ return config.build_result_function(
+ inputs=inputs,
+ branch_stack=branch_stack,
+ config=config,
+ result_idx=self.result_idx,
+ )
+
+ def get_required_features(self) -> t.Set[str]:
+ return set()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return set()
+
+
+# =============================================================================
+# Mixins for embedded children
+# =============================================================================
+
+
+class WithUnaryBranches(BaseModel):
+ """Mixin: adds embedded then/otherwise children (flat_rules style)."""
+
+ then: t.Optional["RuleType"] = Field(default=None)
+ otherwise: t.Optional["RuleType"] = Field(default=None)
+
+ def _get_then_rule(self) -> "RuleType":
+ return self.then if self.then is not None else LeafRule(result_idx=-1)
+
+ def _get_otherwise_rule(self) -> "RuleType":
+ return self.otherwise if self.otherwise is not None else LeafRule(result_idx=-1)
+
+ def get_branch_required_features(self) -> t.Set[str]:
+ features = set()
+ if self.then:
+ features.update(self.then.get_required_features())
+ if self.otherwise:
+ features.update(self.otherwise.get_required_features())
+ return features
+
+ def get_branch_required_parameters(self) -> t.Set[str]:
+ params = set()
+ if self.then:
+ params.update(self.then.get_required_parameters())
+ if self.otherwise:
+ params.update(self.otherwise.get_required_parameters())
+ return params
+
+
+class WithCasesBranches(BaseModel):
+ """Mixin: adds branches list + otherwise index (flat_rules style)."""
+
+ otherwise: int = Field(description="Default branch index if no conditions match")
+ branches: t.List["RuleType"] = Field(description="Array of branch rules")
+
+ def get_branch_required_features(self) -> t.Set[str]:
+ features = set()
+ for branch in self.branches:
+ features.update(branch.get_required_features())
+ return features
+
+ def get_branch_required_parameters(self) -> t.Set[str]:
+ params = set()
+ for branch in self.branches:
+ params.update(branch.get_required_parameters())
+ return params
+
+
+# =============================================================================
+# Unary Rule
+# =============================================================================
+
+
+class UnaryRule(BaseUnaryNode, WithUnaryBranches):
+ """Single condition with embedded then/otherwise branches."""
+
+ @model_validator(mode="after")
+ def ensure_id(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ condition_expr = self.condition.build_condition(inputs, parameters)
+ then_expr = self._get_then_rule().build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=0, rule=self),),
+ config,
+ parameters,
+ )
+ otherwise_expr = self._get_otherwise_rule().build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=None, rule=self),),
+ config,
+ parameters,
+ )
+ return pl.when(condition_expr).then(then_expr).otherwise(otherwise_expr)
+
+ def get_required_features(self) -> t.Set[str]:
+ return super().get_required_features() | self.get_branch_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return super().get_required_parameters() | self.get_branch_required_parameters()
+
+
+# =============================================================================
+# Cases Rules
+# =============================================================================
+
+
+class CasesRanges(_CasesRangesCore, WithCasesBranches):
+ """Multi-way range branching with embedded branches."""
+
+ conditions: t.List[CasesBranch] = Field(
+ description="List of range conditions mapped to branch indices"
+ )
+
+ @model_validator(mode="after")
+ def ensure_id_and_validate(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ validate_range_conditions(
+ [cb.when for cb in self.conditions if isinstance(cb.when, RangeCondition)],
+ self.strict,
+ )
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ out_expr = pl
+
+ for idx, case_branch in enumerate(self.conditions):
+ assert isinstance(case_branch.when, RangeCondition)
+ condition_expr = case_branch.when.build_range_condition(
+ feature_expr=feature_expr,
+ end_logic=self.end_logic,
+ parameters=parameters,
+ )
+ branch_expr = self.branches[case_branch.then].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=idx, rule=self),),
+ config,
+ parameters,
+ )
+ out_expr = out_expr.when(condition_expr).then(branch_expr)
+
+ otherwise_expr = self.branches[self.otherwise].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=None, rule=self),),
+ config,
+ parameters,
+ )
+
+ if out_expr is pl:
+ return pl.when(feature_expr.is_not_null()).then(otherwise_expr).otherwise(otherwise_expr)
+ return out_expr.otherwise(otherwise_expr)
+
+ def get_required_features(self) -> t.Set[str]:
+ return super().get_required_features() | self.get_branch_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cb in self.conditions:
+ if isinstance(cb.when, RangeCondition):
+ if isinstance(cb.when.min, InputRef):
+ params.add(cb.when.min.key)
+ if isinstance(cb.when.max, InputRef):
+ params.add(cb.when.max.key)
+ params.update(super().get_required_parameters() | self.get_branch_required_parameters())
+ return params
+
+
+class CasesStringMatch(_CasesStringMatchCore, WithCasesBranches):
+ """Multi-way string matching with embedded branches."""
+
+ conditions: t.List[CasesBranch] = Field(
+ description="List of pattern conditions mapped to branch indices"
+ )
+
+ @model_validator(mode="after")
+ def ensure_id(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ if self.trim_whitespace:
+ feature_expr = feature_expr.str.strip_chars()
+ out_expr = pl
+
+ for idx, case_branch in enumerate(self.conditions):
+ assert isinstance(case_branch.when, StringMatchCondition)
+ condition_expr = case_branch.when.build_match_condition(
+ feature_expr=feature_expr,
+ match_type=(
+ self.match_type.value
+ if isinstance(self.match_type, TStringMatchType)
+ else self.match_type
+ ),
+ case_sensitive=self.case_sensitive,
+ parameters=parameters,
+ )
+ branch_expr = self.branches[case_branch.then].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=idx, rule=self),),
+ config,
+ parameters,
+ )
+ out_expr = out_expr.when(condition_expr).then(branch_expr)
+
+ otherwise_expr = self.branches[self.otherwise].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=None, rule=self),),
+ config,
+ parameters,
+ )
+
+ if out_expr is pl:
+ return pl.when(feature_expr.is_not_null()).then(otherwise_expr).otherwise(otherwise_expr)
+ return out_expr.otherwise(otherwise_expr)
+
+ def get_required_features(self) -> t.Set[str]:
+ return super().get_required_features() | self.get_branch_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cb in self.conditions:
+ if isinstance(cb.when, StringMatchCondition):
+ for pattern in cb.when.patterns:
+ if isinstance(pattern, InputRef):
+ params.add(pattern.key)
+ params.update(super().get_required_parameters() | self.get_branch_required_parameters())
+ return params
+
+
+class CasesIsIn(_CasesIsInCore, WithCasesBranches):
+ """Multi-way categorical branching with embedded branches."""
+
+ conditions: t.List[CasesBranch] = Field(
+ description="List of value sets mapped to branch indices"
+ )
+
+ @model_validator(mode="after")
+ def ensure_id(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ feature_expr = self.feature.build_expression(inputs, parameters)
+ out_expr = pl
+
+ for idx, case_branch in enumerate(self.conditions):
+ assert isinstance(case_branch.when, IsInCondition)
+ if isinstance(case_branch.when.values, InputRef):
+ condition_expr = feature_expr == case_branch.when.values.resolve(
+ parameters
+ )
+ else:
+ condition_expr = feature_expr.is_in(case_branch.when.values)
+
+ branch_expr = self.branches[case_branch.then].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=idx, rule=self),),
+ config,
+ parameters,
+ )
+ out_expr = out_expr.when(condition_expr).then(branch_expr)
+
+ otherwise_expr = self.branches[self.otherwise].build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=None, rule=self),),
+ config,
+ parameters,
+ )
+
+ if out_expr is pl:
+ return pl.when(feature_expr.is_not_null()).then(otherwise_expr).otherwise(otherwise_expr)
+ return out_expr.otherwise(otherwise_expr)
+
+ def get_required_features(self) -> t.Set[str]:
+ return super().get_required_features() | self.get_branch_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ params = self.feature.get_required_parameters()
+ for cb in self.conditions:
+ if isinstance(cb.when, IsInCondition) and isinstance(cb.when.values, InputRef):
+ params.add(cb.when.values.key)
+ params.update(super().get_required_parameters() | self.get_branch_required_parameters())
+ return params
+
+
+# Discriminated union for cases rules
+TCasesVariant = t.Annotated[
+ t.Union[CasesRanges, CasesStringMatch, CasesIsIn],
+ Field(discriminator="op"),
+]
+
+
+from pydantic import RootModel
+
+
+class CasesRule(RootModel[TCasesVariant]):
+ """Wrapper for all Cases rule variants (discriminated by 'op' field)."""
+
+ root: TCasesVariant
+
+ @property
+ def type(self) -> TNodeType:
+ return TNodeType.CASES
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ return self.root.build_expression(inputs, branch_stack, config, parameters)
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.root.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.root.get_required_parameters()
+
+
+# =============================================================================
+# Composite Rule
+# =============================================================================
+
+
+class CompositeRule(BaseCompositeNode, WithUnaryBranches):
+ """Composite AND/OR/NOT rule with embedded then/otherwise branches."""
+
+ @model_validator(mode="after")
+ def ensure_id(self) -> t_ext.Self:
+ if self.id is None:
+ self.id = str(uuid.uuid4())
+ return self
+
+ def build_expression(
+ self,
+ inputs: t.Dict[str, pl.Expr],
+ branch_stack: TBranchStack,
+ config: BuilderConfig,
+ parameters: t.Optional[pl.Expr] = None,
+ ) -> pl.Expr:
+ if not self.conditions:
+ # pl.lit(False) is scalar and won't broadcast per-row; use a per-row false
+ composite_condition = pl.int_range(pl.len()) < 0
+ else:
+ # Delegate to CompositeCondition's shared build_condition logic
+ _tmp = CompositeCondition(op=self.op, conditions=self.conditions)
+ composite_condition = _tmp.build_condition(inputs, parameters)
+
+ then_expr = self._get_then_rule().build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=0, rule=self),),
+ config,
+ parameters,
+ )
+ otherwise_expr = self._get_otherwise_rule().build_expression(
+ inputs,
+ branch_stack + (IndexedBranch(index=None, rule=self),),
+ config,
+ parameters,
+ )
+ return pl.when(composite_condition).then(then_expr).otherwise(otherwise_expr)
+
+ def get_required_features(self) -> t.Set[str]:
+ return super().get_required_features() | self.get_branch_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return super().get_required_parameters() | self.get_branch_required_parameters()
+
+
+# =============================================================================
+# Root Structure
+# =============================================================================
+
+
+class RuleMeta(BaseModel):
+ name: t.Optional[str] = None
+ description: t.Optional[str] = None
+
+
+class RuleRoot(BaseModel):
+ meta: RuleMeta = Field(default_factory=RuleMeta)
+ rule: "RuleType"
+
+
+class FlatRuleTree(BaseModel):
+ rules: t.List[RuleRoot] = Field(description="List of independent rule trees")
+
+
+# =============================================================================
+# Top-level Rule Union
+# =============================================================================
+
+RuleType = t.Annotated[
+ t.Union[
+ LeafRule,
+ "UnaryRule",
+ CasesRule,
+ "CompositeRule",
+ ],
+ Field(discriminator="type"),
+]
+
+
+# =============================================================================
+# Rebuild models with forward references
+# =============================================================================
+
+UnaryRule.model_rebuild()
+WithUnaryBranches.model_rebuild()
+WithCasesBranches.model_rebuild()
+CasesRanges.model_rebuild()
+CasesStringMatch.model_rebuild()
+CasesIsIn.model_rebuild()
+CompositeRule.model_rebuild()
+RuleRoot.model_rebuild()
diff --git a/decider/modules/rules/modules.py b/decider/modules/rules/modules.py
new file mode 100644
index 0000000..e554921
--- /dev/null
+++ b/decider/modules/rules/modules.py
@@ -0,0 +1,61 @@
+"""Unified module system for DSP Decider.
+
+Combines tree-based and flat rule-based execution modules.
+"""
+
+import typing as t
+from pydantic import RootModel, Field, Discriminator, Tag
+
+# Tree modules (v1, v2, v3)
+from .tree.v1.tree import Tree as V1Tree
+from .tree.v2.tree import Tree as V2Tree
+from .tree.v3.tree import Tree as V3Tree
+
+# Flat rule module
+from .flat_rules.module import PrioritizedFlatRuleModule
+
+
+# =============================================================================
+# Discriminated Union Types
+# =============================================================================
+
+
+# Union of all tree formats (v1, v2, v3)
+TTreeFormat = t.Union[V1Tree, V2Tree, V3Tree]
+
+# Union of all executable module types (tree formats + flat rules) — discriminated by type literal
+TModule = t.Annotated[
+ t.Union[V1Tree, V2Tree, V3Tree, PrioritizedFlatRuleModule],
+ Field(discriminator="type"),
+]
+
+
+def module_discriminator(value: t.Any) -> str:
+ """Return the type discriminator string for a module value."""
+ if isinstance(value, dict):
+ return value.get("type", "v3-tree")
+ return getattr(value, "type", "v3-tree")
+
+
+class ModuleWrapper(RootModel):
+ """Wrapper for discriminated module union (tree or flat rule)."""
+
+ root: TModule
+
+
+# =============================================================================
+# Exports
+# =============================================================================
+
+__all__ = [
+ # Tree versions
+ "V1Tree",
+ "V2Tree",
+ "V3Tree",
+ # Flat rules
+ "PrioritizedFlatRuleModule",
+ # Type unions
+ "TTreeFormat",
+ "TModule",
+ "ModuleWrapper",
+]
diff --git a/decider/modules/rules/tree/__init__.py b/decider/modules/rules/tree/__init__.py
new file mode 100644
index 0000000..b84fc6f
--- /dev/null
+++ b/decider/modules/rules/tree/__init__.py
@@ -0,0 +1,31 @@
+"""Tree component - supports v1, v2, and v3 tree formats.
+
+For module types that combine trees and flat rules, see dspd.components.modules
+"""
+
+import typing as t
+from pydantic import Field
+
+from .tree import Tree, _Tree
+from .v1.tree import Tree as V1Tree
+from .v2.tree import Tree as V2Tree
+from .v3.tree import Tree as V3Tree
+
+# Import flat rule module for TTree union
+from ..flat_rules.module import PrioritizedFlatRuleModule
+
+# Legacy alias
+UiTree = V2Tree
+
+# TTree union for API/DB usage - tree formats + flat rules
+# Note: Uses _Tree (the internal union) + PrioritizedFlatRuleModule
+TTree = t.Union[_Tree, PrioritizedFlatRuleModule]
+
+__all__ = [
+ "Tree",
+ "V1Tree",
+ "V2Tree",
+ "V3Tree",
+ "UiTree",
+ "TTree",
+]
diff --git a/decider/modules/rules/tree/tree.py b/decider/modules/rules/tree/tree.py
new file mode 100644
index 0000000..7cad7a6
--- /dev/null
+++ b/decider/modules/rules/tree/tree.py
@@ -0,0 +1,117 @@
+import typing as t
+from pydantic import (
+ BaseModel,
+ Field,
+ Tag,
+ RootModel,
+ Discriminator,
+ model_validator,
+)
+
+# from .v0.tree import Tree as V0Tree
+from .v1.tree import Tree as V1Tree
+from .v2.tree import Tree as V2Tree
+from .v3.tree import Tree as V3Tree
+
+
+class DeprecatedTree(BaseModel):
+ type: t.Literal["ui-tree"] = "ui-tree"
+
+ @model_validator(mode="before")
+ def raise_error(value):
+ raise ValueError("This version of the tree has been deprecated")
+
+
+def obj_get(obj: any, k: str, default=None):
+ if isinstance(obj, dict):
+ return obj.get(k, default)
+ return getattr(obj, k, default)
+
+
+def get_tree_version(obj: any):
+ format_version = obj_get(obj, "formatVersion")
+ if format_version is None:
+ format_version = obj_get(obj, "format_version")
+ if format_version is not None:
+ return f"v{format_version}-tree"
+ else:
+ nodes = obj_get(obj, "nodes")
+ if isinstance(nodes, dict):
+ return "v0-tree"
+ return "v1-tree"
+
+
+_Tree = t.Annotated[
+ t.Union[
+ t.Annotated[DeprecatedTree, Tag("v0-tree")],
+ t.Annotated[V1Tree, Tag("v1-tree")],
+ t.Annotated[V2Tree, Tag("v2-tree")],
+ t.Annotated[V3Tree, Tag("v3-tree")],
+ ],
+ Discriminator(get_tree_version),
+]
+
+
+class Tree(RootModel):
+ root: _Tree
+
+
+ def upgrade(self):
+ """Upgrade tree to latest version (v3) via upgrade chain."""
+ upgraded_root = self.root
+
+ # Keep upgrading until we reach v3
+ while hasattr(upgraded_root, "upgrade") and upgraded_root.format_version < 3:
+ upgraded_root = upgraded_root.upgrade()
+
+ # If the upgrade changed the tree, return new wrapper
+ if upgraded_root != self.root:
+ return Tree(root=upgraded_root)
+ else:
+ return self
+
+ @classmethod
+ def latest_format_version(cls) -> int:
+ """Helper to get the latest format version for UI display."""
+ return 3
+
+ @classmethod
+ def default_tree(cls):
+ """Create a default V3 tree with a single leaf node."""
+ from .v3.nodes_ui import LeafNode, Position, PositionedNode
+ from ..common.shared import TreeOutput
+
+ return cls(
+ root=V3Tree(
+ type="v3-tree",
+ metadata=None,
+ edges=[],
+ nodes=[
+ PositionedNode(
+ id="node-1",
+ position=Position(x=100.0, y=100.0),
+ data=LeafNode(type="leaf", result_idx=-1),
+ )
+ ],
+ subtrees=[],
+ input_schema=None,
+ format_version=3,
+ output=TreeOutput(
+ data=[
+ {
+ "Action": "Alert",
+ "Description": "A rule suggests that there is an alert",
+ },
+ {
+ "Action": "Block",
+ "Description": "This is more serious and must be blocked",
+ },
+ ],
+ default={"Action": "Allow", "Description": "Default Action"},
+ dtypes=[("Action", "string"), ("Description", "string")],
+ type_defs={},
+ ),
+ parameters={},
+ parameters_col="parameters",
+ ),
+ )
diff --git a/spockflow/components/scorecard/v2/__init__.py b/decider/modules/rules/tree/v1/__init__.py
similarity index 100%
rename from spockflow/components/scorecard/v2/__init__.py
rename to decider/modules/rules/tree/v1/__init__.py
diff --git a/decider/modules/rules/tree/v1/edges.py b/decider/modules/rules/tree/v1/edges.py
new file mode 100644
index 0000000..41a5d38
--- /dev/null
+++ b/decider/modules/rules/tree/v1/edges.py
@@ -0,0 +1,37 @@
+import typing as t
+from uuid import uuid4
+from pydantic import BaseModel, field_validator, Field
+from collections.abc import Iterable
+
+T = t.TypeVar("T")
+
+
+class EdgeData(BaseModel):
+ sourceIndex: int
+
+
+class MultiEdgeData(BaseModel):
+ sourceIndex: t.List[int]
+
+ @field_validator("sourceIndex", mode="before")
+ @classmethod
+ def ensure_list(cls, v: t.Any) -> t.List[int]:
+ if not isinstance(v, Iterable) or isinstance(v, str):
+ return [v]
+ return list(v)
+
+
+class GenericEdge(BaseModel, t.Generic[T]):
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ source: str
+ target: str
+ data: T
+
+ @field_validator("id", mode="before")
+ @classmethod
+ def ensure_id(cls, v: t.Any) -> str:
+ return str(uuid4()) if v is None else v
+
+
+MultiSourceEdge = GenericEdge[MultiEdgeData]
+Edge = GenericEdge[EdgeData]
diff --git a/decider/modules/rules/tree/v1/nodes.py b/decider/modules/rules/tree/v1/nodes.py
new file mode 100644
index 0000000..22653eb
--- /dev/null
+++ b/decider/modules/rules/tree/v1/nodes.py
@@ -0,0 +1,122 @@
+"""Data layer node types with variable support."""
+
+import typing as t
+import typing_extensions as t_ext
+from logging import getLogger
+from pydantic import BaseModel, Field, model_validator
+
+logger = getLogger(__name__)
+
+
+class HasVariablesMixin:
+ variables: t.List[str] = Field(
+ default_factory=list, description="Variable IDs used by this node"
+ )
+
+
+class RangeNode(BaseModel):
+ """Range test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "numerical_range_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ split_feature_id: int
+ default_left: bool = False
+ thresholds: t.List[float] = Field(
+ default_factory=list, description="Direct threshold values"
+ )
+
+
+class NumericalNode(BaseModel, HasVariablesMixin):
+ """Numerical test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "numerical_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ split_feature_id: int
+ default_left: bool = False
+ comparison_op: t.Literal["<=", "<", "==", ">", ">="] = "<="
+ threshold: t.Optional[float] = Field(
+ description="Direct threshold value", default=None
+ )
+
+ @model_validator(mode="after")
+ def validate_threshold_or_variable(self) -> t_ext.Self:
+ if self.threshold is None and len(self.variables) != 1:
+ raise ValueError(
+ "A single variable must be provided if no threshold is given"
+ )
+ return self
+
+
+class CategoricalNode(BaseModel, HasVariablesMixin):
+ """Categorical test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "categorical_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ split_feature_id: int
+ default_left: bool = False
+ category_list: t.List[int] = Field(
+ default_factory=list, description="Direct category values"
+ )
+ category_list_right_child: bool = False
+
+ @model_validator(mode="after")
+ def validate_threshold_or_variable(self) -> t_ext.Self:
+ if len(self.category_list) == 0 and len(self.variables) != 1:
+ raise ValueError(
+ "One or more variables must be given if no categories are provided."
+ )
+ return self
+
+
+class StringMatchNode(BaseModel, HasVariablesMixin):
+ """String match node."""
+
+ NODE_TYPE: t.ClassVar[str] = "string_match_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ split_feature_id: int
+ default_left: bool = True
+ patterns: t.List[str] = Field(
+ default_factory=list, description="Direct pattern values"
+ )
+ match_type: t.Literal["exact", "starts_with", "contains", "ends_with", "regex"] = (
+ "exact"
+ )
+ case_sensitive: bool = True
+ match_any: bool = True
+
+ # @model_validator(mode="after")
+ # def validate_threshold_or_variable(self) -> t_ext.Self:
+ # if not self.match_any and len(self.variables):
+ # raise ValueError("String match node only supports variables on match_any.")
+ # if len(self.patterns) == 0 and len(self.variables) != 1:
+ # raise ValueError(
+ # "One or more variables must be given if no patterns are provided."
+ # )
+ # return self
+
+
+class LeafNode(BaseModel):
+ """Leaf node."""
+
+ NODE_TYPE: t.ClassVar[str] = "leaf"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ leaf_value: int = -1
+ output_data: t.Optional[t.Dict[str, t.Any]] = None
+
+
+# Union type for all data layer nodes
+NodeData = t_ext.Annotated[
+ t.Union[RangeNode, NumericalNode, CategoricalNode, StringMatchNode, LeafNode],
+ Field(discriminator="node_type"),
+]
+
+
+class Position(BaseModel):
+ x: float
+ y: float
+
+
+class PositionedNode(BaseModel):
+ id: str
+ position: Position = Field(default_factory=lambda: Position(x=0, y=0))
+ data: NodeData
diff --git a/decider/modules/rules/tree/v1/schema.py b/decider/modules/rules/tree/v1/schema.py
new file mode 100644
index 0000000..a2f4ab4
--- /dev/null
+++ b/decider/modules/rules/tree/v1/schema.py
@@ -0,0 +1,383 @@
+"""
+Output schema definitions and validation for tree execution results.
+
+This module contains all the schema-related classes, types, and validation logic
+for the new output system (nodeOutputFormatVersion=1).
+"""
+
+import typing as t
+import typing_extensions as t_ext
+from enum import Enum
+from pydantic import (
+ BaseModel,
+ Field,
+ PrivateAttr,
+ model_validator,
+ field_validator,
+ Discriminator,
+ RootModel,
+ Tag,
+ ConfigDict,
+)
+
+
+class FieldType(str, Enum):
+ """Supported field types for output schema"""
+
+ STRING = "string"
+ NUMBER = "number"
+ BOOLEAN = "boolean"
+ CUSTOM = "custom"
+ LIST = "list"
+
+
+class CustomTypeKind(str, Enum):
+ """Supported custom type kinds"""
+
+ ENUM = "enum"
+ RECORD = "record"
+
+
+class RecordDefinition(BaseModel):
+ """Record definition containing fields and records"""
+
+ fields: (
+ t.Dict[str, t.Literal[FieldType.STRING, FieldType.NUMBER, FieldType.BOOLEAN]]
+ | t.List[
+ t.Tuple[
+ str, t.Literal[FieldType.STRING, FieldType.NUMBER, FieldType.BOOLEAN]
+ ]
+ ]
+ )
+ records: t.List[t.Dict[str, t.Any]] = Field(default_factory=list)
+
+ def field_items(self):
+ if isinstance(self.fields, dict):
+ return self.fields.items()
+ else:
+ return self.fields
+
+
+class EnumDefinition(BaseModel):
+ """Enum definition containing values"""
+
+ values: t.List[str]
+
+
+class _RecordCustomTypeInfo(BaseModel):
+ """Snapshot of a custom record type definition at time of use"""
+
+ id: str
+ name: str
+ type_kind: t.Literal[CustomTypeKind.RECORD] = CustomTypeKind.RECORD
+ definition: RecordDefinition
+ display_format: t.Optional[str] = None
+ _parsed_records: t.Optional[t.List[t.Dict[str, t.Any]]] = PrivateAttr(default=None)
+
+ @property
+ def parsed_records(self) -> t.List[t.Dict[str, t.Any]]:
+ """Get parsed records with validated and converted field values"""
+ if self._parsed_records is None:
+ parsed_data = [{} for _ in range(len(self.definition.records))]
+ for record in self.definition.records:
+ record_id = record.get("id")
+ if record_id is None:
+ raise ValueError("Record missing 'id' field")
+ record_id = int(record_id)
+ if record_id >= len(parsed_data):
+ parsed_data.extend(
+ [{} for _ in range(record_id - len(parsed_data) + 1)]
+ )
+ if parsed_data[record_id]:
+ raise ValueError(
+ f"Duplicate record id {record_id} in custom record type"
+ )
+ parsed_record = {}
+ for field_name, field_type in self.definition.field_items():
+ if field_name not in record:
+ parsed_record[field_name] = None
+ else:
+ field_value = record[field_name]
+ parsed_field = _PrimitiveFieldTypes(
+ field_type=field_type, value=field_value
+ ).root
+ parsed_record[field_name] = parsed_field.parsed_value
+ parsed_data[record_id] = parsed_record
+ self._parsed_records = parsed_data
+ return self._parsed_records
+
+
+class _EnumCustomTypeInfo(BaseModel):
+ """Snapshot of a custom enum type definition at time of use"""
+
+ id: str
+ name: str
+ type_kind: t.Literal[CustomTypeKind.ENUM] = CustomTypeKind.ENUM
+ definition: EnumDefinition
+
+
+_CustomTypeInfo = t.Annotated[
+ t.Union[_RecordCustomTypeInfo, _EnumCustomTypeInfo],
+ Field(discriminator="type_kind"),
+]
+
+
+class OutputField(BaseModel):
+ """Output field definition for the schema"""
+
+ id: str
+ field_name: str
+ field_type: FieldType
+ list_type: t.Optional[
+ t.Literal[
+ FieldType.STRING, FieldType.NUMBER, FieldType.BOOLEAN, FieldType.CUSTOM
+ ]
+ ] = None
+ is_required: bool = True
+ custom_type: t.Optional[_CustomTypeInfo] = None # Embedded snapshot
+ custom_type_id: t.Optional[str] = None # Reference to custom type
+
+ @model_validator(mode="after")
+ def validate_type_fields(self) -> "t_ext.Self":
+ if self.field_type == FieldType.CUSTOM:
+ if self.custom_type is None:
+ raise ValueError("Custom type must be provided for CUSTOM field type")
+ if self.field_type == FieldType.LIST:
+ if self.list_type is None:
+ raise ValueError("List type must be provided for LIST field type")
+ if self.list_type == FieldType.CUSTOM:
+ assert (
+ self.custom_type is not None
+ ), "Custom type must be provided for CUSTOM list type"
+ return self
+
+
+# Private field validation types
+_T = t.TypeVar("T")
+_V = t.TypeVar("V", bound=FieldType)
+
+
+class _BaseFieldValueType(BaseModel, t.Generic[_V, _T]):
+ field_type: _V
+ value: _T
+
+ @property
+ def parsed_value(self):
+ return self.value
+
+
+class _StringFieldValue(_BaseFieldValueType[t.Literal[FieldType.STRING], str]):
+ @field_validator("value", mode="before")
+ @classmethod
+ def convert_to_string(cls, v):
+ """Convert input values to string"""
+ if v is None:
+ return None
+ return str(v)
+
+
+class _NumberFieldValue(
+ _BaseFieldValueType[t.Literal[FieldType.NUMBER], t.Union[int, float]]
+):
+ pass
+
+
+class _BooleanFieldValue(_BaseFieldValueType[t.Literal[FieldType.BOOLEAN], bool]):
+ @field_validator("value", mode="before")
+ @classmethod
+ def convert_to_bool(cls, v):
+ """Convert input values to boolean"""
+ if v is None:
+ return None
+ if isinstance(v, bool):
+ return v
+ if isinstance(v, str):
+ return v.lower() in ("true", "1", "yes", "on")
+ if isinstance(v, (int, float)):
+ return bool(v)
+ return bool(v)
+
+
+_TPrimitiveFieldTypes = t.Annotated[
+ t.Union[_StringFieldValue, _NumberFieldValue, _BooleanFieldValue],
+ Discriminator("field_type"),
+]
+
+
+class _PrimitiveFieldTypes(RootModel):
+ root: _TPrimitiveFieldTypes
+
+
+class _CustomEnumFieldType(_BaseFieldValueType[t.Literal[FieldType.CUSTOM], str]):
+ custom_type: _EnumCustomTypeInfo
+
+ @model_validator(mode="after")
+ def validate_values(self):
+ assert (
+ self.value in self.custom_type.definition.values
+ ), f"Must be one of {self.custom_type.definition.values}, got '{self.value}'"
+ return self
+
+
+class _ReferencedFieldType(BaseModel):
+ model_config = ConfigDict(extra="allow")
+ id: int
+
+
+class _CustomRecordFieldType(
+ _BaseFieldValueType[t.Literal[FieldType.CUSTOM], t.Union[int, _ReferencedFieldType]]
+):
+ custom_type: _RecordCustomTypeInfo
+ _parsed_values: t.Dict[str, t.Union[str, int, float]] = PrivateAttr()
+
+ @model_validator(mode="after")
+ def validate_values(self):
+ value = self.value if isinstance(self.value, int) else self.value.id
+ assert (
+ 0 <= value < len(self.custom_type.parsed_records)
+ ), f"Record id must be between 0 and {len(self.custom_type.parsed_records)-1}, got '{value}'"
+ self._parsed_values = self.custom_type.parsed_records[value]
+ return self
+
+ @property
+ def parsed_value(self):
+ return self._parsed_values
+
+
+def _discriminate_custom_type(v):
+ if isinstance(v, dict):
+ # Raw dict during parsing
+ custom_type = v.get("custom_type", {})
+ else:
+ # Already parsed object
+ custom_type = getattr(v, "custom_type", {})
+ if isinstance(custom_type, dict):
+ return custom_type.get("type_kind")
+ return getattr(custom_type, "type_kind", None)
+
+
+_TCustomTypes = t.Annotated[
+ t.Union[
+ t.Annotated[_CustomEnumFieldType, Tag(CustomTypeKind.ENUM)],
+ t.Annotated[_CustomRecordFieldType, Tag(CustomTypeKind.RECORD)],
+ ],
+ Discriminator(_discriminate_custom_type),
+]
+
+_TListFieldTypes = t.Annotated[
+ t.Union[_StringFieldValue, _NumberFieldValue, _BooleanFieldValue, _TCustomTypes],
+ Discriminator("field_type"),
+]
+
+
+class _ListFieldTypes(RootModel):
+ root: _TListFieldTypes
+
+
+class _ListFieldType(_BaseFieldValueType[t.Literal[FieldType.LIST], list]):
+ list_type: t.Literal[
+ FieldType.STRING, FieldType.BOOLEAN, FieldType.NUMBER, FieldType.CUSTOM
+ ]
+ custom_type: t.Optional[_CustomTypeInfo] = None
+ _parsed_values: t.List[t.Union[str, int, float]] = PrivateAttr()
+
+ @model_validator(mode="after")
+ def validate_values(self):
+ type_kwargs = {"field_type": self.list_type}
+ if self.list_type == FieldType.CUSTOM:
+ assert (
+ self.custom_type is not None
+ ), "Expected custom type to be an enum type if field type is CUSTOM"
+ type_kwargs["custom_type"] = self.custom_type
+ self._parsed_values = [
+ _ListFieldTypes(value=v, **type_kwargs).root.parsed_value
+ for v in self.value
+ ]
+ return self
+
+ @property
+ def parsed_value(self):
+ return self._parsed_values
+
+
+_TFieldValueTypes = t.Annotated[
+ t.Union[
+ _StringFieldValue,
+ _NumberFieldValue,
+ _BooleanFieldValue,
+ _TCustomTypes,
+ _ListFieldType,
+ ],
+ Discriminator("field_type"),
+]
+
+
+class _FieldValueTypes(RootModel):
+ root: _TFieldValueTypes
+
+
+class OutputSchema(BaseModel):
+ """Schema definition for structured output"""
+
+ fields: t.List[OutputField] = Field(default_factory=list)
+ display_format: t.Optional[str] = None # Template for node display
+ default_values: t.Optional[t.Dict[str, t.Any]] = None
+
+ def has_default_values(self) -> bool:
+ if self.default_values is None:
+ return False
+ has_required_fields = any(field.is_required for field in self.fields)
+ # Unfortunately the ui still puts this as an empty dict even if no default values are set
+ # So for now we work around it.
+ return len(self.default_values) > 0 and has_required_fields
+
+ @model_validator(mode="after")
+ def validate_default_values(self) -> "t_ext.Self":
+ if self.has_default_values():
+ errors = self.validate_data(self.default_values)
+ if errors:
+ raise ValueError(f"Default values validation errors: {errors}")
+ return self
+
+ def validate_data(self, data: t.Dict[str, t.Any]) -> t.List[str]:
+ """Validate output data against schema. Returns list of validation errors."""
+ errors = []
+
+ # Check for required fields
+ for field in self.fields:
+ if field.field_name not in data:
+ if field.is_required:
+ errors.append(f"Field '{field.field_name}' is required but missing")
+ continue
+ try:
+ parsed_field = _FieldValueTypes(
+ **field.model_dump(), value=data[field.field_name]
+ )
+ except ValueError as e:
+ errors.append(str(e))
+ continue
+ # The parsed field will automatically convert like "5" -> 5
+ data[field.field_name] = parsed_field.root.parsed_value
+ return errors
+
+ def columns(self):
+ return [field.field_name for field in self.fields]
+
+ def dtypes(self):
+ dtype_mapping = {}
+ for field in self.fields:
+ if field.field_type == FieldType.STRING:
+ dtype_mapping[field.field_name] = "object"
+ elif field.field_type == FieldType.NUMBER:
+ dtype_mapping[field.field_name] = "float64"
+ elif field.field_type == FieldType.BOOLEAN:
+ dtype_mapping[field.field_name] = "boolean"
+ elif field.field_type == FieldType.LIST:
+ # Lists are stored as objects in pandas
+ dtype_mapping[field.field_name] = "object"
+ elif field.field_type == FieldType.CUSTOM:
+ # Custom types (enums, records) are stored as objects
+ dtype_mapping[field.field_name] = "object"
+ else:
+ dtype_mapping[field.field_name] = "object"
+ return dtype_mapping
diff --git a/decider/modules/rules/tree/v1/tree.py b/decider/modules/rules/tree/v1/tree.py
new file mode 100644
index 0000000..8ca3744
--- /dev/null
+++ b/decider/modules/rules/tree/v1/tree.py
@@ -0,0 +1,586 @@
+import typing as t
+import polars as pl
+from pydantic import BaseModel, Field, PrivateAttr, model_validator
+from decider.modules.core import BaseExecuteModule
+from decider.types import TInputType, TOutputType
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+from .nodes import (
+ NodeData,
+ PositionedNode,
+ LeafNode,
+ RangeNode,
+ NumericalNode,
+ CategoricalNode,
+ StringMatchNode,
+)
+from .variables import VariableMap, PlaceHolderVariable
+from .edges import MultiSourceEdge, MultiEdgeData
+from .schema import OutputSchema, FieldType
+from logging import getLogger
+
+logger = getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# Output collection helpers
+# ---------------------------------------------------------------------------
+
+
+def _collect_outputs_v0(
+ nodes: t.List[PositionedNode],
+ tree_output: t.Optional["TreeOutput"],
+) -> t.Tuple[t.List[t.Dict], t.Optional[t.Dict], t.Dict[str, int]]:
+ """v0: outputs live in a flat table (treeOutput); leaf_value is the row index."""
+ if tree_output is None:
+ return [], None, {}
+
+ collected_output = [dict(zip(tree_output.columns, row)) for row in tree_output.data]
+ node_id_output_map = {
+ node.id: node.data.leaf_value
+ for node in nodes
+ if isinstance(node.data, LeafNode)
+ }
+ return (
+ collected_output[1:],
+ collected_output[0] if len(collected_output) > 0 else None,
+ node_id_output_map,
+ )
+
+
+def _collect_outputs_v1(
+ nodes: t.List[PositionedNode],
+ output_schema: t.Optional[OutputSchema],
+) -> t.Tuple[t.List[t.Dict], t.Optional[t.Dict], t.Dict[str, int]]:
+ """v1: output data is embedded in each leaf via output_data; deduplicate into a list."""
+ outputs: t.List[t.Dict] = []
+ # Use a stable key (sorted items) to deduplicate identical output dicts
+ seen: t.Dict[str, int] = {}
+ node_id_output_map: t.Dict[str, int] = {}
+
+ for node in nodes:
+ if not isinstance(node.data, LeafNode):
+ continue
+ if node.data.output_data is None:
+ node_id_output_map[node.id] = -1
+ continue
+ key = str(sorted(node.data.output_data.items()))
+ if key not in seen:
+ seen[key] = len(outputs)
+ outputs.append(node.data.output_data)
+ node_id_output_map[node.id] = seen[key]
+
+ default_value = (
+ output_schema.default_values
+ if output_schema and output_schema.has_default_values()
+ else None
+ )
+ return outputs, default_value, node_id_output_map
+
+
+# ---------------------------------------------------------------------------
+# Schema conversion helper
+# ---------------------------------------------------------------------------
+
+_FIELD_TYPE_TO_POLARS: t.Dict[FieldType, str] = {
+ FieldType.STRING: "String",
+ FieldType.NUMBER: "Float64",
+ FieldType.BOOLEAN: "Boolean",
+}
+
+_LIST_ITEM_TYPE_TO_POLARS: t.Dict[FieldType, str] = {
+ FieldType.STRING: "String",
+ FieldType.NUMBER: "Float64",
+ FieldType.BOOLEAN: "Boolean",
+}
+
+
+_RECORD_FIELD_TYPE_TO_POLARS: t.Dict[str, str] = {
+ "string": "String",
+ "number": "Float64",
+ "boolean": "Boolean",
+}
+
+
+def _record_to_polars_struct(record_def: "RecordDefinition") -> t.Dict[str, t.Any]:
+ """Convert a v1 RecordDefinition to a nested PolarsSchema-compatible struct dict."""
+ from .schema import RecordDefinition
+
+ return {
+ field_name: _RECORD_FIELD_TYPE_TO_POLARS.get(field_type, "String")
+ for field_name, field_type in record_def.field_items()
+ }
+
+
+def _convert_custom_types_to_defs(output_schema: "OutputSchema") -> t.Dict[str, t.Any]:
+ """Extract custom types from OutputSchema and convert to StructTypeDef/CategoricalTypeDef.
+
+ Returns a dict mapping custom type ID to its type definition, suitable for the
+ v2 TreeOutput.type_defs field.
+ """
+ from .schema import CustomTypeKind
+ from .....serializable.dtypes import StructTypeDef, CategoricalTypeDef
+
+ type_defs: t.Dict[str, t.Any] = {}
+
+ for field in output_schema.fields:
+ # Check if field has a custom type (either directly or as list item type)
+ custom_type = field.custom_type if field.custom_type else None
+ custom_type_id = custom_type.id if custom_type else None
+
+ # For list fields, check if the list item type is custom
+ if field.field_type == FieldType.LIST and field.list_type == FieldType.CUSTOM:
+ custom_type_id = field.custom_type_id
+ # Get the custom type from the field (it should be there based on schema)
+ custom_type = field.custom_type
+
+ if not custom_type or not custom_type_id:
+ continue
+
+ # Skip if we've already processed this type (multiple fields can reference same type)
+ if custom_type_id in type_defs:
+ continue
+
+ if custom_type.type_kind == CustomTypeKind.RECORD:
+ # Convert record definition to struct type def
+ struct_def = _record_to_polars_struct(custom_type.definition)
+
+ # Extract the actual record data
+ records_data = custom_type.definition.records
+
+ type_defs[custom_type_id] = StructTypeDef(
+ name=custom_type.name,
+ type="struct",
+ definition={
+ "fields": struct_def,
+ "data": records_data,
+ "display_field": "id",
+ },
+ )
+ elif custom_type.type_kind == CustomTypeKind.ENUM:
+ # Convert enum definition to categorical type def
+ type_defs[custom_type_id] = CategoricalTypeDef(
+ name=custom_type.name,
+ type="categorical",
+ definition={"categories": custom_type.definition.values},
+ )
+
+ return type_defs
+
+
+def _convert_schema(output_schema: "OutputSchema") -> t.Dict[str, t.Any]:
+ """Convert a v1 OutputSchema to the dict format accepted by PolarsSchema.
+
+ Custom types are now converted to type references (ExplicitType with type_id)
+ instead of inlining their definitions.
+ """
+ from .schema import CustomTypeKind
+ from .....serializable.schema import ExplicitType
+
+ result: t.Dict[str, t.Any] = {}
+ for field in output_schema.fields:
+ if field.field_type in _FIELD_TYPE_TO_POLARS:
+ result[field.field_name] = _FIELD_TYPE_TO_POLARS[field.field_type]
+ elif field.field_type == FieldType.LIST:
+ # Handle list of custom types
+ if field.list_type == FieldType.CUSTOM and field.custom_type_id:
+ item_type = ExplicitType(type="Custom", type_id=field.custom_type_id)
+ else:
+ item_type = _LIST_ITEM_TYPE_TO_POLARS.get(field.list_type, "String")
+ result[field.field_name] = ExplicitType(type="List", inner=item_type)
+ elif field.field_type == FieldType.CUSTOM:
+ # Custom types are now referenced by ID instead of inlined
+ from .schema import CustomTypeKind
+
+ if field.custom_type and field.custom_type.type_kind == CustomTypeKind.ENUM:
+ # Reference the categorical type def
+ result[field.field_name] = ExplicitType(
+ type="Custom", type_id=field.custom_type.id
+ )
+ elif (
+ field.custom_type
+ and field.custom_type.type_kind == CustomTypeKind.RECORD
+ ):
+ # Reference the struct type def
+ result[field.field_name] = ExplicitType(
+ type="Custom", type_id=field.custom_type.id
+ )
+ else:
+ result[field.field_name] = "String"
+ else:
+ result[field.field_name] = "String"
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Node upgrade helper
+# ---------------------------------------------------------------------------
+
+
+def _resolve_variable(var_id: str, variables: VariableMap) -> t.Any:
+ """Resolve a variable ID to its value, with a warning if not found."""
+ var = variables.get(var_id)
+ if var is None:
+ logger.warning(f"Variable ID '{var_id}' not found in variable map; using None.")
+ return None
+ return var.value
+
+
+def _upgrade_node_data(
+ data: NodeData,
+ features: t.List[str],
+ variables: VariableMap,
+ node_id_output_map: t.Dict[str, int],
+ node_id: str,
+) -> NodeData:
+ from ..v2.nodes import (
+ RangeNode as V2RangeNode,
+ NumericalNode as V2NumericalNode,
+ CategoricalNode as V2CategoricalNode,
+ StringMatchNode as V2StringMatchNode,
+ LeafNode as V2LeafNode,
+ )
+ from ...common.shared import InputRef
+
+ if isinstance(data, RangeNode):
+
+ return V2RangeNode(
+ feature=features[data.split_feature_id],
+ default_left=data.default_left,
+ thresholds=data.thresholds, # already floats; no variable support in v1 RangeNode
+ )
+
+ if isinstance(data, NumericalNode):
+ if data.threshold is not None:
+ threshold: t.Union[float, InputRef] = data.threshold
+ else:
+ var_name = variables[data.variables[0]].name
+ threshold = InputRef(key=var_name)
+ return V2NumericalNode(
+ feature=features[data.split_feature_id],
+ default_left=data.default_left,
+ comparison_op=data.comparison_op,
+ threshold=threshold,
+ )
+
+ if isinstance(data, CategoricalNode):
+ # Note: V2 CategoricalNode doesn't have category_list_right_child
+ # We'll handle edge swapping in the upgrade() method
+ if data.variables:
+ if len(data.variables) > 1:
+ logger.warning(
+ f"CategoricalNode with ID '{node_id}' has multiple variables; only the first will be used."
+ )
+ category_list: t.Union[InputRef, t.List[int]] = InputRef(
+ key=variables[data.variables[0]].name
+ )
+ else:
+ category_list = list(
+ data.category_list
+ ) # already a list of ints; no variable support in v1 CategoricalNode
+ return V2CategoricalNode(
+ feature=features[data.split_feature_id],
+ category_list=category_list,
+ )
+
+ if isinstance(data, StringMatchNode):
+ if data.variables:
+ if len(data.variables) > 1:
+ logger.warning(
+ f"StringMatchNode with ID '{node_id}' has multiple variables; only the first will be used."
+ )
+ patterns = InputRef(key=variables[data.variables[0]].name)
+ else:
+ patterns = (
+ data.patterns
+ ) # already a list of strings; no variable support in v1 StringMatchNode
+ return V2StringMatchNode(
+ feature=features[data.split_feature_id],
+ default_left=data.default_left,
+ patterns=patterns,
+ match_type=data.match_type,
+ case_sensitive=data.case_sensitive,
+ match_any=data.match_any,
+ )
+
+ if isinstance(data, LeafNode):
+ return V2LeafNode(
+ leaf_value=node_id_output_map.get(node_id, -1),
+ )
+
+ raise ValueError(f"Unknown node type: {type(data)}")
+
+
+def _upgrade_nodes(
+ nodes: t.List[PositionedNode],
+ features: t.List[str],
+ variables: VariableMap,
+ node_id_output_map: t.Dict[str, int],
+) -> t.Tuple[t.List[PositionedNode], t.Set[str]]:
+ """Upgrade nodes from v1 to v2 format.
+
+ Returns:
+ Tuple of (upgraded_nodes, nodes_needing_edge_swap)
+ nodes_needing_edge_swap contains IDs of nodes that had category_list_right_child=True
+ """
+ from ..v2.nodes import PositionedNode as V2PositionedNode
+
+ upgraded_nodes = []
+ nodes_needing_edge_swap = set()
+
+ for node in nodes:
+ # Track categorical nodes with category_list_right_child=True
+ if (
+ isinstance(node.data, CategoricalNode)
+ and node.data.category_list_right_child
+ ):
+ nodes_needing_edge_swap.add(node.id)
+
+ upgraded_nodes.append(
+ V2PositionedNode(
+ id=node.id,
+ position=node.position.model_dump(mode="python"),
+ data=_upgrade_node_data(
+ node.data, features, variables, node_id_output_map, node.id
+ ),
+ )
+ )
+
+ return upgraded_nodes, nodes_needing_edge_swap
+
+
+class TreeOutput(BaseModel):
+ columns: t.List[str]
+ data: t.List[t.List[str]]
+ dtype: t.Optional[t.List[str]] = None
+
+
+class TreeMetadata(BaseModel):
+ name: t.Optional[str] = None
+ description: t.Optional[str] = None
+
+
+class SubTree(BaseModel):
+ id: t.Optional[str] = None
+ name: t.Optional[str] = None
+ order: int # Frontend uses "order" instead of "priority"
+ rootNodeId: str
+ isActive: t.Optional[bool] = None
+ hidden: t.Optional[bool] = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def translate_priority_to_order(cls, values: t.Any) -> t.Any:
+ """Convert priority to order for backward compatibility."""
+ if isinstance(values, dict):
+ # If we have priority but no order, translate it
+ if "priority" in values and "order" not in values:
+ values = values.copy() # Don't mutate original
+ values["order"] = values["priority"]
+ # Remove priority from the dict since we don't have it as a field
+ values.pop("priority", None)
+ return values
+
+ @property
+ def priority(self) -> int:
+ """Get priority for compatibility with older versions."""
+ return self.order
+
+
+class Tree(BaseExecuteModule):
+ type: t.Literal["v1-tree"]
+ name: str = "output"
+ features: t.List[str]
+ nodes: t.List[PositionedNode]
+ metadata: TreeMetadata | None = None
+ tree_output: TreeOutput | None = Field(alias="treeOutput", default=None)
+ edges: t.List[MultiSourceEdge]
+ subtrees: t.List[SubTree] = Field(default_factory=list)
+ variables: VariableMap = Field(
+ default_factory=dict, description="Variable definitions keyed by ID"
+ )
+ output_schema: t.Optional[OutputSchema] = Field(
+ alias="outputSchema",
+ default=None,
+ description="Structured output schema with global types support",
+ )
+ node_output_format_version: int = Field(
+ alias="nodeOutputFormatVersion",
+ default=-1,
+ description="Output format version: 0=legacy table, 1=global types schema -1 Try infer",
+ )
+ format_version: t.Literal[1] = Field(alias="formatVersion", default=1)
+ variable_input_name: str = "variables"
+
+ def upgrade(self):
+ """Upgrade to the latest v2 tree format."""
+ from ..v2.tree import (
+ Tree as V2Tree,
+ SubTree as V2SubTree,
+ TreeOutput as V2TreeOutput,
+ )
+ from .....serializable.schema import PolarsSchema
+
+ keyed_nodes = {n.id: n for n in self.nodes}
+
+ # Determine which output collection strategy to use
+ use_v0 = self.node_output_format_version == 0 or (
+ self.node_output_format_version == -1 and self.tree_output is not None
+ )
+
+ if use_v0:
+ collected_output, default_value, node_id_output_map = _collect_outputs_v0(
+ self.nodes, self.tree_output
+ )
+ # v0 format: infer schema from collected output data
+ # If there's data, infer from first row; otherwise use empty struct
+ _schema_collected_output = (
+ collected_output + [default_value]
+ if default_value
+ else collected_output
+ )
+ if _schema_collected_output and len(_schema_collected_output) > 0:
+ from .....serializable.schema import convert_schema
+ import polars as pl
+
+ # Infer schema from the first output row
+ temp_df = pl.DataFrame([_schema_collected_output[0]])
+ dtypes_struct = convert_schema(temp_df.schema)
+ else:
+ dtypes_struct = {} # Empty struct
+ type_defs = {}
+ else:
+ from .....serializable.schema import ExplicitType
+
+ collected_output, default_value, node_id_output_map = _collect_outputs_v1(
+ self.nodes, self.output_schema
+ )
+ if self.output_schema:
+ dtypes_struct = _convert_schema(self.output_schema)
+ # Extract custom type definitions from the output schema
+ type_defs = _convert_custom_types_to_defs(self.output_schema)
+ else:
+ # No schema: infer from collected output data
+ if collected_output and len(collected_output) > 0:
+ from .....serializable.schema import convert_schema
+ import polars as pl
+
+ temp_df = pl.DataFrame([collected_output[0]])
+ dtypes_struct = convert_schema(temp_df.schema)
+ else:
+ dtypes_struct = {} # Empty struct
+ type_defs = {}
+
+ for field_key, field_dtype in dtypes_struct.items():
+ if (
+ isinstance(field_dtype, ExplicitType)
+ and field_dtype.type == "Custom"
+ ):
+ custom_type = type_defs[field_dtype.type_id]
+ if custom_type.type == "struct":
+ for row in collected_output:
+ # Handle case where value is already an index (int) or a dict
+ field_value = row[field_key]
+ if isinstance(field_value, dict):
+ row[field_key] = {
+ "$key": field_value[
+ custom_type.definition.display_field
+ ]
+ }
+ # If it's already an int (index), leave it as is
+
+ # Upgrade nodes and get the set of nodes that need edge swapping
+ upgraded_nodes, nodes_needing_edge_swap = _upgrade_nodes(
+ self.nodes, self.features, self.variables, node_id_output_map
+ )
+
+ # Swap edges for categorical nodes with category_list_right_child=True
+ # In v2, matching categories always go to sourceIndex 0, non-matching to sourceIndex 1
+ # In v1 with category_list_right_child=True, it was reversed
+ upgraded_edges = []
+ for edge in self.edges:
+ if edge.source in nodes_needing_edge_swap:
+ # Swap sourceIndex: 0 <-> 1
+ swapped_indices = [1 - idx for idx in edge.data.sourceIndex]
+ upgraded_edges.append(
+ MultiSourceEdge(
+ id=edge.id,
+ source=edge.source,
+ target=edge.target,
+ data=MultiEdgeData(sourceIndex=swapped_indices),
+ )
+ )
+ else:
+ upgraded_edges.append(edge)
+
+ input_schema = [[ft, "float32"] for ft in self.features]
+
+ for node in self.nodes:
+ if isinstance(node.data, NumericalNode):
+ feature_name = self.features[node.data.split_feature_id]
+ if feature_name not in input_schema:
+ input_schema[node.data.split_feature_id] = [feature_name, "float32"]
+ elif isinstance(node.data, CategoricalNode):
+ feature_name = self.features[node.data.split_feature_id]
+ if feature_name not in input_schema:
+ input_schema[node.data.split_feature_id] = [feature_name, "string"]
+ elif isinstance(node.data, StringMatchNode):
+ feature_name = self.features[node.data.split_feature_id]
+ if feature_name not in input_schema:
+ input_schema[node.data.split_feature_id] = [feature_name, "string"]
+
+ # Convert v1 variables to v2 parameters
+ from ...common.parameters import ParameterInfo
+ from .....serializable.schema import PrimitiveSchema
+
+ parameters = {}
+ if self.variables:
+ for var in self.variables.values():
+ # Map v1 variable types to v2 parameter types
+ if var.var_type == "numeric":
+ param_type = PrimitiveSchema(type="Float64")
+ elif var.var_type == "string":
+ param_type = PrimitiveSchema(type="String")
+ else:
+ logger.warning(
+ f"Unsupported variable type '{var.var_type}' for variable '{var.name}'; defaulting to String"
+ )
+ param_type = PrimitiveSchema(type="String")
+
+ parameters[var.name] = ParameterInfo(
+ type=param_type, default_value=var.value
+ )
+
+ return V2Tree(
+ metadata=self.metadata and self.metadata.model_dump(mode="python"),
+ edges=upgraded_edges,
+ nodes=upgraded_nodes,
+ input_schema=input_schema,
+ subtrees=[
+ V2SubTree(id=s.rootNodeId, name=s.name)
+ for s in (
+ sorted(self.subtrees, key=lambda st: st.order)
+ if self.subtrees
+ else []
+ )
+ if s.rootNodeId in keyed_nodes
+ ],
+ output=V2TreeOutput(
+ data=collected_output,
+ default=default_value,
+ dtypes=dtypes_struct,
+ type_defs=type_defs,
+ ),
+ parameters=parameters,
+ parameters_col=self.variable_input_name,
+ )
+
+ def execute(self, inputs: TInputType, _executor: "Executor") -> TOutputType:
+ frame = inputs["input"]
+ if isinstance(frame, pl.LazyFrame):
+ frame = frame.collect()
+ v2 = self.upgrade()
+ v3 = v2.upgrade()
+ compiled = v3.to_tree_module().build_expression()
+ return frame.select(compiled.expr.struct.unnest()).lazy()
diff --git a/decider/modules/rules/tree/v1/variables.py b/decider/modules/rules/tree/v1/variables.py
new file mode 100644
index 0000000..1a4d6ba
--- /dev/null
+++ b/decider/modules/rules/tree/v1/variables.py
@@ -0,0 +1,71 @@
+"""Variable types for placeholder support in tree nodes."""
+
+import typing as t
+from pydantic import BaseModel, Field, BeforeValidator
+from enum import Enum
+
+
+class VariableType(str, Enum):
+ """Supported variable types matching frontend VariableType."""
+
+ STRING = "string"
+ NUMERIC = "numeric"
+ NUMERIC_EXPR = "numeric_expr"
+
+
+_VT = t.TypeVar("_VT", bound=VariableType)
+_T = t.TypeVar("_T")
+
+
+def _normalize_var_type(value: t.Any) -> t.Any:
+ """Normalize 'type' field to 'var_type' for discriminated union compatibility."""
+ if isinstance(value, dict):
+ # Handle single variable dict
+ if "type" in value and "var_type" not in value:
+ value["var_type"] = value.pop("type")
+ # Handle VariableMap dict (keys are IDs, values are variables)
+ else:
+ for var_dict in value.values():
+ if (
+ isinstance(var_dict, dict)
+ and "type" in var_dict
+ and "var_type" not in var_dict
+ ):
+ var_dict["var_type"] = var_dict.pop("type")
+ return value
+
+
+class _BasePlaceHolderVariable(BaseModel, t.Generic[_VT, _T]):
+ """Represents a placeholder variable reference in tree nodes.
+
+ This matches the frontend VariableInfo structure but simplified for backend use.
+ """
+
+ id: t.Optional[str] = Field(description="Variable ID", default=None)
+ name: str = Field(description="Variable name", pattern="^[a-z][a-z0-9_]*$")
+ var_type: _VT = Field(description="Variable type")
+ value: _T = Field(description="Variable value")
+
+ def __str__(self) -> str:
+ return f"#{self.name}"
+
+
+class NumericVariable(
+ _BasePlaceHolderVariable[t.Literal[VariableType.NUMERIC], t.Union[int, float]]
+):
+ pass
+
+
+class StringVariable(_BasePlaceHolderVariable[t.Literal[VariableType.STRING], str]):
+ pass
+
+
+PlaceHolderVariable = t.Annotated[
+ t.Union[NumericVariable, StringVariable], Field(discriminator="var_type")
+]
+
+# NOTE: The keys in the dict are variable ids which are stored in the nodes,
+# Having a separate variable id over a variable name means we can change the name without having to go to all the nodes and update the id
+VariableMap = t.Annotated[
+ t.Dict[str, PlaceHolderVariable], BeforeValidator(_normalize_var_type)
+]
diff --git a/spockflow/components/tree/v1/__init__.py b/decider/modules/rules/tree/v2/__init__.py
similarity index 100%
rename from spockflow/components/tree/v1/__init__.py
rename to decider/modules/rules/tree/v2/__init__.py
diff --git a/decider/modules/rules/tree/v2/nodes.py b/decider/modules/rules/tree/v2/nodes.py
new file mode 100644
index 0000000..39b57f6
--- /dev/null
+++ b/decider/modules/rules/tree/v2/nodes.py
@@ -0,0 +1,230 @@
+"""Data layer node types with variable support."""
+
+import typing as t
+from uuid import UUID
+import typing_extensions as t_ext
+from logging import getLogger
+from pydantic import BaseModel, Field, model_validator
+from ...common.nodetypes import RangeEndLogic
+from ...common.shared import InputRef
+
+logger = getLogger(__name__)
+
+
+class RangeNode(BaseModel):
+ """Range test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "numerical_range_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ feature: str
+ # Heterogeneous list: each threshold can be a literal float or a variable reference
+ thresholds: t.Union[InputRef, t.List[t.Union[float, InputRef]]] = Field(
+ default_factory=list,
+ description="Threshold values (floats or variable references)",
+ )
+ end_logic: RangeEndLogic = RangeEndLogic.lower_inclusive
+
+ def get_required_parameters(self) -> t.Set[str]:
+ required_params = set()
+ thresholds = self.thresholds
+ if not isinstance(thresholds, list):
+ thresholds = [thresholds]
+ for threshold in thresholds:
+ if isinstance(threshold, InputRef):
+ required_params.add(threshold.key)
+ return required_params
+
+ def to_v3_node(self) -> "t.Any":
+ """Convert v2 RangeNode to v3 CasesRanges."""
+ from ..v3.nodes_ui import CasesRanges, RangeCondition
+
+ thrs = (
+ self.thresholds if isinstance(self.thresholds, list) else [self.thresholds]
+ )
+
+ # Build range conditions from thresholds
+ # Range i: [thrs[i-1], thrs[i])
+ conditions = [
+ RangeCondition(
+ min=thrs[i - 1] if i > 0 else None,
+ max=thrs[i],
+ )
+ for i in range(len(thrs))
+ ]
+
+ return CasesRanges(
+ feature=self.feature,
+ conditions=conditions,
+ end_logic=self.end_logic,
+ strict=False,
+ )
+
+
+class NumericalNode(BaseModel):
+ """Numerical test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "numerical_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ feature: str
+ comparison_op: t.Literal["<=", "<", "==", ">", ">=", "!="] = "<="
+ threshold: t.Union[float, InputRef] = Field(
+ description="Direct threshold value or a variable reference",
+ )
+
+ def get_required_parameters(self) -> t.Set[str]:
+ if isinstance(self.threshold, InputRef):
+ return {self.threshold.key}
+ return set()
+
+ def to_v3_node(self) -> "t.Any":
+ """Convert v2 NumericalNode to v3 UnaryNode."""
+ from ..v3.nodes_ui import (
+ UnaryNode,
+ UnaryLeq,
+ UnaryLt,
+ UnaryEq,
+ UnaryGt,
+ UnaryGeq,
+ UnaryNeq,
+ )
+
+ # Map comparison_op to appropriate v3 unary operator
+ op_map = {
+ "<=": UnaryLeq,
+ "<": UnaryLt,
+ "==": UnaryEq,
+ ">": UnaryGt,
+ ">=": UnaryGeq,
+ "!=": UnaryNeq,
+ }
+
+ op_class = op_map[self.comparison_op]
+ condition = op_class(feature=self.feature, threshold=self.threshold)
+
+ return UnaryNode(condition=condition)
+
+
+class CategoricalNode(BaseModel):
+ """Categorical test node."""
+
+ NODE_TYPE: t.ClassVar[str] = "categorical_test_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ feature: str
+
+ category_list: t.Union[InputRef, t.List[t.Union[int, float]]] = Field(
+ default_factory=list, description="Direct category values"
+ )
+
+ def get_required_parameters(self) -> t.Set[str]:
+ required_params = set()
+ category_list = self.category_list
+ if not isinstance(category_list, list):
+ category_list = [category_list]
+ for category in category_list:
+ if isinstance(category, InputRef):
+ required_params.add(category.key)
+ return required_params
+
+ def to_v3_node(self) -> "t.Any":
+ """Convert v2 CategoricalNode to v3 UnaryNode with IsInOp."""
+ from ..v3.nodes_ui import UnaryNode, IsInOp
+
+ # Convert to UnaryNode with IsInOp condition
+ condition = IsInOp(feature=self.feature, values=self.category_list)
+ return UnaryNode(condition=condition)
+
+
+class StringMatchNode(BaseModel):
+ """String match node."""
+
+ NODE_TYPE: t.ClassVar[str] = "string_match_node"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ feature: str
+
+ patterns: t.Union[InputRef, t.List[str]] = Field(
+ default_factory=list, description="Direct pattern values"
+ )
+ match_type: t.Literal["exact", "starts_with", "contains", "ends_with", "regex"] = (
+ "exact"
+ )
+ case_sensitive: bool = True
+ match_any: bool = True
+
+ def get_required_parameters(self) -> t.Set[str]:
+ patterns = self.patterns
+ if isinstance(patterns, InputRef):
+ return {patterns.key}
+ return {}
+
+ def to_v3_node(self) -> "t.Any":
+ """Convert v2 StringMatchNode to v3 UnaryNode or CasesStringMatch."""
+ from ..v3.nodes_ui import (
+ UnaryNode,
+ StringMatchOp,
+ CasesStringMatch,
+ StringMatchCondition,
+ )
+ from ...common.nodetypes import TStringMatchType
+
+ if self.match_any:
+ # match_any=True: Use UnaryNode with StringMatchOp
+ # Normalize: old V2 allowed a bare InputRef; new unified schema uses List[Union[str, InputRef]]
+ patterns = (
+ self.patterns if isinstance(self.patterns, list) else [self.patterns]
+ )
+ condition = StringMatchOp(
+ feature=self.feature,
+ patterns=patterns,
+ match_type=TStringMatchType(self.match_type),
+ case_sensitive=self.case_sensitive,
+ )
+ return UnaryNode(condition=condition)
+ else:
+ # match_any=False: Use CasesStringMatch for separate branches per pattern
+ pattern_list = (
+ self.patterns if isinstance(self.patterns, list) else [self.patterns]
+ )
+ conditions = [
+ StringMatchCondition(patterns=[p] if isinstance(p, str) else p)
+ for p in pattern_list
+ ]
+ return CasesStringMatch(
+ feature=self.feature,
+ match_type=TStringMatchType(self.match_type),
+ case_sensitive=self.case_sensitive,
+ conditions=conditions,
+ )
+
+class LeafNode(BaseModel):
+ """Leaf node."""
+
+ NODE_TYPE: t.ClassVar[str] = "leaf"
+ node_type: t.Literal[NODE_TYPE] = NODE_TYPE # type: ignore[valid-type]
+ leaf_value: int = -1
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return set()
+
+ def to_v3_node(self) -> "t.Any":
+ """Convert v2 LeafNode to v3 LeafNode."""
+ from ..v3.nodes_ui import LeafNode as V3LeafNode
+
+ return V3LeafNode(result_idx=self.leaf_value)
+
+
+# Union type for all data layer nodes
+NodeData = t_ext.Annotated[
+ t.Union[RangeNode, NumericalNode, CategoricalNode, StringMatchNode, LeafNode],
+ Field(discriminator="node_type"),
+]
+
+
+class Position(BaseModel):
+ x: float
+ y: float
+
+
+class PositionedNode(BaseModel):
+ id: str
+ position: Position = Field(default_factory=lambda: Position(x=0, y=0))
+ data: NodeData
diff --git a/decider/modules/rules/tree/v2/tree.py b/decider/modules/rules/tree/v2/tree.py
new file mode 100644
index 0000000..cc21e8a
--- /dev/null
+++ b/decider/modules/rules/tree/v2/tree.py
@@ -0,0 +1,113 @@
+import typing as t
+import re
+import polars as pl
+from pydantic import BaseModel, Field, PrivateAttr, model_validator
+from .nodes import NodeData, PositionedNode
+from ..v1.edges import MultiSourceEdge
+from logging import getLogger
+from ...common.shared import WithTreeOutput, TreeOutput
+from .....serializable.schema import PolarsSchema
+from ...common.parameters import WithParameters
+from decider.modules.core import BaseExecuteModule
+from decider.types import TInputType, TOutputType
+
+if t.TYPE_CHECKING:
+ from ..v3.tree import Tree as V3Tree
+ from decider.executor import Executor
+
+
+logger = getLogger(__name__)
+
+
+class TreeMetadata(BaseModel):
+ name: t.Optional[str] = None
+ description: t.Optional[str] = None
+
+
+class SubTree(BaseModel):
+ id: t.Optional[str] = None
+ name: t.Optional[str] = None
+
+
+class Tree(WithTreeOutput, BaseExecuteModule, WithParameters):
+ type: t.Literal["v2-tree"]
+ name: str = "output"
+ metadata: TreeMetadata | None = None
+ edges: t.List[MultiSourceEdge]
+ nodes: t.List[PositionedNode]
+ subtrees: t.List[SubTree] = Field(default_factory=list)
+ input_schema: t.Optional[PolarsSchema] = Field(
+ default=None, description="Input schema for casting inputs at runtime"
+ )
+
+ format_version: t.Literal[2] = Field(alias="formatVersion", default=2)
+
+ def get_required_parameters(self) -> t.Set[str]:
+ required_parameters = set()
+ for node in self.nodes:
+ required_parameters.update(node.data.get_required_parameters())
+ return required_parameters
+
+ def to_tree_module(self) -> "t.Any":
+ """Convert to FlatRuleModule via v3 upgrade.
+
+ This method exists for backward compatibility with tests.
+ The conversion path is: v2 → v3 → flat rules → FlatRuleModule.
+ """
+ v3_tree = self.upgrade()
+ return v3_tree.to_tree_module()
+
+ def upgrade(self) -> "V3Tree":
+ """Upgrade v2 tree to v3 format.
+
+ V3 is the latest format with unified types shared with flat rules.
+ V2 nodes are converted to v3 nodes via their to_v3_node() method.
+ """
+ from ..v3.tree import Tree as V3Tree
+ from ..v3.tree import TreeMetadata as V3TreeMetadata
+ from ..v3.tree import SubTree as V3SubTree
+ from ..v3.nodes_ui import PositionedNode as V3PositionedNode
+ from ..v3.nodes_ui import Position as V3Position
+
+ # Convert nodes - v2 nodes need to be converted to v3 node format
+ v3_nodes = []
+ for node in self.nodes:
+ v3_data = node.data.to_v3_node()
+ v3_nodes.append(
+ V3PositionedNode(
+ id=node.id,
+ position=V3Position(x=node.position.x, y=node.position.y),
+ data=v3_data,
+ )
+ )
+
+ # Convert metadata
+ v3_metadata = None
+ if self.metadata:
+ v3_metadata = V3TreeMetadata(
+ name=self.metadata.name,
+ description=self.metadata.description,
+ )
+
+ # Convert subtrees
+ v3_subtrees = [V3SubTree(id=st.id, name=st.name) for st in self.subtrees]
+
+ # Edges remain the same format (MultiSourceEdge)
+ return V3Tree(
+ metadata=v3_metadata,
+ edges=self.edges,
+ nodes=v3_nodes,
+ subtrees=v3_subtrees,
+ output=self.output,
+ parameters=self.parameters,
+ parameters_col=self.parameters_col,
+ input_schema=self.input_schema,
+ )
+
+ def execute(self, inputs: TInputType, _executor: "Executor") -> TOutputType:
+ frame = inputs["input"]
+ if isinstance(frame, pl.LazyFrame):
+ frame = frame.collect()
+ v3 = self.upgrade()
+ compiled = v3.to_tree_module().build_expression()
+ return frame.select(compiled.expr.struct.unnest()).lazy()
diff --git a/decider/modules/rules/tree/v3/__init__.py b/decider/modules/rules/tree/v3/__init__.py
new file mode 100644
index 0000000..9c64da0
--- /dev/null
+++ b/decider/modules/rules/tree/v3/__init__.py
@@ -0,0 +1,94 @@
+"""V3 Tree nodes following flat rules conventions.
+
+Unified type system that shares core types with flat_rules.
+Nodes are managed via edges (tree.py), not embedded in node definitions.
+"""
+
+from ...common.nodetypes import (
+ BaseRule,
+ NodeMeta,
+ NodePosition,
+)
+
+from .nodes_ui import (
+ # Node types
+ LeafNode,
+ UnaryNode,
+ CasesRanges,
+ CasesStringMatch,
+ CasesIsIn,
+ CompositeNode,
+ NodeData,
+ PositionedNode,
+ Position,
+ # Unary operators (aliased in nodes_ui from common.nodes)
+ UnaryLeq,
+ UnaryLt,
+ UnaryEq,
+ UnaryGt,
+ UnaryGeq,
+ UnaryNeq,
+ BetweenOp,
+ IsInOp,
+ StringMatchOp,
+ IsNullOp,
+ IsNotNullOp,
+ IsTrueOp,
+ IsFalseOp,
+)
+
+# TUnaryOp and condition types live in common.nodes
+from ...common.nodes import (
+ TUnaryOp,
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+ CompositeCondition,
+)
+
+from .tree import (
+ Tree,
+ TreeMetadata,
+ SubTree,
+)
+
+__all__ = [
+ # Base types
+ "BaseRule",
+ "NodeMeta",
+ "NodePosition",
+ # Nodes
+ "LeafNode",
+ "UnaryNode",
+ "CasesRanges",
+ "CasesStringMatch",
+ "CasesIsIn",
+ "CompositeNode",
+ "NodeData",
+ "PositionedNode",
+ "Position",
+ # Unary operators
+ "UnaryLeq",
+ "UnaryLt",
+ "UnaryEq",
+ "UnaryGt",
+ "UnaryGeq",
+ "UnaryNeq",
+ "BetweenOp",
+ "IsInOp",
+ "StringMatchOp",
+ "IsNullOp",
+ "IsNotNullOp",
+ "IsTrueOp",
+ "IsFalseOp",
+ "TUnaryOp",
+ # Conditions
+ "RangeCondition",
+ "StringMatchCondition",
+ "IsInCondition",
+ "CompositeCondition",
+ # Tree structure
+ "Tree",
+ "TreeMetadata",
+ "SubTree",
+]
diff --git a/decider/modules/rules/tree/v3/nodes_ui.py b/decider/modules/rules/tree/v3/nodes_ui.py
new file mode 100644
index 0000000..82f0016
--- /dev/null
+++ b/decider/modules/rules/tree/v3/nodes_ui.py
@@ -0,0 +1,307 @@
+"""V3 Tree nodes — UI/tree representation.
+
+All structural definitions (operators, conditions, base nodes) live in
+dspd.components.common.nodes. This module adds only to_flat_rule_node()
+for converting the edge-based tree representation to flat rules.
+"""
+
+import typing as t
+from pydantic import BaseModel, Field, RootModel
+
+from ...common.nodetypes import (
+ LeafNodeCore,
+ BaseRule,
+ NodeMeta,
+ NodePosition,
+ TLogicOp,
+ RangeEndLogic,
+)
+from ...common.nodes import (
+ TUnaryOp,
+ RangeCondition,
+ StringMatchCondition,
+ IsInCondition,
+ CasesBranch,
+ TCaseCondition,
+ CompositeCondition,
+ TCondition,
+ BaseUnaryNode,
+ BaseCasesRanges,
+ BaseCasesStringMatch,
+ BaseCasesIsIn,
+ BaseCompositeNode,
+ # Re-export operators so v3/__init__.py keeps working
+ UnaryLessThanEqual as UnaryLeq,
+ UnaryLessThan as UnaryLt,
+ UnaryEqual as UnaryEq,
+ UnaryGreaterThan as UnaryGt,
+ UnaryGreaterThanEqual as UnaryGeq,
+ UnaryNotEqual as UnaryNeq,
+ UnaryBetween as BetweenOp,
+ UnaryIsIn as IsInOp,
+ UnaryStringMatch as StringMatchOp,
+ UnaryIsNull as IsNullOp,
+ UnaryIsNotNull as IsNotNullOp,
+ UnaryIsTrue as IsTrueOp,
+ UnaryIsFalse as IsFalseOp,
+)
+
+if t.TYPE_CHECKING:
+ from ...flat_rules.nodes import RuleType
+
+
+# =============================================================================
+# Leaf Node
+# =============================================================================
+
+
+class LeafNode(LeafNodeCore):
+ """Leaf node for tree v3."""
+
+ NODE_TYPE: t.ClassVar[str] = "leaf"
+ id: t.Optional[str] = Field(default=None)
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import LeafRule
+
+ return LeafRule(result_idx=self.result_idx)
+
+ def get_required_features(self) -> t.Set[str]:
+ return set()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return set()
+
+
+# =============================================================================
+# Unary Node
+# =============================================================================
+
+
+class UnaryNode(BaseUnaryNode):
+ """Single condition node — children resolved via graph edges.
+
+ Edge mapping:
+ - sourceIndex=0 -> 'then' branch (condition true)
+ - sourceIndex=1 -> 'otherwise' branch (condition false)
+ """
+
+ NODE_TYPE: t.ClassVar[str] = "unary"
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import UnaryRule
+
+ return UnaryRule(
+ id=node_id,
+ condition=self.condition,
+ then=get_child(0),
+ otherwise=get_child(1),
+ )
+
+
+# =============================================================================
+# Cases Nodes
+# =============================================================================
+
+
+class CasesRanges(BaseCasesRanges):
+ """Multi-way range branching — children resolved via graph edges.
+
+ Edge mapping:
+ - sourceIndex=0..N-1 -> branch for conditions[0..N-1]
+ - sourceIndex=N -> 'otherwise' branch
+ """
+
+ NODE_TYPE: t.ClassVar[str] = "cases"
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import (
+ CasesRanges as FlatCasesRanges,
+ CasesBranch as FlatCasesBranch,
+ )
+
+ flat_conditions = [
+ FlatCasesBranch(when=cond, then=i) for i, cond in enumerate(self.conditions)
+ ]
+ otherwise_idx = len(self.conditions)
+ flat_branches = [get_child(i) for i in range(otherwise_idx)] + [
+ get_child(otherwise_idx)
+ ]
+
+ return FlatCasesRanges(
+ id=node_id,
+ feature=self.feature,
+ conditions=flat_conditions,
+ end_logic=self.end_logic,
+ strict=self.strict,
+ otherwise=otherwise_idx,
+ branches=flat_branches,
+ )
+
+
+class CasesStringMatch(BaseCasesStringMatch):
+ """Multi-way string matching — children resolved via graph edges.
+
+ Edge mapping:
+ - sourceIndex=0..N-1 -> branch for conditions[0..N-1]
+ - sourceIndex=N -> 'otherwise' branch
+ """
+
+ NODE_TYPE: t.ClassVar[str] = "cases"
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import (
+ CasesStringMatch as FlatCasesStringMatch,
+ CasesBranch as FlatCasesBranch,
+ )
+
+ flat_conditions = [
+ FlatCasesBranch(when=cond, then=i) for i, cond in enumerate(self.conditions)
+ ]
+ otherwise_idx = len(self.conditions)
+ flat_branches = [get_child(i) for i in range(otherwise_idx)] + [
+ get_child(otherwise_idx)
+ ]
+
+ return FlatCasesStringMatch(
+ id=node_id,
+ feature=self.feature,
+ match_type=self.match_type,
+ case_sensitive=self.case_sensitive,
+ trim_whitespace=self.trim_whitespace,
+ conditions=flat_conditions,
+ otherwise=otherwise_idx,
+ branches=flat_branches,
+ )
+
+
+class CasesIsIn(BaseCasesIsIn):
+ """Multi-way categorical branching — children resolved via graph edges.
+
+ Edge mapping:
+ - sourceIndex=0..N-1 -> branch for conditions[0..N-1]
+ - sourceIndex=N -> 'otherwise' branch
+ """
+
+ NODE_TYPE: t.ClassVar[str] = "cases"
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import (
+ CasesIsIn as FlatCasesIsIn,
+ CasesBranch as FlatCasesBranch,
+ )
+
+ flat_conditions = [
+ FlatCasesBranch(when=cond, then=i) for i, cond in enumerate(self.conditions)
+ ]
+ otherwise_idx = len(self.conditions)
+ flat_branches = [get_child(i) for i in range(otherwise_idx)] + [
+ get_child(otherwise_idx)
+ ]
+
+ return FlatCasesIsIn(
+ id=node_id,
+ feature=self.feature,
+ conditions=flat_conditions,
+ otherwise=otherwise_idx,
+ branches=flat_branches,
+ )
+
+
+_TCasesVariant = t.Annotated[
+ t.Union[CasesRanges, CasesStringMatch, CasesIsIn],
+ Field(discriminator="op"),
+]
+
+
+class CasesNode(RootModel[_TCasesVariant]):
+ """Wrapper for all Cases node variants (discriminated by 'op' field)."""
+
+ root: _TCasesVariant
+
+ @property
+ def type(self) -> str:
+ return "cases"
+
+ @property
+ def id(self) -> t.Optional[str]:
+ return self.root.id
+
+ @property
+ def feature(self):
+ return self.root.feature
+
+ @property
+ def op(self) -> str:
+ return self.root.op
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ return self.root.to_flat_rule_node(node_id, get_child)
+
+ def get_required_features(self) -> t.Set[str]:
+ return self.root.get_required_features()
+
+ def get_required_parameters(self) -> t.Set[str]:
+ return self.root.get_required_parameters()
+
+
+# =============================================================================
+# Composite Node
+# =============================================================================
+
+
+class CompositeNode(BaseCompositeNode):
+ """Composite AND/OR/NOT node — children resolved via graph edges.
+
+ Edge mapping:
+ - sourceIndex=0 -> 'then' branch (composite true)
+ - sourceIndex=1 -> 'otherwise' branch (composite false)
+ """
+
+ NODE_TYPE: t.ClassVar[str] = "composite"
+
+ def to_flat_rule_node(
+ self, node_id: str, get_child: t.Callable[[int], "RuleType"]
+ ) -> "RuleType":
+ from ...flat_rules.nodes import CompositeRule
+
+ return CompositeRule(
+ id=node_id,
+ op=self.op,
+ conditions=self.conditions,
+ then=get_child(0),
+ otherwise=get_child(1),
+ )
+
+
+# =============================================================================
+# Node Data Union and Positioned Node
+# =============================================================================
+
+NodeData = t.Annotated[
+ t.Union[LeafNode, UnaryNode, CasesNode, CompositeNode],
+ Field(discriminator="type"),
+]
+
+
+class Position(BaseModel):
+ x: float
+ y: float
+
+
+class PositionedNode(BaseModel):
+ id: str
+ position: Position = Field(default_factory=lambda: Position(x=0, y=0))
+ data: NodeData
diff --git a/decider/modules/rules/tree/v3/tree.py b/decider/modules/rules/tree/v3/tree.py
new file mode 100644
index 0000000..c021593
--- /dev/null
+++ b/decider/modules/rules/tree/v3/tree.py
@@ -0,0 +1,177 @@
+"""V3 Tree structure - manages nodes and edges separately.
+
+Following v2 conventions where tree structure is defined by edges,
+not embedded in nodes. This allows the same node type definitions to work
+with both inline (for flat rules) and graph-based (for UI) representations.
+"""
+
+import typing as t
+import polars as pl
+from pydantic import BaseModel, Field
+from .nodes_ui import NodeData, PositionedNode
+from ..v1.edges import MultiSourceEdge
+from ...common.shared import WithTreeOutput, TreeOutput
+from .....serializable.schema import PolarsSchema
+from ...common.parameters import WithParameters
+from decider.modules.core import BaseExecuteModule
+from decider.types import TInputType, TOutputType
+
+if t.TYPE_CHECKING:
+ from ...flat_rules.nodes import RuleType
+ from decider.executor import Executor
+
+
+class TreeMetadata(BaseModel):
+ """Metadata for the tree."""
+
+ name: t.Optional[str] = None
+ description: t.Optional[str] = None
+
+
+class SubTree(BaseModel):
+ """Subtree identifier."""
+
+ id: t.Optional[str] = None
+ name: t.Optional[str] = None
+
+
+class Tree(WithTreeOutput, BaseExecuteModule, WithParameters):
+ """V3 Tree structure with nodes and edges.
+
+ Structure:
+ - Nodes contain data (conditions, operators, leaf values)
+ - Edges define the tree structure via sourceIndex
+ - For UnaryNode: sourceIndex=0 is 'then', sourceIndex=1 is 'otherwise'
+ - For CasesNode: sourceIndex maps to branch indices
+ - For CompositeNode: sourceIndex=0 is 'then', sourceIndex=1 is 'otherwise'
+ """
+
+ type: t.Literal["v3-tree"]
+ name: str = "output"
+ metadata: t.Optional[TreeMetadata] = None
+ edges: t.List[MultiSourceEdge]
+ nodes: t.List[PositionedNode]
+ subtrees: t.List[SubTree] = Field(default_factory=list)
+ input_schema: t.Optional[PolarsSchema] = Field(
+ default=None, description="Input schema for casting inputs at runtime"
+ )
+
+ format_version: t.Literal[3] = Field(alias="formatVersion", default=3)
+
+ def get_required_parameters(self) -> t.Set[str]:
+ """Get all parameters required by nodes in this tree."""
+ required_parameters = set()
+ for node in self.nodes:
+ required_parameters.update(node.data.get_required_parameters())
+ return required_parameters
+
+ def to_tree_module(self) -> "t.Any":
+ """Convert to FlatRuleModule for backward compatibility.
+
+ This method exists for compatibility with tests that expect to_tree_module().
+ The new architecture uses to_flat_rule_tree() → FlatRuleModule.
+ """
+ from ...flat_rules.nodes import RuleRoot, RuleMeta
+ from ...flat_rules.module import FlatRuleModule
+
+ # Convert to flat rule tree
+ flat_rule = self.to_flat_rule_tree()
+
+ # Wrap in RuleRoot
+ rule_root = RuleRoot(meta=RuleMeta(), rule=flat_rule)
+
+ # Create FlatRuleModule
+ return FlatRuleModule(
+ rule=rule_root,
+ output=self.output,
+ parameters=self.parameters if hasattr(self, "parameters") else {},
+ parameters_col=(
+ self.parameters_col if hasattr(self, "parameters_col") else "parameters"
+ ),
+ )
+
+ def to_flat_rule_tree(self) -> "RuleType":
+ """Convert this v3 UI tree to flat rules.
+
+ This reconstructs the tree structure from edges and converts each node
+ to its flat rule equivalent, preserving position metadata.
+ """
+ from ...flat_rules.nodes import LeafRule
+ from ...common.nodetypes import NodeMeta, NodePosition
+
+ keyed_nodes = {n.id: n for n in self.nodes}
+
+ def _meta(node_id: str) -> NodeMeta:
+ """Create NodeMeta from positioned node's position."""
+ pos = keyed_nodes[node_id].position
+ return NodeMeta(position=NodePosition(x=pos.x, y=pos.y))
+
+ # Build adjacency: node_id -> {source_index -> target_node_id}
+ children: t.Dict[str, t.Dict[int, str]] = {}
+ for edge in self.edges:
+ for si in edge.data.sourceIndex:
+ children.setdefault(edge.source, {})[si] = edge.target
+
+ def _build(node_id: str) -> "RuleType":
+ """Recursively build flat rule tree from node + edges."""
+ node = keyed_nodes[node_id]
+ data = node.data
+ node_children = children.get(node_id, {})
+ meta = _meta(node_id)
+
+ def get_child(idx: int) -> "RuleType":
+ """Get child at source index, or default leaf if not connected."""
+ target_id = node_children.get(idx)
+ return (
+ LeafRule(result_idx=-1) if target_id is None else _build(target_id)
+ )
+
+ # Convert node to flat rule, passing get_child callback
+ # Note: data.meta is set from node.data, but we also pass meta from position
+ flat_node = data.to_flat_rule_node(node_id, get_child)
+ # Update meta from UI position if not already set
+ if flat_node.meta is None:
+ flat_node.meta = meta
+ return flat_node
+
+ # Find root nodes (nodes with no incoming edges)
+ root_keys = {n.id for n in self.nodes} - {e.target for e in self.edges}
+
+ if len(root_keys) == 0:
+ raise ValueError("Tree has no root nodes (circular structure)")
+ if len(root_keys) > 1:
+ # Multiple roots - use first subtree root or first node
+ subtree_mapping = {
+ st.id: i for i, st in enumerate(self.subtrees) if st.id is not None
+ }
+ node_mapping = {n.id: i for i, n in enumerate(self.nodes)}
+
+ ordered_roots = sorted(
+ root_keys,
+ key=lambda rk: (
+ subtree_mapping.get(rk, float("inf")),
+ node_mapping.get(rk, float("inf")),
+ ),
+ )
+ root_id = ordered_roots[0]
+ else:
+ root_id = next(iter(root_keys))
+
+ return _build(root_id)
+
+ def execute(self, inputs: TInputType, _executor: "Executor") -> TOutputType:
+ from ...flat_rules.nodes import RuleRoot, RuleMeta
+ from ...flat_rules.module import FlatRuleModule
+
+ flat_rule = self.to_flat_rule_tree()
+ module = FlatRuleModule(
+ rule=RuleRoot(meta=RuleMeta(), rule=flat_rule),
+ output=self.output,
+ parameters=self.parameters,
+ parameters_col=self.parameters_col,
+ )
+ frame = inputs["input"]
+ if isinstance(frame, pl.LazyFrame):
+ frame = frame.collect()
+ compiled = module.build_expression()
+ return frame.select(compiled.expr.struct.unnest()).lazy()
diff --git a/decider/modules/util.py b/decider/modules/util.py
new file mode 100644
index 0000000..a2f5eeb
--- /dev/null
+++ b/decider/modules/util.py
@@ -0,0 +1,88 @@
+import typing as t
+from hamilton import node
+import polars as pl
+
+
+def create_node_with_mapping(
+ func: t.Callable,
+ name: str = None,
+ input_mapping: t.Optional[t.Dict[str, str]] = None,
+ partial_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
+) -> node.Node:
+ """Creates a Hamilton node from a function with parameter mapping and partial application.
+
+ Args:
+ func: The function to wrap in a node
+ input_mapping: Dictionary mapping external parameter names to internal function parameter names
+ e.g., {"variable_name": "input"} maps external "variable_name" to function's "input" param
+ partial_kwargs: Dictionary of keyword arguments to partially apply to the function
+
+ Returns:
+ A Hamilton node with appropriately mapped inputs and proper type information
+
+ Example:
+ # Function expects parameter 'input', but DAG will provide 'my_variable'
+ node = create_node_with_mapping(
+ score_variable,
+ input_mapping={"my_variable": "input"},
+ partial_kwargs={"bins": [...], "default": DefaultBin(...)}
+ )
+ """
+ if partial_kwargs is None:
+ partial_kwargs = {}
+ if input_mapping is None:
+ input_mapping = {}
+
+ # Create the original node to get input types
+ original_node = node.Node.from_fn(func)
+
+ def wrapper_function(**kwargs):
+ nonlocal func, input_mapping, partial_kwargs
+ # Apply input mapping: map external names to internal parameter names
+ mapped_kwargs = {}
+ for external_name, value in kwargs.items():
+ internal_name = input_mapping.get(external_name, external_name)
+ mapped_kwargs[internal_name] = value
+
+ # Merge with partial kwargs (partial kwargs take precedence)
+ final_kwargs = {**mapped_kwargs, **partial_kwargs}
+
+ return func(**final_kwargs)
+
+ # Validate partial kwargs against the original function's parameters
+ function_allows_kwargs = "kwargs" in original_node.input_types
+
+ if not function_allows_kwargs:
+ for kwarg in partial_kwargs:
+ if kwarg not in original_node.input_types:
+ raise ValueError(f"Partial argument '{kwarg}' is not a valid parameter of the function.")
+
+ # Build new input types based on mapping
+ new_input_types = {}
+
+ for external_name, internal_name in input_mapping.items():
+ if internal_name in original_node.input_types:
+ # Map the type from internal to external parameter
+ new_input_types[external_name] = original_node.input_types[internal_name]
+ else:
+ if not function_allows_kwargs:
+ raise ValueError(f"Original function must contain parameters in the mapping or have a **kwargs parameter to use for unmapped parameters. Missing parameter: {internal_name}")
+ new_input_types[external_name] = original_node.input_types["kwargs"]
+
+
+ # Add any remaining input types that weren't mapped and aren't in partial_kwargs
+ for param_name, param_info in original_node.input_types.items():
+ if param_name not in partial_kwargs and param_name not in input_mapping.values() and param_name != "kwargs":
+ new_input_types[param_name] = param_info
+
+ # If no input mapping was needed, just apply partial kwargs
+ if not input_mapping and not partial_kwargs:
+ if name is not None:
+ return original_node.copy_with(name=name)
+ return original_node
+
+ return original_node.copy_with(
+ callabl=wrapper_function,
+ input_types=new_input_types,
+ name=name or original_node.name
+ )
\ No newline at end of file
diff --git a/spockflow/inference/__init__.py b/decider/serializable/__init__.py
similarity index 100%
rename from spockflow/inference/__init__.py
rename to decider/serializable/__init__.py
diff --git a/decider/serializable/dataframe.py b/decider/serializable/dataframe.py
new file mode 100644
index 0000000..60f0b94
--- /dev/null
+++ b/decider/serializable/dataframe.py
@@ -0,0 +1,61 @@
+import typing as t
+import polars as pl
+from pydantic import BaseModel, model_validator, model_serializer, PrivateAttr
+from .schema import PolarsSchema
+
+TDataFrameRow = t.Dict[str, t.Any]
+TDataFrameData = t.List[TDataFrameRow]
+
+
+def build_polars_df(
+ data: TDataFrameData,
+ schema: t.Optional[PolarsSchema] = None,
+) -> pl.DataFrame:
+ """Construct a Polars DataFrame from raw dict data and an optional PolarsSchema,
+ raising a descriptive ValueError on any construction failure."""
+ try:
+ if schema:
+ return pl.DataFrame(data, schema=schema.schema)
+ return pl.DataFrame(data)
+ except Exception as e:
+ schema_str = str(schema.root) if schema else "inferred"
+ raise ValueError(f"Could not load data into schema ({schema_str}): {e}") from e
+
+
+class DataFrame(BaseModel):
+ data: TDataFrameData
+ dtypes: t.Optional[PolarsSchema] = None
+
+ _pl_df: pl.DataFrame = PrivateAttr()
+
+ @classmethod
+ def from_dataframe(cls, df: pl.DataFrame):
+ data = df.to_dicts()
+ ret = cls(data=data)
+ ret.infer_schema()
+ return ret
+
+ @model_validator(mode="before")
+ @staticmethod
+ def enable_raw_data(data: t.Any):
+ if isinstance(data, list):
+ return {"data": data}
+ return data
+
+ @model_validator(mode="after")
+ def construct_polars_df(self):
+ self._pl_df = build_polars_df(self.data, self.dtypes)
+ return self
+
+ @model_serializer
+ def serialize(self) -> t.Any:
+ if self.dtypes is None:
+ return self.data
+ return {"data": self.data, "schema": self.dtypes.model_dump()}
+
+ def infer_schema(self):
+ self.dtypes = PolarsSchema.from_polars_schema(self.df.collect_schema())
+
+ @property
+ def df(self):
+ return self._pl_df
diff --git a/decider/serializable/dtypes.py b/decider/serializable/dtypes.py
new file mode 100644
index 0000000..b985158
--- /dev/null
+++ b/decider/serializable/dtypes.py
@@ -0,0 +1,148 @@
+import typing as t
+import polars as pl
+import polars.datatypes.classes as polars_dtypes
+from pydantic import BaseModel, model_validator, PrivateAttr, Field
+from .schema import TStruct, handle_type, convert_schema
+from .dataframe import TDataFrameData, TDataFrameRow
+
+
+class StructDefinition(BaseModel):
+ # name: str
+ fields: TStruct
+ data: TDataFrameData
+ display_field: str
+
+
+class StructTypeDef(BaseModel):
+ name: str
+ type: t.Literal["struct"] = "struct"
+ definition: StructDefinition
+ # data: TDataFrameData
+
+ _schema: pl.Schema = PrivateAttr()
+ _dtype: polars_dtypes.DataType = PrivateAttr()
+ _pl_df: pl.DataFrame = PrivateAttr()
+ _indexed_values: t.Dict[str, TDataFrameRow] = PrivateAttr()
+
+ @model_validator(mode="after")
+ def construct_polars_df(self):
+ # Convert definition to polars struct type (without type_defs to avoid circular lookups)
+ struct_type = handle_type(self.definition.fields, type_defs=None)
+ assert isinstance(
+ struct_type, polars_dtypes.Struct
+ ), "Expected StructTypeDef.definition to produce a Struct type"
+
+ self._dtype = struct_type
+
+ # Extract schema from struct fields
+ fields = [
+ (
+ f.name if hasattr(f, "name") else f[0],
+ f.dtype if hasattr(f, "dtype") else f[1],
+ )
+ for f in struct_type.fields
+ ]
+ self._schema = pl.Schema(fields)
+
+ # Build dataframe with the schema
+ try:
+ self._pl_df = pl.DataFrame(self.definition.data, schema=self._schema)
+ except Exception as e:
+ raise ValueError(
+ f"Could not load data into struct schema ({self.definition.fields}): {e}"
+ ) from e
+
+ indexed_values = {}
+ for record in self.definition.data:
+ key = record.get(self.definition.display_field)
+ if key is None:
+ raise ValueError(
+ f"Display field '{self.definition.display_field}' not found in record: {record}"
+ )
+ if key in indexed_values:
+ raise ValueError(
+ f"Duplicate key '{key}' found for display field '{self.definition.display_field}' in struct data"
+ )
+ indexed_values[key] = record
+ self._indexed_values = indexed_values
+ return self
+
+ def get_value_for_key(self, key: str) -> TDataFrameRow:
+ value = self._indexed_values.get(key)
+ if value is None:
+ raise ValueError(
+ f"Key '{key}' not found in struct definition '{self.name}'"
+ )
+ return value
+
+ @property
+ def df(self):
+ return self._pl_df
+
+ @property
+ def schema(self):
+ return self._schema
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+
+class _CatDefinition(BaseModel):
+ categories: t.List[str]
+
+
+class CategoricalTypeDef(BaseModel):
+ name: str
+ type: t.Literal["categorical"] = "categorical"
+ definition: _CatDefinition
+ _dtype: polars_dtypes.DataType = PrivateAttr()
+
+ @model_validator(mode="after")
+ def construct_schema(self):
+ # Create an Enum polars type with the pre-defined categories
+ # Categories are already a list, so order is preserved
+ self._dtype = polars_dtypes.Enum(self.definition.categories)
+ return self
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+
+TTypeDef = t.Annotated[
+ t.Union[StructTypeDef, CategoricalTypeDef], Field(discriminator="type")
+]
+
+
+class ContainsDtypes(BaseModel):
+ dtypes: TStruct
+ type_defs: t.Dict[str, TTypeDef]
+
+ _polars_schema: pl.Schema = PrivateAttr()
+
+ @model_validator(mode="after")
+ def _convert_schema(self) -> "t.Self":
+ # Pass type_defs to handle_type so it can resolve type references
+ schema = handle_type(self.dtypes, self.type_defs)
+ assert isinstance(
+ schema, polars_dtypes.Struct
+ ), "Expected upper level to be a struct."
+ try:
+ fields = [
+ (
+ f.name if hasattr(f, "name") else f[0],
+ f.dtype if hasattr(f, "dtype") else f[1],
+ )
+ for f in schema.fields
+ ]
+ self._polars_schema = pl.Schema(fields)
+ except pl.exceptions.DuplicateError as e:
+ raise ValueError(
+ f"Found one or more duplicate keys in a struct field. Detail: {e}"
+ )
+ return self
+
+ @property
+ def schema(self):
+ return self._polars_schema
diff --git a/decider/serializable/function.py b/decider/serializable/function.py
new file mode 100644
index 0000000..2367954
--- /dev/null
+++ b/decider/serializable/function.py
@@ -0,0 +1,24 @@
+import typing as t
+from pydantic import BaseModel, PrivateAttr
+
+
+class DefinedFunction(BaseModel):
+ module_name: str
+ function_name: str
+
+ _function: t.Optional[t.Callable[..., t.Any]] = PrivateAttr(default=None)
+
+ @classmethod
+ def from_function(cls, fn: t.Callable[..., t.Any]) -> "DefinedFunction":
+ return cls(module_name=fn.__module__, function_name=fn.__name__)
+
+ def get_function(self) -> t.Callable[..., t.Any]:
+ if self._function is None:
+ from importlib import import_module
+
+ module = import_module(self.module_name)
+ self._function = getattr(module, self.function_name)
+ return self._function
+
+ def __call__(self, *args, **kwds):
+ return self.get_function()(*args, **kwds)
diff --git a/decider/serializable/schema.py b/decider/serializable/schema.py
new file mode 100644
index 0000000..464f529
--- /dev/null
+++ b/decider/serializable/schema.py
@@ -0,0 +1,341 @@
+import json
+import inspect
+import typing as t
+import typing_extensions as te
+from functools import cache
+
+from pydantic import (
+ RootModel,
+ BaseModel,
+ ConfigDict,
+ model_validator,
+ PrivateAttr,
+ Discriminator,
+ Tag,
+)
+
+import polars as pl
+import polars.datatypes.classes as polars_dtypes
+from polars.datatypes import try_parse_into_dtype
+
+if t.TYPE_CHECKING:
+ from .dtypes import TTypeDef
+
+
+class ExplicitType(BaseModel):
+ model_config = ConfigDict(extra="allow")
+ type: str
+
+ # Could do some validator here to distinguish between ExplicitType and TStruct based on if the type exists in polars_dtypes
+ # However this could lead to confusing messages if someone makes a typo.
+ # I would rather have a user have to type out {value:[("type": "Int64"), ]} if they wanted to create a struct with "type" as an attribute
+ @model_validator(mode="after")
+ def _validate_inner(self) -> "t.Self":
+ extra_dict = {**self.model_extra} # Make a shallow copy
+ key = "inner" if "inner" in extra_dict else "fields"
+ if key not in extra_dict:
+ return self
+ self.model_extra[key] = _RootTType.model_validate(extra_dict[key]).root
+
+ return self
+
+
+# This is needed to avoid recursion error
+# see: https://docs.pydantic.dev/2.11/concepts/types/#named-recursive-types
+TOrderedStructType = te.TypeAliasType(
+ "TOrderedStructType", "t.List[t.Tuple[str,TType]]"
+)
+TUnorderedStructType = te.TypeAliasType("TUnorderedStructType", "t.Dict[str,TType]")
+TStruct = te.TypeAliasType(
+ "TStruct",
+ "t.Union[TOrderedStructType, TUnorderedStructType]",
+)
+
+
+def _type_def_discriminator(type_def: t.Any) -> str:
+ if isinstance(type_def, str):
+ return "str"
+ if isinstance(type_def, list):
+ return "struct"
+ if isinstance(type_def, ExplicitType):
+ return "explicit"
+ if isinstance(type_def, dict):
+ if "type" in type_def:
+ return "explicit"
+ return "struct"
+ raise ValueError(
+ f"Could not determine type definition for {type_def}. Expected either a string, list, dict or explicit type definition."
+ )
+
+
+TType = t.Annotated[
+ t.Union[
+ t.Annotated[str, Tag("str")],
+ t.Annotated[ExplicitType, Tag("explicit")],
+ t.Annotated["TStruct", Tag("struct")],
+ ],
+ Discriminator(
+ _type_def_discriminator
+ ), # This is used to distinguish between explicit types and structs when the explicit type has extra fields. If the type field matches a known polars type we treat it as an explicit type otherwise we treat it as a struct
+]
+
+TPrimitiveType = t.Union[str, ExplicitType]
+
+
+class PrimitiveSchema(RootModel):
+ root: TPrimitiveType
+ _polars_type: pl.Schema = PrivateAttr()
+
+ @model_validator(mode="after")
+ def _convert_schema(self) -> "t.Self":
+ self._polars_type = handle_type(self.root)
+ return self
+
+ @property
+ def polars_type(self):
+ return self._polars_type
+
+
+class _RootTType(RootModel):
+ root: TType
+
+
+class PolarsSchema(RootModel):
+ root: TStruct
+
+ _polars_schema: pl.Schema = PrivateAttr()
+
+ @model_validator(mode="after")
+ def _convert_schema(self) -> "t.Self":
+ schema = handle_type(self.root)
+ assert isinstance(
+ schema, polars_dtypes.Struct
+ ), "Expected upper level to be a struct."
+ try:
+ fields = [
+ (
+ f.name if hasattr(f, "name") else f[0],
+ f.dtype if hasattr(f, "dtype") else f[1],
+ )
+ for f in schema.fields
+ ]
+ self._polars_schema = pl.Schema(fields)
+ except pl.exceptions.DuplicateError as e:
+ raise ValueError(
+ f"Found one or more duplicate keys in a struct field. Detail: {e}"
+ )
+ return self
+
+ @classmethod
+ def from_polars_schema(cls, schema: pl.Schema) -> "PolarsSchema":
+ """Create a PolarsSchema from an existing Polars Schema object."""
+ return cls(root=convert_schema(schema))
+
+ @property
+ def schema(self):
+ return self._polars_schema
+
+
+"""
+The following are constants that are used throughout the type conversion
+"""
+
+CUSTOM_TYPE_MAPPINGS: t.Dict[str, str] = {
+ # Add any custom mappings from user-friendly type names to Polars type names here
+ # For example, if you want to allow "integer" as an alias for "Int64":
+ "Set": "List",
+}
+
+
+@cache
+def get_allowed_types():
+ allowed_types = {
+ k.lower(): v
+ for k, v in inspect.getmembers(
+ polars_dtypes,
+ lambda tcls: inspect.isclass(tcls)
+ and issubclass(tcls, polars_dtypes.DataType),
+ )
+ }
+ allowed_types.update(
+ {k.lower(): allowed_types[v.lower()] for k, v in CUSTOM_TYPE_MAPPINGS.items()}
+ )
+ return allowed_types
+
+
+@cache
+def get_type_properties():
+ allowed_types = get_allowed_types()
+ # Maps class name → tuple of constructor parameter names (excluding self / *args / **kwargs).
+ # Auto-built by inspecting each DataType's __init__ signature so it stays in sync with the
+ # installed polars version.
+ type_properties: t.Dict[str, t.Tuple[str, ...]] = {}
+ for _cls_name, _cls in allowed_types.items():
+ try:
+ _sig = inspect.signature(_cls.__init__)
+ _params = [
+ p
+ for p in _sig.parameters.values()
+ if p.name != "self" and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
+ ]
+ if _params:
+ type_properties[_cls.__name__] = tuple(p.name for p in _params)
+ except (ValueError, TypeError):
+ pass
+ return type_properties
+
+
+"""
+The Following functions are used to convert from the json format into a polars schema
+"""
+
+
+def handle_type(
+ t: TType | "TStruct", type_defs: "t.Optional[t.Dict[str, TTypeDef]]" = None
+) -> pl.DataType:
+ if isinstance(
+ t, dict
+ ): # Unordered struct (Technically dicts should be ordered post 3.7)
+ return polars_dtypes.Struct(handle_kv_pair(t.items(), type_defs))
+ if isinstance(t, list): # Ordered struct
+ return polars_dtypes.Struct(handle_kv_pair(t, type_defs))
+ if isinstance(t, str):
+ return handle_str(t)
+ if isinstance(t, ExplicitType):
+ return handle_explicit_type(t, type_defs)
+ # This case should already be handled by the pydantic validator
+ raise ValueError(
+ f"Unexpected value {t}. Expected either a dict, list, string or explicit type"
+ )
+
+
+def handle_kv_pair(
+ it: t.Iterable[t.Tuple[str, TType | "TStruct"]],
+ type_defs: "t.Optional[t.Dict[str, TTypeDef]]" = None,
+):
+ return [pl.Field(k, handle_type(v, type_defs)) for k, v in it]
+
+
+def get_type_from_str(t: str):
+ return get_allowed_types().get(t.lower())
+
+
+def handle_str(t: str):
+ pl_type = get_type_from_str(t)
+ if pl_type is not None:
+ try:
+ return pl_type()
+ except TypeError:
+ # Could not construct type without args
+ # try to still see if polars can automatically get types
+ pass
+ out_type = try_parse_into_dtype(t)
+ if out_type is None:
+ raise ValueError(f"Could not convert {t} into a polars type")
+ return out_type
+
+
+def handle_explicit_type(
+ t: ExplicitType, type_defs: "t.Optional[t.Dict[str, TTypeDef]]" = None
+):
+ # Check if this is a type reference (has type_id in extra fields)
+ if t.type.lower() == "custom":
+ if t.model_extra is None or "type_id" not in t.model_extra:
+ raise ValueError(
+ "ExplicitType of type 'Custom' must have a 'type_id' in its extra fields."
+ )
+ type_id = t.model_extra["type_id"]
+ if type_defs is None:
+ raise ValueError(
+ f"Type reference with type_id {type_id} requires type_defs to be provided"
+ )
+
+ if type_id not in type_defs:
+ raise ValueError(f"Type ID {type_id} not found in type_defs")
+
+ # Simply return the dtype from the type definition
+ return type_defs[type_id].dtype
+
+ pl_type = get_type_from_str(t.type)
+ if pl_type is None:
+ raise ValueError(
+ f"Could not convert {t} into a polars type. No polars type matching {t.type}."
+ )
+ args = tuple()
+ extra_dict = {**t.model_extra} # Make a shallow copy
+ if issubclass(pl_type, polars_dtypes.NestedType):
+ inner_definition = extra_dict.pop("inner", extra_dict.pop("fields", None))
+ if inner_definition is None:
+ raise ValueError(
+ f"For nested type {t} expected either an 'inner' or a 'fields' config."
+ )
+ inner_schema = handle_type(inner_definition, type_defs)
+ args = (
+ inner_schema,
+ ) # Always the first arg for now we can maybe map inner to inner and fields to fields if needs be
+ try:
+ return pl_type(*args, **extra_dict)
+ except TypeError as e:
+ raise ValueError(
+ f"Could not construct type {t} from args {t.model_extra}. Got error: {e}."
+ )
+
+
+"""
+The following functions do the reverse and convert from a polars schema back into a json format
+"""
+
+
+def convert_kv_types(
+ itr: t.Iterable[t.Tuple[str, polars_dtypes.DataType]],
+) -> TOrderedStructType:
+ return [(k, convert_dtype(v)) for k, v in itr]
+
+
+def convert_schema(schema: pl.Schema) -> "TStruct":
+ """Convert a Polars Schema to the ordered TStruct representation."""
+ return [(name, convert_dtype(dtype)) for name, dtype in schema.items()]
+
+
+def convert_dtype(dtype: polars_dtypes.DataType) -> "TType | TStruct":
+ """Convert a Polars DataType instance back to its TType | TStruct JSON representation."""
+ # Struct → ordered list of (field_name, converted_dtype) pairs
+ if isinstance(dtype, polars_dtypes.DataTypeClass): # Unconstricted types
+ return dtype.__name__
+ if isinstance(dtype, polars_dtypes.Struct):
+ return convert_kv_types((field.name, field.dtype) for field in dtype.fields)
+
+ cls = type(dtype)
+ cls_name = cls.__name__
+ props = get_type_properties().get(cls_name)
+
+ # No constructor parameters → simple string e.g. "Int64"
+ if not props:
+ return cls_name
+
+ # Collect non-None property values, recursing into nested DataTypes
+ extra: t.Dict[str, t.Any] = {}
+ for prop in props:
+ val = getattr(dtype, prop, None)
+ if val is None:
+ continue
+ if isinstance(val, (polars_dtypes.DataType, polars_dtypes.DataTypeClass)):
+ # e.g. List.inner, Array.inner
+ extra[prop] = convert_dtype(val)
+ else:
+ try:
+ json.dumps(val)
+ except TypeError:
+ # Its a bit hacky but we want to make sure we can serialise the value
+ # Some options like Categorical have non-dumpable values
+ # We can try support more and more here but we need to be sure we dont break if those types
+ # are used
+ pass
+ else:
+ extra[prop] = val
+
+ # If all params were None/default, still just return the type name
+ if not extra:
+ return cls_name
+
+ return ExplicitType(type=cls_name, **extra)
diff --git a/spockflow/inference/config/__init__.py b/decider/serving/__init__.py
similarity index 100%
rename from spockflow/inference/config/__init__.py
rename to decider/serving/__init__.py
diff --git a/decider/serving/format.py b/decider/serving/format.py
new file mode 100644
index 0000000..904b16c
--- /dev/null
+++ b/decider/serving/format.py
@@ -0,0 +1,47 @@
+import typing as t
+from io import BytesIO
+import polars as pl
+from .media_types import MediaType
+
+class Response(t.NamedTuple):
+ content: bytes
+ media_type: t.Optional[str] = None
+
+
+
+def format_application_json(result: pl.DataFrame) -> "Response":
+ if len(result) == 1:
+ return Response(
+ content=result.write_json().removeprefix('[').removesuffix(']').encode("utf-8"),
+ media_type=MediaType.APPLICATION_JSON.value
+ )
+ else:
+ return Response(content=result.write_json().encode("utf-8"), media_type=MediaType.APPLICATION_JSON.value)
+
+def format_application_jsonl(result: pl.DataFrame) -> "Response":
+ return Response(
+ content=result.write_ndjson().encode("utf-8"),
+ media_type=MediaType.APPLICATION_JSONL.value
+ )
+
+def format_application_x_parquet(result: pl.DataFrame) -> "Response":
+ f = BytesIO()
+ result.write_parquet(f)
+ return Response(
+ content=f.getvalue(),
+ media_type=MediaType.APPLICATION_X_PARQUET.value
+ )
+
+def format_text_csv(result: pl.DataFrame) -> "Response":
+ return Response(
+ content=result.write_csv().encode("utf-8"),
+ media_type=MediaType.TEXT_CSV.value
+ )
+
+DEFAULT_OUTPUT_FORMATTERS = {
+ MediaType.ANY.value: format_application_json, # Default to JSON for any Accept header
+ MediaType.APPLICATION_JSON.value: format_application_json,
+ MediaType.APPLICATION_JSONL.value: format_application_jsonl,
+ MediaType.APPLICATION_X_PARQUET.value: format_application_x_parquet,
+ MediaType.TEXT_CSV.value: format_text_csv,
+}
\ No newline at end of file
diff --git a/decider/serving/handler.py b/decider/serving/handler.py
new file mode 100644
index 0000000..1c92e97
--- /dev/null
+++ b/decider/serving/handler.py
@@ -0,0 +1,113 @@
+from dataclasses import dataclass
+import asyncio
+import importlib.util
+import typing as t
+
+import polars as pl
+
+from decider.config import ConfigManager
+from decider.config.versioned import Version
+from decider.modules import GraphModule
+import decider.exceptions as exc
+from .parse import DEFAULT_INPUT_HANDLERS, ParserConfig
+from .format import DEFAULT_OUTPUT_FORMATTERS, Response
+from .media_types import MediaType
+
+
+@dataclass
+class RequestHandler:
+ config_manager: "ConfigManager"
+ root_module: str = "main"
+ _update_task: asyncio.Task = None
+ _constructed_module: GraphModule = None
+ _constructed_version: Version = None
+
+ async def init_fn(self):
+ await self.config_manager.get_latest()
+ self._update_task = asyncio.create_task(self.config_manager.subscribe_version_updates())
+
+
+ def module_fn(self) -> t.Tuple[GraphModule, ParserConfig]:
+ try:
+ with self.config_manager.current_version_context() as versioned_config:
+ if self._constructed_version is not None and versioned_config.version == self._constructed_version:
+ return self._constructed_module, ParserConfig(input_frame_keys=self._constructed_module.get_input_frame_keys())
+
+ module_config = versioned_config.config.get(self.root_module)
+ if module_config is None:
+ raise ValueError(f"No config found for root module '{self.root_module}' in the current versioned config.")
+ self._constructed_module = GraphModule.model_validate(module_config).root
+ self._constructed_version = versioned_config.version
+ return self._constructed_module, ParserConfig(input_frame_keys=self._constructed_module.get_input_frame_keys())
+
+ except exc.BaseConfigurationError:
+ raise
+ except ValueError as e:
+ raise exc.ModuleLoadError.from_value_error(e)
+ except Exception as e:
+ raise exc.ModuleLoadError(str(e))
+
+ async def input_fn(self, data: bytes, content_type: str, parse_config: t.Optional[ParserConfig] = None):
+ handler = DEFAULT_INPUT_HANDLERS.get(content_type)
+ if handler is None:
+ raise exc.UnsupportedContentTypeError(f"Unsupported content type: {content_type!r}")
+ try:
+ return await handler(data, parse_config or ParserConfig())
+ except Exception as e:
+ raise exc.InputParsingError(str(e))
+
+ def output_fn(self, output: pl.DataFrame, accept: str) -> Response:
+ formatter = DEFAULT_OUTPUT_FORMATTERS.get(accept)
+ if formatter is None:
+ raise exc.UnsupportedAcceptError(f"Unsupported Accept type: {accept!r}")
+ try:
+ response = formatter(output)
+ if response.media_type is None:
+ if accept == MediaType.ANY.value:
+ raise exc.DeciderRuntimeError("Configured Format for MediaType.ANY must return a Response with a specific media_type, got None")
+ response = Response(content=response.content, media_type=accept)
+ return response
+ except exc.DeciderError:
+ raise
+ except Exception as e:
+ raise exc.OutputFormattingError(str(e))
+
+ async def process_fn(self, data: bytes, accept: str, content_type: str) -> Response:
+ module, parse_config = self.module_fn()
+ input_data = await self.input_fn(data, content_type, parse_config)
+ result_df = module(input_data)
+ return self.output_fn(result_df, accept)
+
+ async def shutdown_fn(self):
+ self._update_task.cancel()
+ try:
+ await self._update_task
+ except asyncio.CancelledError:
+ pass
+
+
+def construct_handler_from_settings() -> RequestHandler:
+ import sys
+ import os
+ from decider.settings import settings
+
+ config_manager = settings.config.get()
+
+ handler_path: str = getattr(settings.api, "handler", "inference:Handler")
+ module_name, _, class_name = handler_path.partition(":")
+
+ handler_constructor = None
+
+ # Try to load a custom handler class from the working directory
+ module_file = os.path.join(os.getcwd(), f"{module_name}.py")
+ if os.path.exists(module_file):
+ spec = importlib.util.spec_from_file_location(module_name, module_file)
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = mod
+ spec.loader.exec_module(mod) # raises if file has errors
+ handler_constructor = getattr(mod, class_name, None)
+
+ if handler_constructor is None:
+ handler_constructor = RequestHandler
+
+ return handler_constructor(config_manager=config_manager)
\ No newline at end of file
diff --git a/decider/serving/media_types.py b/decider/serving/media_types.py
new file mode 100644
index 0000000..5a7af78
--- /dev/null
+++ b/decider/serving/media_types.py
@@ -0,0 +1,11 @@
+from enum import Enum
+
+class MediaType(Enum):
+ """A convenience enum to represent common url media types."""
+ ANY = "*/*"
+ APPLICATION_JSON = "application/json"
+ APPLICATION_JSONL = "application/jsonl"
+ APPLICATION_X_PARQUET = "application/x-parquet"
+ TEXT_CSV = "text/csv"
+ APPLICATION_EXCEL = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
+ APPLICATION_VND_MS_EXCEL = "application/vnd.ms-excel"
diff --git a/decider/serving/parse.py b/decider/serving/parse.py
new file mode 100644
index 0000000..8a57e91
--- /dev/null
+++ b/decider/serving/parse.py
@@ -0,0 +1,49 @@
+import json
+import typing as t
+import polars as pl
+from dataclasses import dataclass, field
+from .media_types import MediaType
+
+
+@dataclass(slots=True)
+class ParserConfig:
+ input_frame_keys: t.List[str] = field(default_factory=lambda: ["input"])
+
+
+TParsedMessage = t.Dict[str, pl.DataFrame]
+
+
+async def parse_application_json(data: bytes, parse_config: ParserConfig) -> TParsedMessage:
+ if len(parse_config.input_frame_keys) == 1:
+ return {"input": pl.read_json(data)}
+ else:
+ json_dict = json.loads(data)
+ return {key: pl.from_dict(json_dict[key]) for key in parse_config.input_frame_keys}
+
+
+async def parse_application_jsonl(data: bytes, parse_config: ParserConfig) -> TParsedMessage:
+ assert len(parse_config.input_frame_keys) == 1, "JSONL parsing only supports a single input frame key"
+ return {parse_config.input_frame_keys[0]: pl.read_ndjson(data)}
+
+async def parse_application_x_parquet(data: bytes, parse_config: ParserConfig) -> TParsedMessage:
+ assert len(parse_config.input_frame_keys) == 1, "Parquet parsing only supports a single input frame key"
+ return {parse_config.input_frame_keys[0]: pl.read_parquet(data)}
+
+async def parse_text_csv(data: bytes, parse_config: ParserConfig) -> TParsedMessage:
+ assert len(parse_config.input_frame_keys) == 1, "CSV parsing only supports a single input frame key"
+ return {parse_config.input_frame_keys[0]: pl.read_csv(data)}
+
+async def parse_application_excel(data: bytes, parse_config: ParserConfig) -> TParsedMessage:
+ assert len(parse_config.input_frame_keys) == 1, "Excel parsing only supports a single input frame key"
+ return {parse_config.input_frame_keys[0]: pl.read_excel(data)}
+
+
+
+DEFAULT_INPUT_HANDLERS = {
+ MediaType.APPLICATION_JSON.value: parse_application_json,
+ MediaType.APPLICATION_JSONL.value: parse_application_jsonl,
+ MediaType.APPLICATION_X_PARQUET.value: parse_application_x_parquet,
+ MediaType.TEXT_CSV.value: parse_text_csv,
+ MediaType.APPLICATION_EXCEL.value: parse_application_excel,
+ MediaType.APPLICATION_VND_MS_EXCEL.value: parse_application_excel,
+}
diff --git a/spockflow/inference/config/loader/__init__.py b/decider/serving/servers/__init__.py
similarity index 100%
rename from spockflow/inference/config/loader/__init__.py
rename to decider/serving/servers/__init__.py
diff --git a/decider/serving/servers/core.py b/decider/serving/servers/core.py
new file mode 100644
index 0000000..48fd48d
--- /dev/null
+++ b/decider/serving/servers/core.py
@@ -0,0 +1,19 @@
+import json
+import typing as t
+
+from decider.exceptions import DeciderError
+
+_INITIALIZING = b'{"message": "Server is initializing, please try again shortly."}'
+
+
+def error_response(error: DeciderError) -> t.Tuple[int, bytes, str]:
+ body = error.get_response_body()
+ payload: t.Dict[str, str] = {"message": body.message}
+ if body.details:
+ payload["details"] = body.details
+ return error.get_status_code(), json.dumps(payload).encode(), "application/json"
+
+
+def parse_content_headers(headers: t.Mapping[str, str]) -> t.Tuple[str, str]:
+ """Return (content_type, accept) from a headers mapping."""
+ return headers.get("content-type", ""), headers.get("accept", "*/*")
diff --git a/decider/serving/servers/sanic.py b/decider/serving/servers/sanic.py
new file mode 100644
index 0000000..8178f8d
--- /dev/null
+++ b/decider/serving/servers/sanic.py
@@ -0,0 +1,56 @@
+import typing as t
+from decider.exceptions import DeciderError, wrap_import_errors
+
+with wrap_import_errors("sanic"):
+ from sanic import Sanic
+ from sanic.request import Request
+ from sanic.response import HTTPResponse, raw
+
+from decider.serving.handler import RequestHandler, construct_handler_from_settings
+from .core import error_response, parse_content_headers, _INITIALIZING
+
+
+handler: t.Optional[RequestHandler] = None
+
+
+async def predict(request: Request) -> HTTPResponse:
+ if handler is None:
+ return raw(_INITIALIZING, status=503, content_type="application/json")
+ content_type, accept = parse_content_headers(request.headers)
+ result = await handler.process_fn(request.body, accept, content_type)
+ return raw(result.content, status=200, content_type=result.media_type)
+
+
+async def ping(_request: Request) -> HTTPResponse:
+ if handler is None:
+ return raw(_INITIALIZING, status=503, content_type="application/json")
+ return raw(b"", status=200)
+
+
+def create_app(name: str = "decider") -> Sanic:
+ # Must be called inside each worker process — do NOT call at module level.
+ # Sanic's multi-process AppLoader invokes this factory once per worker.
+ app = Sanic(name)
+ app.add_route(predict, "/invocations", methods=["POST"])
+ app.add_route(ping, "/ping", methods=["GET"])
+
+ @app.exception(DeciderError)
+ async def decider_error_handler(_request: Request, exc: DeciderError) -> HTTPResponse:
+ status_code, body, media_type = error_response(exc)
+ return raw(body, status=status_code, content_type=media_type)
+
+ @app.before_server_start
+ async def startup(_app) -> None:
+ from decider.initialization import initialize_decider
+ global handler
+ initialize_decider()
+ _handler = construct_handler_from_settings()
+ await _handler.init_fn()
+ handler = _handler
+
+ @app.after_server_stop
+ async def shutdown(_app) -> None:
+ if handler is not None:
+ await handler.shutdown_fn()
+
+ return app
diff --git a/decider/serving/servers/starlette.py b/decider/serving/servers/starlette.py
new file mode 100644
index 0000000..45d2085
--- /dev/null
+++ b/decider/serving/servers/starlette.py
@@ -0,0 +1,68 @@
+import typing as t
+from contextlib import asynccontextmanager
+from decider.exceptions import DeciderError, wrap_import_errors
+
+with wrap_import_errors("starlette"):
+ from starlette.applications import Starlette
+ from starlette.requests import Request
+ from starlette.responses import Response
+ from starlette.routing import Route
+
+from decider.serving.handler import RequestHandler, construct_handler_from_settings
+from .core import error_response, parse_content_headers, _INITIALIZING
+
+
+handler: t.Optional[RequestHandler] = None
+
+
+async def decider_error_handler(request: Request, exc: DeciderError) -> Response:
+ status_code, body, media_type = error_response(exc)
+ return Response(content=body, status_code=status_code, media_type=media_type)
+
+
+async def predict(request: Request) -> Response:
+ if handler is None:
+ return Response(content=_INITIALIZING, status_code=503, media_type="application/json")
+ content_type, accept = parse_content_headers(request.headers)
+ result = await handler.process_fn(await request.body(), accept, content_type)
+ return Response(content=result.content, media_type=result.media_type)
+
+
+async def ping(_request: Request) -> Response:
+ if handler is None:
+ return Response(content=_INITIALIZING, status_code=503, media_type="application/json")
+ return Response(status_code=200)
+
+
+@asynccontextmanager
+async def lifespan(app: "Starlette"):
+ from decider.initialization import initialize_decider
+ global handler
+ initialize_decider()
+ _handler = construct_handler_from_settings()
+ await _handler.init_fn()
+ handler = _handler
+ yield
+ if handler is not None:
+ await handler.shutdown_fn()
+
+
+def create_app() -> "Starlette":
+ return Starlette(
+ routes=[
+ Route("/invocations", predict, methods=["POST"]),
+ Route("/ping", ping, methods=["GET"]),
+ ],
+ lifespan=lifespan,
+ exception_handlers={DeciderError: decider_error_handler},
+ )
+
+
+app: t.Optional["Starlette"] = None
+
+
+def get_app() -> "Starlette":
+ global app
+ if app is None:
+ app = create_app()
+ return app
diff --git a/decider/settings.py b/decider/settings.py
new file mode 100644
index 0000000..0f2a49d
--- /dev/null
+++ b/decider/settings.py
@@ -0,0 +1,93 @@
+import os
+import typing as t
+from pydantic import BaseModel, Field, field_validator, ConfigDict
+from pydantic_settings import BaseSettings, SettingsConfigDict
+from functools import cache
+
+
+def _default_workers() -> int:
+ """nproc * 2 + 1 — sensible default for I/O-bound async workers."""
+ return os.cpu_count() * 2 + 1
+
+
+class ServeSettings(BaseModel):
+ """Settings for the Decider HTTP server."""
+ host: str = "0.0.0.0"
+ port: int = 8080
+ # None means use _default_workers() at serve time so nproc is evaluated
+ # on the target machine, not at settings-parse time.
+ workers: t.Optional[int] = None
+
+
+class APISettings(BaseModel):
+ """Settings for the Decider API."""
+ root_path: str = "./model/code"
+ flow_subpath: str = ""
+ init_module: t.Optional[str] = "inference"
+
+class DeciderAppExtensionSettings(BaseModel):
+ """Settings for the Decider application extensions."""
+ extension_path: str = "decider_extensions"
+ extension_imports: t.List[str] = []
+
+ @field_validator("extension_imports", mode="before")
+ @classmethod
+ def split_comma_separated(cls, v: t.Any) -> t.Any:
+ """Allow comma-separated values from environment variables."""
+ if isinstance(v, str):
+ return [item.strip() for item in v.split(",") if item.strip()]
+ return v
+
+SETTINGS_DEFAULT_CONFIG_POLL_DURATION_S: int = 10
+
+class DeciderConfigSettings(BaseModel):
+ model_config = ConfigDict(extra='allow')
+ type: str = "file:json"
+
+ def get(self):
+ from decider.config import ConfigManager
+ return ConfigManager.model_validate(self.model_dump()).root
+
+
+class DeciderSettings(BaseSettings):
+ """Main settings for the Decider application."""
+
+ model_config = SettingsConfigDict(
+ env_prefix="Decider_",
+ env_nested_delimiter="__",
+ case_sensitive=False,
+ )
+
+ serve: ServeSettings = Field(default_factory=ServeSettings)
+ ext: DeciderAppExtensionSettings = Field(default_factory=DeciderAppExtensionSettings)
+ api: APISettings = Field(default_factory=APISettings)
+ config: DeciderConfigSettings = Field(default_factory=DeciderConfigSettings)
+
+
+settings = DeciderSettings()
+
+
+# ========================================
+# Executor Configuration
+# ========================================
+
+if t.TYPE_CHECKING:
+ from decider.executor import Executor
+
+
+
+@cache
+def get_default_executor() -> "Executor":
+ """Get the default executor from settings.
+
+ If no executor has been set, creates and returns a SimpleExecutor.
+
+ Returns:
+ The default executor instance
+
+ Example:
+ >>> executor = get_default_executor()
+ >>> compiled = module.compile(executor)
+ """
+ from decider.executor import SimpleExecutor
+ return SimpleExecutor()
diff --git a/decider/templates/__init__.py b/decider/templates/__init__.py
new file mode 100644
index 0000000..3033adb
--- /dev/null
+++ b/decider/templates/__init__.py
@@ -0,0 +1,3 @@
+from .scaffold import write_project, write_inline_module, write_package_module
+
+__all__ = ["write_project", "write_inline_module", "write_package_module"]
diff --git a/decider/templates/scaffold.py b/decider/templates/scaffold.py
new file mode 100644
index 0000000..8db03ec
--- /dev/null
+++ b/decider/templates/scaffold.py
@@ -0,0 +1,154 @@
+"""
+Template renderer and project scaffolder.
+
+All public functions take explicit paths and return Path objects of what was
+written. They are deliberately side-effect-free beyond filesystem writes so
+they can be called from the magic, a CLI, or a test.
+"""
+
+import re
+import textwrap
+from pathlib import Path
+from string import Template
+
+_STATIC = Path(__file__).parent / "static"
+
+
+# ── low-level helpers ─────────────────────────────────────────────────────────
+
+def _render(template_path: Path, **kwargs) -> str:
+ return Template(template_path.read_text()).substitute(**kwargs)
+
+
+def to_snake(name: str) -> str:
+ s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", name)
+ s = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", s)
+ return s.lower()
+
+
+def extract_function_names(source: str) -> list[str]:
+ """Top-level `def` names in order."""
+ return re.findall(r"^def ([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", source, re.MULTILINE)
+
+
+# ── inline extension (no package) ────────────────────────────────────────────
+
+def write_inline_module(ext_dir: Path, class_name: str, user_code: str) -> Path:
+ """
+ Write decider_extensions//__init__.py.
+ Returns the path written.
+ """
+ snake = to_snake(class_name)
+ fn_names = extract_function_names(user_code)
+ if not fn_names:
+ raise ValueError(
+ f"No top-level functions found in cell body for {class_name!r}. "
+ "Define at least one `def` returning pl.Expr."
+ )
+
+ pkg_dir = ext_dir / snake
+ pkg_dir.mkdir(parents=True, exist_ok=True)
+ out = pkg_dir / "__init__.py"
+ out.write_text(
+ _render(
+ _STATIC / "extension_module.py",
+ user_code=textwrap.dedent(user_code).strip(),
+ class_name=class_name,
+ type_id=snake,
+ function_names=",\n ".join(fn_names),
+ )
+ )
+ return out
+
+
+# ── package extension (uv src-layout) ────────────────────────────────────────
+
+def _package_src_dir(ext_dir: Path, package_name: str) -> Path:
+ return ext_dir / package_name / "src" / package_name
+
+
+def _package_init_path(ext_dir: Path, package_name: str) -> Path:
+ return _package_src_dir(ext_dir, package_name) / "__init__.py"
+
+
+def _rebuild_package_init(src_dir: Path, package_name: str) -> None:
+ """Re-generate __init__.py by importing every sibling .py that isn't __init__."""
+ modules = sorted(
+ p.stem for p in src_dir.glob("*.py") if p.stem != "__init__"
+ )
+ lines = [f"from .{m} import *" for m in modules]
+ (src_dir / "__init__.py").write_text("\n".join(lines) + "\n" if lines else "")
+
+
+def write_package_module(
+ ext_dir: Path,
+ class_name: str,
+ package_name: str,
+ user_code: str,
+) -> tuple[Path, Path]:
+ """
+ Ensure ext_dir// exists as a uv src-layout package, write
+ (or overwrite) the module file for class_name, and regenerate __init__.py.
+
+ Returns (module_file, init_file).
+ """
+ snake_class = to_snake(class_name)
+ fn_names = extract_function_names(user_code)
+ if not fn_names:
+ raise ValueError(
+ f"No top-level functions found in cell body for {class_name!r}."
+ )
+
+ src_dir = _package_src_dir(ext_dir, package_name)
+ src_dir.mkdir(parents=True, exist_ok=True)
+
+ # pyproject.toml — only create if missing
+ pyproject = ext_dir / package_name / "pyproject.toml"
+ if not pyproject.exists():
+ pyproject.write_text(
+ _render(_STATIC / "extension_package" / "pyproject.toml", package_name=package_name)
+ )
+
+ module_file = src_dir / f"{snake_class}.py"
+ module_file.write_text(
+ _render(
+ _STATIC / "extension_package" / "module.py",
+ user_code=textwrap.dedent(user_code).strip(),
+ class_name=class_name,
+ type_id=snake_class,
+ function_names=",\n ".join(fn_names),
+ )
+ )
+
+ _rebuild_package_init(src_dir, package_name)
+ return module_file, src_dir / "__init__.py"
+
+
+# ── new project scaffold ──────────────────────────────────────────────────────
+
+def write_project(projects_dir: Path, project_name: str) -> Path:
+ """
+ Scaffold a new project directory from the project template.
+ Returns the created project directory.
+ """
+ snake = to_snake(project_name)
+ project_dir = projects_dir / snake
+ if project_dir.exists():
+ raise FileExistsError(f"Project directory already exists: {project_dir}")
+
+ project_dir.mkdir(parents=True)
+ (project_dir / "decider_extensions").mkdir()
+ (project_dir / "configs").mkdir()
+
+ vars_ = dict(
+ project_title=project_name.replace("_", " ").title(),
+ project_dir=f"projects/{snake}",
+ )
+
+ for tmpl in (_STATIC / "project").iterdir():
+ if tmpl.is_file():
+ (project_dir / tmpl.name).write_text(
+ _render(tmpl, **vars_)
+ )
+
+ return project_dir
diff --git a/decider/templates/static/extension_module.py b/decider/templates/static/extension_module.py
new file mode 100644
index 0000000..a40a968
--- /dev/null
+++ b/decider/templates/static/extension_module.py
@@ -0,0 +1,13 @@
+import polars as pl
+from pydantic import BaseModel
+from decider.modules.functional import generate_from_functions
+from decider.modules import register_graph_module
+
+$user_code
+
+$class_name = generate_from_functions(
+ "$type_id",
+ $function_names,
+)
+
+register_graph_module($class_name)
diff --git a/decider/templates/static/extension_package/module.py b/decider/templates/static/extension_package/module.py
new file mode 100644
index 0000000..a40a968
--- /dev/null
+++ b/decider/templates/static/extension_package/module.py
@@ -0,0 +1,13 @@
+import polars as pl
+from pydantic import BaseModel
+from decider.modules.functional import generate_from_functions
+from decider.modules import register_graph_module
+
+$user_code
+
+$class_name = generate_from_functions(
+ "$type_id",
+ $function_names,
+)
+
+register_graph_module($class_name)
diff --git a/decider/templates/static/extension_package/pyproject.toml b/decider/templates/static/extension_package/pyproject.toml
new file mode 100644
index 0000000..c62778f
--- /dev/null
+++ b/decider/templates/static/extension_package/pyproject.toml
@@ -0,0 +1,19 @@
+[project]
+name = "$package_name"
+version = "0.1.0"
+description = "Decider extension package"
+requires-python = ">=3.10"
+dependencies = [
+ "polars",
+ "pydantic",
+]
+
+[project.optional-dependencies]
+dev = ["decider"]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/$package_name"]
diff --git a/decider/templates/static/project/generate.py b/decider/templates/static/project/generate.py
new file mode 100644
index 0000000..afedbb9
--- /dev/null
+++ b/decider/templates/static/project/generate.py
@@ -0,0 +1,65 @@
+"""
+$project_title — generates config artifacts and verifies the round-trip.
+
+Run from the project root:
+ python $project_dir/generate.py
+"""
+
+import sys
+import os
+import asyncio
+import json
+import polars as pl
+
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+PROJECT_DIR = os.path.dirname(__file__)
+sys.path.insert(0, PROJECT_ROOT)
+
+EXTENSIONS_DIR = os.path.join(PROJECT_DIR, "decider_extensions")
+
+from decider.initialization import initialize_decider
+from decider.config.file import JsonFileConfigManager
+from decider.modules import GraphModule
+
+CONFIGS_DIR = os.path.join(PROJECT_DIR, "configs")
+ROOT_KEY = "main"
+
+BATCH = pl.DataFrame({
+ # TODO: replace with a sample batch for your module
+ "id": ["row_1", "row_2"],
+})
+
+
+async def main():
+ print("=" * 60)
+ print("$project_title — config generation & serve round-trip")
+ print("=" * 60)
+
+ initialize_decider(extension_path=EXTENSIONS_DIR)
+ print("[1] Extensions loaded")
+
+ # TODO: import and build your module here
+ # from my_extension import MyModule
+ # module = MyModule(name="main")
+
+ config_manager = JsonFileConfigManager(basepath=CONFIGS_DIR)
+ # versioned = await module.asave(ROOT_KEY, config_manager)
+ # await config_manager.save_version(overwrite=True)
+ # print(f"[2] Saved version {versioned.version}")
+
+ abs_configs = os.path.abspath(CONFIGS_DIR)
+ print(f"""
+{"=" * 60}
+SERVING SETUP
+{"=" * 60}
+ export Decider_config__type=file:json
+ export Decider_config__basepath={abs_configs}
+ export Decider_api__root_module={ROOT_KEY}
+ export Decider_ext__extension_path={os.path.abspath(EXTENSIONS_DIR)}
+
+ uvicorn decider.serving.servers.starlette:app --host 0.0.0.0 --port 8080
+""")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/decider/templates/static/project/speedtest.py b/decider/templates/static/project/speedtest.py
new file mode 100644
index 0000000..2523fa6
--- /dev/null
+++ b/decider/templates/static/project/speedtest.py
@@ -0,0 +1,66 @@
+"""
+Speed test for $project_title.
+
+Usage (server must already be running):
+ python $project_dir/speedtest.py [--url URL] [--n N] [--concurrency C]
+"""
+
+import argparse
+import asyncio
+import json
+import statistics
+import time
+
+import httpx
+
+# TODO: replace with your module's input columns
+PAYLOAD = json.dumps({"id": ["row_1"]}).encode()
+HEADERS = {"Content-Type": "application/json"}
+
+
+async def run(url: str, n: int, concurrency: int):
+ timings = []
+ errors = 0
+ semaphore = asyncio.Semaphore(concurrency)
+
+ async def one(client: httpx.AsyncClient):
+ nonlocal errors
+ async with semaphore:
+ t0 = time.perf_counter()
+ try:
+ r = await client.post(url, content=PAYLOAD, headers=HEADERS)
+ if r.status_code != 200:
+ errors += 1
+ except Exception:
+ errors += 1
+ timings.append((time.perf_counter() - t0) * 1000)
+
+ print(f"Sending {n} requests concurrency={concurrency} → {url}")
+ async with httpx.AsyncClient(timeout=10.0) as client:
+ await client.post(url, content=PAYLOAD, headers=HEADERS) # warm-up
+ wall_start = time.perf_counter()
+ await asyncio.gather(*[one(client) for _ in range(n)])
+ wall_ms = (time.perf_counter() - wall_start) * 1000
+
+ ok = n - errors
+ p = sorted(timings)
+
+ def pct(q):
+ return p[min(int(len(p) * q / 100), len(p) - 1)]
+
+ print(f"\n{'─'*40}")
+ print(f" Requests {n} ({ok} OK, {errors} errors)")
+ print(f" Wall time {wall_ms:.0f} ms")
+ print(f" Throughput {ok / (wall_ms / 1000):.0f} req/s")
+ print(f"{'─'*40}")
+ print(f" Latency (ms) min={min(timings):.1f} p50={pct(50):.1f} p90={pct(90):.1f} p99={pct(99):.1f} max={max(timings):.1f}")
+ print(f"{'─'*40}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--url", default="http://localhost:8080/predict")
+ parser.add_argument("--n", type=int, default=1000)
+ parser.add_argument("--concurrency", type=int, default=50)
+ args = parser.parse_args()
+ asyncio.run(run(args.url, args.n, args.concurrency))
diff --git a/decider/types.py b/decider/types.py
new file mode 100644
index 0000000..629f427
--- /dev/null
+++ b/decider/types.py
@@ -0,0 +1,6 @@
+import typing as t
+import polars as pl
+
+
+TInputType = t.Dict[str, pl.LazyFrame]
+TOutputType = pl.LazyFrame
\ No newline at end of file
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 6b5b7cd..8608a3f 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,43 +1,56 @@
-# docker build --build-arg="PIP_CONF=$(cat $HOME/.config/pip/pip.conf)" -t capitec/engine:latest -f docker/Dockerfile .
-ARG BASE_IMG="python:3.12.0-slim-bookworm"
+# Build:
+# docker build \
+# --build-arg="PIP_CONF=$(cat $HOME/.config/pip/pip.conf)" \
+# -t decider:latest -f docker/Dockerfile .
+#
+# Run:
+# docker run -p 8080:8080 \
+# -v $(pwd)/model:/opt/ml/model:ro \
+# decider:latest
+#
+# Environment variables (all optional):
+# DECIDER_SERVE__HOST bind address (default: 0.0.0.0)
+# DECIDER_SERVE__PORT bind port (default: 8080)
+# DECIDER_SERVE__WORKERS worker processes (default: nproc*2+1)
+# DECIDER_API__ROOT_PATH project root (default: /opt/ml/model/code)
+# DECIDER_EXT__EXTENSION_PATH extensions dir (default: /opt/ml/model/extensions)
+
+ARG BASE_IMG="python:3.14-rc-slim-bookworm"
FROM ${BASE_IMG}
ARG PIP_CONF=
-# MAINTAINER Amazon AI
RUN set -aex && \
- apt-get -y update && apt-get install -y --no-install-recommends \
- wget \
- nginx \
- ca-certificates \
+ apt-get -y update && \
+ apt-get install -y --no-install-recommends \
+ ca-certificates \
+ procps \
&& rm -rf /var/lib/apt/lists/*
-# Install requirements first to reduce build times
-COPY requirements /tmp/requirements
-# Dont do set aex here to avoid precious creds being leaked to the console
-RUN mkdir /opt/program && \
- echo "$PIP_CONF" >> /etc/pip.conf && \
- pip --no-cache-dir install -r /tmp/requirements/all-prod.txt && \
- pip --no-cache-dir install "setuptools>=40.8.0" && \
- rm /etc/pip.conf && \
- rm -rf /tmp/requirements
+# Drop the pre-built wheel into the image and install with the sanic extra.
+# Build the wheel first: python -m build --wheel -o dist/
+COPY dist /tmp/decider/
-COPY dist /tmp/spockflow/
-RUN set -aex && \
- ls /tmp/spockflow && \
- echo "$PIP_CONF" >> /etc/pip.conf && \
- pip --no-cache-dir install --no-deps --no-build-isolation /tmp/spockflow/*.whl && \
+RUN set -ex && \
+ echo "$PIP_CONF" >> /etc/pip.conf && \
+ WHEEL=$(ls /tmp/decider/*.whl | head -1) && \
+ pip --no-cache-dir install "${WHEEL}[serve-sanic]" && \
rm /etc/pip.conf && \
- rm -rf /tmp/spockflow
+ rm -rf /tmp/decider
-# Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard
-# output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE
-# keeps Python from writing the .pyc files which are unnecessary in this case. We also update
-# PATH so that the train and serve programs are found when the container is invoked.
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUE
-ENV PATH="/opt/program:${PATH}"
-ENV HAMILTON_TELEMETRY_ENABLED=FALSE
-# Set up the program in the image
-WORKDIR /opt/program
-ENTRYPOINT ["python3", "-m", "spockflow.inference.server.run"]
-CMD ["serve"]
+
+# /opt/ml/model — mounted at runtime; mirrors SageMaker & DSP path conventions.
+# model/code/ project root (generate.py, configs/, …)
+# model/extensions/ decider extensions discovered at startup
+ENV DECIDER_API__ROOT_PATH="/opt/ml/model/code"
+ENV DECIDER_EXT__EXTENSION_PATH="/opt/ml/model/extensions"
+
+# Expose the default port; override with DECIDER_SERVE__PORT + -p at runtime.
+EXPOSE 8080
+
+WORKDIR /opt/ml
+
+# `decider serve --engine sanic` picks up host/port/workers from
+# DECIDER_SERVE__* env vars; workers falls back to nproc*2+1 when unset.
+ENTRYPOINT ["decider", "serve", "--engine", "sanic"]
diff --git a/docker/ui.Dockerfile b/docker/ui.Dockerfile
deleted file mode 100644
index 4529de5..0000000
--- a/docker/ui.Dockerfile
+++ /dev/null
@@ -1,35 +0,0 @@
-FROM node:18-alpine AS builder
-
-WORKDIR /app
-
-COPY package.json package-lock.json ./
-RUN npm ci
-
-COPY src/ ./src/
-COPY public/ ./public/
-COPY tsconfig*.json ./*.js ./*.json ./
-
-RUN npm run build
-# This is useful for testing
-# RUN npm i serve
-# EXPOSE 3000
-# CMD ["npx", "serve", "-s", "build"]
-
-FROM nginx:stable-alpine
-
-COPY --from=builder /app/build /usr/share/nginx/html
-
-RUN addgroup -g 1000 -S appgroup && \
- adduser -u 1000 -S appuser -G appgroup && \
- mkdir -p /var/cache/nginx/client_temp && \
- mkdir -p /tmp/nginx && \
- chown -R 1000:1000 /var/cache/nginx && \
- chown -R 1000:1000 /usr/share/nginx/html && \
- chown -R 1000:1000 /tmp/nginx
-
-COPY nginx.conf /etc/nginx/nginx.conf
-
-USER 1000
-
-EXPOSE 8000
-CMD ["nginx", "-g", "daemon off;"]
\ No newline at end of file
diff --git a/docs/.gitignore b/docs/.gitignore
deleted file mode 100644
index 1888793..0000000
--- a/docs/.gitignore
+++ /dev/null
@@ -1,5 +0,0 @@
-_build
-_autosummary
-confluence.json
-config
-source_dir
\ No newline at end of file
diff --git a/docs/Makefile b/docs/Makefile
index d4bb2cb..fe8e88c 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -1,20 +1,12 @@
-# Minimal makefile for Sphinx documentation
-#
-
-# You can set these variables from the command line, and also
-# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
-# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/_static/concepts/inference.drawio.svg b/docs/_static/concepts/inference.drawio.svg
deleted file mode 100644
index c83e20e..0000000
--- a/docs/_static/concepts/inference.drawio.svg
+++ /dev/null
@@ -1,245 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/_static/concepts/inference_light.drawio.svg b/docs/_static/concepts/inference_light.drawio.svg
deleted file mode 100644
index 8cc431c..0000000
--- a/docs/_static/concepts/inference_light.drawio.svg
+++ /dev/null
@@ -1,241 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/_static/getting-started/example_pipeline.drawio.svg b/docs/_static/getting-started/example_pipeline.drawio.svg
deleted file mode 100644
index 7d2771a..0000000
--- a/docs/_static/getting-started/example_pipeline.drawio.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/docs/_static/getting-started/example_pipeline_light.drawio.svg b/docs/_static/getting-started/example_pipeline_light.drawio.svg
deleted file mode 100644
index 84feb97..0000000
--- a/docs/_static/getting-started/example_pipeline_light.drawio.svg
+++ /dev/null
@@ -1,786 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/api.rst b/docs/api.rst
index bc91609..e1c6cdb 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -1,6 +1,8 @@
+API Reference
+=============
+
.. autosummary::
:toctree: _autosummary
- :template: custom-module-template.rst
:recursive:
- spockflow
\ No newline at end of file
+ decider
diff --git a/docs/concepts/config.md b/docs/concepts/config.md
new file mode 100644
index 0000000..5b2dcbf
--- /dev/null
+++ b/docs/concepts/config.md
@@ -0,0 +1,51 @@
+# Config & Versioning
+
+Decider modules are serialised to and from JSON configs, enabling versioned, auditable deployments.
+
+## Loading a module from config
+
+Every module exposes `.load(config_dict)` which returns a bound, runnable instance:
+
+```python
+config = {
+ "type": "Scorer",
+ "scorer_config": {"threshold": 4.5},
+}
+module = GraphModule.load(config)
+result = module.run(df)
+```
+
+## Saving a config
+
+```python
+config_dict = module.model_dump()
+import json
+json.dumps(config_dict, indent=2)
+```
+
+## Versioning with git tags
+
+Configs stored on disk are snapshots. Pair them with a git tag to pin the exact code version:
+
+```bash
+git tag v1.2.0
+# store configs/v1.2.0/loan_decision.json alongside the tag
+```
+
+## Config directory layout
+
+The default layout expected by `decider serve`:
+
+```
+project/
+├── configs/
+│ └── loan_decision.json
+├── extensions/
+│ └── my_scorer/
+│ └── __init__.py
+└── generate.py
+```
+
+## Environment-specific overrides
+
+Use Pydantic `BaseSettings` inside your config models to pull values from environment variables at load time, keeping secrets out of version-controlled JSON files.
diff --git a/docs/concepts/decision_tables.ipynb b/docs/concepts/decision_tables.ipynb
deleted file mode 100644
index 3a61652..0000000
--- a/docs/concepts/decision_tables.ipynb
+++ /dev/null
@@ -1,297 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## SpockFlow Decision Table Component\n",
- "\n",
- "The Decision Table component in SpockFlow allows users to define and execute decision tables, which are structured mappings of input conditions to output values. Decision tables are particularly useful for explicit and deterministic rule-based decision-making.\n",
- "\n",
- "### Usage\n",
- "\n",
- "To begin using the Decision Table component, import the necessary packages and instantiate a `DecisionTable` object:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "from spockflow.components import dtable\n",
- "\n",
- "input_v1 = \"input_v1\"\n",
- "input_v2 = \"input_v2\"\n",
- "\n",
- "example_dt = dtable.DecisionTable()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Adding Conditions and Outputs\n",
- "\n",
- "Conditions and corresponding outputs can be added to the `DecisionTable` using the `add` and `output` methods respectively. Each condition specifies an operation or comparison involving input variables, while outputs define the values or descriptions associated with matched conditions.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DecisionTable(operations=[DTMin(predicate=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], op='MIN'), DTMax(predicate=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], op='MAX'), DTMin(predicate=[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], op='MIN'), DTMax(predicate=[1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], op='MAX')], operation_inputs=['input_v1', 'input_v1', 'input_v2', 'input_v2'], outputs={'value': [1, 2, 0, None, -1, 20, 1, 2, 3, 4, 5], 'description': ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']}, allow_multi_result=False, default_value= value description\n",
- "0 999 NA)"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "example_dt.add(dtable.DTMin, input_v1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).add(\n",
- " dtable.DTMax, input_v1, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]\n",
- ").add(dtable.DTMin, input_v2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]).add(\n",
- " dtable.DTMax, input_v2, [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]\n",
- ").set_default(\n",
- " pd.DataFrame({\"value\": [999], \"description\": [\"NA\"]})\n",
- ").output(\n",
- " \"value\", [1, 2, 0, None, -1, 20, 1, 2, 3, 4, 5]\n",
- ").output(\n",
- " \"description\", [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\", \"j\", \"k\"]\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Execution\n",
- "\n",
- "Execute the Decision Table on input data using the `execute` method:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
value
\n",
- "
description
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
20.0
\n",
- "
f
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
NaN
\n",
- "
d
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
999.0
\n",
- "
NA
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
1.0
\n",
- "
a
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
999.0
\n",
- "
NA
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " value description\n",
- "0 20.0 f\n",
- "1 NaN d\n",
- "2 999.0 NA\n",
- "3 1.0 a\n",
- "4 999.0 NA"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "input_data = pd.DataFrame({input_v1: [5, 3, 8, 0, 10], input_v2: [0, 0, 0, 0, 4]})\n",
- "result_df = example_dt.execute(inputs=input_data)\n",
- "result_df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `result_df` DataFrame will contain columns for each output defined, with rows corresponding to the matched conditions based on the input data.\n",
- "\n",
- "### Saving and Loading Configurations\n",
- "\n",
- "Similar to other components in SpockFlow, configurations of Decision Tables can be saved and loaded using configuration managers such as `YamlConfigManager`. This allows for easy deployment and reuse of decision rules.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/cp371651/.pyenv/versions/3.12.2/envs/spock/lib/python3.12/site-packages/pydantic/main.py:314: UserWarning: Pydantic serializer warnings:\n",
- " Expected `generator` but got `list` - serialized value may not be as expected\n",
- " Expected `generator` but got `list` - serialized value may not be as expected\n",
- " return self.__pydantic_serializer__.to_python(\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
description
\n",
- "
value
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
f
\n",
- "
20.0
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
d
\n",
- "
NaN
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
NA
\n",
- "
999.0
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
a
\n",
- "
1.0
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
NA
\n",
- "
999.0
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " description value\n",
- "0 f 20.0\n",
- "1 d NaN\n",
- "2 NA 999.0\n",
- "3 a 1.0\n",
- "4 NA 999.0"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from spockflow.inference.config.loader.yamlmanager import YamlConfigManager\n",
- "\n",
- "conf_manager = YamlConfigManager()\n",
- "conf_manager.save_to_config(\n",
- " model_name=\"demo_spock_model\",\n",
- " model_version=\"1.0.0\",\n",
- " namespace=\"decision_table_config\",\n",
- " config=example_dt.model_dump(mode=\"json\"),\n",
- ")\n",
- "\n",
- "# Load configuration\n",
- "config = conf_manager.get_config(\"demo_spock_model\", \"1.0.0\")\n",
- "dt_loaded = dtable.DecisionTable.from_config(\"decision_table_config\").load(config)\n",
- "\n",
- "# Execute with loaded configuration\n",
- "result_df_loaded = dt_loaded.execute(inputs=input_data)\n",
- "result_df_loaded"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "spock",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/concepts/decision_trees.ipynb b/docs/concepts/decision_trees.ipynb
deleted file mode 100644
index c12be6e..0000000
--- a/docs/concepts/decision_trees.ipynb
+++ /dev/null
@@ -1,1600 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Decision Trees in SpockFlow: A Practical Guide\n",
- "\n",
- "Welcome to the world of Decision Trees in SpockFlow! Decision Trees are powerful tools for automating decision-making processes based on defined conditions and actions. Let's dive into how you can leverage Decision Trees effectively within SpockFlow, including integration with Hamilton for creating executable pipelines.\n",
- "\n",
- "#### Importing Necessary Packages\n",
- "\n",
- "To begin, let's import the essential packages from SpockFlow:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from spockflow.components.tree import Tree, Action\n",
- "from spockflow.core import initialize_spock_module\n",
- "import pandas as pd"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `Tree` class allows us to construct decision trees, and `Action` serves as a wrapper for defining standardized outputs.\n",
- "\n",
- "#### Defining Action Types\n",
- "\n",
- "To maintain consistency in the outputs of our decision tree, we define an action type using `TypedDict` from the `typing_extensions` module:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from typing_extensions import TypedDict\n",
- "\n",
- "\n",
- "class Reject(TypedDict):\n",
- " code: int\n",
- " description: str\n",
- "\n",
- "\n",
- "RejectAction = Action[Reject]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Here, `RejectAction` specifies the structure of actions that our decision tree can produce.\n",
- "\n",
- "#### Creating an Instance of the Tree\n",
- "\n",
- "Next, let's create an instance of the `Tree`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "tree = Tree()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This `tree` object will be our canvas for defining conditions and their corresponding actions.\n",
- "\n",
- "#### Adding Conditions to the Tree\n",
- "\n",
- "Conditions in the decision tree are defined using decorators (`@tree.condition`). These conditions evaluate input data and trigger specified actions when conditions are met. Here’s an example of adding a condition:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n",
- "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n",
- " return (d > 5) & (e > 5) & (f > 5)\n",
- "\n",
- "\n",
- "tree.visualize(get_value_name=lambda x: x[\"description\"][0])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In this example, `first_condition` triggers a rejection with code `102` and description \"My first condition\" when certain criteria (`d > 5`, `e > 5`, `f > 5`) are fulfilled.\n",
- "\n",
- "#### Nesting Conditions\n",
- "\n",
- "You can nest conditions under parent conditions using decorators. Here’s how to nest conditions under `condition_a`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@tree.condition()\n",
- "def condition_a(a: pd.Series) -> pd.Series:\n",
- " return a > 5\n",
- "\n",
- "\n",
- "@condition_a.condition(\n",
- " output=RejectAction(code=100, description=\"a and b are out of range\")\n",
- ")\n",
- "def condition_b(b: pd.Series) -> pd.Series:\n",
- " return b > 5\n",
- "\n",
- "\n",
- "@condition_a.condition(\n",
- " output=RejectAction(code=101, description=\"a and c are out of range\")\n",
- ")\n",
- "def condition_c(c: pd.Series) -> pd.Series:\n",
- " return c > 5\n",
- "\n",
- "\n",
- "tree.visualize(get_value_name=lambda x: x[\"description\"][0])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note the usage of `@condition_a.condition` to define nested conditions (`condition_b` and `condition_c`) under `condition_a`.\n",
- "\n",
- "#### Setting Default Actions\n",
- "\n",
- "It’s essential to define a default action for cases where none of the conditions match. This is done using `set_default`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tree.set_default(output=RejectAction(code=-1, description=\"N/A\"))\n",
- "\n",
- "tree.visualize(get_value_name=lambda x: x[\"description\"][0])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "By default, all values are set to `pd.NA` when no specific conditions are met. The tree can now be executed with the following code:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
"
- ],
- "text/plain": [
- " code description\n",
- "0 101 Input condition\n",
- "1 101 a and c are out of range\n",
- "2 102 My first condition\n",
- "3 -1 N/A\n",
- "4 -1 N/A\n",
- "5 -1 N/A\n",
- "6 101 Input condition\n",
- "7 101 Input condition"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tree.execute(inputs=test_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Including Subtrees\n",
- "\n",
- "In SpockFlow, you can amplify the power of your Decision Trees by including subtrees. This feature allows you to nest complex decision logic within your main tree structure seamlessly. Here’s how you can include a subtree:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tree = Tree()\n",
- "tree.condition(output=Action(value=10), condition=\"A\")\n",
- "tree.condition(output=Action(value=20), condition=\"B\")\n",
- "\n",
- "subtree = Tree()\n",
- "subtree.condition(output=Action(value=100), condition=\"SubA\")\n",
- "subtree.condition(output=Action(value=200), condition=\"SubB\")\n",
- "\n",
- "cond_subtree = Tree()\n",
- "cond_subtree.condition(output=Action(value=1000), condition=\"SubD\")\n",
- "cond_subtree.condition(output=Action(value=2000), condition=\"SubE\")\n",
- "\n",
- "tree.include_subtree(subtree)\n",
- "tree.include_subtree(cond_subtree, condition=\"SubC\")\n",
- "tree.visualize(get_value_name=lambda x: str(x[\"value\"][0]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In this setup:\n",
- "- The main `tree` defines conditions 'A' and 'B' with outputs 10 and 20 respectively.\n",
- "- `subtree`, a separate instance of `Tree`, defines its own conditions 'SubA' and 'SubB' with outputs 100 and 200.\n",
- "- `tree.include(subtree)` integrates `subtree` into `tree`, enabling `subtree`'s conditions to function as part of `tree`'s decision-making process.\n",
- "\n",
- "This approach allows for hierarchical structuring of decision logic, making your Decision Trees in SpockFlow even more versatile and powerful."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Integration with Hamilton\n",
- "\n",
- "SpockFlow seamlessly integrates with Hamilton to execute decision trees as part of data processing pipelines. Hamilton allows converting sequences of Python functions into executable Directed Acyclic Graphs (DAGs), enabling clear data flow management. Here's how you can incorporate our decision tree into a Hamilton pipeline:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext hamilton.plugins.jupyter_magic"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%cell_to_module --display -m demo_tree\n",
- "\n",
- "# Import necessary packages\n",
- "from typing import TypedDict\n",
- "import pandas as pd\n",
- "\n",
- "from spockflow.components.tree import Tree, Action\n",
- "from spockflow.core import initialize_spock_module\n",
- "\n",
- "# Define Reject action type\n",
- "class Reject(TypedDict):\n",
- " code: int\n",
- " description: str\n",
- "\n",
- "RejectAction = Action[Reject]\n",
- "\n",
- "def input_condition(nums: pd.Series) -> pd.Series:\n",
- " \"\"\"input condition can be a value calculated just like any other value in a hamilton dag\"\"\"\n",
- " return nums % 2 == 0\n",
- "\n",
- "# Initialize Tree\n",
- "tree = Tree()\n",
- "\n",
- "# Define conditions and actions\n",
- "@tree.condition()\n",
- "def condition_a(a: pd.Series) -> pd.Series:\n",
- " return a > 5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=100, description=\"a and b are out of range\"))\n",
- "def condition_b(b: pd.Series) -> pd.Series:\n",
- " return b > 5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=101, description=\"a and c are out of range\"))\n",
- "def condition_c(c: pd.Series) -> pd.Series:\n",
- " return c > 5\n",
- "\n",
- "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n",
- "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n",
- " return (d > 5) & (e > 5) & (f > 5)\n",
- "\n",
- "# Set condition with reference to external hamilton element\n",
- "tree.condition(condition=\"input_condition\", output=RejectAction(code=101, description=\"Input condition\"))\n",
- "# Set default action\n",
- "tree.set_default(output=RejectAction(code=-1, description=\"N/A\"))\n",
- "\n",
- "# Initialize SpockFlow module for Hamilton\n",
- "initialize_spock_module(__name__, output_names=[\"tree\"])\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the above example:\n",
- "- We define our decision tree and associated conditions.\n",
- "- The initialize_spock_module function hooks into Hamilton's DAG creation system, enabling the construction of SpockFlow-specific nodes and specifying default outputs.\n",
- "\n",
- "#### Executing the Hamilton DAG\n",
- "\n",
- "Once defined, you can execute the Hamilton DAG using the SpockFlow Driver (Driver):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " code description\n",
- "0 -1 N/A\n",
- "1 101 a and c are out of range\n",
- "2 102 My first condition\n",
- "3 101 Input condition\n",
- "4 -1 N/A\n",
- "5 101 Input condition\n",
- "6 -1 N/A\n",
- "7 101 Input condition\n"
- ]
- }
- ],
- "source": [
- "from spockflow.core import Driver\n",
- "\n",
- "# Initialize Driver with the decision tree module\n",
- "dr = Driver({}, demo_tree)\n",
- "\n",
- "# Execute the DAG with test data\n",
- "df = dr.execute(inputs=test_data)\n",
- "print(df)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "spock",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/concepts/extensions.md b/docs/concepts/extensions.md
new file mode 100644
index 0000000..17d46b9
--- /dev/null
+++ b/docs/concepts/extensions.md
@@ -0,0 +1,67 @@
+# Extensions
+
+Extensions are Python packages or modules that live in an `extensions/` directory alongside your project. Decider discovers and imports them automatically at startup, making their module types available in the registry.
+
+## Inline extension (single file)
+
+The simplest form — a single `__init__.py` inside `extensions//`:
+
+```
+extensions/
+└── my_scorer/
+ └── __init__.py # defines and registers MyScorer
+```
+
+## Package extension (shareable)
+
+For extensions you want to share or version independently, use a `uv` src-layout package:
+
+```
+extensions/
+└── my_pkg/
+ ├── pyproject.toml
+ └── src/
+ └── my_pkg/
+ ├── __init__.py
+ └── my_scorer.py
+```
+
+Decider discovers both layouts automatically.
+
+## Jupyter magic
+
+In a notebook, use the `%%module` magic to create or update an extension inline:
+
+```python
+%load_ext decider.magics
+```
+
+```
+%%module MyScorer
+
+def score(income: float, debt: float) -> float:
+ return income / (debt + 1)
+```
+
+This writes the module file to `extensions/my_scorer/__init__.py`, imports it, and registers it — all in one cell. Re-running the cell reloads and re-registers without duplicating the discriminator.
+
+### Package mode
+
+```
+%%module MyScorer --package my_pkg
+```
+
+Creates `extensions/my_pkg/src/my_pkg/my_scorer.py` as a proper installable package.
+
+## CLI scaffolding
+
+```bash
+# scaffold a standalone module
+decider template module MyScorer
+
+# scaffold inside a package
+decider template module MyScorer --package my_pkg
+
+# scaffold a full project
+decider template project my_project
+```
diff --git a/docs/concepts/index.md b/docs/concepts/index.md
index c1b1513..0210836 100644
--- a/docs/concepts/index.md
+++ b/docs/concepts/index.md
@@ -1,11 +1,12 @@
# Concepts
-Welcome to the Concepts section, your gateway to understanding the foundational principles and components of SpockFlow, a versatile framework designed for creating standalone micro-services that enrich data with actionable outputs. This section serves as a comprehensive resource where you can explore key concepts such as model deployment, inference handling, configuration management, and custom component development within SpockFlow. Whether you are a data scientist deploying models or a developer integrating custom functionalities, these concepts provide essential insights into leveraging SpockFlow's capabilities effectively. Dive into each concept to gain a deeper understanding of how SpockFlow empowers efficient and scalable data enrichment and deployment workflows.
+This section explains the core building blocks of Decider and how they compose into complete decision pipelines.
```{toctree}
+:maxdepth: 1
-Decision Trees
-Decision Tables
-Score Cards
-API Customization
-```
\ No newline at end of file
+Modules
+Pipelines
+Config & Versioning
+Extensions
+```
diff --git a/docs/concepts/inference.md b/docs/concepts/inference.md
deleted file mode 100644
index 8120182..0000000
--- a/docs/concepts/inference.md
+++ /dev/null
@@ -1,292 +0,0 @@
-# API Inference Handling
-
-In SpockFlow, method overloading plays a pivotal role in customizing the behavior of the inference handler, allowing users to tailor every aspect of API deployment to their specific needs. Several key methods and classes within the framework can be easily overridden to adjust functionality:
-
-- **Input Handling**: Customize data decoding and preprocessing using `input_fn`, `decoders`.
-- **Model Management**: Define how models are loaded and configured with `model_fn`, `dag_loader_fn`, `config_manager_loader_fn`, `model_loader_cls`, `model_cache_cls`, and `model_config_cls`.
-- **Prediction and Post-processing**: Modify prediction logic and result refinement via `predict_fn` and `post_process_fn`.
-- **Response Formatting**: Adjust how predictions are encoded into various output formats using `output_fn`, `encoders`.
-
-These methods not only provide flexibility but also enhance integration capabilities, optimize performance, and align seamlessly with existing systems. This section will delve into each function's role in tailoring the deployment of SpockFlow's API, offering insights into practical customization strategies and implementation guidelines.
-
-## Overview
-
-The `transform_fn` serves as the central entry point in SpockFlow's API deployment, where raw requests are processed and transformed into meaningful responses. This function orchestrates a sequence of operations, starting with data decoding and preprocessing, followed by model prediction and output formatting. Each step in this workflow can be tailored to specific requirements by defining corresponding functions in the `inference.py` module. The following diagram is a high-level overview of the `transform_fn`.
-
-
-```{figure} ../_static/concepts/inference.drawio.svg
-:scale: 100
-:align: center
-:class: only-dark
-```
-
-```{figure} ../_static/concepts/inference_light.drawio.svg
-:scale: 100
-:align: center
-:class: only-light
-```
-
-For instance, the input_fn method can be customized within `inference.py` to handle data decoding based on content type. Here’s an example:
-```python
-# inference.py
-
-def input_fn(input_data, content_type):
- # Custom decoding logic
- decoded_data = custom_decoder(input_data, content_type)
- return decoded_data
-```
-Additionally, methods defined with `self` as the first argument in `inference.py` allow the instance of the `ServingHandler` to be injected dynamically at runtime. This enables seamless integration of custom functionalities and ensures flexibility in adapting the inference pipeline to diverse application scenarios.
-
-The following sections will explore each method's role within `transform_fn`, illustrating how they can be leveraged to tailor the API deployment process in SpockFlow.
-
-## Methods
-
-This section describes all the configurable stages that can be overloaded in the inference.py and how they affect the predicted result.
-### Input Function
-The `input_fn` method in SpockFlow's ServingHandler class handles the initial step of data ingestion and decoding during API requests. It accepts two parameters:
-- **input_data**: The raw data received from the API request, typically in bytes.
-- **content_type**: Specifies the format or encoding of input_data, guiding the decoding process.
-The primary role of `input_fn` is to decode and pre-process incoming raw data into a structured format that downstream processes can handle effectively. This method plays a crucial role in ensuring compatibility between various data sources and the internal processing pipeline of SpockFlow.
-
-**Example:**
-
-```python
-# inference.py
-import json
-
-def input_fn(input_data, content_type):
- if content_type != 'application/json':
- raise ValueError("API only supports json data")
- return json.loads(input_data)
-```
-
-The default input_fn function in SpockFlow utilizes a decoders dictionary to decode incoming data based on the content-type header of API requests. By default, it supports decoding both CSV and JSON formats, converting them into Python objects. For instance, CSV data is parsed into a Pandas DataFrame, while JSON data is directly parsed into a dictionary. Users have the flexibility to extend or override these default decoders in the inference.py module. For example, adding support for Parquet format can be achieved by defining a custom decoder function and updating the decoders dictionary accordingly:
-
-```python
-# inference.py
-import pandas as pd
-from io import BytesIO
-from spockflow.inference.io.decoders import default_decoders
-
-decoders = {**default_decoders}
-
-def decode_parquet(data: bytes):
- return pd.read_parquet(BytesIO(data)).to_dict(orient='records')
-
-decoders["application/parquet"] = decode_parquet
-```
-
-### Preprocessing Function
-The pre_process_fn function in SpockFlow facilitates the transformation of JSON data into a format suitable for downstream processing, typically a Pandas DataFrame. Its primary role is to prepare input data for enrichment, scoring, or other operations supported by SpockFlow. A typical implementation is as follows:
-```python
-# inference.py
-import pandas as pd
-
-def pre_process_fn(input_data: dict) -> pd.DataFrame:
- return pd.json_normalize(input_data)
-```
-
-### Output Function
-
-The `output_fn` function in SpockFlow's API deployment handles the final step of encoding prediction results into various output formats based on the `accept` header of API responses. It accepts two parameters:
-
-- **prediction**: The dictionary containing the processed prediction results.
-- **accept**: Specifies the desired format or encoding of the response, guiding the encoding process.
-
-The primary role of `output_fn` is to format prediction results into the specified output format, ensuring compatibility with client expectations. This function is crucial for delivering responses in formats such as JSON, CSV, or custom formats like Parquet.
-
-**Example Implementation:**
-
-```python
-import pandas as pd
-from starlette.responses import JSONResponse
-
-def output_fn(prediction: dict, accept: str) -> "JSONResponse":
- """
- Encodes prediction results into JSON format based on the accept header.
-
- Parameters:
- - prediction (dict): Processed prediction results to be encoded.
- - accept (str): Desired format or encoding specified by the accept header.
-
- Returns:
- - JSONResponse: Response object containing encoded prediction results.
- """
- if accept not in ["*/*", "application/json"]:
- raise ValueError("Invalid accept header")
-
- return JSONResponse(prediction)
-```
-
-
-#### Post-process Function
-
-In conjunction with `output_fn`, the `post_process_fn` function is commonly used to manipulate prediction results before they are encoded. It adjusts the structure or content of the output data to meet specific client needs.
-
-**Example Implementation:**
-
-```python
-def post_process_fn(result: dict) -> dict:
- """
- Post-processes prediction result before encoding.
-
- Parameters:
- - result (dict): Raw prediction result from the model.
-
- Returns:
- - dict: Processed prediction result ready for encoding.
- """
- return {"response": result["tree"]}
-```
-
-The `post_process_fn` function in SpockFlow allows developers to refine prediction outputs, ensuring they meet application-specific requirements before encoding them into the desired format using `output_fn`.
-
-#### Encoders Example
-
-Developers can extend the functionality of SpockFlow's default `output_fn` by defining custom encoders for additional output formats. The `output_fn` function utilizes the encoders dictionary to determine the output format based on the accept header in API responses. Below is an example of adding support for Parquet format using a custom encoder:
-
-```python
-import pandas as pd
-from starlette.responses import Response
-from spockflow.inference.io.encoders import default_encoders
-
-def to_parquet(result: dict) -> Response:
- """
- Encodes prediction result into Parquet format.
-
- Parameters:
- - result (dict): Processed prediction result to be encoded.
-
- Returns:
- - Response: Response object containing Parquet-encoded prediction result.
- """
- res = BytesIO()
- pd.json_normalize(result).to_parquet(res)
- res.seek(0)
- return Response(res.read(), media_type="application/vnd.apache.parquet")
-
-# Initialize encoders with default encoders
-encoders = {**default_encoders}
-
-# Add support for Parquet format
-encoders["application/vnd.apache.parquet"] = to_parquet
-```
-
-## Configuration Management in SpockFlow
-
-SpockFlow provides flexible configuration management capabilities to facilitate the setup and management of models and their associated configurations. Developers can utilize built-in managers like the `YamlConfigManager` or create custom managers tailored to specific needs, such as reading configurations from DynamoDB or other data sources.
-
-### Using YamlConfigManager
-
-The `YamlConfigManager` provided by SpockFlow simplifies configuration management using YAML files. It allows for straightforward handling of model configurations stored locally. Here's an example of configuring `model_config_cls` to use `YamlConfigManager`:
-
-```python
-from spockflow.inference.config.loader.yamlmanager import YamlConfigManager
-model_config_cls = YamlConfigManager
-```
-
-### Custom Config Managers
-
-Developers have the flexibility to implement custom config managers to suit unique requirements. For instance, a custom config manager could integrate with DynamoDB to dynamically fetch configurations. Below is a simplified outline of how a custom manager might be structured:
-
-```python
-# Example of a custom config manager (simplified)
-
-class CustomConfigManager(ConfigManager):
- def get_latest_version(self, model_name: str) -> str:
- # Implement logic to fetch the latest version from DynamoDB or other sources
- pass
-
- def get_config(self, model_name: str, model_version: str) -> TNamespacedConfig:
- # Implement logic to fetch config from DynamoDB or other sources
- pass
-
- def save_to_config(self, model_name: str, model_version: str, namespace: str, config: TNamespacedConfig, key: str | None = None):
- # Implement logic to save config to DynamoDB or other sources
- pass
-```
-
-### YamlConfigManager Implementation
-
-For reference, here is a simplified implementation of `YamlConfigManager` provided by SpockFlow, demonstrating its capability to manage YAML-based configurations locally:
-
-```python
-from yaml import dump, load
-try:
- from yaml import CDumper as Dumper, CLoader as Loader
-except ImportError:
- from yaml import Dumper, Loader
-import os
-from .base import ConfigManager, TNamespacedConfig
-from pydantic_settings import BaseSettings, SettingsConfigDict
-
-class YamlConfigManager(ConfigManager, BaseSettings):
- model_config = SettingsConfigDict(case_sensitive=False, env_prefix='CONFIG_MANAGER_')
- config_path: str = os.path.join(".", "config")
-
- def model_path(self, model_name: str) -> str:
- return os.path.join(self.config_path, model_name)
-
- def get_latest_version(self, model_name: str) -> str:
- if model_name == "__default__":
- paths = os.listdir(self.config_path)
- if model_name not in paths:
- assert len(paths) == 1, "can only use default when there is one model or an explicit __default__"
- model_name = paths[0]
- return sorted(os.listdir(self.model_path(model_name)))[-1]
-
- def get_config(self, model_name: str, model_version: str) -> TNamespacedConfig:
- r = {}
- for f in glob(os.path.join(self.model_path(model_name), model_version, "*.yml")):
- ns = os.path.splitext(os.path.split(f)[1])[0]
- with open(f) as fp:
- r[ns] = load(fp, Loader=Loader)
- return r
-
- def save_to_config(self, model_name: str, model_version: str, namespace: str, config: TNamespacedConfig, key: str | None = None):
- save_path = os.path.join(self.model_path(model_name), model_version, namespace) + ".yml"
- if key is not None:
- if os.path.isfile(save_path):
- with open(save_path) as fp:
- curr_config = load(fp, Loader=Loader)
- curr_config[key] = config
- config = curr_config
- os.makedirs(os.path.split(save_path)[0], exist_ok=True)
- with open(save_path, "w") as fp:
- dump(config, fp, Dumper=Dumper)
-```
-
-### Conclusion
-
-SpockFlow's configuration management capabilities, exemplified by the `YamlConfigManager` and the possibility of creating custom managers, provide robust support for handling model configurations across different use cases. Whether utilizing YAML files locally or integrating with external data sources like DynamoDB, developers have the tools necessary to effectively manage and deploy models within SpockFlow.
-
-
-## Advanced Methods
-While the previously discussed methods cover the majority of use cases, SpockFlow also offers advanced customization options for developers seeking more specialized functionalities. These advanced methods provide additional flexibility and control over API deployment processes, enabling tailored solutions for specific requirements.
-### Predict Function
-The `predict_fn` function in SpockFlow is pivotal for executing model predictions using a specified `Driver` object. By default, it leverages `raw_execute` to provide flexibility in handling prediction outputs, returning results as a dictionary. However, developers can override this behavior, as shown in the example below, to use `execute` instead. This alternative approach directly returns a DataFrame, eliminating the need for a separate `post_process_fn` for data transformation.
-
-### Example Implementation
-
-```python
-from spockflow.inference.handler import WrappedInputData
-from spockflow.core import Driver
-
-def predict_fn(input_data: "WrappedInputData", model: "Driver") -> pd.DataFrame:
- """
- Executes model prediction using SpockFlow's Driver object with direct DataFrame output.
-
- Parameters:
- - input_data (WrappedInputData): Wrapped input data containing model inputs.
- - model (Driver): SpockFlow's Driver object representing the model to execute.
-
- Returns:
- - pd.DataFrame: DataFrame containing the prediction results.
- """
- return model.execute(
- inputs=input_data.data,
- overrides=input_data.input_overrides
- )
-```
-
-In this revised example, `predict_fn` overrides the default behavior by using `execute` instead of `raw_execute`. This method directly returns a DataFrame containing prediction results. By adopting this approach, developers streamline the prediction process within SpockFlow, ensuring efficient data processing without the need for additional transformation steps typically handled by a `post_process_fn`.
diff --git a/docs/concepts/modules.md b/docs/concepts/modules.md
new file mode 100644
index 0000000..8183c99
--- /dev/null
+++ b/docs/concepts/modules.md
@@ -0,0 +1,54 @@
+# Modules
+
+A **module** is the fundamental unit of computation in Decider. It takes a Polars `DataFrame` as input and returns a `DataFrame` with new or transformed columns.
+
+## Defining a module from functions
+
+Use `generate_from_functions` to turn plain Python functions into a module. Each function becomes a computed column; its parameter names map to input columns.
+
+```python
+import polars as pl
+from decider import generate_from_functions
+
+def score(income: float, debt: float) -> float:
+ return income / (debt + 1)
+
+def risk_band(score: float) -> str:
+ if score > 5:
+ return "low"
+ elif score > 2:
+ return "medium"
+ return "high"
+
+Scorer = generate_from_functions(score, risk_band, name="Scorer")
+```
+
+## Injecting configuration
+
+Add a `config` parameter typed with a Pydantic model to inject versioned parameters:
+
+```python
+from pydantic import BaseModel
+
+class ScorerConfig(BaseModel):
+ threshold: float = 5.0
+
+def risk_band(score: float, config: ScorerConfig) -> str:
+ return "low" if score > config.threshold else "high"
+```
+
+## Module types
+
+| Type | Description |
+|---|---|
+| `ExpressionModule` | Parallel column computation from functions |
+| `SequentialModule` | Ordered chain of modules (output feeds next) |
+| `JoinModule` | Merges two module outputs side-by-side |
+| `GraphModule` | JSON-serialisable wrapper for any module |
+
+## Running a module
+
+```python
+df = pl.DataFrame({"income": [50000, 30000], "debt": [10000, 25000]})
+result = Scorer.load({}).run(df)
+```
diff --git a/docs/concepts/pipelines.md b/docs/concepts/pipelines.md
new file mode 100644
index 0000000..37756e5
--- /dev/null
+++ b/docs/concepts/pipelines.md
@@ -0,0 +1,48 @@
+# Pipelines
+
+Modules compose into pipelines using `|` (sequential) and `&` (join) operators.
+
+## Sequential pipelines (`|`)
+
+`A | B` passes A's output as input to B. Use this when each stage enriches the data for the next.
+
+```python
+from decider import generate_from_functions, SequentialModule
+
+Enricher = generate_from_functions(enrich_fn, name="Enricher")
+Scorer = generate_from_functions(score_fn, risk_fn, name="Scorer")
+
+Pipeline = Enricher | Scorer
+result = Pipeline.load({}).run(df)
+```
+
+## Join pipelines (`&`)
+
+`A & B` runs both modules on the same input and merges their outputs column-wise.
+
+```python
+CreditScore = generate_from_functions(credit_fn, name="CreditScore")
+AffordabilityScore = generate_from_functions(afford_fn, name="AffordabilityScore")
+
+Combined = CreditScore & AffordabilityScore
+```
+
+## Nesting
+
+Pipelines can be nested to arbitrary depth:
+
+```python
+pipeline = (StageA | StageB) & StageC | FinalStage
+```
+
+## Named pipelines
+
+Wrap any composed pipeline as a `GraphModule` to give it a name and make it JSON-serialisable:
+
+```python
+from decider import register_graph_module
+
+@register_graph_module
+class LoanDecision(SequentialModule):
+ steps = [Enricher, Scorer]
+```
diff --git a/docs/concepts/scorecard.ipynb b/docs/concepts/scorecard.ipynb
deleted file mode 100644
index 0be9ac1..0000000
--- a/docs/concepts/scorecard.ipynb
+++ /dev/null
@@ -1,493 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## SpockFlow Scorecard Component\n",
- "\n",
- "The SpockFlow Scorecard component facilitates the creation of scoring rules for both numerical and categorical data, allowing users to define criteria that assign scores and descriptive labels based on specified conditions. This component is particularly useful for evaluating data against predefined thresholds or patterns, providing insights through structured outputs.\n",
- "\n",
- "### Usage\n",
- "\n",
- "To begin using the Scorecard component, import the necessary packages and instantiate the `ScoreCard` object:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "from spockflow.components import scorecard\n",
- "\n",
- "var_1 = \"var_1\"\n",
- "var_2 = \"var_2\"\n",
- "\n",
- "sc = scorecard.ScoreCard(\n",
- " bin_prefix=\"SCORE_BIN_\",\n",
- " score_prefix=\"SCORE_VALUE_\",\n",
- " description_prefix=\"SCORE_DESC_\",\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "### Adding Criteria\n",
- "\n",
- "Criteria can be added to the `ScoreCard` object using the `add_criteria` method. Each criterion defines how to evaluate a specific variable and assign scores based on conditions. There are two types of criteria: numerical and categorical.\n",
- "\n",
- "#### Numerical Criteria\n",
- "\n",
- "Numerical criteria evaluate numeric variables and can define score ranges, discrete values, and default behaviors:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "ScoreCardModel(bin_prefix='SCORE_BIN_', score_prefix='SCORE_VALUE_', description_prefix='SCORE_DESC_', variable_params=[ScoreCriteriaNumerical(variable='var_1', other_score=DefaultScorePattern(group_id=3, score=73.0, description='default'), discrete_scores=[NumericalDiscreteScorePattern(values=[nan], group_id=2, score=73.0, description='missing')], range_scores=[RangeScorePattern(range=MatchRange(start=0.0, end=1.0), group_id=0, score=10.0, description='First bound var_1'), RangeScorePattern(range=MatchRange(start=1.0, end=2.0), group_id=1, score=30.0, description='2nd bound var_1')], type='numerical', included_bounds=(,))], version='2.2.0', score_scaling_params=None)"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sc.add_criteria(\n",
- " scorecard.ScoreCriteria(var_1, \"numerical\")\n",
- " .add_range_score(0, 1, 10, \"First bound var_1\")\n",
- " .add_range_score(1, 2, 30, \"2nd bound var_1\")\n",
- " .add_discrete_score([None], 73, \"missing\")\n",
- " .set_other_score(73, \"default\")\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "- **Range Scores**: Assign scores based on numeric ranges. For example, `add_range_score(0, 1, 10, \"First bound var_1\")` assigns a score of 10 to `SCORE_VALUE_var_1` when `0 <= var_1 < 1`.\n",
- "- **Discrete Scores**: Assign scores for specific values. `add_discrete_score([None], 73, \"missing\")` assigns a score of 73 to `SCORE_VALUE_var_1` when `var_1` is `None`.\n",
- "- **Default Score**: Set a default score and description for unmatched values. `set_other_score(73, \"default\")` assigns a score of 73 to `SCORE_VALUE_var_1` for all other cases."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Categorical Criteria\n",
- "\n",
- "Categorical criteria evaluate text variables and can define exact matches or patterns (regex):\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "ScoreCardModel(bin_prefix='SCORE_BIN_', score_prefix='SCORE_VALUE_', description_prefix='SCORE_DESC_', variable_params=[ScoreCriteriaNumerical(variable='var_1', other_score=DefaultScorePattern(group_id=3, score=73.0, description='default'), discrete_scores=[NumericalDiscreteScorePattern(values=[nan], group_id=2, score=73.0, description='missing')], range_scores=[RangeScorePattern(range=MatchRange(start=0.0, end=1.0), group_id=0, score=10.0, description='First bound var_1'), RangeScorePattern(range=MatchRange(start=1.0, end=2.0), group_id=1, score=30.0, description='2nd bound var_1')], type='numerical', included_bounds=(,)), ScoreCriteriaCategorical(variable='var_2', other_score=None, discrete_scores=[CategoricalDiscreteScorePattern(values=['a', 'b', 'c'], group_id=0, score=10.0, description='First pattern var_2'), CategoricalDiscreteScorePattern(values=['[b-z]'], group_id=1, score=20.0, description='Second pattern var_2')], type='categorical', default_behavior='regex')], version='2.2.0', score_scaling_params=None)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sc.add_criteria(\n",
- " scorecard.ScoreCriteria(var_2, \"categorical\", default_behavior=\"regex\")\n",
- " .add_discrete_score([\"a\", \"b\", \"c\"], 10, \"First pattern var_2\")\n",
- " .add_discrete_score([\"[b-z]\"], 20, \"Second pattern var_2\")\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "- **Exact Matches**: Assign scores based on exact text matches. `add_discrete_score(['a', 'b', 'c'], 10, \"First pattern var_2\")` assigns a score of 10 to `SCORE_VALUE_var_2` when `var_2` is 'a', 'b', or 'c'.\n",
- "- **Regex Matches**: Evaluate variables using regex patterns. `add_discrete_score(['[b-z]'], 20, \"Second pattern var_2\")` assigns a score of 20 to `SCORE_VALUE_var_2` when `var_2` matches the regex pattern `[b-z]`.\n",
- "\n",
- "### Automatic Binning\n",
- "\n",
- "By default, the Scorecard component automatically determines bin categories (`SCORE_BIN_var_1`, `SCORE_BIN_var_2`) based on the order in which criteria are added. The first criterion added determines `SCORE_BIN_var_1`, the second criterion determines `SCORE_BIN_var_2`, and so forth. These bin categories categorize input values based on the criteria matched.\n",
- "\n",
- "#### Overriding Bins\n",
- "\n",
- "Bins can be overridden if needed. This allows users to customize bin categories or reorder them for specific requirements. The `override_idx` parameter in the `add_range_score`, `add_discrete_score`, and `set_default` methods allows specifying the index of the bin to override:\n",
- "\n",
- "```python\n",
- "sc.add_criteria(\n",
- " scorecard.ScoreCriteria(var_1, \"numerical\")\n",
- " .add_range_score(0, 1, 10, \"First bound var_1\", override_idx=5)\n",
- ")\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "### Execution and Results\n",
- "\n",
- "To execute the Scorecard on a dataset, use the `execute` method:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
"
- ],
- "text/plain": [
- " SCORE_BIN_var_1 SCORE_VALUE_var_1 SCORE_DESC_var_1 SCORE_BIN_var_2 \\\n",
- "0 0 10.0 First bound var_1 0 \n",
- "1 1 30.0 2nd bound var_1 1 \n",
- "2 3 73.0 default 1 \n",
- "3 2 73.0 missing 0 \n",
- "4 0 10.0 First bound var_1 0 \n",
- "5 1 30.0 2nd bound var_1 1 \n",
- "6 3 73.0 default 1 \n",
- "7 2 73.0 missing -1 \n",
- "\n",
- " SCORE_VALUE_var_2 SCORE_DESC_var_2 SCORE_VALUE_SUM \n",
- "0 10.0 First pattern var_2 20.0 \n",
- "1 20.0 Second pattern var_2 50.0 \n",
- "2 20.0 Second pattern var_2 93.0 \n",
- "3 10.0 First pattern var_2 83.0 \n",
- "4 10.0 First pattern var_2 20.0 \n",
- "5 20.0 Second pattern var_2 50.0 \n",
- "6 20.0 Second pattern var_2 93.0 \n",
- "7 -1.0 None 72.0 "
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "result_df = sc.execute(inputs=test_data)\n",
- "result_df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The resulting DataFrame (`result_df`) will contain columns for each score, bin, and description based on the evaluated criteria.\n",
- "\n",
- "### Saving and Loading Configurations\n",
- "\n",
- "To save the configuration of the Scorecard for future use or deployment, use a configuration manager such as `YamlConfigManager`:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "from spockflow.inference.config.loader.yamlmanager import YamlConfigManager\n",
- "\n",
- "conf_manager = YamlConfigManager()\n",
- "conf_manager.save_to_config(\n",
- " model_name=\"demo_spock_model\",\n",
- " model_version=\"1.0.0\",\n",
- " namespace=\"scorecard_config\",\n",
- " config=sc.model_dump(mode=\"json\"),\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Loading and Using Configurations\n",
- "\n",
- "Load a saved Scorecard configuration from a YAML file and instantiate the `ScoreCard` object using `from_config`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "46739ea3aa8241d6a8de38360150ed4d",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "VBox(children=(GridspecLayout(children=(Text(value='SCORE_BIN_', description='Bin Prefix:', layout=Layout(grid…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "config = conf_manager.get_config(\"demo_spock_model\", \"1.0.0\")[\"scorecard_config\"]\n",
- "sc_loaded = scorecard.ScoreCard.from_config(\"\").load(config)\n",
- "\n",
- "# Retrieve view model and display widget\n",
- "vm = sc_loaded.get_view_model()\n",
- "widget = vm.get_widget()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "spock",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/concepts/variable_nodes.md b/docs/concepts/variable_nodes.md
deleted file mode 100644
index d1ecd5e..0000000
--- a/docs/concepts/variable_nodes.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Variable Nodes
-Coming Soon. This section will discuss:
-- How Variable nodes fit into the Spockflow framework
-- Common functionality across variable nodes:
- - How to store and load from config
- - Cloning and aliasing
-- How to create your own custom node type
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index d42c9df..4dabef9 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,62 +1,47 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# For the full list of built-in configuration values, see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
import os
import sys
sys.path.insert(0, os.path.abspath(".."))
-
-# -- Project information -----------------------------------------------------
-# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
-
-project = "SpockFlow"
-copyright = "2024, Sholto Armstrong"
+project = "Decider"
+copyright = "2024-2026, Capitec"
author = "Sholto Armstrong"
-import spockflow
-version = str(spockflow.__version__)
+import decider
+version = decider.__version__
release = version
-# -- General configuration ---------------------------------------------------
-# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
-
extensions = [
"myst_nb",
+ "sphinx_design",
"sphinx.ext.napoleon",
- "sphinx.ext.autodoc", # Core library for html generation from docstrings
- "sphinx.ext.autosummary", # Create neat summary tables
- "sphinxcontrib.confluencebuilder",
- "sphinx_sitemap", # Welcome robots to the website
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx_sitemap",
]
-# AutoDoc Conf (Thanks https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion/blob/master/docs/conf.py)
-autosummary_generate = True # Turn on sphinx.ext.autosummary
-autoclass_content = "both" # Add __init__ doc (ie. params) to class summaries
-html_show_sourcelink = (
- False # Remove 'view source code' from top of page (for html, not python)
-)
-autodoc_inherit_docstrings = True # If no docstring, inherit from base class
-set_type_checking_flag = True # Enable 'expensive' imports for sphinx_autodoc_typehints
-nbsphinx_allow_errors = True # Continue through Jupyter errors
-# autodoc_typehints = "description" # Sphinx-native method. Not as good as sphinx_autodoc_typehints
-add_module_names = False # Remove namespaces from class/method signatures
-
+autosummary_generate = True
+autoclass_content = "both"
+html_show_sourcelink = False
+autodoc_inherit_docstrings = True
+add_module_names = False
+autodoc_mock_imports = ["hamilton", "streamlit"]
+nb_execution_mode = "off" # don't execute notebooks during build
+
+myst_enable_extensions = [
+ "colon_fence",
+ "deflist",
+ "dollarmath",
+]
templates_path = ["_templates"]
-exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
-
-# -- Options for HTML output -------------------------------------------------
-# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
-
-# html_theme = 'alabaster'
html_static_path = ["_static"]
-
html_theme = "furo"
-html_title = "SpockFlow"
+html_title = "Decider"
html_theme_options = {
+ "source_repository": "https://github.com/capitecbankltd/dsp_north-polrs",
"source_branch": "main",
"source_directory": "docs/",
"light_css_variables": {
@@ -69,23 +54,9 @@
},
}
+# anchor links inside {include}'d CONTRIBUTING.md don't resolve cross-doc
+suppress_warnings = ["myst.xref_missing"]
-# for the sitemap extension ---
-# check if the current commit is tagged as a release (vX.Y.Z) and set the version
language = "en"
-html_baseurl = "https://capitec.github.io/ml-decision-engine"
+html_baseurl = "https://capitecbankltd.github.io/dsp_north-polrs"
html_extra_path = ["robots.txt"]
-
-confluence_config_path = os.path.split(__file__)[0] + "/confluence.json"
-if os.path.isfile(confluence_config_path):
- import json
-
- with open(confluence_config_path) as fp:
- conf_config = json.load(fp)
- confluence_publish = conf_config["confluence_publish"]
- confluence_parent_page = conf_config["confluence_parent_page"]
- confluence_space_key = conf_config["confluence_space_key"]
- confluence_ask_password = conf_config["confluence_ask_password"]
- confluence_server_url = conf_config["confluence_server_url"]
- confluence_server_user = conf_config["confluence_server_user"]
- confluence_server_cookies = conf_config["confluence_server_cookies"]
diff --git a/docs/contributing/index.md b/docs/contributing/index.md
new file mode 100644
index 0000000..18b7c36
--- /dev/null
+++ b/docs/contributing/index.md
@@ -0,0 +1,4 @@
+# Contributing
+
+```{include} ../../CONTRIBUTING.md
+```
diff --git a/docs/examples/01_getting_started.ipynb b/docs/examples/01_getting_started.ipynb
new file mode 100644
index 0000000..d40fbd0
--- /dev/null
+++ b/docs/examples/01_getting_started.ipynb
@@ -0,0 +1,204 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "m1",
+ "metadata": {},
+ "source": [
+ "# Getting Started with Decider\n",
+ "\n",
+ "Decider is a Python framework for building composable, config-driven decision pipelines on top of Polars DataFrames. You define plain Python functions that return `pl.Expr`, and Decider automatically wires their dependencies, compiles them into an optimised expression graph, and executes them in a single Polars pass. Pipelines can be saved to versioned JSON configs and hot-swapped at serving time without any code changes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "\n",
+ "import polars as pl\n",
+ "from decider.modules.functional import generate_from_functions\n",
+ "\n",
+ "print('Decider imported successfully')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m2",
+ "metadata": {},
+ "source": [
+ "## Your First Module\n",
+ "\n",
+ "Define one or more plain functions that return `pl.Expr`. Each function name becomes an output column. Parameter names map to input columns (or to sibling function outputs \u2014 more on that shortly)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# --- define functions -------------------------------------------------------\n",
+ "# Each function: name -> output column, params -> input columns\n",
+ "\n",
+ "def score(amount: pl.Expr, rate: pl.Expr) -> pl.Expr:\n",
+ " \"\"\"Raw score: amount adjusted by rate.\"\"\"\n",
+ " return amount * rate\n",
+ "\n",
+ "def risk_flag(score: pl.Expr) -> pl.Expr:\n",
+ " \"\"\"Flag rows where score exceeds 80.\"\"\"\n",
+ " return score > pl.lit(80.0)\n",
+ "\n",
+ "# --- generate module class --------------------------------------------------\n",
+ "# Pass a lowercase type-id string, then the functions.\n",
+ "BasicScorer = generate_from_functions('basic_scorer', score, risk_flag)\n",
+ "\n",
+ "# --- instantiate (name= is required) ----------------------------------------\n",
+ "scorer = BasicScorer(name='my_scorer')\n",
+ "\n",
+ "print('Module type identifier:', scorer.type)\n",
+ "print('Module name:', scorer.name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# --- run on a DataFrame -----------------------------------------------------\n",
+ "df = pl.DataFrame({\n",
+ " 'customer_id': ['C1', 'C2', 'C3', 'C4'],\n",
+ " 'amount': [10.0, 50.0, 90.0, 120.0],\n",
+ " 'rate': [0.5, 1.2, 1.0, 0.8],\n",
+ "})\n",
+ "\n",
+ "result = scorer({'input': df})\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m3",
+ "metadata": {},
+ "source": [
+ "Notice that `score` and `risk_flag` are added as new columns alongside the original data. The module returns a Polars `DataFrame` (or `LazyFrame` depending on the executor \u2014 call `.collect()` if needed).\n",
+ "\n",
+ "## Dependency Wiring\n",
+ "\n",
+ "If one function's parameter name matches another function's `__name__`, Decider automatically injects the upstream expression \u2014 no manual wiring required."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def score(amount: pl.Expr, rate: pl.Expr) -> pl.Expr:\n",
+ " return amount * rate\n",
+ "\n",
+ "def risk_flag(score: pl.Expr) -> pl.Expr:\n",
+ " return score > pl.lit(80.0)\n",
+ "\n",
+ "# --- third function depends on 'score' (sibling output) ---------------------\n",
+ "def score_normalised(score: pl.Expr) -> pl.Expr:\n",
+ " \"\"\"Normalise score to [0, 1] using a fixed max of 200.\"\"\"\n",
+ " return score / pl.lit(200.0)\n",
+ "\n",
+ "# Decider resolves the 'score' parameter in risk_flag and score_normalised\n",
+ "# to the output of the 'score' function \u2014 automatically.\n",
+ "ScorerV2 = generate_from_functions(\n",
+ " 'scorer_v2',\n",
+ " score,\n",
+ " risk_flag,\n",
+ " score_normalised,\n",
+ ")\n",
+ "scorer_v2 = ScorerV2(name='scorer_v2')\n",
+ "\n",
+ "result_v2 = scorer_v2({'input': df})\n",
+ "print(result_v2.select(['customer_id', 'score', 'risk_flag', 'score_normalised']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m4",
+ "metadata": {},
+ "source": [
+ "## Config Injection\n",
+ "\n",
+ "Add a parameter named `config` annotated with a Pydantic model. Decider promotes all config fields onto the module itself, so you set them at instantiation time and they are injected automatically at call time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pydantic import BaseModel\n",
+ "\n",
+ "class PremiumConfig(BaseModel):\n",
+ " base_rate: float = 1.0\n",
+ " multiplier: float = 2.0\n",
+ "\n",
+ "def score(amount: pl.Expr, rate: pl.Expr) -> pl.Expr:\n",
+ " return amount * rate\n",
+ "\n",
+ "def premium_score(score: pl.Expr, config: PremiumConfig) -> pl.Expr:\n",
+ " \"\"\"Scale score by a configurable multiplier on top of a base rate.\"\"\"\n",
+ " return score * pl.lit(config.base_rate) * pl.lit(config.multiplier)\n",
+ "\n",
+ "PremiumScorer = generate_from_functions('premium_scorer', score, premium_score)\n",
+ "\n",
+ "# Config fields (base_rate, multiplier) appear directly on the module\n",
+ "m = PremiumScorer(name='premium', base_rate=1.5, multiplier=3.0)\n",
+ "\n",
+ "print('base_rate :', m.base_rate)\n",
+ "print('multiplier :', m.multiplier)\n",
+ "print('Model fields:', list(type(m).model_fields.keys()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "result_premium = m({'input': df})\n",
+ "print(result_premium.select(['customer_id', 'amount', 'rate', 'score', 'premium_score']))"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/02_pipelines_and_joins.ipynb b/docs/examples/02_pipelines_and_joins.ipynb
new file mode 100644
index 0000000..2a24b69
--- /dev/null
+++ b/docs/examples/02_pipelines_and_joins.ipynb
@@ -0,0 +1,260 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "m1",
+ "metadata": {},
+ "source": [
+ "# Pipelines and Joins\n",
+ "\n",
+ "This notebook covers the two composition operators and the `JoinModule` for multi-frame pipelines:\n",
+ "\n",
+ "- `|` (pipe) \u2014 creates a `SequentialModule` where each step receives the previous step's output as `\"input\"`\n",
+ "- `&` (merge) \u2014 creates a `UnionExpressionModule` that compiles all expression nodes in a single Polars pass\n",
+ "- `JoinModule` \u2014 joins two named input frames before passing the result downstream"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "\n",
+ "import polars as pl\n",
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.modules.primitives.join import JoinModule\n",
+ "from decider.modules.primitives.sequential import SequentialModule\n",
+ "\n",
+ "print('Imports OK')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m2",
+ "metadata": {},
+ "source": [
+ "## Sequential Pipelines with `|`\n",
+ "\n",
+ "Use `|` to chain modules. The output of each step becomes the `\"input\"` frame for the next step, so downstream modules can reference columns produced upstream."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Step 1 \u2014 feature engineering\n",
+ "def log_amount(amount: pl.Expr) -> pl.Expr:\n",
+ " return amount.log(base=10)\n",
+ "\n",
+ "def amount_squared(amount: pl.Expr) -> pl.Expr:\n",
+ " return amount ** 2\n",
+ "\n",
+ "Features = generate_from_functions('features', log_amount, amount_squared)\n",
+ "features = Features(name='features')\n",
+ "\n",
+ "# Step 2 \u2014 scoring (reads log_amount produced by step 1)\n",
+ "def raw_score(log_amount: pl.Expr) -> pl.Expr:\n",
+ " return log_amount * pl.lit(50.0)\n",
+ "\n",
+ "Scoring = generate_from_functions('scoring', raw_score)\n",
+ "scoring = Scoring(name='scoring')\n",
+ "\n",
+ "# Step 3 \u2014 flags (reads raw_score produced by step 2)\n",
+ "def high_value_flag(raw_score: pl.Expr) -> pl.Expr:\n",
+ " return raw_score > pl.lit(80.0)\n",
+ "\n",
+ "def score_band(raw_score: pl.Expr) -> pl.Expr:\n",
+ " return (\n",
+ " pl.when(raw_score >= pl.lit(100.0)).then(pl.lit('platinum'))\n",
+ " .when(raw_score >= pl.lit(80.0)).then(pl.lit('gold'))\n",
+ " .otherwise(pl.lit('standard'))\n",
+ " )\n",
+ "\n",
+ "Flags = generate_from_functions('flags', high_value_flag, score_band)\n",
+ "flags = Flags(name='flags')\n",
+ "\n",
+ "# Chain all three steps\n",
+ "pipeline = features | scoring | flags\n",
+ "\n",
+ "print('Pipeline type :', type(pipeline).__name__)\n",
+ "print('Steps :', [s.name for s in pipeline.steps])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = pl.DataFrame({\n",
+ " 'customer_id': ['C1', 'C2', 'C3', 'C4'],\n",
+ " 'amount': [10.0, 100.0, 1000.0, 10000.0],\n",
+ "})\n",
+ "\n",
+ "result = pipeline({'input': df})\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m3",
+ "metadata": {},
+ "source": [
+ "Each step's output columns are forwarded to the next step. `scoring` can read `log_amount` because step 1 produced it; `flags` can read `raw_score` because step 2 produced it.\n",
+ "\n",
+ "## Merging Parallel Branches with `&`\n",
+ "\n",
+ "Use `&` to merge two independent `ExpressionModule` branches. All expression nodes from both modules are compiled into a single Polars `.with_columns()` pass \u2014 more efficient than two separate steps."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Branch A \u2014 transaction features\n",
+ "def txn_velocity(txn_count: pl.Expr) -> pl.Expr:\n",
+ " \"\"\"How fast is this customer transacting?\"\"\"\n",
+ " return txn_count / pl.lit(30.0) # transactions per day over 30-day window\n",
+ "\n",
+ "def avg_txn_value(total_spend: pl.Expr, txn_count: pl.Expr) -> pl.Expr:\n",
+ " return total_spend / txn_count\n",
+ "\n",
+ "TxnFeatures = generate_from_functions('txn_features', txn_velocity, avg_txn_value)\n",
+ "txn_features = TxnFeatures(name='txn_features')\n",
+ "\n",
+ "# Branch B \u2014 behavioural features (independent of Branch A)\n",
+ "def login_rate(login_count: pl.Expr) -> pl.Expr:\n",
+ " return login_count / pl.lit(30.0)\n",
+ "\n",
+ "def engagement_score(login_count: pl.Expr, txn_count: pl.Expr) -> pl.Expr:\n",
+ " return (login_count * pl.lit(0.4)) + (txn_count * pl.lit(0.6))\n",
+ "\n",
+ "BehavFeatures = generate_from_functions('behav_features', login_rate, engagement_score)\n",
+ "behav_features = BehavFeatures(name='behav_features')\n",
+ "\n",
+ "# Merge both branches \u2014 single Polars pass\n",
+ "merged = txn_features & behav_features\n",
+ "\n",
+ "print('Merged type:', type(merged).__name__)\n",
+ "\n",
+ "customer_df = pl.DataFrame({\n",
+ " 'customer_id': ['C1', 'C2', 'C3'],\n",
+ " 'txn_count': [5, 20, 3],\n",
+ " 'total_spend': [250.0, 1800.0, 90.0],\n",
+ " 'login_count': [12, 30, 2],\n",
+ "})\n",
+ "\n",
+ "result_merged = merged({'input': customer_df})\n",
+ "print(result_merged)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m5",
+ "metadata": {},
+ "source": [
+ "## JoinModule\n",
+ "\n",
+ "`JoinModule` joins two named input frames before passing the combined frame downstream. The `left` and `right` fields are string keys into the input dict."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Two separate frames\n",
+ "txns_df = pl.DataFrame({\n",
+ " 'user_id': ['U1', 'U1', 'U2', 'U3'],\n",
+ " 'txn_id': ['T1', 'T2', 'T3', 'T4'],\n",
+ " 'amount': [50.0, 30.0, 200.0, 15.0],\n",
+ "})\n",
+ "\n",
+ "users_df = pl.DataFrame({\n",
+ " 'user_id': ['U1', 'U2', 'U3'],\n",
+ " 'credit_tier': ['gold', 'platinum', 'standard'],\n",
+ " 'income': [60000.0, 120000.0, 35000.0],\n",
+ "})\n",
+ "\n",
+ "# Join step\n",
+ "join = JoinModule(\n",
+ " name='join_txns_users',\n",
+ " left='txns',\n",
+ " right='users',\n",
+ " on='user_id',\n",
+ " how='left',\n",
+ ")\n",
+ "\n",
+ "# Scorer runs on the joined frame\n",
+ "def spend_to_income_ratio(amount: pl.Expr, income: pl.Expr) -> pl.Expr:\n",
+ " return amount / income\n",
+ "\n",
+ "TxnScorer = generate_from_functions('txn_scorer', spend_to_income_ratio)\n",
+ "txn_scorer = TxnScorer(name='txn_scorer')\n",
+ "\n",
+ "# Chain: join first, then score\n",
+ "join_pipeline = join | txn_scorer\n",
+ "\n",
+ "result_join = join_pipeline({'txns': txns_df, 'users': users_df})\n",
+ "print(result_join)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m6",
+ "metadata": {},
+ "source": [
+ "## Multi-Frame Input: `get_input_frame_keys()`\n",
+ "\n",
+ "You can inspect which frame keys a pipeline expects from its input dict. This is useful for documentation and validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print('Single scorer input keys :', txn_scorer.get_input_frame_keys())\n",
+ "print('Join module input keys :', join.get_input_frame_keys())\n",
+ "print('Full pipeline input keys :', join_pipeline.get_input_frame_keys())"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/03_config_and_persistence.ipynb b/docs/examples/03_config_and_persistence.ipynb
new file mode 100644
index 0000000..74f48d6
--- /dev/null
+++ b/docs/examples/03_config_and_persistence.ipynb
@@ -0,0 +1,227 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "m1",
+ "metadata": {},
+ "source": [
+ "# Config and Persistence\n",
+ "\n",
+ "In production, decision models need to evolve without code deploys. Decider stores every module as a versioned JSON config, so you can update weights, thresholds, and logic by writing a new config version and pointing your server at it \u2014 zero downtime, full audit trail."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "\n",
+ "import polars as pl\n",
+ "from pydantic import BaseModel\n",
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.config.file import JsonFileConfigManager\n",
+ "from decider.modules import GraphModule\n",
+ "\n",
+ "print('Imports OK')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m2",
+ "metadata": {},
+ "source": [
+ "## Build a Scorer with Specific Config\n",
+ "\n",
+ "We will use a simple risk scorer whose weights are controlled by a Pydantic config."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class RiskConfig(BaseModel):\n",
+ " dti_weight: float = 150.0\n",
+ " base_score: float = 700.0\n",
+ "\n",
+ "def dti_ratio(debt: pl.Expr, income: pl.Expr) -> pl.Expr:\n",
+ " return debt / income\n",
+ "\n",
+ "def risk_score(dti_ratio: pl.Expr, config: RiskConfig) -> pl.Expr:\n",
+ " return pl.lit(config.base_score) - dti_ratio * pl.lit(config.dti_weight)\n",
+ "\n",
+ "RiskScorer = generate_from_functions('risk_scorer', dti_ratio, risk_score)\n",
+ "\n",
+ "# Instantiate with specific weights\n",
+ "scorer = RiskScorer(name='risk_scorer', dti_weight=200.0, base_score=800.0)\n",
+ "\n",
+ "df = pl.DataFrame({\n",
+ " 'applicant_id': ['A1', 'A2', 'A3'],\n",
+ " 'debt': [20000.0, 5000.0, 50000.0],\n",
+ " 'income': [40000.0, 60000.0, 80000.0],\n",
+ "})\n",
+ "\n",
+ "result = scorer({'input': df})\n",
+ "print(result.select(['applicant_id', 'dti_ratio', 'risk_score']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m3",
+ "metadata": {},
+ "source": [
+ "## Save to Versioned JSON Config\n",
+ "\n",
+ "`asave` stages the module in a `VersionedConfig` object. `save_version` writes it to disk under `{basepath}/{version}/main.json`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use a temporary directory so this notebook is self-contained\n",
+ "CONFIGS_DIR = os.path.join(tempfile.mkdtemp(), 'configs')\n",
+ "print('Saving configs to:', CONFIGS_DIR)\n",
+ "\n",
+ "mgr = JsonFileConfigManager(basepath=CONFIGS_DIR)\n",
+ "\n",
+ "versioned = asyncio.run(scorer.asave('main', mgr))\n",
+ "asyncio.run(mgr.save_version(overwrite=True))\n",
+ "\n",
+ "print('Saved version:', versioned.version)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m4",
+ "metadata": {},
+ "source": [
+ "## Load it Back\n",
+ "\n",
+ "Create a fresh `JsonFileConfigManager` (simulating what a server does at startup), call `get_latest()`, and reconstruct the module with `GraphModule.model_validate`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fresh_mgr = JsonFileConfigManager(basepath=CONFIGS_DIR)\n",
+ "loaded = asyncio.run(fresh_mgr.get_latest())\n",
+ "\n",
+ "print('Loaded version:', loaded.version)\n",
+ "\n",
+ "module = GraphModule.model_validate(loaded.config['main']).root\n",
+ "\n",
+ "print('Type :', module.type)\n",
+ "print('dti_weight :', module.dti_weight)\n",
+ "print('base_score :', module.base_score)\n",
+ "\n",
+ "# Verify the reconstructed module produces the same output\n",
+ "result2 = module({'input': df})\n",
+ "print(result2.select(['applicant_id', 'dti_ratio', 'risk_score']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m5",
+ "metadata": {},
+ "source": [
+ "## Inspect the JSON on Disk\n",
+ "\n",
+ "The config is plain JSON \u2014 human-readable, diff-friendly, and version-controllable."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "version_str = str(versioned.version)\n",
+ "json_path = os.path.join(CONFIGS_DIR, version_str, 'main.json')\n",
+ "\n",
+ "with open(json_path) as f:\n",
+ " on_disk = json.load(f)\n",
+ "\n",
+ "print(json.dumps(on_disk, indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "m6",
+ "metadata": {},
+ "source": [
+ "## Swap Config Values Without Redeploying\n",
+ "\n",
+ "Save two versions with different weights, load both, and confirm their outputs differ. This is the core of Decider's hot-swap capability."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "CONFIGS_DIR_V2 = os.path.join(tempfile.mkdtemp(), 'configs_v2')\n",
+ "\n",
+ "# --- Version 1: conservative weights ----------------------------------------\n",
+ "scorer_v1 = RiskScorer(name='risk_scorer', dti_weight=100.0, base_score=700.0)\n",
+ "mgr_v1 = JsonFileConfigManager(basepath=CONFIGS_DIR_V2)\n",
+ "versioned_1 = asyncio.run(scorer_v1.asave('main', mgr_v1))\n",
+ "asyncio.run(mgr_v1.save_version(overwrite=True))\n",
+ "print('Saved v1:', versioned_1.version, ' dti_weight=100 base_score=700')\n",
+ "\n",
+ "# --- Version 2: aggressive weights ------------------------------------------\n",
+ "scorer_v2 = RiskScorer(name='risk_scorer', dti_weight=300.0, base_score=850.0)\n",
+ "mgr_v2 = JsonFileConfigManager(basepath=CONFIGS_DIR_V2)\n",
+ "ver2 = asyncio.run(scorer_v2.asave('main', mgr_v2))\n",
+ "asyncio.run(mgr_v2.save_version())\n",
+ "print('Saved v2:', ver2.version, ' dti_weight=300 base_score=850')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c7",
+ "metadata": {},
+ "outputs": [],
+ "source": "# Load BOTH versions and compare outputs on the same data\nall_versions = sorted(\n [d for d in os.listdir(CONFIGS_DIR_V2) if os.path.isdir(os.path.join(CONFIGS_DIR_V2, d))]\n)\nprint('Versions on disk:', all_versions)\n\nfor ver_str in all_versions:\n r_mgr = JsonFileConfigManager(basepath=CONFIGS_DIR_V2)\n ver_cfg = asyncio.run(r_mgr._load_version(ver_str))\n mod = GraphModule.model_validate(ver_cfg.config['main']).root\n out = mod({'input': df}).select(['applicant_id', 'risk_score'])\n print('\\nVersion', ver_str, ' dti_weight=' + str(mod.dti_weight) + ' base_score=' + str(mod.base_score))\n print(out)"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/projects/01_loan_scoring.ipynb b/docs/examples/projects/01_loan_scoring.ipynb
new file mode 100644
index 0000000..6d7cad6
--- /dev/null
+++ b/docs/examples/projects/01_loan_scoring.ipynb
@@ -0,0 +1,229 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 4,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Loan Scoring Pipeline\n",
+ "\n",
+ "This notebook builds an end-to-end loan scoring pipeline using `decider`.\n",
+ "Two modules are composed with `|` into a single sequential pipeline:\n",
+ "\n",
+ "- **IncomeEnricher** \u2014 derives `debt_to_income` and `monthly_surplus` from raw applicant data\n",
+ "- **RiskScorer** \u2014 assigns a `risk_score` and `risk_tier` using a configurable threshold\n",
+ "\n",
+ "We also demonstrate config serialisation and how to hot-swap a tighter threshold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings; warnings.filterwarnings('ignore')\n",
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.modules._ext import register_graph_module, GraphModule\n",
+ "from pydantic import BaseModel"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sample data\n",
+ "\n",
+ "Five loan applicants with income, monthly debt obligations, and requested loan amount."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "applicants = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\", \"C005\"],\n",
+ " \"income\": [50_000, 35_000, 80_000, 22_000, 60_000],\n",
+ " \"monthly_debt\": [1_200, 900, 2_100, 700, 1_500],\n",
+ " \"loan_amount\": [10_000, 15_000, 25_000, 5_000, 20_000],\n",
+ "})\n",
+ "\n",
+ "print(applicants)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 1 \u2014 IncomeEnricher\n",
+ "\n",
+ "Derives two financial ratios from the raw columns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def debt_to_income(monthly_debt: pl.Expr, income: pl.Expr) -> pl.Expr:\n",
+ " return (monthly_debt * 12) / income\n",
+ "\n",
+ "def monthly_surplus(income: pl.Expr, monthly_debt: pl.Expr) -> pl.Expr:\n",
+ " return income / 12 - monthly_debt\n",
+ "\n",
+ "IncomeEnricher = generate_from_functions(\"income_enricher\", debt_to_income, monthly_surplus)\n",
+ "register_graph_module(IncomeEnricher)\n",
+ "income_enricher = IncomeEnricher(name=\"income_enricher\")\n",
+ "\n",
+ "print(income_enricher({\"input\": applicants}))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 2 \u2014 RiskScorer\n",
+ "\n",
+ "Computes a `risk_score` from the enriched columns and assigns a `risk_tier`.\n",
+ "The threshold that separates high from low risk is injected from `RiskConfig`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class RiskConfig(BaseModel):\n",
+ " threshold: float = 0.4\n",
+ "\n",
+ "def risk_score(debt_to_income: pl.Expr, monthly_surplus: pl.Expr) -> pl.Expr:\n",
+ " return (debt_to_income * 0.7 + (1 / (monthly_surplus + 1)) * 0.3).round(4)\n",
+ "\n",
+ "def risk_tier(risk_score: pl.Expr, config: RiskConfig) -> pl.Expr:\n",
+ " return pl.when(risk_score > config.threshold).then(pl.lit(\"HIGH\")).otherwise(pl.lit(\"LOW\"))\n",
+ "\n",
+ "RiskScorer = generate_from_functions(\"risk_scorer\", risk_score, risk_tier)\n",
+ "register_graph_module(RiskScorer)\n",
+ "risk_scorer = RiskScorer(name=\"risk_scorer\", threshold=0.4)\n",
+ "\n",
+ "print(risk_scorer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Composing the pipeline\n",
+ "\n",
+ "The `|` operator chains `IncomeEnricher` and `RiskScorer` into a `SequentialModule`.\n",
+ "Each step receives the previous step's output as its `\"input\"` frame."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline = income_enricher | risk_scorer\n",
+ "print(type(pipeline).__name__, \"\u2014\", pipeline.name)\n",
+ "\n",
+ "result = pipeline({\"input\": applicants})\n",
+ "print(result.select([\"customer_id\", \"debt_to_income\", \"monthly_surplus\", \"risk_score\", \"risk_tier\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Config roundtrip\n",
+ "\n",
+ "`model_dump()` serialises the full pipeline (including nested configs) to a plain dict.\n",
+ "`GraphModule.model_validate()` reconstructs a live module from that dict."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "cfg = pipeline.model_dump()\n",
+ "print(json.dumps(cfg, indent=2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "restored = GraphModule.model_validate(cfg).root\n",
+ "restored_result = restored({\"input\": applicants})\n",
+ "print(\"Roundtrip output matches original:\", result.equals(restored_result))\n",
+ "print(restored_result.select([\"customer_id\", \"risk_score\", \"risk_tier\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Tighter threshold\n",
+ "\n",
+ "Swap in a stricter `threshold=0.3` to see how the `risk_tier` assignments change."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "strict_scorer = RiskScorer(name=\"risk_scorer\", threshold=0.3)\n",
+ "strict_pipeline = income_enricher | strict_scorer\n",
+ "\n",
+ "strict_result = strict_pipeline({\"input\": applicants})\n",
+ "comparison = result.select([\"customer_id\", \"risk_tier\"]).rename({\"risk_tier\": \"tier_0.4\"}).with_columns(\n",
+ " strict_result[\"risk_tier\"].alias(\"tier_0.3\")\n",
+ ")\n",
+ "print(comparison)"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/projects/02_fraud_detection.ipynb b/docs/examples/projects/02_fraud_detection.ipynb
new file mode 100644
index 0000000..772167c
--- /dev/null
+++ b/docs/examples/projects/02_fraud_detection.ipynb
@@ -0,0 +1,215 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 4,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fraud Detection\n",
+ "\n",
+ "This notebook builds a real-time fraud scoring pipeline with `decider`.\n",
+ "\n",
+ "Two modules are composed with `|`:\n",
+ "\n",
+ "- **VelocityFeatures** \u2014 normalises transaction volume and amount against baselines\n",
+ "- **FraudScorer** \u2014 combines the two signals into a `fraud_score` and a binary `is_fraud` flag\n",
+ "\n",
+ "The blending weights are controlled by `FraudConfig` and can be updated without touching the function logic."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings; warnings.filterwarnings('ignore')\n",
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.modules._ext import register_graph_module, GraphModule\n",
+ "from pydantic import BaseModel"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sample transactions\n",
+ "\n",
+ "Four transactions: `txn_id`, raw `amount`, rolling `avg_amount`, and hourly `txn_count`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transactions = pl.DataFrame({\n",
+ " \"txn_id\": [\"T001\", \"T002\", \"T003\", \"T004\"],\n",
+ " \"amount\": [120.0, 5_500.0, 45.0, 980.0],\n",
+ " \"avg_amount\": [100.0, 200.0, 50.0, 300.0],\n",
+ " \"txn_count\": [3, 48, 1, 12],\n",
+ "})\n",
+ "\n",
+ "print(transactions)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 1 \u2014 VelocityFeatures\n",
+ "\n",
+ "`txn_velocity` normalises the transaction count to an hourly rate. \n",
+ "`amount_spike` measures how far the current amount deviates from the customer average."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def txn_velocity(txn_count: pl.Expr) -> pl.Expr:\n",
+ " return txn_count / 24\n",
+ "\n",
+ "def amount_spike(amount: pl.Expr, avg_amount: pl.Expr) -> pl.Expr:\n",
+ " return amount / avg_amount\n",
+ "\n",
+ "VelocityFeatures = generate_from_functions(\"velocity_features\", txn_velocity, amount_spike)\n",
+ "register_graph_module(VelocityFeatures)\n",
+ "velocity_features = VelocityFeatures(name=\"velocity_features\")\n",
+ "\n",
+ "print(velocity_features({\"input\": transactions}))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 2 \u2014 FraudScorer\n",
+ "\n",
+ "`fraud_score` is a weighted combination of the two signals. \n",
+ "`is_fraud` fires when the score exceeds `1.0` \u2014 meaning the transaction looks anomalous on at least one dimension."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class FraudConfig(BaseModel):\n",
+ " velocity_weight: float = 0.6\n",
+ " spike_weight: float = 0.4\n",
+ "\n",
+ "def fraud_score(txn_velocity: pl.Expr, amount_spike: pl.Expr, config: FraudConfig) -> pl.Expr:\n",
+ " return (config.velocity_weight * txn_velocity + config.spike_weight * amount_spike).round(4)\n",
+ "\n",
+ "def is_fraud(fraud_score: pl.Expr) -> pl.Expr:\n",
+ " return fraud_score > 1.0\n",
+ "\n",
+ "FraudScorer = generate_from_functions(\"fraud_scorer\", fraud_score, is_fraud)\n",
+ "register_graph_module(FraudScorer)\n",
+ "fraud_scorer = FraudScorer(name=\"fraud_scorer\", velocity_weight=0.6, spike_weight=0.4)\n",
+ "\n",
+ "print(fraud_scorer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Composing the pipeline\n",
+ "\n",
+ "Chain `VelocityFeatures | FraudScorer` so the scorer automatically receives the derived columns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline = velocity_features | fraud_scorer\n",
+ "\n",
+ "result = pipeline({\"input\": transactions})\n",
+ "print(result.select([\"txn_id\", \"txn_velocity\", \"amount_spike\", \"fraud_score\", \"is_fraud\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Serialisation to JSON\n",
+ "\n",
+ "The entire pipeline \u2014 including the nested `FraudConfig` weights \u2014 serialises cleanly to JSON.\n",
+ "This string can be stored in a config store or version-controlled alongside model artefacts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "pipeline_json = json.dumps(pipeline.model_dump(), indent=2)\n",
+ "print(pipeline_json)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Restore from JSON and verify\n",
+ "\n",
+ "`GraphModule.model_validate()` deserialises the JSON back into a live pipeline."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "restored = GraphModule.model_validate(json.loads(pipeline_json)).root\n",
+ "restored_result = restored({\"input\": transactions})\n",
+ "\n",
+ "print(\"Outputs match after roundtrip:\", result.equals(restored_result))\n",
+ "print(restored_result.select([\"txn_id\", \"fraud_score\", \"is_fraud\"]))"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/projects/03_affordability_check.ipynb b/docs/examples/projects/03_affordability_check.ipynb
new file mode 100644
index 0000000..c051d24
--- /dev/null
+++ b/docs/examples/projects/03_affordability_check.ipynb
@@ -0,0 +1,239 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 4,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Affordability Check\n",
+ "\n",
+ "This notebook demonstrates `JoinModule` \u2014 used here to enrich a customer frame with\n",
+ "expense data that lives in a separate source.\n",
+ "\n",
+ "The pipeline:\n",
+ "\n",
+ "1. `ExpenseEnricher` runs on the `expenses` frame to derive `total_expense` and `expense_ratio`\n",
+ "2. `JoinModule` merges the enriched expense data back onto the customer frame on `customer_id`\n",
+ "3. `AffordabilityScorer` computes `disposable_income` and a boolean `can_afford` flag\n",
+ "\n",
+ "Both input frames are passed together as `{\"input\": customers_df, \"expenses\": expenses_df}`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings; warnings.filterwarnings('ignore')\n",
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.modules._ext import register_graph_module, GraphModule\n",
+ "from decider.modules.primitives.join import JoinModule\n",
+ "from decider.modules.primitives.sequential import SequentialModule\n",
+ "from pydantic import BaseModel"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sample data\n",
+ "\n",
+ "Two separate frames that share `customer_id` as a join key."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "customers = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\"],\n",
+ " \"income\": [48_000, 32_000, 75_000, 28_000],\n",
+ " \"loan_amount\": [12_000, 8_000, 30_000, 5_000],\n",
+ "})\n",
+ "\n",
+ "expenses = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\"],\n",
+ " \"income\": [48_000, 32_000, 75_000, 28_000],\n",
+ " \"rent\": [ 1_200, 900, 2_500, 750],\n",
+ " \"food\": [ 400, 300, 600, 250],\n",
+ " \"transport\": [ 150, 100, 200, 80],\n",
+ "})\n",
+ "\n",
+ "print(\"Customers:\")\n",
+ "print(customers)\n",
+ "print(\"\\nExpenses:\")\n",
+ "print(expenses)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 1 \u2014 ExpenseEnricher\n",
+ "\n",
+ "Runs on the `expenses` frame. Aggregates individual cost lines into `total_expense`\n",
+ "and computes `expense_ratio` as a proportion of annual income."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def total_expense(rent: pl.Expr, food: pl.Expr, transport: pl.Expr) -> pl.Expr:\n",
+ " return (rent + food + transport) * 12\n",
+ "\n",
+ "def expense_ratio(total_expense: pl.Expr, income: pl.Expr) -> pl.Expr:\n",
+ " return (total_expense / income).round(4)\n",
+ "\n",
+ "ExpenseEnricher = generate_from_functions(\"expense_enricher\", total_expense, expense_ratio)\n",
+ "register_graph_module(ExpenseEnricher)\n",
+ "expense_enricher = ExpenseEnricher(name=\"expense_enricher\")\n",
+ "\n",
+ "print(expense_enricher({\"input\": expenses}))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## JoinModule \u2014 merge enriched expenses back to customers\n",
+ "\n",
+ "`JoinModule` accepts module references as `left` or `right`, so we can point `right`\n",
+ "directly at `expense_enricher` rather than pre-computing a separate frame."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from decider.modules.primitives.join import FrameRef\n",
+ "\n",
+ "# Route the \"expenses\" frame into expense_enricher using FrameRef\n",
+ "expense_pipeline = FrameRef(name=\"expenses\") | expense_enricher\n",
+ "\n",
+ "joiner = JoinModule(\n",
+ " name=\"expense_join\",\n",
+ " left=\"input\",\n",
+ " right=expense_pipeline,\n",
+ " on=\"customer_id\",\n",
+ " how=\"left\",\n",
+ ")\n",
+ "\n",
+ "joined = joiner({\"input\": customers, \"expenses\": expenses})\n",
+ "print(joined)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Module 2 \u2014 AffordabilityScorer\n",
+ "\n",
+ "Runs on the joined frame. `disposable_income` is what is left after all annual expenses.\n",
+ "`can_afford` is `True` when the disposable income covers at least one monthly loan repayment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def disposable_income(income: pl.Expr, total_expense: pl.Expr) -> pl.Expr:\n",
+ " return income - total_expense\n",
+ "\n",
+ "def can_afford(disposable_income: pl.Expr, loan_amount: pl.Expr) -> pl.Expr:\n",
+ " return disposable_income > (loan_amount / 12)\n",
+ "\n",
+ "AffordabilityScorer = generate_from_functions(\"affordability_scorer\", disposable_income, can_afford)\n",
+ "register_graph_module(AffordabilityScorer)\n",
+ "affordability_scorer = AffordabilityScorer(name=\"affordability_scorer\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Full pipeline\n",
+ "\n",
+ "The final pipeline is `joiner | affordability_scorer`.\n",
+ "Both input frames are supplied at call time \u2014 the `JoinModule` internally routes\n",
+ "the `\"expenses\"` frame through `ExpenseEnricher` before the join."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline = joiner | affordability_scorer\n",
+ "\n",
+ "result = pipeline({\"input\": customers, \"expenses\": expenses})\n",
+ "print(result.select([\"customer_id\", \"income\", \"loan_amount\", \"total_expense\", \"expense_ratio\", \"disposable_income\", \"can_afford\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Config roundtrip\n",
+ "\n",
+ "The whole pipeline, including the embedded `ExpenseEnricher` inside the `JoinModule`, serialises correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "cfg = pipeline.model_dump()\n",
+ "print(json.dumps(cfg, indent=2))\n",
+ "\n",
+ "restored = GraphModule.model_validate(cfg).root\n",
+ "restored_result = restored({\"input\": customers, \"expenses\": expenses})\n",
+ "print(\"\\nOutputs match after roundtrip:\", result.equals(restored_result))"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/projects/04_multi_bureau_pipeline.ipynb b/docs/examples/projects/04_multi_bureau_pipeline.ipynb
new file mode 100644
index 0000000..ae13800
--- /dev/null
+++ b/docs/examples/projects/04_multi_bureau_pipeline.ipynb
@@ -0,0 +1,239 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 4,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Multi-Bureau Credit Pipeline\n",
+ "\n",
+ "This notebook demonstrates a pipeline that scores applicants against two separate credit bureaus, joins the results, and makes a combined final decision.\n",
+ "\n",
+ "Key concepts: `FrameRef` for routing named frames, `JoinModule` with module operands, config serialisation.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q decider\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings; warnings.filterwarnings('ignore')\n",
+ "import polars as pl\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from decider.modules.functional import generate_from_functions\n",
+ "from decider.modules._ext import register_graph_module, GraphModule\n",
+ "from decider.modules.primitives.join import JoinModule, FrameRef\n",
+ "from decider.modules.primitives.sequential import SequentialModule\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sample data\n",
+ "\n",
+ "Each bureau provides its own DataFrame for the same applicants.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bureau1_data = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\"],\n",
+ " \"b1_enquiries\": [2, 8, 1, 5],\n",
+ "})\n",
+ "\n",
+ "bureau2_data = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\"],\n",
+ " \"b2_balance\": [5_000, 18_000, 1_200, 9_500],\n",
+ " \"b2_limit\": [10_000, 20_000, 8_000, 12_000],\n",
+ "})\n",
+ "\n",
+ "print(\"Bureau 1:\"); print(bureau1_data)\n",
+ "print(\"\\nBureau 2:\"); print(bureau2_data)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Bureau scorers\n",
+ "\n",
+ "Each bureau has its own scoring module that takes the bureau's raw columns.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def b1_score(b1_enquiries: pl.Expr) -> pl.Expr:\n",
+ " return b1_enquiries * -2 + 700\n",
+ "\n",
+ "def b1_default_flag(b1_score: pl.Expr) -> pl.Expr:\n",
+ " return b1_score < 600\n",
+ "\n",
+ "Bureau1Scorer = generate_from_functions(\"bureau1_scorer\", b1_score, b1_default_flag)\n",
+ "register_graph_module(Bureau1Scorer)\n",
+ "bureau1_scorer = Bureau1Scorer(name=\"bureau1_scorer\")\n",
+ "\n",
+ "print(bureau1_scorer({\"input\": bureau1_data}))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def b2_score(b2_balance: pl.Expr, b2_limit: pl.Expr) -> pl.Expr:\n",
+ " return (b2_balance / b2_limit) * -300 + 700\n",
+ "\n",
+ "def b2_default_flag(b2_score: pl.Expr) -> pl.Expr:\n",
+ " return b2_score < 600\n",
+ "\n",
+ "Bureau2Scorer = generate_from_functions(\"bureau2_scorer\", b2_score, b2_default_flag)\n",
+ "register_graph_module(Bureau2Scorer)\n",
+ "bureau2_scorer = Bureau2Scorer(name=\"bureau2_scorer\")\n",
+ "\n",
+ "print(bureau2_scorer({\"input\": bureau2_data}))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Join bureau outputs\n",
+ "\n",
+ "`FrameRef` routes each named input frame to its scorer. `JoinModule` then joins the two resulting frames on `customer_id`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Wrap each scorer so it reads from its own named input frame\n",
+ "b1_pipeline = FrameRef(name=\"bureau1\") | bureau1_scorer\n",
+ "b2_pipeline = FrameRef(name=\"bureau2\") | bureau2_scorer\n",
+ "\n",
+ "bureau_join = JoinModule(\n",
+ " name=\"bureau_join\",\n",
+ " left=b1_pipeline,\n",
+ " right=b2_pipeline,\n",
+ " on=\"customer_id\",\n",
+ " how=\"left\",\n",
+ ")\n",
+ "\n",
+ "joined = bureau_join({\"bureau1\": bureau1_data, \"bureau2\": bureau2_data})\n",
+ "print(joined)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Combined decision\n",
+ "\n",
+ "Average the two bureau scores and apply a threshold for the final accept/decline.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def combined_score(b1_score: pl.Expr, b2_score: pl.Expr) -> pl.Expr:\n",
+ " return ((b1_score + b2_score) / 2).round(2)\n",
+ "\n",
+ "def final_decision(combined_score: pl.Expr) -> pl.Expr:\n",
+ " return combined_score >= 600\n",
+ "\n",
+ "CombinedScorer = generate_from_functions(\"combined_scorer\", combined_score, final_decision)\n",
+ "register_graph_module(CombinedScorer)\n",
+ "combined_scorer = CombinedScorer(name=\"combined_scorer\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline = bureau_join | combined_scorer\n",
+ "\n",
+ "result = pipeline({\"bureau1\": bureau1_data, \"bureau2\": bureau2_data})\n",
+ "print(result.select([\"customer_id\", \"b1_score\", \"b2_score\", \"combined_score\", \"final_decision\"]))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Config serialisation & roundtrip\n",
+ "\n",
+ "The full pipeline \u2014 including `FrameRef`, nested `JoinModule`, and all scorers \u2014 serialises to JSON and can be restored exactly.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "cfg = pipeline.model_dump()\n",
+ "print(json.dumps(cfg, indent=2))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "restored_pipeline = GraphModule.model_validate(cfg).root\n",
+ "\n",
+ "restored_result = restored_pipeline({\"bureau1\": bureau1_data, \"bureau2\": bureau2_data})\n",
+ "print(\"Outputs match after roundtrip:\", result.equals(restored_result))\n",
+ "print(restored_result.select([\"customer_id\", \"combined_score\", \"final_decision\"]))\n"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/examples/projects/05_notebook_to_serving/.gitignore b/docs/examples/projects/05_notebook_to_serving/.gitignore
new file mode 100644
index 0000000..dc3d233
--- /dev/null
+++ b/docs/examples/projects/05_notebook_to_serving/.gitignore
@@ -0,0 +1,2 @@
+model/
+decider_extensions/
diff --git a/docs/examples/projects/05_notebook_to_serving/05_notebook_to_serving.ipynb b/docs/examples/projects/05_notebook_to_serving/05_notebook_to_serving.ipynb
new file mode 100644
index 0000000..899f5d5
--- /dev/null
+++ b/docs/examples/projects/05_notebook_to_serving/05_notebook_to_serving.ipynb
@@ -0,0 +1,414 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# From Notebook to Production Service\n",
+ "\n",
+ "This notebook walks through the full `decider` workflow end-to-end:\n",
+ "\n",
+ "1. Define modules interactively using the `%%module` magic\n",
+ "2. Compose and test the pipeline in the notebook\n",
+ "3. Save a versioned JSON config\n",
+ "4. Start an HTTP server with `decider serve`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q \"decider[serve-sanic,notebook]\"\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import warnings; warnings.filterwarnings('ignore')\n",
+ "import os, json, asyncio\n",
+ "from pathlib import Path\n",
+ "import polars as pl\n",
+ "from pydantic import BaseModel\n",
+ "\n",
+ "# All generated files land inside this notebook's folder, not wherever\n",
+ "# the kernel happens to be cwd'd \u2014 consistent regardless of how Jupyter is launched.\n",
+ "HERE = Path(__file__).parent if '__file__' in dir() else Path('.')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load the `%%module` magic\n",
+ "\n",
+ "The `%%module` magic writes an extension file under `decider_extensions/` and\n",
+ "registers the class in the current session \u2014 no manual file management needed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Point the magic at this notebook's own extensions folder so generated\n",
+ "# files stay inside docs/examples/projects/05_notebook_to_serving/ and\n",
+ "# are covered by its .gitignore.\n",
+ "os.environ['DECIDER_EXTENSIONS_DIR'] = str(HERE / 'decider_extensions')\n",
+ "%load_ext decider.magics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the feature engineering module\n",
+ "\n",
+ "Each top-level function in the cell body becomes an output column.\n",
+ "Parameter names map to input column names; a `config: MyConfig` parameter\n",
+ "injects versioned config values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[module] Written: /Users/cp371651/Documents/Workspace/Upskilling/dsp_north-polrs/docs/examples/projects/decider_extensions/loan_features/__init__.py\n",
+ "[module] LoanFeatures registered type='loan_features'\n",
+ "[module] LoanFeatures injected into namespace\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%module LoanFeatures\n",
+ "\n",
+ "def debt_to_income(monthly_debt: pl.Expr, income: pl.Expr) -> pl.Expr:\n",
+ " return (monthly_debt * 12) / income\n",
+ "\n",
+ "def monthly_surplus(income: pl.Expr, monthly_debt: pl.Expr) -> pl.Expr:\n",
+ " return income / 12 - monthly_debt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the scoring module with config injection\n",
+ "\n",
+ "`RiskConfig` fields are promoted directly onto the module \u2014 you can pass\n",
+ "`threshold=0.3` when instantiating, or load it from a saved JSON config."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[module] Written: /Users/cp371651/Documents/Workspace/Upskilling/dsp_north-polrs/docs/examples/projects/decider_extensions/risk_scorer/__init__.py\n",
+ "[module] RiskScorer registered type='risk_scorer'\n",
+ "[module] RiskScorer injected into namespace\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%module RiskScorer\n",
+ "\n",
+ "class RiskScorerConfig(BaseModel):\n",
+ " threshold: float = 0.4\n",
+ "\n",
+ "def risk_score(debt_to_income: pl.Expr, monthly_surplus: pl.Expr) -> pl.Expr:\n",
+ " return (debt_to_income * 0.7 + (1 / (monthly_surplus + 1)) * 0.3).round(4)\n",
+ "\n",
+ "def risk_tier(risk_score: pl.Expr, config: RiskScorerConfig) -> pl.Expr:\n",
+ " return pl.when(risk_score > config.threshold).then(pl.lit(\"HIGH\")).otherwise(pl.lit(\"LOW\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Test the pipeline in the notebook\n",
+ "\n",
+ "Modules registered by `%%module` are available immediately in the namespace.\n",
+ "Compose with `|` and call with a plain dict of Polars DataFrames."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape: (4, 4)\n",
+ "\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n",
+ "\u2502 customer_id \u2506 debt_to_income \u2506 risk_score \u2506 risk_tier \u2502\n",
+ "\u2502 --- \u2506 --- \u2506 --- \u2506 --- \u2502\n",
+ "\u2502 str \u2506 f64 \u2506 f64 \u2506 str \u2502\n",
+ "\u255e\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256a\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256a\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256a\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2561\n",
+ "\u2502 C001 \u2506 0.288 \u2506 0.2017 \u2506 LOW \u2502\n",
+ "\u2502 C002 \u2506 0.308571 \u2506 0.2161 \u2506 LOW \u2502\n",
+ "\u2502 C003 \u2506 0.315 \u2506 0.2206 \u2506 LOW \u2502\n",
+ "\u2502 C004 \u2506 0.381818 \u2506 0.2675 \u2506 LOW \u2502\n",
+ "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n"
+ ]
+ }
+ ],
+ "source": [
+ "applicants = pl.DataFrame({\n",
+ " \"customer_id\": [\"C001\", \"C002\", \"C003\", \"C004\"],\n",
+ " \"income\": [50_000, 35_000, 80_000, 22_000],\n",
+ " \"monthly_debt\": [ 1_200, 900, 2_100, 700],\n",
+ " \"loan_amount\": [10_000, 15_000, 25_000, 5_000],\n",
+ "})\n",
+ "\n",
+ "pipeline = LoanFeatures(name=\"features\") | RiskScorer(name=\"scorer\", threshold=0.4)\n",
+ "result = pipeline({\"input\": applicants})\n",
+ "print(result.select([\"customer_id\", \"debt_to_income\", \"risk_score\", \"risk_tier\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Save a versioned config\n",
+ "\n",
+ "`asave` stages the module in memory; `save_version` writes a timestamped\n",
+ "JSON file to disk. The server reads this file at startup."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saved version: 0.0.0\n",
+ "Config directory: /Users/cp371651/Documents/Workspace/Upskilling/dsp_north-polrs/docs/examples/projects/model/configs\n"
+ ]
+ }
+ ],
+ "source": [
+ "from decider.config.file import JsonFileConfigManager\n",
+ "\n",
+ "CONFIGS_DIR = HERE / 'model' / 'configs'\n",
+ "CONFIGS_DIR.mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "mgr = JsonFileConfigManager(basepath=str(CONFIGS_DIR))\n",
+ "versioned = await pipeline.asave('main', mgr)\n",
+ "await mgr.save_version(overwrite=True)\n",
+ "\n",
+ "print(f'Saved version: {versioned.version}')\n",
+ "print(f'Config directory: {CONFIGS_DIR}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Inspect the saved config\n",
+ "\n",
+ "The config is plain JSON \u2014 diff-friendly, human-readable, and\n",
+ "safe to check into version control."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"type\": \"sequential\",\n",
+ " \"name\": \"features\",\n",
+ " \"steps\": [\n",
+ " {\n",
+ " \"type\": \"loan_features\",\n",
+ " \"name\": \"features\"\n",
+ " },\n",
+ " {\n",
+ " \"threshold\": 0.4,\n",
+ " \"type\": \"risk_scorer\",\n",
+ " \"name\": \"scorer\"\n",
+ " }\n",
+ " ]\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "version_str = str(versioned.version)\n",
+ "config_path = CONFIGS_DIR / version_str / 'main.json'\n",
+ "\n",
+ "with open(config_path) as f:\n",
+ " print(json.dumps(json.load(f), indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Start the inference server\n",
+ "\n",
+ "`decider serve` picks up the config directory from the\n",
+ "`DECIDER_CONFIG__BASEPATH` environment variable.\n",
+ "\n",
+ "The server exposes two endpoints:\n",
+ "- `POST /predict` \u2014 accepts JSON or Arrow IPC, returns scored output\n",
+ "- `GET /ping` \u2014 health check\n",
+ "\n",
+ "> **Note:** the cell below starts a blocking process. Run it in a terminal or\n",
+ "> uncomment the `subprocess` variant to run it in the background."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "To start the server, run in a terminal:\n",
+ " DECIDER_CONFIG__BASEPATH=/Users/cp371651/Documents/Workspace/Upskilling/dsp_north-polrs/docs/examples/projects/model/configs decider serve --engine sanic --port 8080\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/miniconda/envs/spockappdev/bin/python: No module named decider.cli.__main__; 'decider.cli' is a package and cannot be directly executed\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Set the config path so the server knows where to load models from.\n",
+ "os.environ['DECIDER_CONFIG__BASEPATH'] = str(CONFIGS_DIR)\n",
+ "os.environ['DECIDER_CONFIG__TYPE'] = 'file:json'\n",
+ "\n",
+ "# Run in a terminal (from this notebook's directory):\n",
+ "# decider serve --engine sanic --port 8080\n",
+ "#\n",
+ "# Or start non-blocking from Python (uncomment to use):\n",
+ "# import subprocess, sys\n",
+ "# server = subprocess.Popen(\n",
+ "# ['decider', 'serve', '--engine', 'sanic', '--port', '8080'],\n",
+ "# env=os.environ.copy(),\n",
+ "# )\n",
+ "# import time; time.sleep(2) # give sanic a moment to start\n",
+ "\n",
+ "print('To start the server, run in a terminal:')\n",
+ "print(f' DECIDER_CONFIG__BASEPATH={CONFIGS_DIR} decider serve --engine sanic --port 8080')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Send a prediction request\n",
+ "\n",
+ "Once the server is running, call `/predict` with a JSON payload.\n",
+ "The server returns a JSON object with the scored columns appended."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import urllib.request\n",
+ "\n",
+ "payload = json.dumps({\n",
+ " \"customer_id\": [\"C001\", \"C002\"],\n",
+ " \"income\": [50_000, 35_000],\n",
+ " \"monthly_debt\": [ 1_200, 900],\n",
+ " \"loan_amount\": [10_000, 15_000],\n",
+ "}).encode()\n",
+ "\n",
+ "# Uncomment once the server is running:\n",
+ "# req = urllib.request.Request(\n",
+ "# \"http://localhost:8080/predict\",\n",
+ "# data=payload,\n",
+ "# headers={\"Content-Type\": \"application/json\"},\n",
+ "# )\n",
+ "# with urllib.request.urlopen(req) as resp:\n",
+ "# print(json.dumps(json.loads(resp.read()), indent=2))\n",
+ "\n",
+ "print(\"Prediction payload:\")\n",
+ "print(json.dumps(json.loads(payload), indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Deploying with Docker\n",
+ "\n",
+ "Build and run the pre-built Docker image (requires `docker/Dockerfile` and a\n",
+ "built wheel in `dist/`):\n",
+ "\n",
+ "```bash\n",
+ "# Build wheel\n",
+ "python -m build --wheel -o dist/\n",
+ "\n",
+ "# Build image\n",
+ "docker build -t decider:latest -f docker/Dockerfile .\n",
+ "\n",
+ "# Run \u2014 mount your model directory\n",
+ "docker run -p 8080:8080 \\\n",
+ " -v $(pwd)/model:/opt/ml/model:ro \\\n",
+ " -e DECIDER_SERVE__WORKERS=5 \\\n",
+ " decider:latest\n",
+ "```\n",
+ "\n",
+ "The server starts automatically using `decider serve --engine sanic`.\n",
+ "Workers default to `nproc * 2 + 1` when `DECIDER_SERVE__WORKERS` is not set."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "spockappdev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file
diff --git a/docs/getting_started/contributing/architecure/index.md b/docs/getting_started/contributing/architecure/index.md
deleted file mode 100644
index 9055374..0000000
--- a/docs/getting_started/contributing/architecure/index.md
+++ /dev/null
@@ -1,204 +0,0 @@
-## Spockflow Architecture Documentation - Landing Page
-
-
-### **Directory Structure**
-
-To better understand Spockflow’s architecture, let’s explore the key folders and their responsibilities within the package:
-
----
-
-#### **1. `spockflow/`**
-This is the core of the Spockflow framework, containing all of its primary modules and components.
-
----
-
-##### **core.py**
-- The main entry point for Hamilton integration.
-- Defines a custom decorator for injecting Spockflow logic into Hamilton’s DAG framework.
-- Expands Hamilton subdags with configurable components and generates nodes from these components.
-- Calls `initialize_spock_module` to inject the Spockflow functionality into a given module, allowing for the automatic generation of Hamilton nodes.
-
-Example of Hamilton subdag integration:
-```python
-@subdag(
- feature_modules,
- inputs={"path": source("source_path")},
- config={}
-)
-def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame:
- return feature_df
-```
-- In Spockflow, the `initialize_spock_module` decorator ensures that subdags are expanded and executed according to the framework's configuration.
-
----
-
-##### **nodes.py**
-- Contains the definition of `VariableNode`, the core class responsible for transforming configuration-driven logic into executable Hamilton nodes.
-- Handles utilities such as `CloneVariableNode` (for duplicating nodes) and `AliasedVariableNode` (for renaming nodes without re-executing).
-- Uses Pydantic classes to serialize and deserialize configuration, making it easier to manage node definitions and configurations.
-- The `generate_nodes` function within `VariableNode` handles the actual creation of subnodes, ensuring that each node can be expanded within a Hamilton DAG.
-
-```python
-def _generate_nodes(self, ...):
- ...
- node_functions = inspect.getmembers(
- compiled_variable_node, predicate=self._does_define_node
- )
- ...
-```
-This method identifies and expands functions within a module as Hamilton nodes, ensuring that subcomponents can be injected into larger data pipelines.
-
----
-
-##### **_serializable.py**
-- Provides utilities to help with the serialization and deserialization of data, particularly for handling Pandas DataFrames and Series.
-- Ensures that data passed through Spockflow nodes can be properly transformed and maintained across different steps of the pipeline.
-
----
-
-#### **2. `components/`**
-Contains all the core components and decision-oriented modules in Spockflow, including:
-- **Decision Trees**: Build decision trees to enforce rules for data enrichment and transformations.
-- **Scorecards**: Create scoring systems for evaluating data based on multiple parameters.
-- **Decision Tables**: Define mappings of input values to outputs based on set conditions.
-
-Each of these components is built as reusable modules that can be configured and inserted into your data flows.
-
----
-
-#### **3. `inference/`**
-- Contains logic and tools to serve models via endpoints compatible with services like AWS SageMaker.
-
----
-
-
-### **How to Define a Custom Node in Spockflow**
-
-In Spockflow, custom nodes allow users to extend the framework's functionality by creating new components that integrate seamlessly into the Hamilton DAG-based architecture. A custom node is a class that inherits from `VariableNode` and defines its own behavior for node creation, input handling, and execution.
-
-Here, we'll define a custom `Tree` node as an example of how to create a custom decision-making process using Spockflow’s infrastructure.
-
-#### **Step 1: Define the Custom Node Class**
-
-To create a custom node, you need to subclass `VariableNode` and define several key components, such as input fields, the `compile()` method, and custom logic for handling inputs and outputs.
-
-```python
-class Tree(VariableNode):
- # This is used in visualisation by Hamilton
- doc: str = "This executes a user-defined decision tree"
-
- # Define fields using Pydantic (these can be any fields for configuration)
- execution_conditions: typing.List[str]
- execution_outputs: typing.List[str]
-
- # The compile function needs to be provided. By default, it will just return self.
- def compile(self):
- # This step may involve transforming or processing the input data into a usable format
- from .compiled import CompiledNumpyTree
- return CompiledNumpyTree(self)
-```
-
-- `execution_conditions` and `execution_outputs` are lists of strings that define the conditions and outputs associated with the decision tree.
-- The `compile()` function is responsible for transforming the raw input data into a format that can be used by the Hamilton DAG. In this case, it initializes a `CompiledNumpyTree`.
-
-#### **Step 2: Define a Compiled Representation for the Node**
-
-To optimize how the node’s logic is executed, we can define a compiled version of the node, such as `CompiledNumpyTree`. This compiled version will contain the logic to handle the execution and manage inputs dynamically.
-
-```python
-class CompiledNumpyTree:
- def __init__(self, tree: Tree) -> None:
- # This constructor will process and configure the tree into an executable form
- self.tree = tree
- # Additional processing logic for the tree can go here
-
- def _get_inputs(self, function: typing.Callable):
- # Returns the expected input types for the node
- node_input_types = {o: pd.DataFrame for o in self.tree.execution_outputs}
- node_input_types.update({c: typing.Union[np.ndarray, pd.Series] for c in self.tree.execution_conditions})
- return node_input_types
-```
-
-- The `CompiledNumpyTree` class is responsible for transforming the raw `Tree` object into an optimized version that can be used in a Hamilton DAG.
-- The `_get_inputs` function dynamically determines the input types required for this node’s execution.
-
-#### **Step 3: Define Node Functions with `@creates_node`**
-
-Next, we define the various operations that make up the logic of our custom `Tree` node. These operations are implemented as functions within the `Tree` class and are decorated with `@creates_node`. The `@creates_node` decorator tells Spockflow to treat these methods as subnodes within the Hamilton DAG.
-
-```python
-from spockflow.nodes import creates_node
-import numpy as np
-import pandas as pd
-
-class Tree(VariableNode):
- # Other fields and compile method defined previously
-
- @creates_node(kwarg_input_generator="_get_inputs") # Generates node inputs dynamically
- def format_inputs(
- self, **kwargs: typing.Union[pd.DataFrame, pd.Series]
- ) -> TFormatData:
- # Process inputs and return transformed data
- pass
-
- @creates_node() # Defines a subnode for conditions met
- def conditions_met(self, format_inputs: TFormatData) -> np.ndarray:
- # Logic for evaluating conditions based on inputs
- pass
-
- @creates_node() # Defines a subnode for prioritizing conditions
- def prioritized_conditions(self, conditions_met: np.ndarray) -> np.ndarray:
- # Logic for prioritizing conditions
- pass
-
- @creates_node() # Defines a subnode for generating condition names
- def condition_names(self, format_inputs: TFormatData) -> typing.List[str]:
- # Logic to generate the names of the conditions
- pass
-
- @creates_node() # Defines a subnode for the final decision logic
- def all(
- self,
- format_inputs: TFormatData,
- conditions_met: np.ndarray,
- ) -> pd.DataFrame:
- # Logic for making a decision based on inputs and conditions
- pass
-
- @creates_node(is_namespaced=False) # This node will be created outside the namespace
- def get_results(
- self,
- format_inputs: TFormatData,
- prioritized_conditions: np.ndarray,
- ) -> pd.DataFrame:
- # Final output of the decision tree process
- pass
-```
-
-- **`@creates_node`**: This decorator defines the function as a subnode in the DAG.
- - The `kwarg_input_generator="_get_inputs"` argument is used to specify how to dynamically determine the input types for this node.
- - Each method, such as `format_inputs()`, `conditions_met()`, etc., corresponds to a specific operation in the decision tree.
-
-The above tree when created as follows:
-
-```python
-# Example Tree node instance
-example_tree = Tree(execution_conditions=["a", "b"], execution_outputs=["c", "d"])
-
-```
-will create the following DAG:
-
-
-- The above relationships represent the connections between nodes, where each `@creates_node` function becomes part of the Hamilton DAG.
-- The `example_tree.format_inputs` node takes inputs `a`, `b`, `c`, and `d`, and feeds them into subsequent nodes like `conditions_met`, `prioritized_conditions`, and others.
-
----
-
-### **Summary of Custom Node Creation Steps**
-
-1. **Define the Node Class**: Inherit from `VariableNode` and specify fields such as conditions and outputs.
-2. **Compile the Node**: Provide a `compile()` method to transform the node into an optimized executable form (e.g., `CompiledNumpyTree`).
-3. **Define Operations as Subnodes**: Use `@creates_node` to define methods that represent different parts of the decision-making process.
-4. **Establish Dependencies**: The created subnodes will automatically link based on their input/output relationships, forming a complete DAG.
-
-By following these steps, you can define complex, decision-oriented nodes in Spockflow and integrate them seamlessly into a Hamilton-based data pipeline.
\ No newline at end of file
diff --git a/docs/getting_started/contributing/architecure/tree.drawio.svg b/docs/getting_started/contributing/architecure/tree.drawio.svg
deleted file mode 100644
index f8b56fd..0000000
--- a/docs/getting_started/contributing/architecure/tree.drawio.svg
+++ /dev/null
@@ -1,300 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/backlog.md b/docs/getting_started/contributing/backlog.md
deleted file mode 100644
index 5a1ecda..0000000
--- a/docs/getting_started/contributing/backlog.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# Backlog items
-## Configuration Management
-1. Automatic model Validation: We need a way to validate configuration as soon as its changed before deploying a model
-2. Parameter changing: We need better ways to abstract certain config parameters to be changed by exec levels
- 1. We potentially need a ui to enable quick parameter changing
- 2. We need to flesh out standards for parameter stores (not only file stores)
- 3. We need a way to enable validation on parameters so that they cannot be changed the erroneous values.
-3. Seamless deployment: We need a way to enable deployments without users needing to be familiar with github
- 1. We need this to integrate well with parameters (the process to update a value like vat from 14% to 15% should be a 1 min job not a 30 min job)
- 2. We need a way to verify results perhaps with backtesting or live data before deploying
-4. Integration with A/B testing: We must make sure that all the retraining and ab testing work that is done is compatible with spockflow
- 1. Node-level: Users may want to select a subsection of the model to alter and do ab testing with so that they are able to see the outcomes midflow and how it affected the process as a whole (multiple ab tests at the same time)
-5. Dashboarding/tracing monitoring
-## Tree Functionality
-1. We should make it easy to migrate from the old format to the new format
-2. We need a way to optimise tree structures using libraries such as sympy
-3. We need a way to make trees work better with different data types. Currently it only supports int and float values. We will need to extend it that if we want a categorical string we are able to draw out the compute of changing the value to categorical before.
-4. We should consider bringing an easy way to split the data and rejoin it after a tree. This way we can cull off parts of the pipeline where there is no data to compute.
-## Dashboarding
-1. We should consider creating easy to use dashboarding functionalities into the ui so that we can monitor the outputs of the flows.
-
-
-We need a way to show feature use. Make sure you are able to catalog a feature and see the downstream impact of a feature on the entire system
\ No newline at end of file
diff --git a/docs/getting_started/contributing/index.md b/docs/getting_started/contributing/index.md
deleted file mode 100644
index 89bc7c5..0000000
--- a/docs/getting_started/contributing/index.md
+++ /dev/null
@@ -1,12 +0,0 @@
-```{include} ../../../CONTRIBUTING.md
-:relative-docs: docs
-:relative-images:
-```
-
-# Useful Links
-```{toctree}
-
-Architecture
-Roadmap
-Backlog
-```
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_plugin/architecture.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_plugin/architecture.drawio.svg
deleted file mode 100644
index 74ecf83..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_plugin/architecture.drawio.svg
+++ /dev/null
@@ -1,94 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_plugin/index.md b/docs/getting_started/contributing/roadmap/as_a_plugin/index.md
deleted file mode 100644
index 5ad88ca..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_plugin/index.md
+++ /dev/null
@@ -1,44 +0,0 @@
-# Description
-Integrating the UI as a plugin involves making an extra side menu item for dsp-rule-engine. This could be done in a similar fashion to the existing Redshift plugin or the git plugin. The final result could be similar to the mockup shown below.
-
-
-The system will rely on the same methods of model deployment as we currently use but can add a few convenience utilities such as automatically creating a folder with the project details and helping the user to link the project with git. It could potentially also facilitate creating the appropriate pull requests to kick off model deployment. A potential architecture of this solution could look as follows:
-
-# User Stories
-## The Developers (Advanced)
-For advanced rule developers this is likely the most convenient option as they will likely already be working in a sagemaker environment. The advanced rule developer will start their process off loading some data from redshift. Ideally there would be a convenient way to get this data directly into the ui so that as rules are added they can get a visual representation of the counts of data that fall within each segment and potentially add a set of criteria like accuracy with a target and detection rate that they can view. The advanced developer would likely switch back and fourth between this view and a notebook to develop more features or use other tools to discover better decision boundaries.
-
-Once the developer is done with their rules they could press a button that would generate a initial spock workflow from a template and add the config to the flow. They could then import this into their current notebook and run it on live data to generate a report. Some reporting templates can also be incorporated as part of the deploy template to make this process easier.
-
-As part of the template the user should specify a validation dataset and a set of metrics that the model should conform to before it is considered valid. This could be a set of exact requests and responses for specific cases but more generally it could be a target accuracy or alert rate that must be met on a dataset. This way in the future modifications to rules can quickly be validated.
-
-When the developer is happy with the process they could commit the code to a repo and begin the model deployment process. The above validations and metrics will be calculated on deployment and displayed to the user which they can inspect before approving the stage to finally deploy into the respective environments.
-## The Developers (Basic)
-For developers which are more used to a UI tool this may be a bit inconvenient as they will likely need to get familiar with a new environment (Sagemaker) and they will now need to spin up an instance before they can do any work rather than just going to a page and getting started. Once they have navigated the terrain their process will be quite similar to the advanced developer. They will open the tool and create a project. They will likely not switch too much between notebooks and the app and will likely stay within the confines of safety.
-
-For the basic developers it will be more essential to attempt to integrate the raw data into the flow creation process so that they can get real-time feedback as the develop rules and in the future entire workflows. These developers will need a way to automatically evaluate if things like features are not available in the source data. For an initial step data can be provided as a csv/parquet source. However in the future with a flow we could introduce a source in the form of either a request or a redshift datasource.
-
-The development of standard and custom metrics will be essential to enable these developers to be data driven rather than basing their decisions on pure intuition. Once these developers have met targets on their metrics they can export their models in an identical fashion to the advanced developers. These developers will likely initially struggle with the concept of a notebook environment to view reporting; However, with targeted templates we could assist them in gathering beneficial metrics to properly validate the rules that they have created. These templates would have to tie in well with the model validation process and potentially prompt the user to define thresholds to evaluate the model on.
-
-The ui can assist these developers greatly in getting git setup and assisting with the deployment process.
-## The Maintainers
-Maintainers are typically on the same skill-level as the Basic Developers and may even be less of a technical audience attempting to adjust a simple threshold. Due to this they will likely have a hard time opening the environment and cloning a repo to get started. Tight integration between the widget could assist the user as it may be possible to do a discovery process were the user can just select an existing model and it could clone it or open the folder. However this integration would be challenging for open-source as it would be very targeted to how our systems are setup unless a plugin system is used.
-
-Once the maintainer has an existing project cloned they will begin development in a similar fashion to the basic developer. A maintainer will generally stick to making minor adjustments to rules such as changing parameters, thresholds or occasionally adding or deleting a rule.
-
-The maintainer can rely on the tool to make a commit and push with changes however they would likely battle to make a pr and navigate the model deployment process. A link could be generated to submit the pr and docs could be linked to; however for an executive trying to change a simple threshold this might be quite a cumbersome task.
-## The Endpoint Users
-Endpoint users will be in tight communication with the developers to understand what input features need to be given into the model. They will at the end of the day be given a url to invoke to integrate with the final model.
-
-If new features are added to the model they will need to ensure that the upstream dependencies are updated in a way that is backwards compatible (always add new features and don't delete old features for a rollover period). That way upstream jobs can be deployed before the rules are changed and ideally a rollover can be smooth. If this is not a possibility they will need to communicate with the Developer that there needs to be a major version change and in this case the new model will be deployed to a new endpoint /v1 vs /v2 and should be deployed before the upstream job. The model can be configured to return a model version so that any downstream logic can handle the results accordingly.
-## The Monitors
-There will likely be people who are interested in monitoring the real-time performance of the model to determine the accuracy over time. For these cases it will be the responsibility of the developer or endpoint user to integrate the results back into redshift this can be done with the use of endpoint capture or just by storing the results at the end of the day. Ideally results will be enriched at some point with ground-truth values such as if the fraud is confirmed or not. Feature pipelines or prodbooks can be developed to calculate desired metrics and this can be used then in a report on PowerBI for teams to monitor. Setting up the entire monitoring process has quite a few moving parts and may be considered cumbersome.
-# Advantages
-- A plugin has more control than a widget
-- We do not not need to host it (use of existing compute)
-# Limitations
-- Hard to maintain (Plugin developers are niche)
-- Hard to setup (Not much documentation on creating sagemaker plugins is it the same as standard jupyter lab)
-- Clients need to login to sagemaker
-- Limits possibilities for open-source (niche)
-- Hard for managers just wanting to update a parameter
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_plugin/mockup.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_plugin/mockup.drawio.svg
deleted file mode 100644
index 35f2bf4..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_plugin/mockup.drawio.svg
+++ /dev/null
@@ -1,120 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_service/architecture.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_service/architecture.drawio.svg
deleted file mode 100644
index d4c686c..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_service/architecture.drawio.svg
+++ /dev/null
@@ -1,82 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_service/index.md b/docs/getting_started/contributing/roadmap/as_a_service/index.md
deleted file mode 100644
index 1033443..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_service/index.md
+++ /dev/null
@@ -1,28 +0,0 @@
-# Description
-Exposing the UI as a service means that the app will have its own url and be hosted on odin. It can be possible for clients to spin up their own instances or make use of a shared centralized instance. An example ui could look as follows:
-
-
-The system could make use of the current approach to deployments. It could also act as a hosting system of its own and create endpoints for flows automatically without the need for additional deployments. This would be advantageous as it opens the possibilities for one click deployments or more advanced AB testing as we can control everything about how the model is rolled out or how it is executed. The architecture could look as follows:
-
-# User Stories
-## The Developers (Advanced)
-Developers will be able to log in using the portal and navigate to the page. From there they can develop rules in a similar way to what they would as a plugin. To allow them to continue to work in their notebooks it would also be possible to create a widget that can integrate with the server. This will enable the advanced developer to quickly switch between notebook environments and rule development keeping it all in one environment.
-## The Developers (Basic)
-Developers with less technical backgrounds will find it easier to log into a dedicated app as they do not need to navigate through sagemaker to get into the app. The app and server will always be running so there will be less time required to get up and running. A hosted app will be ideal for basic developers as they will be used to low-code environments where they do not need to look at a single line of code for most usecases. They will log into the app develop a flow and press a button to deploy.
-## The Maintainers
-The flow of the maintainers will be identical to the developers. However it will be possible to even further simplify the maintenance of executive parameters and thresholds. Essential parameters can be extracted from flows and separated into their own interface. Here it will be possible for maintainers who aren't very technically inclined to update exposed thresholds based on descriptions and test their effects on certain key metrics without ever having to unpack the system as a whole. It will also be possible to create simplified ways to deploy the model using a UI so that maintainers do not need to be familiar with git.
-## The Endpoint Users
-Endpoint users will be able to use the model in an identical fashion to always. However as a service it may be possible to create a custom serving layer. With a custom serving layer we can handle cases such as automatic rerouting based on input payload capabilities, higher level orchestration (The output of one flow could determine which other flows need to be invoked before a result is returned.), and a prescriptive means for storing and retrieving versioned config.
-## The Monitors
-A custom service enables us to create a real-time dashboard where we are able to integrate with the serving layer and extract metrics on inputs, outputs or branches taken for each request. We will be able to provide advanced monitoring capacities that will be useful for cases like fraud where we would like to see which rules are affective and which branches are taken in real-time. It would also enable monitors to quickly react if they see that a given rule is producing too few or too many alerts.
-# Advantages
-- Can use additional protocols (not limited to HTTP requests for inference)
-- Easier to integrate with a standardized parameter store if we host the models
-# Limitations
-- Extra Overheads. Need to worry about:
- - authentication
- - Access control
- - Backups and DR
- - Code execution security (may be possible to extract information about other flows by running your flow and inspecting globals)
- - Processing capacity and autoscaling
-- Need to keep at least a pilot server constantly running.
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_service/mockup.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_service/mockup.drawio.svg
deleted file mode 100644
index a173dca..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_service/mockup.drawio.svg
+++ /dev/null
@@ -1,6 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_widget/architecture.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_widget/architecture.drawio.svg
deleted file mode 100644
index 8eea089..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_widget/architecture.drawio.svg
+++ /dev/null
@@ -1,106 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_widget/index.md b/docs/getting_started/contributing/roadmap/as_a_widget/index.md
deleted file mode 100644
index 8239440..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_widget/index.md
+++ /dev/null
@@ -1,20 +0,0 @@
-# User Stories
-A widget would be similar to a plugin but it would only run inside a running jupyter notebook. So it would have limited functionality and not integrate with the sidebar. An example could look as follows:
-
-
-The widget would act in a similar fashion to the plugin. There would be minimal differences apart from it would likely be harder to automate some tasks like repo creation. We would likely need to lean more into odin for the creation of repos and would require that developers of rules are more technical.
-
-## The Developers
-The user story would be identical to a plugin. A user would likely start their journey in odin to create a new repo. They would then go to github to get the link and then open sagemaker to clone the repo and start development. Once they are done they will push their changes, go back to github to deploy their model and use a tool like postman to test their model before deploying it to various environments.
-## The Maintainers
-Maintainers would need to be the most technical for this as they would need to be comfortable cloning a repo before making changes to config either using a widget by running in a notebook or by editing it directly in json format. They would then need to know how to push to a branch and go back to github to do the deployment before testing in dev.
-# Advantages
-- We do not not need to host it (use of existing compute)
-- Easier to maintain than a Plugin (less moving parts)
-# Limitations
-- Hard to maintain (Widget developers are niche)
-- Hard to setup (Not much documentation on creating sagemaker plugins is it the same as standard jupyter lab)
-- Clients need to login to sagemaker
-- Clients need to be familiar with github and jupyter
-- Limits possibilities for open-source (niche)
-- Hard for managers just wanting to update a parameters
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/as_a_widget/mockup.drawio.svg b/docs/getting_started/contributing/roadmap/as_a_widget/mockup.drawio.svg
deleted file mode 100644
index 7689615..0000000
--- a/docs/getting_started/contributing/roadmap/as_a_widget/mockup.drawio.svg
+++ /dev/null
@@ -1,7 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/getting_started/contributing/roadmap/index.md b/docs/getting_started/contributing/roadmap/index.md
deleted file mode 100644
index e69de29..0000000
diff --git a/docs/getting_started/index.md b/docs/getting_started/index.md
index 542798d..a168007 100644
--- a/docs/getting_started/index.md
+++ b/docs/getting_started/index.md
@@ -1,11 +1,9 @@
-# Getting Started with Spockflow
+# Getting Started
-Welcome to Spockflow! This guide will help you get up and running with Spockflow, a Python framework designed for creating standalone micro-services that enrich data with actionable outputs.
+Welcome to Decider! This guide will get you up and running with Decider, a Python framework for building, serving, and inspecting decision pipelines as versioned, deployable micro-services.
```{toctree}
-Installation Guide
+Installation
Quick Start
-Contributing
-License
```
diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md
index 7fcb73a..320c984 100644
--- a/docs/getting_started/install.md
+++ b/docs/getting_started/install.md
@@ -1,57 +1,46 @@
-# Spockflow Installation Guide
-
-Welcome to the installation guide for Spockflow. Spockflow is a powerful package designed to streamline decisioning processes with its cutting-edge features.
+# Installation
## Prerequisites
-Before you begin, ensure you have the following prerequisites installed:
-- Python (version 3.8 or higher recommended)
-- Git (if installing from source)
-- AWS CLI configured with appropriate credentials for AWS CodeArtifact (if applicable)
- - Alternatively be setup in a notebook environment on the Data Science Platform
-
-## Installation Options
-
-### Option 1: Pip Install
+- Python 3.10 or higher
+- [uv](https://docs.astral.sh/uv/) (recommended) or pip
-To install Spockflow via pip, use the following command:
+## Install with uv
```bash
-pip install spockflow
+uv add decider
```
-#### Installing with Optional Features
-**YAML Support**: To add YAML support, use the [yaml] extra option:
+### Optional extras
+
+| Extra | Installs |
+|---|---|
+| `serve-starlette` | uvicorn + starlette for async HTTP serving |
+| `serve-sanic` | sanic for high-throughput HTTP serving |
+| `visualise` | streamlit + graphviz for the pipeline visualiser |
+| `notebook` | IPython + Jupyter magic |
```bash
-pip install spockflow[yaml]
-```
-**Web Application (REST API)**: To enable local REST API serving, use the [webapp] extra option:
+# Example: install with starlette serving and notebook support
+uv add "decider[serve-starlette,notebook]"
```
-pip install spockflow[webapp]
-```
-
-## Option 2: Install from Source
-To install Spockflow from source, follow these steps:
-Clone the repository from GitHub:
+## Install with pip
```bash
-git clone https://github.com/spockflow/spockflow.git
-cd spockflow
-pip install -e .
+pip install decider
```
-## Verify Installation
-To verify that Spockflow has been installed correctly, you can run the following command to check the version:
+## Install from source
```bash
- python3 -c "import spockflow; print(spockflow.__version__)"
+git clone https://github.com/capitecbankltd/dsp_north-polrs.git
+cd dsp_north-polrs
+pip install -e ".[all]"
```
-## Usage
-Refer to the Spockflow documentation for detailed usage instructions and examples.
-
-## Troubleshooting
-If you encounter any issues during installation or usage, please refer to the FAQs or reach out to our support team at sholto.armstrong@capitecbank.co.za.
+## Verify installation
+```bash
+python -c "import decider; print(decider.__version__)"
+```
diff --git a/docs/getting_started/license.md b/docs/getting_started/license.md
deleted file mode 100644
index 8479381..0000000
--- a/docs/getting_started/license.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# License
-
-```{include} ../../LICENSE
-:relative-docs: docs
-:relative-images:
-```
\ No newline at end of file
diff --git a/docs/getting_started/quick_start.ipynb b/docs/getting_started/quick_start.ipynb
deleted file mode 100644
index 10b99d2..0000000
--- a/docs/getting_started/quick_start.ipynb
+++ /dev/null
@@ -1,1757 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Quick Start with Spockflow\n",
- "Welcome to the Quick Start guide for Spockflow! This guide will walk you through setting up and running your first data enrichment pipeline using Spockflow, designed to be straightforward and user-friendly, especially within a Jupyter notebook environment. Whether you’re a data scientist looking to deploy models quickly or a developer interested in streamlining data workflows, Spockflow offers powerful tools to enrich your data with actionable insights. Let’s get started!\n",
- "\n",
- "### Preface\n",
- "\n",
- "This guide assumes that you already have hamilton installed and have installed the GraphViz package for visualization. If not please follow the [Installation Guide](install.md) before continuing."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building a Tree\n",
- "To begin, we import the necessary packages. This guide focuses on exploring the Decision Tree component, a fundamental part of Spockflow."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from spockflow.components.tree import Tree, Action"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `Tree` class allows us to construct decision trees, while the `Action` class serves as a convenient wrapper for standardizing outputs.\n",
- "\n",
- "To enforce a schema for actions, we define a specific action type:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from typing_extensions import TypedDict\n",
- "import pandas as pd\n",
- "\n",
- "\n",
- "class Reject(TypedDict):\n",
- " code: int\n",
- " description: str\n",
- "\n",
- "\n",
- "RejectAction = Action[Reject]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we create an instance of the `Tree`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "tree = Tree()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Conditions can be added to the tree using decorators. Here's an example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n",
- "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n",
- " return (d > 5) & (e > 5) & (f > 5)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This condition triggers a rejection with code 102 when certain criteria are met.\n",
- "Nested conditions can also be defined. Here's how to nest under a parent condition:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "@tree.condition()\n",
- "def condition_a(a: pd.Series) -> pd.Series:\n",
- " return a > 5\n",
- "\n",
- "\n",
- "@condition_a.condition(\n",
- " output=RejectAction(code=100, description=\"a and b is out of range\")\n",
- ")\n",
- "def condition_b(b: pd.Series) -> pd.Series:\n",
- " return b > 5\n",
- "\n",
- "\n",
- "@condition_a.condition(\n",
- " output=RejectAction(code=101, description=\"a and c is out of range\")\n",
- ")\n",
- "def condition_c(c: pd.Series) -> pd.Series:\n",
- " return c > 5"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Notice that when specifying a child node we make use of `@condition_a.condition` to specify conditions nested under `condition_a`. Furthermore, The base condition has no output set.\n",
- "\n",
- "It is important to be aware that tree construction is order dependant and rules placed first will be evaluated before rules coming after.\n",
- "\n",
- "Finally it is generally desired to set a default value if no condition is matched. This can be done as:\n",
- "\n",
- "*Note: The default behavior is to set all values to pd.NA when nothing is matched.*"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "tree.set_default(output=RejectAction(code=-1, description=\"N/A\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "It is also possible to set defaults on child elements with the following code:\n",
- "```python\n",
- "condition_a.set_default(output=RejectAction(code=-1, description=\"N/A\"))\n",
- "```\n",
- "We now define some dummy data to test the decision tree with as follows:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
"
- ],
- "text/plain": [
- " code description\n",
- "0 -1 N/A\n",
- "1 101 a and c is out of range\n",
- "2 102 My first condition\n",
- "3 -1 N/A\n",
- "4 -1 N/A\n",
- "5 -1 N/A\n",
- "6 -1 N/A\n",
- "7 -1 N/A"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tree.execute(test_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Execution as part of Hamilton\n",
- "We have successfully created our first Decision Tree using SpockFlow. Now we would like to combine it as part of a hamilton pipeline. Hamilton allows us to convert a sequence of python functions into an executable Directed Acyclic Graph (DAG). This allows us to trace the flow of data through a sequence of steps. For a detailed explination of Hamilton please refer to the [Hamilton docs](https://hamilton.dagworks.io/en/latest/get-started/your-first-dataflow/). \n",
- "Hamilton harnesses the power of functions contained in modules to construct pipelines. The following lines enable us to create Python modules directly within our notebook, empowering us to build Hamilton pipelines seamlessly without the need for external python modules. It is possible to define the above decision tree in a hamilton pipeline as shown below:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext hamilton.plugins.jupyter_magic"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%cell_to_module --display -m demo_tree\n",
- "from typing import TypedDict\n",
- "import pandas as pd\n",
- "\n",
- "from spockflow.components.tree import Tree, Action\n",
- "from spockflow.core import initialize_spock_module\n",
- "\n",
- "class Reject(TypedDict):\n",
- " code: int\n",
- " description: str\n",
- "\n",
- "RejectAction = Action[Reject]\n",
- "tree = Tree()\n",
- "\n",
- "def a(num: pd.Series) -> pd.Series:\n",
- " \"\"\" We can add a function to calculate a from an external input \"\"\"\n",
- " return num * 2\n",
- "\n",
- "\n",
- "def b(a: pd.Series, c: pd.Series) -> pd.Series:\n",
- " \"\"\"It is possible to take in more than one input as well for the transforms\"\"\"\n",
- " return a+c\n",
- "\n",
- "@tree.condition()\n",
- "def condition_a(a: pd.Series) -> pd.Series:\n",
- " return a>5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=100, description=\"a and b is out of range\"))\n",
- "def condition_b(b: pd.Series) -> pd.Series:\n",
- " return b>5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=101, description=\"a and c is out of range\"))\n",
- "def condition_c(c: pd.Series) -> pd.Series:\n",
- " return c>5\n",
- "\n",
- "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n",
- "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n",
- " return (d>5)&(e>5)&(f>5)\n",
- "\n",
- "tree.set_default(output=RejectAction(code=-1, description=\"N/A\"))\n",
- "\n",
- "initialize_spock_module(__name__, output_names=[\"tree\"])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The above code created a DAG from the decision tree. Note that additional pre-processing steps were inserted for the input values `a` and `b`. We also introduced a new line of code\n",
- "```python\n",
- "initialize_spock_module(__name__, output_names=[\"tree\"])\n",
- "```\n",
- "This hooks into the hamilton DAG creation system to construct the SpockFlow specific nodes. This is required as by default hamilton only creates nodes for functions in a module. This line also gives us the ability to specify default outputs as is discussed further in this tutorial.\n",
- "\n",
- "It is now possible to execute this DAG as follows:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
code
\n",
- "
description
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
5
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
6
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
7
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " code description\n",
- "0 -1 N/A\n",
- "1 -1 N/A\n",
- "2 100 a and b is out of range\n",
- "3 100 a and b is out of range\n",
- "4 100 a and b is out of range\n",
- "5 100 a and b is out of range\n",
- "6 100 a and b is out of range\n",
- "7 100 a and b is out of range"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from spockflow.core import Driver\n",
- "\n",
- "dr = Driver({}, demo_tree)\n",
- "df = dr.execute(inputs=test_data)\n",
- "df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "It can be seen in the above code that `spockflow.core.Driver` is used. This is a convinience wrapper to `hamilton.driver.Driver`. However, the SpockFlow driver has context of the default outputs set in the `initialize_spock_module` function. To execute this using the default driver the following code is needed:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
code
\n",
- "
description
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
5
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
6
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
7
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " code description\n",
- "0 -1 N/A\n",
- "1 -1 N/A\n",
- "2 100 a and b is out of range\n",
- "3 100 a and b is out of range\n",
- "4 100 a and b is out of range\n",
- "5 100 a and b is out of range\n",
- "6 100 a and b is out of range\n",
- "7 100 a and b is out of range"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from hamilton.driver import Driver as HamiltonDriver\n",
- "\n",
- "dr_ham = HamiltonDriver({}, demo_tree)\n",
- "df = dr_ham.execute(inputs=test_data, final_vars=[\"tree\"])\n",
- "df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Here is where Hamilton's true power begins to shine as it is possible to extract outputs from any stage in the execution and furthermore it is possible to override inputs to the different stages of the pipeline."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
"
- ],
- "text/plain": [
- " condition_b tree.code tree.description\n",
- "0 False -1 N/A\n",
- "1 False -1 N/A\n",
- "2 False 102 My first condition\n",
- "3 False -1 N/A\n",
- "4 False -1 N/A\n",
- "5 False 101 a and c is out of range\n",
- "6 False -1 N/A\n",
- "7 False -1 N/A"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df = dr.execute(\n",
- " inputs=test_data,\n",
- " final_vars=[\"condition_b\", \"tree\"],\n",
- " overrides={\"b\": test_data[\"b\"]},\n",
- ")\n",
- "df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dr.visualize_execution(\n",
- " inputs=test_data,\n",
- " final_vars=[\"condition_b\", \"tree\"],\n",
- " overrides={\"b\": test_data[\"b\"]},\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Running your flows as a server\n",
- "*NOTE: This section requires that you have installed SpockFlow with the `[webapp]` dependencies*\n",
- "\n",
- "Now that we have created a DAG it is possible to package the execution pipeline for execution. For this it is possible to adjust the jupyter magic command above to the `%%writefile` magic as shown below. We will write the contents to main.py under source_dir as this is the default entrypoint for the server."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Overwriting source_dir/main.py\n"
- ]
- }
- ],
- "source": [
- "%%writefile source_dir/main.py\n",
- "from typing import TypedDict\n",
- "import pandas as pd\n",
- "\n",
- "from spockflow.components.tree import Tree, Action\n",
- "from spockflow.core import initialize_spock_module\n",
- "\n",
- "class Reject(TypedDict):\n",
- " code: int\n",
- " description: str\n",
- "\n",
- "RejectAction = Action[Reject]\n",
- "tree = Tree()\n",
- "\n",
- "def a(num: pd.Series) -> pd.Series:\n",
- " \"\"\" We can add a function to calculate a from an external input \"\"\"\n",
- " return num * 2\n",
- "\n",
- "\n",
- "def b(a: pd.Series, c: pd.Series) -> pd.Series:\n",
- " \"\"\"It is possible to take in more than one input as well for the transforms\"\"\"\n",
- " return a+c\n",
- "\n",
- "@tree.condition()\n",
- "def condition_a(a: pd.Series) -> pd.Series:\n",
- " return a>5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=100, description=\"a and b is out of range\"))\n",
- "def condition_b(b: pd.Series) -> pd.Series:\n",
- " return b>5\n",
- "\n",
- "@condition_a.condition(output=RejectAction(code=101, description=\"a and c is out of range\"))\n",
- "def condition_c(c: pd.Series) -> pd.Series:\n",
- " return c>5\n",
- "\n",
- "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n",
- "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n",
- " return (d>5)&(e>5)&(f>5)\n",
- "\n",
- "tree.set_default(output=RejectAction(code=-1, description=\"N/A\"))\n",
- "\n",
- "initialize_spock_module(__name__, output_names=[\"tree\"])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We will also create a inference.py file which allows us to override various aspects of our requests layer including how the data is processed before entering into the model. This is discussed in depth in the [API Customization](../concepts/inference.md) section. For now we are simply converting the json data-structures to a Pandas DataFrame before executing the model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Overwriting source_dir/inference.py\n"
- ]
- }
- ],
- "source": [
- "%%writefile source_dir/inference.py\n",
- "import typing\n",
- "import pandas as pd\n",
- "\n",
- "def pre_process_fn(input_data: typing.Any) -> pd.DataFrame:\n",
- " return pd.json_normalize(input_data)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Once executed, your code should be written to the file `source_dir/main.py`. We can now simply start a server as follows:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import json\n",
- "import requests\n",
- "\n",
- "os.environ[\"MODEL_PREFIX\"] = os.path.abspath(\".\")\n",
- "os.environ[\"MODEL_RELATIVE_PATH\"] = \"source_dir\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%script bash --bg --out OUTPUT_STREAM\n",
- "uvicorn spockflow.inference.server.asgi:app --reload"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
code
\n",
- "
description
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
0
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
1
\n",
- "
-1
\n",
- "
N/A
\n",
- "
\n",
- "
\n",
- "
2
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
3
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
4
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
5
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
6
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- "
\n",
- "
7
\n",
- "
100
\n",
- "
a and b is out of range
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " code description\n",
- "0 -1 N/A\n",
- "1 -1 N/A\n",
- "2 100 a and b is out of range\n",
- "3 100 a and b is out of range\n",
- "4 100 a and b is out of range\n",
- "5 100 a and b is out of range\n",
- "6 100 a and b is out of range\n",
- "7 100 a and b is out of range"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "resp = requests.post(\n",
- " \"http://localhost:8000/invocations\", json=test_data.to_dict(orient=\"records\")\n",
- ")\n",
- "pd.DataFrame(resp.json()[\"tree\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "All background processes were killed.\n"
- ]
- }
- ],
- "source": [
- "%killbgscripts"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Advanced Deployment\n",
- "### Docker Deployment\n",
- "SpockFlow is also deployable as a docker container. The following command can be used to deploy a SpockFlow pipeline in a docker image:\n",
- "\n",
- "```bash\n",
- "docker run -v $(pwd)/source_dir:/opt/ml/model --network=\"host\" --entrypoint uvicorn -p 8000:8000 405458085848.dkr.ecr.af-south-1.amazonaws.com/sagemaker-rules-engine:0.1.8 spockflow.inference.server.asgi:app --reload\n",
- "```\n",
- "### Debugging in VSCode\n",
- "Spockflow makes use of a [Starlette](https://www.starlette.io/) app under the hood. It is therefore possible to launch a debug server directly in vscode with the following configuration:\n",
- "```json\n",
- "{\n",
- " \"version\": \"0.2.0\",\n",
- " \"configurations\": [\n",
- " {\n",
- " \"name\": \"Python Debugger: FastAPI\",\n",
- " \"type\": \"debugpy\",\n",
- " \"request\": \"launch\",\n",
- " \"module\": \"uvicorn\",\n",
- " \"args\": [\n",
- " \"spockflow.inference.server.asgi:app\",\n",
- " \"--reload\"\n",
- " ],\n",
- " \"jinja\": true,\n",
- " \"env\": {\n",
- " // Note the below config may need to change to suite your needs.\n",
- " \"MODEL_PREFIX\": \"docs/getting_started\",\n",
- " \"MODEL_RELATIVE_PATH\": \"source_dir\"\n",
- " }\n",
- " }\n",
- " ]\n",
- "}\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "spock",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/getting_started/quick_start.md b/docs/getting_started/quick_start.md
new file mode 100644
index 0000000..1a488c1
--- /dev/null
+++ b/docs/getting_started/quick_start.md
@@ -0,0 +1,24 @@
+# Quick Start
+
+The notebooks below walk through Decider end-to-end. Each one is self-contained and runnable in order.
+
+```{toctree}
+:maxdepth: 1
+
+Getting Started <../examples/01_getting_started>
+Pipelines & Joins <../examples/02_pipelines_and_joins>
+Config & Persistence <../examples/03_config_and_persistence>
+```
+
+## Example Projects
+
+More complete, domain-specific pipeline examples:
+
+```{toctree}
+:maxdepth: 1
+
+Loan Scoring <../examples/projects/01_loan_scoring>
+Fraud Detection <../examples/projects/02_fraud_detection>
+Affordability Check <../examples/projects/03_affordability_check>
+Multi-Bureau Pipeline <../examples/projects/04_multi_bureau_pipeline>
+```
diff --git a/docs/index.rst b/docs/index.rst
index d17f350..344daad 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,21 +1,11 @@
-.. include:: main.rst
+.. include:: main.md
+ :parser: myst_parser.sphinx_
.. toctree::
:hidden:
+ :maxdepth: 2
- Home page
-
-.. toctree::
- :hidden:
- :caption: USER GUIDE
-
- Getting Started
+ getting_started/index
concepts/index
- API reference <_autosummary/spockflow>
-
-.. toctree::
- :hidden:
- :caption: EXTERNAL RESOURCES
-
- GitHub
- Hamilton
+ contributing/index
+ api
diff --git a/docs/main.md b/docs/main.md
new file mode 100644
index 0000000..514ec15
--- /dev/null
+++ b/docs/main.md
@@ -0,0 +1,32 @@
+# Decider
+
+**Decider** is a Python framework for building, serving, and inspecting decision pipelines as versioned, deployable micro-services.
+
+Define pipelines from plain Python functions, compose them with `|` and `&`, save them as versioned JSON configs, and serve them over HTTP — all with a single consistent API.
+
+::::{grid} 2
+:::{grid-item-card} Getting Started
+:link: getting_started/index
+:link-type: doc
+
+Install Decider and run your first pipeline in minutes.
+:::
+:::{grid-item-card} Concepts
+:link: concepts/index
+:link-type: doc
+
+Understand modules, pipelines, config versioning and extensions.
+:::
+:::{grid-item-card} API Reference
+:link: api
+:link-type: doc
+
+Full auto-generated API documentation.
+:::
+:::{grid-item-card} Contributing
+:link: contributing/index
+:link-type: doc
+
+How to contribute to Decider.
+:::
+::::
diff --git a/docs/main.rst b/docs/main.rst
deleted file mode 100644
index 26b0d74..0000000
--- a/docs/main.rst
+++ /dev/null
@@ -1,64 +0,0 @@
-Spockflow
-=========
-
-Spock is a python framework aimed at creating standalone micro-services that enrich data with actionable outputs. Spock supports both batch and live inference modes. Spock extends existing frameworks to provide a simplistic abstraction for common data flows including policy rules and scoring. Spock is designed to be extensible and modular allowing pipelines and config to be reused in multiple flows. Finally, a large emphasis is placed on runtime traceability and explainability by leveraging Hamilton which is a well established python framework. This allows the lineage of data to be tracked and visualised as well as enabling the steps in a process that lead to a certain outcome to be easily identifiable.
-
-.. figure:: ./_static/getting-started/example_pipeline.drawio.svg
- :scale: 100
- :align: center
- :class: only-dark
-
-.. figure:: ./_static/getting-started/example_pipeline_light.drawio.svg
- :scale: 100
- :align: center
- :class: only-light
-
-Components
-----------
-
-Spock provides the following tools
-
-- **Scorecards**: Providing an understandable way to assign scores to entities based on various parameters.
-- **Decision Tables**: Enables input values to be mapped to output values based on a set of conditions.
-- **Decision Trees**: Empowers the developer to specify a set of potentially nested rules which can be used for policies or to choose outputs based on the aforementioned rules.
-- **General Transformations**: Often times simplistic transformations are required as part of the process and can be done using simple python code.
-- **User driven components**: The framework provides a simple interface that can be used to enable the user to extend Spock and create custom reusable components.
-
-Why Spock
----------
-
-- Spock enables Data Scientists to own the deployment of their models enabling a quicker time to production.
-- Easy handover to maintenance teams due to the ability to export core components as config.
-- Hamilton provides a form of runtime traceability
-- Components in the Spock ecosystem can be easily adjusted to tweak performance, throughput and for general maintenance.
-- Spock is built into DSP (and follows the Sagemaker standard) allowing it to leverage all of the existing capabilities out-of-the-box including:
- - easy EDL, EDW and feature platform integration
- - Model lifecycle with CICD integration
- - EDA and a development environment
- - Monitoring
- - Methods for both batch and live predictions using a single codebase
-
-- The python programming ecosystem
- - Python is a popular language among Data Scientists
- - Many libraries exist for python
- - Python and the abstractions provided by spock make the development of custom components easy
- - Tooling including: IDEs, static code analysers (validation), code completion
-
-Use Cases
----------
-
-- **Credit Granting**: A general credit granting flow will consist of ARODS, Affordability, Risk assessment and Pricing. Spock provides the following:
- - abstractions over Decision trees to assist with ARODS,
- - Scorecards for affordability and risk assessments and decision trees to allow results to be selected from scorecard based on a set of criteria,
- - Decision tables are typically used to form a price from the resulting affordability and risk assessment.
-
-- **Fraud**: Fraud is typically a flat set of rules used to detect fraudulent transactions. The decision tree functionality can be used to provide binary labels to the data.
-
-When to look elsewhere
----------
-
-- Spock is a function level orchestrator but is not intended to orchestrate multiple micro-services.
-- Spock is intended to run on a record level and is not designed to aggregate multiple records
-- Spock is stateless so all information needed to process records must be provided as part of the inputs to the flow.
-- Spock does not provide any means to perform actions but merely enriches the output data
-- Spock does not intend to provide a encumbering ui for development of arbitrary flows.
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index 32bb245..0000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,35 +0,0 @@
-@ECHO OFF
-
-pushd %~dp0
-
-REM Command file for Sphinx documentation
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set SOURCEDIR=.
-set BUILDDIR=_build
-
-%SPHINXBUILD% >NUL 2>NUL
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.If you don't have Sphinx installed, grab it from
- echo.https://www.sphinx-doc.org/
- exit /b 1
-)
-
-if "%1" == "" goto help
-
-%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-goto end
-
-:help
-%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-
-:end
-popd
diff --git a/docs/robots.txt b/docs/robots.txt
index d7f5a46..7dd6844 100644
--- a/docs/robots.txt
+++ b/docs/robots.txt
@@ -1,3 +1,2 @@
User-agent: *
-
-Sitemap: https://spockflow.capinet/en/latest/sitemap.xml
\ No newline at end of file
+Sitemap: https://capitecbankltd.github.io/dsp_north-polrs/sitemap.xml
diff --git a/dsp-re-ui/.gitignore b/dsp-re-ui/.gitignore
deleted file mode 100644
index e5c939f..0000000
--- a/dsp-re-ui/.gitignore
+++ /dev/null
@@ -1,39 +0,0 @@
-.DS_STORE
-node_modules
-scripts/flow/*/.flowconfig
-.flowconfig
-*~
-*.pyc
-.grunt
-_SpecRunner.html
-__benchmarks__
-build/
-remote-repo/
-coverage/
-.module-cache
-fixtures/dom/public/react-dom.js
-fixtures/dom/public/react.js
-test/the-files-to-test.generated.js
-*.log*
-chrome-user-data
-*.sublime-project
-*.sublime-workspace
-.idea
-*.iml
-.vscode
-*.swp
-*.swo
-
-packages/react-devtools-core/dist
-packages/react-devtools-extensions/chrome/build
-packages/react-devtools-extensions/chrome/*.crx
-packages/react-devtools-extensions/chrome/*.pem
-packages/react-devtools-extensions/firefox/build
-packages/react-devtools-extensions/firefox/*.xpi
-packages/react-devtools-extensions/firefox/*.pem
-packages/react-devtools-extensions/shared/build
-packages/react-devtools-extensions/.tempUserDataDir
-packages/react-devtools-fusebox/dist
-packages/react-devtools-inline/dist
-packages/react-devtools-shell/dist
-packages/react-devtools-timeline/dist
\ No newline at end of file
diff --git a/dsp-re-ui/README.md b/dsp-re-ui/README.md
deleted file mode 100644
index a6c55f6..0000000
--- a/dsp-re-ui/README.md
+++ /dev/null
@@ -1,82 +0,0 @@
-# Data Science Platform (DSP) Rule Engine UI
-
-## Overview
-
-The DSP Rule Engine UI is a React-based application that allows users to visually build decision trees for rule-based decision-making systems. Once the trees are built, users can save and execute them through endpoints or visualize and test them within a Jupyter notebook.
-
-## Prerequisites
-
-Before getting started, make sure you have the following installed:
-
-- **Node.js** (for running the React app)
-- **Yarn** or **npm** (for managing dependencies)
-- **Python** (for running the backend model and Jupyter notebooks)
-- **Graphviz** (for visualizing decision trees in the notebook)
-
-## Installation
-
-1. Clone the repository:
-
- ```bash
- git clone https://github.com/capitec/dsp-re-ui.git
- cd dsp-re-ui
- ```
-
-2. Install the dependencies using **Yarn** or **npm**:
-
- - Using **npm**:
- ```bash
- npm install
- ```
-
- - Or using **Yarn** (if you prefer Yarn):
- ```bash
- yarn install
- ```
-
-## Running the React App
-
-To start the React application in development mode:
-
-- Using **npm**:
- ```bash
- npm start
- ```
-
-- Using **Yarn**:
- ```bash
- yarn start
- ```
-
-This will start the UI on `http://localhost:3000`, where you can interact with the decision tree builder.
-
-## Visual Decision Tree Builder
-
-The UI allows you to visually build decision trees. Once a decision tree is created, you can save it to a configuration file.
-
-1. **Add Config to `config.json`**:
- - After building the tree in the UI, save the configuration to the file `dsp-re-ui/example/tree/source_dir/config.json`.
-
-2. **Run the Model with VSCode**:
- - You can use VSCode to launch the decision tree model as an endpoint based on the config file.
-
-3. **Visualize and Test in Jupyter Notebook**:
- - Alternatively, you can use the notebook `dsp-re-ui/example/tree/test.ipynb` to visualize the decision tree (make sure **Graphviz** is installed) and test it with batch data on the model.
-
-## Running the Project Manually
-
-If you'd prefer to run the project manually, you can use the following command:
-
-```bash
-MODEL_PREFIX="dsp-re-ui/example/tree" MODEL_RELATIVE_PATH="source_dir" MODEL_VERSION="0.0.0" uvicorn spockflow.inference.server.asgi:app --reload
-```
-
-Make sure to replace the relevant paths if needed, and this will start the backend server for your decision tree model.
-
-## License
-
-This project is licensed under the MIT License - see the [LICENSE](../LICENSE) file for details.
-
----
-
-Now, the README includes options for **Yarn** alongside the npm instructions, offering flexibility for users who prefer Yarn as their package manager.
\ No newline at end of file
diff --git a/dsp-re-ui/components.json b/dsp-re-ui/components.json
deleted file mode 100644
index 3aedb46..0000000
--- a/dsp-re-ui/components.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "$schema": "https://ui.shadcn.com/schema.json",
- "style": "new-york",
- "rsc": false,
- "tsx": true,
- "tailwind": {
- "config": "tailwind.config.js",
- "css": "src/index.css",
- "baseColor": "zinc",
- "cssVariables": true,
- "prefix": ""
- },
- "aliases": {
- "components": "@/components",
- "utils": "@/lib/utils",
- "ui": "@/components/ui",
- "lib": "@/lib",
- "hooks": "@/hooks"
- },
- "iconLibrary": "lucide"
- }
-
\ No newline at end of file
diff --git a/dsp-re-ui/craco.config.js b/dsp-re-ui/craco.config.js
deleted file mode 100644
index 533f8e9..0000000
--- a/dsp-re-ui/craco.config.js
+++ /dev/null
@@ -1,18 +0,0 @@
-const CracoAlias = require("craco-alias");
-
-module.exports = {
- plugins: [
- {
- plugin: CracoAlias,
- options: {
- source: "tsconfig",
- // baseUrl SHOULD be specified
- // plugin does not take it from tsconfig
- baseUrl: "./src",
- /* tsConfigPath should point to the file where "baseUrl" and "paths"
- are specified*/
- tsConfigPath: "./tsconfig.paths.json"
- }
- }
- ]
-};
\ No newline at end of file
diff --git a/dsp-re-ui/example/tree/.gitignore b/dsp-re-ui/example/tree/.gitignore
deleted file mode 100644
index 0c71fb9..0000000
--- a/dsp-re-ui/example/tree/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-source_dir/config.json
-data.json
\ No newline at end of file
diff --git a/dsp-re-ui/example/tree/dtreevis.ipynb b/dsp-re-ui/example/tree/dtreevis.ipynb
deleted file mode 100644
index e7719a6..0000000
--- a/dsp-re-ui/example/tree/dtreevis.ipynb
+++ /dev/null
@@ -1,1501 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install dtreeviz"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%config InlineBackend.figure_format = 'retina' # Make visualizations look good\n",
- "# %config InlineBackend.figure_format = 'svg'\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "from spockflow.components.treelite import Tree\n",
- "from spockflow.components.treelite.core import CompiledTreeliteTree, OutputEncoding\n",
- "import json\n",
- "import treelite.sklearn\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import dtreeviz"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "with open(\"source_dir/config.json\") as fp:\n",
- " tree = Tree.model_validate_json(fp.read())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "node_id_mapping, leaf_nodes = CompiledTreeliteTree._get_node_id_mapping(tree.nodes)\n",
- "root_nodes = CompiledTreeliteTree._identify_independent_tree_roots(tree.nodes)\n",
- "trees = [\n",
- " CompiledTreeliteTree._build_treelite_tree(\n",
- " root_nodes=[rn],\n",
- " tree=tree,\n",
- " node_id_mapping=node_id_mapping,\n",
- " output_encoding=OutputEncoding.ONE_HOT,\n",
- " ).commit()\n",
- " for rn in root_nodes\n",
- "]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "sklearn_tree = treelite.sklearn.export_model(trees[0])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "