-
Notifications
You must be signed in to change notification settings - Fork 2
feat(Init): Initial upload of array + interoperability package #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
1a482dc
Initial upload
Simpag 49226bc
fix docs
Simpag ce74739
fix broken links
Simpag 7ff0352
fix pytest
Simpag 6c617e7
update ci
Simpag 4633489
add logical operators
Simpag 19b5060
Merge branch 'main' into dev1
Simpag 8852fdb
fix(PR): PR comments and fix MPS tests
Simpag f12d604
Merge branch 'dev1' of github.com:Simpag/decent-array into dev1
Simpag 609b939
fix(tests): Fix float64 error on MPS for torch
Simpag 8d77916
PR comments
Simpag e4e9ea5
PR comments
Simpag 6c95a2f
pr changes
Simpag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| * @Simpag @nicola-bastianello |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| **/__pycache__ | ||
| **/build | ||
| *.egg-info | ||
| *.so | ||
| .DS_Store | ||
| .mypy_cache | ||
| .tox | ||
| .vscode | ||
| dist | ||
| pyrightconfig.json | ||
| .claude | ||
| *.py[codz] | ||
| *$py.class | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """ | ||
|
nicola-bastianello marked this conversation as resolved.
|
||
| Microbenchmark: ``decent_array.Array`` operator overhead vs native frameworks. | ||
|
|
||
| Measures the wrapper cost added by routing operators through ``Array.__add__``, | ||
| ``Array.__neg__`` etc. against calling the framework's native operators | ||
| directly. Iterates over every framework whose package is importable; missing | ||
| optional dependencies are skipped silently. | ||
|
|
||
| The overhead column is ``wrapped / native`` runtime — values close to 1.0x mean | ||
| the wrapper is essentially free. Large values at small sizes are expected | ||
| (operator dispatch dominates) and should converge toward 1.0x as elementwise | ||
| work grows. | ||
|
|
||
| Run with:: | ||
|
|
||
| python benchmarks/bench_array.py | ||
| """ | ||
|
nicola-bastianello marked this conversation as resolved.
|
||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from bench_common import ( | ||
| SIZES, | ||
| BackendCase, | ||
| activate_backend, | ||
| discover_backends, | ||
| fmt_row, | ||
| parse_backends_arg, | ||
| print_preamble, | ||
| print_size_header, | ||
| time_us_safe, | ||
| ) | ||
|
|
||
| from decent_array import Array | ||
|
|
||
|
|
||
| def _bench_case(case: BackendCase) -> None: | ||
| activate_backend(case.name) | ||
| print(f"## {case.name}\n") | ||
| for n in SIZES: | ||
| a = case.make(n) | ||
| b = case.make(n) | ||
| d_a, d_b = Array(a), Array(b) | ||
|
|
||
| print_size_header(n) | ||
| rows = ( | ||
| ("add", lambda a=a, b=b: a + b, lambda d_a=d_a, d_b=d_b: d_a + d_b), | ||
| ("sub", lambda a=a, b=b: a - b, lambda d_a=d_a, d_b=d_b: d_a - d_b), | ||
| ("mul", lambda a=a, b=b: a * b, lambda d_a=d_a, d_b=d_b: d_a * d_b), | ||
| ("div", lambda a=a, b=b: a / b, lambda d_a=d_a, d_b=d_b: d_a / d_b), | ||
| ("neg", lambda a=a: -a, lambda d_a=d_a: -d_a), | ||
| ("abs", lambda a=a: abs(a), lambda d_a=d_a: abs(d_a)), | ||
| ("pow", lambda a=a: a ** 2.0, lambda d_a=d_a: d_a ** 2.0), | ||
| ) | ||
| for op, native_fn, wrapped_fn in rows: | ||
| n_us = time_us_safe(case, native_fn) | ||
| w_us = time_us_safe(case, wrapped_fn) | ||
| print(fmt_row(op, n_us, w_us)) | ||
| print() | ||
| print() | ||
|
|
||
|
|
||
| def main() -> None: | ||
| print_preamble("Array operator overhead vs native frameworks") | ||
| cases = discover_backends(only=parse_backends_arg()) | ||
| print(f"available backends: {', '.join(c.name for c in cases)}\n") | ||
| for case in cases: | ||
| _bench_case(case) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| """ | ||
| Shared helpers for ``bench_array.py`` and ``bench_iop.py``. | ||
|
|
||
| Three concerns live here so the benchmarks stay focused on the comparison logic: | ||
|
|
||
| * :func:`discover_backends` returns the subset of frameworks whose package is | ||
| importable; backends with a missing optional dependency are skipped silently. | ||
| * :func:`is_compiled` / :func:`print_preamble` report whether the user is | ||
| running against a mypyc-compiled build of ``decent_array`` or the pure-Python | ||
| source — this materially changes overhead numbers, so the result is printed | ||
| at the top of every run. | ||
| * :func:`time_us` / :func:`time_us_safe` wrap :mod:`timeit` to take the | ||
| ``min`` of several auto-ranged repeats. ``min`` is the canonical choice: it | ||
| reports the lower bound of the machine's per-call cost and is the metric | ||
| least sensitive to background activity. A warmup call precedes timing so | ||
| JIT-style backends (JAX) don't skew the first iteration. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import importlib | ||
| import timeit | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| SIZES: tuple[int, ...] = (10, 100, 1_000, 10_000) | ||
| REPEATS: int = 7 | ||
|
|
||
|
|
||
| def _no_sync(_value: Any) -> None: # noqa: ANN401 | ||
| """No-op sync used for synchronous backends (numpy, torch CPU, tf eager CPU).""" | ||
|
|
||
|
|
||
| def _sync_jax(value: Any) -> None: # noqa: ANN401 | ||
| """Block until a JAX DeviceArray is materialized, unwrapping ``Array`` if needed.""" | ||
| # Imported lazily so the module can load even when decent_array isn't yet importable. | ||
| from decent_array import Array # noqa: PLC0415 | ||
|
|
||
| raw = value.value if isinstance(value, Array) else value | ||
| raw.block_until_ready() | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class BackendCase: | ||
| """A discovered backend plus the helpers needed to drive it in a benchmark.""" | ||
|
|
||
| name: str | ||
| make: Callable[[int], Any] | ||
| sync: Callable[[Any], None] | ||
|
|
||
|
|
||
| def activate_backend(name: str) -> None: | ||
| """Activate ``name`` as the live backend, resetting any previously active one. | ||
|
|
||
| ``decent_array`` enforces a single-active-backend invariant per execution context; | ||
| swapping between frameworks within one process requires resetting first. | ||
| """ | ||
| from decent_array.interoperability._backend_manager import reset_backends, set_backend # noqa: PLC0415 | ||
|
|
||
| reset_backends() | ||
| set_backend(name) | ||
|
|
||
|
|
||
| def discover_backends(only: list[str] | None = None) -> list[BackendCase]: | ||
| """Return one :class:`BackendCase` per importable framework, in a stable order. | ||
|
|
||
| Args: | ||
| only: Optional allowlist of backend names. When provided, frameworks not in the | ||
| list are skipped entirely (their packages aren't even imported), and any | ||
| requested name that isn't a known backend raises :class:`ValueError`. | ||
|
|
||
| """ | ||
| import numpy as np # always available — hard dependency # noqa: PLC0415 | ||
|
|
||
| known = {"numpy", "pytorch", "jax", "tensorflow"} | ||
| if only is not None: | ||
| unknown = set(only) - known | ||
| if unknown: | ||
| raise ValueError(f"unknown backend(s): {sorted(unknown)}; known: {sorted(known)}") | ||
| wanted = set(only) | ||
| else: | ||
| wanted = known | ||
|
|
||
| cases: list[BackendCase] = [] | ||
|
|
||
| if "numpy" in wanted: | ||
| cases.append(BackendCase("numpy", lambda n: np.random.rand(n), _no_sync)) | ||
|
|
||
| if "pytorch" in wanted: | ||
| try: | ||
| import torch # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("pytorch", lambda n: torch.from_numpy(np.random.rand(n)), _no_sync)) | ||
|
|
||
| if "jax" in wanted: | ||
| try: | ||
| import jax.numpy as jnp # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("jax", lambda n: jnp.asarray(np.random.rand(n)), _sync_jax)) | ||
|
|
||
| if "tensorflow" in wanted: | ||
| try: | ||
| import tensorflow as tf # noqa: PLC0415 | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| cases.append(BackendCase("tensorflow", lambda n: tf.constant(np.random.rand(n)), _no_sync)) | ||
|
|
||
| return cases | ||
|
|
||
|
|
||
| def parse_backends_arg() -> list[str] | None: | ||
| """Parse the shared ``--backends`` CLI flag; returns ``None`` if not given.""" | ||
| import argparse # noqa: PLC0415 | ||
|
|
||
| parser = argparse.ArgumentParser(add_help=True) | ||
| parser.add_argument( | ||
| "--backends", | ||
| type=str, | ||
| default=None, | ||
| help="comma-separated allowlist of backends (numpy,pytorch,jax,tensorflow); default = all available", | ||
| ) | ||
| args = parser.parse_args() | ||
| if args.backends is None: | ||
| return None | ||
| return [b.strip() for b in args.backends.split(",") if b.strip()] | ||
|
|
||
|
|
||
| def is_compiled() -> tuple[bool, str]: | ||
| """Return ``(True, path)`` if the Array module loaded from a ``.so``/``.pyd``, else ``(False, .py path)``.""" | ||
| module = importlib.import_module("decent_array._array") | ||
| path = module.__file__ or "<unknown>" | ||
| return path.endswith((".so", ".pyd")), path | ||
|
|
||
|
|
||
| def print_preamble(title: str) -> None: | ||
| compiled, path = is_compiled() | ||
| print(f"# {title}\n") | ||
| print(f"decent_array compiled: {'yes' if compiled else 'no'}") | ||
| print(f" Array loaded from: {path}") | ||
| print(f" timing: min over {REPEATS} repeats, iterations per repeat auto-tuned to ~0.2s\n") | ||
|
|
||
|
|
||
| def time_us(case: BackendCase, fn: Callable[[], Any]) -> float: | ||
| """Per-call runtime in µs; min over :data:`REPEATS` measurements with autoranged N.""" | ||
|
|
||
| def runner() -> None: | ||
| case.sync(fn()) | ||
|
|
||
| runner() # warmup — material for JAX's first-call compilation | ||
| timer = timeit.Timer(runner) | ||
| n, _ = timer.autorange() | ||
| times = timer.repeat(repeat=REPEATS, number=n) | ||
| return (min(times) / n) * 1e6 | ||
|
|
||
|
|
||
| def time_us_safe(case: BackendCase, fn: Callable[[], Any]) -> float | None: | ||
| """Like :func:`time_us` but returns ``None`` if ``fn`` raises (e.g. TF 1D matmul).""" | ||
| try: | ||
| return time_us(case, fn) | ||
| except Exception: # noqa: BLE001 | ||
| return None | ||
|
|
||
|
|
||
| def fmt_row(op: str, native_us: float | None, wrapped_us: float | None) -> str: | ||
| if native_us is None or wrapped_us is None: | ||
| return f" {op:<8} {'n/a':>13} {'n/a':>13} {'n/a':>8}" | ||
| ratio = wrapped_us / native_us if native_us > 0 else float("inf") | ||
| return f" {op:<8} {native_us:>10.3f} µs {wrapped_us:>10.3f} µs {ratio:>6.2f}x" | ||
|
|
||
|
|
||
| def print_size_header(n: int) -> None: | ||
| print(f"size = {n:_}") | ||
| print(f" {'op':<8} {'native':>13} {'wrapped':>13} {'overhead':>8}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.