Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions pharmsol-dsl/src/authoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use super::ast::*;
use super::diagnostic::{Applicability, DiagnosticSuggestion, ParseError, Span, TextEdit};
use super::parser::{parse_expr_fragment, parse_place_fragment};
use crate::name_match::{
common_prefix_len, edit_distance, is_high_confidence_match, is_single_adjacent_transposition,
bare_numeric_label, canonical_numeric_suffix, common_prefix_len, edit_distance,
is_high_confidence_match, is_single_adjacent_transposition,
};
use crate::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME};

Expand Down Expand Up @@ -1146,15 +1147,6 @@ impl LabelKind {
}
}

fn bare_numeric_label(src: &str) -> Option<&str> {
(!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src)
}

fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> {
let suffix = src.strip_prefix(prefix)?;
(!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix)
}

fn parse_place_at(src: &str, abs_start: usize) -> Result<Place, ParseError> {
let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?;
shift_place(&mut place, abs_start);
Expand Down
9 changes: 9 additions & 0 deletions pharmsol-dsl/src/name_match.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
pub(crate) fn bare_numeric_label(src: &str) -> Option<&str> {
(!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src)
}

pub(crate) fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> {
let suffix = src.strip_prefix(prefix)?;
(!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix)
}

pub(crate) fn is_high_confidence_match(
needle: &str,
candidate: &str,
Expand Down
12 changes: 2 additions & 10 deletions pharmsol-dsl/src/semantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use crate::diagnostic::{
};
use crate::ir::*;
use crate::name_match::{
common_prefix_len, edit_distance, is_high_confidence_match, is_single_adjacent_transposition,
bare_numeric_label, canonical_numeric_suffix, common_prefix_len, edit_distance,
is_high_confidence_match, is_single_adjacent_transposition,
};
use crate::{ModelKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME};

Expand Down Expand Up @@ -2341,15 +2342,6 @@ fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bo
)
}

fn bare_numeric_label(src: &str) -> Option<&str> {
(!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src)
}

fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> {
let suffix = src.strip_prefix(prefix)?;
(!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix)
}

fn numeric_label_literal_suffix(value: f64) -> Option<String> {
(value.is_finite() && value >= 0.0 && value.fract() == 0.0 && value <= usize::MAX as f64)
.then(|| (value as usize).to_string())
Expand Down
27 changes: 14 additions & 13 deletions src/dsl/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use super::compiled_backend_abi::{
OUTPUTS_SYMBOL, ROUTE_BIOAVAILABILITY_SYMBOL, ROUTE_LAG_SYMBOL,
};
#[cfg(feature = "dsl-aot-load")]
use super::native::{CompiledNativeModel, DenseKernelFn, NativeExecutionArtifact, NativeModelInfo};
use super::native::{
CompiledNativeModel, DenseKernelFn, NativeExecutionArtifact, NativeKernels, NativeModelInfo,
};
#[cfg(feature = "dsl-aot")]
use super::rust_backend::{emit_rust_backend_source, RustBackendFlavor};
#[cfg(feature = "dsl-aot")]
Expand Down Expand Up @@ -320,18 +322,17 @@ pub fn load_aot_model(path: impl AsRef<Path>) -> Result<CompiledNativeModel, Aot
let info = unsafe { read_model_info_from_library(&library)? };
let model_name = info.name.clone();
let artifact = unsafe {
NativeExecutionArtifact::from_library(
model_name,
load_optional_kernel(&library, DERIVE_SYMBOL),
load_optional_kernel(&library, DYNAMICS_SYMBOL),
load_required_kernel(&library, OUTPUTS_SYMBOL)?,
load_optional_kernel(&library, INIT_SYMBOL),
load_optional_kernel(&library, DRIFT_SYMBOL),
load_optional_kernel(&library, DIFFUSION_SYMBOL),
load_optional_kernel(&library, ROUTE_LAG_SYMBOL),
load_optional_kernel(&library, ROUTE_BIOAVAILABILITY_SYMBOL),
library,
)
let kernels = NativeKernels {
derive: load_optional_kernel(&library, DERIVE_SYMBOL),
dynamics: load_optional_kernel(&library, DYNAMICS_SYMBOL),
outputs: load_required_kernel(&library, OUTPUTS_SYMBOL)?,
init: load_optional_kernel(&library, INIT_SYMBOL),
drift: load_optional_kernel(&library, DRIFT_SYMBOL),
diffusion: load_optional_kernel(&library, DIFFUSION_SYMBOL),
route_lag: load_optional_kernel(&library, ROUTE_LAG_SYMBOL),
route_bioavailability: load_optional_kernel(&library, ROUTE_BIOAVAILABILITY_SYMBOL),
};
NativeExecutionArtifact::from_library(model_name, kernels, library)
};

Ok(match info.kind {
Expand Down
3 changes: 1 addition & 2 deletions src/dsl/compiled_backend_abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ pub fn output_buffer_plan(
}
}

#[cfg(test)]
fn kernel_output_len(info: &NativeModelInfo, role: KernelRole) -> usize {
pub(super) fn kernel_output_len(info: &NativeModelInfo, role: KernelRole) -> usize {
match role {
KernelRole::Derive => info.derived_len,
KernelRole::Dynamics | KernelRole::Init | KernelRole::Drift | KernelRole::Diffusion => {
Expand Down
154 changes: 70 additions & 84 deletions src/dsl/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,98 +261,84 @@ pub fn compile_execution_artifact(
let mut ctx = module.make_context();
let mut builder_context = FunctionBuilderContext::new();

let derive = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Derive,
)?;
let dynamics = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Dynamics,
)?;
let outputs = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Outputs,
)?
.ok_or_else(|| JitCompileError::new("missing outputs kernel", Some(model.span)))?;
let init = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Init,
)?;
let drift = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Drift,
)?;
let diffusion = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::Diffusion,
)?;
let route_lag = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::RouteLag,
)?;
let route_bioavailability = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
KernelRole::RouteBioavailability,
)?;
let mut compiled: [Option<cranelift_module::FuncId>; 8] = [None; 8];
for role in ROLES_IN_TABLE_ORDER {
compiled[role_index(role)] = compile_role_kernel(
&mut module,
&mut ctx,
&mut builder_context,
ptr_ty,
externs,
model,
role,
)?;
}

let outputs_id = compiled[role_index(KernelRole::Outputs)]
.ok_or_else(|| JitCompileError::new("missing outputs kernel", Some(model.span)))?;

module
.finalize_definitions()
.map_err(|error| JitCompileError::new(error.to_string(), Some(model.span)))?;

let mut lookup =
|role: KernelRole| compiled[role_index(role)].map(|id| function_pointer(&mut module, id));
// Borrow checker forbids reusing the closure once `module` is moved into the
// artifact, so resolve every pointer up front.
let derive = lookup(KernelRole::Derive);
let dynamics = lookup(KernelRole::Dynamics);
let init = lookup(KernelRole::Init);
let drift = lookup(KernelRole::Drift);
let diffusion = lookup(KernelRole::Diffusion);
let route_lag = lookup(KernelRole::RouteLag);
let route_bioavailability = lookup(KernelRole::RouteBioavailability);
let outputs = function_pointer(&mut module, outputs_id);

Ok(NativeExecutionArtifact::from_jit_module(
model.name.clone(),
derive.map(|id| function_pointer(&mut module, id)),
dynamics.map(|id| function_pointer(&mut module, id)),
function_pointer(&mut module, outputs),
init.map(|id| function_pointer(&mut module, id)),
drift.map(|id| function_pointer(&mut module, id)),
diffusion.map(|id| function_pointer(&mut module, id)),
route_lag.map(|id| function_pointer(&mut module, id)),
route_bioavailability.map(|id| function_pointer(&mut module, id)),
super::native::NativeKernels {
derive,
dynamics,
outputs,
init,
drift,
diffusion,
route_lag,
route_bioavailability,
},
module,
))
}

/// The eight kernel roles materialised by the JIT/AoT/WASM backends, in the
/// canonical order used by [`NativeKernels`].
const ROLES_IN_TABLE_ORDER: [KernelRole; 8] = [
KernelRole::Derive,
KernelRole::Dynamics,
KernelRole::Outputs,
KernelRole::Init,
KernelRole::Drift,
KernelRole::Diffusion,
KernelRole::RouteLag,
KernelRole::RouteBioavailability,
];

fn role_index(role: KernelRole) -> usize {
match role {
KernelRole::Derive => 0,
KernelRole::Dynamics => 1,
KernelRole::Outputs => 2,
KernelRole::Init => 3,
KernelRole::Drift => 4,
KernelRole::Diffusion => 5,
KernelRole::RouteLag => 6,
KernelRole::RouteBioavailability => 7,
KernelRole::Analytical => {
unreachable!("analytical kernels are not stored in the JIT table")
}
}
}

fn declare_externs(module: &mut JITModule, span: Span) -> Result<ExternIds, JitCompileError> {
let declare_unary = |module: &mut JITModule, symbol: &str| -> Result<_, JitCompileError> {
let mut signature = module.make_signature();
Expand Down Expand Up @@ -1519,7 +1505,7 @@ out(cp) = central / v ~ continuous()
let covariates = [70.0];
let routes = [0.0, 0.0];

let derive = artifact.derive.expect("derive kernel present");
let derive = artifact.kernels.derive.expect("derive kernel present");
unsafe {
derive(
0.0,
Expand All @@ -1530,7 +1516,7 @@ out(cp) = central / v ~ continuous()
derived.as_ptr(),
derived.as_mut_ptr(),
);
artifact.dynamics.expect("dynamics kernel present")(
artifact.kernels.dynamics.expect("dynamics kernel present")(
0.0,
states.as_ptr(),
params.as_ptr(),
Expand All @@ -1539,7 +1525,7 @@ out(cp) = central / v ~ continuous()
derived.as_ptr(),
dx.as_mut_ptr(),
);
(artifact.outputs)(
(artifact.kernels.outputs)(
0.0,
states.as_ptr(),
params.as_ptr(),
Expand Down
4 changes: 3 additions & 1 deletion src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ pub use jit::{
compile_ode_model_to_jit, compile_sde_model_to_jit, CompiledJitModel, JitAnalyticalModel,
JitCompileError, JitExecutionArtifact, JitOdeModel, JitSdeModel,
};
pub use model_info::{NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo};
pub use model_info::{
NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, NativeStateInfo,
};
#[cfg(any(
feature = "dsl-jit",
feature = "dsl-aot-load",
Expand Down
Loading
Loading