Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion metatomic-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ else()
find_package(metatensor_torch ${REQUIRED_METATENSOR_TORCH_VERSION} CONFIG REQUIRED)
endif()

set(REQUIRED_METATOMIC_VERSION "0.1.0")
if (TARGET metatomic)
get_target_property(METATOMIC_BUILD_VERSION metatomic BUILD_VERSION)
check_compatible_versions(${METATOMIC_BUILD_VERSION} ${REQUIRED_METATOMIC_VERSION})
else()
find_package(metatomic ${REQUIRED_METATOMIC_VERSION} CONFIG QUIET)
if (NOT metatomic_FOUND)
get_filename_component(_metatomic_root "${CMAKE_CURRENT_SOURCE_DIR}/.." ABSOLUTE)
if (EXISTS "${_metatomic_root}/metatomic-core/CMakeLists.txt")
add_subdirectory("${_metatomic_root}/metatomic-core" metatomic-core)
else()
find_package(metatomic ${REQUIRED_METATOMIC_VERSION} CONFIG REQUIRED)
endif()
endif()
endif()

# FindCUDNN.cmake distributed with PyTorch is a bit broken, so we have a
# fixed version in `cmake/FindCUDNN.cmake`
set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
Expand Down Expand Up @@ -123,7 +139,7 @@ set_target_properties(metatomic_torch PROPERTIES
BUILD_VERSION ${METATOMIC_TORCH_FULL_VERSION}
)

target_link_libraries(metatomic_torch PUBLIC torch metatensor_torch ${CMAKE_DL_LIBS})
target_link_libraries(metatomic_torch PUBLIC torch metatensor_torch metatomic ${CMAKE_DL_LIBS})
target_compile_features(metatomic_torch PUBLIC cxx_std_17)
target_include_directories(metatomic_torch PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
Expand Down
3 changes: 3 additions & 0 deletions metatomic-torch/cmake/metatomic_torch-config.in.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ include(CMakeFindDependencyMacro)
set(REQUIRED_METATENSOR_TORCH_VERSION @REQUIRED_METATENSOR_TORCH_VERSION@)
find_package(metatensor_torch ${REQUIRED_METATENSOR_TORCH_VERSION} CONFIG REQUIRED)

set(REQUIRED_METATOMIC_VERSION @REQUIRED_METATOMIC_VERSION@)
find_package(metatomic ${REQUIRED_METATOMIC_VERSION} CONFIG REQUIRED)

# We can only load metatomic_torch with the same minor version of Torch that
# was used to compile it (and is stored in BUILD_TORCH_VERSION)
set(BUILD_TORCH_VERSION @Torch_VERSION@)
Expand Down
21 changes: 21 additions & 0 deletions metatomic-torch/include/metatomic/torch/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <vector>
#include <string>

#include <metatomic.hpp>

#include <torch/script.h>

#include <metatensor/torch.hpp>
Expand Down Expand Up @@ -130,6 +132,15 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
/// Load a serialized `ModelOutput` from a JSON string.
static ModelOutput from_json(std::string_view json);

/// Convert this output to the Rust-backed C++ API representation.
///
/// `name` is the name used for this output in the model outputs map.
metatomic::Quantity as_quantity(const std::string& name) const;

/// Create a TorchScript model output from the Rust-backed C++ API
/// representation.
static ModelOutput from_quantity(const metatomic::Quantity& quantity);

private:
void set_per_atom_no_deprecation(bool per_atom);
bool get_per_atom_no_deprecation() const;
Expand Down Expand Up @@ -224,6 +235,10 @@ class METATOMIC_TORCH_EXPORT ModelCapabilitiesHolder: public torch::CustomClassH
/// Load a serialized `ModelCapabilities` from a JSON string.
static ModelCapabilities from_json(std::string_view json);

/// Convert the outputs in this capability object to the Rust-backed C++
/// API representation.
std::vector<metatomic::Quantity> supported_outputs() const;

private:
void set_outputs(torch::Dict<std::string, ModelOutput> outputs, bool warn_on_deprecated);

Expand Down Expand Up @@ -339,6 +354,12 @@ class METATOMIC_TORCH_EXPORT ModelMetadataHolder: public torch::CustomClassHolde
/// Load a serialized `ModelMetadata` from a JSON string.
static ModelMetadata from_json(std::string_view json);

/// Convert this metadata to the Rust-backed C++ API representation.
metatomic::ModelMetadata as_metatomic() const;

/// Create TorchScript metadata from the Rust-backed C++ API representation.
static ModelMetadata from_metatomic(const metatomic::ModelMetadata& metadata);

private:
/// validate the metadata before using it
void validate() const;
Expand Down
15 changes: 15 additions & 0 deletions metatomic-torch/include/metatomic/torch/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <vector>
#include <string>

#include <metatomic.hpp>

#include <torch/script.h>

#include <metatensor/torch.hpp>
Expand Down Expand Up @@ -79,6 +81,13 @@ class METATOMIC_TORCH_EXPORT NeighborListOptionsHolder final: public torch::Cust
/// Load a serialized `NeighborListOptions` from a JSON string.
static NeighborListOptions from_json(const std::string& json);

/// Convert these options to the Rust-backed C++ API representation.
metatomic::PairListOptions as_pair_list_options() const;

/// Create TorchScript neighbor list options from the Rust-backed C++ API
/// representation.
static NeighborListOptions from_pair_list_options(const metatomic::PairListOptions& options);

private:
// cutoff in the model units
double cutoff_;
Expand Down Expand Up @@ -301,6 +310,12 @@ class METATOMIC_TORCH_EXPORT SystemHolder final: public torch::CustomClassHolder
/// Implementation of `__str__` and `__repr__` for Python
std::string str() const;

/// Convert this system to the Rust-backed C++ API representation.
///
/// The returned system owns DLPack views of the current torch tensors and
/// deep copies of neighbor lists/custom data.
metatomic::System as_metatomic(const std::string& length_unit) const;

private:
struct nl_options_compare {
bool operator()(const NeighborListOptions& a, const NeighborListOptions& b) const {
Expand Down
72 changes: 72 additions & 0 deletions metatomic-torch/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,25 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
return result;
}

metatomic::Quantity ModelOutputHolder::as_quantity(const std::string& name) const {
return metatomic::Quantity(
name,
this->unit(),
this->explicit_gradients,
this->sample_kind()
);
}

ModelOutput ModelOutputHolder::from_quantity(const metatomic::Quantity& quantity) {
return torch::make_intrusive<ModelOutputHolder>(
/*quantity=*/"",
quantity.unit,
quantity.sample_kind,
quantity.gradients,
/*description=*/""
);
}

static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
"system",
"atom",
Expand Down Expand Up @@ -503,6 +522,17 @@ ModelCapabilities ModelCapabilitiesHolder::from_json(std::string_view json) {
return result;
}

std::vector<metatomic::Quantity> ModelCapabilitiesHolder::supported_outputs() const {
auto result = std::vector<metatomic::Quantity>();
result.reserve(this->outputs_.size());

for (const auto& it: this->outputs_) {
result.emplace_back(it.value()->as_quantity(it.key()));
}

return result;
}

/******************************************************************************/

static void check_selected_atoms(const torch::optional<metatensor_torch::Labels>& selected_atoms) {
Expand Down Expand Up @@ -794,6 +824,48 @@ ModelMetadata ModelMetadataHolder::from_json(std::string_view json) {
return result;
}

metatomic::ModelMetadata ModelMetadataHolder::as_metatomic() const {
this->validate();

auto references = std::map<std::string, std::vector<std::string>>();
for (const auto& it: this->references) {
references.emplace(it.key(), it.value());
}

auto extra = std::map<std::string, std::string>();
for (const auto& it: this->extra) {
extra.emplace(it.key(), it.value());
}

return metatomic::ModelMetadata(
this->name,
this->description,
this->authors,
std::move(references),
std::move(extra)
);
}

ModelMetadata ModelMetadataHolder::from_metatomic(const metatomic::ModelMetadata& metadata) {
auto references = torch::Dict<std::string, std::vector<std::string>>();
for (const auto& it: metadata.references) {
references.insert(it.first, it.second);
}

auto extra = torch::Dict<std::string, std::string>();
for (const auto& it: metadata.extra) {
extra.insert(it.first, it.second);
}

return torch::make_intrusive<ModelMetadataHolder>(
metadata.name,
metadata.description,
metadata.authors,
references,
extra
);
}


// replace end of line characters and tabs with a single space
static std::string normalize_whitespace(std::string_view data) {
Expand Down
78 changes: 78 additions & 0 deletions metatomic-torch/src/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <sstream>
#include <algorithm>

#include <metatomic.hpp>

#include <ATen/DLConvertor.h>
#include <torch/torch.h>
#include <nlohmann/json.hpp>

Expand All @@ -18,6 +21,29 @@

using namespace metatomic_torch;

namespace {

DLManagedTensorVersioned* torch_to_versioned_dlpack(const torch::Tensor& tensor) {
auto* torch_tensor = at::toDLPack(tensor);

auto* result = new DLManagedTensorVersioned();
result->version = {DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION};
result->manager_ctx = torch_tensor;
result->deleter = [](DLManagedTensorVersioned* self) {
auto* torch_tensor = static_cast<DLManagedTensor*>(self->manager_ctx);
if (torch_tensor != nullptr && torch_tensor->deleter != nullptr) {
torch_tensor->deleter(torch_tensor);
}
delete self;
};
result->flags = 0;
result->dl_tensor = torch_tensor->dl_tensor;

return result;
}

} // namespace

NeighborListOptionsHolder::NeighborListOptionsHolder(
double cutoff,
bool full_list,
Expand Down Expand Up @@ -160,6 +186,29 @@ NeighborListOptions NeighborListOptionsHolder::from_json(const std::string& json
return options;
}

metatomic::PairListOptions NeighborListOptionsHolder::as_pair_list_options() const {
return metatomic::PairListOptions(
this->cutoff_,
this->full_list_,
this->strict_,
this->requestors_
);
}

NeighborListOptions NeighborListOptionsHolder::from_pair_list_options(const metatomic::PairListOptions& options) {
auto result = torch::make_intrusive<NeighborListOptionsHolder>(
options.cutoff,
options.full_list,
options.strict
);

for (const auto& requestor: options.requestors) {
result->add_requestor(requestor);
}

return result;
}

// ========================================================================== //

torch::Tensor NeighborsAutograd::forward(
Expand Down Expand Up @@ -991,3 +1040,32 @@ std::string SystemHolder::str() const {

return result.str();
}

metatomic::System SystemHolder::as_metatomic(const std::string& length_unit) const {
auto system = metatomic::System(
length_unit,
torch_to_versioned_dlpack(this->types_),
torch_to_versioned_dlpack(this->positions_),
torch_to_versioned_dlpack(this->cell_),
torch_to_versioned_dlpack(this->pbc_)
);

for (const auto& it: this->neighbors_) {
auto neighbors = it.second->copy(/*deep=*/true)->release();
auto* pairs = mts_block_copy(neighbors.as_mts_block_t());
if (pairs == nullptr) {
C10_THROW_ERROR(ValueError, "failed to copy neighbor list " + it.first->str());
}
system.add_pairs(it.first->as_pair_list_options(), pairs);
}

for (const auto& it: this->data_) {
auto* data = mts_tensormap_copy(it.second->as_metatensor().as_mts_tensormap_t());
if (data == nullptr) {
C10_THROW_ERROR(ValueError, "failed to copy custom data '" + it.first + "'");
}
system.add_data(it.first, data);
}

return system;
}
2 changes: 2 additions & 0 deletions python/metatomic_torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ else()
set(metatomic_install_rpath "${CMAKE_INSTALL_RPATH}")
# when loading the libraries from a Python installation:
# - $ORIGIN/../../../torch/lib is where libtorch.so will be
# - $ORIGIN/../../../metatomic/lib is where libmetatomic.so will be
# - $ORIGIN/../../../metatensor/lib is where libmetatensor.so will be
# - $ORIGIN/../../../metatensor_torch/torch-${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}/lib is where libmetatensor_torch.so will be
set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../torch/lib")
set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../metatomic/lib")
set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../metatensor/lib")
set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../metatensor_torch/torch-${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}/lib")

Expand Down
Loading