From da76c29c465ada8a8e17706809e9b9513070cbae Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 28 May 2026 16:38:52 +0200 Subject: [PATCH 1/5] Implement functions for C++ interface --- metatomic-core/include/metatomic/model.hpp | 225 ++++++++++++++++ metatomic-core/include/metatomic/plugin.hpp | 133 ++++++++++ metatomic-core/include/metatomic/system.hpp | 199 ++++++++++++++ metatomic-core/include/metatomic/utils.hpp | 280 ++++++++++++++++++++ 4 files changed, 837 insertions(+) diff --git a/metatomic-core/include/metatomic/model.hpp b/metatomic-core/include/metatomic/model.hpp index 1cae91bdf..5a4d6bce5 100644 --- a/metatomic-core/include/metatomic/model.hpp +++ b/metatomic-core/include/metatomic/model.hpp @@ -1,7 +1,232 @@ #pragma once +#include +#include +#include + #include +#include + +#include "./system.hpp" +#include "./utils.hpp" namespace metatomic { +/// RAII wrapper around a `mta_model_t`. +class Model final { +public: + /// Create an empty, invalid model. + Model() { + model_ = empty_model(); + } + + /// Take ownership of a raw `mta_model_t`. + explicit Model(mta_model_t model): model_(model) {} + + ~Model() { + this->reset_noexcept(); + } + + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + + Model(Model&& other) noexcept: Model() { + *this = std::move(other); + } + + Model& operator=(Model&& other) noexcept { + if (this != &other) { + this->reset_noexcept(); + model_ = other.model_; + other.model_ = empty_model(); + } + return *this; + } + + /// Does this wrapper contain a model? + bool is_valid() const { + return model_.data != nullptr; + } + + /// Unload the model. + void unload() { + if (model_.data != nullptr && model_.unload != nullptr) { + details::check_status(model_.unload(model_.data)); + } + model_ = empty_model(); + } + + /// Get model metadata as a JSON string. + std::string metadata() const { + this->check_callback(model_.metadata, "metadata"); + + mta_string_t metadata = nullptr; + details::check_status(model_.metadata(model_.data, &metadata)); + return String(metadata).str(); + } + + /// Get supported outputs as a JSON string. + std::string supported_outputs() const { + this->check_callback(model_.supported_outputs, "supported_outputs"); + + mta_string_t outputs = nullptr; + details::check_status(model_.supported_outputs(model_.data, &outputs)); + return String(outputs).str(); + } + + /// Get all pair lists requested by this model, each one serialized as JSON. + std::vector requested_pair_lists() const { + this->check_callback(model_.requested_pair_lists_count, "requested_pair_lists_count"); + this->check_callback(model_.requested_pair_list, "requested_pair_list"); + + uintptr_t count = 0; + details::check_status(model_.requested_pair_lists_count(model_.data, &count)); + + auto result = std::vector(); + result.reserve(count); + for (uintptr_t i=0; i requested_inputs() const { + this->check_callback(model_.requested_inputs_count, "requested_inputs_count"); + this->check_callback(model_.requested_input, "requested_input"); + + uintptr_t count = 0; + details::check_status(model_.requested_inputs_count(model_.data, &count)); + + auto result = std::vector(); + result.reserve(count); + for (uintptr_t i=0; i execute( + const std::vector& systems, + const metatensor::Labels* selected_atoms, + const std::vector& requested_outputs + ) { + this->check_valid(); + + auto c_systems = std::vector(); + c_systems.reserve(systems.size()); + for (const auto* system: systems) { + details::check_pointer(system); + c_systems.push_back(system->as_mta_system_t()); + } + + auto c_requested_outputs = std::vector(); + c_requested_outputs.reserve(requested_outputs.size()); + for (const auto& output: requested_outputs) { + c_requested_outputs.push_back(output.c_str()); + } + + auto raw_outputs = std::vector(requested_outputs.size(), nullptr); + details::check_status(mta_execute_model( + model_, + c_systems.data(), + c_systems.size(), + selected_atoms == nullptr ? nullptr : selected_atoms->as_mts_labels_t(), + c_requested_outputs.data(), + c_requested_outputs.size(), + raw_outputs.data(), + raw_outputs.size() + )); + + auto outputs = std::vector(); + outputs.reserve(raw_outputs.size()); + + try { + for (auto*& output: raw_outputs) { + details::check_pointer(output); + outputs.emplace_back(output); + output = nullptr; + } + } catch (...) { + for (auto* output: raw_outputs) { + if (output != nullptr) { + (void)mts_tensormap_free(output); + } + } + throw; + } + + return outputs; + } + + /// Execute this model on all atoms. + std::vector execute( + const std::vector& systems, + const std::vector& requested_outputs + ) { + return this->execute(systems, nullptr, requested_outputs); + } + + /// Get the underlying `mta_model_t`. + const mta_model_t& as_mta_model_t() const & { + return model_; + } + + const mta_model_t& as_mta_model_t() && = delete; + + /// Release the underlying `mta_model_t` without unloading it. + mta_model_t release() { + auto model = model_; + model_ = empty_model(); + return model; + } + +private: + static mta_model_t empty_model() { + mta_model_t model; + model.data = nullptr; + model.unload = nullptr; + model.metadata = nullptr; + model.supported_outputs = nullptr; + model.requested_pair_lists_count = nullptr; + model.requested_pair_list = nullptr; + model.requested_inputs_count = nullptr; + model.requested_input = nullptr; + model.execute_inner = nullptr; + return model; + } + + void reset_noexcept() noexcept { + if (model_.data != nullptr && model_.unload != nullptr) { + (void)model_.unload(model_.data); + } + model_ = empty_model(); + } + + void check_valid() const { + if (model_.data == nullptr) { + throw Error("can not use an empty metatomic::Model"); + } + } + + template + void check_callback(Callback callback, const char* name) const { + this->check_valid(); + if (callback == nullptr) { + throw Error("metatomic::Model does not implement " + std::string(name)); + } + } + + mta_model_t model_; +}; + } // namespace metatomic diff --git a/metatomic-core/include/metatomic/plugin.hpp b/metatomic-core/include/metatomic/plugin.hpp index 1cae91bdf..1b0e7ce8d 100644 --- a/metatomic-core/include/metatomic/plugin.hpp +++ b/metatomic-core/include/metatomic/plugin.hpp @@ -1,7 +1,140 @@ #pragma once +#include + +#include +#include +#include +#include + #include +#include "./model.hpp" +#include "./utils.hpp" + namespace metatomic { +/// Abstract base class for metatomic plugins implemented in C++. +class Plugin { +public: + virtual ~Plugin() = default; + + /// Name used to identify this plugin. + virtual std::string name() const = 0; + + /// Load a model from `load_from`, using the provided key/value options. + virtual Model load_model( + const std::string& load_from, + const std::vector& options + ) = 0; +}; + +namespace details { + template + struct PluginRegistration { + static PluginT* plugin; + static const char* name; + + static mta_status_t load_model( + const char* load_from, + const mta_kv_pair_t* options, + uintptr_t options_count, + mta_model_t* model + ) { + return details::catch_exceptions([&]() { + details::check_pointer(plugin); + details::check_pointer(model); + + auto loaded = plugin->load_model( + load_from == nullptr ? "" : load_from, + details::from_c_options(options, options_count) + ); + + *model = loaded.release(); + }); + } + }; + + template + PluginT* PluginRegistration::plugin = nullptr; + + template + const char* PluginRegistration::name = nullptr; +} // namespace details + +/// Register a C++ plugin. +/// +/// Due to the current C plugin ABI, this stores one plugin instance per concrete +/// C++ plugin type. The registered object must outlive all model-loading calls. +template +void register_plugin(PluginT& plugin) { + static_assert( + std::is_base_of::value, + "register_plugin expects a class derived from metatomic::Plugin" + ); + + details::PluginRegistration::plugin = &plugin; + const auto name = plugin.name(); + // The C plugin registry keeps this pointer; allocate stable process-lifetime storage. + auto* name_storage = new char[name.size() + 1]; + std::memcpy(name_storage, name.c_str(), name.size() + 1); + details::PluginRegistration::name = name_storage; + + auto c_plugin = mta_plugin_t{ + details::PluginRegistration::name, + &details::PluginRegistration::load_model, + }; + + mta_register_plugin(c_plugin); +} + +/// Register a raw C plugin. +inline void register_plugin(mta_plugin_t plugin) { + mta_register_plugin(plugin); +} + +/// Load a plugin dynamic library from the given path. +inline void load_plugin(const std::string& path) { + details::check_status(mta_load_plugin(path.c_str())); +} + +/// Load a model using the given plugin. +inline Model load_model( + const std::string& plugin_name, + const std::string& load_from, + const std::vector& options = {} +) { + auto c_options = details::to_c_options(options); + + auto model = mta_model_t{}; + details::check_status(mta_load_model( + plugin_name.c_str(), + load_from.c_str(), + c_options.data(), + c_options.size(), + &model + )); + + return Model(model); +} + +/// Load a model, letting metatomic pick the plugin. +inline Model load_model( + const std::string& load_from, + const std::vector& options = {} +) { + auto c_options = details::to_c_options(options); + + auto model = mta_model_t{}; + details::check_status(mta_load_model( + nullptr, + load_from.c_str(), + c_options.data(), + c_options.size(), + &model + )); + + return Model(model); +} + } // namespace metatomic diff --git a/metatomic-core/include/metatomic/system.hpp b/metatomic-core/include/metatomic/system.hpp index 1cae91bdf..5fde00f3d 100644 --- a/metatomic-core/include/metatomic/system.hpp +++ b/metatomic-core/include/metatomic/system.hpp @@ -1,7 +1,206 @@ #pragma once +#include +#include + #include +#include + +#include "./utils.hpp" namespace metatomic { +/// A System contains all the information about an atomistic system, and should +/// be used as input of atomistic models. +class System final { +public: + /// Create a new `System` from DLPack tensors. + /// + /// Ownership of all DLPack tensors is transferred to the C API. + System( + const std::string& length_unit, + DLManagedTensorVersioned* types, + DLManagedTensorVersioned* positions, + DLManagedTensorVersioned* cell, + DLManagedTensorVersioned* pbc + ): system_(nullptr) { + details::check_status(mta_system_create( + length_unit.c_str(), + types, + positions, + cell, + pbc, + &system_ + )); + details::check_pointer(system_); + } + + ~System() { + if (system_ != nullptr) { + (void)mta_system_free(system_); + } + } + + System(const System&) = delete; + System& operator=(const System&) = delete; + + System(System&& other) noexcept: system_(nullptr) { + *this = std::move(other); + } + + System& operator=(System&& other) noexcept { + if (system_ != nullptr) { + (void)mta_system_free(system_); + } + + system_ = other.system_; + other.system_ = nullptr; + return *this; + } + + /// Get the number of particles in this system. + size_t size() const { + uintptr_t result = 0; + details::check_status(mta_system_size(system_, &result)); + return static_cast(result); + } + + /// Get the length unit used by positions and cell. + std::string length_unit() const { + mta_string_t length_unit = nullptr; + details::check_status(mta_system_get_length_unit(system_, &length_unit)); + return String(length_unit).str(); + } + + /// Get particle types for all particles in the system. + DLPackTensor types() const { + return this->data(MTA_SYSTEM_DATA_TYPES); + } + + /// Get the positions for all particles in the system. + DLPackTensor positions() const { + return this->data(MTA_SYSTEM_DATA_POSITIONS); + } + + /// Get the unit cell/bounding box of the system. + DLPackTensor cell() const { + return this->data(MTA_SYSTEM_DATA_CELL); + } + + /// Get the periodic boundary conditions for the system. + DLPackTensor pbc() const { + return this->data(MTA_SYSTEM_DATA_PBC); + } + + /// Add a new pair list in this system. + /// + /// Ownership of `pairs` is transferred to the C API. + void set_pairs(const std::string& options, mts_block_t* pairs) { + details::check_status(mta_system_set_pairs(system_, options.c_str(), pairs)); + } + + /// Retrieve a previously stored pair list with the given options. + const mts_block_t* pairs_raw(const std::string& options) const { + const mts_block_t* pairs = nullptr; + details::check_status(mta_system_get_pairs(system_, options.c_str(), &pairs)); + details::check_pointer(pairs); + return pairs; + } + + /// Retrieve a previously stored pair list with the given options as a + /// non-owning metatensor view. + metatensor::TensorBlock pairs(const std::string& options) const { + return metatensor::TensorBlock::unsafe_view_from_ptr( + const_cast(this->pairs_raw(options)) + ); + } + + /// Get the options for all pair lists registered with this `System`. + std::vector pairs_options() const { + uintptr_t count = 0; + details::check_status(mta_system_pairs_count(system_, &count)); + + auto result = std::vector(); + result.reserve(count); + for (uintptr_t i=0; i data_names() const { + uintptr_t count = 0; + details::check_status(mta_system_data_count(system_, &count)); + + auto result = std::vector(); + result.reserve(count); + for (uintptr_t i=0; i +#include + +#include +#include +#include +#include +#include + #include namespace metatomic { +/// Exception class used for all errors in metatomic. +class Error: public std::runtime_error { +public: + /// Create a new Error with the given `message`. + explicit Error(const std::string& message): std::runtime_error(message) {} +}; + +/// Key/value pair used when loading models from plugins. +struct KeyValuePair { + std::string key; + std::string value; +}; + +/// RAII wrapper around a `DLManagedTensorVersioned*`. +/// +/// This owns the DLPack managed tensor object, and calls its deleter when the +/// wrapper is destroyed. +class DLPackTensor final { +public: + /// Create an empty wrapper. + DLPackTensor(): tensor_(nullptr) {} + + /// Take ownership of an existing DLPack managed tensor. + explicit DLPackTensor(DLManagedTensorVersioned* tensor): tensor_(tensor) {} + + ~DLPackTensor() { + if (tensor_ != nullptr && tensor_->deleter != nullptr) { + tensor_->deleter(tensor_); + } + } + + DLPackTensor(const DLPackTensor&) = delete; + DLPackTensor& operator=(const DLPackTensor&) = delete; + + DLPackTensor(DLPackTensor&& other) noexcept: DLPackTensor() { + *this = std::move(other); + } + + DLPackTensor& operator=(DLPackTensor&& other) noexcept { + if (tensor_ != nullptr && tensor_->deleter != nullptr) { + tensor_->deleter(tensor_); + } + + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + /// Check if this wrapper contains a tensor. + explicit operator bool() const { + return tensor_ != nullptr; + } + + /// Get the underlying DLPack managed tensor. + DLManagedTensorVersioned* get() const { + return tensor_; + } + + /// Get the underlying DLPack managed tensor. + DLManagedTensorVersioned* as_dlpack() const { + return tensor_; + } + + /// Release the DLPack managed tensor without calling its deleter. + DLManagedTensorVersioned* release() { + auto* tensor = tensor_; + tensor_ = nullptr; + return tensor; + } + +private: + DLManagedTensorVersioned* tensor_; +}; + +namespace details { + /// Check if a return status from the C API indicates an error, and throw a + /// `metatomic::Error` with the last error message if this is the case. + inline void check_status(mta_status_t status) { + if (status == MTA_SUCCESS) { + return; + } + + const char* message = nullptr; + const char* origin = nullptr; + void* data = nullptr; + (void)mta_last_error(&message, &origin, &data); + + if (origin != nullptr && std::strcmp(origin, "C++ exception") == 0 && data != nullptr) { + std::rethrow_exception(*static_cast(data)); + } + + throw Error(message == nullptr ? "unknown error" : message); + } + + /// Call the given `function`, catching any C++ exception and translating it + /// to a metatomic error code. + /// + /// This is required to prevent callbacks unwinding through the C API. + template + inline mta_status_t catch_exceptions(Function function, Args ...args) { + try { + function(std::move(args)...); + return MTA_SUCCESS; + } catch (...) { + auto* exception_ptr = new std::exception_ptr(std::current_exception()); + + const char* message = nullptr; + try { + std::rethrow_exception(*exception_ptr); + } catch (const std::exception& e) { + message = e.what(); + } catch (...) { + message = "C++ code threw an exception that was not a std::exception"; + } + + auto status = mta_set_last_error( + message, + "C++ exception", + exception_ptr, + [](void* ptr) { delete static_cast(ptr); } + ); + + if (status != MTA_SUCCESS) { + std::fprintf( + stderr, + "INTERNAL ERROR: unable to set last error after C++ callback failure (status: %d). ", + static_cast(status) + ); + if (message != nullptr) { + std::fprintf(stderr, "C++ error was: %s\n", message); + } else { + std::fprintf(stderr, "Unknown C++ error\n"); + } + delete exception_ptr; + } + + return MTA_ERROR_OTHER; + } + } + + /// Check if a pointer allocated by the C API is null. + inline void check_pointer(const void* pointer) { + if (pointer != nullptr) { + return; + } + + const char* message = nullptr; + const char* origin = nullptr; + void* data = nullptr; + (void)mta_last_error(&message, &origin, &data); + + if (origin != nullptr && std::strcmp(origin, "C++ exception") == 0 && data != nullptr) { + std::rethrow_exception(*static_cast(data)); + } + + throw Error(message == nullptr ? "received a null pointer from the metatomic C API" : message); + } + + inline std::vector to_c_options(const std::vector& options) { + auto c_options = std::vector(); + c_options.reserve(options.size()); + + for (const auto& option: options) { + c_options.push_back(mta_kv_pair_t{option.key.c_str(), option.value.c_str()}); + } + + return c_options; + } + + inline std::vector from_c_options(const mta_kv_pair_t* options, uintptr_t count) { + auto result = std::vector(); + result.reserve(count); + + if (count != 0) { + check_pointer(options); + } + + for (uintptr_t i=0; ic_str()); + } + +private: + mta_string_t string_; +}; + +/// Get the runtime version of metatomic as a string. +inline std::string version() { + auto* raw = mta_version(); + details::check_pointer(raw); + return std::string(raw); +} + +/// Get the conversion factor from `from_unit` to `to_unit`. +inline double unit_conversion_factor(const std::string& from_unit, const std::string& to_unit) { + double conversion = 0.0; + details::check_status(mta_unit_conversion_factor(from_unit.c_str(), to_unit.c_str(), &conversion)); + return conversion; +} + +/// Format model metadata JSON for display. +inline std::string format_metadata(const std::string& metadata) { + mta_string_t printed = nullptr; + details::check_status(mta_format_metadata(metadata.c_str(), &printed)); + return String(printed).str(); +} + } // namespace metatomic From 398368c8047af9c6724dcc152c04bc1ccfafeece Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 29 May 2026 07:41:03 +0200 Subject: [PATCH 2/5] Implement review suggestions --- metatomic-core/CMakeLists.txt | 16 ++ .../cmake/metatomic-config.in.cmake | 14 +- metatomic-core/include/metatomic.hpp | 1 + metatomic-core/include/metatomic/metadata.hpp | 192 +++++++++++++++ metatomic-core/include/metatomic/model.hpp | 227 +++++++++++++++++- metatomic-core/include/metatomic/plugin.hpp | 70 +++--- metatomic-core/include/metatomic/system.hpp | 20 +- 7 files changed, 494 insertions(+), 46 deletions(-) create mode 100644 metatomic-core/include/metatomic/metadata.hpp diff --git a/metatomic-core/CMakeLists.txt b/metatomic-core/CMakeLists.txt index 717e52e81..5cb30446a 100644 --- a/metatomic-core/CMakeLists.txt +++ b/metatomic-core/CMakeLists.txt @@ -5,6 +5,11 @@ # an easier to use, idiomatic Rust API. cmake_minimum_required(VERSION 3.22) +if (POLICY CMP0135) + # Use download time as timestamp when extracting files from archives. + cmake_policy(SET CMP0135 NEW) +endif() + # Is metatomic the main project configured by the user? Or is this being used # as a submodule/subdirectory? if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) @@ -449,6 +454,17 @@ else() target_link_libraries(metatomic::shared INTERFACE metatensor) endif() +include(FetchContent) + +# JSON library from https://github.com/nlohmann/json +FetchContent_Declare(nlohmann_json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz + URL_HASH SHA256=d6c65aca6b1ed68e7a182f4757257b107ae403032760ed6ef121c9d55e81757d +) +FetchContent_MakeAvailable(nlohmann_json) + +target_link_libraries(metatomic::shared INTERFACE nlohmann_json::nlohmann_json) +target_link_libraries(metatomic::static INTERFACE nlohmann_json::nlohmann_json) if (BUILD_SHARED_LIBS) add_library(metatomic ALIAS metatomic::shared) diff --git a/metatomic-core/cmake/metatomic-config.in.cmake b/metatomic-core/cmake/metatomic-config.in.cmake index 90fca167a..8f2fc5443 100644 --- a/metatomic-core/cmake/metatomic-config.in.cmake +++ b/metatomic-core/cmake/metatomic-config.in.cmake @@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 3.22) include(CMakeFindDependencyMacro) include(FindPackageHandleStandardArgs) +include(FetchContent) if(metatomic_FOUND) return() @@ -15,6 +16,15 @@ enable_language(CXX) set(REQUIRED_METATENSOR_VERSION @REQUIRED_METATENSOR_VERSION@) find_package(metatensor ${REQUIRED_METATENSOR_VERSION} CONFIG REQUIRED) +if (NOT TARGET nlohmann_json::nlohmann_json) + # JSON library from https://github.com/nlohmann/json + FetchContent_Declare(nlohmann_json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz + URL_HASH SHA256=d6c65aca6b1ed68e7a182f4757257b107ae403032760ed6ef121c9d55e81757d + ) + FetchContent_MakeAvailable(nlohmann_json) +endif() + get_filename_component(METATOMIC_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/@PACKAGE_RELATIVE_PATH@" ABSOLUTE) if (WIN32) @@ -46,7 +56,7 @@ if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR @BUILD_SHARED_LIBS@) ) target_compile_features(metatomic::shared INTERFACE cxx_std_17) - target_link_libraries(metatomic::shared INTERFACE metatensor) + target_link_libraries(metatomic::shared INTERFACE metatensor nlohmann_json::nlohmann_json) if (WIN32) if (NOT EXISTS ${METATOMIC_IMPLIB_LOCATION}) @@ -75,7 +85,7 @@ if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR NOT @BUILD_SHARED_LIBS@) ) target_compile_features(metatomic::static INTERFACE cxx_std_17) - target_link_libraries(metatomic::static INTERFACE metatensor) + target_link_libraries(metatomic::static INTERFACE metatensor nlohmann_json::nlohmann_json) endif() # Export either the shared or static library as the metatomic target diff --git a/metatomic-core/include/metatomic.hpp b/metatomic-core/include/metatomic.hpp index 3b5c8ac2a..9eaad1e57 100644 --- a/metatomic-core/include/metatomic.hpp +++ b/metatomic-core/include/metatomic.hpp @@ -1,4 +1,5 @@ #include "metatomic/utils.hpp" // IWYU pragma: export +#include "metatomic/metadata.hpp" // IWYU pragma: export #include "metatomic/system.hpp" // IWYU pragma: export #include "metatomic/model.hpp" // IWYU pragma: export #include "metatomic/plugin.hpp" // IWYU pragma: export diff --git a/metatomic-core/include/metatomic/metadata.hpp b/metatomic-core/include/metatomic/metadata.hpp new file mode 100644 index 000000000..35a92c75d --- /dev/null +++ b/metatomic-core/include/metatomic/metadata.hpp @@ -0,0 +1,192 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace metatomic { + +/// Options for the calculation of a pair list. +class PairListOptions final { +public: + PairListOptions() = default; + + PairListOptions(double cutoff_value, bool full_list_value, bool strict_value, std::vector requestors_list = {}): + cutoff(cutoff_value), + full_list(full_list_value), + strict(strict_value), + requestors(std::move(requestors_list)) + {} + + double cutoff = 0.0; + bool full_list = false; + bool strict = false; + std::vector requestors; + + std::string to_json() const; + static PairListOptions from_json(const std::string& json); +}; + +/// Metadata about a specific exported model. +class ModelMetadata final { +public: + ModelMetadata() = default; + + ModelMetadata( + std::string model_name, + std::string model_description, + std::vector model_authors, + std::map> model_references = {}, + std::map extra_metadata = {} + ): + name(std::move(model_name)), + description(std::move(model_description)), + authors(std::move(model_authors)), + references(std::move(model_references)), + extra(std::move(extra_metadata)) + {} + + std::string name; + std::string description; + std::vector authors; + std::map> references; + std::map extra; + + std::string to_json() const; + static ModelMetadata from_json(const std::string& json); +}; + +/// Description of a quantity used as model input or output. +class Quantity final { +public: + Quantity() = default; + + Quantity( + std::string quantity_name, + std::string quantity_unit, + std::vector quantity_gradients, + std::string quantity_sample_kind + ): + name(std::move(quantity_name)), + unit(std::move(quantity_unit)), + gradients(std::move(quantity_gradients)), + sample_kind(std::move(quantity_sample_kind)) + {} + + std::string name; + std::string unit; + std::vector gradients; + std::string sample_kind; + + std::string to_json() const; + static Quantity from_json(const std::string& json); +}; + +namespace details { + inline std::string double_to_hex(double value) { + uint64_t bits = 0; + static_assert(sizeof(bits) == sizeof(value), "unexpected double size"); + std::memcpy(&bits, &value, sizeof(bits)); + + auto stream = std::ostringstream(); + stream << "0x" << std::hex << bits; + return stream.str(); + } + + inline double hex_to_double(const std::string& value) { + auto bits = uint64_t(0); + auto stream = std::istringstream(value); + if (value.rfind("0x", 0) == 0 || value.rfind("0X", 0) == 0) { + stream.seekg(2); + } + stream >> std::hex >> bits; + + auto result = 0.0; + static_assert(sizeof(bits) == sizeof(result), "unexpected double size"); + std::memcpy(&result, &bits, sizeof(result)); + return result; + } + +} // namespace details + +inline std::string PairListOptions::to_json() const { + return nlohmann::json{ + {"type", "metatomic_pair_options"}, + {"cutoff", details::double_to_hex(cutoff)}, + {"full_list", full_list}, + {"strict", strict}, + {"requestors", requestors}, + }.dump(); +} + +inline PairListOptions PairListOptions::from_json(const std::string& string) { + auto json = nlohmann::json::parse(string); + auto options = PairListOptions(); + + auto cutoff = std::string("0x0"); + if (json.contains("cutoff")) { + cutoff = json.at("cutoff").get(); + } + + options.cutoff = details::hex_to_double(cutoff); + options.full_list = json.value("full_list", false); + options.strict = json.value("strict", false); + options.requestors = json.value("requestors", std::vector{}); + + return options; +} + +inline std::string ModelMetadata::to_json() const { + return nlohmann::json{ + {"type", "metatomic_model_metadata"}, + {"name", name}, + {"description", description}, + {"authors", authors}, + {"references", references}, + {"extra", extra}, + }.dump(); +} + +inline ModelMetadata ModelMetadata::from_json(const std::string& string) { + auto json = nlohmann::json::parse(string); + auto metadata = ModelMetadata(); + + metadata.name = json.value("name", ""); + metadata.description = json.value("description", ""); + metadata.authors = json.value("authors", std::vector{}); + metadata.references = json.value("references", std::map>{}); + metadata.extra = json.value("extra", std::map{}); + + return metadata; +} + +inline std::string Quantity::to_json() const { + return nlohmann::json{ + {"type", "metatomic_quantity"}, + {"name", name}, + {"unit", unit}, + {"gradients", gradients}, + {"sample_kind", sample_kind}, + }.dump(); +} + +inline Quantity Quantity::from_json(const std::string& string) { + auto json = nlohmann::json::parse(string); + auto quantity = Quantity(); + + quantity.name = json.value("name", ""); + quantity.unit = json.value("unit", ""); + quantity.gradients = json.value("gradients", std::vector{}); + quantity.sample_kind = json.value("sample_kind", ""); + + return quantity; +} + +} // namespace metatomic diff --git a/metatomic-core/include/metatomic/model.hpp b/metatomic-core/include/metatomic/model.hpp index 5a4d6bce5..ea0592dc8 100644 --- a/metatomic-core/include/metatomic/model.hpp +++ b/metatomic-core/include/metatomic/model.hpp @@ -1,17 +1,48 @@ #pragma once #include +#include #include #include #include #include +#include "./metadata.hpp" #include "./system.hpp" #include "./utils.hpp" namespace metatomic { +/// Abstract base class for atomistic models implemented in C++. +class ModelBase { +public: + virtual ~ModelBase() = default; + + /// Get metadata about this model. + virtual ModelMetadata metadata() const = 0; + + /// Get all quantities this model can compute. + virtual std::vector supported_outputs() const = 0; + + /// Get all pair lists this model requires. + virtual std::vector requested_pair_lists() const { + return {}; + } + + /// Get all custom inputs this model requires. + virtual std::vector requested_inputs() const { + return {}; + } + + /// Execute this model. + virtual std::vector execute( + const std::vector& systems, + const mts_labels_t* selected_atoms, + const std::vector& requested_outputs + ) = 0; +}; + /// RAII wrapper around a `mta_model_t`. class Model final { public: @@ -23,6 +54,24 @@ class Model final { /// Take ownership of a raw `mta_model_t`. explicit Model(mta_model_t model): model_(model) {} + /// Create a C API model wrapping a C++ model implementation. + explicit Model(std::unique_ptr model) { + if (model == nullptr) { + throw Error("can not create a metatomic::Model from a null ModelBase"); + } + + model_ = empty_model(); + model_.data = model.release(); + model_.unload = &Model::unload_callback; + model_.metadata = &Model::metadata_callback; + model_.supported_outputs = &Model::supported_outputs_callback; + model_.requested_pair_lists_count = &Model::requested_pair_lists_count_callback; + model_.requested_pair_list = &Model::requested_pair_list_callback; + model_.requested_inputs_count = &Model::requested_inputs_count_callback; + model_.requested_input = &Model::requested_input_callback; + model_.execute_inner = &Model::execute_callback; + } + ~Model() { this->reset_noexcept(); } @@ -56,8 +105,8 @@ class Model final { model_ = empty_model(); } - /// Get model metadata as a JSON string. - std::string metadata() const { + /// Get model metadata serialized as JSON. + std::string metadata_json() const { this->check_callback(model_.metadata, "metadata"); mta_string_t metadata = nullptr; @@ -65,8 +114,13 @@ class Model final { return String(metadata).str(); } - /// Get supported outputs as a JSON string. - std::string supported_outputs() const { + /// Get model metadata. + ModelMetadata metadata() const { + return ModelMetadata::from_json(this->metadata_json()); + } + + /// Get supported outputs serialized as JSON. + std::string supported_outputs_json() const { this->check_callback(model_.supported_outputs, "supported_outputs"); mta_string_t outputs = nullptr; @@ -74,8 +128,17 @@ class Model final { return String(outputs).str(); } + /// Get all quantities this model can compute. + std::vector supported_outputs() const { + auto outputs = std::vector(); + for (const auto& output: nlohmann::json::parse(this->supported_outputs_json())) { + outputs.push_back(Quantity::from_json(output.dump())); + } + return outputs; + } + /// Get all pair lists requested by this model, each one serialized as JSON. - std::vector requested_pair_lists() const { + std::vector requested_pair_lists_json() const { this->check_callback(model_.requested_pair_lists_count, "requested_pair_lists_count"); this->check_callback(model_.requested_pair_list, "requested_pair_list"); @@ -93,8 +156,17 @@ class Model final { return result; } + /// Get all pair lists requested by this model. + std::vector requested_pair_lists() const { + auto result = std::vector(); + for (const auto& options: this->requested_pair_lists_json()) { + result.push_back(PairListOptions::from_json(options)); + } + return result; + } + /// Get all custom inputs requested by this model, each one serialized as JSON. - std::vector requested_inputs() const { + std::vector requested_inputs_json() const { this->check_callback(model_.requested_inputs_count, "requested_inputs_count"); this->check_callback(model_.requested_input, "requested_input"); @@ -112,13 +184,22 @@ class Model final { return result; } + /// Get all custom inputs requested by this model. + std::vector requested_inputs() const { + auto result = std::vector(); + for (const auto& input: this->requested_inputs_json()) { + result.push_back(Quantity::from_json(input)); + } + return result; + } + /// Execute this model. /// /// The number of returned tensor maps is `requested_outputs.size()`. std::vector execute( const std::vector& systems, const metatensor::Labels* selected_atoms, - const std::vector& requested_outputs + const std::vector& requested_outputs ) { this->check_valid(); @@ -130,9 +211,12 @@ class Model final { } auto c_requested_outputs = std::vector(); + auto requested_outputs_json = std::vector(); + requested_outputs_json.reserve(requested_outputs.size()); c_requested_outputs.reserve(requested_outputs.size()); for (const auto& output: requested_outputs) { - c_requested_outputs.push_back(output.c_str()); + requested_outputs_json.push_back(output.to_json()); + c_requested_outputs.push_back(requested_outputs_json.back().c_str()); } auto raw_outputs = std::vector(requested_outputs.size(), nullptr); @@ -171,7 +255,7 @@ class Model final { /// Execute this model on all atoms. std::vector execute( const std::vector& systems, - const std::vector& requested_outputs + const std::vector& requested_outputs ) { return this->execute(systems, nullptr, requested_outputs); } @@ -205,6 +289,131 @@ class Model final { return model; } + static ModelBase* model_base(const void* data) { + details::check_pointer(data); + return static_cast(const_cast(data)); + } + + static mta_status_t unload_callback(void* data) { + return details::catch_exceptions([&]() { + delete model_base(data); + }); + } + + static mta_status_t metadata_callback(const void* data, mta_string_t* metadata_json) { + return details::catch_exceptions([&]() { + details::check_pointer(metadata_json); + *metadata_json = mta_string_create(model_base(data)->metadata().to_json().c_str()); + details::check_pointer(*metadata_json); + }); + } + + static mta_status_t supported_outputs_callback(const void* data, mta_string_t* outputs_json) { + return details::catch_exceptions([&]() { + details::check_pointer(outputs_json); + auto outputs = nlohmann::json::array(); + for (const auto& output: model_base(data)->supported_outputs()) { + outputs.push_back(nlohmann::json::parse(output.to_json())); + } + + *outputs_json = mta_string_create(outputs.dump().c_str()); + details::check_pointer(*outputs_json); + }); + } + + static mta_status_t requested_pair_lists_count_callback(const void* data, uintptr_t* count) { + return details::catch_exceptions([&]() { + details::check_pointer(count); + *count = model_base(data)->requested_pair_lists().size(); + }); + } + + static mta_status_t requested_pair_list_callback(const void* data, uintptr_t index, mta_string_t* pair_options_json) { + return details::catch_exceptions([&]() { + details::check_pointer(pair_options_json); + auto options = model_base(data)->requested_pair_lists(); + if (index >= options.size()) { + throw Error("pair list request index out of bounds"); + } + *pair_options_json = mta_string_create(options[index].to_json().c_str()); + details::check_pointer(*pair_options_json); + }); + } + + static mta_status_t requested_inputs_count_callback(const void* data, uintptr_t* count) { + return details::catch_exceptions([&]() { + details::check_pointer(count); + *count = model_base(data)->requested_inputs().size(); + }); + } + + static mta_status_t requested_input_callback(const void* data, uintptr_t index, mta_string_t* input_json) { + return details::catch_exceptions([&]() { + details::check_pointer(input_json); + auto inputs = model_base(data)->requested_inputs(); + if (index >= inputs.size()) { + throw Error("input request index out of bounds"); + } + *input_json = mta_string_create(inputs[index].to_json().c_str()); + details::check_pointer(*input_json); + }); + } + + static mta_status_t execute_callback( + void* data, + const mta_system_t* const* systems, + uintptr_t systems_count, + const mts_labels_t* selected_atoms, + const char* const* requested_outputs_json, + uintptr_t requested_outputs_count, + mts_tensormap_t** outputs, + uintptr_t outputs_count + ) { + return details::catch_exceptions([&]() { + if (systems_count != 0) { + details::check_pointer(systems); + } + if (requested_outputs_count != 0) { + details::check_pointer(requested_outputs_json); + } + if (outputs_count != 0) { + details::check_pointer(outputs); + } + if (requested_outputs_count != outputs_count) { + throw Error("expected one output storage slot for each requested output"); + } + + auto system_views = std::vector(); + system_views.reserve(systems_count); + for (uintptr_t i=0; i(); + cxx_systems.reserve(system_views.size()); + for (const auto& system: system_views) { + cxx_systems.push_back(&system); + } + + auto requested_outputs = std::vector(); + requested_outputs.reserve(requested_outputs_count); + for (uintptr_t i=0; iexecute(cxx_systems, selected_atoms, requested_outputs); + if (result.size() != outputs_count) { + throw Error("model returned the wrong number of outputs"); + } + + for (uintptr_t i=0; i - #include #include #include @@ -25,15 +23,47 @@ class Plugin { /// Load a model from `load_from`, using the provided key/value options. virtual Model load_model( const std::string& load_from, - const std::vector& options + const std::vector& options = {} ) = 0; }; +/// Handle to a plugin registered in metatomic's global plugin registry. +class PluginHandle final { +public: + explicit PluginHandle(std::string name): name_(std::move(name)) {} + + /// Name used to identify this plugin. + const std::string& name() const { + return name_; + } + + /// Load a model from `load_from`, using the provided key/value options. + Model load_model( + const std::string& load_from, + const std::vector& options = {} + ) const { + auto c_options = details::to_c_options(options); + + auto model = mta_model_t{}; + details::check_status(mta_load_model( + name_.c_str(), + load_from.c_str(), + c_options.data(), + c_options.size(), + &model + )); + + return Model(model); + } + +private: + std::string name_; +}; + namespace details { template struct PluginRegistration { static PluginT* plugin; - static const char* name; static mta_status_t load_model( const char* load_from, @@ -57,9 +87,6 @@ namespace details { template PluginT* PluginRegistration::plugin = nullptr; - - template - const char* PluginRegistration::name = nullptr; } // namespace details /// Register a C++ plugin. @@ -75,47 +102,32 @@ void register_plugin(PluginT& plugin) { details::PluginRegistration::plugin = &plugin; const auto name = plugin.name(); - // The C plugin registry keeps this pointer; allocate stable process-lifetime storage. - auto* name_storage = new char[name.size() + 1]; - std::memcpy(name_storage, name.c_str(), name.size() + 1); - details::PluginRegistration::name = name_storage; auto c_plugin = mta_plugin_t{ - details::PluginRegistration::name, + name.c_str(), &details::PluginRegistration::load_model, }; mta_register_plugin(c_plugin); } -/// Register a raw C plugin. -inline void register_plugin(mta_plugin_t plugin) { - mta_register_plugin(plugin); -} - /// Load a plugin dynamic library from the given path. inline void load_plugin(const std::string& path) { details::check_status(mta_load_plugin(path.c_str())); } +/// Get a handle to a plugin in metatomic's global plugin registry. +inline PluginHandle plugin(const std::string& name) { + return PluginHandle(name); +} + /// Load a model using the given plugin. inline Model load_model( const std::string& plugin_name, const std::string& load_from, const std::vector& options = {} ) { - auto c_options = details::to_c_options(options); - - auto model = mta_model_t{}; - details::check_status(mta_load_model( - plugin_name.c_str(), - load_from.c_str(), - c_options.data(), - c_options.size(), - &model - )); - - return Model(model); + return plugin(plugin_name).load_model(load_from, options); } /// Load a model, letting metatomic pick the plugin. diff --git a/metatomic-core/include/metatomic/system.hpp b/metatomic-core/include/metatomic/system.hpp index 5fde00f3d..55724d52e 100644 --- a/metatomic-core/include/metatomic/system.hpp +++ b/metatomic-core/include/metatomic/system.hpp @@ -23,7 +23,7 @@ class System final { DLManagedTensorVersioned* positions, DLManagedTensorVersioned* cell, DLManagedTensorVersioned* pbc - ): system_(nullptr) { + ): system_(nullptr), is_view_(false) { details::check_status(mta_system_create( length_unit.c_str(), types, @@ -36,7 +36,7 @@ class System final { } ~System() { - if (system_ != nullptr) { + if (system_ != nullptr && !is_view_) { (void)mta_system_free(system_); } } @@ -44,17 +44,19 @@ class System final { System(const System&) = delete; System& operator=(const System&) = delete; - System(System&& other) noexcept: system_(nullptr) { + System(System&& other) noexcept: system_(nullptr), is_view_(true) { *this = std::move(other); } System& operator=(System&& other) noexcept { - if (system_ != nullptr) { + if (system_ != nullptr && !is_view_) { (void)mta_system_free(system_); } system_ = other.system_; + is_view_ = other.is_view_; other.system_ = nullptr; + other.is_view_ = true; return *this; } @@ -180,7 +182,12 @@ class System final { /// Take ownership of a raw `mta_system_t*`. static System unsafe_from_ptr(mta_system_t* system) { - return System(system); + return System(system, false); + } + + /// Create a non-owning view of a raw `mta_system_t*`. + static System unsafe_view_from_ptr(const mta_system_t* system) { + return System(const_cast(system), true); } /// Release the raw `mta_system_t*` without freeing it. @@ -191,7 +198,7 @@ class System final { } private: - explicit System(mta_system_t* system): system_(system) {} + explicit System(mta_system_t* system, bool is_view): system_(system), is_view_(is_view) {} DLPackTensor data(mta_system_data_kind request) const { DLManagedTensorVersioned* data = nullptr; @@ -201,6 +208,7 @@ class System final { } mta_system_t* system_; + bool is_view_; }; } // namespace metatomic From 36993d8982301aeaf53a4e1d6448a1ce82b643f7 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sun, 31 May 2026 08:48:44 +0200 Subject: [PATCH 3/5] Update to new C API --- metatomic-core/include/metatomic/model.hpp | 115 +++++++------------- metatomic-core/include/metatomic/plugin.hpp | 34 +++--- metatomic-core/include/metatomic/system.hpp | 107 +++++++++++++----- metatomic-core/include/metatomic/utils.hpp | 36 ------ 4 files changed, 137 insertions(+), 155 deletions(-) diff --git a/metatomic-core/include/metatomic/model.hpp b/metatomic-core/include/metatomic/model.hpp index ea0592dc8..608f232b5 100644 --- a/metatomic-core/include/metatomic/model.hpp +++ b/metatomic-core/include/metatomic/model.hpp @@ -7,6 +7,7 @@ #include #include +#include #include "./metadata.hpp" #include "./system.hpp" @@ -65,10 +66,8 @@ class Model final { model_.unload = &Model::unload_callback; model_.metadata = &Model::metadata_callback; model_.supported_outputs = &Model::supported_outputs_callback; - model_.requested_pair_lists_count = &Model::requested_pair_lists_count_callback; - model_.requested_pair_list = &Model::requested_pair_list_callback; - model_.requested_inputs_count = &Model::requested_inputs_count_callback; - model_.requested_input = &Model::requested_input_callback; + model_.requested_pair_lists = &Model::requested_pair_lists_callback; + model_.requested_inputs = &Model::requested_inputs_callback; model_.execute_inner = &Model::execute_callback; } @@ -137,58 +136,38 @@ class Model final { return outputs; } - /// Get all pair lists requested by this model, each one serialized as JSON. - std::vector requested_pair_lists_json() const { - this->check_callback(model_.requested_pair_lists_count, "requested_pair_lists_count"); - this->check_callback(model_.requested_pair_list, "requested_pair_list"); + /// Get all pair lists requested by this model serialized as a JSON array. + std::string requested_pair_lists_json() const { + this->check_callback(model_.requested_pair_lists, "requested_pair_lists"); - uintptr_t count = 0; - details::check_status(model_.requested_pair_lists_count(model_.data, &count)); - - auto result = std::vector(); - result.reserve(count); - for (uintptr_t i=0; i requested_pair_lists() const { auto result = std::vector(); - for (const auto& options: this->requested_pair_lists_json()) { - result.push_back(PairListOptions::from_json(options)); + for (const auto& options: nlohmann::json::parse(this->requested_pair_lists_json())) { + result.push_back(PairListOptions::from_json(options.dump())); } return result; } - /// Get all custom inputs requested by this model, each one serialized as JSON. - std::vector requested_inputs_json() const { - this->check_callback(model_.requested_inputs_count, "requested_inputs_count"); - this->check_callback(model_.requested_input, "requested_input"); - - uintptr_t count = 0; - details::check_status(model_.requested_inputs_count(model_.data, &count)); + /// Get all custom inputs requested by this model serialized as a JSON array. + std::string requested_inputs_json() const { + this->check_callback(model_.requested_inputs, "requested_inputs"); - auto result = std::vector(); - result.reserve(count); - for (uintptr_t i=0; i requested_inputs() const { auto result = std::vector(); - for (const auto& input: this->requested_inputs_json()) { - result.push_back(Quantity::from_json(input)); + for (const auto& input: nlohmann::json::parse(this->requested_inputs_json())) { + result.push_back(Quantity::from_json(input.dump())); } return result; } @@ -199,7 +178,8 @@ class Model final { std::vector execute( const std::vector& systems, const metatensor::Labels* selected_atoms, - const std::vector& requested_outputs + const std::vector& requested_outputs, + bool check_consistency = true ) { this->check_valid(); @@ -227,6 +207,7 @@ class Model final { selected_atoms == nullptr ? nullptr : selected_atoms->as_mts_labels_t(), c_requested_outputs.data(), c_requested_outputs.size(), + check_consistency, raw_outputs.data(), raw_outputs.size() )); @@ -255,9 +236,10 @@ class Model final { /// Execute this model on all atoms. std::vector execute( const std::vector& systems, - const std::vector& requested_outputs + const std::vector& requested_outputs, + bool check_consistency = true ) { - return this->execute(systems, nullptr, requested_outputs); + return this->execute(systems, nullptr, requested_outputs, check_consistency); } /// Get the underlying `mta_model_t`. @@ -281,10 +263,8 @@ class Model final { model.unload = nullptr; model.metadata = nullptr; model.supported_outputs = nullptr; - model.requested_pair_lists_count = nullptr; - model.requested_pair_list = nullptr; - model.requested_inputs_count = nullptr; - model.requested_input = nullptr; + model.requested_pair_lists = nullptr; + model.requested_inputs = nullptr; model.execute_inner = nullptr; return model; } @@ -321,41 +301,29 @@ class Model final { }); } - static mta_status_t requested_pair_lists_count_callback(const void* data, uintptr_t* count) { - return details::catch_exceptions([&]() { - details::check_pointer(count); - *count = model_base(data)->requested_pair_lists().size(); - }); - } - - static mta_status_t requested_pair_list_callback(const void* data, uintptr_t index, mta_string_t* pair_options_json) { + static mta_status_t requested_pair_lists_callback(const void* data, mta_string_t* pair_options_json) { return details::catch_exceptions([&]() { details::check_pointer(pair_options_json); - auto options = model_base(data)->requested_pair_lists(); - if (index >= options.size()) { - throw Error("pair list request index out of bounds"); + auto options = nlohmann::json::array(); + for (const auto& option: model_base(data)->requested_pair_lists()) { + options.push_back(nlohmann::json::parse(option.to_json())); } - *pair_options_json = mta_string_create(options[index].to_json().c_str()); - details::check_pointer(*pair_options_json); - }); - } - static mta_status_t requested_inputs_count_callback(const void* data, uintptr_t* count) { - return details::catch_exceptions([&]() { - details::check_pointer(count); - *count = model_base(data)->requested_inputs().size(); + *pair_options_json = mta_string_create(options.dump().c_str()); + details::check_pointer(*pair_options_json); }); } - static mta_status_t requested_input_callback(const void* data, uintptr_t index, mta_string_t* input_json) { + static mta_status_t requested_inputs_callback(const void* data, mta_string_t* inputs_json) { return details::catch_exceptions([&]() { - details::check_pointer(input_json); - auto inputs = model_base(data)->requested_inputs(); - if (index >= inputs.size()) { - throw Error("input request index out of bounds"); + details::check_pointer(inputs_json); + auto inputs = nlohmann::json::array(); + for (const auto& input: model_base(data)->requested_inputs()) { + inputs.push_back(nlohmann::json::parse(input.to_json())); } - *input_json = mta_string_create(inputs[index].to_json().c_str()); - details::check_pointer(*input_json); + + *inputs_json = mta_string_create(inputs.dump().c_str()); + details::check_pointer(*inputs_json); }); } @@ -399,6 +367,7 @@ class Model final { auto requested_outputs = std::vector(); requested_outputs.reserve(requested_outputs_count); for (uintptr_t i=0; i #include #include -#include #include @@ -23,7 +22,7 @@ class Plugin { /// Load a model from `load_from`, using the provided key/value options. virtual Model load_model( const std::string& load_from, - const std::vector& options = {} + const std::string& options_json = "{}" ) = 0; }; @@ -40,16 +39,13 @@ class PluginHandle final { /// Load a model from `load_from`, using the provided key/value options. Model load_model( const std::string& load_from, - const std::vector& options = {} + const std::string& options_json = "{}" ) const { - auto c_options = details::to_c_options(options); - auto model = mta_model_t{}; details::check_status(mta_load_model( name_.c_str(), load_from.c_str(), - c_options.data(), - c_options.size(), + options_json.c_str(), &model )); @@ -64,11 +60,11 @@ namespace details { template struct PluginRegistration { static PluginT* plugin; + static std::string name; static mta_status_t load_model( const char* load_from, - const mta_kv_pair_t* options, - uintptr_t options_count, + const char* options_json, mta_model_t* model ) { return details::catch_exceptions([&]() { @@ -77,7 +73,7 @@ namespace details { auto loaded = plugin->load_model( load_from == nullptr ? "" : load_from, - details::from_c_options(options, options_count) + options_json == nullptr ? "{}" : options_json ); *model = loaded.release(); @@ -87,6 +83,9 @@ namespace details { template PluginT* PluginRegistration::plugin = nullptr; + + template + std::string PluginRegistration::name; } // namespace details /// Register a C++ plugin. @@ -101,10 +100,10 @@ void register_plugin(PluginT& plugin) { ); details::PluginRegistration::plugin = &plugin; - const auto name = plugin.name(); + details::PluginRegistration::name = plugin.name(); auto c_plugin = mta_plugin_t{ - name.c_str(), + details::PluginRegistration::name.c_str(), &details::PluginRegistration::load_model, }; @@ -125,24 +124,21 @@ inline PluginHandle plugin(const std::string& name) { inline Model load_model( const std::string& plugin_name, const std::string& load_from, - const std::vector& options = {} + const std::string& options_json = "{}" ) { - return plugin(plugin_name).load_model(load_from, options); + return plugin(plugin_name).load_model(load_from, options_json); } /// Load a model, letting metatomic pick the plugin. inline Model load_model( const std::string& load_from, - const std::vector& options = {} + const std::string& options_json = "{}" ) { - auto c_options = details::to_c_options(options); - auto model = mta_model_t{}; details::check_status(mta_load_model( nullptr, load_from.c_str(), - c_options.data(), - c_options.size(), + options_json.c_str(), &model )); diff --git a/metatomic-core/include/metatomic/system.hpp b/metatomic-core/include/metatomic/system.hpp index 55724d52e..15d136b83 100644 --- a/metatomic-core/include/metatomic/system.hpp +++ b/metatomic-core/include/metatomic/system.hpp @@ -5,7 +5,9 @@ #include #include +#include +#include "./metadata.hpp" #include "./utils.hpp" namespace metatomic { @@ -97,47 +99,97 @@ class System final { /// Add a new pair list in this system. /// /// Ownership of `pairs` is transferred to the C API. - void set_pairs(const std::string& options, mts_block_t* pairs) { - details::check_status(mta_system_set_pairs(system_, options.c_str(), pairs)); + void add_pairs(const PairListOptions& options, mts_block_t* pairs) { + this->add_pairs(options.to_json(), pairs); + } + + /// Add a new pair list in this system. + /// + /// Ownership of `pairs` is transferred to the C API. + void add_pairs(const std::string& options_json, mts_block_t* pairs) { + details::check_status(mta_system_add_pairs(system_, options_json.c_str(), pairs)); + } + + /// Add a new pair list in this system. + /// + /// Ownership of `pairs` is transferred to the C API. + void set_pairs(const PairListOptions& options, mts_block_t* pairs) { + this->add_pairs(options, pairs); + } + + /// Add a new pair list in this system. + /// + /// Ownership of `pairs` is transferred to the C API. + void set_pairs(const std::string& options_json, mts_block_t* pairs) { + this->add_pairs(options_json, pairs); } /// Retrieve a previously stored pair list with the given options. - const mts_block_t* pairs_raw(const std::string& options) const { + const mts_block_t* pairs_raw(const PairListOptions& options) const { + return this->pairs_raw(options.to_json()); + } + + /// Retrieve a previously stored pair list with the given options. + const mts_block_t* pairs_raw(const std::string& options_json) const { const mts_block_t* pairs = nullptr; - details::check_status(mta_system_get_pairs(system_, options.c_str(), &pairs)); + details::check_status(mta_system_get_pairs(system_, options_json.c_str(), &pairs)); details::check_pointer(pairs); return pairs; } /// Retrieve a previously stored pair list with the given options as a /// non-owning metatensor view. - metatensor::TensorBlock pairs(const std::string& options) const { + metatensor::TensorBlock pairs(const PairListOptions& options) const { + return this->pairs(options.to_json()); + } + + /// Retrieve a previously stored pair list with the given options as a + /// non-owning metatensor view. + metatensor::TensorBlock pairs(const std::string& options_json) const { return metatensor::TensorBlock::unsafe_view_from_ptr( - const_cast(this->pairs_raw(options)) + const_cast(this->pairs_raw(options_json)) ); } + /// Get the options for all pair lists registered with this `System`, + /// serialized as a JSON array. + std::string known_pairs_json() const { + mta_string_t pairs_options = nullptr; + details::check_status(mta_system_known_pairs(system_, &pairs_options)); + return String(pairs_options).str(); + } + /// Get the options for all pair lists registered with this `System`. - std::vector pairs_options() const { - uintptr_t count = 0; - details::check_status(mta_system_pairs_count(system_, &count)); + std::vector known_pairs() const { + auto result = std::vector(); + for (const auto& options: nlohmann::json::parse(this->known_pairs_json())) { + result.push_back(PairListOptions::from_json(options.dump())); + } + return result; + } + /// Get the options for all pair lists registered with this `System`, + /// each one serialized as JSON. + std::vector pairs_options() const { auto result = std::vector(); - result.reserve(count); - for (uintptr_t i=0; iknown_pairs_json())) { + result.push_back(options.dump()); } - return result; } + /// Add custom data to this system. + /// + /// Ownership of `data` is transferred to the C API. + void add_data(const std::string& name, mts_tensormap_t* data) { + details::check_status(mta_system_add_custom_data(system_, name.c_str(), data)); + } + /// Add custom data to this system. /// /// Ownership of `data` is transferred to the C API. void set_data(const std::string& name, mts_tensormap_t* data) { - details::check_status(mta_system_set_custom_data(system_, name.c_str(), data)); + this->add_data(name, data); } /// Retrieve custom data stored in this system. @@ -151,19 +203,20 @@ class System final { } /// Get the names of all custom data registered with this `System`. - std::vector data_names() const { - uintptr_t count = 0; - details::check_status(mta_system_data_count(system_, &count)); + std::string known_data_json() const { + mta_string_t names = nullptr; + details::check_status(mta_system_known_custom_data(system_, &names)); + return String(names).str(); + } - auto result = std::vector(); - result.reserve(count); - for (uintptr_t i=0; i known_data() const { + return nlohmann::json::parse(this->known_data_json()).get>(); + } - return result; + /// Get the names of all custom data registered with this `System`. + std::vector data_names() const { + return this->known_data(); } /// Get the underlying `mta_system_t` pointer. diff --git a/metatomic-core/include/metatomic/utils.hpp b/metatomic-core/include/metatomic/utils.hpp index 9ff181d18..056e18224 100644 --- a/metatomic-core/include/metatomic/utils.hpp +++ b/metatomic-core/include/metatomic/utils.hpp @@ -7,7 +7,6 @@ #include #include #include -#include #include @@ -20,12 +19,6 @@ class Error: public std::runtime_error { explicit Error(const std::string& message): std::runtime_error(message) {} }; -/// Key/value pair used when loading models from plugins. -struct KeyValuePair { - std::string key; - std::string value; -}; - /// RAII wrapper around a `DLManagedTensorVersioned*`. /// /// This owns the DLPack managed tensor object, and calls its deleter when the @@ -170,35 +163,6 @@ namespace details { throw Error(message == nullptr ? "received a null pointer from the metatomic C API" : message); } - - inline std::vector to_c_options(const std::vector& options) { - auto c_options = std::vector(); - c_options.reserve(options.size()); - - for (const auto& option: options) { - c_options.push_back(mta_kv_pair_t{option.key.c_str(), option.value.c_str()}); - } - - return c_options; - } - - inline std::vector from_c_options(const mta_kv_pair_t* options, uintptr_t count) { - auto result = std::vector(); - result.reserve(count); - - if (count != 0) { - check_pointer(options); - } - - for (uintptr_t i=0; i Date: Sun, 31 May 2026 09:21:19 +0200 Subject: [PATCH 4/5] Rename model classes for consistency with current `metatomic.torch` --- metatomic-core/include/metatomic/model.hpp | 44 ++++++++++----------- metatomic-core/include/metatomic/plugin.hpp | 16 ++++---- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/metatomic-core/include/metatomic/model.hpp b/metatomic-core/include/metatomic/model.hpp index 608f232b5..ee140760d 100644 --- a/metatomic-core/include/metatomic/model.hpp +++ b/metatomic-core/include/metatomic/model.hpp @@ -16,9 +16,9 @@ namespace metatomic { /// Abstract base class for atomistic models implemented in C++. -class ModelBase { +class ModelInterface { public: - virtual ~ModelBase() = default; + virtual ~ModelInterface() = default; /// Get metadata about this model. virtual ModelMetadata metadata() const = 0; @@ -45,44 +45,44 @@ class ModelBase { }; /// RAII wrapper around a `mta_model_t`. -class Model final { +class AtomisticModel final { public: /// Create an empty, invalid model. - Model() { + AtomisticModel() { model_ = empty_model(); } /// Take ownership of a raw `mta_model_t`. - explicit Model(mta_model_t model): model_(model) {} + explicit AtomisticModel(mta_model_t model): model_(model) {} /// Create a C API model wrapping a C++ model implementation. - explicit Model(std::unique_ptr model) { + explicit AtomisticModel(std::unique_ptr model) { if (model == nullptr) { - throw Error("can not create a metatomic::Model from a null ModelBase"); + throw Error("can not create a metatomic::AtomisticModel from a null ModelInterface"); } model_ = empty_model(); model_.data = model.release(); - model_.unload = &Model::unload_callback; - model_.metadata = &Model::metadata_callback; - model_.supported_outputs = &Model::supported_outputs_callback; - model_.requested_pair_lists = &Model::requested_pair_lists_callback; - model_.requested_inputs = &Model::requested_inputs_callback; - model_.execute_inner = &Model::execute_callback; + model_.unload = &AtomisticModel::unload_callback; + model_.metadata = &AtomisticModel::metadata_callback; + model_.supported_outputs = &AtomisticModel::supported_outputs_callback; + model_.requested_pair_lists = &AtomisticModel::requested_pair_lists_callback; + model_.requested_inputs = &AtomisticModel::requested_inputs_callback; + model_.execute_inner = &AtomisticModel::execute_callback; } - ~Model() { + ~AtomisticModel() { this->reset_noexcept(); } - Model(const Model&) = delete; - Model& operator=(const Model&) = delete; + AtomisticModel(const AtomisticModel&) = delete; + AtomisticModel& operator=(const AtomisticModel&) = delete; - Model(Model&& other) noexcept: Model() { + AtomisticModel(AtomisticModel&& other) noexcept: AtomisticModel() { *this = std::move(other); } - Model& operator=(Model&& other) noexcept { + AtomisticModel& operator=(AtomisticModel&& other) noexcept { if (this != &other) { this->reset_noexcept(); model_ = other.model_; @@ -269,9 +269,9 @@ class Model final { return model; } - static ModelBase* model_base(const void* data) { + static ModelInterface* model_base(const void* data) { details::check_pointer(data); - return static_cast(const_cast(data)); + return static_cast(const_cast(data)); } static mta_status_t unload_callback(void* data) { @@ -392,7 +392,7 @@ class Model final { void check_valid() const { if (model_.data == nullptr) { - throw Error("can not use an empty metatomic::Model"); + throw Error("can not use an empty metatomic::AtomisticModel"); } } @@ -400,7 +400,7 @@ class Model final { void check_callback(Callback callback, const char* name) const { this->check_valid(); if (callback == nullptr) { - throw Error("metatomic::Model does not implement " + std::string(name)); + throw Error("metatomic::AtomisticModel does not implement " + std::string(name)); } } diff --git a/metatomic-core/include/metatomic/plugin.hpp b/metatomic-core/include/metatomic/plugin.hpp index a806e71d7..f5aee0692 100644 --- a/metatomic-core/include/metatomic/plugin.hpp +++ b/metatomic-core/include/metatomic/plugin.hpp @@ -19,8 +19,8 @@ class Plugin { /// Name used to identify this plugin. virtual std::string name() const = 0; - /// Load a model from `load_from`, using the provided key/value options. - virtual Model load_model( + /// Load a model from `load_from`, using the provided JSON options. + virtual AtomisticModel load_model( const std::string& load_from, const std::string& options_json = "{}" ) = 0; @@ -36,8 +36,8 @@ class PluginHandle final { return name_; } - /// Load a model from `load_from`, using the provided key/value options. - Model load_model( + /// Load a model from `load_from`, using the provided JSON options. + AtomisticModel load_model( const std::string& load_from, const std::string& options_json = "{}" ) const { @@ -49,7 +49,7 @@ class PluginHandle final { &model )); - return Model(model); + return AtomisticModel(model); } private: @@ -121,7 +121,7 @@ inline PluginHandle plugin(const std::string& name) { } /// Load a model using the given plugin. -inline Model load_model( +inline AtomisticModel load_model( const std::string& plugin_name, const std::string& load_from, const std::string& options_json = "{}" @@ -130,7 +130,7 @@ inline Model load_model( } /// Load a model, letting metatomic pick the plugin. -inline Model load_model( +inline AtomisticModel load_model( const std::string& load_from, const std::string& options_json = "{}" ) { @@ -142,7 +142,7 @@ inline Model load_model( &model )); - return Model(model); + return AtomisticModel(model); } } // namespace metatomic From fc879295017b5ac609d3c782429c7a84e76a5153 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Jun 2026 10:02:05 +0200 Subject: [PATCH 5/5] Add clear to venv --- metatomic-core/tests/utils/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/metatomic-core/tests/utils/mod.rs b/metatomic-core/tests/utils/mod.rs index e12e6897c..78540ba96 100644 --- a/metatomic-core/tests/utils/mod.rs +++ b/metatomic-core/tests/utils/mod.rs @@ -148,6 +148,7 @@ pub fn create_python_venv(build_dir: PathBuf) -> PathBuf { let mut cmd = Command::new(find_python()); cmd.arg("-m"); cmd.arg("venv"); + cmd.arg("--clear"); cmd.arg(&build_dir); run_command(cmd, "python to create virtualenv with `venv`");