diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 4233d7a5..71963f08 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -4,7 +4,8 @@ use super::ast::*; use super::diagnostic::{Applicability, DiagnosticSuggestion, ParseError, Span, TextEdit}; use super::parser::{parse_expr_fragment, parse_place_fragment}; use crate::name_match::{ - common_prefix_len, edit_distance, is_high_confidence_match, is_single_adjacent_transposition, + bare_numeric_label, canonical_numeric_suffix, common_prefix_len, edit_distance, + is_high_confidence_match, is_single_adjacent_transposition, }; use crate::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; @@ -1146,15 +1147,6 @@ impl LabelKind { } } -fn bare_numeric_label(src: &str) -> Option<&str> { - (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) -} - -fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { - let suffix = src.strip_prefix(prefix)?; - (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) -} - fn parse_place_at(src: &str, abs_start: usize) -> Result { let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?; shift_place(&mut place, abs_start); diff --git a/pharmsol-dsl/src/name_match.rs b/pharmsol-dsl/src/name_match.rs index 33d90953..7cadfeb5 100644 --- a/pharmsol-dsl/src/name_match.rs +++ b/pharmsol-dsl/src/name_match.rs @@ -1,3 +1,12 @@ +pub(crate) fn bare_numeric_label(src: &str) -> Option<&str> { + (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) +} + +pub(crate) fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { + let suffix = src.strip_prefix(prefix)?; + (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) +} + pub(crate) fn is_high_confidence_match( needle: &str, candidate: &str, diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 981db69d..7b09efb4 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -9,7 +9,8 @@ use crate::diagnostic::{ }; use crate::ir::*; use crate::name_match::{ - common_prefix_len, edit_distance, is_high_confidence_match, is_single_adjacent_transposition, + bare_numeric_label, canonical_numeric_suffix, common_prefix_len, edit_distance, + is_high_confidence_match, is_single_adjacent_transposition, }; use crate::{ModelKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; @@ -2341,15 +2342,6 @@ fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bo ) } -fn bare_numeric_label(src: &str) -> Option<&str> { - (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) -} - -fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { - let suffix = src.strip_prefix(prefix)?; - (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) -} - fn numeric_label_literal_suffix(value: f64) -> Option { (value.is_finite() && value >= 0.0 && value.fract() == 0.0 && value <= usize::MAX as f64) .then(|| (value as usize).to_string()) diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 0e76c4cc..50abde21 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -21,7 +21,9 @@ use super::compiled_backend_abi::{ OUTPUTS_SYMBOL, ROUTE_BIOAVAILABILITY_SYMBOL, ROUTE_LAG_SYMBOL, }; #[cfg(feature = "dsl-aot-load")] -use super::native::{CompiledNativeModel, DenseKernelFn, NativeExecutionArtifact, NativeModelInfo}; +use super::native::{ + CompiledNativeModel, DenseKernelFn, NativeExecutionArtifact, NativeKernels, NativeModelInfo, +}; #[cfg(feature = "dsl-aot")] use super::rust_backend::{emit_rust_backend_source, RustBackendFlavor}; #[cfg(feature = "dsl-aot")] @@ -320,18 +322,17 @@ pub fn load_aot_model(path: impl AsRef) -> Result usize { +pub(super) fn kernel_output_len(info: &NativeModelInfo, role: KernelRole) -> usize { match role { KernelRole::Derive => info.derived_len, KernelRole::Dynamics | KernelRole::Init | KernelRole::Drift | KernelRole::Diffusion => { diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index d9691b57..aef6fcc4 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -261,98 +261,84 @@ pub fn compile_execution_artifact( let mut ctx = module.make_context(); let mut builder_context = FunctionBuilderContext::new(); - let derive = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Derive, - )?; - let dynamics = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Dynamics, - )?; - let outputs = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Outputs, - )? - .ok_or_else(|| JitCompileError::new("missing outputs kernel", Some(model.span)))?; - let init = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Init, - )?; - let drift = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Drift, - )?; - let diffusion = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::Diffusion, - )?; - let route_lag = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::RouteLag, - )?; - let route_bioavailability = compile_role_kernel( - &mut module, - &mut ctx, - &mut builder_context, - ptr_ty, - externs, - model, - KernelRole::RouteBioavailability, - )?; + let mut compiled: [Option; 8] = [None; 8]; + for role in ROLES_IN_TABLE_ORDER { + compiled[role_index(role)] = compile_role_kernel( + &mut module, + &mut ctx, + &mut builder_context, + ptr_ty, + externs, + model, + role, + )?; + } + + let outputs_id = compiled[role_index(KernelRole::Outputs)] + .ok_or_else(|| JitCompileError::new("missing outputs kernel", Some(model.span)))?; module .finalize_definitions() .map_err(|error| JitCompileError::new(error.to_string(), Some(model.span)))?; + let mut lookup = + |role: KernelRole| compiled[role_index(role)].map(|id| function_pointer(&mut module, id)); + // Borrow checker forbids reusing the closure once `module` is moved into the + // artifact, so resolve every pointer up front. + let derive = lookup(KernelRole::Derive); + let dynamics = lookup(KernelRole::Dynamics); + let init = lookup(KernelRole::Init); + let drift = lookup(KernelRole::Drift); + let diffusion = lookup(KernelRole::Diffusion); + let route_lag = lookup(KernelRole::RouteLag); + let route_bioavailability = lookup(KernelRole::RouteBioavailability); + let outputs = function_pointer(&mut module, outputs_id); + Ok(NativeExecutionArtifact::from_jit_module( model.name.clone(), - derive.map(|id| function_pointer(&mut module, id)), - dynamics.map(|id| function_pointer(&mut module, id)), - function_pointer(&mut module, outputs), - init.map(|id| function_pointer(&mut module, id)), - drift.map(|id| function_pointer(&mut module, id)), - diffusion.map(|id| function_pointer(&mut module, id)), - route_lag.map(|id| function_pointer(&mut module, id)), - route_bioavailability.map(|id| function_pointer(&mut module, id)), + super::native::NativeKernels { + derive, + dynamics, + outputs, + init, + drift, + diffusion, + route_lag, + route_bioavailability, + }, module, )) } +/// The eight kernel roles materialised by the JIT/AoT/WASM backends, in the +/// canonical order used by [`NativeKernels`]. +const ROLES_IN_TABLE_ORDER: [KernelRole; 8] = [ + KernelRole::Derive, + KernelRole::Dynamics, + KernelRole::Outputs, + KernelRole::Init, + KernelRole::Drift, + KernelRole::Diffusion, + KernelRole::RouteLag, + KernelRole::RouteBioavailability, +]; + +fn role_index(role: KernelRole) -> usize { + match role { + KernelRole::Derive => 0, + KernelRole::Dynamics => 1, + KernelRole::Outputs => 2, + KernelRole::Init => 3, + KernelRole::Drift => 4, + KernelRole::Diffusion => 5, + KernelRole::RouteLag => 6, + KernelRole::RouteBioavailability => 7, + KernelRole::Analytical => { + unreachable!("analytical kernels are not stored in the JIT table") + } + } +} + fn declare_externs(module: &mut JITModule, span: Span) -> Result { let declare_unary = |module: &mut JITModule, symbol: &str| -> Result<_, JitCompileError> { let mut signature = module.make_signature(); @@ -1519,7 +1505,7 @@ out(cp) = central / v ~ continuous() let covariates = [70.0]; let routes = [0.0, 0.0]; - let derive = artifact.derive.expect("derive kernel present"); + let derive = artifact.kernels.derive.expect("derive kernel present"); unsafe { derive( 0.0, @@ -1530,7 +1516,7 @@ out(cp) = central / v ~ continuous() derived.as_ptr(), derived.as_mut_ptr(), ); - artifact.dynamics.expect("dynamics kernel present")( + artifact.kernels.dynamics.expect("dynamics kernel present")( 0.0, states.as_ptr(), params.as_ptr(), @@ -1539,7 +1525,7 @@ out(cp) = central / v ~ continuous() derived.as_ptr(), dx.as_mut_ptr(), ); - (artifact.outputs)( + (artifact.kernels.outputs)( 0.0, states.as_ptr(), params.as_ptr(), diff --git a/src/dsl/mod.rs b/src/dsl/mod.rs index 0f4cb0ab..07adbe2e 100644 --- a/src/dsl/mod.rs +++ b/src/dsl/mod.rs @@ -142,7 +142,9 @@ pub use jit::{ compile_ode_model_to_jit, compile_sde_model_to_jit, CompiledJitModel, JitAnalyticalModel, JitCompileError, JitExecutionArtifact, JitOdeModel, JitSdeModel, }; -pub use model_info::{NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo}; +pub use model_info::{ + NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, NativeStateInfo, +}; #[cfg(any( feature = "dsl-jit", feature = "dsl-aot-load", diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 634dddb1..56f238f6 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -109,8 +109,12 @@ impl std::fmt::Debug for NativeArtifactOwner { } } -pub struct NativeExecutionArtifact { - pub model_name: String, +/// Table of compiled native kernels indexed by [`KernelRole`]. +/// +/// `outputs` is the only required kernel; every other role is optional and +/// absent when the source model does not declare it. +#[derive(Clone, Copy)] +pub struct NativeKernels { pub derive: Option, pub dynamics: Option, pub outputs: DenseKernelFn, @@ -119,6 +123,49 @@ pub struct NativeExecutionArtifact { pub diffusion: Option, pub route_lag: Option, pub route_bioavailability: Option, +} + +impl NativeKernels { + /// Look up the kernel registered for `role`, treating `Outputs` as always + /// present and `Analytical` as never having a dense kernel. + pub(crate) fn get(&self, role: KernelRole) -> Option { + match role { + KernelRole::Derive => self.derive, + KernelRole::Dynamics => self.dynamics, + KernelRole::Outputs => Some(self.outputs), + KernelRole::Init => self.init, + KernelRole::Drift => self.drift, + KernelRole::Diffusion => self.diffusion, + KernelRole::RouteLag => self.route_lag, + KernelRole::RouteBioavailability => self.route_bioavailability, + KernelRole::Analytical => None, + } + } + + pub(crate) fn has(&self, role: KernelRole) -> bool { + self.get(role).is_some() + } +} + +impl std::fmt::Debug for NativeKernels { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let as_ptr = |k: Option| k.map(|ptr| ptr as *const ()); + f.debug_struct("NativeKernels") + .field("derive", &as_ptr(self.derive)) + .field("dynamics", &as_ptr(self.dynamics)) + .field("outputs", &(self.outputs as *const ())) + .field("init", &as_ptr(self.init)) + .field("drift", &as_ptr(self.drift)) + .field("diffusion", &as_ptr(self.diffusion)) + .field("route_lag", &as_ptr(self.route_lag)) + .field("route_bioavailability", &as_ptr(self.route_bioavailability)) + .finish() + } +} + +pub struct NativeExecutionArtifact { + pub model_name: String, + pub kernels: NativeKernels, _owner: Option, } @@ -129,74 +176,34 @@ impl std::fmt::Debug for NativeExecutionArtifact { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("NativeExecutionArtifact") .field("model_name", &self.model_name) - .field("derive", &self.derive.map(|ptr| ptr as *const ())) - .field("dynamics", &self.dynamics.map(|ptr| ptr as *const ())) - .field("outputs", &(self.outputs as *const ())) - .field("init", &self.init.map(|ptr| ptr as *const ())) - .field("drift", &self.drift.map(|ptr| ptr as *const ())) - .field("diffusion", &self.diffusion.map(|ptr| ptr as *const ())) - .field("route_lag", &self.route_lag.map(|ptr| ptr as *const ())) - .field( - "route_bioavailability", - &self.route_bioavailability.map(|ptr| ptr as *const ()), - ) + .field("kernels", &self.kernels) .finish() } } impl NativeExecutionArtifact { #[cfg(feature = "dsl-jit")] - #[allow(clippy::too_many_arguments)] pub(crate) fn from_jit_module( model_name: String, - derive: Option, - dynamics: Option, - outputs: DenseKernelFn, - init: Option, - drift: Option, - diffusion: Option, - route_lag: Option, - route_bioavailability: Option, + kernels: NativeKernels, module: JITModule, ) -> Self { Self { model_name, - derive, - dynamics, - outputs, - init, - drift, - diffusion, - route_lag, - route_bioavailability, + kernels, _owner: Some(NativeArtifactOwner::Jit(Box::new(module))), } } #[cfg(feature = "dsl-aot-load")] - #[allow(clippy::too_many_arguments)] pub(crate) fn from_library( model_name: String, - derive: Option, - dynamics: Option, - outputs: DenseKernelFn, - init: Option, - drift: Option, - diffusion: Option, - route_lag: Option, - route_bioavailability: Option, + kernels: NativeKernels, library: Library, ) -> Self { Self { model_name, - derive, - dynamics, - outputs, - init, - drift, - diffusion, - route_lag, - route_bioavailability, + kernels, _owner: Some(NativeArtifactOwner::Library(library)), } } @@ -218,18 +225,7 @@ impl KernelSession for NativeKernelSession<'_> { derived: *const f64, out: *mut f64, ) -> Result<(), PharmsolError> { - let kernel = match role { - KernelRole::Derive => self.artifact.derive, - KernelRole::Dynamics => self.artifact.dynamics, - KernelRole::Outputs => Some(self.artifact.outputs), - KernelRole::Init => self.artifact.init, - KernelRole::Drift => self.artifact.drift, - KernelRole::Diffusion => self.artifact.diffusion, - KernelRole::RouteLag => self.artifact.route_lag, - KernelRole::RouteBioavailability => self.artifact.route_bioavailability, - KernelRole::Analytical => None, - } - .ok_or_else(|| { + let kernel = self.artifact.kernels.get(role).ok_or_else(|| { PharmsolError::OtherError(format!( "model `{}` does not provide a {:?} kernel", self.artifact.model_name, role @@ -253,17 +249,7 @@ impl RuntimeArtifact for NativeExecutionArtifact { } fn has_kernel(&self, role: KernelRole) -> bool { - match role { - KernelRole::Derive => self.derive.is_some(), - KernelRole::Dynamics => self.dynamics.is_some(), - KernelRole::Outputs => true, - KernelRole::Init => self.init.is_some(), - KernelRole::Drift => self.drift.is_some(), - KernelRole::Diffusion => self.diffusion.is_some(), - KernelRole::RouteLag => self.route_lag.is_some(), - KernelRole::RouteBioavailability => self.route_bioavailability.is_some(), - KernelRole::Analytical => false, - } + self.kernels.has(role) } fn start_session(&self) -> Result, PharmsolError> { diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index a09b0247..966407ff 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -2,15 +2,13 @@ use std::ops::Range; use std::path::Path; use std::sync::Mutex; -#[allow(unused)] -use crate::dsl::model_info; use wasmtime::{Engine, Instance, Linker, Memory, Module, Store, TypedFunc}; use super::compiled_backend_abi::{ - decode_compiled_model_info, CompiledKernelAvailability, ALLOC_F64_BUFFER_SYMBOL, - API_VERSION_SYMBOL, DERIVE_SYMBOL, DIFFUSION_SYMBOL, DRIFT_SYMBOL, DYNAMICS_SYMBOL, - FREE_F64_BUFFER_SYMBOL, INIT_SYMBOL, MODEL_INFO_JSON_LEN_SYMBOL, MODEL_INFO_JSON_PTR_SYMBOL, - OUTPUTS_SYMBOL, ROUTE_BIOAVAILABILITY_SYMBOL, ROUTE_LAG_SYMBOL, + decode_compiled_model_info, kernel_output_len, CompiledKernelAvailability, + ALLOC_F64_BUFFER_SYMBOL, API_VERSION_SYMBOL, DERIVE_SYMBOL, DIFFUSION_SYMBOL, DRIFT_SYMBOL, + DYNAMICS_SYMBOL, FREE_F64_BUFFER_SYMBOL, INIT_SYMBOL, MODEL_INFO_JSON_LEN_SYMBOL, + MODEL_INFO_JSON_PTR_SYMBOL, OUTPUTS_SYMBOL, ROUTE_BIOAVAILABILITY_SYMBOL, ROUTE_LAG_SYMBOL, }; use super::native::{KernelSession, NativeModelInfo, RuntimeArtifact, RuntimeBackend}; use super::wasm_compile::{WasmError, WASM_API_VERSION}; @@ -20,52 +18,11 @@ use super::wasm_direct_emitter::{ use crate::PharmsolError; use pharmsol_dsl::execution::KernelRole; -#[derive(Clone, Copy, Debug, Default)] -struct WasmKernelAvailability { - derive: bool, - dynamics: bool, - outputs: bool, - init: bool, - drift: bool, - diffusion: bool, - route_lag: bool, - route_bioavailability: bool, -} - -impl WasmKernelAvailability { - fn has(self, role: KernelRole) -> bool { - match role { - KernelRole::Derive => self.derive, - KernelRole::Dynamics => self.dynamics, - KernelRole::Outputs => self.outputs, - KernelRole::Init => self.init, - KernelRole::Drift => self.drift, - KernelRole::Diffusion => self.diffusion, - KernelRole::RouteLag => self.route_lag, - KernelRole::RouteBioavailability => self.route_bioavailability, - KernelRole::Analytical => false, - } - } - - fn compiled(self) -> CompiledKernelAvailability { - CompiledKernelAvailability { - derive: self.derive, - dynamics: self.dynamics, - outputs: self.outputs, - init: self.init, - drift: self.drift, - diffusion: self.diffusion, - route_lag: self.route_lag, - route_bioavailability: self.route_bioavailability, - } - } -} - pub(crate) struct WasmExecutionArtifact { info: NativeModelInfo, engine: Engine, module: Module, - kernels: WasmKernelAvailability, + kernels: CompiledKernelAvailability, session_pool: Mutex>, } @@ -120,7 +77,7 @@ impl WasmKernelSession { info: &NativeModelInfo, engine: &Engine, module: &Module, - kernels: WasmKernelAvailability, + kernels: CompiledKernelAvailability, ) -> Result { let mut store = Store::new(engine, ()); let linker = configured_wasm_linker(engine)?; @@ -147,42 +104,21 @@ impl WasmKernelSession { .max(info.route_len), )?; - let derive = if kernels.derive { - optional_typed_func(&instance, &mut store, DERIVE_SYMBOL)? - } else { - None - }; - let dynamics = if kernels.dynamics { - optional_typed_func(&instance, &mut store, DYNAMICS_SYMBOL)? - } else { - None - }; + let derive = maybe_typed_func(&instance, &mut store, kernels.derive, DERIVE_SYMBOL)?; + let dynamics = maybe_typed_func(&instance, &mut store, kernels.dynamics, DYNAMICS_SYMBOL)?; let outputs = typed_func(&instance, &mut store, OUTPUTS_SYMBOL)?; - let init = if kernels.init { - optional_typed_func(&instance, &mut store, INIT_SYMBOL)? - } else { - None - }; - let drift = if kernels.drift { - optional_typed_func(&instance, &mut store, DRIFT_SYMBOL)? - } else { - None - }; - let diffusion = if kernels.diffusion { - optional_typed_func(&instance, &mut store, DIFFUSION_SYMBOL)? - } else { - None - }; - let route_lag = if kernels.route_lag { - optional_typed_func(&instance, &mut store, ROUTE_LAG_SYMBOL)? - } else { - None - }; - let route_bioavailability = if kernels.route_bioavailability { - optional_typed_func(&instance, &mut store, ROUTE_BIOAVAILABILITY_SYMBOL)? - } else { - None - }; + let init = maybe_typed_func(&instance, &mut store, kernels.init, INIT_SYMBOL)?; + let drift = maybe_typed_func(&instance, &mut store, kernels.drift, DRIFT_SYMBOL)?; + let diffusion = + maybe_typed_func(&instance, &mut store, kernels.diffusion, DIFFUSION_SYMBOL)?; + let route_lag = + maybe_typed_func(&instance, &mut store, kernels.route_lag, ROUTE_LAG_SYMBOL)?; + let route_bioavailability = maybe_typed_func( + &instance, + &mut store, + kernels.route_bioavailability, + ROUTE_BIOAVAILABILITY_SYMBOL, + )?; Ok(Self { info: info.clone(), @@ -485,7 +421,7 @@ fn load_wasm_artifact_from_module( .get_memory(&mut store, "memory") .ok_or(WasmError::MissingExport("memory"))?; let (info, expected_kernels) = read_model_info_envelope(&instance, &mut store, &memory)?; - let kernels = WasmKernelAvailability { + let kernels = CompiledKernelAvailability { derive: instance.get_func(&mut store, DERIVE_SYMBOL).is_some(), dynamics: instance.get_func(&mut store, DYNAMICS_SYMBOL).is_some(), outputs: instance.get_func(&mut store, OUTPUTS_SYMBOL).is_some(), @@ -497,12 +433,11 @@ fn load_wasm_artifact_from_module( .get_func(&mut store, ROUTE_BIOAVAILABILITY_SYMBOL) .is_some(), }; - let found_kernels = kernels.compiled(); - if found_kernels != expected_kernels { + if kernels != expected_kernels { return Err(WasmError::KernelMetadataMismatch { model: info.name.clone(), expected: expected_kernels, - found: found_kernels, + found: kernels, }); } @@ -571,18 +506,6 @@ fn alloc_buffer( Ok(WasmBuffer { ptr, len }) } -fn kernel_output_len(info: &NativeModelInfo, role: KernelRole) -> usize { - match role { - KernelRole::Derive => info.derived_len, - KernelRole::Dynamics | KernelRole::Init | KernelRole::Drift | KernelRole::Diffusion => { - info.state_len - } - KernelRole::Outputs => info.output_len, - KernelRole::RouteLag | KernelRole::RouteBioavailability => info.route_len, - KernelRole::Analytical => 0, - } -} - fn typed_func( instance: &Instance, store: &mut Store<()>, @@ -612,6 +535,23 @@ where } } +fn maybe_typed_func( + instance: &Instance, + store: &mut Store<()>, + available: bool, + name: &'static str, +) -> Result>, WasmError> +where + Params: wasmtime::WasmParams, + Results: wasmtime::WasmResults, +{ + if available { + optional_typed_func(instance, store, name) + } else { + Ok(None) + } +} + unsafe fn raw_slice<'a>(ptr: *const f64, len: usize) -> &'a [f64] { if len == 0 { &[] @@ -768,7 +708,7 @@ mod tests { use super::*; use crate::dsl::{ compile_execution_artifact, CompiledKernelAvailability, CompiledModelInfoEnvelope, - NativeModelInfo, NativeOutputInfo, NativeRouteInfo, + NativeModelInfo, NativeOutputInfo, NativeRouteInfo, NativeStateInfo, }; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; use approx::assert_relative_eq; @@ -802,7 +742,7 @@ mod tests { parameters: vec!["ka".to_string()], derived: Vec::new(), covariates: Vec::new(), - states: vec![super::model_info::NativeStateInfo { + states: vec![NativeStateInfo { name: "depot".to_string(), offset: 0, }], @@ -1107,7 +1047,7 @@ mod tests { let mut jit_dynamics = vec![0.0; info.state_len]; unsafe { - jit.derive.expect("jit derive")( + jit.kernels.derive.expect("jit derive")( 0.0, states.as_ptr(), params.as_ptr(), @@ -1116,7 +1056,7 @@ mod tests { jit_derived.as_ptr(), jit_derived.as_mut_ptr(), ); - (jit.outputs)( + (jit.kernels.outputs)( 0.0, states.as_ptr(), params.as_ptr(), @@ -1125,7 +1065,7 @@ mod tests { jit_derived.as_ptr(), jit_outputs.as_mut_ptr(), ); - jit.dynamics.expect("jit dynamics")( + jit.kernels.dynamics.expect("jit dynamics")( 0.0, states.as_ptr(), params.as_ptr(), @@ -1272,7 +1212,7 @@ mod tests { let mut derived = vec![0.0; info.derived_len]; unsafe { - jit.derive.expect("jit derive")( + jit.kernels.derive.expect("jit derive")( 0.0, expected.as_ptr(), params.as_ptr(), @@ -1281,7 +1221,7 @@ mod tests { derived.as_ptr(), derived.as_mut_ptr(), ); - jit.dynamics.expect("jit dynamics")( + jit.kernels.dynamics.expect("jit dynamics")( 0.0, expected.as_ptr(), params.as_ptr(), @@ -1327,7 +1267,7 @@ mod tests { let mut actual = vec![42.0; info.state_len]; unsafe { - jit.diffusion.expect("jit diffusion")( + jit.kernels.diffusion.expect("jit diffusion")( 0.0, states.as_ptr(), params.as_ptr(), @@ -1376,7 +1316,7 @@ mod tests { let mut actual_bioavailability = vec![f64::NAN; info.route_len]; unsafe { - jit.route_lag.expect("jit route lag")( + jit.kernels.route_lag.expect("jit route lag")( 0.0, states.as_ptr(), params.as_ptr(), @@ -1385,7 +1325,8 @@ mod tests { derived.as_ptr(), expected_lag.as_mut_ptr(), ); - jit.route_bioavailability + jit.kernels + .route_bioavailability .expect("jit route bioavailability")( 0.0, states.as_ptr(),