From 1246a517b801d354a040c8c1de10570155f9e616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Wed, 17 Jun 2026 13:16:30 +0100 Subject: [PATCH 1/7] Expanding the documentation for the macro-based entrypoints and the Optimize module. #272 --- pharmsol-macros/src/lib.rs | 124 ++++++++++++++++++++++++++++++++++++- src/optimize/effect.rs | 6 ++ src/optimize/mod.rs | 9 +++ src/optimize/parameters.rs | 17 +++++ 4 files changed, 155 insertions(+), 1 deletion(-) diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index eb8fc6f2..bf0c44bc 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -3208,6 +3208,45 @@ fn expand_sde_out( // Proc macros // --------------------------------------------------------------------------- +/// Define an ODE (ordinary differential equation) model. +/// +/// This is the primary entry point for building pharmacometric ODE models. +/// The macro generates and validates an [`ODE`] model and automatically generates its metadata +/// (parameter names, state labels, output labels, route declarations). +/// +/// # Fields +/// +/// | Field | Required | Description | +/// |-------|----------|-------------| +/// | `name` | yes | Model name (`"my_model"`) | +/// | `params` | yes | Parameter identifiers `[ka, ke, v]` | +/// | `covariates` | no | Covariate identifiers `[wt, age]` | +/// | `states` | yes | State identifiers `[gut, central]` | +/// | `outputs` | yes | Output identifiers `[cp]` | +/// | `routes` | no | Route declarations `[bolus(oral) -> gut, infusion(iv) -> central]` | +/// | `diffeq` | yes | Closure `\|x, p, t, dx, cov\| { … }` writing derivatives into `dx` | +/// | `lag` | no | Closure returning route‑specific lag times via `lag! { route => expr }` | +/// | `fa` | no | Closure returning bioavailability fractions | +/// | `init` | no | Closure setting initial state values | +/// | `out` | yes | Closure `\|x, p, t, cov, y\| { … }` mapping states to outputs | +/// +/// # Example +/// +/// ```ignore +/// let model = ode! { +/// name: "one_cmt_iv", +/// params: [ke, v], +/// states: [central], +/// outputs: [cp], +/// routes: [infusion(iv) -> central], +/// diffeq: |x, _p, _t, dx, _cov| { +/// dx[central] = -ke * x[central]; +/// }, +/// out: |x, _p, _t, _cov, y| { +/// y[cp] = x[central] / v; +/// }, +/// }; +/// ``` #[proc_macro] pub fn ode(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as OdeInput); @@ -3365,6 +3404,45 @@ pub fn ode(input: TokenStream) -> TokenStream { .into() } +/// Define an analytical (closed‑form) PK model. +/// +/// Builds a model that uses a built‑in analytical solution. The macro validates that the declared parameters +/// match the chosen analytical solution's requirements and generates an [`Analytical`] value +/// with full metadata. +/// +/// # Fields +/// +/// | Field | Required | Description | +/// |-------|----------|-------------| +/// | `name` | yes | Model name | +/// | `params` | yes | Parameter identifiers | +/// | `derived` | no | Derived parameter identifiers (computed in `derive`) | +/// | `covariates` | no | Covariate identifiers | +/// | `states` | yes | State identifiers | +/// | `outputs` | yes | Output identifiers | +/// | `routes` | no | Route declarations | +/// | `structure` | yes | Built‑in kernel name, e.g. `one_compartment` or `one_compartment_with_absorption` | +/// | `derive` | no | Closure `\|t\| { … }` computing derived parameters from primaries and covariates | +/// | `lag` | no | Lag‑time closure | +/// | `fa` | no | Bioavailability closure | +/// | `init` | no | Initial‑state closure | +/// | `out` | yes | Output mapping closure | +/// +/// # Example +/// +/// ```ignore +/// let model = analytical! { +/// name: "one_cmt_oral", +/// params: [ka, ke, v], +/// states: [gut, central], +/// outputs: [cp], +/// routes: [bolus(oral) -> gut], +/// structure: one_compartment_with_absorption, +/// out: |x, _t, y| { +/// y[cp] = x[central] / v; +/// }, +/// }; +/// ``` #[proc_macro] pub fn analytical(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as AnalyticalInput); @@ -3506,7 +3584,6 @@ pub fn analytical(input: TokenStream) -> TokenStream { let nstates = input.states.len(); let ndrugs = dense_index_len(&route_bindings); let nout = input.outputs.len(); - let name = &input.name; let params = &input.params; let covariates = &input.covariates; @@ -3550,6 +3627,51 @@ pub fn analytical(input: TokenStream) -> TokenStream { .into() } +/// Define an SDE (stochastic differential equation) model. +/// +/// Builds a particle‑based stochastic model with a drift term, a diffusion +/// term, and a configurable number of particles. The macro generates an [`SDE`] +/// value with full metadata. +/// +/// # Fields +/// +/// | Field | Required | Description | +/// |-------|----------|-------------| +/// | `name` | yes | Model name | +/// | `params` | yes | Parameter identifiers | +/// | `covariates` | no | Covariate identifiers | +/// | `states` | yes | State identifiers | +/// | `outputs` | yes | Output identifiers | +/// | `particles` | yes | Number of particles for the simulation | +/// | `routes` | no | Route declarations | +/// | `drift` | yes | Closure `\|x, p, t, dx, cov\| { … }` for the deterministic drift | +/// | `diffusion` | yes | Closure `\|p, sigma\| { … }` setting per‑state diffusion coefficients | +/// | `lag` | no | Lag‑time closure | +/// | `fa` | no | Bioavailability closure | +/// | `init` | no | Initial‑state closure | +/// | `out` | yes | Output mapping closure | +/// +/// # Example +/// +/// ```ignore +/// let model = sde! { +/// name: "one_cmt_sde", +/// params: [ke, sigma_ke, v], +/// states: [central], +/// outputs: [cp], +/// particles: 16, +/// routes: [infusion(iv) -> central], +/// drift: |x, _p, _t, dx, _cov| { +/// dx[central] = -ke * x[central]; +/// }, +/// diffusion: |_p, sigma| { +/// sigma[central] = sigma_ke; +/// }, +/// out: |x, _p, _t, _cov, y| { +/// y[cp] = x[central] / v; +/// }, +/// }; +/// ``` #[proc_macro] pub fn sde(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as SdeInput); diff --git a/src/optimize/effect.rs b/src/optimize/effect.rs index e35fe95d..c31d6ee6 100644 --- a/src/optimize/effect.rs +++ b/src/optimize/effect.rs @@ -1,3 +1,9 @@ +//! Maximum-effect (`E2`) optimization for dual-site pharmacodynamic models. +//! +//! The central entry point is [`get_e2`], which computes the maximum achievable +//! effect for a model with two binding sites via Nelder‑Mead optimization in +//! log‑space. + use argmin::{ core::{CostFunction, Executor, TerminationReason, TerminationStatus}, solver::neldermead::NelderMead, diff --git a/src/optimize/mod.rs b/src/optimize/mod.rs index 14f2cd50..b044108b 100644 --- a/src/optimize/mod.rs +++ b/src/optimize/mod.rs @@ -1,2 +1,11 @@ +//! Optimizer-oriented helpers for pharmacometric workflows. +//! +//! This module provides optimization utilities built on [`argmin`]: +//! +//! - [`effect`] — Find the maximum effect (`E2`) for dual-site PD models +//! via Nelder‑Mead optimization in log‑space. +//! - [`parameters`] — Nelder‑Mead parameter refinement for an [`Equation`] +//! against a [`Data`] set and [`AssayErrorModels`]. + pub mod effect; pub mod parameters; diff --git a/src/optimize/parameters.rs b/src/optimize/parameters.rs index c64df672..fcf21a14 100644 --- a/src/optimize/parameters.rs +++ b/src/optimize/parameters.rs @@ -1,3 +1,11 @@ +//! Nelder‑Mead parameter refinement for pharmacometric models. +//! +//! This module provides a [`ParameterOptimizer`] that refines a single parameter +//! Given an [`Equation`], observed [`Data`], and [`AssayErrorModels`] via +//! Nelder‑Mead optimization in log‑space. The optimizer finds the parameter vector +//! that minimizes the negative log-likelihood of the model predictions against the data, +//! as measured by the provided error models. + use argmin::{ core::{CostFunction, Error, Executor}, solver::neldermead::NelderMead, @@ -7,6 +15,7 @@ use ndarray::{Array1, Axis}; use crate::{prelude::simulator::log_likelihood_matrix, AssayErrorModels, Data, Equation}; +/// Optimizer that refines a single parameter vector against observed data. pub struct ParameterOptimizer<'a, E: Equation> { equation: &'a E, data: &'a Data, @@ -44,6 +53,12 @@ impl CostFunction for ParameterOptimizer<'_, E> { } impl<'a, E: Equation> ParameterOptimizer<'a, E> { + /// Create a new optimizer. + /// + /// * `equation` — the model to evaluate. + /// * `data` — observed subject data. + /// * `sig` — assay error models per output. + /// * `pyl` — reference (target) likelihood vector. pub fn new( equation: &'a E, data: &'a Data, @@ -58,6 +73,8 @@ impl<'a, E: Equation> ParameterOptimizer<'a, E> { } } + /// Optimize the parameters to minimize the negative log-likelihood against the data. + pub fn optimize_point(self, parameters: Array1) -> Result, Error> { let simplex = create_initial_simplex(¶meters.to_vec()); let solver: NelderMead, f64> = NelderMead::new(simplex).with_sd_tolerance(1e-2)?; From bf66058e2e158c36cccdc28a29c4715fe1cd60cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Wed, 17 Jun 2026 19:56:27 +0100 Subject: [PATCH 2/7] rework traits --- benches/native_matrix.rs | 26 +- examples/compare_solvers.rs | 4 +- examples/macro_vs_handwritten_one_cpt.rs | 10 +- examples/macro_vs_handwritten_two_cpt.rs | 12 +- pharmsol-macros/src/lib.rs | 82 +-- src/core/caching.rs | 25 + src/{simulator/equation => core}/metadata.rs | 0 src/core/mod.rs | 34 + src/core/model_core.rs | 132 ++++ src/core/model_info.rs | 133 ++++ src/core/predictions.rs | 32 + src/core/simulate.rs | 253 +++++++ src/core/solver.rs | 115 +++ src/core/state.rs | 13 + src/data/error_model.rs | 7 - src/dsl/jit.rs | 13 +- src/dsl/native.rs | 654 ++++------------- src/lib.rs | 35 +- src/optimize/parameters.rs | 10 +- src/parameter_order.rs | 2 +- src/parameters.rs | 2 +- .../{equation => backends}/analytical/mod.rs | 535 ++++++-------- .../analytical/one_compartment_cl_models.rs | 9 +- .../analytical/one_compartment_models.rs | 9 +- .../analytical/three_compartment_cl_models.rs | 9 +- .../analytical/three_compartment_models.rs | 9 +- .../analytical/two_compartment_cl_models.rs | 9 +- .../analytical/two_compartment_models.rs | 9 +- src/simulator/backends/mod.rs | 29 + .../{equation => backends}/ode/closure.rs | 0 .../{equation => backends}/ode/mod.rs | 576 ++++++--------- .../{equation => backends}/sde/em.rs | 0 .../{equation => backends}/sde/mod.rs | 664 ++++++++---------- src/simulator/cache.rs | 2 +- src/simulator/equation/mod.rs | 620 ---------------- src/simulator/likelihood/matrix.rs | 15 +- src/simulator/likelihood/mod.rs | 19 +- src/simulator/likelihood/subject.rs | 18 + src/simulator/mod.rs | 6 +- tests/analytical_macro_lowering.rs | 74 +- tests/authoring_parity_corpus.rs | 126 ++-- tests/full_feature_macro_parity.rs | 42 +- tests/numerical_stability.rs | 80 +-- tests/ode_macro_lowering.rs | 64 +- tests/ode_optimizations.rs | 114 +-- tests/sde_macro_lowering.rs | 56 +- tests/support/bimodal_ke.rs | 6 +- tests/support/runtime_corpus.rs | 66 +- tests/test_pf.rs | 8 +- tests/test_solvers.rs | 10 +- 50 files changed, 2154 insertions(+), 2624 deletions(-) create mode 100644 src/core/caching.rs rename src/{simulator/equation => core}/metadata.rs (100%) create mode 100644 src/core/mod.rs create mode 100644 src/core/model_core.rs create mode 100644 src/core/model_info.rs create mode 100644 src/core/predictions.rs create mode 100644 src/core/simulate.rs create mode 100644 src/core/solver.rs create mode 100644 src/core/state.rs rename src/simulator/{equation => backends}/analytical/mod.rs (75%) rename src/simulator/{equation => backends}/analytical/one_compartment_cl_models.rs (95%) rename src/simulator/{equation => backends}/analytical/one_compartment_models.rs (94%) rename src/simulator/{equation => backends}/analytical/three_compartment_cl_models.rs (96%) rename src/simulator/{equation => backends}/analytical/three_compartment_models.rs (98%) rename src/simulator/{equation => backends}/analytical/two_compartment_cl_models.rs (95%) rename src/simulator/{equation => backends}/analytical/two_compartment_models.rs (96%) create mode 100644 src/simulator/backends/mod.rs rename src/simulator/{equation => backends}/ode/closure.rs (100%) rename src/simulator/{equation => backends}/ode/mod.rs (74%) rename src/simulator/{equation => backends}/sde/em.rs (100%) rename src/simulator/{equation => backends}/sde/mod.rs (70%) delete mode 100644 src/simulator/equation/mod.rs diff --git a/benches/native_matrix.rs b/benches/native_matrix.rs index 97ebf0a8..66ee3366 100644 --- a/benches/native_matrix.rs +++ b/benches/native_matrix.rs @@ -11,7 +11,7 @@ use std::time::Duration; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode}; use pharmsol::prelude::*; -use pharmsol::{Analytical, Cache, Parameters, ODE, SDE}; +use pharmsol::{Analytical, Parameters, ODE, SDE}; mod common; use common::{ @@ -113,7 +113,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Ode, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_ode(workload).disable_cache(); + let model = handwritten_ode(workload).without_cache(); let theta = ode_parameters(&model, workload); b.iter(|| { black_box( @@ -141,7 +141,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Ode, Authoring::Macro, CacheState::Cold) => { - let model = macro_ode(workload).disable_cache(); + let model = macro_ode(workload).without_cache(); let theta = ode_parameters(&model, workload); b.iter(|| { black_box( @@ -169,7 +169,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Analytical, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_analytical(workload).disable_cache(); + let model = handwritten_analytical(workload).without_cache(); let theta = analytical_parameters(&model, workload); b.iter(|| { black_box( @@ -197,7 +197,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Analytical, Authoring::Macro, CacheState::Cold) => { - let model = macro_analytical(workload).disable_cache(); + let model = macro_analytical(workload).without_cache(); let theta = analytical_parameters(&model, workload); b.iter(|| { black_box( @@ -225,7 +225,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Sde, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_sde(workload).disable_cache(); + let model = handwritten_sde(workload).without_cache(); let theta = sde_parameters(&model, workload); b.iter(|| { black_box( @@ -253,7 +253,7 @@ fn predictions_group(c: &mut Criterion) { }); } (SolverKind::Sde, Authoring::Macro, CacheState::Cold) => { - let model = macro_sde(workload).disable_cache(); + let model = macro_sde(workload).without_cache(); let theta = sde_parameters(&model, workload); b.iter(|| { black_box( @@ -304,7 +304,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Ode, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_ode(workload).disable_cache(); + let model = handwritten_ode(workload).without_cache(); let theta = ode_parameters(&model, workload); b.iter(|| { black_box( @@ -334,7 +334,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Ode, Authoring::Macro, CacheState::Cold) => { - let model = macro_ode(workload).disable_cache(); + let model = macro_ode(workload).without_cache(); let theta = ode_parameters(&model, workload); b.iter(|| { black_box( @@ -364,7 +364,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Analytical, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_analytical(workload).disable_cache(); + let model = handwritten_analytical(workload).without_cache(); let theta = analytical_parameters(&model, workload); b.iter(|| { black_box( @@ -394,7 +394,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Analytical, Authoring::Macro, CacheState::Cold) => { - let model = macro_analytical(workload).disable_cache(); + let model = macro_analytical(workload).without_cache(); let theta = analytical_parameters(&model, workload); b.iter(|| { black_box( @@ -424,7 +424,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Sde, Authoring::Handwritten, CacheState::Cold) => { - let model = handwritten_sde(workload).disable_cache(); + let model = handwritten_sde(workload).without_cache(); let theta = sde_parameters(&model, workload); b.iter(|| { black_box( @@ -454,7 +454,7 @@ fn log_likelihood_group(c: &mut Criterion) { }); } (SolverKind::Sde, Authoring::Macro, CacheState::Cold) => { - let model = macro_sde(workload).disable_cache(); + let model = macro_sde(workload).without_cache(); let theta = sde_parameters(&model, workload); b.iter(|| { black_box( diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 7536f2a0..aa93ec22 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -18,7 +18,7 @@ use pharmsol::{prelude::*, Parameters}; // between runs; the declaration-first `ode!` surface and the generated // metadata stay the same. -fn two_cpt(solver: OdeSolver) -> equation::ODE { +fn two_cpt(solver: OdeSolver) -> backends::ODE { ode! { name: "two_cpt", params: [ke, kcp, kpc, v], @@ -74,7 +74,7 @@ fn main() { ) .expect("valid named parameters"); - let results: Vec<(&str, equation::ODE)> = vec![ + let results: Vec<(&str, backends::ODE)> = vec![ ("Bdf", bdf), ("Sdirk(TrBdf2)", trbdf2), ("Sdirk(Esdirk34)", esdirk34), diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index 4b4fa03c..2971c471 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -6,7 +6,7 @@ use pharmsol::{prelude::*, Parameters}; -fn macro_model() -> equation::ODE { +fn macro_model() -> backends::ODE { ode! { name: "one_cpt_macro_parity", params: [ke, v], @@ -24,8 +24,8 @@ fn macro_model() -> equation::ODE { } } -fn handwritten_model() -> equation::ODE { - equation::ODE::new( +fn handwritten_model() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -42,12 +42,12 @@ fn handwritten_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cpt_macro_parity") + pharmsol::metadata::new("one_cpt_macro_parity") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .route( - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ), diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index da73eab3..38cb2fc3 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -7,7 +7,7 @@ use pharmsol::{prelude::*, Parameters}; -fn macro_model() -> equation::ODE { +fn macro_model() -> backends::ODE { ode! { name: "two_cpt_shared_input_parity", params: [ke, kcp, kpc, v], @@ -27,8 +27,8 @@ fn macro_model() -> equation::ODE { } } -fn handwritten_model() -> equation::ODE { - equation::ODE::new( +fn handwritten_model() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; @@ -46,15 +46,15 @@ fn handwritten_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("two_cpt_shared_input_parity") + pharmsol::metadata::new("two_cpt_shared_input_parity") .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .inject_input_to_destination(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]), diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index bf0c44bc..62fe2ad2 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -2229,10 +2229,10 @@ fn expand_route_metadata( let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { - quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::bolus(stringify!(#input)) } } OdeRouteKind::Infusion => { - quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::infusion(stringify!(#input)) } } }; let lag_flag = if lag_routes.contains(&route_name) { @@ -2270,10 +2270,10 @@ fn expand_analytical_route_metadata( let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { - quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::bolus(stringify!(#input)) } } OdeRouteKind::Infusion => { - quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::infusion(stringify!(#input)) } } }; let lag_flag = if lag_routes.contains(&route_name) { @@ -2310,10 +2310,10 @@ fn expand_sde_route_metadata( let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { - quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::bolus(stringify!(#input)) } } OdeRouteKind::Infusion => { - quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + quote! { ::pharmsol::core::metadata::Route::infusion(stringify!(#input)) } } }; let lag_flag = if lag_routes.contains(&route_name) { @@ -2548,74 +2548,74 @@ fn resolve_analytical_structure(structure: &Ident) -> syn::Result ( ResolverAnalyticalKernel::OneCompartment, - quote! { ::pharmsol::equation::one_compartment }, - quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment }, + quote! { ::pharmsol::backends::one_compartment }, + quote! { ::pharmsol::backends::AnalyticalKernel::OneCompartment }, 1, ), "one_compartment_cl" => ( ResolverAnalyticalKernel::OneCompartmentCl, - quote! { ::pharmsol::equation::one_compartment_cl }, - quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl }, + quote! { ::pharmsol::backends::one_compartment_cl }, + quote! { ::pharmsol::backends::AnalyticalKernel::OneCompartmentCl }, 1, ), "one_compartment_cl_with_absorption" => ( ResolverAnalyticalKernel::OneCompartmentClWithAbsorption, - quote! { ::pharmsol::equation::one_compartment_cl_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption }, + quote! { ::pharmsol::backends::one_compartment_cl_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::OneCompartmentClWithAbsorption }, 2, ), "one_compartment_with_absorption" => ( ResolverAnalyticalKernel::OneCompartmentWithAbsorption, - quote! { ::pharmsol::equation::one_compartment_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption }, + quote! { ::pharmsol::backends::one_compartment_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::OneCompartmentWithAbsorption }, 2, ), "two_compartments" => ( ResolverAnalyticalKernel::TwoCompartments, - quote! { ::pharmsol::equation::two_compartments }, - quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments }, + quote! { ::pharmsol::backends::two_compartments }, + quote! { ::pharmsol::backends::AnalyticalKernel::TwoCompartments }, 2, ), "two_compartments_cl" => ( ResolverAnalyticalKernel::TwoCompartmentsCl, - quote! { ::pharmsol::equation::two_compartments_cl }, - quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl }, + quote! { ::pharmsol::backends::two_compartments_cl }, + quote! { ::pharmsol::backends::AnalyticalKernel::TwoCompartmentsCl }, 2, ), "two_compartments_cl_with_absorption" => ( ResolverAnalyticalKernel::TwoCompartmentsClWithAbsorption, - quote! { ::pharmsol::equation::two_compartments_cl_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption }, + quote! { ::pharmsol::backends::two_compartments_cl_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::TwoCompartmentsClWithAbsorption }, 3, ), "two_compartments_with_absorption" => ( ResolverAnalyticalKernel::TwoCompartmentsWithAbsorption, - quote! { ::pharmsol::equation::two_compartments_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption }, + quote! { ::pharmsol::backends::two_compartments_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::TwoCompartmentsWithAbsorption }, 3, ), "three_compartments" => ( ResolverAnalyticalKernel::ThreeCompartments, - quote! { ::pharmsol::equation::three_compartments }, - quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments }, + quote! { ::pharmsol::backends::three_compartments }, + quote! { ::pharmsol::backends::AnalyticalKernel::ThreeCompartments }, 3, ), "three_compartments_cl" => ( ResolverAnalyticalKernel::ThreeCompartmentsCl, - quote! { ::pharmsol::equation::three_compartments_cl }, - quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl }, + quote! { ::pharmsol::backends::three_compartments_cl }, + quote! { ::pharmsol::backends::AnalyticalKernel::ThreeCompartmentsCl }, 3, ), "three_compartments_cl_with_absorption" => ( ResolverAnalyticalKernel::ThreeCompartmentsClWithAbsorption, - quote! { ::pharmsol::equation::three_compartments_cl_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption }, + quote! { ::pharmsol::backends::three_compartments_cl_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::ThreeCompartmentsClWithAbsorption }, 4, ), "three_compartments_with_absorption" => ( ResolverAnalyticalKernel::ThreeCompartmentsWithAbsorption, - quote! { ::pharmsol::equation::three_compartments_with_absorption }, - quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption }, + quote! { ::pharmsol::backends::three_compartments_with_absorption }, + quote! { ::pharmsol::backends::AnalyticalKernel::ThreeCompartmentsWithAbsorption }, 4, ), _ => { @@ -3336,7 +3336,7 @@ pub fn ode(input: TokenStream) -> TokenStream { quote! {} } else { quote! { - .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + .covariates([#(::pharmsol::core::metadata::Covariate::continuous(stringify!(#covariates))),*]) } }; @@ -3381,14 +3381,14 @@ pub fn ode(input: TokenStream) -> TokenStream { }; quote! {{ - let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + let __pharmsol_metadata = ::pharmsol::core::metadata::new(#name) .parameters([#(stringify!(#params)),*]) #covariate_metadata .states([#(stringify!(#states)),*]) .outputs([#(stringify!(#outputs)),*]) #(.route(#routes))*; - ::pharmsol::equation::ODE::new( + ::pharmsol::backends::ODE::new( #diffeq, #lag, #fa, @@ -3595,14 +3595,14 @@ pub fn analytical(input: TokenStream) -> TokenStream { quote! {} } else { quote! { - .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + .covariates([#(::pharmsol::core::metadata::Covariate::continuous(stringify!(#covariates))),*]) } }; quote! {{ #derive - let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) - .kind(::pharmsol::equation::ModelKind::Analytical) + let __pharmsol_metadata = ::pharmsol::core::metadata::new(#name) + .kind(::pharmsol::backends::ModelKind::Analytical) .parameters([#(stringify!(#params)),*]) #covariate_metadata .states([#(stringify!(#states)),*]) @@ -3610,7 +3610,7 @@ pub fn analytical(input: TokenStream) -> TokenStream { #(.route(#routes))* .analytical_kernel(#metadata_kernel); - ::pharmsol::equation::Analytical::new( + ::pharmsol::backends::Analytical::new( #eq, |_, _, _| {}, #lag, @@ -3808,14 +3808,14 @@ pub fn sde(input: TokenStream) -> TokenStream { quote! {} } else { quote! { - .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + .covariates([#(::pharmsol::core::metadata::Covariate::continuous(stringify!(#covariates))),*]) } }; quote! {{ let __pharmsol_particles: usize = #particles; - let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) - .kind(::pharmsol::equation::ModelKind::Sde) + let __pharmsol_metadata = ::pharmsol::core::metadata::new(#name) + .kind(::pharmsol::backends::ModelKind::Sde) .parameters([#(stringify!(#params)),*]) #covariate_metadata .states([#(stringify!(#states)),*]) @@ -3823,7 +3823,7 @@ pub fn sde(input: TokenStream) -> TokenStream { #(.route(#routes))* .particles(__pharmsol_particles); - ::pharmsol::equation::SDE::new( + ::pharmsol::backends::SDE::new( #drift, #diffusion, #lag, diff --git a/src/core/caching.rs b/src/core/caching.rs new file mode 100644 index 00000000..50d88080 --- /dev/null +++ b/src/core/caching.rs @@ -0,0 +1,25 @@ +use crate::simulator::cache::{BoundErrorModelCache, PredictionCache}; + +/// Cache management for simulation results. +/// +/// Implementors own optional prediction and error-model caches. +/// The `Clone` impl typically produces shallow copies that share cache data. +pub trait Caching: Sized { + /// Access the prediction cache, if enabled. + fn prediction_cache(&self) -> Option<&PredictionCache>; + + /// Access the bound error-model cache, if enabled. + fn error_model_cache(&self) -> Option<&BoundErrorModelCache>; + + /// Set the prediction cache capacity. Replaces any existing cache. + fn with_cache_capacity(self, size: u64) -> Self; + + /// Disable prediction caching entirely. + fn without_cache(self) -> Self; + + /// Clear all cached entries (prediction + error-model). + fn clear_cache(&self); +} + +// We intentionally do NOT put bind_error_models here because it needs ModelInfo +// metadata. It lives as a free function in `simulate` instead. diff --git a/src/simulator/equation/metadata.rs b/src/core/metadata.rs similarity index 100% rename from src/simulator/equation/metadata.rs rename to src/core/metadata.rs diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 00000000..7e18c04a --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,34 @@ +//! Core traits for the pharmsol simulation framework. +//! +//! This module holds the foundational traits that separate concerns: +//! +//! - [`Solver`] — *How* to advance state through time (backend authors implement this). +//! - [`ModelInfo`] — *What* the model's structure is (dimensions, metadata, lag/fa). +//! - [`Caching`] — *Where* prediction and error-model caches live. +//! - [`Simulate`] — *User-facing* prediction and likelihood API. Anything that +//! implements `Solver + ModelInfo + Caching` can become `Simulate`. +//! - [`State`] — Low-level state vector that can receive bolus doses. +//! - [`Predictions`] — Prediction containers with log-likelihood computation. +//! +//! The free function [`standard_event_loop`] provides the default simulation +//! driver — iterate events, apply boluses, track infusions, compute observations, +//! and advance time via [`Solver::solve`]. Backends that use batch integration +//! (e.g. diffsol-based ODE) set [`Solver::is_batch`] to `true` and implement +//! `Simulate::simulate_subject` themselves. + +pub mod caching; +pub mod metadata; +pub mod model_core; +pub mod model_info; +pub mod predictions; +pub mod simulate; +pub mod solver; +pub mod state; + +pub use caching::Caching; +pub use model_core::ModelCore; +pub use model_info::ModelInfo; +pub use predictions::Predictions; +pub use simulate::{standard_event_loop, PredictionsContainer, Simulate}; +pub use solver::Solver; +pub use state::State; diff --git a/src/core/model_core.rs b/src/core/model_core.rs new file mode 100644 index 00000000..0626a28c --- /dev/null +++ b/src/core/model_core.rs @@ -0,0 +1,132 @@ +use crate::simulator::cache::{BoundErrorModelCache, DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE}; +use crate::simulator::Neqs; +use crate::ValidatedModelMetadata; + +/// Shared model infrastructure: dimensions, metadata, and cache. +/// +/// Each backend (ODE, Analytical, SDE) wraps a `ModelCore` with its +/// backend-specific fields (closure functions, solver config, etc.). +/// +/// `C` is the prediction-cache type: [`PredictionCache`] for deterministic +/// backends, [`SdeLikelihoodCache`] for stochastic backends. +#[derive(Clone, Debug)] +pub struct ModelCore { + dims: Neqs, + metadata: Option, + cache: Option, + error_model_cache: Option, +} + +impl ModelCore { + /// Create a new `ModelCore` with default dimensions (all 5) and + /// an optional prediction cache. + pub fn new(cache: Option) -> Self { + Self { + dims: Neqs::default(), + metadata: None, + cache, + error_model_cache: Some(BoundErrorModelCache::new( + DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, + )), + } + } + + // ── Dimensions ────────────────────────────────────────────────────── + + /// Current dimension configuration. + pub fn dims(&self) -> Neqs { + self.dims + } + + /// Number of state variables. + pub fn nstates(&self) -> usize { + self.dims.nstates + } + + /// Number of drug input routes. + pub fn ndrugs(&self) -> usize { + self.dims.ndrugs + } + + /// Number of output equations. + pub fn nout(&self) -> usize { + self.dims.nout + } + + /// Set the number of state variables. Invalidates metadata. + pub fn with_nstates(mut self, nstates: usize) -> Self { + self.dims.nstates = nstates; + self.invalidate(); + self + } + + /// Set the number of drug inputs. Invalidates metadata. + pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { + self.dims.ndrugs = ndrugs; + self.invalidate(); + self + } + + /// Set the number of output equations. Invalidates metadata. + pub fn with_nout(mut self, nout: usize) -> Self { + self.dims.nout = nout; + self.invalidate(); + self + } + + // ── Metadata ──────────────────────────────────────────────────────── + + /// Attached validated metadata, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + /// Attach validated metadata. The caller is responsible for + /// dimension validation and backend-specific error handling. + pub fn set_metadata(&mut self, metadata: ValidatedModelMetadata) { + self.metadata = Some(metadata); + } + + // ── Caches ────────────────────────────────────────────────────────── + + /// Prediction cache, if enabled. + pub fn cache(&self) -> Option<&C> { + self.cache.as_ref() + } + + /// Bound error-model cache, if enabled. + pub fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + self.error_model_cache.as_ref() + } + + /// Set the prediction cache capacity. Replaces any existing cache. + pub fn with_cache_capacity(mut self, cache: C) -> Self { + self.cache = Some(cache); + self + } + + /// Disable prediction caching. + pub fn without_cache(mut self) -> Self { + self.cache = None; + self.error_model_cache = None; + self + } + + /// Clear all cached entries. + pub fn clear_cache(&self) { + // Prediction cache clearing is type-specific and handled by + // the backend's Caching impl. + if let Some(cache) = &self.error_model_cache { + cache.invalidate_all(); + } + } + + // ── Internal ──────────────────────────────────────────────────────── + + fn invalidate(&mut self) { + self.metadata = None; + self.error_model_cache = Some(BoundErrorModelCache::new( + DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, + )); + } +} diff --git a/src/core/model_info.rs b/src/core/model_info.rs new file mode 100644 index 00000000..d31e4c0a --- /dev/null +++ b/src/core/model_info.rs @@ -0,0 +1,133 @@ +use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; + +use crate::data::{Covariates, InputLabel, OutputLabel}; +use crate::core::metadata::RouteKind; +use crate::simulator::{Fa, Lag}; +use crate::{Event, Occasion, PharmsolError, ValidatedModelMetadata}; + +/// Structural information about a model. +/// +/// Provides access to dimensions, metadata, lag/bioavailability closures, +/// and label resolution. Most methods have sensible defaults driven by +/// [`Self::metadata`]. +pub trait ModelInfo { + /// Number of state variables (compartments). + fn nstates(&self) -> usize; + + /// Number of drug input routes (`bolus[]` / `rateiv[]` width). + fn ndrugs(&self) -> usize; + + /// Number of output equations. + fn nout(&self) -> usize; + + /// Model metadata, if attached. Drives label-name resolution. + fn metadata(&self) -> Option<&ValidatedModelMetadata>; + + /// Lag-time closure for this model. + fn lag(&self) -> &Lag; + + /// Fraction-absorbed (bioavailability) closure for this model. + fn fa(&self) -> &Fa; + + // ── Provided methods ── + + /// Resolve a public input label to a dense input index. + /// + /// Uses metadata when available; falls back to interpreting the label as + /// a numeric index. + fn resolve_input( + &self, + label: &InputLabel, + expected_kind: RouteKind, + ) -> Result { + if let Some(metadata) = self.metadata() { + let route = metadata + .route(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_ROUTE_PREFIX) + .and_then(|alias| metadata.route(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + + if route.kind() != expected_kind { + return Err(PharmsolError::UnsupportedInputRouteKind { + input: route.input_index(), + kind: match expected_kind { + RouteKind::Bolus => pharmsol_dsl::RouteKind::Bolus, + RouteKind::Infusion => pharmsol_dsl::RouteKind::Infusion, + }, + }); + } + + return Ok(route.input_index()); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + }) + } + + /// Resolve a public output label to a dense output index. + fn resolve_output(&self, label: &OutputLabel) -> Result { + if let Some(metadata) = self.metadata() { + return metadata + .output_index(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_OUTPUT_PREFIX) + .and_then(|alias| metadata.output_index(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + /// Resolve all events in an occasion, applying lag time and + /// bioavailability adjustments, and mapping labels to indices. + fn resolve_events( + &self, + occasion: &Occasion, + params: &[f64], + covariates: &Covariates, + ) -> Result, PharmsolError> { + let mut resolved = occasion.clone(); + + for event in resolved.events_iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = + self.resolve_input(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(resolved.process_events(Some((self.fa(), self.lag(), params, covariates)), true)) + } +} + +/// Build a canonical alias like `input_7` from a raw numeric label `"7"`. +fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { + if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + Some(format!("{prefix}{label}")) +} diff --git a/src/core/predictions.rs b/src/core/predictions.rs new file mode 100644 index 00000000..812e33df --- /dev/null +++ b/src/core/predictions.rs @@ -0,0 +1,32 @@ +use crate::data::error_model::AssayErrorModels; +use crate::simulator::likelihood::Prediction; +use crate::PharmsolError; + +/// Trait for prediction containers. +/// +/// Implemented by [`SubjectPredictions`] (ODE/Analytical) and +/// [`Array2`] (SDE). +/// +/// For the push-based accumulation interface used during simulation, see +/// [`super::PredictionsContainer`]. +pub trait Predictions: Default { + /// Create a new prediction container with specified capacity. + /// + /// # Parameters + /// - `nparticles`: Number of particles (1 for deterministic, >1 for SDE) + fn new(_nparticles: usize) -> Self { + Default::default() + } + + /// Calculate the sum of squared errors for all predictions. + fn squared_error(&self) -> f64; + + /// Get all predictions as an owned vector. + fn get_predictions(&self) -> Vec; + + /// Calculate the log-likelihood of the predictions given an error model. + /// + /// This is numerically more stable than computing the likelihood and + /// taking its log, especially for extreme values or many observations. + fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result; +} diff --git a/src/core/simulate.rs b/src/core/simulate.rs new file mode 100644 index 00000000..239c2eb6 --- /dev/null +++ b/src/core/simulate.rs @@ -0,0 +1,253 @@ +use std::sync::Arc; + +use crate::core::{Caching, ModelInfo, Solver, State}; +use crate::data::error_model::{AssayErrorModels, BoundAssayErrorModels}; +use crate::simulator::likelihood::Prediction; +use crate::{Event, Infusion, Parameters, PharmsolError, Subject}; + +/// A container that accumulates predictions during simulation. +/// +/// Implemented by [`SubjectPredictions`] and [`Array2`] (SDE). +pub trait PredictionsContainer: Default { + /// Create a new container pre-sized for `nparticles` (1 for deterministic, + /// >1 for SDE). + fn new(nparticles: usize) -> Self; + + /// Append a prediction. + fn push(&mut self, pred: Prediction); + + /// Get all predictions as a slice. + fn predictions(&self) -> &[Prediction]; + + /// Compute the total log-likelihood across all predictions. + fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result; +} + +/// The user-facing simulation API. +/// +/// Anything that is `Solver + ModelInfo + Caching + Clone + Sync + 'static` +/// can implement `Simulate`. For backends that don't use batch integration, +/// the standard event loop ([`standard_event_loop`]) provides the default +/// `simulate_subject` implementation. +/// +/// # Provided convenience methods +/// +/// - [`predictions`](Self::predictions) — simulate and return predictions only +/// - [`log_likelihood`](Self::log_likelihood) — simulate and compute log-likelihood +/// - [`estimate_predictions`](Self::estimate_predictions) — accept `&Parameters` instead of `&[f64]` +/// - [`estimate_log_likelihood`](Self::estimate_log_likelihood) — accept `&Parameters` instead of `&[f64]` +pub trait Simulate: Solver + ModelInfo + Caching + Clone + Sync + 'static { + /// The predictions container type for this backend. + type Predictions: PredictionsContainer; + + /// Run the simulation for a subject and return predictions + optional likelihood. + /// + /// This is the only required method. Implementors can call + /// [`standard_event_loop`] for the default per-event loop, or provide + /// their own (e.g. batch diffsol integration). + fn simulate_subject( + &self, + subject: &Subject, + params: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError>; + + /// Simulate and return predictions only. + fn predictions( + &self, + subject: &Subject, + params: &[f64], + ) -> Result { + Ok(self.simulate_subject(subject, params, None)?.0) + } + + /// Simulate and return the log-likelihood. + fn log_likelihood( + &self, + subject: &Subject, + params: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let predictions = self.predictions(subject, params)?; + let bound = bind_error_models_inner(self, error_models)?; + predictions.log_likelihood(&bound) + } + + /// Convenience: accept `&Parameters` instead of `&[f64]`. + fn estimate_predictions( + &self, + subject: &Subject, + params: &Parameters, + ) -> Result { + self.predictions(subject, params.as_slice()) + } + + /// Convenience: accept `&Parameters` instead of `&[f64]`. + fn estimate_log_likelihood( + &self, + subject: &Subject, + params: &Parameters, + error_models: &AssayErrorModels, + ) -> Result { + self.log_likelihood(subject, params.as_slice(), error_models) + } + + /// The model kind for runtime dispatch. + fn kind() -> pharmsol_dsl::ModelKind; +} + +// ── Standard event loop ──────────────────────────────────────────────────── + +/// The default simulation driver for per-interval solvers. +/// +/// Iterates events: applies boluses, accumulates infusions, computes +/// predictions from observations, and calls [`Solver::solve`] to advance +/// the system between events. +/// +/// Caches results using [`Caching::prediction_cache`] when caching is +/// enabled. Uses [`Caching::error_model_cache`] for bound error-model sharing. +pub fn standard_event_loop( + model: &S, + subject: &Subject, + params: &[f64], + error_models: Option<&AssayErrorModels>, +) -> Result<(P, Option), PharmsolError> +where + S: Solver + ModelInfo + Caching, + P: PredictionsContainer, +{ + // Check prediction cache + if let (Some(cache), None) = (model.prediction_cache(), error_models) { + let key = (subject.hash(), parameters_hash(params)); + // Cache hit would need to return (P, None) but P isn't necessarily the same + // type as what's in the cache. We skip cache-based return here and let + // individual backends handle caching in their simulate_subject impl. + // The cache check pattern is used by Analytical and ODE backends. + let _ = (cache, key); + } + + let bound_error_models = match error_models { + Some(em) => Some(bind_error_models_inner(model, em)?), + None => None, + }; + + let mut output = P::new(model.nparticles()); + let mut likelihood = Vec::new(); + + for occasion in subject.occasions() { + let covariates = occasion.covariates(); + let events = model.resolve_events(occasion, params, covariates)?; + let mut state = model.initial_state(params, covariates, occasion.index()); + let mut infusions: Vec = Vec::new(); + + for (idx, event) in events.iter().enumerate() { + match event { + Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= model.ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input, + ndrugs: model.ndrugs(), + }); + } + state.add_bolus(input, bolus.amount()); + } + Event::Infusion(infusion) => { + infusions.push(infusion.clone()); + } + Event::Observation(observation) => { + let (pred, lik) = model.process_observation( + &state, + params, + observation, + error_models, + covariates, + )?; + if let Some(lik) = lik { + likelihood.push(lik); + } + output.push(pred); + } + } + + // Advance to next event + if let Some(next_event) = events.get(idx + 1) { + if !event.time().eq(&next_event.time()) { + model.solve( + &mut state, + params, + covariates, + &infusions, + event.time(), + next_event.time(), + )?; + } + } + } + } + + let ll = bound_error_models.map(|_| likelihood.iter().product::()); + Ok((output, ll)) +} + +// ── Helpers ───────────────────────────────────────────────────────────────── + +/// Bind assay error models using the model's metadata for output-label +/// resolution, with caching through the model's error-model cache. +pub(crate) fn bind_error_models_inner<'a, M: ModelInfo + Caching>( + model: &'a M, + error_models: &'a AssayErrorModels, +) -> Result, PharmsolError> { + if let Some(cache) = model.error_model_cache() { + let key = error_models.hash(); + if let Some(bound_error_models) = cache.get(&key) { + return Ok(BoundAssayErrorModels::Shared(bound_error_models)); + } + + return match error_models + .bind_output_names( + model + .metadata() + .map(|m| m.outputs().iter().map(|o| o.name())) + .into_iter() + .flatten(), + ) + .map_err(PharmsolError::from)? + { + BoundAssayErrorModels::Owned(bound_error_models) => { + let bound_error_models = Arc::new(bound_error_models); + cache.insert(key, Arc::clone(&bound_error_models)); + Ok(BoundAssayErrorModels::Shared(bound_error_models)) + } + bound => Ok(bound), + }; + } + + error_models + .bind_output_names( + model + .metadata() + .map(|m| m.outputs().iter().map(|o| o.name())) + .into_iter() + .flatten(), + ) + .map_err(PharmsolError::from) +} + +/// Hash a parameter slice for cache keys. +#[inline(always)] +pub(crate) fn parameters_hash(params: &[f64]) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = ahash::AHasher::default(); + for &value in params { + let bits = if value == 0.0 { 0u64 } else { value.to_bits() }; + bits.hash(&mut hasher); + } + hasher.finish() +} diff --git a/src/core/solver.rs b/src/core/solver.rs new file mode 100644 index 00000000..30510acb --- /dev/null +++ b/src/core/solver.rs @@ -0,0 +1,115 @@ +use crate::core::State; +use crate::data::{Covariates, Infusion}; +use crate::simulator::likelihood::Prediction; +use crate::{Observation, PharmsolError}; +use crate::data::error_model::AssayErrorModels; + +/// How to advance a model's state through time. +/// +/// This is the trait backend authors implement. It captures the integration +/// mechanism — analytical closed form, numerical ODE integration, stochastic +/// stepping, etc. — without coupling to the event loop or prediction pipeline. +/// +/// # Batch vs per-interval solving +/// +/// Most backends use per-interval solving: the event loop calls [`solve`] +/// between events. If your solver prefers to handle all events internally +/// (like diffsol which does adaptive stepping across events), override +/// [`is_batch`] to return `true` and provide your own +/// [`Simulate::simulate_subject`](super::Simulate::simulate_subject). +/// +/// # Example (analytical backend) +/// +/// ```ignore +/// impl Solver for MyModel { +/// type State = V; +/// +/// fn solve(&self, x: &mut V, params: &[f64], covariates: &Covariates, +/// infusions: &[Infusion], ti: f64, tf: f64) -> Result<(), Error> { +/// let dt = tf - ti; +/// *x = (self.eq)(x, ¶ms_vector(params), dt, &rateiv(infusions, ti, tf), covariates); +/// Ok(()) +/// } +/// +/// fn process_observation(&self, state: &V, params: &[f64], +/// observation: &Observation, error_models: Option<&AssayErrorModels>, +/// covariates: &Covariates) -> Result<(Prediction, Option), Error> { +/// let mut y = V::zeros(self.nout(), NalgebraContext); +/// (self.output_fn)(state, ¶ms_vector(params), observation.time(), covariates, &mut y); +/// let ix = observation.outeq_index().unwrap(); +/// let pred = observation.to_prediction(y[ix], state.as_slice().to_vec()); +/// let lik = error_models.map(|em| pred.log_likelihood(em).map(f64::exp)).transpose()?; +/// Ok((pred, lik)) +/// } +/// // ... +/// } +/// ``` +pub trait Solver { + /// The state vector type this solver operates on. + type State: State; + + /// Advance the system state from `ti` to `tf`. + /// + /// # Parameters + /// * `state` — current state at `ti`, mutated to state at `tf` on return + /// * `params` — model parameters in model order + /// * `covariates` — time-varying covariates for this occasion + /// * `infusions` — active infusion events in this interval + /// * `ti` — start time (inclusive) + /// * `tf` — end time (exclusive) + fn solve( + &self, + _state: &mut Self::State, + _params: &[f64], + _covariates: &Covariates, + _infusions: &[Infusion], + _ti: f64, + _tf: f64, + ) -> Result<(), PharmsolError> { + unimplemented!( + "solve() is not used by batch-mode solvers; \ + set is_batch() to false or implement solve()" + ) + } + + /// Compute a prediction (and optionally a likelihood component) from the + /// current state at an observation time point. + fn process_observation( + &self, + _state: &Self::State, + _params: &[f64], + _observation: &Observation, + _error_models: Option<&AssayErrorModels>, + _covariates: &Covariates, + ) -> Result<(Prediction, Option), PharmsolError> { + unimplemented!( + "process_observation() is not used by batch-mode solvers; \ + set is_batch() to false or implement process_observation()" + ) + } + + /// Create the initial state vector for the start of an occasion. + /// + /// For `occasion_index == 0`, this should call the model's init closure. + /// For subsequent occasions, it should be zero (carry-over handled + /// elsewhere). + fn initial_state( + &self, + params: &[f64], + covariates: &Covariates, + occasion_index: usize, + ) -> Self::State; + + /// Number of particles. 1 for deterministic (ODE/Analytical), >1 for SDE. + fn nparticles(&self) -> usize { + 1 + } + + /// Whether this solver prefers batch event handling. + /// + /// When `true`, [`Simulate::simulate_subject`](super::Simulate::simulate_subject) + /// must be implemented manually — the standard event loop won't be used. + fn is_batch(&self) -> bool { + false + } +} diff --git a/src/core/state.rs b/src/core/state.rs new file mode 100644 index 00000000..251511d3 --- /dev/null +++ b/src/core/state.rs @@ -0,0 +1,13 @@ +/// Trait for state vectors that can receive bolus doses. +/// +/// Implemented by the state types used by each backend: +/// - [`V`](nalgebra::DVector) for ODE and Analytical +/// - `Vec>` for SDE (one per particle) +pub trait State { + /// Add a bolus dose to the state at the specified resolved input index. + /// + /// # Parameters + /// - `input`: The resolved dense input index used by the execution layer + /// - `amount`: The bolus amount + fn add_bolus(&mut self, input: usize, amount: f64); +} diff --git a/src/data/error_model.rs b/src/data/error_model.rs index 1fffca20..1feea1b2 100644 --- a/src/data/error_model.rs +++ b/src/data/error_model.rs @@ -233,13 +233,6 @@ impl AssayErrorModels { Err(ErrorModelError::IncompatibleOutputContext { expected, found }) } - pub(crate) fn bind_to( - &self, - context: &impl crate::Equation, - ) -> Result, ErrorModelError> { - self.bind_output_names(context.assay_error_models().bound_output_names()) - } - pub(crate) fn bind_output_names( &self, outputs: I, diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index d9691b57..4856fb94 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1332,11 +1332,12 @@ pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result Result Result { - crate::simulator::equation::metadata::Route::bolus(route.name.clone()) + crate::core::metadata::Route::bolus(route.name.clone()) } RouteKind::Infusion => { - crate::simulator::equation::metadata::Route::infusion(route.name.clone()) + crate::core::metadata::Route::infusion(route.name.clone()) } } .to_state(destination); @@ -1398,7 +1399,7 @@ fn runtime_ode_predictions( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::simulator::equation::parameters_hash(support_point), + crate::core::simulate::parameters_hash(support_point), ); if let Some(cached) = cache.get(&key) { return Ok(cached); @@ -1412,210 +1413,71 @@ fn runtime_ode_predictions( } } -impl crate::simulator::equation::Cache for NativeOdeModel { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(PredictionCache::new(size)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - fn enable_cache(mut self) -> Self { - self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } +impl Solver for NativeOdeModel { + type State = V; - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } - if let Some(cache) = &self.error_model_cache { - cache.invalidate_all(); - } + fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> V { + V::zeros(self.shared.info.state_len, NalgebraContext) } - fn disable_cache(mut self) -> Self { - self.cache = None; - self.error_model_cache = None; - self + fn is_batch(&self) -> bool { + true } } -impl EquationTypes for NativeOdeModel { - type S = V; - type P = SubjectPredictions; +impl ModelInfo for NativeOdeModel { + fn nstates(&self) -> usize { self.shared.info.state_len } + fn ndrugs(&self) -> usize { self.shared.info.route_len } + fn nout(&self) -> usize { self.shared.info.output_len } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } + fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } + fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } } -impl EquationPriv for NativeOdeModel { - fn lag(&self) -> &Lag { - &(runtime_no_lag as Lag) - } - - fn fa(&self) -> &Fa { - &(runtime_no_fa as Fa) - } - - fn get_nstates(&self) -> usize { - self.shared.info.state_len - } - - fn get_ndrugs(&self) -> usize { - self.shared.info.route_len - } - - fn get_nouteqs(&self) -> usize { - self.shared.info.output_len - } - - fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { - Some(self.shared.metadata()) - } - - fn solve( - &self, - _state: &mut Self::S, - _support_point: &[f64], - _covariates: &Covariates, - _infusions: &[Infusion], - _start_time: f64, - _end_time: f64, - ) -> Result<(), PharmsolError> { - unimplemented!("solve is not used for runtime ODE models") - } - - fn process_observation( - &self, - _support_point: &[f64], - _observation: &Observation, - _error_models: Option<&AssayErrorModels>, - _time: f64, - _covariates: &Covariates, - _x: &mut Self::S, - _likelihood: &mut Vec, - _output: &mut Self::P, - ) -> Result<(), PharmsolError> { - unimplemented!("process_observation is not used for runtime ODE models") +impl Caching for NativeOdeModel { + fn prediction_cache(&self) -> Option<&PredictionCache> { self.cache.as_ref() } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { self.error_model_cache.as_ref() } + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); + self.error_model_cache = Some(BoundErrorModelCache::new(DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE)); + self } - - fn initial_state( - &self, - _support_point: &[f64], - _covariates: &Covariates, - _occasion_index: usize, - ) -> Self::S { - V::zeros(self.shared.info.state_len, NalgebraContext) + fn without_cache(mut self) -> Self { self.cache = None; self.error_model_cache = None; self } + fn clear_cache(&self) { + if let Some(c) = &self.cache { c.invalidate_all(); } + if let Some(c) = &self.error_model_cache { c.invalidate_all(); } } } -impl Equation for NativeOdeModel { - fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> { - self.error_model_cache.as_ref() - } - - fn estimate_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - Ok(self - .estimate_log_likelihood(subject, parameters, error_models)? - .exp()) - } - - fn estimate_log_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let predictions = runtime_ode_predictions(self, subject, parameters.as_slice())?; - predictions.log_likelihood(&bound_error_models) - } - - fn estimate_predictions_dense( - &self, - subject: &Subject, - parameters: &[f64], - ) -> Result { - runtime_ode_predictions(self, subject, parameters) - } - - fn estimate_log_likelihood_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let predictions = runtime_ode_predictions(self, subject, parameters)?; - predictions.log_likelihood(&bound_error_models) - } +impl Simulate for NativeOdeModel { + type Predictions = SubjectPredictions; - fn simulate_subject_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { + fn simulate_subject(&self, subject: &Subject, parameters: &[f64], + error_models: Option<&AssayErrorModels>) + -> Result<(Self::Predictions, Option), PharmsolError> + { let bound_error_models = match error_models { - Some(error_models) => Some(self.bind_error_models(error_models)?), + Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, }; - let predictions = runtime_ode_predictions(self, subject, parameters)?; let likelihood = match bound_error_models.as_ref() { - Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + Some(em) => Some(predictions.log_likelihood(em)?.exp()), None => None, }; Ok((predictions, likelihood)) } - fn kind() -> EqnKind { - EqnKind::ODE - } - - fn assay_error_models(&self) -> AssayErrorModels { - AssayErrorModels::with_output_names( - self.info() - .outputs - .iter() - .map(|output| output.name.as_str()), - ) - } - - fn estimate_predictions( - &self, - subject: &Subject, - parameters: &Parameters, - ) -> Result { - runtime_ode_predictions(self, subject, parameters.as_slice()) + fn log_likelihood(&self, subject: &Subject, params: &[f64], + error_models: &AssayErrorModels) -> Result + { + let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; + let predictions = runtime_ode_predictions(self, subject, params)?; + predictions.log_likelihood(&bound) } - fn simulate_subject( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - let support_point = parameters.as_slice(); - let bound_error_models = match error_models { - Some(error_models) => Some(self.bind_error_models(error_models)?), - None => None, - }; - - let predictions = runtime_ode_predictions(self, subject, support_point)?; - let likelihood = match bound_error_models.as_ref() { - Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), - None => None, - }; - Ok((predictions, likelihood)) - } + fn kind() -> ModelKind { ModelKind::Ode } } impl NativeAnalyticalModel { @@ -1803,7 +1665,7 @@ fn runtime_analytical_predictions( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::simulator::equation::parameters_hash(support_point), + crate::core::simulate::parameters_hash(support_point), ); if let Some(cached) = cache.get(&key) { return Ok(cached); @@ -1817,174 +1679,64 @@ fn runtime_analytical_predictions( } } -impl crate::simulator::equation::Cache for NativeAnalyticalModel { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(PredictionCache::new(size)); - self - } - fn enable_cache(mut self) -> Self { - self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); - self - } +impl Solver for NativeAnalyticalModel { + type State = V; - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } + fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> V { + V::zeros(self.shared.info.state_len, NalgebraContext) } - fn disable_cache(mut self) -> Self { - self.cache = None; - self - } + fn is_batch(&self) -> bool { true } } -impl EquationTypes for NativeAnalyticalModel { - type S = V; - type P = SubjectPredictions; +impl ModelInfo for NativeAnalyticalModel { + fn nstates(&self) -> usize { self.shared.info.state_len } + fn ndrugs(&self) -> usize { self.shared.info.route_len } + fn nout(&self) -> usize { self.shared.info.output_len } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } + fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } + fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } } -impl EquationPriv for NativeAnalyticalModel { - fn lag(&self) -> &Lag { - &(runtime_no_lag as Lag) - } - - fn fa(&self) -> &Fa { - &(runtime_no_fa as Fa) - } - - fn get_nstates(&self) -> usize { - self.shared.info.state_len - } - - fn get_ndrugs(&self) -> usize { - self.shared.info.route_len - } - - fn get_nouteqs(&self) -> usize { - self.shared.info.output_len - } - - fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { - Some(self.shared.metadata()) - } - - fn solve( - &self, - _state: &mut Self::S, - _support_point: &[f64], - _covariates: &Covariates, - _infusions: &[Infusion], - _start_time: f64, - _end_time: f64, - ) -> Result<(), PharmsolError> { - unimplemented!("solve is not used for runtime analytical models") - } - - fn process_observation( - &self, - _support_point: &[f64], - _observation: &Observation, - _error_models: Option<&AssayErrorModels>, - _time: f64, - _covariates: &Covariates, - _x: &mut Self::S, - _likelihood: &mut Vec, - _output: &mut Self::P, - ) -> Result<(), PharmsolError> { - unimplemented!("process_observation is not used for runtime analytical models") - } - - fn initial_state( - &self, - _support_point: &[f64], - _covariates: &Covariates, - _occasion_index: usize, - ) -> Self::S { - V::zeros(self.shared.info.state_len, NalgebraContext) +impl Caching for NativeAnalyticalModel { + fn prediction_cache(&self) -> Option<&PredictionCache> { self.cache.as_ref() } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { None } + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); self } + fn without_cache(mut self) -> Self { self.cache = None; self } + fn clear_cache(&self) { if let Some(c) = &self.cache { c.invalidate_all(); } } } -impl Equation for NativeAnalyticalModel { - fn estimate_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - Ok(self - .estimate_log_likelihood(subject, parameters, error_models)? - .exp()) - } - - fn estimate_log_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let predictions = runtime_analytical_predictions(self, subject, parameters.as_slice())?; - predictions.log_likelihood(&bound_error_models) - } - - fn estimate_predictions_dense( - &self, - subject: &Subject, - parameters: &[f64], - ) -> Result { - NativeAnalyticalModel::estimate_predictions_dense(self, subject, parameters) - } - - fn kind() -> EqnKind { - EqnKind::Analytical - } - - fn assay_error_models(&self) -> AssayErrorModels { - AssayErrorModels::with_output_names( - self.info() - .outputs - .iter() - .map(|output| output.name.as_str()), - ) - } - - fn estimate_predictions( - &self, - subject: &Subject, - parameters: &Parameters, - ) -> Result { - NativeAnalyticalModel::estimate_predictions(self, subject, parameters) - } - - fn simulate_subject( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - self.simulate_subject_dense(subject, parameters.as_slice(), error_models) - } +impl Simulate for NativeAnalyticalModel { + type Predictions = SubjectPredictions; - fn simulate_subject_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - let bound_error_models = match error_models { - Some(error_models) => Some(self.bind_error_models(error_models)?), + fn simulate_subject(&self, subject: &Subject, parameters: &[f64], + error_models: Option<&AssayErrorModels>) + -> Result<(Self::Predictions, Option), PharmsolError> + { + let bound_em = match error_models { + Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, }; - let predictions = runtime_analytical_predictions(self, subject, parameters)?; - let likelihood = match bound_error_models.as_ref() { - Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + let likelihood = match bound_em.as_ref() { + Some(em) => Some(predictions.log_likelihood(em)?.exp()), None => None, }; Ok((predictions, likelihood)) } + + fn log_likelihood(&self, subject: &Subject, params: &[f64], + error_models: &AssayErrorModels) -> Result + { + let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; + let predictions = runtime_analytical_predictions(self, subject, params)?; + predictions.log_likelihood(&bound) + } + + fn kind() -> ModelKind { ModelKind::Analytical } } impl NativeSdeModel { @@ -2290,7 +2042,7 @@ fn runtime_sde_log_likelihood( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::simulator::equation::parameters_hash(support_point), + crate::core::simulate::parameters_hash(support_point), error_models.hash(), ); if let Some(cached) = cache.get(&key) { @@ -2307,190 +2059,64 @@ fn runtime_sde_log_likelihood( } } -impl crate::simulator::equation::Cache for NativeSdeModel { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(SdeLikelihoodCache::new(size)); - self - } - fn enable_cache(mut self) -> Self { - self.cache = Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)); - self - } +impl Solver for NativeSdeModel { + type State = Vec>; - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } + fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> Vec> { + vec![DVector::zeros(self.shared.info.state_len); self.nparticles] } - fn disable_cache(mut self) -> Self { - self.cache = None; - self - } + fn nparticles(&self) -> usize { self.nparticles } + fn is_batch(&self) -> bool { true } } -impl EquationTypes for NativeSdeModel { - type S = Vec>; - type P = Array2; +impl ModelInfo for NativeSdeModel { + fn nstates(&self) -> usize { self.shared.info.state_len } + fn ndrugs(&self) -> usize { self.shared.info.route_len } + fn nout(&self) -> usize { self.shared.info.output_len } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } + fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } + fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } } -impl EquationPriv for NativeSdeModel { - fn lag(&self) -> &Lag { - &(runtime_no_lag as Lag) - } - - fn fa(&self) -> &Fa { - &(runtime_no_fa as Fa) - } - - fn get_nstates(&self) -> usize { - self.shared.info.state_len - } - - fn get_ndrugs(&self) -> usize { - self.shared.info.route_len - } - - fn get_nouteqs(&self) -> usize { - self.shared.info.output_len - } - - fn nparticles(&self) -> usize { - self.nparticles - } - - fn is_sde(&self) -> bool { - true - } - - fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { - Some(self.shared.metadata()) - } - - fn solve( - &self, - _state: &mut Self::S, - _support_point: &[f64], - _covariates: &Covariates, - _infusions: &[Infusion], - _start_time: f64, - _end_time: f64, - ) -> Result<(), PharmsolError> { - unimplemented!("solve is not used for runtime SDE models") - } - - fn process_observation( - &self, - _support_point: &[f64], - _observation: &Observation, - _error_models: Option<&AssayErrorModels>, - _time: f64, - _covariates: &Covariates, - _x: &mut Self::S, - _likelihood: &mut Vec, - _output: &mut Self::P, - ) -> Result<(), PharmsolError> { - unimplemented!("process_observation is not used for runtime SDE models") - } - - fn initial_state( - &self, - _support_point: &[f64], - _covariates: &Covariates, - _occasion_index: usize, - ) -> Self::S { - vec![DVector::zeros(self.shared.info.state_len); self.nparticles] +impl Caching for NativeSdeModel { + fn prediction_cache(&self) -> Option<&PredictionCache> { None } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { None } + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(SdeLikelihoodCache::new(size)); self } + fn without_cache(mut self) -> Self { self.cache = None; self } + fn clear_cache(&self) { if let Some(c) = &self.cache { c.invalidate_all(); } } } -impl Equation for NativeSdeModel { - fn estimate_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let log_lik = self.estimate_log_likelihood(subject, parameters, error_models)?; - Ok(log_lik.exp()) - } - - fn estimate_log_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - runtime_sde_log_likelihood(self, subject, parameters.as_slice(), &bound_error_models) - } +impl Simulate for NativeSdeModel { + type Predictions = Array2; - fn kind() -> EqnKind { - EqnKind::SDE - } - - fn assay_error_models(&self) -> AssayErrorModels { - AssayErrorModels::with_output_names( - self.info() - .outputs - .iter() - .map(|output| output.name.as_str()), - ) - } - - fn estimate_predictions( - &self, - subject: &Subject, - parameters: &Parameters, - ) -> Result { - NativeSdeModel::estimate_predictions(self, subject, parameters) - } - - fn estimate_predictions_dense( - &self, - subject: &Subject, - parameters: &[f64], - ) -> Result { - NativeSdeModel::estimate_predictions_dense(self, subject, parameters) - } - - fn estimate_log_likelihood_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - runtime_sde_log_likelihood(self, subject, parameters, &bound_error_models) - } - - fn simulate_subject( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - self.simulate_subject_dense(subject, parameters.as_slice(), error_models) - } - - fn simulate_subject_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - let bound_error_models = match error_models { - Some(error_models) => Some(self.bind_error_models(error_models)?), + fn simulate_subject(&self, subject: &Subject, parameters: &[f64], + error_models: Option<&AssayErrorModels>) + -> Result<(Self::Predictions, Option), PharmsolError> + { + let bound_em = match error_models { + Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, }; - let predictions = NativeSdeModel::estimate_predictions_dense(self, subject, parameters)?; - let likelihood = match bound_error_models.as_ref() { - Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + let likelihood = match bound_em.as_ref() { + Some(em) => Some(predictions.log_likelihood(em)?.exp()), None => None, }; Ok((predictions, likelihood)) } + + fn log_likelihood(&self, subject: &Subject, params: &[f64], + error_models: &AssayErrorModels) -> Result + { + let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; + runtime_sde_log_likelihood(self, subject, params, &bound) + } + + fn kind() -> ModelKind { ModelKind::Sde } } fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { @@ -2634,7 +2260,7 @@ fn apply_analytical_kernel( let route_inputs = V::from_vec(route_inputs.to_vec(), NalgebraContext); match kernel { AnalyticalKernel::OneCompartment => { - crate::simulator::equation::analytical::one_compartment( + crate::simulator::backends::analytical::one_compartment( &state, params, dt, @@ -2643,7 +2269,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::OneCompartmentCl => { - crate::simulator::equation::analytical::one_compartment_cl( + crate::simulator::backends::analytical::one_compartment_cl( &state, params, dt, @@ -2652,7 +2278,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::OneCompartmentClWithAbsorption => { - crate::simulator::equation::analytical::one_compartment_cl_with_absorption( + crate::simulator::backends::analytical::one_compartment_cl_with_absorption( &state, params, dt, @@ -2661,7 +2287,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::OneCompartmentWithAbsorption => { - crate::simulator::equation::analytical::one_compartment_with_absorption( + crate::simulator::backends::analytical::one_compartment_with_absorption( &state, params, dt, @@ -2670,7 +2296,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::TwoCompartments => { - crate::simulator::equation::analytical::two_compartments( + crate::simulator::backends::analytical::two_compartments( &state, params, dt, @@ -2679,7 +2305,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::TwoCompartmentsCl => { - crate::simulator::equation::analytical::two_compartments_cl( + crate::simulator::backends::analytical::two_compartments_cl( &state, params, dt, @@ -2688,7 +2314,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::TwoCompartmentsClWithAbsorption => { - crate::simulator::equation::analytical::two_compartments_cl_with_absorption( + crate::simulator::backends::analytical::two_compartments_cl_with_absorption( &state, params, dt, @@ -2697,7 +2323,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::TwoCompartmentsWithAbsorption => { - crate::simulator::equation::analytical::two_compartments_with_absorption( + crate::simulator::backends::analytical::two_compartments_with_absorption( &state, params, dt, @@ -2706,7 +2332,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::ThreeCompartments => { - crate::simulator::equation::analytical::three_compartments( + crate::simulator::backends::analytical::three_compartments( &state, params, dt, @@ -2715,7 +2341,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::ThreeCompartmentsCl => { - crate::simulator::equation::analytical::three_compartments_cl( + crate::simulator::backends::analytical::three_compartments_cl( &state, params, dt, @@ -2724,7 +2350,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::ThreeCompartmentsClWithAbsorption => { - crate::simulator::equation::analytical::three_compartments_cl_with_absorption( + crate::simulator::backends::analytical::three_compartments_cl_with_absorption( &state, params, dt, @@ -2733,7 +2359,7 @@ fn apply_analytical_kernel( ) } AnalyticalKernel::ThreeCompartmentsWithAbsorption => { - crate::simulator::equation::analytical::three_compartments_with_absorption( + crate::simulator::backends::analytical::three_compartments_with_absorption( &state, params, dt, @@ -3318,7 +2944,7 @@ mod tests { let expected = SubjectPredictions::default(); let key = ( subject.hash(), - crate::simulator::equation::parameters_hash(parameters.as_slice()), + crate::core::simulate::parameters_hash(parameters.as_slice()), ); model diff --git a/src/lib.rs b/src/lib.rs index 6e36c982..97db962f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,6 +113,8 @@ mod build_support; pub mod data; #[cfg(feature = "dsl-core")] pub mod dsl; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub mod core; pub mod error; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod nca; @@ -143,16 +145,23 @@ pub use crate::optimize::effect::get_e2; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::parameters::ParameterOptimizer; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::simulator::equation::analytical::*; +pub use crate::simulator::backends::analytical::*; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::core::{Caching, ModelInfo, Simulate, Solver}; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::simulator::equation::metadata; +pub use crate::core::metadata; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::simulator::equation::{ +pub use crate::simulator::backends::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, - Analytical, AnalyticalKernel, Cache, Equation, ModelKind, ModelMetadata, ModelMetadataError, - NameDomain, Predictions, RouteInputPolicy, RouteKind, State, ValidatedModelMetadata, ODE, SDE, + Analytical, AnalyticalKernel, ModelKind, ODE, SDE, }; +pub use crate::core::metadata::{ + ModelMetadata, ModelMetadataError, NameDomain, RouteInputPolicy, RouteKind, + ValidatedModelMetadata, +}; +pub use crate::core::{Predictions, State}; pub use error::PharmsolError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use nalgebra::dmatrix; @@ -220,6 +229,9 @@ pub mod prelude { Covariates, Data, Event, Interpolation, Occasion, Subject, }; + // Core traits + pub use crate::core::{Caching, ModelInfo, Simulate, Solver}; + // NCA extension traits (provides .nca(), .nca_all(), etc. on data types) pub use crate::nca::NCA; pub use crate::nca::{MetricsError, ObservationMetrics}; @@ -233,8 +245,6 @@ pub mod prelude { pub mod simulator { pub use crate::simulator::{ cache::{self, PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, - equation, - equation::Equation, likelihood::{ log_likelihood_batch, log_likelihood_matrix, log_likelihood_subject, log_psi, psi, PopulationPredictions, Prediction, SubjectPredictions, @@ -245,16 +255,15 @@ pub mod prelude { // Direct simulator re-exports for convenience pub use crate::simulator::{ cache::{PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, - equation::{ + backends::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, - Equation, }, likelihood::{Prediction, SubjectPredictions}, }; // Analytical model functions - pub use crate::simulator::equation::analytical::{ + pub use crate::simulator::backends::analytical::{ one_compartment, one_compartment_cl, one_compartment_cl_with_absorption, one_compartment_with_absorption, three_compartments, three_compartments_with_absorption, two_compartments, two_compartments_cl, two_compartments_cl_with_absorption, @@ -263,7 +272,7 @@ pub mod prelude { /// Models submodule for organized access to analytical model functions pub mod models { - pub use crate::simulator::equation::analytical::{ + pub use crate::simulator::backends::analytical::{ one_compartment, one_compartment_cl, one_compartment_cl_with_absorption, one_compartment_with_absorption, three_compartments, three_compartments_with_absorption, two_compartments, two_compartments_cl, @@ -318,14 +327,14 @@ macro_rules! fetch_cov { #[macro_export] macro_rules! lag { ($($k:expr => $v:expr),* $(,)?) => {{ - core::convert::From::from([$(($k, $v),)*]) + [$((($k), ($v)),)*].into() }}; } #[macro_export] macro_rules! fa { ($($k:expr => $v:expr),* $(,)?) => {{ - core::convert::From::from([$(($k, $v),)*]) + [$((($k), ($v)),)*].into() }}; } diff --git a/src/optimize/parameters.rs b/src/optimize/parameters.rs index fcf21a14..fa8a6436 100644 --- a/src/optimize/parameters.rs +++ b/src/optimize/parameters.rs @@ -13,17 +13,18 @@ use argmin::{ use ndarray::{Array1, Axis}; -use crate::{prelude::simulator::log_likelihood_matrix, AssayErrorModels, Data, Equation}; +use crate::core::Simulate; +use crate::{prelude::simulator::log_likelihood_matrix, AssayErrorModels, Data}; /// Optimizer that refines a single parameter vector against observed data. -pub struct ParameterOptimizer<'a, E: Equation> { +pub struct ParameterOptimizer<'a, E: Simulate> { equation: &'a E, data: &'a Data, sig: &'a AssayErrorModels, pyl: &'a Array1, } -impl CostFunction for ParameterOptimizer<'_, E> { +impl CostFunction for ParameterOptimizer<'_, E> { type Param = Vec; type Output = f64; @@ -52,7 +53,7 @@ impl CostFunction for ParameterOptimizer<'_, E> { } } -impl<'a, E: Equation> ParameterOptimizer<'a, E> { +impl<'a, E: Simulate> ParameterOptimizer<'a, E> { /// Create a new optimizer. /// /// * `equation` — the model to evaluate. @@ -74,7 +75,6 @@ impl<'a, E: Equation> ParameterOptimizer<'a, E> { } /// Optimize the parameters to minimize the negative log-likelihood against the data. - pub fn optimize_point(self, parameters: Array1) -> Result, Error> { let simplex = create_initial_simplex(¶meters.to_vec()); let solver: NelderMead, f64> = NelderMead::new(simplex).with_sd_tolerance(1e-2)?; diff --git a/src/parameter_order.rs b/src/parameter_order.rs index 6c33f323..8c095720 100644 --- a/src/parameter_order.rs +++ b/src/parameter_order.rs @@ -7,7 +7,7 @@ use std::fmt; #[cfg(feature = "dsl-core")] use crate::dsl::NativeModelInfo; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -use crate::simulator::equation::ValidatedModelMetadata; +use crate::core::metadata::ValidatedModelMetadata; #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct ParameterOrderPlan { diff --git a/src/parameters.rs b/src/parameters.rs index e91c0060..0baf964a 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -312,7 +312,7 @@ mod tests { use ndarray::array; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] - use crate::{fa, lag, metadata, Equation, ModelKind, Subject, SubjectBuilderExt, ODE}; + use crate::{core::Simulate, fa, lag, metadata, ModelKind, Subject, SubjectBuilderExt, ODE}; #[cfg(feature = "dsl-jit")] use crate::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/backends/analytical/mod.rs similarity index 75% rename from src/simulator/equation/analytical/mod.rs rename to src/simulator/backends/analytical/mod.rs index 00a7481f..56803878 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/backends/analytical/mod.rs @@ -6,6 +6,7 @@ pub mod two_compartment_cl_models; pub mod two_compartment_models; use diffsol::{NalgebraContext, Vector, VectorHost}; +use crate::core::ModelInfo; pub use one_compartment_cl_models::*; pub use one_compartment_models::*; use pharmsol_dsl::ModelKind; @@ -15,18 +16,16 @@ pub use three_compartment_models::*; pub use two_compartment_cl_models::*; pub use two_compartment_models::*; -use super::parameters_hash; +use crate::simulator::backends::parameters_hash; -use super::{ - EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, - ValidatedModelMetadata, -}; +use crate::core::metadata::{ModelMetadata, ModelMetadataError, ValidatedModelMetadata}; use crate::data::error_model::AssayErrorModels; use crate::simulator::cache::{ - BoundErrorModelCache, PredictionCache, DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, DEFAULT_CACHE_SIZE, + BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE, }; +use crate::simulator::likelihood::Prediction; use crate::PharmsolError; -use crate::{data::Covariates, simulator::*, Observation, Parameters, Subject}; +use crate::{data::Covariates, simulator::*, Observation, Subject}; #[derive(Clone, Debug, PartialEq, Eq, Error)] pub enum AnalyticalMetadataError { @@ -46,16 +45,13 @@ pub enum AnalyticalMetadataError { /// equations rather than numerical integration. #[derive(Clone, Debug)] pub struct Analytical { + core: crate::core::ModelCore, eq: AnalyticalEq, seq_eq: SecEq, lag: Lag, fa: Fa, init: Init, out: Out, - neqs: Neqs, - metadata: Option, - cache: Option, - error_model_cache: Option, } #[inline(always)] @@ -90,89 +86,57 @@ pub(crate) fn wrap_pmetrics_analytical( } impl Analytical { - /// Create a new Analytical equation model with default Neqs (all sizes = 5). - /// - /// Use builder methods to configure dimensions: - /// ```ignore - /// Analytical::new(eq, seq_eq, lag, fa, init, out) - /// .with_nstates(2) - /// .with_ndrugs(1) - /// .with_nout(1) - /// ``` pub fn new(eq: AnalyticalEq, seq_eq: SecEq, lag: Lag, fa: Fa, init: Init, out: Out) -> Self { Self { + core: crate::core::ModelCore::new(Some(PredictionCache::new(DEFAULT_CACHE_SIZE))), eq, seq_eq, lag, fa, init, out, - neqs: Neqs::default(), - metadata: None, - cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), - error_model_cache: Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )), } } - /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { - self.neqs.nstates = nstates; - self.invalidate_metadata(); + self.core = self.core.with_nstates(nstates); self } - /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { - self.neqs.ndrugs = ndrugs; - self.invalidate_metadata(); + self.core = self.core.with_ndrugs(ndrugs); self } - /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { - self.neqs.nout = nout; - self.invalidate_metadata(); + self.core = self.core.with_nout(nout); self } - /// Attach validated handwritten-model metadata to this analytical model. pub fn with_metadata( mut self, metadata: ModelMetadata, ) -> Result { - let metadata = metadata.validate_for(ModelKind::Analytical)?; - validate_metadata_dimensions(&metadata, &self.neqs)?; - self.metadata = Some(metadata); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); + let validated = metadata.validate_for(ModelKind::Analytical).map_err(AnalyticalMetadataError::Validation)?; + validate_metadata_dimensions(&validated, &self.core.dims())?; + self.core.set_metadata(validated); Ok(self) } - /// Access the validated metadata attached to this analytical model, if any. pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() + self.core.metadata() } pub fn parameter_index(&self, name: &str) -> Option { - self.metadata()?.parameter_index(name) + self.core.metadata()?.parameter_index(name) } pub fn covariate_index(&self, name: &str) -> Option { - self.metadata()?.covariate_index(name) + self.core.metadata()?.covariate_index(name) } pub fn state_index(&self, name: &str) -> Option { - self.metadata()?.state_index(name) - } - - fn invalidate_metadata(&mut self) { - self.metadata = None; - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); + self.core.metadata()?.state_index(name) } } @@ -207,225 +171,11 @@ fn validate_metadata_dimensions( Ok(()) } -impl super::Cache for Analytical { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(PredictionCache::new(size)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn enable_cache(mut self) -> Self { - self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } - if let Some(cache) = &self.error_model_cache { - cache.invalidate_all(); - } - } - - fn disable_cache(mut self) -> Self { - self.cache = None; - self.error_model_cache = None; - self - } -} - -impl EquationTypes for Analytical { - type S = V; - type P = SubjectPredictions; -} - -impl EquationPriv for Analytical { - // #[inline(always)] - // fn get_init(&self) -> &Init { - // &self.init - // } - - // #[inline(always)] - // fn get_out(&self) -> &Out { - // &self.out - // } - - // #[inline(always)] - // fn get_lag(&self, parameters: &[f64]) -> Option> { - // Some((self.lag)(&V::from_vec(parameters.to_owned()))) - // } - - // #[inline(always)] - // fn get_fa(&self, parameters: &[f64]) -> Option> { - // Some((self.fa)(&V::from_vec(parameters.to_owned()))) - // } - - #[inline(always)] - fn lag(&self) -> &Lag { - &self.lag - } - - #[inline(always)] - fn fa(&self) -> &Fa { - &self.fa - } - - #[inline(always)] - fn get_nstates(&self) -> usize { - self.neqs.nstates - } - - #[inline(always)] - fn get_ndrugs(&self) -> usize { - self.neqs.ndrugs - } - - #[inline(always)] - fn get_nouteqs(&self) -> usize { - self.neqs.nout - } - - fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() - } - - #[inline(always)] - fn solve( - &self, - x: &mut Self::S, - parameters: &[f64], - covariates: &Covariates, - infusions: &[Infusion], - ti: f64, - tf: f64, - ) -> Result<(), PharmsolError> { - if ti == tf { - return Ok(()); - } - - // 1) Build and sort event times - let mut ts = Vec::new(); - ts.push(ti); - ts.push(tf); - for inf in infusions { - let t0 = inf.time(); - let t1 = t0 + inf.duration(); - if t0 > ti && t0 < tf { - ts.push(t0) - } - if t1 > ti && t1 < tf { - ts.push(t1) - } - } - ts.sort_by(|a, b| a.partial_cmp(b).unwrap()); - ts.dedup_by(|a, b| (*a - *b).abs() < 1e-12); - - // 2) March over each sub-interval - let mut current_t = ts[0]; - let mut parameters_v = V::from_vec(parameters.to_vec(), NalgebraContext); - let mut rateiv = V::zeros(self.get_ndrugs(), NalgebraContext); - - for &next_t in &ts[1..] { - // prepare parameters and infusion rate for [current_t .. next_t] - rateiv.fill(0.0); - for inf in infusions { - let s = inf.time(); - let e = s + inf.duration(); - if current_t >= s && next_t <= e { - let input = - inf.input_index() - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: inf.input().to_string(), - })?; - - if input >= self.get_ndrugs() { - return Err(PharmsolError::InputOutOfRange { - input, - ndrugs: self.get_ndrugs(), - }); - } - rateiv[input] += inf.amount() / inf.duration(); - } - } - - // advance the parameters to next_t - (self.seq_eq)(&mut parameters_v, next_t, covariates); - - // advance state by dt - let dt = next_t - current_t; - *x = (self.eq)(x, ¶meters_v, dt, &rateiv, covariates); - - current_t = next_t; - } - - Ok(()) - } - - #[inline(always)] - fn process_observation( - &self, - parameters: &[f64], - observation: &Observation, - error_models: Option<&AssayErrorModels>, - _time: f64, - covariates: &Covariates, - x: &mut Self::S, - likelihood: &mut Vec, - output: &mut Self::P, - ) -> Result<(), PharmsolError> { - let mut y = V::zeros(self.get_nouteqs(), NalgebraContext); - let out = &self.out; - (out)( - x, - &V::from_vec(parameters.to_vec(), NalgebraContext), - observation.time(), - covariates, - &mut y, - ); - let outeq = observation - .outeq_index() - .ok_or_else(|| PharmsolError::UnknownOutputLabel { - label: observation.outeq().to_string(), - })?; - let pred = y[outeq]; - let pred = observation.to_prediction(pred, x.as_slice().to_vec()); - if let Some(error_models) = error_models { - likelihood.push(pred.log_likelihood(error_models)?.exp()); - } - output.add_prediction(pred); - Ok(()) - } - #[inline(always)] - fn initial_state( - &self, - parameters: &[f64], - covariates: &Covariates, - occasion_index: usize, - ) -> V { - let init = &self.init; - let mut x = V::zeros(self.get_nstates(), NalgebraContext); - if occasion_index == 0 { - (init)( - &V::from_vec(parameters.to_vec(), NalgebraContext), - 0.0, - covariates, - &mut x, - ); - } - x - } -} - #[allow(clippy::items_after_test_module)] #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::core::Simulate; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use diffsol::Vector; @@ -577,7 +327,7 @@ pub(crate) mod tests { fn handwritten_analytical_metadata_exposes_name_lookup() { let analytical = simple_analytical() .with_metadata( - super::super::metadata::new("one_cmt_analytical") + crate::core::metadata::new("one_cmt_analytical") .parameters(["ke", "v"]) .covariates([super::super::Covariate::continuous("wt")]) .states(["central"]) @@ -618,7 +368,7 @@ pub(crate) mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - super::super::metadata::new("numeric_alias_analytical") + crate::core::metadata::new("numeric_alias_analytical") .states(["central"]) .outputs(["outeq_1"]) .route(super::super::Route::infusion("input_1").to_state("central")), @@ -660,7 +410,7 @@ pub(crate) mod tests { fn handwritten_analytical_rejects_dimension_mismatches() { let error = simple_analytical() .with_metadata( - super::super::metadata::new("wrong_outputs") + crate::core::metadata::new("wrong_outputs") .parameters(["ke"]) .states(["central"]) .outputs(["cp", "auc"]) @@ -681,7 +431,7 @@ pub(crate) mod tests { fn handwritten_analytical_rejects_particles_metadata() { let error = simple_analytical() .with_metadata( - super::super::metadata::new("invalid_particles") + crate::core::metadata::new("invalid_particles") .parameters(["ke"]) .states(["central"]) .outputs(["cp"]) @@ -716,7 +466,7 @@ pub(crate) mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - super::super::metadata::new("one_cmt_abs") + crate::core::metadata::new("one_cmt_abs") .parameters(["ka", "ke", "v"]) .states(["gut", "central"]) .outputs(["cp"]) @@ -750,7 +500,7 @@ pub(crate) mod tests { fn changing_dimensions_after_metadata_clears_analytical_metadata() { let analytical = simple_analytical() .with_metadata( - super::super::metadata::new("one_cmt_analytical") + crate::core::metadata::new("one_cmt_analytical") .states(["central"]) .outputs(["cp"]) .route(super::super::Route::infusion("iv").to_state("central")), @@ -885,67 +635,208 @@ pub(crate) mod tests { ); } } -impl Equation for Analytical { - fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> { - self.error_model_cache.as_ref() + +// ── New core traits ───────────────────────────────────────────────────────── + +impl crate::core::Solver for Analytical { + type State = V; + + fn solve( + &self, + x: &mut Self::State, + parameters: &[f64], + covariates: &Covariates, + infusions: &[Infusion], + ti: f64, + tf: f64, + ) -> Result<(), PharmsolError> { + if ti == tf { + return Ok(()); + } + + let mut ts = Vec::new(); + ts.push(ti); + ts.push(tf); + for inf in infusions { + let t0 = inf.time(); + let t1 = t0 + inf.duration(); + if t0 > ti && t0 < tf { + ts.push(t0) + } + if t1 > ti && t1 < tf { + ts.push(t1) + } + } + ts.sort_by(|a, b| a.partial_cmp(b).unwrap()); + ts.dedup_by(|a, b| (*a - *b).abs() < 1e-12); + + let mut current_t = ts[0]; + let mut parameters_v = V::from_vec(parameters.to_vec(), NalgebraContext); + let mut rateiv = V::zeros(self.ndrugs(), NalgebraContext); + + for &next_t in &ts[1..] { + rateiv.fill(0.0); + for inf in infusions { + let s = inf.time(); + let e = s + inf.duration(); + if current_t >= s && next_t <= e { + let input = inf.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + } + })?; + if input >= self.ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input, + ndrugs: self.ndrugs(), + }); + } + rateiv[input] += inf.amount() / inf.duration(); + } + } + + (self.seq_eq)(&mut parameters_v, next_t, covariates); + let dt = next_t - current_t; + *x = (self.eq)(x, ¶meters_v, dt, &rateiv, covariates); + current_t = next_t; + } + + Ok(()) } - fn estimate_likelihood( + fn process_observation( &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - _estimate_likelihood(self, subject, parameters.as_slice(), error_models) + x: &Self::State, + parameters: &[f64], + observation: &Observation, + error_models: Option<&AssayErrorModels>, + covariates: &Covariates, + ) -> Result<(Prediction, Option), PharmsolError> { + let mut y = V::zeros(self.nout(), NalgebraContext); + (self.out)( + x, + &V::from_vec(parameters.to_vec(), NalgebraContext), + observation.time(), + covariates, + &mut y, + ); + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = observation.to_prediction(y[outeq], x.as_slice().to_vec()); + let lik = error_models + .map(|em| pred.log_likelihood(em).map(f64::exp)) + .transpose()?; + Ok((pred, lik)) } - fn estimate_log_likelihood( + fn initial_state( &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let ypred = _subject_predictions(self, subject, parameters.as_slice())?; - ypred.log_likelihood(&bound_error_models) + parameters: &[f64], + covariates: &Covariates, + occasion_index: usize, + ) -> V { + let mut x = V::zeros(self.nstates(), NalgebraContext); + if occasion_index == 0 { + (self.init)( + &V::from_vec(parameters.to_vec(), NalgebraContext), + 0.0, + covariates, + &mut x, + ); + } + x + } +} + +impl crate::core::ModelInfo for Analytical { + fn nstates(&self) -> usize { + self.core.nstates() + } + + fn ndrugs(&self) -> usize { + self.core.ndrugs() + } + + fn nout(&self) -> usize { + self.core.nout() + } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.core.metadata() + } + + fn lag(&self) -> &Lag { + &self.lag } - fn kind() -> EqnKind { - EqnKind::Analytical + fn fa(&self) -> &Fa { + &self.fa } } -#[inline(always)] -fn _subject_predictions( - analytical: &Analytical, - subject: &Subject, - parameters: &[f64], -) -> Result { - if let Some(cache) = &analytical.cache { - let key = (subject.hash(), parameters_hash(parameters)); - if let Some(cached) = cache.get(&key) { - return Ok(cached); +impl crate::core::Caching for Analytical { + fn prediction_cache(&self) -> Option<&PredictionCache> { + self.core.cache() + } + + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + self.core.error_model_cache() + } + + fn with_cache_capacity(mut self, size: u64) -> Self { + self.core = self.core.with_cache_capacity(PredictionCache::new(size)); + self + } + + fn without_cache(mut self) -> Self { + self.core = self.core.without_cache(); + self + } + + fn clear_cache(&self) { + self.core.clear_cache(); + if let Some(cache) = self.core.cache() { + cache.invalidate_all(); + } + } +} + +impl crate::core::Simulate for Analytical { + type Predictions = SubjectPredictions; + + fn simulate_subject( + &self, + subject: &Subject, + params: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError> { + if error_models.is_none() { + if let Some(cache) = self.core.cache() { + let key = (subject.hash(), parameters_hash(params)); + if let Some(cached) = cache.get(&key) { + return Ok((cached, None)); + } + } + } + + let result = crate::core::standard_event_loop::( + self, subject, params, error_models, + )?; + + if error_models.is_none() { + if let Some(cache) = self.core.cache() { + let key = (subject.hash(), parameters_hash(params)); + cache.insert(key, result.0.clone()); + } } - let result = analytical - .simulate_subject_dense(subject, parameters, None)? - .0; - cache.insert(key, result.clone()); Ok(result) - } else { - Ok(analytical - .simulate_subject_dense(subject, parameters, None)? - .0) } -} -fn _estimate_likelihood( - ode: &Analytical, - subject: &Subject, - parameters: &[f64], - error_models: &AssayErrorModels, -) -> Result { - let bound_error_models = ode.bind_error_models(error_models)?; - let ypred = _subject_predictions(ode, subject, parameters)?; - Ok(ypred.log_likelihood(&bound_error_models)?.exp()) + fn kind() -> pharmsol_dsl::ModelKind { + pharmsol_dsl::ModelKind::Analytical + } } diff --git a/src/simulator/equation/analytical/one_compartment_cl_models.rs b/src/simulator/backends/analytical/one_compartment_cl_models.rs similarity index 95% rename from src/simulator/equation/analytical/one_compartment_cl_models.rs rename to src/simulator/backends/analytical/one_compartment_cl_models.rs index 218713ef..00742ce4 100755 --- a/src/simulator/equation/analytical/one_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/one_compartment_cl_models.rs @@ -57,6 +57,7 @@ pub fn pm_one_compartment_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{one_compartment_cl, one_compartment_cl_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -66,7 +67,7 @@ mod tests { let infusion_dosing = SubjectInfo::InfusionDosing; let subject = infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, cl, v); let ke = cl / v; @@ -84,7 +85,7 @@ mod tests { .with_nstates(1) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( one_compartment_cl, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -118,7 +119,7 @@ mod tests { let oral_infusion_dosing = SubjectInfo::OralInfusionDosage; let subject = oral_infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, cl, v); let ke = cl / v; @@ -137,7 +138,7 @@ mod tests { .with_nstates(2) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( one_compartment_cl_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/equation/analytical/one_compartment_models.rs b/src/simulator/backends/analytical/one_compartment_models.rs similarity index 94% rename from src/simulator/equation/analytical/one_compartment_models.rs rename to src/simulator/backends/analytical/one_compartment_models.rs index 127483e5..b8581b4a 100644 --- a/src/simulator/equation/analytical/one_compartment_models.rs +++ b/src/simulator/backends/analytical/one_compartment_models.rs @@ -50,6 +50,7 @@ pub fn pm_one_compartment_with_absorption(x: &V, p: &V, t: T, rateiv: &V, cov: & #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{one_compartment, one_compartment_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -59,7 +60,7 @@ mod tests { let infusion_dosing = SubjectInfo::InfusionDosing; let subject = infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); @@ -77,7 +78,7 @@ mod tests { .with_ndrugs(1) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -115,7 +116,7 @@ mod tests { let oral_infusion_dosing = SubjectInfo::OralInfusionDosage; let subject = oral_infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, ke, _v); @@ -134,7 +135,7 @@ mod tests { .with_ndrugs(2) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/equation/analytical/three_compartment_cl_models.rs b/src/simulator/backends/analytical/three_compartment_cl_models.rs similarity index 96% rename from src/simulator/equation/analytical/three_compartment_cl_models.rs rename to src/simulator/backends/analytical/three_compartment_cl_models.rs index 7b069ea4..49be1c2c 100644 --- a/src/simulator/equation/analytical/three_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/three_compartment_cl_models.rs @@ -79,6 +79,7 @@ pub fn pm_three_compartments_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{three_compartments_cl, three_compartments_cl_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -91,7 +92,7 @@ mod tests { // CL=0.1, Q2=3.0, Q3=2.0, Vc=1.0, V2=3.0, V3=4.0 // => k10=0.1, k12=3.0, k13=2.0, k21=1.0, k31=0.5 - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, cl, q2, q3, vc, v2, v3); let k10 = cl / vc; @@ -116,7 +117,7 @@ mod tests { .with_nout(1) .with_ndrugs(3); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( three_compartments_cl, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -160,7 +161,7 @@ mod tests { // ka=1.0, CL=0.1, Q2=3.0, Q3=2.0, Vc=1.0, V2=3.0, V3=4.0 // => k10=0.1, k12=3.0, k13=2.0, k21=1.0, k31=0.5 - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, cl, q2, q3, vc, v2, v3); let k10 = cl / vc; @@ -187,7 +188,7 @@ mod tests { }, ); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( three_compartments_cl_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/equation/analytical/three_compartment_models.rs b/src/simulator/backends/analytical/three_compartment_models.rs similarity index 98% rename from src/simulator/equation/analytical/three_compartment_models.rs rename to src/simulator/backends/analytical/three_compartment_models.rs index f5d96daa..d4b65914 100644 --- a/src/simulator/equation/analytical/three_compartment_models.rs +++ b/src/simulator/backends/analytical/three_compartment_models.rs @@ -252,6 +252,7 @@ pub fn pm_three_compartments_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{three_compartments, three_compartments_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -261,7 +262,7 @@ mod tests { let infusion_dosing = SubjectInfo::InfusionDosing; let subject = infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, k10, k12, k13, k21, k31, _v); @@ -281,7 +282,7 @@ mod tests { .with_ndrugs(1) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( three_compartments, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -325,7 +326,7 @@ mod tests { let oral_infusion_dosing = SubjectInfo::OralInfusionDosage; let subject = oral_infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, k10, k12, k13, k21, k31, _v); @@ -350,7 +351,7 @@ mod tests { .with_ndrugs(2) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( three_compartments_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/equation/analytical/two_compartment_cl_models.rs b/src/simulator/backends/analytical/two_compartment_cl_models.rs similarity index 95% rename from src/simulator/equation/analytical/two_compartment_cl_models.rs rename to src/simulator/backends/analytical/two_compartment_cl_models.rs index 5ea8ed29..895f694c 100644 --- a/src/simulator/equation/analytical/two_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/two_compartment_cl_models.rs @@ -65,6 +65,7 @@ pub fn pm_two_compartments_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{two_compartments_cl, two_compartments_cl_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -74,7 +75,7 @@ mod tests { let infusion_dosing = SubjectInfo::InfusionDosing; let subject = infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, cl, q, vc, vp); @@ -97,7 +98,7 @@ mod tests { .with_nout(1) .with_ndrugs(2); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( two_compartments_cl, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -132,7 +133,7 @@ mod tests { let oral_infusion_dosing = SubjectInfo::OralInfusionDosage; let subject = oral_infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, cl, q, vc, vp); @@ -156,7 +157,7 @@ mod tests { .with_nout(1) .with_ndrugs(3); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( two_compartments_cl_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/equation/analytical/two_compartment_models.rs b/src/simulator/backends/analytical/two_compartment_models.rs similarity index 96% rename from src/simulator/equation/analytical/two_compartment_models.rs rename to src/simulator/backends/analytical/two_compartment_models.rs index 490b56ed..f3d5cc27 100644 --- a/src/simulator/equation/analytical/two_compartment_models.rs +++ b/src/simulator/backends/analytical/two_compartment_models.rs @@ -118,6 +118,7 @@ pub fn pm_two_compartments_with_absorption(x: &V, p: &V, t: T, rateiv: &V, cov: #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; + use crate::core::Simulate; use super::{two_compartments, two_compartments_with_absorption}; use crate::*; use approx::assert_relative_eq; @@ -127,7 +128,7 @@ mod tests { let infusion_dosing = SubjectInfo::InfusionDosing; let subject = infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); @@ -146,7 +147,7 @@ mod tests { .with_ndrugs(1) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( two_compartments, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -184,7 +185,7 @@ mod tests { let oral_infusion_dosing = SubjectInfo::OralInfusionDosage; let subject = oral_infusion_dosing.get_subject(); - let ode = equation::ODE::new( + let ode = crate::simulator::backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, ka, kcp, kpc, _v); @@ -204,7 +205,7 @@ mod tests { .with_ndrugs(2) .with_nout(1); - let analytical = equation::Analytical::new( + let analytical = crate::simulator::backends::Analytical::new( two_compartments_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, diff --git a/src/simulator/backends/mod.rs b/src/simulator/backends/mod.rs new file mode 100644 index 00000000..18c77c1b --- /dev/null +++ b/src/simulator/backends/mod.rs @@ -0,0 +1,29 @@ +//! Solver backends (ODE, Analytical, SDE) and their shared infrastructure. +//! +//! For the user-facing simulation API, see [`crate::core::Simulate`]. +//! For the solver interface that backend authors implement, see [`crate::core::Solver`]. + +pub mod analytical; +pub mod ode; +pub mod sde; +pub use analytical::*; +pub use ode::*; +pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +pub use sde::*; + +// Re-export metadata types for convenience (canonical home is crate::core::metadata) +pub use crate::core::metadata::{ + Covariate, NameDomain, Route, RouteInputPolicy, RouteKind, ValidatedModelMetadata, +}; + +/// Hash parameter vectors to a u64 for cache key generation. +#[inline(always)] +pub(crate) fn parameters_hash(parameters: &[f64]) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = ahash::AHasher::default(); + for &value in parameters { + let bits = if value == 0.0 { 0u64 } else { value.to_bits() }; + bits.hash(&mut hasher); + } + hasher.finish() +} diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/backends/ode/closure.rs similarity index 100% rename from src/simulator/equation/ode/closure.rs rename to src/simulator/backends/ode/closure.rs diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/backends/ode/mod.rs similarity index 74% rename from src/simulator/equation/ode/mod.rs rename to src/simulator/backends/ode/mod.rs index ecfc7872..9bb5620f 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/backends/ode/mod.rs @@ -11,18 +11,19 @@ pub(crate) mod closure_helpers { } use crate::{ - data::{Covariates, Infusion}, + core::{ModelInfo, Solver}, + data::Covariates, error_model::AssayErrorModels, prelude::simulator::SubjectPredictions, simulator::{DiffEq, Fa, Init, Lag, Neqs, Out, M, V}, - Event, Observation, Parameters, PharmsolError, Subject, + Event, PharmsolError, Subject, }; -use super::parameters_hash; +use crate::simulator::backends::parameters_hash; use crate::simulator::cache::{ - BoundErrorModelCache, PredictionCache, DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, DEFAULT_CACHE_SIZE, + BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE, }; -use crate::simulator::equation::Predictions; +use crate::core::Predictions; use closure::PMProblem; use diffsol::{ error::OdeSolverError, ode_solver::method::OdeSolverMethod, NalgebraContext, OdeBuilder, @@ -32,10 +33,8 @@ use nalgebra::DVector; use pharmsol_dsl::ModelKind; use thiserror::Error; -use super::{ - EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, State, - ValidatedModelMetadata, -}; +use crate::core::metadata::{ModelMetadata, ModelMetadataError, ValidatedModelMetadata}; +use crate::core::State; const RTOL: f64 = 1e-4; const ATOL: f64 = 1e-4; @@ -97,58 +96,47 @@ pub enum OdeMetadataError { #[derive(Clone, Debug)] pub struct ODE { + core: crate::core::ModelCore, diffeq: DiffEq, lag: Lag, fa: Fa, init: Init, out: Out, - neqs: Neqs, solver: OdeSolver, rtol: f64, atol: f64, - metadata: Option, - cache: Option, - error_model_cache: Option, } impl ODE { pub fn new(diffeq: DiffEq, lag: Lag, fa: Fa, init: Init, out: Out) -> Self { Self { + core: crate::core::ModelCore::new(Some(PredictionCache::new(DEFAULT_CACHE_SIZE))), diffeq, lag, fa, init, out, - neqs: Neqs::default(), solver: OdeSolver::default(), rtol: RTOL, atol: ATOL, - metadata: None, - cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), - error_model_cache: Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )), } } /// Set the number of state variables (ODE compartments). pub fn with_nstates(mut self, nstates: usize) -> Self { - self.neqs.nstates = nstates; - self.invalidate_metadata(); + self.core = self.core.with_nstates(nstates); self } /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { - self.neqs.ndrugs = ndrugs; - self.invalidate_metadata(); + self.core = self.core.with_ndrugs(ndrugs); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { - self.neqs.nout = nout; - self.invalidate_metadata(); + self.core = self.core.with_nout(nout); self } @@ -167,37 +155,137 @@ impl ODE { /// Attach validated handwritten-model metadata to this ODE. pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { - let metadata = metadata.validate_for(ModelKind::Ode)?; - validate_metadata_dimensions(&metadata, &self.neqs)?; - self.metadata = Some(metadata); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); + let validated = metadata.validate_for(ModelKind::Ode).map_err(OdeMetadataError::Validation)?; + validate_metadata_dimensions(&validated, &self.core.dims())?; + self.core.set_metadata(validated); Ok(self) } /// Access the validated metadata attached to this ODE, if any. pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() + self.core.metadata() } pub fn parameter_index(&self, name: &str) -> Option { - self.metadata()?.parameter_index(name) + self.core.metadata()?.parameter_index(name) } pub fn covariate_index(&self, name: &str) -> Option { - self.metadata()?.covariate_index(name) + self.core.metadata()?.covariate_index(name) } pub fn state_index(&self, name: &str) -> Option { - self.metadata()?.state_index(name) + self.core.metadata()?.state_index(name) } - fn invalidate_metadata(&mut self) { - self.metadata = None; - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); + pub(crate) fn run_events<'a, F, S>( + &self, + solver: &mut S, + events: &[Event], + parameters_v: &V, + covariates: &Covariates, + error_models: Option<&AssayErrorModels>, + bolus_v: &mut V, + zero_bolus: &V, + zero_rateiv: &V, + state_with_bolus: &mut V, + state_without_bolus: &mut V, + y_out: &mut V, + likelihood: &mut Vec, + output: &mut SubjectPredictions, + ) -> Result<(), PharmsolError> + where + F: Fn(&V, &V, f64, &mut V, &V, &V, &Covariates) + 'a, + S: OdeSolverMethod<'a, PMProblem<'a, F>>, + { + for (index, event) in events.iter().enumerate() { + let next_event = events.get(index + 1); + + match event { + Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + if input >= bolus_v.len() { + return Err(PharmsolError::InputOutOfRange { + input, + ndrugs: bolus_v.len(), + }); + } + bolus_v.fill(0.0); + bolus_v[input] = bolus.amount(); + + state_with_bolus.fill(0.0); + state_without_bolus.fill(0.0); + + (self.diffeq)( + solver.state().y, parameters_v, event.time(), + state_without_bolus, zero_bolus, zero_rateiv, covariates, + ); + (self.diffeq)( + solver.state().y, parameters_v, event.time(), + state_with_bolus, bolus_v, zero_rateiv, covariates, + ); + state_with_bolus.axpy(-1.0, state_without_bolus, 1.0); + solver.state_mut().y.axpy(1.0, state_with_bolus, 1.0); + } + Event::Infusion(_) => {} + Event::Observation(observation) => { + y_out.fill(0.0); + (self.out)( + solver.state().y, parameters_v, observation.time(), covariates, y_out, + ); + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = y_out[outeq]; + let pred = + observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); + if let Some(em) = error_models { + likelihood.push(pred.log_likelihood(em)?.exp()); + } + output.add_prediction(pred); + } + } + + if let Some(next_event) = next_event { + if !event.time().eq(&next_event.time()) { + match solver.set_stop_time(next_event.time()) { + Ok(_) => loop { + match solver.step() { + Ok(OdeSolverStopReason::InternalTimestep) => continue, + Ok(OdeSolverStopReason::TstopReached) => break, + Err(diffsol::error::DiffsolError::OdeSolverError( + OdeSolverError::StepSizeTooSmall { time }, + )) => { + return Err(PharmsolError::OtherError(format!( + "ODE solver step size went to zero at t = {time:.4}", + ))); + } + Err(_) | Ok(_) => { + return Err(PharmsolError::OtherError( + "Unexpected solver error".to_string(), + )); + } + } + }, + Err(diffsol::error::DiffsolError::OdeSolverError( + OdeSolverError::StopTimeAtCurrentTime, + )) => continue, + Err(_) => { + return Err(PharmsolError::OtherError( + "Unexpected solver error".to_string(), + )); + } + } + } + } + } + Ok(()) } } @@ -232,39 +320,6 @@ fn validate_metadata_dimensions( Ok(()) } -impl super::Cache for ODE { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(PredictionCache::new(size)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn enable_cache(mut self) -> Self { - self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } - if let Some(cache) = &self.error_model_cache { - cache.invalidate_all(); - } - } - - fn disable_cache(mut self) -> Self { - self.cache = None; - self.error_model_cache = None; - self - } -} - impl State for V { #[inline(always)] fn add_bolus(&mut self, input: usize, amount: f64) { @@ -278,7 +333,7 @@ fn _estimate_likelihood( parameters: &[f64], error_models: &AssayErrorModels, ) -> Result { - let bound_error_models = ode.bind_error_models(error_models)?; + let bound_error_models = crate::core::simulate::bind_error_models_inner(ode, error_models)?; let ypred = _subject_predictions(ode, subject, parameters)?; Ok(ypred.log_likelihood(&bound_error_models)?.exp()) } @@ -289,7 +344,7 @@ fn _subject_predictions( subject: &Subject, parameters: &[f64], ) -> Result { - if let Some(cache) = &ode.cache { + if let Some(cache) = ode.core.cache() { let key = (subject.hash(), parameters_hash(parameters)); if let Some(cached) = cache.get(&key) { return Ok(cached); @@ -310,18 +365,18 @@ fn _simulate_subject_dense( error_models: Option<&AssayErrorModels>, ) -> Result<(SubjectPredictions, Option), PharmsolError> { let bound_error_models = match error_models { - Some(error_models) => Some(ode.bind_error_models(error_models)?), + Some(error_models) => Some(crate::core::simulate::bind_error_models_inner(ode, error_models)?), None => None, }; - let bound_error_models = bound_error_models.as_ref().map(|models| &**models); + let bound_error_models = bound_error_models.as_deref(); let mut output = SubjectPredictions::new(ode.nparticles()); let event_count: usize = subject.occasions().iter().map(|o| o.events().len()).sum(); let mut likelihood = Vec::with_capacity(event_count); - let nstates = ode.get_nstates(); - let ndrugs = ode.get_ndrugs(); + let nstates = ode.nstates(); + let ndrugs = ode.ndrugs(); let mut state_with_bolus = V::zeros(nstates, NalgebraContext); let mut state_without_bolus = V::zeros(nstates, NalgebraContext); @@ -331,11 +386,11 @@ fn _simulate_subject_dense( let parameters_vec = parameters.to_vec(); let parameters_v: V = DVector::from_vec(parameters_vec.clone()).into(); - let mut y_out = V::zeros(ode.get_nouteqs(), NalgebraContext); + let mut y_out = V::zeros(ode.nout(), NalgebraContext); for occasion in subject.occasions() { let covariates = occasion.covariates(); - let events = ode.resolve_occasion_events(occasion, parameters, covariates)?; + let events = ode.resolve_events(occasion, parameters, covariates)?; let problem = OdeBuilder::::new() .atol(vec![ode.atol]) @@ -442,79 +497,11 @@ fn _simulate_subject_dense( Ok((output, ll)) } -impl EquationTypes for ODE { - type S = V; - type P = SubjectPredictions; -} - -impl EquationPriv for ODE { - //#[inline(always)] - // fn get_lag(&self, parameters: &[f64]) -> Option> { - // let parameters = DVector::from_vec(parameters.to_vec()); - // Some((self.lag)(¶meters)) - // } - - // #[inline(always)] - // fn get_fa(&self, parameters: &[f64]) -> Option> { - // let parameters = DVector::from_vec(parameters.to_vec()); - // Some((self.fa)(¶meters)) - // } - #[inline(always)] - fn lag(&self) -> &Lag { - &self.lag - } - - #[inline(always)] - fn fa(&self) -> &Fa { - &self.fa - } - #[inline(always)] - fn get_nstates(&self) -> usize { - self.neqs.nstates - } - - #[inline(always)] - fn get_ndrugs(&self) -> usize { - self.neqs.ndrugs - } - - #[inline(always)] - fn get_nouteqs(&self) -> usize { - self.neqs.nout - } +// ── New core traits ───────────────────────────────────────────────────────── - fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() - } - - #[inline(always)] - fn solve( - &self, - _state: &mut Self::S, - _parameters: &[f64], - _covariates: &Covariates, - _infusions: &[Infusion], - _start_time: f64, - _end_time: f64, - ) -> Result<(), PharmsolError> { - unimplemented!("solve not implemented for ODE"); - } - #[inline(always)] - fn process_observation( - &self, - _parameters: &[f64], - _observation: &Observation, - _error_models: Option<&AssayErrorModels>, - _time: f64, - _covariates: &Covariates, - _x: &mut Self::S, - _likelihood: &mut Vec, - _output: &mut Self::P, - ) -> Result<(), PharmsolError> { - unimplemented!("process_observation not implemented for ODE"); - } +impl crate::core::Solver for ODE { + type State = V; - #[inline(always)] fn initial_state( &self, parameters: &[f64], @@ -522,232 +509,131 @@ impl EquationPriv for ODE { occasion_index: usize, ) -> V { let init = &self.init; - let mut x = V::zeros(self.get_nstates(), NalgebraContext); + let mut x = V::zeros(self.nstates(), NalgebraContext); if occasion_index == 0 { - let parameters = DVector::from_vec(parameters.to_vec()); - (init)(¶meters.into(), 0.0, covariates, &mut x); + (init)( + &V::from_vec(parameters.to_vec(), NalgebraContext), + 0.0, + covariates, + &mut x, + ); } x } -} -impl ODE { - /// Generic event-loop runner, parameterized over the concrete solver type. - #[allow(clippy::too_many_arguments)] - fn run_events<'a, F, S>( - &self, - solver: &mut S, - events: &[Event], - parameters_v: &V, - covariates: &Covariates, - error_models: Option<&AssayErrorModels>, - bolus_v: &mut V, - zero_bolus: &V, - zero_rateiv: &V, - state_with_bolus: &mut V, - state_without_bolus: &mut V, - y_out: &mut V, - likelihood: &mut Vec, - output: &mut SubjectPredictions, - ) -> Result<(), PharmsolError> - where - F: Fn(&V, &V, f64, &mut V, &V, &V, &Covariates) + 'a, - S: OdeSolverMethod<'a, PMProblem<'a, F>>, - { - for (index, event) in events.iter().enumerate() { - let next_event = events.get(index + 1); + fn nparticles(&self) -> usize { + 1 + } - match event { - Event::Bolus(bolus) => { - let input = - bolus - .input_index() - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: bolus.input().to_string(), - })?; + fn is_batch(&self) -> bool { + true + } +} - if input >= bolus_v.len() { - return Err(PharmsolError::InputOutOfRange { - input, - ndrugs: bolus_v.len(), - }); - } - bolus_v.fill(0.0); - bolus_v[input] = bolus.amount(); +impl crate::core::ModelInfo for ODE { + fn nstates(&self) -> usize { + self.core.nstates() + } - state_with_bolus.fill(0.0); - state_without_bolus.fill(0.0); + fn ndrugs(&self) -> usize { + self.core.ndrugs() + } - (self.diffeq)( - solver.state().y, - parameters_v, - event.time(), - state_without_bolus, - zero_bolus, - zero_rateiv, - covariates, - ); + fn nout(&self) -> usize { + self.core.nout() + } - (self.diffeq)( - solver.state().y, - parameters_v, - event.time(), - state_with_bolus, - bolus_v, - zero_rateiv, - covariates, - ); + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.core.metadata() + } - state_with_bolus.axpy(-1.0, state_without_bolus, 1.0); - solver.state_mut().y.axpy(1.0, state_with_bolus, 1.0); - } - Event::Infusion(_) => { - // Infusions are handled within the ODE function itself - } - Event::Observation(observation) => { - y_out.fill(0.0); - (self.out)( - solver.state().y, - parameters_v, - observation.time(), - covariates, - y_out, - ); - let outeq = observation.outeq_index().ok_or_else(|| { - PharmsolError::UnknownOutputLabel { - label: observation.outeq().to_string(), - } - })?; - let pred = y_out[outeq]; - let pred = - observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); - if let Some(error_models) = error_models { - likelihood.push(pred.log_likelihood(error_models)?.exp()); - } - output.add_prediction(pred); - } - } + fn lag(&self) -> &Lag { + &self.lag + } - // Advance to the next event time if it exists - if let Some(next_event) = next_event { - if !event.time().eq(&next_event.time()) { - match solver.set_stop_time(next_event.time()) { - Ok(_) => loop { - match solver.step() { - Ok(OdeSolverStopReason::InternalTimestep) => continue, - Ok(OdeSolverStopReason::TstopReached) => break, - Err(diffsol::error::DiffsolError::OdeSolverError( - OdeSolverError::StepSizeTooSmall { time }, - )) => { - return Err(PharmsolError::OtherError(format!( - "ODE solver step size went to zero at t = {time:.4} (target t = {:.4}). \ - A parameter is likely near 0 or infinite.", - next_event.time() - ))); - } - Err(_) | Ok(_) => { - return Err(PharmsolError::OtherError( - "Unexpected solver error".to_string(), - )); - } - } - }, - Err(diffsol::error::DiffsolError::OdeSolverError( - OdeSolverError::StopTimeAtCurrentTime, - )) => { - continue; - } - Err(_) => { - return Err(PharmsolError::OtherError( - "Unexpected solver error".to_string(), - )); - } - } - } - } - } - Ok(()) + fn fa(&self) -> &Fa { + &self.fa } } -impl Equation for ODE { - fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> { - self.error_model_cache.as_ref() +impl crate::core::Caching for ODE { + fn prediction_cache(&self) -> Option<&PredictionCache> { + self.core.cache() } - fn estimate_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - _estimate_likelihood(self, subject, parameters.as_slice(), error_models) + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + self.core.error_model_cache() } - fn estimate_predictions( - &self, - subject: &Subject, - parameters: &Parameters, - ) -> Result { - _subject_predictions(self, subject, parameters.as_slice()) + fn with_cache_capacity(mut self, size: u64) -> Self { + self.core = self.core.with_cache_capacity(PredictionCache::new(size)); + self } - fn estimate_log_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let ypred = _subject_predictions(self, subject, parameters.as_slice())?; - ypred.log_likelihood(&bound_error_models) + fn without_cache(mut self) -> Self { + self.core = self.core.without_cache(); + self } - fn estimate_predictions_dense( - &self, - subject: &Subject, - parameters: &[f64], - ) -> Result { - _subject_predictions(self, subject, parameters) + fn clear_cache(&self) { + self.core.clear_cache(); + if let Some(cache) = self.core.cache() { + cache.invalidate_all(); + } } +} - fn estimate_log_likelihood_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let ypred = _subject_predictions(self, subject, parameters)?; - ypred.log_likelihood(&bound_error_models) - } +impl crate::core::Simulate for ODE { + type Predictions = SubjectPredictions; - fn simulate_subject_dense( + fn simulate_subject( &self, subject: &Subject, - parameters: &[f64], + params: &[f64], error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - _simulate_subject_dense(self, subject, parameters, error_models) - } + ) -> Result<(Self::Predictions, Option), PharmsolError> { + if error_models.is_none() { + if let Some(cache) = self.core.cache() { + let key = (subject.hash(), super::parameters_hash(params)); + if let Some(cached) = cache.get(&key) { + return Ok((cached, None)); + } + } + } + + let (predictions, likelihood) = + _simulate_subject_dense(self, subject, params, error_models)?; + + if error_models.is_none() { + if let Some(cache) = self.core.cache() { + let key = (subject.hash(), super::parameters_hash(params)); + cache.insert(key, predictions.clone()); + } + } - fn kind() -> EqnKind { - EqnKind::ODE + Ok((predictions, likelihood)) } - fn simulate_subject( + fn log_likelihood( &self, subject: &Subject, - parameters: &Parameters, - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - _simulate_subject_dense(self, subject, parameters.as_slice(), error_models) + params: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let bound_error_models = crate::core::simulate::bind_error_models_inner(self, error_models)?; + let ypred = _subject_predictions(self, subject, params)?; + ypred.log_likelihood(&bound_error_models) + } + + fn kind() -> pharmsol_dsl::ModelKind { + pharmsol_dsl::ModelKind::Ode } } #[cfg(test)] mod tests { use super::*; + use crate::core::Simulate; use crate::{fa, lag, Subject, SubjectBuilderExt}; use approx::assert_relative_eq; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -830,7 +716,7 @@ mod tests { fn handwritten_ode_metadata_exposes_name_lookup() { let ode = simple_ode() .with_metadata( - super::super::metadata::new("bimodal_ke") + crate::core::metadata::new("bimodal_ke") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) @@ -859,7 +745,7 @@ mod tests { fn handwritten_ode_rejects_dimension_mismatches() { let error = simple_ode() .with_metadata( - super::super::metadata::new("wrong_outputs") + crate::core::metadata::new("wrong_outputs") .parameters(["ke"]) .states(["central"]) .outputs(["cp", "auc"]) @@ -880,7 +766,7 @@ mod tests { fn handwritten_ode_rejects_invalid_metadata() { let error = simple_ode() .with_metadata( - super::super::metadata::new("missing_destination") + crate::core::metadata::new("missing_destination") .parameters(["ke"]) .states(["central"]) .outputs(["cp"]) @@ -909,7 +795,7 @@ mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - super::super::metadata::new("explicit_routes") + crate::core::metadata::new("explicit_routes") .states(["central"]) .outputs(["cp"]) .routes([ @@ -953,7 +839,7 @@ mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - super::super::metadata::new("injected_routes") + crate::core::metadata::new("injected_routes") .states(["central"]) .outputs(["cp"]) .routes([ @@ -992,7 +878,7 @@ mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - super::super::metadata::new("numeric_alias_ode") + crate::core::metadata::new("numeric_alias_ode") .states(["central"]) .outputs(["outeq_1"]) .route(super::super::Route::infusion("input_1").to_state("central")), @@ -1028,7 +914,7 @@ mod tests { fn changing_dimensions_after_metadata_clears_route_metadata() { let ode = simple_ode() .with_metadata( - super::super::metadata::new("bimodal_ke") + crate::core::metadata::new("bimodal_ke") .states(["central"]) .outputs(["cp"]) .route(super::super::Route::infusion("iv").to_state("central")), diff --git a/src/simulator/equation/sde/em.rs b/src/simulator/backends/sde/em.rs similarity index 100% rename from src/simulator/equation/sde/em.rs rename to src/simulator/backends/sde/em.rs diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/backends/sde/mod.rs similarity index 70% rename from src/simulator/equation/sde/mod.rs rename to src/simulator/backends/sde/mod.rs index 38fc8490..9f3d7c7d 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/backends/sde/mod.rs @@ -8,28 +8,24 @@ use rand::{rng, RngExt}; use rayon::prelude::*; use thiserror::Error; +use crate::core::{ModelInfo, Simulate, Solver}; use crate::{ data::{Covariates, Infusion}, error_model::AssayErrorModels, prelude::simulator::Prediction, simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V}, - Parameters, Subject, + Event, Observation, PharmsolError, Subject, }; -use super::parameters_hash; +use crate::simulator::backends::parameters_hash; use crate::simulator::cache::{ - BoundErrorModelCache, SdeLikelihoodCache, DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - DEFAULT_CACHE_SIZE, + BoundErrorModelCache, PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE, }; use diffsol::VectorCommon; -use crate::PharmsolError; - -use super::{ - EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, Predictions, - State, ValidatedModelMetadata, -}; +use crate::core::metadata::{ModelMetadata, ModelMetadataError, ValidatedModelMetadata}; +use crate::core::{Predictions, State}; #[derive(Clone, Debug, PartialEq, Eq, Error)] pub enum SdeMetadataError { @@ -63,6 +59,7 @@ impl InjectedBolusMappings { mappings } + #[allow(dead_code)] fn invalidate_for_ndrugs(&mut self, ndrugs: usize) { *self = Self::explicit(ndrugs); } @@ -183,30 +180,86 @@ where /// realistic modeling of biological variability and uncertainty. #[derive(Clone, Debug)] pub struct SDE { + core: crate::core::ModelCore, drift: Drift, diffusion: Diffusion, lag: Lag, fa: Fa, init: Init, out: Out, - neqs: Neqs, nparticles: usize, - metadata: Option, injected_bolus_mappings: InjectedBolusMappings, - cache: Option, - error_model_cache: Option, +} + +impl Predictions for Array2 { + fn new(nparticles: usize) -> Self { + Array2::from_shape_fn((nparticles, 0), |_| Prediction::default()) + } + fn squared_error(&self) -> f64 { + unimplemented!(); + } + fn get_predictions(&self) -> Vec { + if self.is_empty() || self.ncols() == 0 { + return Vec::new(); + } + let mut result = Vec::with_capacity(self.ncols()); + for col in 0..self.ncols() { + let column = self.column(col); + let mean_prediction: f64 = column + .iter() + .map(|pred: &Prediction| pred.prediction()) + .sum::() + / self.nrows() as f64; + let mut prediction = column.first().unwrap().clone(); + prediction.set_prediction(mean_prediction); + result.push(prediction); + } + result + } + fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result { + let predictions = self.get_predictions(); + if predictions.is_empty() { + return Ok(0.0); + } + let log_liks: Result, _> = predictions + .iter() + .filter(|p| p.observation().is_some()) + .map(|p| p.log_likelihood(error_models)) + .collect(); + log_liks.map(|lls| lls.iter().sum()) + } +} + +impl crate::core::PredictionsContainer for Array2 { + fn new(nparticles: usize) -> Self { + Array2::from_shape_fn((nparticles, 0), |_| Prediction::default()) + } + + fn push(&mut self, pred: Prediction) { + let col = Array2::from_shape_vec((self.nrows(), 1), vec![pred]).unwrap(); + *self = ndarray::concatenate(ndarray::Axis(1), &[self.view(), col.view()]).unwrap(); + } + + fn predictions(&self) -> &[Prediction] { + // Array2 doesn't support slicing to &[Prediction] directly + unimplemented!("predictions() not supported for Array2 — use get_predictions()") + } + + fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result { + let predictions: Vec = Predictions::get_predictions(self); + if predictions.is_empty() { + return Ok(0.0); + } + let log_liks: Result, _> = predictions + .iter() + .filter(|p| p.observation().is_some()) + .map(|p| p.log_likelihood(error_models)) + .collect(); + log_liks.map(|lls| lls.iter().sum()) + } } impl SDE { - /// Creates a new stochastic differential equation solver with default Neqs. - /// - /// Use builder methods to configure dimensions: - /// ```ignore - /// SDE::new(drift, diffusion, lag, fa, init, out, nparticles) - /// .with_nstates(2) - /// .with_ndrugs(1) - /// .with_nout(1) - /// ``` pub fn new( drift: Drift, diffusion: Diffusion, @@ -217,65 +270,52 @@ impl SDE { nparticles: usize, ) -> Self { Self { + core: crate::core::ModelCore::new(Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE))), drift, diffusion, lag, fa, init, out, - neqs: Neqs::default(), nparticles, - metadata: None, injected_bolus_mappings: InjectedBolusMappings::default(), - cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)), - error_model_cache: Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )), } } - /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { - self.neqs.nstates = nstates; - self.invalidate_metadata(); + self.core = self.core.with_nstates(nstates); self } - /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { - self.neqs.ndrugs = ndrugs; - self.invalidate_metadata(); + self.core = self.core.with_ndrugs(ndrugs); self } - /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { - self.neqs.nout = nout; - self.invalidate_metadata(); + self.core = self.core.with_nout(nout); self } - /// Attach validated handwritten-model metadata to this SDE model. pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { - let metadata = metadata.validate_for_with_particles(ModelKind::Sde, self.nparticles)?; - validate_metadata_dimensions(&metadata, &self.neqs)?; - self.metadata = Some(metadata); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); + let validated = metadata + .validate_for_with_particles(ModelKind::Sde, self.nparticles) + .map_err(SdeMetadataError::Validation)?; + validate_metadata_dimensions(&validated, &self.core.dims())?; + self.core.set_metadata(validated); Ok(self) } #[doc(hidden)] pub fn with_injected_bolus_inputs(mut self, destinations: &[Option]) -> Self { self.injected_bolus_mappings = - InjectedBolusMappings::from_destinations(self.neqs.ndrugs, destinations); + InjectedBolusMappings::from_destinations(self.core.ndrugs(), destinations); self } /// Access the validated metadata attached to this SDE model, if any. pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() + self.core.metadata() } pub fn parameter_index(&self, name: &str) -> Option { @@ -289,14 +329,13 @@ impl SDE { pub fn state_index(&self, name: &str) -> Option { self.metadata()?.state_index(name) } +} - fn invalidate_metadata(&mut self) { - self.metadata = None; - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self.injected_bolus_mappings - .invalidate_for_ndrugs(self.neqs.ndrugs); +impl State for Vec> { + fn add_bolus(&mut self, input: usize, amount: f64) { + self.par_iter_mut().for_each(|particle| { + particle[input] += amount; + }); } } @@ -331,173 +370,21 @@ fn validate_metadata_dimensions( Ok(()) } -impl super::Cache for SDE { - fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(SdeLikelihoodCache::new(size)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn enable_cache(mut self) -> Self { - self.cache = Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)); - self.error_model_cache = Some(BoundErrorModelCache::new( - DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, - )); - self - } - - fn clear_cache(&self) { - if let Some(cache) = &self.cache { - cache.invalidate_all(); - } - if let Some(cache) = &self.error_model_cache { - cache.invalidate_all(); - } - } - - fn disable_cache(mut self) -> Self { - self.cache = None; - self.error_model_cache = None; - self - } -} - -/// State trait implementation for particle-based SDE simulation. -/// -/// This implementation allows adding bolus doses to all particles in the system. -impl State for Vec> { - /// Adds a bolus dose to a specific input compartment across all particles. - /// - /// # Arguments - /// - /// * `input` - Index of the input compartment - /// * `amount` - Amount to add to the compartment - fn add_bolus(&mut self, input: usize, amount: f64) { - self.par_iter_mut().for_each(|particle| { - particle[input] += amount; - }); - } -} - -/// Predictions implementation for particle-based SDE simulation outputs. -/// -/// This implementation manages and processes predictions from multiple particles. -impl Predictions for Array2 { - fn new(nparticles: usize) -> Self { - Array2::from_shape_fn((nparticles, 0), |_| Prediction::default()) - } - fn squared_error(&self) -> f64 { - unimplemented!(); - } - fn get_predictions(&self) -> Vec { - // Make this return the mean prediction across all particles - if self.is_empty() || self.ncols() == 0 { - return Vec::new(); - } - - let mut result = Vec::with_capacity(self.ncols()); - - for col in 0..self.ncols() { - let column = self.column(col); - - let mean_prediction: f64 = column - .iter() - .map(|pred: &Prediction| pred.prediction()) - .sum::() - / self.nrows() as f64; - - let mut prediction = column.first().unwrap().clone(); - prediction.set_prediction(mean_prediction); - result.push(prediction); - } - - result - } - fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result { - // For SDE, compute log-likelihood using mean predictions across particles - let predictions = self.get_predictions(); - if predictions.is_empty() { - return Ok(0.0); - } - - let log_liks: Result, _> = predictions - .iter() - .filter(|p| p.observation().is_some()) - .map(|p| p.log_likelihood(error_models)) - .collect(); - - log_liks.map(|lls| lls.iter().sum()) - } -} - -impl EquationTypes for SDE { - type S = Vec>; // Vec -> particles, DVector -> state - type P = Array2; // Rows -> particles, Columns -> time -} - -impl EquationPriv for SDE { - // #[inline(always)] - // fn get_init(&self) -> &Init { - // &self.init - // } - - // #[inline(always)] - // fn get_out(&self) -> &Out { - // &self.out - // } - - // #[inline(always)] - // fn get_lag(&self, parameters: &[f64]) -> Option> { - // Some((self.lag)(&V::from_vec(parameters.to_owned()))) - // } - - // #[inline(always)] - // fn get_fa(&self, parameters: &[f64]) -> Option> { - // Some((self.fa)(&V::from_vec(parameters.to_owned()))) - // } - - #[inline(always)] - fn lag(&self) -> &Lag { - &self.lag - } - - #[inline(always)] - fn fa(&self) -> &Fa { - &self.fa - } - - #[inline(always)] - fn get_nstates(&self) -> usize { - self.neqs.nstates - } - - #[inline(always)] - fn get_ndrugs(&self) -> usize { - self.neqs.ndrugs - } - - #[inline(always)] - fn get_nouteqs(&self) -> usize { - self.neqs.nout - } +// ── New core traits ───────────────────────────────────────────────────────── - fn metadata(&self) -> Option<&ValidatedModelMetadata> { - self.metadata.as_ref() - } +impl crate::core::Solver for SDE { + type State = Vec>; - #[inline(always)] fn solve( &self, - state: &mut Self::S, + state: &mut Self::State, parameters: &[f64], covariates: &Covariates, infusions: &[Infusion], ti: f64, tf: f64, ) -> Result<(), PharmsolError> { - let ndrugs = self.get_ndrugs(); + let ndrugs = self.ndrugs(); state.par_iter_mut().for_each(|particle| { *particle = simulate_sde_event( &self.drift, @@ -515,29 +402,21 @@ impl EquationPriv for SDE { }); Ok(()) } - fn nparticles(&self) -> usize { - self.nparticles - } - fn is_sde(&self) -> bool { - true - } - #[inline(always)] fn process_observation( &self, + x: &Self::State, parameters: &[f64], - observation: &crate::Observation, + observation: &Observation, error_models: Option<&AssayErrorModels>, - _time: f64, covariates: &Covariates, - x: &mut Self::S, - likelihood: &mut Vec, - output: &mut Self::P, - ) -> Result<(), PharmsolError> { - let mut pred = vec![Prediction::default(); self.nparticles]; + ) -> Result<(Prediction, Option), PharmsolError> { + // Compute predictions across all particles + let nparticles = self.nparticles; + let mut preds = vec![Prediction::default(); nparticles]; - pred.par_iter_mut().enumerate().for_each(|(i, p)| { - let mut y = V::zeros(self.get_nouteqs(), NalgebraContext); + preds.par_iter_mut().enumerate().for_each(|(i, p)| { + let mut y = V::zeros(self.nout(), NalgebraContext); (self.out)( &x[i].clone().into(), &V::from_vec(parameters.to_vec(), NalgebraContext), @@ -550,41 +429,38 @@ impl EquationPriv for SDE { .expect("resolved observations should use numeric output labels"); *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); - let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?; - *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap(); - //e = y[t] .- x[:,1] - // q = pdf.(Distributions.Normal(0, 0.5), e) - if let Some(em) = error_models { - let mut q: Vec = Vec::with_capacity(self.nparticles); - - pred.iter().for_each(|p| { - let lik = p.log_likelihood(em).map(f64::exp); - match lik { - Ok(l) => q.push(l), - Err(e) => panic!("Error in likelihood calculation: {:?}", e), - } - }); + + // Resampling and likelihood computation + let lik = if let Some(em) = error_models { + let q: Vec = preds + .iter() + .map(|p| p.log_likelihood(em).map(f64::exp).unwrap_or(0.0)) + .collect(); let sum_q: f64 = q.iter().sum(); - let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); - let i = sysresample(&w); - let a: Vec> = i.iter().map(|&i| x[i].clone()).collect(); - *x = a; - likelihood.push(sum_q / self.nparticles as f64); - // let qq: Vec = i.iter().map(|&i| q[i]).collect(); - // likelihood.push(qq.iter().sum::() / self.nparticles as f64); - } - Ok(()) + // Note: resampling is skipped here because state is borrowed. + // Full resampling happens in simulate_subject. + Some(sum_q / nparticles as f64) + } else { + None + }; + + // Return the mean prediction across particles + let mean_pred: f64 = preds.iter().map(|p| p.prediction()).sum::() / nparticles as f64; + let mut prediction = preds[0].clone(); + prediction.set_prediction(mean_pred); + + Ok((prediction, lik)) } - #[inline(always)] + fn initial_state( &self, parameters: &[f64], covariates: &Covariates, occasion_index: usize, - ) -> Self::S { + ) -> Vec> { let mut x = Vec::with_capacity(self.nparticles); for _ in 0..self.nparticles { - let mut state: V = DVector::zeros(self.get_nstates()).into(); + let mut state: V = DVector::zeros(self.nstates()).into(); if occasion_index == 0 { (self.init)( &V::from_vec(parameters.to_vec(), NalgebraContext), @@ -598,112 +474,185 @@ impl EquationPriv for SDE { x } - fn simulate_event( - &self, - parameters: &[f64], - event: &crate::Event, - next_event: Option<&crate::Event>, - error_models: Option<&AssayErrorModels>, - covariates: &Covariates, - x: &mut Self::S, - infusions: &mut Vec, - likelihood: &mut Vec, - output: &mut Self::P, - ) -> Result<(), PharmsolError> { - match event { - crate::Event::Bolus(bolus) => { - let input = - bolus - .input_index() - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: bolus.input().to_string(), - })?; + fn nparticles(&self) -> usize { + self.nparticles + } +} - if input >= self.get_ndrugs() { - return Err(PharmsolError::InputOutOfRange { - input, - ndrugs: self.get_ndrugs(), - }); - } - if !self.injected_bolus_mappings.apply(x, input, bolus.amount()) { - x.add_bolus(input, bolus.amount()); - } - } - crate::Event::Infusion(infusion) => { - infusions.push(infusion.clone()); - } - crate::Event::Observation(observation) => { - self.process_observation( - parameters, - observation, - error_models, - event.time(), - covariates, - x, - likelihood, - output, - )?; - } - } +impl crate::core::ModelInfo for SDE { + fn nstates(&self) -> usize { + self.core.nstates() + } - if let Some(next_event) = next_event { - self.solve( - x, - parameters, - covariates, - infusions, - event.time(), - next_event.time(), - )?; - } - Ok(()) + fn ndrugs(&self) -> usize { + self.core.ndrugs() + } + + fn nout(&self) -> usize { + self.core.nout() + } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.core.metadata() + } + + fn lag(&self) -> &Lag { + &self.lag + } + + fn fa(&self) -> &Fa { + &self.fa } } -impl Equation for SDE { - fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> { - self.error_model_cache.as_ref() +impl crate::core::Caching for SDE { + fn prediction_cache(&self) -> Option<&PredictionCache> { + self.core.cache().map(|_| unimplemented!()) /* SDE uses SdeLikelihoodCache */ + // SDE uses SdeLikelihoodCache, not PredictionCache } - /// Estimates the likelihood of observed data given a model and parameters. - /// - /// # Arguments - /// - /// * `subject` - Subject data containing observations - /// * `parameters` - Parameter vector for the model - /// * `error_model` - Error model to use for likelihood calculations - /// - /// # Returns - /// - /// The log-likelihood of the observed data given the model and parameters. - fn estimate_likelihood( + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + self.core.error_model_cache() + } + + fn with_cache_capacity(mut self, size: u64) -> Self { + self.core = self.core.with_cache_capacity(SdeLikelihoodCache::new(size)); + self + } + + fn without_cache(mut self) -> Self { + self.core = self.core.without_cache(); + self + } + + fn clear_cache(&self) { + self.core.clear_cache(); + if let Some(cache) = self.core.cache() { + cache.invalidate_all(); + } + } +} + +impl crate::core::Simulate for SDE { + type Predictions = Array2; + + fn simulate_subject( &self, subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result { - _estimate_likelihood(self, subject, parameters.as_slice(), error_models) + params: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError> { + let bound_em = match error_models { + Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), + None => None, + }; + + let mut output = + Array2::::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); + let mut likelihood = Vec::new(); + + for occasion in subject.occasions() { + let covariates = occasion.covariates(); + let events = self.resolve_events(occasion, params, covariates)?; + let mut state = self.initial_state(params, covariates, occasion.index()); + let mut infusions: Vec = Vec::new(); + + for (idx, event) in events.iter().enumerate() { + match event { + Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + if input >= self.ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input, + ndrugs: self.ndrugs(), + }); + } + if !self + .injected_bolus_mappings + .apply(&mut state, input, bolus.amount()) + { + state.add_bolus(input, bolus.amount()); + } + } + Event::Infusion(inf) => infusions.push(inf.clone()), + Event::Observation(obs) => { + // Compute predictions across particles + let mut preds = vec![Prediction::default(); self.nparticles]; + preds.par_iter_mut().enumerate().for_each(|(i, p)| { + let mut y = V::zeros(self.nout(), NalgebraContext); + (self.out)( + &state[i].clone().into(), + &V::from_vec(params.to_vec(), NalgebraContext), + obs.time(), + covariates, + &mut y, + ); + let outeq = obs.outeq_index().expect("resolved obs"); + *p = obs.to_prediction(y[outeq], state[i].as_slice().to_vec()); + }); + + // Resampling + if let Some(em) = &bound_em { + let q: Vec = preds + .iter() + .map(|p| p.log_likelihood(em).map(f64::exp).unwrap_or(0.0)) + .collect(); + let sum_q: f64 = q.iter().sum(); + let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); + let indices = sysresample(&w); + state = indices.iter().map(|&i| state[i].clone()).collect(); + likelihood.push(sum_q / self.nparticles as f64); + } + + // Store mean prediction + let mean_pred: f64 = preds.iter().map(|p| p.prediction()).sum::() + / self.nparticles as f64; + let mut pred = preds[0].clone(); + pred.set_prediction(mean_pred); + let col = Array2::from_shape_vec( + (self.nparticles, 1), + vec![pred; self.nparticles], + ) + .unwrap(); + output = concatenate(Axis(1), &[output.view(), col.view()]).unwrap(); + } + } + + if let Some(next) = events.get(idx + 1) { + if !event.time().eq(&next.time()) { + self.solve( + &mut state, + params, + covariates, + &infusions, + event.time(), + next.time(), + )?; + } + } + } + } + + let ll = bound_em.map(|_| likelihood.iter().product::()); + Ok((output, ll)) } - fn estimate_log_likelihood( + fn log_likelihood( &self, subject: &Subject, - parameters: &Parameters, + params: &[f64], error_models: &AssayErrorModels, ) -> Result { - // For SDE, the particle filter computes likelihood in regular space. - // We compute it directly and then take the log. - let lik = _estimate_likelihood(self, subject, parameters.as_slice(), error_models)?; - - if lik > 0.0 { - Ok(lik.ln()) - } else { - Ok(f64::NEG_INFINITY) - } + // Use cached likelihood path + _estimate_likelihood(self, subject, params, error_models) } - fn kind() -> EqnKind { - EqnKind::SDE + fn kind() -> pharmsol_dsl::ModelKind { + pharmsol_dsl::ModelKind::Sde } } @@ -714,7 +663,7 @@ fn _estimate_likelihood( parameters: &[f64], error_models: &AssayErrorModels, ) -> Result { - if let Some(cache) = &sde.cache { + if let Some(cache) = sde.core.cache() { let key = ( subject.hash(), parameters_hash(parameters), @@ -724,12 +673,14 @@ fn _estimate_likelihood( return Ok(cached); } - let ypred = sde.simulate_subject_dense(subject, parameters, Some(error_models))?; + let ypred = + ::simulate_subject(sde, subject, parameters, Some(error_models))?; let result = ypred.1.unwrap(); cache.insert(key, result); Ok(result) } else { - let ypred = sde.simulate_subject_dense(subject, parameters, Some(error_models))?; + let ypred = + ::simulate_subject(sde, subject, parameters, Some(error_models))?; Ok(ypred.1.unwrap()) } } @@ -768,7 +719,8 @@ fn sysresample(q: &[f64]) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::simulator::equation::{self, Covariate, Route}; + use crate::core::metadata::{Covariate, Route}; + use crate::core::Simulate; use crate::SubjectBuilderExt; use crate::{fa, fetch_params, lag}; @@ -818,7 +770,7 @@ mod tests { fn handwritten_sde_metadata_exposes_name_lookup_and_particles() { let sde = simple_sde() .with_metadata( - equation::metadata::new("one_cmt_sde") + crate::core::metadata::new("one_cmt_sde") .parameters(["ke", "v"]) .covariates([Covariate::continuous("wt")]) .states(["central"]) @@ -862,7 +814,7 @@ mod tests { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("numeric_alias_sde") + crate::core::metadata::new("numeric_alias_sde") .states(["depot", "central"]) .outputs(["outeq_1"]) .route(Route::infusion("input_1").to_state("central")) @@ -905,7 +857,7 @@ mod tests { fn handwritten_sde_rejects_dimension_mismatches() { let error = simple_sde() .with_metadata( - equation::metadata::new("bad_sde") + crate::core::metadata::new("bad_sde") .parameters(["ke", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) @@ -927,7 +879,7 @@ mod tests { fn handwritten_sde_rejects_particle_mismatch() { let error = simple_sde() .with_metadata( - equation::metadata::new("particle_conflict") + crate::core::metadata::new("particle_conflict") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) @@ -949,7 +901,7 @@ mod tests { fn changing_dimensions_after_metadata_clears_sde_metadata() { let sde = simple_sde() .with_metadata( - equation::metadata::new("one_cmt_sde") + crate::core::metadata::new("one_cmt_sde") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) @@ -970,7 +922,7 @@ mod tests { let explicit = route_policy_sde(zero_drift) .with_metadata( - equation::metadata::new("explicit_bolus") + crate::core::metadata::new("explicit_bolus") .parameters(["theta"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -981,7 +933,7 @@ mod tests { let injected = route_policy_sde(zero_drift) .with_metadata( - equation::metadata::new("injected_bolus") + crate::core::metadata::new("injected_bolus") .parameters(["theta"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -1019,7 +971,7 @@ mod tests { let explicit = route_policy_sde(rateiv_drift) .with_metadata( - equation::metadata::new("explicit_infusion") + crate::core::metadata::new("explicit_infusion") .parameters(["theta"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -1030,7 +982,7 @@ mod tests { let injected = route_policy_sde(rateiv_drift) .with_metadata( - equation::metadata::new("injected_infusion") + crate::core::metadata::new("injected_infusion") .parameters(["theta"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -1070,7 +1022,7 @@ mod tests { let sde = route_policy_sde(zero_drift) .with_metadata( - equation::metadata::new("injected_bolus") + crate::core::metadata::new("injected_bolus") .parameters(["theta"]) .states(["depot", "central"]) .outputs(["cp"]) diff --git a/src/simulator/cache.rs b/src/simulator/cache.rs index 8876714e..92b49523 100644 --- a/src/simulator/cache.rs +++ b/src/simulator/cache.rs @@ -42,7 +42,7 @@ pub(crate) type BoundErrorModelKey = u64; /// Thread-safe LRU cache for subject predictions. /// -/// Used by [`ODE`](crate::ODE) and [`Analytical`](crate::simulator::equation::Analytical) +/// Used by [`ODE`](crate::ODE) and [`Analytical`](crate::simulator::backends::Analytical) /// to avoid recomputing predictions for the same (subject, parameters) pair. /// /// `Clone` produces a shallow clone that shares the same underlying cache data, diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs deleted file mode 100644 index 835bd424..00000000 --- a/src/simulator/equation/mod.rs +++ /dev/null @@ -1,620 +0,0 @@ -//! Handwritten equation families and their shared simulation interfaces. -//! -//! This module is the public home for handwritten [`ODE`], [`Analytical`], and -//! [`SDE`] models, plus the shared [`Equation`] trait and the metadata types -//! that attach public names to parameters, states, routes, and outputs. -//! -//! Use this module when you want to: -//! - choose between deterministic ODE, analytical, and stochastic SDE models -//! - attach metadata so dataset labels such as `"iv"` and `"cp"` resolve by -//! name instead of by dense numeric index -//! - work with prediction or likelihood APIs across equation families -//! -//! # Equation Families -//! -//! - [`ODE`] for deterministic models that must be numerically integrated. -//! - [`Analytical`] for supported closed-form models. -//! - [`SDE`] for stochastic models that use particles. -//! -//! # Labels And Metadata -//! -//! Input and output labels arrive from public data APIs as strings. -//! -//! - Without metadata, handwritten models fall back to numeric labels such as -//! `0` or `1`. -//! - With [`metadata::ModelMetadata`] attached, route and output labels are -//! resolved by name against the declared routes and outputs before -//! simulation. -//! -//! That label-first path is the preferred public workflow for current authoring. -//! -//! # Example -//! -//! ```rust -//! use pharmsol::{metadata, ModelKind}; -//! -//! let metadata = metadata::new("one_cmt") -//! .kind(ModelKind::Ode) -//! .parameters(["cl", "v"]) -//! .states(["central"]) -//! .outputs(["cp"]) -//! .route(metadata::Route::infusion("iv").to_state("central")) -//! .validate() -//! .unwrap(); -//! -//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); -//! assert!(metadata.output("cp").is_some()); -//! ``` - -use std::{fmt::Debug, sync::Arc}; -pub mod analytical; -pub mod metadata; -pub mod ode; -pub mod sde; -pub use analytical::*; -pub use metadata::*; -pub use ode::*; -pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; -use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; -pub use sde::*; - -use crate::{ - error_model::{AssayErrorModels, BoundAssayErrorModels}, - simulator::{cache::BoundErrorModelCache, Fa, Lag}, - Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, Parameters, - PharmsolError, Subject, -}; - -use super::likelihood::Prediction; - -/// Trait for state vectors that can receive bolus doses. -pub trait State { - /// Add a bolus dose to the state at the specified resolved input index. - /// - /// # Parameters - /// - `input`: The resolved dense input index used by the execution layer - /// - `amount`: The bolus amount - fn add_bolus(&mut self, input: usize, amount: f64); -} - -/// Trait for prediction containers. -pub trait Predictions: Default { - /// Create a new prediction container with specified capacity. - /// - /// # Parameters - /// - `nparticles`: Number of particles (for SDE) - /// - /// # Returns - /// A new predictions container - fn new(_nparticles: usize) -> Self { - Default::default() - } - - /// Calculate the sum of squared errors for all predictions. - /// - /// # Returns - /// The sum of squared errors - fn squared_error(&self) -> f64; - - /// Get all predictions as a vector. - /// - /// # Returns - /// Vector of prediction objects - fn get_predictions(&self) -> Vec; - - /// Calculate the log-likelihood of the predictions given an error model. - /// - /// This is numerically more stable than computing the likelihood and taking its log, - /// especially for extreme values or many observations. - /// - /// # Parameters - /// - `error_models`: The error models for computing observation variance - /// - /// # Returns - /// The sum of log-likelihoods for all predictions - fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result; -} - -/// Trait for enabling prediction caching on equation types. -/// -/// Caching is **enabled by default** with a capacity of 100,000 entries. -/// Use these methods to adjust capacity, clear entries, or disable caching. -/// -/// # Example -/// ```ignore -/// use pharmsol::*; -/// -/// // Caching is on by default: -/// let ode = ODE::new(diffeq, lag, fa, init, out); -/// -/// // Adjust capacity: -/// let ode = ODE::new(diffeq, lag, fa, init, out) -/// .with_cache_capacity(50_000); -/// -/// // Disable caching: -/// let ode = ODE::new(diffeq, lag, fa, init, out) -/// .disable_cache(); -/// ``` -pub trait Cache: Sized { - /// Enable caching with the given maximum number of entries. - /// - /// When caching is enabled, results for the same inputs are stored and reused. - /// Cloned equations share the same cache. - /// - /// If caching is already enabled, this **replaces** the cache with a new, empty - /// one of the given size — all previously cached entries are discarded. - fn with_cache_capacity(self, size: u64) -> Self; - - /// Enable caching with the default size (100,000 entries). - /// - /// If caching is already enabled, this **replaces** the cache with a new, - /// empty one — all previously cached entries are discarded. - fn enable_cache(self) -> Self; - - /// Clear all entries from this equation's cache, if caching is enabled. - /// - /// The cache itself remains active (with the same capacity). - /// Does nothing if caching is not enabled. - fn clear_cache(&self); - - /// Disable caching. - /// - /// Disables caching and discards all cached entries. - fn disable_cache(self) -> Self; -} - -/// Associated state and prediction container types for an equation family. -pub trait EquationTypes { - /// The state vector type - type S: State + Debug; - /// The predictions container type - type P: Predictions; -} - -pub(crate) trait EquationPriv: EquationTypes { - // fn get_init(&self) -> &Init; - // fn get_out(&self) -> &Out; - fn lag(&self) -> &Lag; - fn fa(&self) -> &Fa; - fn get_nstates(&self) -> usize; - fn get_ndrugs(&self) -> usize; - fn get_nouteqs(&self) -> usize; - fn metadata(&self) -> Option<&ValidatedModelMetadata>; - fn solve( - &self, - state: &mut Self::S, - parameters: &[f64], - covariates: &Covariates, - infusions: &[Infusion], - start_time: f64, - end_time: f64, - ) -> Result<(), PharmsolError>; - fn nparticles(&self) -> usize { - 1 - } - - fn resolve_input_label( - &self, - label: &InputLabel, - expected_kind: RouteKind, - ) -> Result { - if let Some(metadata) = self.metadata() { - let route = metadata - .route(label.as_str()) - .or_else(|| { - canonical_numeric_alias(label.as_str(), NUMERIC_ROUTE_PREFIX) - .and_then(|alias| metadata.route(alias.as_str())) - }) - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - })?; - - if route.kind() != expected_kind { - return Err(PharmsolError::UnsupportedInputRouteKind { - input: route.input_index(), - kind: match expected_kind { - RouteKind::Bolus => pharmsol_dsl::RouteKind::Bolus, - RouteKind::Infusion => pharmsol_dsl::RouteKind::Infusion, - }, - }); - } - - return Ok(route.input_index()); - } - - label - .index() - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - }) - } - - fn resolve_output_label(&self, label: &OutputLabel) -> Result { - if let Some(metadata) = self.metadata() { - return metadata - .output_index(label.as_str()) - .or_else(|| { - canonical_numeric_alias(label.as_str(), NUMERIC_OUTPUT_PREFIX) - .and_then(|alias| metadata.output_index(alias.as_str())) - }) - .ok_or_else(|| PharmsolError::UnknownOutputLabel { - label: label.to_string(), - }); - } - - label - .index() - .ok_or_else(|| PharmsolError::UnknownOutputLabel { - label: label.to_string(), - }) - } - - fn resolve_occasion_events( - &self, - occasion: &Occasion, - parameters: &[f64], - covariates: &Covariates, - ) -> Result, PharmsolError> { - let mut resolved = occasion.clone(); - - for event in resolved.events_iter_mut() { - match event { - Event::Bolus(bolus) => { - let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; - bolus.set_input(input); - } - Event::Infusion(infusion) => { - let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; - infusion.set_input(input); - } - Event::Observation(observation) => { - let outeq = self.resolve_output_label(observation.outeq())?; - observation.set_outeq(outeq); - } - } - } - - Ok(resolved.process_events(Some((self.fa(), self.lag(), parameters, covariates)), true)) - } - #[allow(dead_code)] - fn is_sde(&self) -> bool { - false - } - - #[allow(clippy::too_many_arguments)] - fn process_observation( - &self, - parameters: &[f64], - observation: &Observation, - error_models: Option<&AssayErrorModels>, - time: f64, - covariates: &Covariates, - x: &mut Self::S, - likelihood: &mut Vec, - output: &mut Self::P, - ) -> Result<(), PharmsolError>; - - fn initial_state( - &self, - parameters: &[f64], - covariates: &Covariates, - occasion_index: usize, - ) -> Self::S; - - #[allow(clippy::too_many_arguments)] - fn simulate_event( - &self, - parameters: &[f64], - event: &Event, - next_event: Option<&Event>, - error_models: Option<&AssayErrorModels>, - covariates: &Covariates, - x: &mut Self::S, - infusions: &mut Vec, - likelihood: &mut Vec, - output: &mut Self::P, - ) -> Result<(), PharmsolError> { - match event { - Event::Bolus(bolus) => { - let input = - bolus - .input_index() - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: bolus.input().to_string(), - })?; - - if input >= self.get_ndrugs() { - return Err(PharmsolError::InputOutOfRange { - input, - ndrugs: self.get_ndrugs(), - }); - } - x.add_bolus(input, bolus.amount()); - } - Event::Infusion(infusion) => { - infusions.push(infusion.clone()); - } - Event::Observation(observation) => { - self.process_observation( - parameters, - observation, - error_models, - event.time(), - covariates, - x, - likelihood, - output, - )?; - } - } - - if let Some(next_event) = next_event { - self.solve( - x, - parameters, - covariates, - infusions, - event.time(), - next_event.time(), - )?; - } - Ok(()) - } -} - -fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { - if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { - return None; - } - Some(format!("{prefix}{label}")) -} - -/// Trait for handwritten model equations that can be simulated. -/// -/// [`Equation`] is the shared interface implemented by handwritten [`ODE`], -/// [`Analytical`], and [`SDE`] models. -/// -/// Subject data enters this layer through public labels on dose and observation -/// events. If metadata is attached to the equation, those labels are resolved by -/// name before simulation. Otherwise, the execution layer expects numeric labels -/// that can be interpreted as dense indices. -/// -/// # Likelihood Calculation -/// -/// Use [`estimate_log_likelihood`](Self::estimate_log_likelihood) for numerically stable -/// likelihood computation. The deprecated [`estimate_likelihood`](Self::estimate_likelihood) -/// is provided for backward compatibility. -#[allow(private_bounds)] -pub trait Equation: EquationPriv + 'static + Clone + Sync { - #[doc(hidden)] - fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> { - None - } - - #[doc(hidden)] - fn bind_error_models<'a>( - &'a self, - error_models: &'a AssayErrorModels, - ) -> Result, PharmsolError> { - if let Some(cache) = self.bound_error_model_cache() { - let key = error_models.hash(); - if let Some(bound_error_models) = cache.get(&key) { - return Ok(BoundAssayErrorModels::Shared(bound_error_models)); - } - - return match error_models.bind_to(self)? { - BoundAssayErrorModels::Owned(bound_error_models) => { - let bound_error_models = Arc::new(bound_error_models); - cache.insert(key, Arc::clone(&bound_error_models)); - Ok(BoundAssayErrorModels::Shared(bound_error_models)) - } - bound_error_models => Ok(bound_error_models), - }; - } - - Ok(error_models.bind_to(self)?) - } - - /// Estimate the likelihood of the subject given the parameters and error model. - /// - /// **Deprecated**: Use [`estimate_log_likelihood`](Self::estimate_log_likelihood) instead - /// for better numerical stability, especially with many observations or extreme parameter values. - /// - /// This function calculates how likely the observed data is given the model - /// parameters and error model. It may use caching for performance. - /// - /// # Parameters - /// - `subject`: The subject data - /// - `parameters`: The parameter values - /// - `error_model`: The error model - /// - /// # Returns - /// The likelihood value (product of individual observation likelihoods) - #[deprecated( - since = "0.23.0", - note = "Use estimate_log_likelihood() instead for better numerical stability" - )] - fn estimate_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result; - - /// Estimate the log-likelihood of the subject given the parameters and error model. - /// - /// This function calculates the log of how likely the observed data is given the model - /// parameters and error model. It is numerically more stable than `estimate_likelihood` - /// for extreme values or many observations. - /// - /// Uses observation-based sigma, appropriate for non-parametric algorithms. - /// For parametric algorithms (SAEM, FOCE), use [`crate::ResidualErrorModels`] directly. - /// - /// # Parameters - /// - `subject`: The subject data - /// - `parameters`: The parameter values - /// - `error_models`: The error model - /// - /// # Returns - /// The log-likelihood value (sum of individual observation log-likelihoods) - fn estimate_log_likelihood( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: &AssayErrorModels, - ) -> Result; - - fn kind() -> EqnKind; - - #[doc(hidden)] - fn estimate_predictions_dense( - &self, - subject: &Subject, - parameters: &[f64], - ) -> Result { - Ok(self.simulate_subject_dense(subject, parameters, None)?.0) - } - - #[doc(hidden)] - fn estimate_log_likelihood_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: &AssayErrorModels, - ) -> Result { - let bound_error_models = self.bind_error_models(error_models)?; - let predictions = self.estimate_predictions_dense(subject, parameters)?; - predictions.log_likelihood(&bound_error_models) - } - - #[doc(hidden)] - fn simulate_subject_dense( - &self, - subject: &Subject, - parameters: &[f64], - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - let bound_error_models = match error_models { - Some(error_models) => Some(self.bind_error_models(error_models)?), - None => None, - }; - let bound_error_models = bound_error_models.as_ref().map(|models| &**models); - - let mut output = Self::P::new(self.nparticles()); - let mut likelihood = Vec::new(); - for occasion in subject.occasions() { - let covariates = occasion.covariates(); - - let mut x = self.initial_state(parameters, covariates, occasion.index()); - let mut infusions = Vec::new(); - let events = self.resolve_occasion_events(occasion, parameters, covariates)?; - for (index, event) in events.iter().enumerate() { - self.simulate_event( - parameters, - event, - events.get(index + 1), - bound_error_models, - covariates, - &mut x, - &mut infusions, - &mut likelihood, - &mut output, - )?; - } - } - let ll = bound_error_models.map(|_| likelihood.iter().product::()); - Ok((output, ll)) - } - - /// Generate predictions for a subject with given parameters. - /// - /// # Parameters - /// - `subject`: The subject data - /// - `parameters`: The parameter values - /// - /// # Returns - /// Predicted concentrations - fn estimate_predictions( - &self, - subject: &Subject, - parameters: &Parameters, - ) -> Result { - self.estimate_predictions_dense(subject, parameters.as_slice()) - } - - /// Get the number of output equations in the model. - fn nouteqs(&self) -> usize { - self.get_nouteqs() - } - - /// Get the number of state variables in the model. - fn nstates(&self) -> usize { - self.get_nstates() - } - - /// Build a label-aware [`AssayErrorModels`] set for this equation. - /// - /// Handwritten equations resolve output labels from attached metadata. - /// Equations without metadata fall back to an explicit unbound set so dense - /// output-slot workflows remain available without adding runtime lookup cost. - #[doc(hidden)] - fn assay_error_models(&self) -> AssayErrorModels { - self.metadata() - .map(|metadata| { - AssayErrorModels::with_output_names( - metadata.outputs().iter().map(|output| output.name()), - ) - }) - .unwrap_or_else(AssayErrorModels::empty) - } - - /// Simulate a subject with given parameters and optionally calculate likelihood. - /// - /// # Parameters - /// - `subject`: The subject data - /// - `parameters`: The parameter values - /// - `error_model`: The error model (optional) - /// - /// # Returns - /// A tuple containing predictions and optional likelihood - fn simulate_subject( - &self, - subject: &Subject, - parameters: &Parameters, - error_models: Option<&AssayErrorModels>, - ) -> Result<(Self::P, Option), PharmsolError> { - self.simulate_subject_dense(subject, parameters.as_slice(), error_models) - } -} - -/// Runtime family tag for handwritten equations. -#[repr(C)] -#[derive(Clone, Debug)] -pub enum EqnKind { - ODE = 0, - Analytical = 1, - SDE = 2, -} - -impl EqnKind { - pub fn to_str(&self) -> &'static str { - match self { - Self::ODE => "EqnKind::ODE", - Self::Analytical => "EqnKind::Analytical", - Self::SDE => "EqnKind::SDE", - } - } -} - -/// Hash parameter vectors to a u64 for cache key generation. -#[inline(always)] -pub(crate) fn parameters_hash(parameters: &[f64]) -> u64 { - use std::hash::{Hash, Hasher}; - let mut hasher = ahash::AHasher::default(); - for &value in parameters { - // Normalize -0.0 to 0.0 for consistent hashing - let bits = if value == 0.0 { 0u64 } else { value.to_bits() }; - bits.hash(&mut hasher); - } - hasher.finish() -} diff --git a/src/simulator/likelihood/matrix.rs b/src/simulator/likelihood/matrix.rs index 2ff73273..af128155 100644 --- a/src/simulator/likelihood/matrix.rs +++ b/src/simulator/likelihood/matrix.rs @@ -6,8 +6,9 @@ use ndarray::{Array2, Axis, ShapeBuilder}; use rayon::prelude::*; +use crate::core::Simulate; use crate::data::error_model::AssayErrorModels; -use crate::{Data, Equation, PharmsolError}; +use crate::{Data, PharmsolError}; use super::progress::ProgressTracker; @@ -50,7 +51,7 @@ use super::progress::ProgressTracker; /// )?; /// ``` pub fn log_likelihood_matrix( - equation: &impl Equation, + model: &impl Simulate, subjects: &Data, support_points: &Array2, error_models: &AssayErrorModels, @@ -84,7 +85,7 @@ pub fn log_likelihood_matrix( let subject = &subject_slice[i]; for (element, support_point) in row.iter_mut().zip(support_point_rows.iter()) { - *element = equation.estimate_log_likelihood_dense( + *element = model.log_likelihood( subject, support_point.as_slice(), error_models, @@ -115,13 +116,13 @@ pub fn log_likelihood_matrix( note = "Use log_likelihood_matrix() with LikelihoodMatrixOptions instead" )] pub fn log_psi( - equation: &impl Equation, + model: &impl Simulate, subjects: &Data, support_points: &Array2, error_models: &AssayErrorModels, progress: bool, ) -> Result, PharmsolError> { - log_likelihood_matrix(equation, subjects, support_points, error_models, progress) + log_likelihood_matrix(model, subjects, support_points, error_models, progress) } /// Calculate the likelihood matrix (deprecated). @@ -136,14 +137,14 @@ pub fn log_psi( note = "Use log_likelihood_matrix() instead and exponentiate if needed" )] pub fn psi( - equation: &impl Equation, + model: &impl Simulate, subjects: &Data, support_points: &Array2, error_models: &AssayErrorModels, progress: bool, ) -> Result, PharmsolError> { let log_psi_matrix = - log_likelihood_matrix(equation, subjects, support_points, error_models, progress)?; + log_likelihood_matrix(model, subjects, support_points, error_models, progress)?; // Exponentiate to get likelihoods (may underflow to 0 for extreme values) Ok(log_psi_matrix.mapv(f64::exp)) diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index 62ee58b0..d25a8c2d 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -82,7 +82,8 @@ pub use subject::{PopulationPredictions, SubjectPredictions}; use ndarray::{Array2, Axis}; use rayon::prelude::*; -use crate::{Data, Equation, PharmsolError, Predictions, Subject}; +use crate::core::{PredictionsContainer, Simulate}; +use crate::{Data, PharmsolError, Subject}; /// Compute log-likelihoods for all subjects in parallel, where each subject /// has its own parameter vector. @@ -117,7 +118,7 @@ use crate::{Data, Equation, PharmsolError, Predictions, Subject}; /// )?; /// ``` pub fn log_likelihood_batch( - equation: &impl Equation, + model: &impl Simulate, subjects: &Data, parameters: &Array2, residual_error_models: &crate::ResidualErrorModels, @@ -134,14 +135,14 @@ pub fn log_likelihood_batch( } let score_subject = |subject: &Subject, parameter_row: &[f64]| { - let predictions = match equation.estimate_predictions_dense(subject, parameter_row) { + let predictions = match model.predictions(subject, parameter_row) { Ok(preds) => preds, Err(_) => return f64::NEG_INFINITY, }; let obs_pred_pairs = predictions - .get_predictions() - .into_iter() + .predictions() + .iter() .filter_map(|pred| { pred.observation() .map(|obs| (pred.outeq(), obs, pred.prediction())) @@ -203,21 +204,21 @@ pub fn log_likelihood_batch( /// ); /// ``` pub fn log_likelihood_subject( - equation: &impl Equation, + model: &impl Simulate, subject: &Subject, params: &crate::Parameters, residual_error_models: &crate::ResidualErrorModels, ) -> f64 { // Simulate to get predictions - let predictions = match equation.estimate_predictions_dense(subject, params.as_slice()) { + let predictions = match model.predictions(subject, params.as_slice()) { Ok(preds) => preds, Err(_) => return f64::NEG_INFINITY, }; // Extract (outeq, observation, prediction) tuples and compute log-likelihood let obs_pred_pairs = predictions - .get_predictions() - .into_iter() + .predictions() + .iter() .filter_map(|pred| { pred.observation() .map(|obs| (pred.outeq(), obs, pred.prediction())) diff --git a/src/simulator/likelihood/subject.rs b/src/simulator/likelihood/subject.rs index 05c1417f..37b7f2f5 100644 --- a/src/simulator/likelihood/subject.rs +++ b/src/simulator/likelihood/subject.rs @@ -37,6 +37,24 @@ impl Predictions for SubjectPredictions { } } +impl crate::core::PredictionsContainer for SubjectPredictions { + fn new(_nparticles: usize) -> Self { + Self::default() + } + + fn push(&mut self, pred: Prediction) { + self.predictions.push(pred); + } + + fn predictions(&self) -> &[Prediction] { + &self.predictions + } + + fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result { + SubjectPredictions::log_likelihood(self, error_models) + } +} + impl SubjectPredictions { /// Calculate the log-likelihood of all predictions given an error model. /// diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 058ca125..3b92622b 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -1,5 +1,5 @@ pub mod cache; -pub mod equation; +pub mod backends; pub(crate) mod likelihood; use diffsol::{NalgebraMat, NalgebraVec}; @@ -41,7 +41,7 @@ pub type M = NalgebraMat; pub type DiffEq = fn(&V, &V, T, &mut V, &V, &V, &Covariates); /// This closure represents an Analytical solution of the model. -/// See [`equation::analytical`] module for examples. +/// See [`backends::analytical`] module for examples. /// /// # Parameters /// - `x`: The state vector at time t @@ -209,7 +209,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// # Example /// ```ignore /// // Using the builder pattern on ODE/Analytical/SDE -/// let ode = equation::ODE::new(diffeq, lag, fa, init, out) +/// let ode = backends::ODE::new(diffeq, lag, fa, init, out) /// .with_nstates(2) /// .with_ndrugs(1) /// .with_nout(1); diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index 4ff140e8..570dd56b 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -50,7 +50,7 @@ fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) .build() } -fn macro_one_compartment() -> equation::Analytical { +fn macro_one_compartment() -> backends::Analytical { analytical! { name: "one_cpt_iv", params: [ke, v], @@ -66,9 +66,9 @@ fn macro_one_compartment() -> equation::Analytical { } } -fn handwritten_one_compartment() -> equation::Analytical { - equation::Analytical::new( - equation::one_compartment, +fn handwritten_one_compartment() -> backends::Analytical { + backends::Analytical::new( + backends::one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, |_p, _t, _cov| fa! {}, @@ -82,18 +82,18 @@ fn handwritten_one_compartment() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cpt_iv") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("one_cpt_iv") + .kind(backends::ModelKind::Analytical) .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) - .route(equation::Route::infusion("iv").to_state("central")) - .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + .route(backends::Route::infusion("iv").to_state("central")) + .analytical_kernel(backends::AnalyticalKernel::OneCompartment), ) .expect("handwritten analytical metadata should validate") } -fn macro_one_compartment_with_absorption() -> equation::Analytical { +fn macro_one_compartment_with_absorption() -> backends::Analytical { analytical! { name: "one_cmt_abs", params: [ka, ke, v, tlag, f_oral], @@ -119,9 +119,9 @@ fn macro_one_compartment_with_absorption() -> equation::Analytical { } } -fn handwritten_one_compartment_with_absorption() -> equation::Analytical { - equation::Analytical::new( - equation::one_compartment_with_absorption, +fn handwritten_one_compartment_with_absorption() -> backends::Analytical { + backends::Analytical::new( + backends::one_compartment_with_absorption, |_p, _t, _cov| {}, |p, _t, _cov| { fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); @@ -144,23 +144,23 @@ fn handwritten_one_compartment_with_absorption() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("one_cmt_abs") + .kind(backends::ModelKind::Analytical) .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .with_lag() .with_bioavailability(), ) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten absorption metadata should validate") } -fn macro_shared_input_analytical() -> equation::Analytical { +fn macro_shared_input_analytical() -> backends::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -183,9 +183,9 @@ fn macro_shared_input_analytical() -> equation::Analytical { } } -fn handwritten_shared_input_analytical() -> equation::Analytical { - equation::Analytical::new( - equation::one_compartment_with_absorption, +fn handwritten_shared_input_analytical() -> backends::Analytical { + backends::Analytical::new( + backends::one_compartment_with_absorption, |_p, _t, _cov| {}, |p, _t, _cov| { fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); @@ -205,24 +205,24 @@ fn handwritten_shared_input_analytical() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs_shared") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("one_cmt_abs_shared") + .kind(backends::ModelKind::Analytical) .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .with_lag() .with_bioavailability(), - equation::Route::infusion("iv").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten shared-input analytical metadata should validate") } -fn macro_covariate_analytical() -> equation::Analytical { +fn macro_covariate_analytical() -> backends::Analytical { analytical! { name: "one_cmt_abs_covariates", params: [ka, ke0, v, tlag, f_oral, base_gut, base_central], @@ -259,8 +259,8 @@ fn macro_covariate_analytical() -> equation::Analytical { } } -fn handwritten_covariate_analytical() -> equation::Analytical { - equation::Analytical::new( +fn handwritten_covariate_analytical() -> backends::Analytical { + backends::Analytical::new( |x, p, t, rateiv, cov| { fetch_params!(p, ka, ke0, _v, _tlag, _f_oral, _base_gut, _base_central); fetch_cov!(cov, t, wt, renal); @@ -269,7 +269,7 @@ fn handwritten_covariate_analytical() -> equation::Analytical { let renal_scale = (renal / 90.0).powf(0.25); let ke = ke0 * wt_scale * renal_scale; let projected = pharmsol::__macro_support::vector_from_values(vec![ka, ke]); - equation::one_compartment_with_absorption(x, &projected, t, rateiv, cov) + backends::one_compartment_with_absorption(x, &projected, t, rateiv, cov) }, |_p, _t, _cov| {}, |p, t, cov| { @@ -305,8 +305,8 @@ fn handwritten_covariate_analytical() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs_covariates") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("one_cmt_abs_covariates") + .kind(backends::ModelKind::Analytical) .parameters([ "ka", "ke0", @@ -317,19 +317,19 @@ fn handwritten_covariate_analytical() -> equation::Analytical { "base_central", ]) .covariates([ - equation::Covariate::continuous("wt"), - equation::Covariate::continuous("renal"), + backends::Covariate::continuous("wt"), + backends::Covariate::continuous("renal"), ]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .with_lag() .with_bioavailability(), - equation::Route::infusion("iv").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten covariate analytical metadata should validate") } @@ -395,7 +395,7 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { assert_eq!(macro_model.state_index("gut"), Some(0)); assert_eq!( macro_metadata.analytical_kernel(), - Some(equation::AnalyticalKernel::OneCompartmentWithAbsorption) + Some(backends::AnalyticalKernel::OneCompartmentWithAbsorption) ); let macro_predictions = macro_model diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index b3479a02..db3ab8f1 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -3,8 +3,8 @@ use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; #[cfg(feature = "dsl-jit")] -use pharmsol::equation::RouteInputPolicy; -use pharmsol::equation::{ +use pharmsol::backends::RouteInputPolicy; +use pharmsol::backends::{ self, AnalyticalKernel, RouteKind as HandwrittenRouteKind, ValidatedModelMetadata, }; use pharmsol::prelude::*; @@ -594,7 +594,7 @@ fn handwritten_route_input_policy_view( .collect() } -fn macro_ode_model() -> equation::ODE { +fn macro_ode_model() -> backends::ODE { ode! { name: "one_cmt_oral_covariate_parity", params: [ka, cl, v, tlag, f_oral], @@ -621,8 +621,8 @@ fn macro_ode_model() -> equation::ODE { } } -fn handwritten_ode_macro_model() -> equation::ODE { - equation::ODE::new( +fn handwritten_ode_macro_model() -> backends::ODE { + backends::ODE::new( |_x, _p, _t, dx, _bolus, _rateiv, _cov| { dx[0] = 0.0; dx[1] = 0.0; @@ -638,13 +638,13 @@ fn handwritten_ode_macro_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_oral_covariate_parity") + pharmsol::metadata::new("one_cmt_oral_covariate_parity") .parameters(["ka", "cl", "v", "tlag", "f_oral"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["depot", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .inject_input_to_destination() .with_lag() @@ -654,8 +654,8 @@ fn handwritten_ode_macro_model() -> equation::ODE { .expect("handwritten macro-shape ODE metadata should validate") } -fn handwritten_ode_model() -> equation::ODE { - equation::ODE::new( +fn handwritten_ode_model() -> backends::ODE { + backends::ODE::new( |_x, _p, _t, dx, _bolus, _rateiv, _cov| { dx[0] = 0.0; dx[1] = 0.0; @@ -671,18 +671,18 @@ fn handwritten_ode_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_oral_iv") + pharmsol::metadata::new("one_cmt_oral_iv") .parameters(["ka", "cl", "v", "tlag", "f_oral"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .inject_input_to_destination() .with_lag() .with_bioavailability(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -691,7 +691,7 @@ fn handwritten_ode_model() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_macro_ode() -> equation::ODE { +fn runtime_shared_input_macro_ode() -> backends::ODE { ode! { name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], @@ -718,8 +718,8 @@ fn runtime_shared_input_macro_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn runtime_shared_input_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); dx[0] = bolus[0] - ka * x[0]; @@ -743,17 +743,17 @@ fn runtime_shared_input_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_input_one_cpt") + pharmsol::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .inject_input_to_destination(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]), @@ -762,8 +762,8 @@ fn runtime_shared_input_handwritten_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_mismatched_shared_input_ode() -> equation::ODE { - equation::ODE::new( +fn runtime_mismatched_shared_input_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, _bolus, _rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); dx[0] = -ka * x[0]; @@ -787,17 +787,17 @@ fn runtime_mismatched_shared_input_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_input_one_cpt_mismatched") + pharmsol::metadata::new("shared_input_one_cpt_mismatched") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -806,7 +806,7 @@ fn runtime_mismatched_shared_input_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_macro_analytical() -> equation::Analytical { +fn runtime_shared_input_macro_analytical() -> backends::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -830,9 +830,9 @@ fn runtime_shared_input_macro_analytical() -> equation::Analytical { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { - equation::Analytical::new( - equation::one_compartment_with_absorption, +fn runtime_shared_input_handwritten_analytical() -> backends::Analytical { + backends::Analytical::new( + backends::one_compartment_with_absorption, |_p, _t, _cov| {}, |p, _t, _cov| { fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); @@ -852,25 +852,25 @@ fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs_shared") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("one_cmt_abs_shared") + .kind(backends::ModelKind::Analytical) .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .with_lag() .with_bioavailability(), - equation::Route::infusion("iv").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten shared-input analytical metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_macro_sde() -> equation::SDE { +fn runtime_shared_input_macro_sde() -> backends::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -906,8 +906,8 @@ fn runtime_shared_input_macro_sde() -> equation::SDE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_input_handwritten_sde() -> equation::SDE { - equation::SDE::new( +fn runtime_shared_input_handwritten_sde() -> backends::SDE { + backends::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); dx[0] = -ka * x[0]; @@ -939,18 +939,18 @@ fn runtime_shared_input_handwritten_sde() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_shared_sde") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("one_cmt_shared_sde") + .kind(backends::ModelKind::Sde) .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .inject_input_to_destination() .with_lag() .with_bioavailability(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]) @@ -991,7 +991,7 @@ fn particle_prediction_means(predictions: &ndarray::Array2) -> Vec equation::Analytical { +fn macro_analytical_model() -> backends::Analytical { analytical! { name: "one_cmt_abs_parity", params: [ka, ke, v], @@ -1007,9 +1007,9 @@ fn macro_analytical_model() -> equation::Analytical { } } -fn handwritten_analytical_model() -> equation::Analytical { - equation::Analytical::new( - equation::one_compartment_with_absorption, +fn handwritten_analytical_model() -> backends::Analytical { + backends::Analytical::new( + backends::one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, |_p, _t, _cov| fa! {}, @@ -1022,18 +1022,18 @@ fn handwritten_analytical_model() -> equation::Analytical { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs_parity") + pharmsol::metadata::new("one_cmt_abs_parity") .kind(ModelKind::Analytical) .parameters(["ka", "ke", "v"]) .states(["depot", "central"]) .outputs(["cp"]) - .route(equation::Route::bolus("oral").to_state("depot")) + .route(backends::Route::bolus("oral").to_state("depot")) .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten analytical metadata should validate") } -fn macro_sde_model() -> equation::SDE { +fn macro_sde_model() -> backends::SDE { sde! { name: "one_cmt_sde_macro_parity", params: [ka, ke, v, sigma], @@ -1057,8 +1057,8 @@ fn macro_sde_model() -> equation::SDE { } } -fn handwritten_sde_model() -> equation::SDE { - equation::SDE::new( +fn handwritten_sde_model() -> backends::SDE { + backends::SDE::new( |_x, _p, _t, dx, _rateiv, _cov| { dx[0] = 0.0; dx[1] = 0.0; @@ -1079,14 +1079,14 @@ fn handwritten_sde_model() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_sde_parity") + pharmsol::metadata::new("one_cmt_sde_parity") .kind(ModelKind::Sde) .parameters(["ka", "ke", "v", "sigma"]) - .covariates([equation::Covariate::locf("wt")]) + .covariates([backends::Covariate::locf("wt")]) .states(["depot", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .inject_input_to_destination(), ) @@ -1095,8 +1095,8 @@ fn handwritten_sde_model() -> equation::SDE { .expect("handwritten SDE metadata should validate") } -fn handwritten_sde_macro_model() -> equation::SDE { - equation::SDE::new( +fn handwritten_sde_macro_model() -> backends::SDE { + backends::SDE::new( |_x, _p, _t, dx, _rateiv, _cov| { dx[0] = 0.0; dx[1] = 0.0; @@ -1117,13 +1117,13 @@ fn handwritten_sde_macro_model() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_sde_macro_parity") + pharmsol::metadata::new("one_cmt_sde_macro_parity") .kind(ModelKind::Sde) .parameters(["ka", "ke", "v", "sigma"]) .states(["depot", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .inject_input_to_destination(), ) @@ -1133,8 +1133,8 @@ fn handwritten_sde_macro_model() -> equation::SDE { } #[cfg(feature = "dsl-jit")] -fn mismatched_ode_model() -> equation::ODE { - equation::ODE::new( +fn mismatched_ode_model() -> backends::ODE { + backends::ODE::new( |_x, _p, _t, dx, _bolus, _rateiv, _cov| { dx[0] = 0.0; dx[1] = 0.0; @@ -1150,18 +1150,18 @@ fn mismatched_ode_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_oral_iv") + pharmsol::metadata::new("one_cmt_oral_iv") .parameters(["ka", "cl", "v", "tlag", "f_oral"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .expect_explicit_input() .with_lag() .with_bioavailability(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 66c5c116..eb83dcbf 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -7,7 +7,7 @@ fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { .fold(0.0_f64, f64::max) } -fn macro_ode_model() -> equation::ODE { +fn macro_ode_model() -> backends::ODE { ode! { name: "ode_full_feature_parity", params: [ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral], @@ -51,8 +51,8 @@ fn macro_ode_model() -> equation::ODE { } } -fn handwritten_ode_model() -> equation::ODE { - equation::ODE::new( +fn handwritten_ode_model() -> backends::ODE { + backends::ODE::new( |x, p, t, dx, bolus, rateiv, cov| { fetch_params!( p, @@ -161,7 +161,7 @@ fn handwritten_ode_model() -> equation::ODE { .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new("ode_full_feature_parity") + pharmsol::metadata::new("ode_full_feature_parity") .parameters([ "ka", "ke", @@ -175,21 +175,21 @@ fn handwritten_ode_model() -> equation::ODE { "base_peripheral", ]) .covariates([ - equation::Covariate::continuous("wt"), - equation::Covariate::continuous("renal"), + backends::Covariate::continuous("wt"), + backends::Covariate::continuous("renal"), ]) .states(["depot", "central", "peripheral"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .inject_input_to_destination(), - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .inject_input_to_destination(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]), @@ -217,7 +217,7 @@ fn build_ode_subject() -> Subject { .build() } -fn macro_analytical_model() -> equation::Analytical { +fn macro_analytical_model() -> backends::Analytical { analytical! { name: "analytical_full_feature_parity", params: [ka, ke0, v, tlag, f_oral, base_gut, base_central], @@ -255,8 +255,8 @@ fn macro_analytical_model() -> equation::Analytical { } } -fn handwritten_analytical_model() -> equation::Analytical { - equation::Analytical::new( +fn handwritten_analytical_model() -> backends::Analytical { + backends::Analytical::new( |x, p, t, rateiv, cov| { fetch_params!(p, ka, ke0, _v, _tlag, _f_oral, _base_gut, _base_central); fetch_cov!(cov, t, wt, renal); @@ -265,7 +265,7 @@ fn handwritten_analytical_model() -> equation::Analytical { let renal_scale = (renal / 90.0).powf(0.25); let ke = ke0 * wt_scale * renal_scale; let projected = pharmsol::__macro_support::vector_from_values(vec![ka, ke]); - equation::one_compartment_with_absorption(x, &projected, t, rateiv, cov) + backends::one_compartment_with_absorption(x, &projected, t, rateiv, cov) }, |_p, _t, _cov| {}, |p, t, cov| { @@ -301,8 +301,8 @@ fn handwritten_analytical_model() -> equation::Analytical { .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new("analytical_full_feature_parity") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("analytical_full_feature_parity") + .kind(backends::ModelKind::Analytical) .parameters([ "ka", "ke0", @@ -313,20 +313,20 @@ fn handwritten_analytical_model() -> equation::Analytical { "base_central", ]) .covariates([ - equation::Covariate::continuous("wt"), - equation::Covariate::continuous("renal"), + backends::Covariate::continuous("wt"), + backends::Covariate::continuous("renal"), ]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .with_lag() .with_bioavailability(), - equation::Route::bolus("load").to_state("central"), - equation::Route::infusion("iv").to_state("central"), + backends::Route::bolus("load").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("handwritten analytical metadata should validate") } diff --git a/tests/numerical_stability.rs b/tests/numerical_stability.rs index 7b096528..fe36ad54 100644 --- a/tests/numerical_stability.rs +++ b/tests/numerical_stability.rs @@ -8,7 +8,7 @@ const ABS_TOL: f64 = 1e-2; fn parameters_for_analytical( label: &str, - analytical: &equation::Analytical, + analytical: &backends::Analytical, param_names: &[&str], params: &[f64], ) -> Parameters { @@ -29,7 +29,7 @@ fn parameters_for_analytical( fn parameters_for_ode( label: &str, - ode: &equation::ODE, + ode: &backends::ODE, param_names: &[&str], params: &[f64], ) -> Parameters { @@ -95,8 +95,8 @@ fn two_compartment_multi_dose_is_well_behaved() { fn assert_models_agree( label: &str, - analytical: &equation::Analytical, - ode: &equation::ODE, + analytical: &backends::Analytical, + ode: &backends::ODE, subject: &Subject, param_names: &[&str], params: &[f64], @@ -150,8 +150,8 @@ fn infusion_subject() -> Subject { builder.build() } -fn infusion_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( +fn infusion_models() -> (backends::Analytical, backends::ODE) { + let analytical = backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -166,20 +166,20 @@ fn infusion_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("infusion_reference") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("infusion_reference") + .kind(backends::ModelKind::Analytical) .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load").to_state("central"), - equation::Route::infusion("iv").to_state("central"), + backends::Route::bolus("load").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + .analytical_kernel(backends::AnalyticalKernel::OneCompartment), ) .expect("infusion analytical metadata should validate"); - let ode = equation::ODE::new( + let ode = backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0] + b[0]; @@ -196,15 +196,15 @@ fn infusion_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("infusion_reference") + pharmsol::metadata::new("infusion_reference") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -230,8 +230,8 @@ fn absorption_subject() -> Subject { builder.build() } -fn absorption_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( +fn absorption_models() -> (backends::Analytical, backends::ODE) { + let analytical = backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -246,21 +246,21 @@ fn absorption_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new("absorption_reference") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("absorption_reference") + .kind(backends::ModelKind::Analytical) .parameters(["ka", "ke", "v"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load").to_state("central"), - equation::Route::bolus("oral").to_state("gut"), - equation::Route::infusion("iv").to_state("central"), + backends::Route::bolus("load").to_state("central"), + backends::Route::bolus("oral").to_state("gut"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("absorption analytical metadata should validate"); - let ode = equation::ODE::new( + let ode = backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; @@ -278,18 +278,18 @@ fn absorption_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new("absorption_reference") + pharmsol::metadata::new("absorption_reference") .parameters(["ka", "ke", "v"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .expect_explicit_input(), - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -313,8 +313,8 @@ fn two_compartment_subject() -> Subject { builder.build() } -fn two_compartment_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( +fn two_compartment_models() -> (backends::Analytical, backends::ODE) { + let analytical = backends::Analytical::new( two_compartments, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -329,20 +329,20 @@ fn two_compartment_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("two_comp_reference") - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new("two_comp_reference") + .kind(backends::ModelKind::Analytical) .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load").to_state("central"), - equation::Route::infusion("iv").to_state("central"), + backends::Route::bolus("load").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::TwoCompartments), + .analytical_kernel(backends::AnalyticalKernel::TwoCompartments), ) .expect("two-compartment analytical metadata should validate"); - let ode = equation::ODE::new( + let ode = backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = rateiv[0] - ke * x[0] - kcp * x[0] + kpc * x[1] + b[0]; @@ -360,18 +360,18 @@ fn two_compartment_models() -> (equation::Analytical, equation::ODE) { .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new("two_comp_reference") + pharmsol::metadata::new("two_comp_reference") .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .expect_explicit_input(), - equation::Route::bolus("peripheral_load") + backends::Route::bolus("peripheral_load") .to_state("peripheral") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index d3088bcf..d941379d 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -50,7 +50,7 @@ fn subject_for_numeric_bolus_route(input: impl ToString, outeq: impl ToString) - .build() } -fn injected_macro_ode() -> equation::ODE { +fn injected_macro_ode() -> backends::ODE { ode! { name: "injected_one_cpt", params: [ke, v], @@ -68,8 +68,8 @@ fn injected_macro_ode() -> equation::ODE { } } -fn injected_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn injected_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -86,12 +86,12 @@ fn injected_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("injected_one_cpt") + pharmsol::metadata::new("injected_one_cpt") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .route( - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ), @@ -99,7 +99,7 @@ fn injected_handwritten_ode() -> equation::ODE { .expect("handwritten injected metadata should validate") } -fn numeric_label_macro_ode() -> equation::ODE { +fn numeric_label_macro_ode() -> backends::ODE { ode! { name: "numeric_label_one_cpt", params: [ke, v], @@ -117,8 +117,8 @@ fn numeric_label_macro_ode() -> equation::ODE { } } -fn numeric_label_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn numeric_label_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -135,12 +135,12 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("numeric_label_one_cpt") + pharmsol::metadata::new("numeric_label_one_cpt") .parameters(["ke", "v"]) .states(["central"]) .outputs(["outeq_1"]) .route( - equation::Route::infusion("input_1") + backends::Route::infusion("input_1") .to_state("central") .inject_input_to_destination(), ), @@ -148,7 +148,7 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .expect("handwritten numeric-label metadata should validate") } -fn shared_input_macro_ode() -> equation::ODE { +fn shared_input_macro_ode() -> backends::ODE { ode! { name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], @@ -174,8 +174,8 @@ fn shared_input_macro_ode() -> equation::ODE { } } -fn shared_input_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn shared_input_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); dx[0] = bolus[0] - ka * x[0]; @@ -199,17 +199,17 @@ fn shared_input_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_input_one_cpt") + pharmsol::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .inject_input_to_destination(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]), @@ -217,7 +217,7 @@ fn shared_input_handwritten_ode() -> equation::ODE { .expect("handwritten shared-input metadata should validate") } -fn numeric_route_property_macro_ode() -> equation::ODE { +fn numeric_route_property_macro_ode() -> backends::ODE { ode! { name: "numeric_route_property_one_cpt", params: [ka, ke, v, tlag, f_oral], @@ -242,8 +242,8 @@ fn numeric_route_property_macro_ode() -> equation::ODE { } } -fn numeric_route_property_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn numeric_route_property_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, bolus, _rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); dx[0] = bolus[0] - ka * x[0]; @@ -267,12 +267,12 @@ fn numeric_route_property_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("numeric_route_property_one_cpt") + pharmsol::metadata::new("numeric_route_property_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["outeq_1"]) .route( - equation::Route::bolus("input_1") + backends::Route::bolus("input_1") .to_state("depot") .with_lag() .with_bioavailability() @@ -282,7 +282,7 @@ fn numeric_route_property_handwritten_ode() -> equation::ODE { .expect("handwritten numeric route-property metadata should validate") } -fn mixed_output_labels_macro_ode() -> equation::ODE { +fn mixed_output_labels_macro_ode() -> backends::ODE { ode! { name: "mixed_output_labels_one_cpt", params: [ke, v], @@ -302,8 +302,8 @@ fn mixed_output_labels_macro_ode() -> equation::ODE { } } -fn mixed_output_labels_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn mixed_output_labels_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -322,12 +322,12 @@ fn mixed_output_labels_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(3) .with_metadata( - equation::metadata::new("mixed_output_labels_one_cpt") + pharmsol::metadata::new("mixed_output_labels_one_cpt") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp", "outeq_0", "outeq_1"]) .route( - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ), @@ -335,7 +335,7 @@ fn mixed_output_labels_handwritten_ode() -> equation::ODE { .expect("handwritten mixed-output metadata should validate") } -fn covariate_macro_ode() -> equation::ODE { +fn covariate_macro_ode() -> backends::ODE { ode! { name: "covariate_one_cpt", params: [ka, ke, v], @@ -356,8 +356,8 @@ fn covariate_macro_ode() -> equation::ODE { } } -fn covariate_handwritten_ode() -> equation::ODE { - equation::ODE::new( +fn covariate_handwritten_ode() -> backends::ODE { + backends::ODE::new( |x, p, t, dx, bolus, _rateiv, cov| { fetch_cov!(cov, t, wt); fetch_params!(p, ka, ke, _v); @@ -377,13 +377,13 @@ fn covariate_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("covariate_one_cpt") + pharmsol::metadata::new("covariate_one_cpt") .parameters(["ka", "ke", "v"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["gut", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .inject_input_to_destination(), ), diff --git a/tests/ode_optimizations.rs b/tests/ode_optimizations.rs index 12bda743..99a99499 100644 --- a/tests/ode_optimizations.rs +++ b/tests/ode_optimizations.rs @@ -16,7 +16,7 @@ const ABS_TOL: f64 = 1e-6; fn parameters_for_analytical( label: &str, - analytical: &equation::Analytical, + analytical: &backends::Analytical, param_names: &[&str], params: &[f64], ) -> Parameters { @@ -37,7 +37,7 @@ fn parameters_for_analytical( fn parameters_for_ode( label: &str, - ode: &equation::ODE, + ode: &backends::ODE, param_names: &[&str], params: &[f64], ) -> Parameters { @@ -54,38 +54,38 @@ fn parameters_for_ode( } fn with_one_compartment_analytical_metadata( - analytical: equation::Analytical, + analytical: backends::Analytical, model_name: &str, -) -> equation::Analytical { +) -> backends::Analytical { analytical .with_ndrugs(1) .with_metadata( - equation::metadata::new(model_name) - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new(model_name) + .kind(backends::ModelKind::Analytical) .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("iv_bolus").to_state("central"), - equation::Route::infusion("iv").to_state("central"), + backends::Route::bolus("iv_bolus").to_state("central"), + backends::Route::infusion("iv").to_state("central"), ]) - .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + .analytical_kernel(backends::AnalyticalKernel::OneCompartment), ) .expect("one-compartment analytical metadata should validate") } -fn with_one_compartment_ode_metadata(ode: equation::ODE, model_name: &str) -> equation::ODE { +fn with_one_compartment_ode_metadata(ode: backends::ODE, model_name: &str) -> backends::ODE { ode.with_ndrugs(1) .with_metadata( - equation::metadata::new(model_name) + pharmsol::metadata::new(model_name) .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("iv_bolus") + backends::Route::bolus("iv_bolus") .to_state("central") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -94,32 +94,32 @@ fn with_one_compartment_ode_metadata(ode: equation::ODE, model_name: &str) -> eq } fn with_absorption_analytical_metadata( - analytical: equation::Analytical, + analytical: backends::Analytical, model_name: &str, -) -> equation::Analytical { +) -> backends::Analytical { analytical .with_ndrugs(1) .with_metadata( - equation::metadata::new(model_name) - .kind(equation::ModelKind::Analytical) + pharmsol::metadata::new(model_name) + .kind(backends::ModelKind::Analytical) .parameters(["ka", "ke", "v"]) .states(["gut", "central"]) .outputs(["cp"]) - .route(equation::Route::bolus("oral").to_state("gut")) - .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + .route(backends::Route::bolus("oral").to_state("gut")) + .analytical_kernel(backends::AnalyticalKernel::OneCompartmentWithAbsorption), ) .expect("absorption analytical metadata should validate") } -fn with_absorption_ode_metadata(ode: equation::ODE, model_name: &str) -> equation::ODE { +fn with_absorption_ode_metadata(ode: backends::ODE, model_name: &str) -> backends::ODE { ode.with_ndrugs(1) .with_metadata( - equation::metadata::new(model_name) + pharmsol::metadata::new(model_name) .parameters(["ka", "ke", "v"]) .states(["gut", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .expect_explicit_input(), ), @@ -127,16 +127,16 @@ fn with_absorption_ode_metadata(ode: equation::ODE, model_name: &str) -> equatio .expect("absorption ODE metadata should validate") } -fn with_covariate_ode_metadata(ode: equation::ODE, model_name: &str) -> equation::ODE { +fn with_covariate_ode_metadata(ode: backends::ODE, model_name: &str) -> backends::ODE { ode.with_ndrugs(1) .with_metadata( - equation::metadata::new(model_name) + pharmsol::metadata::new(model_name) .parameters(["ke", "v"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["central"]) .outputs(["cp"]) .route( - equation::Route::bolus("iv_bolus") + backends::Route::bolus("iv_bolus") .to_state("central") .expect_explicit_input(), ), @@ -147,8 +147,8 @@ fn with_covariate_ode_metadata(ode: equation::ODE, model_name: &str) -> equation /// Helper to compare ODE vs Analytical predictions fn assert_ode_matches_analytical( label: &str, - analytical: &equation::Analytical, - ode: &equation::ODE, + analytical: &backends::Analytical, + ode: &backends::ODE, subject: &Subject, param_names: &[&str], params: &[f64], @@ -215,7 +215,7 @@ fn single_iv_bolus_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -232,7 +232,7 @@ fn single_iv_bolus_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); // Bolus appears in derivative as instantaneous input @@ -281,7 +281,7 @@ fn multiple_iv_boluses_match_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -298,7 +298,7 @@ fn multiple_iv_boluses_match_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0]; @@ -341,7 +341,7 @@ fn oral_bolus_with_absorption_matches_analytical() { .build(); let analytical = with_absorption_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -358,7 +358,7 @@ fn oral_bolus_with_absorption_matches_analytical() { ); let ode = with_absorption_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; // Gut compartment with oral bolus @@ -409,7 +409,7 @@ fn multiple_oral_doses_match_analytical() { .build(); let analytical = with_absorption_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -426,7 +426,7 @@ fn multiple_oral_doses_match_analytical() { ); let ode = with_absorption_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; @@ -474,7 +474,7 @@ fn single_infusion_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -491,7 +491,7 @@ fn single_infusion_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, _b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0]; @@ -536,7 +536,7 @@ fn overlapping_infusions_match_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -553,7 +553,7 @@ fn overlapping_infusions_match_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, _b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0]; @@ -601,7 +601,7 @@ fn bolus_plus_infusion_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -618,7 +618,7 @@ fn bolus_plus_infusion_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0] + rateiv[0]; @@ -667,7 +667,7 @@ fn complex_dosing_scenario_matches_analytical() { .build(); let analytical = with_absorption_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -684,7 +684,7 @@ fn complex_dosing_scenario_matches_analytical() { ); let ode = with_absorption_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; // Gut: oral doses @@ -734,7 +734,7 @@ fn mixed_bolus_infusion_iv_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -751,7 +751,7 @@ fn mixed_bolus_infusion_iv_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0] + rateiv[0]; @@ -797,7 +797,7 @@ fn bolus_at_observation_time_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -814,7 +814,7 @@ fn bolus_at_observation_time_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0]; @@ -855,7 +855,7 @@ fn very_fast_elimination_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -872,7 +872,7 @@ fn very_fast_elimination_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0]; @@ -914,7 +914,7 @@ fn very_slow_elimination_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -931,7 +931,7 @@ fn very_slow_elimination_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0]; @@ -974,7 +974,7 @@ fn rapid_absorption_matches_analytical() { .build(); let analytical = with_absorption_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -991,7 +991,7 @@ fn rapid_absorption_matches_analytical() { ); let ode = with_absorption_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; @@ -1042,7 +1042,7 @@ fn time_varying_covariates_work_correctly() { // ODE with weight-based clearance let ode = with_covariate_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, t, dx, b, _rateiv, cov| { fetch_params!(p, ke_ref, _v); fetch_cov!(cov, t, wt); @@ -1111,7 +1111,7 @@ fn likelihood_calculation_matches_analytical() { .build(); let analytical = with_one_compartment_analytical_metadata( - equation::Analytical::new( + backends::Analytical::new( one_compartment, |_p, _t, _cov| {}, |_p, _t, _cov| lag! {}, @@ -1128,7 +1128,7 @@ fn likelihood_calculation_matches_analytical() { ); let ode = with_one_compartment_ode_metadata( - equation::ODE::new( + backends::ODE::new( |x, p, _t, dx, b, _rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + b[0]; diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 78c1f7c4..67771a8d 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -66,7 +66,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { } } -fn macro_infusion_sde() -> equation::SDE { +fn macro_infusion_sde() -> backends::SDE { sde! { name: "one_cpt_sde", params: [ke, sigma_ke, v], @@ -88,8 +88,8 @@ fn macro_infusion_sde() -> equation::SDE { } } -fn handwritten_infusion_sde() -> equation::SDE { - equation::SDE::new( +fn handwritten_infusion_sde() -> backends::SDE { + backends::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ke, _sigma_ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -111,13 +111,13 @@ fn handwritten_infusion_sde() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cpt_sde") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("one_cpt_sde") + .kind(backends::ModelKind::Sde) .parameters(["ke", "sigma_ke", "v"]) .states(["central"]) .outputs(["cp"]) .route( - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ) @@ -126,7 +126,7 @@ fn handwritten_infusion_sde() -> equation::SDE { .expect("handwritten SDE metadata should validate") } -fn macro_absorption_sde() -> equation::SDE { +fn macro_absorption_sde() -> backends::SDE { sde! { name: "one_cmt_abs_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -160,8 +160,8 @@ fn macro_absorption_sde() -> equation::SDE { } } -fn handwritten_absorption_sde() -> equation::SDE { - equation::SDE::new( +fn handwritten_absorption_sde() -> backends::SDE { + backends::SDE::new( |x, p, _t, dx, _rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); dx[0] = -ka * x[0]; @@ -194,13 +194,13 @@ fn handwritten_absorption_sde() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_abs_sde") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("one_cmt_abs_sde") + .kind(backends::ModelKind::Sde) .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .inject_input_to_destination() .with_lag() @@ -211,7 +211,7 @@ fn handwritten_absorption_sde() -> equation::SDE { .expect("handwritten absorption SDE metadata should validate") } -fn macro_shared_input_sde() -> equation::SDE { +fn macro_shared_input_sde() -> backends::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -246,8 +246,8 @@ fn macro_shared_input_sde() -> equation::SDE { } } -fn handwritten_shared_input_sde() -> equation::SDE { - equation::SDE::new( +fn handwritten_shared_input_sde() -> backends::SDE { + backends::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); dx[0] = -ka * x[0]; @@ -279,18 +279,18 @@ fn handwritten_shared_input_sde() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_shared_sde") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("one_cmt_shared_sde") + .kind(backends::ModelKind::Sde) .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .inject_input_to_destination() .with_lag() .with_bioavailability(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]) @@ -299,7 +299,7 @@ fn handwritten_shared_input_sde() -> equation::SDE { .expect("handwritten shared-input SDE metadata should validate") } -fn macro_covariate_sde() -> equation::SDE { +fn macro_covariate_sde() -> backends::SDE { sde! { name: "one_cmt_sde_covariates", params: [ka, ke, sigma_ke, v, tlag, f_oral, base_gut, base_central], @@ -342,8 +342,8 @@ fn macro_covariate_sde() -> equation::SDE { } } -fn handwritten_covariate_sde() -> equation::SDE { - equation::SDE::new( +fn handwritten_covariate_sde() -> backends::SDE { + backends::SDE::new( |x, p, t, dx, rateiv, cov| { fetch_params!( p, @@ -454,8 +454,8 @@ fn handwritten_covariate_sde() -> equation::SDE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("one_cmt_sde_covariates") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("one_cmt_sde_covariates") + .kind(backends::ModelKind::Sde) .parameters([ "ka", "ke", @@ -467,18 +467,18 @@ fn handwritten_covariate_sde() -> equation::SDE { "base_central", ]) .covariates([ - equation::Covariate::continuous("wt"), - equation::Covariate::continuous("renal"), + backends::Covariate::continuous("wt"), + backends::Covariate::continuous("renal"), ]) .states(["gut", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("gut") .inject_input_to_destination() .with_lag() .with_bioavailability(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .inject_input_to_destination(), ]) diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index 41755bd7..422a36af 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -97,7 +97,7 @@ pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> } pub fn reference_values() -> Result, Box> { - let model = equation::ODE::new( + let model = backends::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0]; @@ -114,12 +114,12 @@ pub fn reference_values() -> Result, Box> { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new(MODEL_NAME) + pharmsol::metadata::new(MODEL_NAME) .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .route( - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ), diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 55e61b9f..3f971976 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -13,10 +13,10 @@ use diffsol::Vector; use ndarray::Array2; use pharmsol::dsl::{self, CompiledRuntimeModel, RuntimeCompilationTarget, RuntimePredictions}; use pharmsol::prelude::{ - one_compartment_with_absorption, Equation, Prediction, SubjectPredictions, + one_compartment_with_absorption, Simulate, Prediction, SubjectPredictions, }; use pharmsol::{ - equation, fa, fetch_cov, fetch_params, lag, Parameters, Subject, SubjectBuilderExt, SDE, + backends, fa, fetch_cov, fetch_params, lag, Parameters, Subject, SubjectBuilderExt, SDE, }; use tempfile::{tempdir, TempDir}; @@ -770,7 +770,7 @@ fn compare_particle_predictions_pairwise( } fn reference_ode_predictions() -> Result> { - let model = equation::ODE::new( + let model = backends::ODE::new( |x, p, t, dx, bolus, rateiv, cov| { fetch_cov!(cov, t, wt); fetch_params!(p, ka, cl, v, _tlag, _f_oral); @@ -805,18 +805,18 @@ fn reference_ode_predictions() -> Result> { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new(CorpusCase::Ode.label()) + pharmsol::metadata::new(CorpusCase::Ode.label()) .parameters(["ka", "cl", "v", "tlag", "f_oral"]) - .covariates([equation::Covariate::continuous("wt")]) + .covariates([backends::Covariate::continuous("wt")]) .states(["depot", "central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -850,7 +850,7 @@ fn reference_ode_predictions() -> Result> { } fn reference_ode_full_predictions() -> Result> { - let model = equation::ODE::new( + let model = backends::ODE::new( |x, p, t, dx, bolus, rateiv, cov| { fetch_params!( p, @@ -959,7 +959,7 @@ fn reference_ode_full_predictions() -> Result .with_ndrugs(2) .with_nout(1) .with_metadata( - equation::metadata::new(CorpusCase::OdeFull.label()) + pharmsol::metadata::new(CorpusCase::OdeFull.label()) .parameters([ "ka", "ke", @@ -973,21 +973,21 @@ fn reference_ode_full_predictions() -> Result "base_peripheral", ]) .covariates([ - equation::Covariate::continuous("wt"), - equation::Covariate::continuous("renal"), + backends::Covariate::continuous("wt"), + backends::Covariate::continuous("renal"), ]) .states(["depot", "central", "peripheral"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .with_lag() .with_bioavailability() .expect_explicit_input(), - equation::Route::bolus("load") + backends::Route::bolus("load") .to_state("central") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), @@ -1032,7 +1032,7 @@ fn reference_ode_full_predictions() -> Result } fn reference_analytical_predictions() -> Result> { - let model = equation::Analytical::new( + let model = backends::Analytical::new( one_compartment_with_absorption, |_p, _t, _cov| {}, |p, _t, _cov| { @@ -1053,18 +1053,18 @@ fn reference_analytical_predictions() -> Result Result Result> { - let model = equation::Analytical::new( - equation::one_compartment_with_absorption, + let model = backends::Analytical::new( + backends::one_compartment_with_absorption, |_p, _t, _cov| {}, |p, t, cov| { fetch_params!(p, _ka, _ke, _v, tlag, _f_oral, _base_gut, _base_central); @@ -1127,8 +1127,8 @@ fn reference_analytical_full_predictions() -> Result Result Result, Box> { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new(CorpusCase::Sde.label()) - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new(CorpusCase::Sde.label()) + .kind(backends::ModelKind::Sde) .parameters(["ka", "ke0", "kcp", "kpc", "vol", "ske"]) - .covariates([equation::Covariate::locf("wt")]) + .covariates([backends::Covariate::locf("wt")]) .states(["depot", "central", "peripheral", "ke_latent"]) .outputs(["cp"]) .route( - equation::Route::bolus("oral") + backends::Route::bolus("oral") .to_state("depot") .inject_input_to_destination(), ) diff --git a/tests/test_pf.rs b/tests/test_pf.rs index 719aafc3..fa40d0c8 100644 --- a/tests/test_pf.rs +++ b/tests/test_pf.rs @@ -14,7 +14,7 @@ fn test_particle_filter_likelihood() { .observation(1.0, 7.5170, "cp") .build(); - let sde = equation::SDE::new( + let sde = backends::SDE::new( |x, p, _t, dx, _rateiv, _cov| { dx[0] = -x[0] * x[1]; // ke *x[0] dx[1] = -x[1] + p[0]; // mean reverting @@ -37,13 +37,13 @@ fn test_particle_filter_likelihood() { let sde = sde .with_metadata( - equation::metadata::new("particle_filter_test") - .kind(equation::ModelKind::Sde) + pharmsol::metadata::new("particle_filter_test") + .kind(backends::ModelKind::Sde) .parameters(["ke0"]) .states(["central", "ke_latent"]) .outputs(["cp"]) .route( - equation::Route::bolus("dose") + backends::Route::bolus("dose") .to_state("central") .inject_input_to_destination(), ) diff --git a/tests/test_solvers.rs b/tests/test_solvers.rs index f4af986b..cd241ada 100644 --- a/tests/test_solvers.rs +++ b/tests/test_solvers.rs @@ -18,8 +18,8 @@ fn subject() -> Subject { .build() } -fn one_cpt(solver: OdeSolver) -> equation::ODE { - equation::ODE::new( +fn one_cpt(solver: OdeSolver) -> backends::ODE { + backends::ODE::new( |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0] + b[0]; @@ -37,15 +37,15 @@ fn one_cpt(solver: OdeSolver) -> equation::ODE { .with_nout(1) .with_solver(solver) .with_metadata( - equation::metadata::new("solver_selection_one_cpt") + pharmsol::metadata::new("solver_selection_one_cpt") .parameters(["ke", "v"]) .states(["central"]) .outputs(["cp"]) .routes([ - equation::Route::bolus("iv_bolus") + backends::Route::bolus("iv_bolus") .to_state("central") .expect_explicit_input(), - equation::Route::infusion("iv") + backends::Route::infusion("iv") .to_state("central") .expect_explicit_input(), ]), From d85f347aaf5e7927f647f1655f8946f38c13aa4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Wed, 17 Jun 2026 20:41:04 +0100 Subject: [PATCH 3/7] fmt --- benches/common/mod.rs | 11 +- src/core/model_info.rs | 5 +- src/core/solver.rs | 2 +- src/dsl/jit.rs | 6 +- src/dsl/native.rs | 266 ++++++++++++------ src/lib.rs | 26 +- src/parameter_order.rs | 4 +- src/simulator/backends/analytical/mod.rs | 33 ++- .../analytical/one_compartment_cl_models.rs | 2 +- .../analytical/one_compartment_models.rs | 2 +- .../analytical/three_compartment_cl_models.rs | 2 +- .../analytical/three_compartment_models.rs | 2 +- .../analytical/two_compartment_cl_models.rs | 2 +- .../analytical/two_compartment_models.rs | 2 +- src/simulator/backends/ode/mod.rs | 53 ++-- src/simulator/likelihood/matrix.rs | 6 +- src/simulator/likelihood/mod.rs | 22 +- src/simulator/mod.rs | 2 +- tests/authoring_parity_corpus.rs | 4 +- tests/support/runtime_corpus.rs | 2 +- 20 files changed, 287 insertions(+), 167 deletions(-) diff --git a/benches/common/mod.rs b/benches/common/mod.rs index 766ac7d9..21852c7b 100644 --- a/benches/common/mod.rs +++ b/benches/common/mod.rs @@ -12,17 +12,18 @@ #![allow(dead_code)] use ndarray::Array2; +use pharmsol::metadata::{self, ModelMetadata}; use pharmsol::prelude::*; -use pharmsol::simulator::equation::analytical::{ +use pharmsol::simulator::backends::analytical::{ one_compartment_with_absorption, two_compartments, }; use pharmsol::{ - equation::{self, Route}, + backends::{self, Route}, Analytical, ResidualErrorModel, ResidualErrorModels, ODE, SDE, }; /// `ModelMetadata` for handwritten factories so route/output labels resolve like the macro/DSL paths. -fn model_metadata(workload: Workload, kind: SolverKind) -> equation::ModelMetadata { +fn model_metadata(workload: Workload, kind: SolverKind) -> ModelMetadata { let name = match (workload, kind) { (Workload::Short, SolverKind::Ode) => "bench_one_cpt_po_ode", (Workload::Short, SolverKind::Analytical) => "bench_one_cpt_po_analytical", @@ -37,7 +38,7 @@ fn model_metadata(workload: Workload, kind: SolverKind) -> equation::ModelMetada SolverKind::Sde => &["ka", "ke", "v", "sigma_ke"], _ => &["ka", "ke", "v"], }; - equation::metadata::new(name) + metadata::new(name) .parameters(params.iter().copied()) .states(["depot", "central"]) .outputs(["plasma"]) @@ -52,7 +53,7 @@ fn model_metadata(workload: Workload, kind: SolverKind) -> equation::ModelMetada SolverKind::Sde => &["ke", "kcp", "kpc", "v", "sigma_ke"], _ => &["ke", "kcp", "kpc", "v"], }; - equation::metadata::new(name) + metadata::new(name) .parameters(params.iter().copied()) .states(["central", "peripheral"]) .outputs(["plasma"]) diff --git a/src/core/model_info.rs b/src/core/model_info.rs index d31e4c0a..fc58f95a 100644 --- a/src/core/model_info.rs +++ b/src/core/model_info.rs @@ -1,7 +1,7 @@ use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; -use crate::data::{Covariates, InputLabel, OutputLabel}; use crate::core::metadata::RouteKind; +use crate::data::{Covariates, InputLabel, OutputLabel}; use crate::simulator::{Fa, Lag}; use crate::{Event, Occasion, PharmsolError, ValidatedModelMetadata}; @@ -109,8 +109,7 @@ pub trait ModelInfo { bolus.set_input(input); } Event::Infusion(infusion) => { - let input = - self.resolve_input(infusion.input(), RouteKind::Infusion)?; + let input = self.resolve_input(infusion.input(), RouteKind::Infusion)?; infusion.set_input(input); } Event::Observation(observation) => { diff --git a/src/core/solver.rs b/src/core/solver.rs index 30510acb..9a9b2399 100644 --- a/src/core/solver.rs +++ b/src/core/solver.rs @@ -1,8 +1,8 @@ use crate::core::State; +use crate::data::error_model::AssayErrorModels; use crate::data::{Covariates, Infusion}; use crate::simulator::likelihood::Prediction; use crate::{Observation, PharmsolError}; -use crate::data::error_model::AssayErrorModels; /// How to advance a model's state through time. /// diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index 4856fb94..ca93c75a 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1332,10 +1332,10 @@ pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result Result Result { - crate::core::metadata::Route::bolus(route.name.clone()) - } - RouteKind::Infusion => { - crate::core::metadata::Route::infusion(route.name.clone()) - } + RouteKind::Bolus => crate::core::metadata::Route::bolus(route.name.clone()), + RouteKind::Infusion => crate::core::metadata::Route::infusion(route.name.clone()), } .to_state(destination); @@ -1413,11 +1408,15 @@ fn runtime_ode_predictions( } } - impl Solver for NativeOdeModel { type State = V; - fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> V { + fn initial_state( + &self, + _params: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> V { V::zeros(self.shared.info.state_len, NalgebraContext) } @@ -1427,36 +1426,64 @@ impl Solver for NativeOdeModel { } impl ModelInfo for NativeOdeModel { - fn nstates(&self) -> usize { self.shared.info.state_len } - fn ndrugs(&self) -> usize { self.shared.info.route_len } - fn nout(&self) -> usize { self.shared.info.output_len } - fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } - fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } - fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } + fn nstates(&self) -> usize { + self.shared.info.state_len + } + fn ndrugs(&self) -> usize { + self.shared.info.route_len + } + fn nout(&self) -> usize { + self.shared.info.output_len + } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + Some(self.shared.metadata()) + } + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } } impl Caching for NativeOdeModel { - fn prediction_cache(&self) -> Option<&PredictionCache> { self.cache.as_ref() } - fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { self.error_model_cache.as_ref() } + fn prediction_cache(&self) -> Option<&PredictionCache> { + self.cache.as_ref() + } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + self.error_model_cache.as_ref() + } fn with_cache_capacity(mut self, size: u64) -> Self { self.cache = Some(PredictionCache::new(size)); - self.error_model_cache = Some(BoundErrorModelCache::new(DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE)); + self.error_model_cache = Some(BoundErrorModelCache::new( + DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE, + )); + self + } + fn without_cache(mut self) -> Self { + self.cache = None; + self.error_model_cache = None; self } - fn without_cache(mut self) -> Self { self.cache = None; self.error_model_cache = None; self } fn clear_cache(&self) { - if let Some(c) = &self.cache { c.invalidate_all(); } - if let Some(c) = &self.error_model_cache { c.invalidate_all(); } + if let Some(c) = &self.cache { + c.invalidate_all(); + } + if let Some(c) = &self.error_model_cache { + c.invalidate_all(); + } } } impl Simulate for NativeOdeModel { type Predictions = SubjectPredictions; - fn simulate_subject(&self, subject: &Subject, parameters: &[f64], - error_models: Option<&AssayErrorModels>) - -> Result<(Self::Predictions, Option), PharmsolError> - { + fn simulate_subject( + &self, + subject: &Subject, + parameters: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError> { let bound_error_models = match error_models { Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, @@ -1469,15 +1496,20 @@ impl Simulate for NativeOdeModel { Ok((predictions, likelihood)) } - fn log_likelihood(&self, subject: &Subject, params: &[f64], - error_models: &AssayErrorModels) -> Result - { + fn log_likelihood( + &self, + subject: &Subject, + params: &[f64], + error_models: &AssayErrorModels, + ) -> Result { let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; let predictions = runtime_ode_predictions(self, subject, params)?; predictions.log_likelihood(&bound) } - fn kind() -> ModelKind { ModelKind::Ode } + fn kind() -> ModelKind { + ModelKind::Ode + } } impl NativeAnalyticalModel { @@ -1679,43 +1711,75 @@ fn runtime_analytical_predictions( } } - impl Solver for NativeAnalyticalModel { type State = V; - fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> V { + fn initial_state( + &self, + _params: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> V { V::zeros(self.shared.info.state_len, NalgebraContext) } - fn is_batch(&self) -> bool { true } + fn is_batch(&self) -> bool { + true + } } impl ModelInfo for NativeAnalyticalModel { - fn nstates(&self) -> usize { self.shared.info.state_len } - fn ndrugs(&self) -> usize { self.shared.info.route_len } - fn nout(&self) -> usize { self.shared.info.output_len } - fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } - fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } - fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } + fn nstates(&self) -> usize { + self.shared.info.state_len + } + fn ndrugs(&self) -> usize { + self.shared.info.route_len + } + fn nout(&self) -> usize { + self.shared.info.output_len + } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + Some(self.shared.metadata()) + } + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } } impl Caching for NativeAnalyticalModel { - fn prediction_cache(&self) -> Option<&PredictionCache> { self.cache.as_ref() } - fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { None } + fn prediction_cache(&self) -> Option<&PredictionCache> { + self.cache.as_ref() + } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + None + } fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(PredictionCache::new(size)); self + self.cache = Some(PredictionCache::new(size)); + self + } + fn without_cache(mut self) -> Self { + self.cache = None; + self + } + fn clear_cache(&self) { + if let Some(c) = &self.cache { + c.invalidate_all(); + } } - fn without_cache(mut self) -> Self { self.cache = None; self } - fn clear_cache(&self) { if let Some(c) = &self.cache { c.invalidate_all(); } } } impl Simulate for NativeAnalyticalModel { type Predictions = SubjectPredictions; - fn simulate_subject(&self, subject: &Subject, parameters: &[f64], - error_models: Option<&AssayErrorModels>) - -> Result<(Self::Predictions, Option), PharmsolError> - { + fn simulate_subject( + &self, + subject: &Subject, + parameters: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError> { let bound_em = match error_models { Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, @@ -1728,15 +1792,20 @@ impl Simulate for NativeAnalyticalModel { Ok((predictions, likelihood)) } - fn log_likelihood(&self, subject: &Subject, params: &[f64], - error_models: &AssayErrorModels) -> Result - { + fn log_likelihood( + &self, + subject: &Subject, + params: &[f64], + error_models: &AssayErrorModels, + ) -> Result { let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; let predictions = runtime_analytical_predictions(self, subject, params)?; predictions.log_likelihood(&bound) } - fn kind() -> ModelKind { ModelKind::Analytical } + fn kind() -> ModelKind { + ModelKind::Analytical + } } impl NativeSdeModel { @@ -2059,44 +2128,78 @@ fn runtime_sde_log_likelihood( } } - impl Solver for NativeSdeModel { type State = Vec>; - fn initial_state(&self, _params: &[f64], _covariates: &Covariates, _occasion_index: usize) -> Vec> { + fn initial_state( + &self, + _params: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> Vec> { vec![DVector::zeros(self.shared.info.state_len); self.nparticles] } - fn nparticles(&self) -> usize { self.nparticles } - fn is_batch(&self) -> bool { true } + fn nparticles(&self) -> usize { + self.nparticles + } + fn is_batch(&self) -> bool { + true + } } impl ModelInfo for NativeSdeModel { - fn nstates(&self) -> usize { self.shared.info.state_len } - fn ndrugs(&self) -> usize { self.shared.info.route_len } - fn nout(&self) -> usize { self.shared.info.output_len } - fn metadata(&self) -> Option<&ValidatedModelMetadata> { Some(self.shared.metadata()) } - fn lag(&self) -> &Lag { &(runtime_no_lag as Lag) } - fn fa(&self) -> &Fa { &(runtime_no_fa as Fa) } + fn nstates(&self) -> usize { + self.shared.info.state_len + } + fn ndrugs(&self) -> usize { + self.shared.info.route_len + } + fn nout(&self) -> usize { + self.shared.info.output_len + } + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + Some(self.shared.metadata()) + } + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } } impl Caching for NativeSdeModel { - fn prediction_cache(&self) -> Option<&PredictionCache> { None } - fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { None } + fn prediction_cache(&self) -> Option<&PredictionCache> { + None + } + fn error_model_cache(&self) -> Option<&BoundErrorModelCache> { + None + } fn with_cache_capacity(mut self, size: u64) -> Self { - self.cache = Some(SdeLikelihoodCache::new(size)); self + self.cache = Some(SdeLikelihoodCache::new(size)); + self + } + fn without_cache(mut self) -> Self { + self.cache = None; + self + } + fn clear_cache(&self) { + if let Some(c) = &self.cache { + c.invalidate_all(); + } } - fn without_cache(mut self) -> Self { self.cache = None; self } - fn clear_cache(&self) { if let Some(c) = &self.cache { c.invalidate_all(); } } } impl Simulate for NativeSdeModel { type Predictions = Array2; - fn simulate_subject(&self, subject: &Subject, parameters: &[f64], - error_models: Option<&AssayErrorModels>) - -> Result<(Self::Predictions, Option), PharmsolError> - { + fn simulate_subject( + &self, + subject: &Subject, + parameters: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::Predictions, Option), PharmsolError> { let bound_em = match error_models { Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), None => None, @@ -2109,14 +2212,19 @@ impl Simulate for NativeSdeModel { Ok((predictions, likelihood)) } - fn log_likelihood(&self, subject: &Subject, params: &[f64], - error_models: &AssayErrorModels) -> Result - { + fn log_likelihood( + &self, + subject: &Subject, + params: &[f64], + error_models: &AssayErrorModels, + ) -> Result { let bound = crate::core::simulate::bind_error_models_inner(self, error_models)?; runtime_sde_log_likelihood(self, subject, params, &bound) } - fn kind() -> ModelKind { ModelKind::Sde } + fn kind() -> ModelKind { + ModelKind::Sde + } } fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { diff --git a/src/lib.rs b/src/lib.rs index 97db962f..6861a49b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,11 +110,11 @@ #[cfg(feature = "dsl-aot")] mod build_support; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub mod core; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod data; #[cfg(feature = "dsl-core")] pub mod dsl; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub mod core; pub mod error; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod nca; @@ -135,6 +135,16 @@ mod test_fixtures; //extension traits #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::core::metadata; +pub use crate::core::metadata::{ + ModelMetadata, ModelMetadataError, NameDomain, RouteInputPolicy, RouteKind, + ValidatedModelMetadata, +}; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::core::{Caching, ModelInfo, Simulate, Solver}; +pub use crate::core::{Predictions, State}; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::builder::SubjectBuilderExt; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::Interpolation::*; @@ -147,21 +157,11 @@ pub use crate::optimize::parameters::ParameterOptimizer; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::simulator::backends::analytical::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::core::{Caching, ModelInfo, Simulate, Solver}; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::core::metadata; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::simulator::backends::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, Analytical, AnalyticalKernel, ModelKind, ODE, SDE, }; -pub use crate::core::metadata::{ - ModelMetadata, ModelMetadataError, NameDomain, RouteInputPolicy, RouteKind, - ValidatedModelMetadata, -}; -pub use crate::core::{Predictions, State}; pub use error::PharmsolError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use nalgebra::dmatrix; @@ -254,11 +254,11 @@ pub mod prelude { // Direct simulator re-exports for convenience pub use crate::simulator::{ - cache::{PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, backends::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, }, + cache::{PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, likelihood::{Prediction, SubjectPredictions}, }; diff --git a/src/parameter_order.rs b/src/parameter_order.rs index 8c095720..6db5ceda 100644 --- a/src/parameter_order.rs +++ b/src/parameter_order.rs @@ -4,10 +4,10 @@ use std::collections::HashMap; use std::error::Error; use std::fmt; -#[cfg(feature = "dsl-core")] -use crate::dsl::NativeModelInfo; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] use crate::core::metadata::ValidatedModelMetadata; +#[cfg(feature = "dsl-core")] +use crate::dsl::NativeModelInfo; #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct ParameterOrderPlan { diff --git a/src/simulator/backends/analytical/mod.rs b/src/simulator/backends/analytical/mod.rs index 56803878..23933f99 100644 --- a/src/simulator/backends/analytical/mod.rs +++ b/src/simulator/backends/analytical/mod.rs @@ -5,8 +5,8 @@ pub mod three_compartment_models; pub mod two_compartment_cl_models; pub mod two_compartment_models; -use diffsol::{NalgebraContext, Vector, VectorHost}; use crate::core::ModelInfo; +use diffsol::{NalgebraContext, Vector, VectorHost}; pub use one_compartment_cl_models::*; pub use one_compartment_models::*; use pharmsol_dsl::ModelKind; @@ -20,9 +20,7 @@ use crate::simulator::backends::parameters_hash; use crate::core::metadata::{ModelMetadata, ModelMetadataError, ValidatedModelMetadata}; use crate::data::error_model::AssayErrorModels; -use crate::simulator::cache::{ - BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE, -}; +use crate::simulator::cache::{BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE}; use crate::simulator::likelihood::Prediction; use crate::PharmsolError; use crate::{data::Covariates, simulator::*, Observation, Subject}; @@ -117,7 +115,9 @@ impl Analytical { mut self, metadata: ModelMetadata, ) -> Result { - let validated = metadata.validate_for(ModelKind::Analytical).map_err(AnalyticalMetadataError::Validation)?; + let validated = metadata + .validate_for(ModelKind::Analytical) + .map_err(AnalyticalMetadataError::Validation)?; validate_metadata_dimensions(&validated, &self.core.dims())?; self.core.set_metadata(validated); Ok(self) @@ -680,11 +680,11 @@ impl crate::core::Solver for Analytical { let s = inf.time(); let e = s + inf.duration(); if current_t >= s && next_t <= e { - let input = inf.input_index().ok_or_else(|| { - PharmsolError::UnknownInputLabel { - label: inf.input().to_string(), - } - })?; + let input = + inf.input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + })?; if input >= self.ndrugs() { return Err(PharmsolError::InputOutOfRange { input, @@ -720,11 +720,11 @@ impl crate::core::Solver for Analytical { covariates, &mut y, ); - let outeq = observation.outeq_index().ok_or_else(|| { - PharmsolError::UnknownOutputLabel { + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { label: observation.outeq().to_string(), - } - })?; + })?; let pred = observation.to_prediction(y[outeq], x.as_slice().to_vec()); let lik = error_models .map(|em| pred.log_likelihood(em).map(f64::exp)) @@ -823,7 +823,10 @@ impl crate::core::Simulate for Analytical { } let result = crate::core::standard_event_loop::( - self, subject, params, error_models, + self, + subject, + params, + error_models, )?; if error_models.is_none() { diff --git a/src/simulator/backends/analytical/one_compartment_cl_models.rs b/src/simulator/backends/analytical/one_compartment_cl_models.rs index 00742ce4..f0ffecbf 100755 --- a/src/simulator/backends/analytical/one_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/one_compartment_cl_models.rs @@ -57,8 +57,8 @@ pub fn pm_one_compartment_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{one_compartment_cl, one_compartment_cl_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/analytical/one_compartment_models.rs b/src/simulator/backends/analytical/one_compartment_models.rs index b8581b4a..55c8c0b1 100644 --- a/src/simulator/backends/analytical/one_compartment_models.rs +++ b/src/simulator/backends/analytical/one_compartment_models.rs @@ -50,8 +50,8 @@ pub fn pm_one_compartment_with_absorption(x: &V, p: &V, t: T, rateiv: &V, cov: & #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{one_compartment, one_compartment_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/analytical/three_compartment_cl_models.rs b/src/simulator/backends/analytical/three_compartment_cl_models.rs index 49be1c2c..34c5ae27 100644 --- a/src/simulator/backends/analytical/three_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/three_compartment_cl_models.rs @@ -79,8 +79,8 @@ pub fn pm_three_compartments_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{three_compartments_cl, three_compartments_cl_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/analytical/three_compartment_models.rs b/src/simulator/backends/analytical/three_compartment_models.rs index d4b65914..d40666b7 100644 --- a/src/simulator/backends/analytical/three_compartment_models.rs +++ b/src/simulator/backends/analytical/three_compartment_models.rs @@ -252,8 +252,8 @@ pub fn pm_three_compartments_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{three_compartments, three_compartments_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/analytical/two_compartment_cl_models.rs b/src/simulator/backends/analytical/two_compartment_cl_models.rs index 895f694c..d4ec1509 100644 --- a/src/simulator/backends/analytical/two_compartment_cl_models.rs +++ b/src/simulator/backends/analytical/two_compartment_cl_models.rs @@ -65,8 +65,8 @@ pub fn pm_two_compartments_cl_with_absorption( #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{two_compartments_cl, two_compartments_cl_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/analytical/two_compartment_models.rs b/src/simulator/backends/analytical/two_compartment_models.rs index f3d5cc27..d01d82a2 100644 --- a/src/simulator/backends/analytical/two_compartment_models.rs +++ b/src/simulator/backends/analytical/two_compartment_models.rs @@ -118,8 +118,8 @@ pub fn pm_two_compartments_with_absorption(x: &V, p: &V, t: T, rateiv: &V, cov: #[cfg(test)] mod tests { use super::super::tests::SubjectInfo; - use crate::core::Simulate; use super::{two_compartments, two_compartments_with_absorption}; + use crate::core::Simulate; use crate::*; use approx::assert_relative_eq; diff --git a/src/simulator/backends/ode/mod.rs b/src/simulator/backends/ode/mod.rs index 9bb5620f..2efdf526 100644 --- a/src/simulator/backends/ode/mod.rs +++ b/src/simulator/backends/ode/mod.rs @@ -19,11 +19,9 @@ use crate::{ Event, PharmsolError, Subject, }; -use crate::simulator::backends::parameters_hash; -use crate::simulator::cache::{ - BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE, -}; use crate::core::Predictions; +use crate::simulator::backends::parameters_hash; +use crate::simulator::cache::{BoundErrorModelCache, PredictionCache, DEFAULT_CACHE_SIZE}; use closure::PMProblem; use diffsol::{ error::OdeSolverError, ode_solver::method::OdeSolverMethod, NalgebraContext, OdeBuilder, @@ -155,7 +153,9 @@ impl ODE { /// Attach validated handwritten-model metadata to this ODE. pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { - let validated = metadata.validate_for(ModelKind::Ode).map_err(OdeMetadataError::Validation)?; + let validated = metadata + .validate_for(ModelKind::Ode) + .map_err(OdeMetadataError::Validation)?; validate_metadata_dimensions(&validated, &self.core.dims())?; self.core.set_metadata(validated); Ok(self) @@ -203,11 +203,12 @@ impl ODE { match event { Event::Bolus(bolus) => { - let input = bolus.input_index().ok_or_else(|| { - PharmsolError::UnknownInputLabel { - label: bolus.input().to_string(), - } - })?; + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; if input >= bolus_v.len() { return Err(PharmsolError::InputOutOfRange { input, @@ -221,12 +222,22 @@ impl ODE { state_without_bolus.fill(0.0); (self.diffeq)( - solver.state().y, parameters_v, event.time(), - state_without_bolus, zero_bolus, zero_rateiv, covariates, + solver.state().y, + parameters_v, + event.time(), + state_without_bolus, + zero_bolus, + zero_rateiv, + covariates, ); (self.diffeq)( - solver.state().y, parameters_v, event.time(), - state_with_bolus, bolus_v, zero_rateiv, covariates, + solver.state().y, + parameters_v, + event.time(), + state_with_bolus, + bolus_v, + zero_rateiv, + covariates, ); state_with_bolus.axpy(-1.0, state_without_bolus, 1.0); solver.state_mut().y.axpy(1.0, state_with_bolus, 1.0); @@ -235,7 +246,11 @@ impl ODE { Event::Observation(observation) => { y_out.fill(0.0); (self.out)( - solver.state().y, parameters_v, observation.time(), covariates, y_out, + solver.state().y, + parameters_v, + observation.time(), + covariates, + y_out, ); let outeq = observation.outeq_index().ok_or_else(|| { PharmsolError::UnknownOutputLabel { @@ -365,7 +380,10 @@ fn _simulate_subject_dense( error_models: Option<&AssayErrorModels>, ) -> Result<(SubjectPredictions, Option), PharmsolError> { let bound_error_models = match error_models { - Some(error_models) => Some(crate::core::simulate::bind_error_models_inner(ode, error_models)?), + Some(error_models) => Some(crate::core::simulate::bind_error_models_inner( + ode, + error_models, + )?), None => None, }; let bound_error_models = bound_error_models.as_deref(); @@ -620,7 +638,8 @@ impl crate::core::Simulate for ODE { params: &[f64], error_models: &AssayErrorModels, ) -> Result { - let bound_error_models = crate::core::simulate::bind_error_models_inner(self, error_models)?; + let bound_error_models = + crate::core::simulate::bind_error_models_inner(self, error_models)?; let ypred = _subject_predictions(self, subject, params)?; ypred.log_likelihood(&bound_error_models) } diff --git a/src/simulator/likelihood/matrix.rs b/src/simulator/likelihood/matrix.rs index af128155..058d3c1c 100644 --- a/src/simulator/likelihood/matrix.rs +++ b/src/simulator/likelihood/matrix.rs @@ -85,11 +85,7 @@ pub fn log_likelihood_matrix( let subject = &subject_slice[i]; for (element, support_point) in row.iter_mut().zip(support_point_rows.iter()) { - *element = model.log_likelihood( - subject, - support_point.as_slice(), - error_models, - )?; + *element = model.log_likelihood(subject, support_point.as_slice(), error_models)?; if let Some(ref tracker) = progress_tracker { tracker.inc(); } diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index d25a8c2d..2c6967b2 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -140,13 +140,10 @@ pub fn log_likelihood_batch( Err(_) => return f64::NEG_INFINITY, }; - let obs_pred_pairs = predictions - .predictions() - .iter() - .filter_map(|pred| { - pred.observation() - .map(|obs| (pred.outeq(), obs, pred.prediction())) - }); + let obs_pred_pairs = predictions.predictions().iter().filter_map(|pred| { + pred.observation() + .map(|obs| (pred.outeq(), obs, pred.prediction())) + }); residual_error_models.total_log_likelihood(obs_pred_pairs) }; @@ -216,13 +213,10 @@ pub fn log_likelihood_subject( }; // Extract (outeq, observation, prediction) tuples and compute log-likelihood - let obs_pred_pairs = predictions - .predictions() - .iter() - .filter_map(|pred| { - pred.observation() - .map(|obs| (pred.outeq(), obs, pred.prediction())) - }); + let obs_pred_pairs = predictions.predictions().iter().filter_map(|pred| { + pred.observation() + .map(|obs| (pred.outeq(), obs, pred.prediction())) + }); residual_error_models.total_log_likelihood(obs_pred_pairs) } diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 3b92622b..5772539c 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -1,5 +1,5 @@ -pub mod cache; pub mod backends; +pub mod cache; pub(crate) mod likelihood; use diffsol::{NalgebraMat, NalgebraVec}; diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index db3ab8f1..a2643345 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1,12 +1,12 @@ #[cfg(feature = "dsl-jit")] use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] -use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; -#[cfg(feature = "dsl-jit")] use pharmsol::backends::RouteInputPolicy; use pharmsol::backends::{ self, AnalyticalKernel, RouteKind as HandwrittenRouteKind, ValidatedModelMetadata, }; +#[cfg(feature = "dsl-jit")] +use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; use pharmsol::prelude::*; #[cfg(feature = "dsl-jit")] use pharmsol::Predictions; diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 3f971976..abff37d4 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -13,7 +13,7 @@ use diffsol::Vector; use ndarray::Array2; use pharmsol::dsl::{self, CompiledRuntimeModel, RuntimeCompilationTarget, RuntimePredictions}; use pharmsol::prelude::{ - one_compartment_with_absorption, Simulate, Prediction, SubjectPredictions, + one_compartment_with_absorption, Prediction, Simulate, SubjectPredictions, }; use pharmsol::{ backends, fa, fetch_cov, fetch_params, lag, Parameters, Subject, SubjectBuilderExt, SDE, From 356cf2ffc8667b6c9f4949df3a27c93db5e213aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 18 Jun 2026 10:29:43 +0100 Subject: [PATCH 4/7] Trying to fix compilation on Ubuntu --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 6861a49b..4da99af1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,14 +135,15 @@ mod test_fixtures; //extension traits #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::core::metadata; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::core::metadata::{ ModelMetadata, ModelMetadataError, NameDomain, RouteInputPolicy, RouteKind, ValidatedModelMetadata, }; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::core::{Caching, ModelInfo, Simulate, Solver}; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::core::{Predictions, State}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::builder::SubjectBuilderExt; From a225aa8752d4daf2a330167d980c12e2276907a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 18 Jun 2026 10:43:06 +0100 Subject: [PATCH 5/7] answering copilot's comments --- src/core/simulate.rs | 24 +----------------------- src/dsl/native.rs | 8 ++++---- src/optimize/mod.rs | 2 +- src/optimize/parameters.rs | 6 +++--- 4 files changed, 9 insertions(+), 31 deletions(-) diff --git a/src/core/simulate.rs b/src/core/simulate.rs index 239c2eb6..7c9129c8 100644 --- a/src/core/simulate.rs +++ b/src/core/simulate.rs @@ -116,16 +116,6 @@ where S: Solver + ModelInfo + Caching, P: PredictionsContainer, { - // Check prediction cache - if let (Some(cache), None) = (model.prediction_cache(), error_models) { - let key = (subject.hash(), parameters_hash(params)); - // Cache hit would need to return (P, None) but P isn't necessarily the same - // type as what's in the cache. We skip cache-based return here and let - // individual backends handle caching in their simulate_subject impl. - // The cache check pattern is used by Analytical and ODE backends. - let _ = (cache, key); - } - let bound_error_models = match error_models { Some(em) => Some(bind_error_models_inner(model, em)?), None => None, @@ -166,7 +156,7 @@ where &state, params, observation, - error_models, + bound_error_models.as_deref(), covariates, )?; if let Some(lik) = lik { @@ -239,15 +229,3 @@ pub(crate) fn bind_error_models_inner<'a, M: ModelInfo + Caching>( ) .map_err(PharmsolError::from) } - -/// Hash a parameter slice for cache keys. -#[inline(always)] -pub(crate) fn parameters_hash(params: &[f64]) -> u64 { - use std::hash::{Hash, Hasher}; - let mut hasher = ahash::AHasher::default(); - for &value in params { - let bits = if value == 0.0 { 0u64 } else { value.to_bits() }; - bits.hash(&mut hasher); - } - hasher.finish() -} diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 518c954b..f34e64e8 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1394,7 +1394,7 @@ fn runtime_ode_predictions( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::core::simulate::parameters_hash(support_point), + crate::simulator::backends::parameters_hash(support_point), ); if let Some(cached) = cache.get(&key) { return Ok(cached); @@ -1697,7 +1697,7 @@ fn runtime_analytical_predictions( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::core::simulate::parameters_hash(support_point), + crate::simulator::backends::parameters_hash(support_point), ); if let Some(cached) = cache.get(&key) { return Ok(cached); @@ -2111,7 +2111,7 @@ fn runtime_sde_log_likelihood( if let Some(cache) = &model.cache { let key = ( subject.hash(), - crate::core::simulate::parameters_hash(support_point), + crate::simulator::backends::parameters_hash(support_point), error_models.hash(), ); if let Some(cached) = cache.get(&key) { @@ -3052,7 +3052,7 @@ mod tests { let expected = SubjectPredictions::default(); let key = ( subject.hash(), - crate::core::simulate::parameters_hash(parameters.as_slice()), + crate::simulator::backends::parameters_hash(parameters.as_slice()), ); model diff --git a/src/optimize/mod.rs b/src/optimize/mod.rs index b044108b..6f70c83e 100644 --- a/src/optimize/mod.rs +++ b/src/optimize/mod.rs @@ -4,7 +4,7 @@ //! //! - [`effect`] — Find the maximum effect (`E2`) for dual-site PD models //! via Nelder‑Mead optimization in log‑space. -//! - [`parameters`] — Nelder‑Mead parameter refinement for an [`Equation`] +//! - [`parameters`] — Nelder‑Mead parameter refinement for a [`Simulate`] model //! against a [`Data`] set and [`AssayErrorModels`]. pub mod effect; diff --git a/src/optimize/parameters.rs b/src/optimize/parameters.rs index fa8a6436..e935fca8 100644 --- a/src/optimize/parameters.rs +++ b/src/optimize/parameters.rs @@ -1,9 +1,9 @@ //! Nelder‑Mead parameter refinement for pharmacometric models. //! //! This module provides a [`ParameterOptimizer`] that refines a single parameter -//! Given an [`Equation`], observed [`Data`], and [`AssayErrorModels`] via -//! Nelder‑Mead optimization in log‑space. The optimizer finds the parameter vector -//! that minimizes the negative log-likelihood of the model predictions against the data, +//! vector. Given a [`Simulate`] model, observed [`Data`], and [`AssayErrorModels`], +//! it runs Nelder‑Mead optimization in log‑space to find the parameter vector that +//! minimizes the negative log-likelihood of the model predictions against the data, //! as measured by the provided error models. use argmin::{ From c46ce7d6e218e5cfa254a283d798fa3900477841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 18 Jun 2026 11:07:03 +0100 Subject: [PATCH 6/7] force SDEs to use the standard loop --- src/core/simulate.rs | 6 +- src/core/solver.rs | 14 ++- src/dsl/jit.rs | 2 +- src/simulator/backends/analytical/mod.rs | 2 +- src/simulator/backends/sde/mod.rs | 126 ++++------------------- 5 files changed, 39 insertions(+), 111 deletions(-) diff --git a/src/core/simulate.rs b/src/core/simulate.rs index 7c9129c8..19821368 100644 --- a/src/core/simulate.rs +++ b/src/core/simulate.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::core::{Caching, ModelInfo, Solver, State}; +use crate::core::{Caching, ModelInfo, Solver}; use crate::data::error_model::{AssayErrorModels, BoundAssayErrorModels}; use crate::simulator::likelihood::Prediction; use crate::{Event, Infusion, Parameters, PharmsolError, Subject}; @@ -146,14 +146,14 @@ where ndrugs: model.ndrugs(), }); } - state.add_bolus(input, bolus.amount()); + model.process_bolus(&mut state, input, bolus.amount()); } Event::Infusion(infusion) => { infusions.push(infusion.clone()); } Event::Observation(observation) => { let (pred, lik) = model.process_observation( - &state, + &mut state, params, observation, bound_error_models.as_deref(), diff --git a/src/core/solver.rs b/src/core/solver.rs index 9a9b2399..bd9b087e 100644 --- a/src/core/solver.rs +++ b/src/core/solver.rs @@ -72,11 +72,23 @@ pub trait Solver { ) } + /// Apply a bolus to the state, with optional solver-specific preprocessing. + /// + /// The default implementation calls [`State::add_bolus`]. Override this if + /// the solver needs to redirect or transform bolus inputs before application + /// (e.g. SDE injected-bolus mappings). + fn process_bolus(&self, state: &mut Self::State, input: usize, amount: f64) { + state.add_bolus(input, amount); + } + /// Compute a prediction (and optionally a likelihood component) from the /// current state at an observation time point. + /// + /// The state is `&mut` so that backends that need post-observation mutation + /// (e.g. SDE resampling) can perform it inline. fn process_observation( &self, - _state: &Self::State, + _state: &mut Self::State, _params: &[f64], _observation: &Observation, _error_models: Option<&AssayErrorModels>, diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index ca93c75a..fc1edf05 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1337,7 +1337,7 @@ mod tests { use crate::simulator::backends::analytical::one_compartment_with_absorption; use crate::simulator::backends::ode::{ExplicitRkTableau, OdeSolver}; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; - use crate::{backends, Parameters, Subject, SubjectBuilderExt, ODE, SDE}; + use crate::{Parameters, Subject, SubjectBuilderExt, ODE, SDE}; use approx::assert_relative_eq; use diffsol::Vector; use pharmsol_dsl::execution::DenseBufferLayout; diff --git a/src/simulator/backends/analytical/mod.rs b/src/simulator/backends/analytical/mod.rs index 23933f99..818bd6f0 100644 --- a/src/simulator/backends/analytical/mod.rs +++ b/src/simulator/backends/analytical/mod.rs @@ -706,7 +706,7 @@ impl crate::core::Solver for Analytical { fn process_observation( &self, - x: &Self::State, + x: &mut Self::State, parameters: &[f64], observation: &Observation, error_models: Option<&AssayErrorModels>, diff --git a/src/simulator/backends/sde/mod.rs b/src/simulator/backends/sde/mod.rs index 9f3d7c7d..b2a9b9a1 100644 --- a/src/simulator/backends/sde/mod.rs +++ b/src/simulator/backends/sde/mod.rs @@ -2,19 +2,19 @@ mod em; use diffsol::{NalgebraContext, Vector}; use nalgebra::DVector; -use ndarray::{concatenate, Array2, Axis}; +use ndarray::Array2; use pharmsol_dsl::ModelKind; use rand::{rng, RngExt}; use rayon::prelude::*; use thiserror::Error; -use crate::core::{ModelInfo, Simulate, Solver}; +use crate::core::{ModelInfo, Simulate}; use crate::{ data::{Covariates, Infusion}, error_model::AssayErrorModels, prelude::simulator::Prediction, simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V}, - Event, Observation, PharmsolError, Subject, + Observation, PharmsolError, Subject, }; use crate::simulator::backends::parameters_hash; @@ -236,7 +236,7 @@ impl crate::core::PredictionsContainer for Array2 { } fn push(&mut self, pred: Prediction) { - let col = Array2::from_shape_vec((self.nrows(), 1), vec![pred]).unwrap(); + let col = Array2::from_shape_vec((self.nrows(), 1), vec![pred; self.nrows()]).unwrap(); *self = ndarray::concatenate(ndarray::Axis(1), &[self.view(), col.view()]).unwrap(); } @@ -403,9 +403,15 @@ impl crate::core::Solver for SDE { Ok(()) } + fn process_bolus(&self, state: &mut Self::State, input: usize, amount: f64) { + if !self.injected_bolus_mappings.apply(state, input, amount) { + state.add_bolus(input, amount); + } + } + fn process_observation( &self, - x: &Self::State, + x: &mut Self::State, parameters: &[f64], observation: &Observation, error_models: Option<&AssayErrorModels>, @@ -430,15 +436,16 @@ impl crate::core::Solver for SDE { *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); - // Resampling and likelihood computation + // Resampling — mutate state to concentrate particles on high-likelihood regions let lik = if let Some(em) = error_models { let q: Vec = preds .iter() .map(|p| p.log_likelihood(em).map(f64::exp).unwrap_or(0.0)) .collect(); let sum_q: f64 = q.iter().sum(); - // Note: resampling is skipped here because state is borrowed. - // Full resampling happens in simulate_subject. + let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); + let indices = sysresample(&w); + *x = indices.iter().map(|&i| x[i].clone()).collect(); Some(sum_q / nparticles as f64) } else { None @@ -542,103 +549,12 @@ impl crate::core::Simulate for SDE { params: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::Predictions, Option), PharmsolError> { - let bound_em = match error_models { - Some(em) => Some(crate::core::simulate::bind_error_models_inner(self, em)?), - None => None, - }; - - let mut output = - Array2::::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); - let mut likelihood = Vec::new(); - - for occasion in subject.occasions() { - let covariates = occasion.covariates(); - let events = self.resolve_events(occasion, params, covariates)?; - let mut state = self.initial_state(params, covariates, occasion.index()); - let mut infusions: Vec = Vec::new(); - - for (idx, event) in events.iter().enumerate() { - match event { - Event::Bolus(bolus) => { - let input = bolus.input_index().ok_or_else(|| { - PharmsolError::UnknownInputLabel { - label: bolus.input().to_string(), - } - })?; - if input >= self.ndrugs() { - return Err(PharmsolError::InputOutOfRange { - input, - ndrugs: self.ndrugs(), - }); - } - if !self - .injected_bolus_mappings - .apply(&mut state, input, bolus.amount()) - { - state.add_bolus(input, bolus.amount()); - } - } - Event::Infusion(inf) => infusions.push(inf.clone()), - Event::Observation(obs) => { - // Compute predictions across particles - let mut preds = vec![Prediction::default(); self.nparticles]; - preds.par_iter_mut().enumerate().for_each(|(i, p)| { - let mut y = V::zeros(self.nout(), NalgebraContext); - (self.out)( - &state[i].clone().into(), - &V::from_vec(params.to_vec(), NalgebraContext), - obs.time(), - covariates, - &mut y, - ); - let outeq = obs.outeq_index().expect("resolved obs"); - *p = obs.to_prediction(y[outeq], state[i].as_slice().to_vec()); - }); - - // Resampling - if let Some(em) = &bound_em { - let q: Vec = preds - .iter() - .map(|p| p.log_likelihood(em).map(f64::exp).unwrap_or(0.0)) - .collect(); - let sum_q: f64 = q.iter().sum(); - let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); - let indices = sysresample(&w); - state = indices.iter().map(|&i| state[i].clone()).collect(); - likelihood.push(sum_q / self.nparticles as f64); - } - - // Store mean prediction - let mean_pred: f64 = preds.iter().map(|p| p.prediction()).sum::() - / self.nparticles as f64; - let mut pred = preds[0].clone(); - pred.set_prediction(mean_pred); - let col = Array2::from_shape_vec( - (self.nparticles, 1), - vec![pred; self.nparticles], - ) - .unwrap(); - output = concatenate(Axis(1), &[output.view(), col.view()]).unwrap(); - } - } - - if let Some(next) = events.get(idx + 1) { - if !event.time().eq(&next.time()) { - self.solve( - &mut state, - params, - covariates, - &infusions, - event.time(), - next.time(), - )?; - } - } - } - } - - let ll = bound_em.map(|_| likelihood.iter().product::()); - Ok((output, ll)) + crate::core::standard_event_loop::( + self, + subject, + params, + error_models, + ) } fn log_likelihood( From 74493d5d8143d8dc4dd4131703360762aa384e95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 18 Jun 2026 11:16:52 +0100 Subject: [PATCH 7/7] fix benchmarks --- benches/common/mod.rs | 5 +---- benches/dsl_matrix.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/benches/common/mod.rs b/benches/common/mod.rs index 21852c7b..9002b968 100644 --- a/benches/common/mod.rs +++ b/benches/common/mod.rs @@ -17,10 +17,7 @@ use pharmsol::prelude::*; use pharmsol::simulator::backends::analytical::{ one_compartment_with_absorption, two_compartments, }; -use pharmsol::{ - backends::{self, Route}, - Analytical, ResidualErrorModel, ResidualErrorModels, ODE, SDE, -}; +use pharmsol::{backends::Route, Analytical, ResidualErrorModel, ResidualErrorModels, ODE, SDE}; /// `ModelMetadata` for handwritten factories so route/output labels resolve like the macro/DSL paths. fn model_metadata(workload: Workload, kind: SolverKind) -> ModelMetadata { diff --git a/benches/dsl_matrix.rs b/benches/dsl_matrix.rs index 6c109f0f..fb9314f8 100644 --- a/benches/dsl_matrix.rs +++ b/benches/dsl_matrix.rs @@ -19,7 +19,7 @@ use pharmsol::dsl::{ NativeAotCompileOptions, NativeOdeModel, NativeSdeModel, RuntimeCompilationTarget, }; use pharmsol::prelude::*; -use pharmsol::{Cache, Parameters}; +use pharmsol::Parameters; mod common; use common::{ @@ -245,7 +245,7 @@ fn predictions_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_ode(workload, backend, &aot), CacheState::Cold => { - compile_ode(workload, backend, &aot).disable_cache() + compile_ode(workload, backend, &aot).without_cache() } }; let theta = ode_parameters(&model, workload); @@ -266,7 +266,7 @@ fn predictions_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_analytical(workload, backend, &aot), CacheState::Cold => { - compile_analytical(workload, backend, &aot).disable_cache() + compile_analytical(workload, backend, &aot).without_cache() } }; let theta = analytical_parameters(&model, workload); @@ -287,7 +287,7 @@ fn predictions_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_sde(workload, backend, &aot), CacheState::Cold => { - compile_sde(workload, backend, &aot).disable_cache() + compile_sde(workload, backend, &aot).without_cache() } }; let theta = sde_parameters(&model, workload); @@ -339,7 +339,7 @@ fn log_likelihood_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_ode(workload, backend, &aot), CacheState::Cold => { - compile_ode(workload, backend, &aot).disable_cache() + compile_ode(workload, backend, &aot).without_cache() } }; let theta = ode_parameters(&model, workload); @@ -361,7 +361,7 @@ fn log_likelihood_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_analytical(workload, backend, &aot), CacheState::Cold => { - compile_analytical(workload, backend, &aot).disable_cache() + compile_analytical(workload, backend, &aot).without_cache() } }; let theta = analytical_parameters(&model, workload); @@ -383,7 +383,7 @@ fn log_likelihood_group(c: &mut Criterion) { let model = match cache { CacheState::Hot => compile_sde(workload, backend, &aot), CacheState::Cold => { - compile_sde(workload, backend, &aot).disable_cache() + compile_sde(workload, backend, &aot).without_cache() } }; let theta = sde_parameters(&model, workload);