diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..ccc5f2c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,67 @@ +name: release + +on: + push: + tags: + - "v*" + +jobs: + build: + name: Build sdist and wheel + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install build + run: python -m pip install --upgrade build + + - name: Build distribution + run: python -m build + + - uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + publish-testpypi: + name: Publish to TestPyPI + needs: build + if: contains(github.ref, '-rc') || contains(github.ref, '-a') || contains(github.ref, '-b') + runs-on: ubuntu-latest + environment: + name: testpypi + url: https://test.pypi.org/p/baycomp_plotting + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + + publish-pypi: + name: Publish to PyPI + needs: build + # Only stable tags (no -rc, -a, -b suffix) go to real PyPI. + if: ${{ !contains(github.ref, '-') }} + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/baycomp_plotting + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a0d8b07 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,40 @@ +name: tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + # Weekly run on Mondays catches breakage from new transitive deps. + - cron: "0 6 * * 1" + +jobs: + test: + name: py${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + include: + - os: macos-latest + python-version: "3.12" + - os: windows-latest + python-version: "3.12" + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package with test extras + run: | + python -m pip install --upgrade pip + python -m pip install -e ".[test]" + + - name: Run tests + run: pytest --mpl diff --git a/baycomp_plotting/__init__.py b/baycomp_plotting/__init__.py deleted file mode 100644 index 7472c3d..0000000 --- a/baycomp_plotting/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .plotting import * diff --git a/baycomp_plotting/plotting.py b/baycomp_plotting/plotting.py deleted file mode 100644 index 217b39c..0000000 --- a/baycomp_plotting/plotting.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Functions for plotting baycomp's posterior plots. - -``tern`` - A ternary plot. - -``dens`` - A density plot. - -Author: Mario Juez-Gil -""" - -import types -import numpy as np -import matplotlib.patches as patches -import matplotlib.colors as clrs -from matplotlib import pyplot as plt -from matplotlib.lines import Line2D -from matplotlib.path import Path -from matplotlib.colors import ListedColormap -from math import sqrt, sin, cos, pi -from iteround import saferound -from scipy.interpolate import interpn -from scipy import stats - -__all__ = ['Color', 'tern', 'dens'] - -# Custom colormap -> dark blue means high density, and light blue low -BLUES_CMAP = np.ones((256, 4)) -BLUES_CMAP[:, 0] = np.linspace(199/256, 8/256, 256) -BLUES_CMAP[:, 1] = np.linspace(224/256, 64/256, 256) -BLUES_CMAP[:, 2] = np.linspace(252/256, 129/256, 256) -BLUES_CMAP = ListedColormap(BLUES_CMAP) - -# Custom colors -class Color(): - BLUE = clrs.to_hex((0/255, 142/255, 206/255)) - GRAY = clrs.to_hex((77/255, 80/255, 94/255)) - BORDEAUX = clrs.to_hex((208/255, 33/255, 85/255)) - GREEN = clrs.to_hex((5/255, 126/255, 121/255)) - -def project(pts): - SQRT_3 = sqrt(3) - PI_6 = pi / 6 - - p1, p2, p3 = pts.T / SQRT_3 - x = (p2 - p1) * cos(PI_6) + .5 - y = p3 - (p1 + p2) * sin(PI_6) + 1 / (2 * SQRT_3) - - return np.vstack((x, y)).T - -def process_names(names): - for i in range(len(names)): - names[i] = names[i].replace("-", "{-}") - names[i] = names[i].replace(" ", "\\ ") - return names - -def tern(p, names=["L", "R"]): - names = process_names(names) - plt.style.use('classic') - - fig, ax = plt.subplots() - fig.patch.set_alpha(0) - ax.set_aspect('equal', 'box') - ax.axis('off') - ax.set_xlim(-.1, 1.1) - ax.set_ylim(-.1, 1.1) - - # inner lines - cx, cy = project(np.array([[1/3, 1/3, 1/3]]))[0] # central point - for x, y in project(np.array([[.5, .5, 0], [.5, 0, .5], [0, .5, .5]])): - ax.add_line(Line2D([cx, x], [cy, y], color='k', lw=2)) - - # outer border of the triangle - vert_coords = [(0., 0.), (.5, sqrt(3)/2), (1., 0.), (0., 0.)] - codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY] - triangle = Path(vert_coords, codes) - patch = patches.PathPatch(triangle, facecolor='none', lw=3) - ax.add_patch(patch) - - # vertices texts - probs = saferound(list(p.probs()), places=4) - # L - ax.text(-.04, -.02, r'$\mathrm{%s}$'%names[0], ha='center', va='top', fontsize=30) - ax.text(-.04, -.12, r'$\mathbf{(%.4f)}$'%probs[0], ha='center', va='top', fontsize=32) - # ROPE - ax.text(.5, 1, r'$\mathrm{ROPE}$', ha='center', va='bottom', fontsize=30) - ax.text(.5, .87, r'$\mathbf{(%.4f)}$'%probs[1], ha='center', va='bottom', fontsize=32) - # R - ax.text(1.04, -.02, r'$\mathrm{%s}$'%names[1], ha='center', va='top', fontsize=30) - ax.text(1.04, -.12, r'$\mathbf{(%.4f)}$'%probs[2], ha='center', va='top', fontsize=32) - - # draw points - tripts = project(p.sample[:, [0, 2, 1]]) - data, xe, ye = np.histogram2d(tripts[:, 0], tripts[:, 1], bins=30) - z = interpn((.5 * (xe[1:] + xe[:-1]), .5 * (ye[1:] + ye[:-1])), data, - np.vstack([tripts[:, 0], tripts[:, 1]]).T, method='splinef2d', - bounds_error=False) - idx = z.argsort() - ax.scatter(tripts[:, 0][idx], tripts[:, 1][idx], c=z[idx], clip_path=patch, - s=50, linewidth=0, cmap=BLUES_CMAP, rasterized=True) - - fig.tight_layout() - return fig - -def dens(p, label, ls='-', color=Color.BLUE): - def add_posterior(ax, p, label, ls, color): - def _update_yticks(): - tick = ax.max_y / 3 - yticks = [0, tick*1, tick*2, tick*3] - ax.set_ylim(0, ax.max_y) - ax.set_yticks(yticks) - ax.set_yticklabels([r'$%.3f$'%round(x, 3) for x in yticks]) - - targs = (p.df, p.mean, np.sqrt(p.var)) - x = np.linspace(min(stats.t.ppf(.005, *targs), -1.05 * p.rope), - max(stats.t.ppf(.995, *targs), 1.05 * p.rope), 100) - y = stats.t.pdf(x, *targs) - y = y / y.sum() # density - ax.plot(x, y, c=color, linestyle=ls, linewidth=2, - label=r'$\mathrm{%s}$'%label, zorder=ax.zo) - ax.fill_between(x, y, facecolor=color, alpha=.1, edgecolor='none', - zorder=ax.zo) - ax.zo = ax.zo - 1 - - if(ax.max_y == None): - ax.max_y = np.amax(y) + np.amax(y)*.02 - _update_yticks() - else: - curr_max_y = np.amax(y) + np.amax(y)*.02 - if(curr_max_y > ax.max_y): - ax.max_y = curr_max_y - _update_yticks() - - plt.style.use('classic') - plt.rc('legend', fontsize=25) - - # figure customization - fig, ax = plt.subplots() - fig.patch.set_alpha(0) - ax.axvline(.01, c='darkorange', linewidth=2, zorder=101) - ax.axvline(-.01, c='darkorange', linewidth=2, zorder=101) - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.yaxis.set_ticks_position('left') - ax.xaxis.set_ticks_position('none') - ax.tick_params(axis='both', which='major', labelsize=30, direction='inout', - width=1, length=5) - ax.set_xticks([]) - - # appending the posterior to the axes - ax.max_y = None - ax.zo = 100 - add_posterior(ax, p, label, ls, color) - fig.add_posterior = types.MethodType(add_posterior, ax) - - fig.tight_layout() - return fig diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..df09cbb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,73 @@ +[build-system] +requires = ["hatchling>=1.24"] +build-backend = "hatchling.build" + +[project] +name = "baycomp_plotting" +version = "1.2.0" +description = "Extra plotting functionality for baycomp's Bayesian classifier comparison posteriors." +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.9" +authors = [ + { name = "Mario Juez-Gil", email = "mariojg@ubu.es" }, +] +keywords = [ + "bayesian", + "classifier-comparison", + "machine-learning", + "plotting", + "baycomp", +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Visualization", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "matplotlib>=3.5", + "numpy>=1.21", + "scipy>=1.7", +] + +[project.optional-dependencies] +test = [ + "pytest>=7", + "pytest-mpl>=0.17", + "baycomp>=1.0.3", +] + +[project.urls] +Homepage = "https://github.com/mjuez/baycomp_plotting" +Repository = "https://github.com/mjuez/baycomp_plotting" +Issues = "https://github.com/mjuez/baycomp_plotting/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/baycomp_plotting"] + +[tool.hatch.build.targets.sdist] +include = [ + "src", + "tests", + "README.md", + "LICENSE", + "pyproject.toml", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-ra" +filterwarnings = [ + "error::DeprecationWarning", + "error::FutureWarning", + # Emitted by matplotlib's internal use of pyparsing; third-party. + "ignore::pyparsing.warnings.PyparsingDeprecationWarning", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b88034e..0000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[metadata] -description-file = README.md diff --git a/setup.py b/setup.py deleted file mode 100644 index d0ac92a..0000000 --- a/setup.py +++ /dev/null @@ -1,35 +0,0 @@ -import setuptools - -with open('README.md') as f: - readme = f.read() - -setuptools.setup( - name='baycomp_plotting', - version='1.1.1', - description='This package provides some extra functionality for plotting baycomp\'s posteriors.', - long_description=readme, - long_description_content_type='text/markdown', - author='Mario Juez-Gil', - author_email='mariojg@ubu.es', - url='https://github.com/mjuez/baycomp_plotting', - download_url='https://github.com/mjuez/baycomp_plotting/archive/v1_1_1.tar.gz', - license='GPLv3', - install_requires=[ - 'matplotlib==3.3.2', - 'numpy==1.19.1', - 'iteround==1.0.2', - 'scipy==1.5.3' - ], - packages=setuptools.find_packages(), - include_package_data=True, - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering', - 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8' - ] -) diff --git a/src/baycomp_plotting/__init__.py b/src/baycomp_plotting/__init__.py new file mode 100644 index 0000000..a200f31 --- /dev/null +++ b/src/baycomp_plotting/__init__.py @@ -0,0 +1,11 @@ +from importlib.metadata import PackageNotFoundError as _PackageNotFoundError +from importlib.metadata import version as _version + +from .plotting import Color, dens, tern + +try: + __version__ = _version("baycomp_plotting") +except _PackageNotFoundError: + __version__ = "0.0.0+unknown" + +__all__ = ["Color", "dens", "tern", "__version__"] diff --git a/src/baycomp_plotting/plotting.py b/src/baycomp_plotting/plotting.py new file mode 100644 index 0000000..7958d25 --- /dev/null +++ b/src/baycomp_plotting/plotting.py @@ -0,0 +1,226 @@ +""" +Functions for plotting baycomp's posterior distributions. + +``tern`` + A ternary plot for posteriors over (left, rope, right) probabilities. + +``dens`` + A density plot for the posterior of a CorrelatedTTest. + +Author: Mario Juez-Gil +""" +from __future__ import annotations + +import types +from math import cos, pi, sin, sqrt +from typing import Sequence + +import matplotlib.colors as clrs +import matplotlib.patches as patches +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.colors import ListedColormap +from matplotlib.lines import Line2D +from matplotlib.path import Path +from scipy import stats +from scipy.interpolate import interpn + +__all__ = ["Color", "tern", "dens"] + + +_BLUES_CMAP_RGBA = np.ones((256, 4)) +_BLUES_CMAP_RGBA[:, 0] = np.linspace(199 / 256, 8 / 256, 256) +_BLUES_CMAP_RGBA[:, 1] = np.linspace(224 / 256, 64 / 256, 256) +_BLUES_CMAP_RGBA[:, 2] = np.linspace(252 / 256, 129 / 256, 256) +BLUES_CMAP = ListedColormap(_BLUES_CMAP_RGBA) + + +class Color: + """Alternative palette of four colors for matplotlib plots.""" + + BLUE = clrs.to_hex((0 / 255, 142 / 255, 206 / 255)) + GRAY = clrs.to_hex((77 / 255, 80 / 255, 94 / 255)) + BORDEAUX = clrs.to_hex((208 / 255, 33 / 255, 85 / 255)) + GREEN = clrs.to_hex((5 / 255, 126 / 255, 121 / 255)) + + +def _safe_round(values: Sequence[float], places: int) -> list[float]: + """Round ``values`` to ``places`` decimals while preserving their sum. + + Largest-remainder method: distribute the rounding residual to the entries + with the largest fractional parts so that, for probabilities summing to 1, + the rounded triplet still sums to 1.0 to ``places`` decimals. + """ + if not values: + return [] + multiplier = 10 ** places + scaled = [v * multiplier for v in values] + floors = [int(s) for s in scaled] + diff = int(round(sum(scaled))) - sum(floors) + if diff > 0: + order = sorted(range(len(scaled)), key=lambda i: scaled[i] - floors[i], reverse=True) + for i in order[:diff]: + floors[i] += 1 + elif diff < 0: + order = sorted(range(len(scaled)), key=lambda i: scaled[i] - floors[i]) + for i in order[: -diff]: + floors[i] -= 1 + return [f / multiplier for f in floors] + + +def _project(pts: np.ndarray) -> np.ndarray: + """Project barycentric coordinates onto the 2D simplex triangle.""" + sqrt_3 = sqrt(3) + pi_6 = pi / 6 + p1, p2, p3 = pts.T / sqrt_3 + x = (p2 - p1) * cos(pi_6) + 0.5 + y = p3 - (p1 + p2) * sin(pi_6) + 1 / (2 * sqrt_3) + return np.vstack((x, y)).T + + +def _process_names(names: Sequence[str]) -> list[str]: + """Escape characters in ``names`` so they survive matplotlib's mathtext.""" + return [n.replace("-", "{-}").replace(" ", "\\ ") for n in names] + + +def _add_posterior(ax, p, label: str, ls="-", color: str = Color.BLUE) -> None: + """Render a CorrelatedTTest posterior on ``ax`` and update the y-scale. + + Reads and updates two ad-hoc attributes on ``ax``: ``max_y`` (current y-axis + upper bound) and ``zo`` (z-order counter, decremented for each posterior so + earlier ones stay on top). + """ + targs = (p.df, p.mean, np.sqrt(p.var)) + x = np.linspace( + min(stats.t.ppf(0.005, *targs), -1.05 * p.rope), + max(stats.t.ppf(0.995, *targs), 1.05 * p.rope), + 100, + ) + y = stats.t.pdf(x, *targs) + y = y / y.sum() + ax.plot( + x, y, c=color, linestyle=ls, linewidth=2, + label=r"$\mathrm{%s}$" % label, zorder=ax.zo, + ) + ax.fill_between(x, y, facecolor=color, alpha=0.1, edgecolor="none", zorder=ax.zo) + ax.zo -= 1 + + curr_max_y = float(y.max()) * 1.02 + if ax.max_y is None or curr_max_y > ax.max_y: + ax.max_y = curr_max_y + tick = ax.max_y / 3 + yticks = [0, tick, 2 * tick, 3 * tick] + ax.set_ylim(0, ax.max_y) + ax.set_yticks(yticks, labels=[r"$%.3f$" % v for v in yticks]) + + +def tern(p, names: Sequence[str] = ("L", "R")): + """Ternary plot for a posterior over (left, rope, right) probabilities. + + Parameters + ---------- + p + A baycomp posterior with ``probs()`` returning three values that sum to 1 + and ``sample`` of shape ``(n_samples, 3)`` (e.g. ``HierarchicalTest`` or + ``SignedRankTest``). + names + Two-element sequence with labels for the left and right vertices. + """ + names = _process_names(list(names)) + plt.style.use("classic") + + fig, ax = plt.subplots() + fig.patch.set_alpha(0) + ax.set_aspect("equal", "box") + ax.axis("off") + ax.set_xlim(-0.1, 1.1) + ax.set_ylim(-0.1, 1.1) + + cx, cy = _project(np.array([[1 / 3, 1 / 3, 1 / 3]]))[0] + for x, y in _project(np.array([[0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]])): + ax.add_line(Line2D([cx, x], [cy, y], color="k", lw=2)) + + vert_coords = [(0.0, 0.0), (0.5, sqrt(3) / 2), (1.0, 0.0), (0.0, 0.0)] + codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY] + triangle = Path(vert_coords, codes) + patch = patches.PathPatch(triangle, facecolor="none", lw=3) + ax.add_patch(patch) + + probs = _safe_round(list(p.probs()), places=4) + ax.text(-0.04, -0.02, r"$\mathrm{%s}$" % names[0], ha="center", va="top", fontsize=30) + ax.text(-0.04, -0.12, r"$\mathbf{(%.4f)}$" % probs[0], ha="center", va="top", fontsize=32) + ax.text(0.5, 1, r"$\mathrm{ROPE}$", ha="center", va="bottom", fontsize=30) + ax.text(0.5, 0.87, r"$\mathbf{(%.4f)}$" % probs[1], ha="center", va="bottom", fontsize=32) + ax.text(1.04, -0.02, r"$\mathrm{%s}$" % names[1], ha="center", va="top", fontsize=30) + ax.text(1.04, -0.12, r"$\mathbf{(%.4f)}$" % probs[2], ha="center", va="top", fontsize=32) + + tripts = _project(p.sample[:, [0, 2, 1]]) + data, xe, ye = np.histogram2d(tripts[:, 0], tripts[:, 1], bins=30) + z = interpn( + (0.5 * (xe[1:] + xe[:-1]), 0.5 * (ye[1:] + ye[:-1])), + data, + np.vstack([tripts[:, 0], tripts[:, 1]]).T, + method="splinef2d", + bounds_error=False, + ) + idx = z.argsort() + ax.scatter( + tripts[:, 0][idx], + tripts[:, 1][idx], + c=z[idx], + clip_path=patch, + s=50, + linewidth=0, + cmap=BLUES_CMAP, + rasterized=True, + ) + + fig.tight_layout() + return fig + + +def dens(p, label: str, ls="-", color: str = Color.BLUE): + """Density plot for the posterior of a ``baycomp.CorrelatedTTest``. + + Parameters + ---------- + p + A ``CorrelatedTTest`` posterior with attributes ``df``, ``mean``, ``var`` + and ``rope``. + label + Math-text label shown in the legend (LaTeX-style spacing must be escaped). + ls + Matplotlib line style. + color + Density colour, e.g. one of the :class:`Color` constants. + + Returns + ------- + matplotlib.figure.Figure + Figure with an extra ``add_posterior(p, label, ls, color)`` method that + appends another posterior on top of the existing axes. + """ + plt.style.use("classic") + plt.rc("legend", fontsize=25) + + fig, ax = plt.subplots() + fig.patch.set_alpha(0) + ax.axvline(0.01, c="darkorange", linewidth=2, zorder=101) + ax.axvline(-0.01, c="darkorange", linewidth=2, zorder=101) + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.yaxis.set_ticks_position("left") + ax.xaxis.set_ticks_position("none") + ax.tick_params( + axis="both", which="major", labelsize=30, direction="inout", width=1, length=5, + ) + ax.set_xticks([]) + + ax.max_y = None + ax.zo = 100 + _add_posterior(ax, p, label, ls, color) + fig.add_posterior = types.MethodType(_add_posterior, ax) + + fig.tight_layout() + return fig diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/baseline_images/test_dens_single_posterior.png b/tests/baseline_images/test_dens_single_posterior.png new file mode 100644 index 0000000..2cb0367 Binary files /dev/null and b/tests/baseline_images/test_dens_single_posterior.png differ diff --git a/tests/baseline_images/test_dens_two_posteriors.png b/tests/baseline_images/test_dens_two_posteriors.png new file mode 100644 index 0000000..eb23325 Binary files /dev/null and b/tests/baseline_images/test_dens_two_posteriors.png differ diff --git a/tests/baseline_images/test_tern_basic.png b/tests/baseline_images/test_tern_basic.png new file mode 100644 index 0000000..784bbaa Binary files /dev/null and b/tests/baseline_images/test_tern_basic.png differ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d99130d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +"""Shared fixtures for the baycomp_plotting test suite.""" +from __future__ import annotations + +import baycomp as bc +import matplotlib +import numpy as np +import pytest + +matplotlib.use("Agg") + + +@pytest.fixture +def rng() -> np.random.Generator: + return np.random.default_rng(42) + + +@pytest.fixture +def correlated_ttest_posterior(rng): + base = rng.normal(0.85, 0.015, 10) + acc_a = np.clip(base + rng.normal(0.0, 0.005, 10), 0, 1) + acc_b = np.clip(base + rng.normal(0.015, 0.005, 10), 0, 1) + return bc.CorrelatedTTest(acc_a, acc_b, rope=0.01, runs=1) + + +@pytest.fixture +def signed_rank_posterior(rng): + mu = rng.uniform(0.65, 0.92, 25) + mean_a = np.clip(mu + rng.normal(0.0, 0.005, 25), 0, 1) + mean_b = np.clip(mu + rng.normal(0.012, 0.005, 25), 0, 1) + return bc.SignedRankTest(mean_a, mean_b, rope=0.01, random_state=0) diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..75c0c63 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,73 @@ +"""Unit tests for the pure helper functions.""" +from __future__ import annotations + +import numpy as np +import pytest + +from baycomp_plotting import Color +from baycomp_plotting.plotting import _process_names, _project, _safe_round + + +class TestColor: + def test_palette_has_four_hex_strings(self): + assert len({Color.BLUE, Color.GRAY, Color.BORDEAUX, Color.GREEN}) == 4 + for c in (Color.BLUE, Color.GRAY, Color.BORDEAUX, Color.GREEN): + assert isinstance(c, str) + assert c.startswith("#") and len(c) == 7 + + def test_blue_is_stable(self): + # Pinned for backwards compatibility — used to be public in plots. + assert Color.BLUE == "#008ece" + + +class TestSafeRound: + def test_rounds_to_requested_decimals(self): + out = _safe_round([0.12345, 0.34567, 0.53088], places=4) + assert all(round(v * 10_000) == v * 10_000 for v in out) + + def test_preserves_sum(self): + # Triplet that would round naively to 1.0001 + values = [0.33335, 0.33335, 0.33330] + out = _safe_round(values, places=4) + assert sum(out) == pytest.approx(1.0, abs=1e-9) + + def test_handles_empty(self): + assert _safe_round([], places=4) == [] + + def test_extreme_distribution_sums_to_one(self): + out = _safe_round([0.99996, 1e-5, 3e-5], places=4) + assert sum(out) == pytest.approx(1.0, abs=1e-9) + + +class TestProject: + def test_centroid_maps_to_triangle_center(self): + out = _project(np.array([[1 / 3, 1 / 3, 1 / 3]])) + # Center of the equilateral triangle drawn by tern() is (0.5, 1/(2*sqrt(3))). + np.testing.assert_allclose(out[0], (0.5, 1 / (2 * np.sqrt(3))), atol=1e-9) + + def test_vertex_left_maps_to_left(self): + # (1,0,0) is the "left" axis in the projection; we don't pin the exact + # corner but verify it sits on x<0.5 and y= 3 # 1 density + 2 ROPE verticals + assert hasattr(fig, "add_posterior") + plt.close(fig) + + def test_add_posterior_appends_a_line(self, correlated_ttest_posterior): + fig = bplt.dens(correlated_ttest_posterior, label="A", color=bplt.Color.BLUE) + n_before = len(fig.axes[0].lines) + fig.add_posterior(correlated_ttest_posterior, label="B", color=bplt.Color.BORDEAUX) + assert len(fig.axes[0].lines) == n_before + 1 + plt.close(fig) + + def test_tern_renders_three_vertex_labels(self, signed_rank_posterior): + fig = bplt.tern(signed_rank_posterior, names=["L", "R"]) + ax = fig.axes[0] + text_contents = [t.get_text() for t in ax.texts] + assert any("L" in t for t in text_contents) + assert any("R" in t for t in text_contents) + assert any("ROPE" in t for t in text_contents) + plt.close(fig)