diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index e7de4118..fb224576 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -109,6 +109,65 @@ typedef enum mta_status_t { MTA_INTERNAL_ERROR = 255, } mta_status_t; +/** + * Section of the references stored in a `mta_model_metadata_t`. + */ +typedef enum mta_references_section_t { + /** + * References describing the model as a whole (e.g. the paper that + * introduced the model) + */ + MTA_REFERENCES_MODEL = 0, + /** + * References describing the model architecture (e.g. papers that + * describe the mathematical form) + */ + MTA_REFERENCES_ARCHITECTURE = 1, + /** + * References describing the model implementation (e.g. a link to the + * source-code repository) + */ + MTA_REFERENCES_IMPLEMENTATION = 2, +} mta_references_section_t; + +/** + * Data type for all inputs and outputs. + */ +typedef enum mta_dtype_t { + /** + * 32-bit floating point + */ + MTA_DTYPE_FLOAT32 = 0, + /** + * 64-bit floating point + */ + MTA_DTYPE_FLOAT64 = 1, +} mta_dtype_t; + +/** + * Device on which a model can run. + * + * Match DLpack dtypes + */ +typedef enum mta_device_t { + /** + * CPU device + */ + MTA_DEVICE_CPU = 0, + /** + * CUDA-capable NVIDIA GPU + */ + MTA_DEVICE_CUDA = 1, + /** + * ROCm-capable AMD GPU + */ + MTA_DEVICE_ROCM = 2, + /** + * Apple Metal GPU + */ + MTA_DEVICE_METAL = 3, +} mta_device_t; + /** * TODO */ @@ -119,6 +178,23 @@ typedef enum mta_system_data_kind { MTA_SYSTEM_DATA_PBC = 3, } mta_system_data_kind; +/** + * Opaque handle to model capabilities of a model: which outputs it can compute, which atomic types + * it supports, its interaction range, supported devices, and data type. + */ +typedef struct mta_model_capabilities_t mta_model_capabilities_t; + +/** + * Opaque handle to metadata describing a model: name, authors, description, references, and + * arbitrary extra key-value pairs. + */ +typedef struct mta_model_metadata_t mta_model_metadata_t; + +/** + * Opaque handle for a `PairListOptionsOptions` pair list (neighbor list) requested by a model. + */ +typedef struct mta_pair_list_options_t mta_pair_list_options_t; + /** * TODO */ @@ -403,6 +479,229 @@ enum mta_status_t mta_unit_conversion_factor(const char *from_unit, const char *to_unit, double *conversion); +/** + * Create a new `mta_pair_list_options_t` with an empty requestors list. + */ +enum mta_status_t mta_pair_list_options_create(double cutoff, + bool full_list, + bool strict, + struct mta_pair_list_options_t **options); + +/** + * Deserialize a `mta_pair_list_options_t` from a JSON string. + */ +enum mta_status_t mta_pair_list_options_from_json(const char *json, + struct mta_pair_list_options_t **options); + +/** + * Free a `mta_pair_list_options_t` object + */ +enum mta_status_t mta_pair_list_options_free(struct mta_pair_list_options_t *options); + +/** + * Serialize a `mta_pair_list_options_t` to a JSON string. + */ +enum mta_status_t mta_pair_list_options_to_json(const struct mta_pair_list_options_t *options, + mta_string_t *json); + +/** + * Get the type discriminator of a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_get_type(const struct mta_pair_list_options_t *options, + mta_string_t *type_); + +/** + * Get the cutoff radius from a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_get_cutoff(const struct mta_pair_list_options_t *options, + double *cutoff); + +/** + * Get the `full_list` flag from a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_get_full_list(const struct mta_pair_list_options_t *options, + bool *full_list); + +/** + * Get the `strict` flag from a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_get_strict(const struct mta_pair_list_options_t *options, + bool *strict); + +/** + * Get the number of requestors stored in a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_requestors_count(const struct mta_pair_list_options_t *options, + uintptr_t *count); + +/** + * Get a requestor string by index from a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_get_requestor(const struct mta_pair_list_options_t *options, + uintptr_t index, + mta_string_t *requestor); + +/** + * Add a requestor string to a `mta_pair_list_options_t`. + */ +enum mta_status_t mta_pair_list_options_add_requestor(struct mta_pair_list_options_t *options, + const char *requestor); + +/** + * Get `mta_model_metadata_t` from a JSON string. + */ +enum mta_status_t mta_model_metadata_from_json(const char *json, + struct mta_model_metadata_t **metadata); + +/** + * Free a `mta_model_metadata_t` previously created by any + * `mta_model_metadata_*` function. + */ +enum mta_status_t mta_model_metadata_free(struct mta_model_metadata_t *metadata); + +/** + * Serialize `mta_model_metadata_t` to a JSON string. + */ +enum mta_status_t mta_model_metadata_to_json(const struct mta_model_metadata_t *metadata, + mta_string_t *json); + +/** + * Get the name of a model from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_name(const struct mta_model_metadata_t *metadata, + mta_string_t *name); + +/** + * Get the description of a model from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_description(const struct mta_model_metadata_t *metadata, + mta_string_t *description); + +/** + * Get the number of authors of a model from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_authors_count(const struct mta_model_metadata_t *metadata, + uintptr_t *count); + +/** + * Get an author string by index from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_author(const struct mta_model_metadata_t *metadata, + uintptr_t index, + mta_string_t *author); + +/** + * Get the number of references in a section of a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_references_count(const struct mta_model_metadata_t *metadata, + enum mta_references_section_t section, + uintptr_t *count); + +/** + * Get a reference string by index from a section of a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_reference(const struct mta_model_metadata_t *metadata, + enum mta_references_section_t section, + uintptr_t index, + mta_string_t *reference); + +/** + * Get the number of entries in the `extra` key-value map of a + * `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_extra_count(const struct mta_model_metadata_t *metadata, + uintptr_t *count); + +/** + * Get an extra metadata key by position from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_extra_key(const struct mta_model_metadata_t *metadata, + uintptr_t index, + mta_string_t *key); + +/** + * Get an extra metadata value by key from a `mta_model_metadata_t`. + */ +enum mta_status_t mta_model_metadata_get_extra_value(const struct mta_model_metadata_t *metadata, + const char *key, + mta_string_t *value); + +/** + * Deserialize a `mta_model_capabilities_t` from a JSON string. + */ +enum mta_status_t mta_model_capabilities_from_json(const char *json, + struct mta_model_capabilities_t **capabilities); + +/** + * Free a `mta_model_capabilities_t` previously created by any + * `mta_model_capabilities_*` function. + */ +enum mta_status_t mta_model_capabilities_free(struct mta_model_capabilities_t *capabilities); + +/** + * Serialize a `mta_model_capabilities_t` to a JSON string. + */ +enum mta_status_t mta_model_capabilities_to_json(const struct mta_model_capabilities_t *capabilities, + mta_string_t *json); + +/** + * Get the interaction range of a model. + */ +enum mta_status_t mta_model_capabilities_get_interaction_range(const struct mta_model_capabilities_t *capabilities, + double *interaction_range); + +/** + * Get the length unit of a model. + */ +enum mta_status_t mta_model_capabilities_get_length_unit(const struct mta_model_capabilities_t *capabilities, + mta_string_t *length_unit); + +/** + * Get the data type of a model from a `mta_model_capabilities_t`. + */ +enum mta_status_t mta_model_capabilities_get_dtype(const struct mta_model_capabilities_t *capabilities, + enum mta_dtype_t *dtype); + +/** + * Get the number of outputs a model can compute from a + * `mta_model_capabilities_t`. + */ +enum mta_status_t mta_model_capabilities_outputs_count(const struct mta_model_capabilities_t *capabilities, + uintptr_t *count); + +/** + * Get a JSON-serialized `Quantity` by index from a `mta_model_capabilities_t`. + */ +enum mta_status_t mta_model_capabilities_get_output_json(const struct mta_model_capabilities_t *capabilities, + uintptr_t index, + mta_string_t *output_json); + +/** + * Get the number of supported atomic types. + */ +enum mta_status_t mta_model_capabilities_atomic_types_count(const struct mta_model_capabilities_t *capabilities, + uintptr_t *count); + +/** + * Get an atomic type by index. + */ +enum mta_status_t mta_model_capabilities_get_atomic_type(const struct mta_model_capabilities_t *capabilities, + uintptr_t index, + int64_t *atomic_type); + +/** + * Get the number of supported devices. + */ +enum mta_status_t mta_model_capabilities_supported_devices_count(const struct mta_model_capabilities_t *capabilities, + uintptr_t *count); + +/** + * Get a supported device by index. + */ +enum mta_status_t mta_model_capabilities_get_supported_device(const struct mta_model_capabilities_t *capabilities, + uintptr_t index, + enum mta_device_t *device); + /** * TODO */ diff --git a/metatomic-core/src/c_api/metadata.rs b/metatomic-core/src/c_api/metadata.rs new file mode 100644 index 00000000..9f18bced --- /dev/null +++ b/metatomic-core/src/c_api/metadata.rs @@ -0,0 +1,725 @@ +use std::ffi::{c_char, CStr}; + +use crate::{DType, Error, ModelCapabilities, ModelMetadata, PairListOptions}; +use crate::metadata::References; +use super::{catch_unwind, mta_status_t, mta_string_t}; + +/// Data type for all inputs and outputs. +#[allow(non_camel_case_types)] +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum mta_dtype_t { + /// 32-bit floating point + MTA_DTYPE_FLOAT32 = 0, + /// 64-bit floating point + MTA_DTYPE_FLOAT64 = 1, +} + +impl From for mta_dtype_t { + fn from(dtype: DType) -> Self { + match dtype { + DType::Float32 => mta_dtype_t::MTA_DTYPE_FLOAT32, + DType::Float64 => mta_dtype_t::MTA_DTYPE_FLOAT64, + } + } +} + +impl From for DType { + fn from(dtype: mta_dtype_t) -> Self { + match dtype { + mta_dtype_t::MTA_DTYPE_FLOAT32 => DType::Float32, + mta_dtype_t::MTA_DTYPE_FLOAT64 => DType::Float64, + } + } +} + +/// Device on which a model can run. +/// +/// Match DLpack dtypes +#[allow(non_camel_case_types)] +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum mta_device_t { + /// CPU device + MTA_DEVICE_CPU = 0, + /// CUDA-capable NVIDIA GPU + MTA_DEVICE_CUDA = 1, + /// ROCm-capable AMD GPU + MTA_DEVICE_ROCM = 2, + /// Apple Metal GPU + MTA_DEVICE_METAL = 3, +} + +/// Section of the references stored in a `mta_model_metadata_t`. +#[allow(non_camel_case_types)] +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum mta_references_section_t { + /// References describing the model as a whole (e.g. the paper that + /// introduced the model) + MTA_REFERENCES_MODEL = 0, + /// References describing the model architecture (e.g. papers that + /// describe the mathematical form) + MTA_REFERENCES_ARCHITECTURE = 1, + /// References describing the model implementation (e.g. a link to the + /// source-code repository) + MTA_REFERENCES_IMPLEMENTATION = 2, +} + +/// Opaque handle for a `PairListOptionsOptions` pair list (neighbor list) requested by a model. +#[allow(non_camel_case_types)] +pub struct mta_pair_list_options_t(PairListOptions); + +/// Create a new `mta_pair_list_options_t` with an empty requestors list. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_create( + cutoff: f64, + full_list: bool, + strict: bool, + options: *mut *mut mta_pair_list_options_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options); + + if !cutoff.is_finite() || cutoff <= 0.0 { + return Err(Error::InvalidParameter( + "cutoff must be a finite positive number".into(), + )); + } + + let inner = PairListOptions { + cutoff, + full_list, + strict, + requestors: Vec::new(), + }; + + *options = Box::into_raw(Box::new(mta_pair_list_options_t(inner))); + Ok(()) + }) +} + + +/// Deserialize a `mta_pair_list_options_t` from a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_from_json( + json: *const c_char, + options: *mut *mut mta_pair_list_options_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(json); + + let s = CStr::from_ptr(json).to_str().map_err(|_| { + Error::InvalidParameter("json is not valid UTF-8".into()) + })?; + + let json_val = json::parse(s).map_err(|e| { + Error::Serialization(format!("invalid JSON: {e}")) + })?; + + let inner = PairListOptions::try_from(&json_val)?; + + *options = Box::into_raw(Box::new(mta_pair_list_options_t(inner))); + Ok(()) + }) +} + +/// Free a `mta_pair_list_options_t` object +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_free( + options: *mut mta_pair_list_options_t, +) -> mta_status_t { + catch_unwind(|| { + if !options.is_null() { + let _ = Box::from_raw(options); + } + Ok(()) + }) +} + +/// Serialize a `mta_pair_list_options_t` to a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_to_json( + options: *const mta_pair_list_options_t, + json: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, json); + + let json_val = json::JsonValue::from((*options).0.clone()); + *json = mta_string_t::new(json_val.dump()); + Ok(()) + }) +} + +/// Get the type discriminator of a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_get_type( + options: *const mta_pair_list_options_t, + type_: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, type_); + + *type_ = mta_string_t::new("metatomic_pair_options"); + Ok(()) + }) +} + +/// Get the cutoff radius from a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_get_cutoff( + options: *const mta_pair_list_options_t, + cutoff: *mut f64, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, cutoff); + *cutoff = (*options).0.cutoff; + Ok(()) + }) +} + +/// Get the `full_list` flag from a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_get_full_list( + options: *const mta_pair_list_options_t, + full_list: *mut bool, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, full_list); + *full_list = (*options).0.full_list; + Ok(()) + }) +} + +/// Get the `strict` flag from a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_get_strict( + options: *const mta_pair_list_options_t, + strict: *mut bool, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, strict); + *strict = (*options).0.strict; + Ok(()) + }) +} + +/// Get the number of requestors stored in a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_requestors_count( + options: *const mta_pair_list_options_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, count); + *count = (*options).0.requestors.len(); + Ok(()) + }) +} + +/// Get a requestor string by index from a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_get_requestor( + options: *const mta_pair_list_options_t, + index: usize, + requestor: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, requestor); + + let requestors = &(*options).0.requestors; + if index >= requestors.len() { + return Err(Error::InvalidParameter(format!( + "requestor index {} is out of bounds, there are {} requestors", + index, + requestors.len() + ))); + } + + *requestor = mta_string_t::new(requestors[index].clone()); + Ok(()) + }) +} + +/// Add a requestor string to a `mta_pair_list_options_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_pair_list_options_add_requestor( + options: *mut mta_pair_list_options_t, + requestor: *const c_char, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(options, requestor); + + let s = CStr::from_ptr(requestor).to_str().map_err(|_| { + Error::InvalidParameter("requestor is not valid UTF-8".into()) + })?; + + if !s.is_empty() { + let requestors = &mut (*options).0.requestors; + if !requestors.iter().any(|r| r == s) { + requestors.push(s.to_string()); + } + } + + Ok(()) + }) +} + +/// Opaque handle to metadata describing a model: name, authors, description, references, and +/// arbitrary extra key-value pairs. +#[allow(non_camel_case_types)] +pub struct mta_model_metadata_t(ModelMetadata); + +/// Get `mta_model_metadata_t` from a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_from_json( + json: *const c_char, + metadata: *mut *mut mta_model_metadata_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(json, metadata); + + let s = CStr::from_ptr(json).to_str().map_err(|_| { + Error::InvalidParameter("json is not valid UTF-8".into()) + })?; + + let json_val = json::parse(s).map_err(|e| { + Error::Serialization(format!("invalid JSON for ModelMetadata: {e}")) + })?; + + let inner = ModelMetadata::try_from(&json_val)?; + + *metadata = Box::into_raw(Box::new(mta_model_metadata_t(inner))); + Ok(()) + }) +} + +/// Free a `mta_model_metadata_t` previously created by any +/// `mta_model_metadata_*` function. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_free( + metadata: *mut mta_model_metadata_t, +) -> mta_status_t { + catch_unwind(|| { + if !metadata.is_null() { + let _ = Box::from_raw(metadata); + } + Ok(()) + }) +} + +/// Serialize `mta_model_metadata_t` to a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_to_json( + metadata: *const mta_model_metadata_t, + json: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, json); + + let json_val = json::JsonValue::from((*metadata).0.clone()); + *json = mta_string_t::new(json_val.dump()); + Ok(()) + }) +} + +/// Get the name of a model from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_name( + metadata: *const mta_model_metadata_t, + name: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, name); + *name = mta_string_t::new((*metadata).0.name.clone()); + Ok(()) + }) +} + +/// Get the description of a model from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_description( + metadata: *const mta_model_metadata_t, + description: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, description); + *description = mta_string_t::new((*metadata).0.description.clone()); + Ok(()) + }) +} + +/// Get the number of authors of a model from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_authors_count( + metadata: *const mta_model_metadata_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, count); + *count = (*metadata).0.authors.len(); + Ok(()) + }) +} + +/// Get an author string by index from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_author( + metadata: *const mta_model_metadata_t, + index: usize, + author: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, author); + + let authors = &(*metadata).0.authors; + if index >= authors.len() { + return Err(Error::InvalidParameter(format!( + "author index {} is out of bounds, there are {} authors", + index, + authors.len() + ))); + } + + *author = mta_string_t::new(authors[index].clone()); + Ok(()) + }) +} + +/// Get the number of references in a section of a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_references_count( + metadata: *const mta_model_metadata_t, + section: mta_references_section_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, count); + *count = references_section(&(*metadata).0.references, section).len(); + Ok(()) + }) +} + +/// Get a reference string by index from a section of a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_reference( + metadata: *const mta_model_metadata_t, + section: mta_references_section_t, + index: usize, + reference: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, reference); + + let refs = references_section(&(*metadata).0.references, section); + if index >= refs.len() { + return Err(Error::InvalidParameter(format!( + "reference index {} is out of bounds, there are {} references in this section", + index, + refs.len() + ))); + } + + *reference = mta_string_t::new(refs[index].clone()); + Ok(()) + }) +} + +/// Get the number of entries in the `extra` key-value map of a +/// `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_extra_count( + metadata: *const mta_model_metadata_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, count); + *count = (*metadata).0.extra.len(); + Ok(()) + }) +} + +/// Get an extra metadata key by position from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_extra_key( + metadata: *const mta_model_metadata_t, + index: usize, + key: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, key); + + let extra = &(*metadata).0.extra; + let k = extra.keys().nth(index).ok_or_else(|| { + Error::InvalidParameter(format!( + "extra key index {} is out of bounds, there are {} extra entries", + index, + extra.len() + )) + })?; + + *key = mta_string_t::new(k.clone()); + Ok(()) + }) +} + +/// Get an extra metadata value by key from a `mta_model_metadata_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_metadata_get_extra_value( + metadata: *const mta_model_metadata_t, + key: *const c_char, + value: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(metadata, key, value); + + let key_str = CStr::from_ptr(key).to_str().map_err(|_| { + Error::InvalidParameter("key is not valid UTF-8".into()) + })?; + + let v = (*metadata).0.extra.get(key_str).ok_or_else(|| { + Error::InvalidParameter(format!( + "key '{}' not found in extra metadata", + key_str + )) + })?; + + *value = mta_string_t::new(v.clone()); + Ok(()) + }) +} + +/// Opaque handle to model capabilities of a model: which outputs it can compute, which atomic types +/// it supports, its interaction range, supported devices, and data type. +#[allow(non_camel_case_types)] +pub struct mta_model_capabilities_t(ModelCapabilities); + +/// Deserialize a `mta_model_capabilities_t` from a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_from_json( + json: *const c_char, + capabilities: *mut *mut mta_model_capabilities_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(json, capabilities); + + let s = CStr::from_ptr(json).to_str().map_err(|_| { + Error::InvalidParameter("json is not valid UTF-8".into()) + })?; + + let json_val = json::parse(s).map_err(|e| { + Error::Serialization(format!("invalid JSON for ModelCapabilities: {e}")) + })?; + + let inner = ModelCapabilities::try_from(&json_val)?; + + *capabilities = Box::into_raw(Box::new(mta_model_capabilities_t(inner))); + Ok(()) + }) +} + +/// Free a `mta_model_capabilities_t` previously created by any +/// `mta_model_capabilities_*` function. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_free( + capabilities: *mut mta_model_capabilities_t, +) -> mta_status_t { + catch_unwind(|| { + if !capabilities.is_null() { + let _ = Box::from_raw(capabilities); + } + Ok(()) + }) +} + +/// Serialize a `mta_model_capabilities_t` to a JSON string. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_to_json( + capabilities: *const mta_model_capabilities_t, + json: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, json); + + let json_val = json::JsonValue::from((*capabilities).0.clone()); + *json = mta_string_t::new(json_val.dump()); + Ok(()) + }) +} + +/// Get the interaction range of a model. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_interaction_range( + capabilities: *const mta_model_capabilities_t, + interaction_range: *mut f64, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, interaction_range); + *interaction_range = (*capabilities).0.interaction_range; + Ok(()) + }) +} + +/// Get the length unit of a model. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_length_unit( + capabilities: *const mta_model_capabilities_t, + length_unit: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, length_unit); + *length_unit = mta_string_t::new((*capabilities).0.length_unit.clone()); + Ok(()) + }) +} + +/// Get the data type of a model from a `mta_model_capabilities_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_dtype( + capabilities: *const mta_model_capabilities_t, + dtype: *mut mta_dtype_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, dtype); + *dtype = mta_dtype_t::from((*capabilities).0.dtype); + Ok(()) + }) +} + +/// Get the number of outputs a model can compute from a +/// `mta_model_capabilities_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_outputs_count( + capabilities: *const mta_model_capabilities_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, count); + *count = (*capabilities).0.outputs.len(); + Ok(()) + }) +} + +/// Get a JSON-serialized `Quantity` by index from a `mta_model_capabilities_t`. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_output_json( + capabilities: *const mta_model_capabilities_t, + index: usize, + output_json: *mut mta_string_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, output_json); + + let outputs = &(*capabilities).0.outputs; + if index >= outputs.len() { + return Err(Error::InvalidParameter(format!( + "output index {} is out of bounds, there are {} outputs", + index, + outputs.len() + ))); + } + + let json_val = json::JsonValue::from(outputs[index].clone()); + *output_json = mta_string_t::new(json_val.dump()); + Ok(()) + }) +} + +/// Get the number of supported atomic types. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_atomic_types_count( + capabilities: *const mta_model_capabilities_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, count); + *count = (*capabilities).0.atomic_types.len(); + Ok(()) + }) +} + +/// Get an atomic type by index. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_atomic_type( + capabilities: *const mta_model_capabilities_t, + index: usize, + atomic_type: *mut i64, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, atomic_type); + + let atomic_types = &(*capabilities).0.atomic_types; + if index >= atomic_types.len() { + return Err(Error::InvalidParameter(format!( + "atomic type index {} is out of bounds, there are {} atomic types", + index, + atomic_types.len() + ))); + } + + *atomic_type = atomic_types[index]; + Ok(()) + }) +} + +/// Get the number of supported devices. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_supported_devices_count( + capabilities: *const mta_model_capabilities_t, + count: *mut usize, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, count); + *count = (*capabilities).0.supported_devices.len(); + Ok(()) + }) +} + +/// Get a supported device by index. +#[no_mangle] +pub unsafe extern "C" fn mta_model_capabilities_get_supported_device( + capabilities: *const mta_model_capabilities_t, + index: usize, + device: *mut mta_device_t, +) -> mta_status_t { + catch_unwind(|| { + check_pointers_non_null!(capabilities, device); + + let devices = &(*capabilities).0.supported_devices; + if index >= devices.len() { + return Err(Error::InvalidParameter(format!( + "device index {} is out of bounds, there are {} supported devices", + index, + devices.len() + ))); + } + + let json_val = json::JsonValue::from(devices[index]); + let s = json_val.as_str().ok_or_else(|| { + Error::Serialization("Device JSON serialization did not produce a string".into()) + })?; + + let mta_dev = match s { + "cpu" => mta_device_t::MTA_DEVICE_CPU, + "cuda" => mta_device_t::MTA_DEVICE_CUDA, + "rocm" => mta_device_t::MTA_DEVICE_ROCM, + "metal" => mta_device_t::MTA_DEVICE_METAL, + _ => return Err(Error::InvalidParameter(format!( + "unknown device type '{}'", s + ))), + }; + + *device = mta_dev; + Ok(()) + }) +} + +/// Return reference section within a `References` struct based on the `mta_references_section_t` enum. +fn references_section(refs: &References, section: mta_references_section_t) -> &[String] { + match section { + mta_references_section_t::MTA_REFERENCES_MODEL => &refs.model, + mta_references_section_t::MTA_REFERENCES_ARCHITECTURE => &refs.architecture, + mta_references_section_t::MTA_REFERENCES_IMPLEMENTATION => &refs.implementation, + } +} diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index 235b0029..57d42135 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -8,6 +8,13 @@ mod utils; pub use self::utils::mta_string_t; pub use self::utils::{mta_string_create, mta_string_free, mta_string_view}; +mod metadata; +pub use self::metadata::{ + mta_references_section_t, + mta_pair_list_options_t, + mta_model_metadata_t, +}; + mod system; pub use self::system::mta_system_t; diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index 3bf1ce00..c197c8b8 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -134,13 +134,13 @@ impl<'a> TryFrom<&'a JsonValue> for PairListOptions { pub struct References { /// The references about the model as a whole, e.g. a paper describing the /// model or a website presenting it. - model: Vec, + pub(crate) model: Vec, /// The references about the architecture of the model, e.g. papers /// describing the mathematical form of the model. - architecture: Vec, + pub(crate) architecture: Vec, /// The references about the implementation of the model, e.g. a link to /// the source code repository or a paper describing the software. - implementation: Vec, + pub(crate) implementation: Vec, } impl From for JsonValue { diff --git a/metatomic-core/tests/CMakeLists.txt b/metatomic-core/tests/CMakeLists.txt index 77e251dd..f72e607a 100644 --- a/metatomic-core/tests/CMakeLists.txt +++ b/metatomic-core/tests/CMakeLists.txt @@ -2,19 +2,19 @@ cmake_minimum_required(VERSION 3.22) project(metatomic-tests) if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) - if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") - message(STATUS "Setting build type to 'release' as none was specified.") - set(CMAKE_BUILD_TYPE "release" + if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") + message(STATUS "Setting build type to 'release' as none was specified.") + set(CMAKE_BUILD_TYPE "release" CACHE STRING "Choose the type of build, options are: debug or release" FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug) - endif() + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug) + endif() endif() if (MINGW) - # CI can't find libsdc++, so we statically link it - set(CMAKE_EXE_LINKER_FLAGS "-static-libstdc++") + # CI can't find libsdc++, so we statically link it + set(CMAKE_EXE_LINKER_FLAGS "-static-libstdc++") endif() add_subdirectory(../ metatomic) @@ -25,30 +25,30 @@ add_subdirectory(external) find_program(VALGRIND valgrind) if (VALGRIND) - if (NOT "$ENV{METATOMIC_DISABLE_VALGRIND}" EQUAL "1") - message(STATUS "Running tests using valgrind") - set(TEST_COMMAND + if (NOT "$ENV{METATOMIC_DISABLE_VALGRIND}" EQUAL "1") + message(STATUS "Running tests using valgrind") + set(TEST_COMMAND "${VALGRIND}" "--tool=memcheck" "--dsymutil=yes" "--error-exitcode=125" "--leak-check=full" "--show-leak-kinds=definite,indirect,possible" "--track-origins=yes" "--gen-suppressions=all" ) - endif() + endif() else() - set(TEST_COMMAND "") + set(TEST_COMMAND "") endif() if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Weverything") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat-pedantic") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-shadow-uncaptured-local") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unsafe-buffer-usage") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-poison-system-directories") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-allocator-wrappers") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Weverything") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat-pedantic") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-shadow-uncaptured-local") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unsafe-buffer-usage") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-poison-system-directories") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-allocator-wrappers") endif() @@ -57,11 +57,11 @@ add_subdirectory(test-plugins) file(GLOB ALL_TESTS *.cpp) foreach(_file_ ${ALL_TESTS}) - get_filename_component(_name_ ${_file_} NAME_WE) - add_executable(${_name_} ${_file_}) - target_link_libraries(${_name_} metatomic catch) + get_filename_component(_name_ ${_file_} NAME_WE) + add_executable(${_name_} ${_file_}) + target_link_libraries(${_name_} metatomic catch) - set_target_properties(${_name_} PROPERTIES + set_target_properties(${_name_} PROPERTIES # Ensure that the binaries find the right shared library. # # Without this, when configuring with cmake before the library is built, @@ -71,19 +71,19 @@ foreach(_file_ ${ALL_TESTS}) NO_SYSTEM_FROM_IMPORTED ON ) - target_compile_definitions(${_name_} PRIVATE PLUGIN_DIR="${CMAKE_CURRENT_BINARY_DIR}/test-plugins") + target_compile_definitions(${_name_} PRIVATE PLUGIN_DIR="${CMAKE_CURRENT_BINARY_DIR}/test-plugins") - add_test( + add_test( NAME ${_name_} COMMAND ${TEST_COMMAND} $ ) - if(WIN32) - # We need to set the path to allow access to metatomic.dll - # this does a similar job to the BUILD_RPATH above - STRING(REPLACE ";" "\\;" PATH_STRING "$ENV{PATH}") - set_tests_properties(${_name_} PROPERTIES + if(WIN32) + # We need to set the path to allow access to metatomic.dll + # this does a similar job to the BUILD_RPATH above + STRING(REPLACE ";" "\\;" PATH_STRING "$ENV{PATH}") + set_tests_properties(${_name_} PROPERTIES ENVIRONMENT "PATH=${PATH_STRING}\;$" ) - endif() + endif() endforeach() diff --git a/metatomic-core/tests/metadata.cpp b/metatomic-core/tests/metadata.cpp new file mode 100644 index 00000000..cace7713 --- /dev/null +++ b/metatomic-core/tests/metadata.cpp @@ -0,0 +1,449 @@ +#include + +#include "metatomic.h" +#include "metatomic.hpp" + +TEST_CASE("model metadata formatting") { + std::string json =R"({ + "type": "metatomic_model_metadata", + "name": "name", + "description": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation.", + "authors": ["Short author", "Some extremely long author that will take more than one line in the printed output"], + "references": { + "architecture": ["ref-2", "ref-3"], + "model": ["a very long reference that will take more than one line in the printed output"], + "implementation": [] + }, + "extra": {} +})"; + auto* mta_string = mta_string_create(""); + REQUIRE(mta_string != nullptr); + auto status = mta_format_metadata(json.c_str(), &mta_string); + REQUIRE(status == MTA_SUCCESS); + const auto* expected = R"(This is the name model +====================== + +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor +incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis +nostrud exercitation. + +Model authors +------------- + +- Short author +- Some extremely long author that will take more than one line in the printed + output + +Model references +---------------- + +Please cite the following references when using this model: +- about this specific model: + * a very long reference that will take more than one line in the printed + output +- about the architecture of this model: + * ref-2 + * ref-3 +)"; + CHECK(std::string(mta_string_view(mta_string)) == expected); + mta_string_free(mta_string); +} + +TEST_CASE("pair list metadata", "C API"){ + SECTION("JSON serialization"){ + mta_pair_list_options_t *metadata = nullptr; + auto status = mta_pair_list_options_create(0.42, true, true, &metadata); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(metadata != nullptr); + + mta_string_t json = nullptr; + status = mta_pair_list_options_to_json(metadata, &json); + REQUIRE(status == MTA_SUCCESS); + + // cutoff is 0.42 in double precision converted to hex + CHECK(std::string(mta_string_view(json)) == R"({"type":"metatomic_pair_options","cutoff":"0x3fdae147ae147ae1","full_list":true,"strict":true,"requestors":[]})"); + mta_string_free(json); + mta_pair_list_options_free(metadata); + } + SECTION("JSON deserialization"){ + const char* json = "{\"type\":\"metatomic_pair_options\",\"cutoff\":\"0x3fdae147ae147ae1\",\"full_list\":true,\"strict\":true,\"requestors\":[]}"; + mta_pair_list_options_t *options = nullptr; + auto status = mta_pair_list_options_from_json(json, &options); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(options != nullptr); + + mta_string_t type = nullptr; + status = mta_pair_list_options_get_type(options, &type); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(type)) == "metatomic_pair_options"); + mta_string_free(type); + + double cutoff = -1.0; + status = mta_pair_list_options_get_cutoff(options, &cutoff); + REQUIRE(status == MTA_SUCCESS); + CHECK(cutoff == Approx(0.42).epsilon(1e-15)); + + bool full_list = false; + status = mta_pair_list_options_get_full_list(options, &full_list); + REQUIRE(status == MTA_SUCCESS); + CHECK(full_list == true); + + bool strict = false; + status = mta_pair_list_options_get_strict(options, &strict); + REQUIRE(status == MTA_SUCCESS); + CHECK(strict == true); + + size_t num_requestors = 0; + status = mta_pair_list_options_requestors_count(options, &num_requestors); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_requestors == 0); + + mta_pair_list_options_free(options); + } + SECTION("JSON deserialization with wrong JSON"){ + // boolean values are erroneously stored as strings + const char* json = "{\"type\":\"metatomic_pair_options\",\"cutoff\":\"0x3fdae147ae147ae1\",\"full_list\":\"true\",\"strict\":\"true\",\"requestors\":[]}"; + + mta_pair_list_options_t *options = nullptr; + auto status = mta_pair_list_options_from_json(json, &options); + CHECK(status == MTA_SERIALIZATION_ERROR); + } + SECTION("requestors"){ + mta_pair_list_options_t *options = nullptr; + auto status = mta_pair_list_options_create(0.42, true, true, &options); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(options != nullptr); + + const char* requestor1 = "requestor1"; + status = mta_pair_list_options_add_requestor(options, requestor1); + REQUIRE(status == MTA_SUCCESS); + + const char* requestor2 = "requestor2"; + status = mta_pair_list_options_add_requestor(options, requestor2); + REQUIRE(status == MTA_SUCCESS); + + mta_string_t requestor = nullptr; + size_t num_requestors = 0; + status = mta_pair_list_options_requestors_count(options, &num_requestors); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_requestors == 2); + + status = mta_pair_list_options_get_requestor(options, 0, &requestor); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(requestor)) == requestor1); + mta_string_free(requestor); + + status = mta_pair_list_options_get_requestor(options, 1, &requestor); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(requestor)) == requestor2); + mta_string_free(requestor); + + mta_pair_list_options_free(options); + } +} + +TEST_CASE("model metadata", "C API"){ + const char* json = R"({ + "type": "metatomic_model_metadata", + "name": "model name", + "description": "model name is awesome", + "authors": ["Author One", "Author Two"], + "references": { + "architecture": ["reference one", "reference two", "refrerence three"], + "model": ["model reference"], + "implementation": [] + }, + "extra": { + "foo": "bar" + } + })"; + + SECTION("JSON deserialization"){ + mta_model_metadata_t *metadata = nullptr; + auto status = mta_model_metadata_from_json(json, &metadata); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(metadata != nullptr); + + mta_string_t name = nullptr; + status = mta_model_metadata_get_name(metadata, &name); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(name)) == "model name"); + mta_string_free(name); + + mta_string_t description = nullptr; + status = mta_model_metadata_get_description(metadata, &description); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(description)) == "model name is awesome"); + mta_string_free(description); + + size_t num_authors = 0; + status = mta_model_metadata_authors_count(metadata, &num_authors); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_authors == 2); + + mta_string_t author0 = nullptr; + status = mta_model_metadata_get_author(metadata, 0, &author0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(author0)) == "Author One"); + mta_string_free(author0); + + mta_string_t author1 = nullptr; + status = mta_model_metadata_get_author(metadata, 1, &author1); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(author1)) == "Author Two"); + mta_string_free(author1); + + size_t num_arch_refs = 0; + status = mta_model_metadata_references_count(metadata, MTA_REFERENCES_ARCHITECTURE, &num_arch_refs); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_arch_refs == 3); + + size_t num_model_refs = 0; + status = mta_model_metadata_references_count(metadata, MTA_REFERENCES_MODEL, &num_model_refs); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_model_refs == 1); + + size_t num_impl_refs = 0; + status = mta_model_metadata_references_count(metadata, MTA_REFERENCES_IMPLEMENTATION, &num_impl_refs); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_impl_refs == 0); + + mta_string_t arch_ref0 = nullptr; + status = mta_model_metadata_get_reference(metadata, MTA_REFERENCES_ARCHITECTURE, 0, &arch_ref0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(arch_ref0)) == "reference one"); + mta_string_free(arch_ref0); + + mta_string_t arch_ref2 = nullptr; + status = mta_model_metadata_get_reference(metadata, MTA_REFERENCES_ARCHITECTURE, 2, &arch_ref2); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(arch_ref2)) == "refrerence three"); + mta_string_free(arch_ref2); + + mta_string_t model_ref0 = nullptr; + status = mta_model_metadata_get_reference(metadata, MTA_REFERENCES_MODEL, 0, &model_ref0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(model_ref0)) == "model reference"); + mta_string_free(model_ref0); + + size_t num_extra = 0; + status = mta_model_metadata_extra_count(metadata, &num_extra); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_extra == 1); + + mta_string_t extra_key0 = nullptr; + status = mta_model_metadata_get_extra_key(metadata, 0, &extra_key0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(extra_key0)) == "foo"); + mta_string_free(extra_key0); + + mta_string_t extra_value0 = nullptr; + status = mta_model_metadata_get_extra_value(metadata, "foo", &extra_value0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(extra_value0)) == "bar"); + mta_string_free(extra_value0); + + mta_model_metadata_free(metadata); + } + + SECTION("JSON serialization"){ + mta_model_metadata_t *metadata = nullptr; + auto status = mta_model_metadata_from_json(json, &metadata); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(metadata != nullptr); + + mta_string_t serialized = nullptr; + status = mta_model_metadata_to_json(metadata, &serialized); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(serialized != nullptr); + + CHECK(std::string(mta_string_view(serialized)) == + R"({"type":"metatomic_model_metadata","name":"model name","authors":["Author One","Author Two"],"description":"model name is awesome","references":{"model":["model reference"],"architecture":["reference one","reference two","refrerence three"],"implementation":[]},"extra":{"foo":"bar"}})"); + + mta_string_free(serialized); + mta_model_metadata_free(metadata); + } + + SECTION("Check out of bound requests"){ + mta_model_metadata_t *metadata = nullptr; + auto status = mta_model_metadata_from_json(json, &metadata); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(metadata != nullptr); + + mta_string_t author = nullptr; + status = mta_model_metadata_get_author(metadata, 2, &author); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + mta_string_t ref = nullptr; + status = mta_model_metadata_get_reference(metadata, MTA_REFERENCES_MODEL, 1, &ref); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + status = mta_model_metadata_get_reference(metadata, MTA_REFERENCES_IMPLEMENTATION, 0, &ref); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + mta_string_t extra_key = nullptr; + status = mta_model_metadata_get_extra_key(metadata, 1, &extra_key); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + mta_model_metadata_free(metadata); + } +} + +TEST_CASE("model capabilities", "C API"){ + const char* json = R"({ + "type": "metatomic_model_capabilities", + "outputs": [{ + "type": "metatomic_quantity", + "name": "energy", + "unit": "eV", + "description": "total energy", + "gradients": ["positions"], + "sample_kind": "system" + }, { + "type": "metatomic_quantity", + "name": "charge", + "unit": "e", + "gradients": [], + "sample_kind": "atom" + }], + "atomic_types": [1, 6, 8], + "interaction_range": 5.5, + "length_unit": "Angstrom", + "supported_devices": ["cpu", "cuda"], + "dtype": "float32" + })"; + + SECTION("JSON deserialization"){ + mta_model_capabilities_t *capabilities = nullptr; + auto status = mta_model_capabilities_from_json(json, &capabilities); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(capabilities != nullptr); + + double interaction_range = -1.0; + status = mta_model_capabilities_get_interaction_range(capabilities, &interaction_range); + REQUIRE(status == MTA_SUCCESS); + CHECK(interaction_range == Approx(5.5).epsilon(1e-15)); + + mta_string_t length_unit = nullptr; + status = mta_model_capabilities_get_length_unit(capabilities, &length_unit); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(length_unit)) == "Angstrom"); + mta_string_free(length_unit); + + mta_dtype_t dtype = MTA_DTYPE_FLOAT64; + status = mta_model_capabilities_get_dtype(capabilities, &dtype); + REQUIRE(status == MTA_SUCCESS); + CHECK(dtype == MTA_DTYPE_FLOAT32); + + size_t num_outputs = 0; + status = mta_model_capabilities_outputs_count(capabilities, &num_outputs); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_outputs == 2); + + mta_string_t output0 = nullptr; + status = mta_model_capabilities_get_output_json(capabilities, 0, &output0); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(output0)) == + R"({"type":"metatomic_quantity","name":"energy","unit":"eV","description":"total energy","gradients":["positions"],"sample_kind":"system"})"); + mta_string_free(output0); + + mta_string_t output1 = nullptr; + status = mta_model_capabilities_get_output_json(capabilities, 1, &output1); + REQUIRE(status == MTA_SUCCESS); + CHECK(std::string(mta_string_view(output1)) == + R"({"type":"metatomic_quantity","name":"charge","unit":"e","gradients":[],"sample_kind":"atom"})"); + mta_string_free(output1); + + size_t num_atomic_types = 0; + status = mta_model_capabilities_atomic_types_count(capabilities, &num_atomic_types); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_atomic_types == 3); + + int64_t atomic_type0 = -1; + status = mta_model_capabilities_get_atomic_type(capabilities, 0, &atomic_type0); + REQUIRE(status == MTA_SUCCESS); + CHECK(atomic_type0 == 1); + + int64_t atomic_type1 = -1; + status = mta_model_capabilities_get_atomic_type(capabilities, 1, &atomic_type1); + REQUIRE(status == MTA_SUCCESS); + CHECK(atomic_type1 == 6); + + int64_t atomic_type2 = -1; + status = mta_model_capabilities_get_atomic_type(capabilities, 2, &atomic_type2); + REQUIRE(status == MTA_SUCCESS); + CHECK(atomic_type2 == 8); + + size_t num_devices = 0; + status = mta_model_capabilities_supported_devices_count(capabilities, &num_devices); + REQUIRE(status == MTA_SUCCESS); + CHECK(num_devices == 2); + + mta_device_t device0 = MTA_DEVICE_CUDA; + status = mta_model_capabilities_get_supported_device(capabilities, 0, &device0); + REQUIRE(status == MTA_SUCCESS); + CHECK(device0 == MTA_DEVICE_CPU); + + mta_device_t device1 = MTA_DEVICE_CPU; + status = mta_model_capabilities_get_supported_device(capabilities, 1, &device1); + REQUIRE(status == MTA_SUCCESS); + CHECK(device1 == MTA_DEVICE_CUDA); + + mta_model_capabilities_free(capabilities); + } + + SECTION("JSON serialization"){ + mta_model_capabilities_t *capabilities = nullptr; + auto status = mta_model_capabilities_from_json(json, &capabilities); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(capabilities != nullptr); + + mta_string_t serialized = nullptr; + status = mta_model_capabilities_to_json(capabilities, &serialized); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(serialized != nullptr); + + CHECK(std::string(mta_string_view(serialized)) == + R"({"type":"metatomic_model_capabilities","outputs":[{"type":"metatomic_quantity","name":"energy","unit":"eV","description":"total energy","gradients":["positions"],"sample_kind":"system"},{"type":"metatomic_quantity","name":"charge","unit":"e","gradients":[],"sample_kind":"atom"}],"atomic_types":[1,6,8],"interaction_range":5.5,"length_unit":"Angstrom","supported_devices":["cpu","cuda"],"dtype":"float32"})"); + + mta_string_free(serialized); + mta_model_capabilities_free(capabilities); + } + + SECTION("Check out of bound requests"){ + mta_model_capabilities_t *capabilities = nullptr; + auto status = mta_model_capabilities_from_json(json, &capabilities); + REQUIRE(status == MTA_SUCCESS); + REQUIRE(capabilities != nullptr); + + mta_string_t output = nullptr; + status = mta_model_capabilities_get_output_json(capabilities, 2, &output); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + int64_t atomic_type = -1; + status = mta_model_capabilities_get_atomic_type(capabilities, 3, &atomic_type); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + mta_device_t device = MTA_DEVICE_CPU; + status = mta_model_capabilities_get_supported_device(capabilities, 2, &device); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + mta_model_capabilities_free(capabilities); + } + + SECTION("JSON deserialization with wrong type"){ + const char* wrong_json = R"({ + "type": "something-else", + "outputs": [], + "atomic_types": [], + "interaction_range": 0.0, + "length_unit": "Angstrom", + "supported_devices": ["cpu"], + "dtype": "float32" + })"; + + mta_model_capabilities_t *capabilities = nullptr; + auto status = mta_model_capabilities_from_json(wrong_json, &capabilities); + CHECK(status == MTA_SERIALIZATION_ERROR); + } +} diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index 0028f1b8..c014a369 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -92,50 +92,3 @@ TEST_CASE("unit conversion factor") { } } } - - -TEST_CASE("metatdata formatting") { - std::string json =R"({ - "type": "metatomic_model_metadata", - "name": "name", - "description": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation.", - "authors": ["Short author", "Some extremely long author that will take more than one line in the printed output"], - "references": { - "architecture": ["ref-2", "ref-3"], - "model": ["a very long reference that will take more than one line in the printed output"], - "implementation": [] - }, - "extra": {} -})"; - auto* mta_string = mta_string_create(""); - REQUIRE(mta_string != nullptr); - auto status = mta_format_metadata(json.c_str(), &mta_string); - REQUIRE(status == MTA_SUCCESS); - const auto* expected = R"(This is the name model -====================== - -Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor -incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis -nostrud exercitation. - -Model authors -------------- - -- Short author -- Some extremely long author that will take more than one line in the printed - output - -Model references ----------------- - -Please cite the following references when using this model: -- about this specific model: - * a very long reference that will take more than one line in the printed - output -- about the architecture of this model: - * ref-2 - * ref-3 -)"; - CHECK(std::string(mta_string_view(mta_string)) == expected); - mta_string_free(mta_string); -}