From e3a3d5bc6be7509903b3c6244809d7ffc3043b60 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 02:43:16 -0700 Subject: [PATCH 1/6] adding a model only restore function to the RL path Signed-off-by: Thorsten Kurth --- docs/api/c_api.rst | 16 + docs/api/f_api.rst | 48 ++- docs/usage.rst | 30 +- src/csrc/include/internal/model_pack.h | 13 + src/csrc/include/internal/rl/off_policy.h | 3 + .../include/internal/rl/off_policy/ddpg.h | 1 + src/csrc/include/internal/rl/off_policy/sac.h | 1 + src/csrc/include/internal/rl/off_policy/td3.h | 1 + src/csrc/include/internal/rl/on_policy.h | 3 + src/csrc/include/internal/rl/on_policy/ppo.h | 1 + src/csrc/include/torchfort_rl.h | 31 ++ src/csrc/model_pack.cpp | 2 +- src/csrc/rl/off_policy/ddpg.cpp | 30 ++ src/csrc/rl/off_policy/interface.cpp | 13 + src/csrc/rl/off_policy/sac.cpp | 39 ++- src/csrc/rl/off_policy/td3.cpp | 34 ++ src/csrc/rl/on_policy/interface.cpp | 13 + src/csrc/rl/on_policy/ppo.cpp | 22 +- src/csrc/torchfort.cpp | 2 +- src/fsrc/torchfort_m.F90 | 32 ++ tests/rl/CMakeLists.txt | 15 + tests/rl/test_checkpoint.cpp | 303 ++++++++++++++++++ tests/rl/test_interface.cpp | 6 +- tests/rl/test_off_policy.cpp | 3 +- tests/rl/test_on_policy.cpp | 3 +- tests/test_utils.h | 35 ++ 26 files changed, 686 insertions(+), 14 deletions(-) create mode 100644 tests/rl/test_checkpoint.cpp diff --git a/docs/api/c_api.rst b/docs/api/c_api.rst index 14ecb901..093f0124 100644 --- a/docs/api/c_api.rst +++ b/docs/api/c_api.rst @@ -334,6 +334,14 @@ torchfort_rl_off_policy_load_checkpoint ------ +.. _torchfort_rl_off_policy_load_model-ref: + +torchfort_rl_off_policy_load_model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. doxygenfunction:: torchfort_rl_off_policy_load_model + +------ + Weights and Biases Logging -------------------------- @@ -483,6 +491,14 @@ torchfort_rl_on_policy_load_checkpoint ------ +.. _torchfort_rl_on_policy_load_model-ref: + +torchfort_rl_on_policy_load_model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. doxygenfunction:: torchfort_rl_on_policy_load_model + +------ + Weights and Biases Logging -------------------------- diff --git a/docs/api/f_api.rst b/docs/api/f_api.rst index 92a72e3d..01b04445 100644 --- a/docs/api/f_api.rst +++ b/docs/api/f_api.rst @@ -534,11 +534,31 @@ torchfort_rl_off_policy_load_checkpoint .. f:function:: torchfort_rl_off_policy_load_checkpoint(name, checkpoint_dir) - Restores a reinforcement learning system from a checkpoint. + Restores a reinforcement learning system from a checkpoint. This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also fully restores the state of the replay buffer, but not the current RNG seed. This function should be used in conjunction with :code:`torchfort_rl_off_policy_save_checkpoint`. - + + :p character(:) name [in]: The name of system instance to use, as defined during system creation. + :p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load. + :r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure. + +------ + +.. _torchfort_rl_off_policy_load_model-f-ref: + +torchfort_rl_off_policy_load_model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. f:function:: torchfort_rl_off_policy_load_model(name, checkpoint_dir) + + Restores only the network weights of a reinforcement learning system from a checkpoint. + In contrast to :code:`torchfort_rl_off_policy_load_checkpoint`, this method only restores the weights of the online policy and + critic networks. The optimizers, LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly + created state, and the target networks are re-initialized from the loaded online networks. This is intended for fine-tuning or + transfer-learning workflows, where a pretrained model is used as the starting point for a new training run (e.g. with a modified + reward function or new environment data). The checkpoint is the one produced by :code:`torchfort_rl_off_policy_save_checkpoint`. + :p character(:) name [in]: The name of system instance to use, as defined during system creation. :p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load. :r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure. @@ -800,11 +820,31 @@ torchfort_rl_on_policy_load_checkpoint .. f:function:: torchfort_rl_on_policy_load_checkpoint(name, checkpoint_dir) - Restores a reinforcement learning system from a checkpoint. + Restores a reinforcement learning system from a checkpoint. This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also fully restores the state of the rollout buffer, but not the current RNG seed. This function should be used in conjunction with :code:`torchfort_rl_on_policy_save_checkpoint`. - + + :p character(:) name [in]: The name of system instance to use, as defined during system creation. + :p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load. + :r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure. + +------ + +.. _torchfort_rl_on_policy_load_model-f-ref: + +torchfort_rl_on_policy_load_model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. f:function:: torchfort_rl_on_policy_load_model(name, checkpoint_dir) + + Restores only the network weights of a reinforcement learning system from a checkpoint. + In contrast to :code:`torchfort_rl_on_policy_load_checkpoint`, this method only restores the weights of the actor-critic network. + The optimizer, LR scheduler, rollout buffer, normalizer statistics and step counters are left in their freshly created state. This is + intended for fine-tuning or transfer-learning workflows, where a pretrained model is used as the starting point for a new training run + (e.g. with a modified reward function or new environment data). The checkpoint is the one produced by + :code:`torchfort_rl_on_policy_save_checkpoint`. + :p character(:) name [in]: The name of system instance to use, as defined during system creation. :p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load. :r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure. diff --git a/docs/usage.rst b/docs/usage.rst index 1f92ace0..9c7562b3 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -351,5 +351,33 @@ This function is only required if RL training from the checkpoint should be resu istat = torchfort_load_model(model_name, policy_model_file); -can be used instead. The model instance should be created beforehand using the methods described in the :ref:`supervised_learning-ref` section. +can be used instead. The model instance should be created beforehand using the methods described in the :ref:`supervised_learning-ref` section. + +Fine-Tuning / Transfer Learning +------------------------------- + +Sometimes a previously trained system should be used as the *starting point* for a new training run, rather than fully resumed. A typical example is +refining a policy with new environment data or a modified reward function. In this case restoring the full checkpoint is undesirable, because it would +also restore the stale replay buffer, the optimizer momenta, the learning-rate schedule position and the training step counters of the previous run. + +For this purpose, only the network weights can be restored from a checkpoint: + +.. tabs:: + + .. code-tab:: fortran + + istat = torchfort_rl_off_policy_load_model(system_name, directory_name) + + .. code-tab:: c++ + + istat = torchfort_rl_off_policy_load_model(system_name, directory_name); + +This loads only the weights of the online policy and critic networks from the checkpoint directory created by ``torchfort_rl_off_policy_save_checkpoint``. +The optimizers, learning-rate schedulers, replay buffer, normalizer statistics and step counters remain in their freshly created state, and the target +networks are re-initialized from the loaded online networks. Training then proceeds from the pretrained weights with a clean training history, so that newly +collected transitions (e.g. generated under the modified reward function) are not mixed with stale experience. The system must be created beforehand with +``torchfort_rl_off_policy_create_system`` using a network architecture matching the saved checkpoint. + +On-policy systems provide the equivalent function ``torchfort_rl_on_policy_load_model``, which restores only the actor-critic network weights and leaves the +optimizer, learning-rate scheduler, rollout buffer, normalizer statistics and step counters in their freshly created state. diff --git a/src/csrc/include/internal/model_pack.h b/src/csrc/include/internal/model_pack.h index f09b331e..db74068b 100644 --- a/src/csrc/include/internal/model_pack.h +++ b/src/csrc/include/internal/model_pack.h @@ -23,6 +23,7 @@ #include "internal/base_loss.h" #include "internal/base_lr_scheduler.h" #include "internal/distributed.h" +#include "internal/exceptions.h" #include "internal/model_state.h" #include "internal/model_wrapper.h" #ifdef ENABLE_GPU @@ -51,4 +52,16 @@ struct ModelPack { void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true); +// Re-point an optimizer at the given parameter tensors, e.g. after (re)loading model weights (for JIT +// models a load replaces the underlying tensors, so the optimizer's references have to be refreshed). +// TorchFort constructs optimizers with a single parameter group, so we update group 0 directly; this +// avoids the deprecated torch::optim::Optimizer::parameters() accessor. +inline void reset_optimizer_parameters(const std::shared_ptr& optimizer, + const std::vector& parameters) { + if (optimizer->param_groups().size() != 1) { + THROW_NOT_SUPPORTED("reset_optimizer_parameters expects an optimizer with a single parameter group."); + } + optimizer->param_groups()[0].params() = parameters; +} + } // namespace torchfort diff --git a/src/csrc/include/internal/rl/off_policy.h b/src/csrc/include/internal/rl/off_policy.h index d40fd3a9..a05d1584 100644 --- a/src/csrc/include/internal/rl/off_policy.h +++ b/src/csrc/include/internal/rl/off_policy.h @@ -67,6 +67,9 @@ class RLOffPolicySystem { virtual void initSystemComm(MPI_Comm mpi_comm) = 0; virtual void saveCheckpoint(const std::string& checkpoint_dir) const = 0; virtual void loadCheckpoint(const std::string& checkpoint_dir) = 0; + // load only the network weights (e.g. for fine-tuning), leaving optimizers, + // LR schedulers, replay buffer and step counters in their freshly created state + virtual void loadModel(const std::string& checkpoint_dir) = 0; virtual torch::Device modelDevice() const = 0; virtual torch::Device rbDevice() const = 0; virtual int getRank() const = 0; diff --git a/src/csrc/include/internal/rl/off_policy/ddpg.h b/src/csrc/include/internal/rl/off_policy/ddpg.h index 4c0190a4..98944209 100644 --- a/src/csrc/include/internal/rl/off_policy/ddpg.h +++ b/src/csrc/include/internal/rl/off_policy/ddpg.h @@ -271,6 +271,7 @@ class DDPGSystem : public RLOffPolicySystem, public std::enable_shared_from_this // saving and loading void saveCheckpoint(const std::string& checkpoint_dir) const; void loadCheckpoint(const std::string& checkpoint_dir); + void loadModel(const std::string& checkpoint_dir); // info printing void printInfo() const; diff --git a/src/csrc/include/internal/rl/off_policy/sac.h b/src/csrc/include/internal/rl/off_policy/sac.h index 677059f2..cc821709 100644 --- a/src/csrc/include/internal/rl/off_policy/sac.h +++ b/src/csrc/include/internal/rl/off_policy/sac.h @@ -401,6 +401,7 @@ class SACSystem : public RLOffPolicySystem, public std::enable_shared_from_this< // saving and loading void saveCheckpoint(const std::string& checkpoint_dir) const; void loadCheckpoint(const std::string& checkpoint_dir); + void loadModel(const std::string& checkpoint_dir); // info printing void printInfo() const; diff --git a/src/csrc/include/internal/rl/off_policy/td3.h b/src/csrc/include/internal/rl/off_policy/td3.h index 1d267028..35b4c8f1 100644 --- a/src/csrc/include/internal/rl/off_policy/td3.h +++ b/src/csrc/include/internal/rl/off_policy/td3.h @@ -304,6 +304,7 @@ class TD3System : public RLOffPolicySystem, public std::enable_shared_from_this< // saving and loading void saveCheckpoint(const std::string& checkpoint_dir) const; void loadCheckpoint(const std::string& checkpoint_dir); + void loadModel(const std::string& checkpoint_dir); // info printing void printInfo() const; diff --git a/src/csrc/include/internal/rl/on_policy.h b/src/csrc/include/internal/rl/on_policy.h index 65e0f496..e63a94d2 100644 --- a/src/csrc/include/internal/rl/on_policy.h +++ b/src/csrc/include/internal/rl/on_policy.h @@ -69,6 +69,9 @@ class RLOnPolicySystem { virtual void initSystemComm(MPI_Comm mpi_comm) = 0; virtual void saveCheckpoint(const std::string& checkpoint_dir) const = 0; virtual void loadCheckpoint(const std::string& checkpoint_dir) = 0; + // load only the network weights (e.g. for fine-tuning), leaving optimizers, + // LR schedulers, rollout buffer and step counters in their freshly created state + virtual void loadModel(const std::string& checkpoint_dir) = 0; virtual torch::Device modelDevice() const = 0; virtual torch::Device rbDevice() const = 0; diff --git a/src/csrc/include/internal/rl/on_policy/ppo.h b/src/csrc/include/internal/rl/on_policy/ppo.h index ea74992e..48c7d0a8 100644 --- a/src/csrc/include/internal/rl/on_policy/ppo.h +++ b/src/csrc/include/internal/rl/on_policy/ppo.h @@ -258,6 +258,7 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_thisparameters() = model_pack.model->parameters(); + reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters()); } if (load_optimizer) { diff --git a/src/csrc/rl/off_policy/ddpg.cpp b/src/csrc/rl/off_policy/ddpg.cpp index c6aeb9ca..3ec7fb19 100644 --- a/src/csrc/rl/off_policy/ddpg.cpp +++ b/src/csrc/rl/off_policy/ddpg.cpp @@ -409,6 +409,36 @@ void DDPGSystem::loadCheckpoint(const std::string& checkpoint_dir) { } } +// loading only the network weights (e.g. for fine-tuning / transfer learning): +// this restores the online policy and critic weights from a checkpoint directory, but +// leaves optimizers, LR schedulers, replay buffer, normalizers and step counters untouched. +void DDPGSystem::loadModel(const std::string& checkpoint_dir) { + using namespace torchfort; + std::filesystem::path root_dir(checkpoint_dir); + + // load only the model weights (model.pt) of a model pack, reconnecting the optimizer + // to the (possibly newly allocated) model parameters afterwards. + auto load_weights = [](auto& model_pack, const std::filesystem::path& dir) { + auto model_path = dir / "model.pt"; + if (!std::filesystem::exists(model_path)) { + THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); + } + model_pack.model->load(model_path.native()); + if (model_pack.optimizer) { + reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters()); + } + }; + + // online policy and critic + load_weights(p_model_, root_dir / "policy"); + load_weights(q_model_, root_dir / "critic"); + + // initialize the target networks from the freshly loaded online networks; the saved + // target weights are an artifact of the previous run and are intentionally ignored. + copy_parameters(p_model_target_.model, p_model_.model); + copy_parameters(q_model_target_.model, q_model_.model); +} + // we should pass a tuple (s, a, s', r, d) void DDPGSystem::updateReplayBuffer(torch::Tensor s, torch::Tensor a, torch::Tensor sp, torch::Tensor r, torch::Tensor d) { diff --git a/src/csrc/rl/off_policy/interface.cpp b/src/csrc/rl/off_policy/interface.cpp index 967dee58..010e2442 100644 --- a/src/csrc/rl/off_policy/interface.cpp +++ b/src/csrc/rl/off_policy/interface.cpp @@ -159,6 +159,19 @@ torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char* name, con return TORCHFORT_RESULT_SUCCESS; } +// load network weights only (e.g. for fine-tuning) +torchfort_result_t torchfort_rl_off_policy_load_model(const char* name, const char* checkpoint_dir) { + using namespace torchfort; + + try { + rl::off_policy::registry[name]->loadModel(checkpoint_dir); + } catch (const BaseException& e) { + std::cerr << e.what(); + return e.getResult(); + } + return TORCHFORT_RESULT_SUCCESS; +} + // ready check torchfort_result_t torchfort_rl_off_policy_is_ready(const char* name, bool& ready) { using namespace torchfort; diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index 99e7ac19..fb8deaa0 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -484,7 +484,7 @@ void SACSystem::loadCheckpoint(const std::string& checkpoint_dir) { p_model_.model->load(model_path.native()); // connect model and optimizer parameters: - p_model_.optimizer->parameters() = p_model_.model->parameters(); + reset_optimizer_parameters(p_model_.optimizer, p_model_.model->parameters()); auto optimizer_path = root_dir / "policy" / "optimizer.pt"; if (!std::filesystem::exists(optimizer_path)) { @@ -530,7 +530,7 @@ void SACSystem::loadCheckpoint(const std::string& checkpoint_dir) { alpha_model_->to(model_device_); // connect model and optimizer parameters: - alpha_optimizer_->parameters() = alpha_model_->parameters(); + reset_optimizer_parameters(alpha_optimizer_, alpha_model_->parameters()); auto optimizer_path = root_dir / "alpha" / "optimizer.pt"; if (!std::filesystem::exists(optimizer_path)) { @@ -590,6 +590,41 @@ void SACSystem::loadCheckpoint(const std::string& checkpoint_dir) { } } +// loading only the network weights (e.g. for fine-tuning / transfer learning): +// this restores the online policy and critic weights from a checkpoint directory, but +// leaves optimizers, LR schedulers, the entropy temperature (alpha), replay buffer, +// normalizers and step counters untouched. +void SACSystem::loadModel(const std::string& checkpoint_dir) { + using namespace torchfort; + std::filesystem::path root_dir(checkpoint_dir); + + // load only the model weights (model.pt) of a model pack, reconnecting the optimizer + // to the (possibly newly allocated) model parameters afterwards. + auto load_weights = [](auto& model_pack, const std::filesystem::path& dir) { + auto model_path = dir / "model.pt"; + if (!std::filesystem::exists(model_path)) { + THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); + } + model_pack.model->load(model_path.native()); + if (model_pack.optimizer) { + reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters()); + } + }; + + // online policy and critics + load_weights(p_model_, root_dir / "policy"); + for (int i = 0; i < q_models_.size(); ++i) { + load_weights(q_models_[i], root_dir / ("critic_" + std::to_string(i))); + } + + // initialize the critic target networks from the freshly loaded online critics; the saved + // target weights are an artifact of the previous run and are intentionally ignored. + // SAC has no policy target network. + for (int i = 0; i < q_models_target_.size(); ++i) { + copy_parameters(q_models_target_[i].model, q_models_[i].model); + } +} + // we should pass a tuple (s, a, s', r, d) void SACSystem::updateReplayBuffer(torch::Tensor s, torch::Tensor a, torch::Tensor sp, torch::Tensor r, torch::Tensor d) { diff --git a/src/csrc/rl/off_policy/td3.cpp b/src/csrc/rl/off_policy/td3.cpp index 8ad9bee9..5afbe83d 100644 --- a/src/csrc/rl/off_policy/td3.cpp +++ b/src/csrc/rl/off_policy/td3.cpp @@ -481,6 +481,40 @@ void TD3System::loadCheckpoint(const std::string& checkpoint_dir) { } } +// loading only the network weights (e.g. for fine-tuning / transfer learning): +// this restores the online policy and critic weights from a checkpoint directory, but +// leaves optimizers, LR schedulers, replay buffer, normalizers and step counters untouched. +void TD3System::loadModel(const std::string& checkpoint_dir) { + using namespace torchfort; + std::filesystem::path root_dir(checkpoint_dir); + + // load only the model weights (model.pt) of a model pack, reconnecting the optimizer + // to the (possibly newly allocated) model parameters afterwards. + auto load_weights = [](auto& model_pack, const std::filesystem::path& dir) { + auto model_path = dir / "model.pt"; + if (!std::filesystem::exists(model_path)) { + THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); + } + model_pack.model->load(model_path.native()); + if (model_pack.optimizer) { + reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters()); + } + }; + + // online policy and critics + load_weights(p_model_, root_dir / "policy"); + for (int i = 0; i < q_models_.size(); ++i) { + load_weights(q_models_[i], root_dir / ("critic_" + std::to_string(i))); + } + + // initialize the target networks from the freshly loaded online networks; the saved + // target weights are an artifact of the previous run and are intentionally ignored. + copy_parameters(p_model_target_.model, p_model_.model); + for (int i = 0; i < q_models_target_.size(); ++i) { + copy_parameters(q_models_target_[i].model, q_models_[i].model); + } +} + // we should pass a tuple (s, a, s', r, d) void TD3System::updateReplayBuffer(torch::Tensor s, torch::Tensor a, torch::Tensor sp, torch::Tensor r, torch::Tensor d) { diff --git a/src/csrc/rl/on_policy/interface.cpp b/src/csrc/rl/on_policy/interface.cpp index 8e7daa11..7a62c18e 100644 --- a/src/csrc/rl/on_policy/interface.cpp +++ b/src/csrc/rl/on_policy/interface.cpp @@ -148,6 +148,19 @@ torchfort_result_t torchfort_rl_on_policy_load_checkpoint(const char* name, cons return TORCHFORT_RESULT_SUCCESS; } +// load network weights only (e.g. for fine-tuning) +torchfort_result_t torchfort_rl_on_policy_load_model(const char* name, const char* checkpoint_dir) { + using namespace torchfort; + + try { + rl::on_policy::registry[name]->loadModel(checkpoint_dir); + } catch (const BaseException& e) { + std::cerr << e.what(); + return e.getResult(); + } + return TORCHFORT_RESULT_SUCCESS; +} + // ready check torchfort_result_t torchfort_rl_on_policy_is_ready(const char* name, bool& ready) { using namespace torchfort; diff --git a/src/csrc/rl/on_policy/ppo.cpp b/src/csrc/rl/on_policy/ppo.cpp index 28ea0499..75b6b2e9 100644 --- a/src/csrc/rl/on_policy/ppo.cpp +++ b/src/csrc/rl/on_policy/ppo.cpp @@ -19,6 +19,7 @@ #include #include "internal/exceptions.h" +#include "internal/model_pack.h" #include "internal/rl/distributions.h" #include "internal/rl/on_policy/ppo.h" @@ -293,7 +294,7 @@ void PPOSystem::loadCheckpoint(const std::string& checkpoint_dir) { pq_model_.model->load(model_path.native()); // connect model and optimizer parameters: - pq_model_.optimizer->parameters() = pq_model_.model->parameters(); + reset_optimizer_parameters(pq_model_.optimizer, pq_model_.model->parameters()); auto optimizer_path = root_dir / "actor_critic" / "optimizer.pt"; if (!std::filesystem::exists(optimizer_path)) { @@ -357,6 +358,25 @@ void PPOSystem::loadCheckpoint(const std::string& checkpoint_dir) { } } +// loading only the network weights (e.g. for fine-tuning / transfer learning): +// this restores the actor-critic network weights from a checkpoint directory, but leaves the optimizer, +// LR scheduler, rollout buffer, normalizers and step counters in their freshly created state. +void PPOSystem::loadModel(const std::string& checkpoint_dir) { + using namespace torchfort; + std::filesystem::path root_dir(checkpoint_dir); + + auto model_path = root_dir / "actor_critic" / "model.pt"; + if (!std::filesystem::exists(model_path)) { + THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); + } + pq_model_.model->load(model_path.native()); + + // reconnect the optimizer to the (possibly newly allocated) model parameters + if (pq_model_.optimizer) { + reset_optimizer_parameters(pq_model_.optimizer, pq_model_.model->parameters()); + } +} + // convenience function for n_envs=1: if this is not the case, the error will be captured // in the replay buffer update function, so no need to check it here void PPOSystem::updateRolloutBuffer(torch::Tensor stens, torch::Tensor atens, float r, bool d) { diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index 324738ae..5dd437e0 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -332,7 +332,7 @@ torchfort_result_t torchfort_load_model(const char* name, const char* fname) { } models[name].model->load(fname); if (models[name].optimizer) { - models[name].optimizer->parameters() = models[name].model->parameters(); + reset_optimizer_parameters(models[name].optimizer, models[name].model->parameters()); } } catch (const BaseException& e) { std::cerr << e.what(); diff --git a/src/fsrc/torchfort_m.F90 b/src/fsrc/torchfort_m.F90 index a59d3f29..f932d1c1 100644 --- a/src/fsrc/torchfort_m.F90 +++ b/src/fsrc/torchfort_m.F90 @@ -280,6 +280,14 @@ function torchfort_rl_off_policy_load_checkpoint_c(mname, checkpoint_dir) result integer(c_int) :: res end function torchfort_rl_off_policy_load_checkpoint_c + function torchfort_rl_off_policy_load_model_c(mname, checkpoint_dir) result(res) & + bind(C, name="torchfort_rl_off_policy_load_model") + import + type(*) :: mname(*) + type(*) :: checkpoint_dir(*) + integer(c_int) :: res + end function torchfort_rl_off_policy_load_model_c + ! training function torchfort_rl_off_policy_update_replay_buffer_c(mname, & state_old, state_new, state_dim, state_shape, & @@ -450,6 +458,14 @@ function torchfort_rl_on_policy_load_checkpoint_c(mname, checkpoint_dir) result( integer(c_int) :: res end function torchfort_rl_on_policy_load_checkpoint_c + function torchfort_rl_on_policy_load_model_c(mname, checkpoint_dir) result(res) & + bind(C, name="torchfort_rl_on_policy_load_model") + import + type(*) :: mname(*) + type(*) :: checkpoint_dir(*) + integer(c_int) :: res + end function torchfort_rl_on_policy_load_model_c + ! training function torchfort_rl_on_policy_update_rollout_buffer_c(mname, & state, state_dim, state_shape, & @@ -7394,6 +7410,14 @@ function torchfort_rl_off_policy_load_checkpoint(mname, checkpoint_dir) result(r [trim(checkpoint_dir) // C_NULL_CHAR]) end function torchfort_rl_off_policy_load_checkpoint + function torchfort_rl_off_policy_load_model(mname, checkpoint_dir) result(res) + character(len=*) :: mname + character(len=*) :: checkpoint_dir + integer(c_int) :: res + res = torchfort_rl_off_policy_load_model_c([trim(mname) // C_NULL_CHAR], & + [trim(checkpoint_dir) // C_NULL_CHAR]) + end function torchfort_rl_off_policy_load_model + ! Training routines function torchfort_rl_off_policy_update_replay_buffer_float_1d_1d(mname, state_old, act_old, state_new, & reward, final_state, stream) result(res) @@ -8715,6 +8739,14 @@ function torchfort_rl_on_policy_load_checkpoint(mname, checkpoint_dir) result(re [trim(checkpoint_dir) // C_NULL_CHAR]) end function torchfort_rl_on_policy_load_checkpoint + function torchfort_rl_on_policy_load_model(mname, checkpoint_dir) result(res) + character(len=*) :: mname + character(len=*) :: checkpoint_dir + integer(c_int) :: res + res = torchfort_rl_on_policy_load_model_c([trim(mname) // C_NULL_CHAR], & + [trim(checkpoint_dir) // C_NULL_CHAR]) + end function torchfort_rl_on_policy_load_model + ! Training routines ! single env tollout buffer updates function torchfort_rl_on_policy_update_rollout_buffer_float_1d_1d(mname, state, act, & diff --git a/tests/rl/CMakeLists.txt b/tests/rl/CMakeLists.txt index d9c51fb4..240bbca6 100644 --- a/tests/rl/CMakeLists.txt +++ b/tests/rl/CMakeLists.txt @@ -8,6 +8,7 @@ set(test_targets test_interface test_off_policy test_on_policy + test_checkpoint_rl ) add_executable(test_replay_buffer) @@ -52,12 +53,19 @@ target_sources(test_on_policy test_on_policy.cpp ) +add_executable(test_checkpoint_rl) +target_sources(test_checkpoint_rl + PRIVATE + test_checkpoint.cpp + ) + find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) foreach(tgt ${test_targets}) target_include_directories(${tgt} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../ ${YAML_CPP_INCLUDE_DIR} ${MPI_CXX_INCLUDE_DIRS} ${CMAKE_BINARY_DIR}/include @@ -82,6 +90,13 @@ foreach(tgt ${test_targets}) # discover tests: we have an issue with the work dir of gtest so disable that for now #gtest_discover_tests(${tgt}) add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + + # copy the config files next to the built test binary so the tests can be run directly from the + # build tree (the tests resolve their configs relative to the executable location) + add_custom_command(TARGET ${tgt} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/configs $/configs + ) endforeach() # installation diff --git a/tests/rl/test_checkpoint.cpp b/tests/rl/test_checkpoint.cpp new file mode 100644 index 00000000..9d61e9a9 --- /dev/null +++ b/tests/rl/test_checkpoint.cpp @@ -0,0 +1,303 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_GPU +#include +#endif + +#include +#include +#include +#include +#include +#include + +#include "environments.h" +#include "internal/defines.h" +#include "torchfort.h" +#include +#include + +#include "test_utils.h" + +namespace { + +// RAII helper to silence TorchFort's (std::cout) logging during the test bodies. This does not hide +// gtest output, which is written via C stdio rather than std::cout. +struct CoutRedirect { + CoutRedirect(std::streambuf* new_buffer) : old(std::cout.rdbuf(new_buffer)) {} + ~CoutRedirect() { std::cout.rdbuf(old); } + +private: + std::streambuf* old; +}; + +// run a deterministic policy prediction for a fixed scalar state and return the scalar action +float predict_fixed(const std::string& name, float state_value) { + std::vector state_shape{1}, action_shape{1}; + std::vector state_batch_shape{1, 1}, action_batch_shape{1, 1}; + torch::Tensor state = torch::full(state_shape, state_value, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor action = torch::zeros(action_shape, torch::TensorOptions().dtype(torch::kFloat32)); + CHECK_TORCHFORT(torchfort_rl_off_policy_predict(name.c_str(), state.data_ptr(), 2, state_batch_shape.data(), + action.data_ptr(), 2, action_batch_shape.data(), TORCHFORT_FLOAT, 0)); + return action.item(); +} + +// fill the replay buffer with random transitions until the system is ready, then perform a few training steps +void train_and_fill(const std::string& name, int num_train_steps) { + std::vector state_shape{1}, action_shape{1}; + torch::Tensor state = torch::zeros(state_shape, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor state_new = torch::zeros(state_shape, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor action = torch::zeros(action_shape, torch::TensorOptions().dtype(torch::kFloat32)); + + bool ready = false; + int pushed = 0; + const int max_push = 8192; // safety cap, well above any configured min_size + while (!ready && pushed < max_push) { + { + torch::NoGradGuard no_grad; + state.uniform_(-1., 1.); + state_new.uniform_(-1., 1.); + action.uniform_(-1., 1.); + } + float reward = 0.5f; + bool done = (pushed % 8 == 7); + CHECK_TORCHFORT(torchfort_rl_off_policy_update_replay_buffer( + name.c_str(), state.data_ptr(), state_new.data_ptr(), 1, state_shape.data(), action.data_ptr(), 1, + action_shape.data(), &reward, done, TORCHFORT_FLOAT, 0)); + ++pushed; + CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(name.c_str(), ready)); + } + ASSERT_TRUE(ready); + + float p_loss, q_loss; + for (int i = 0; i < num_train_steps; ++i) { + CHECK_TORCHFORT(torchfort_rl_off_policy_train_step(name.c_str(), &p_loss, &q_loss, 0)); + } +} + +// Round-trip test for both checkpoint restore mechanisms of an off-policy system: +// * torchfort_rl_off_policy_load_checkpoint -> full restore (weights + replay buffer) +// * torchfort_rl_off_policy_load_model -> weights-only restore (e.g. for fine-tuning) +void checkpoint_roundtrip(const std::string& system) { + // silence TorchFort's verbose/logging output for the duration of the test + std::stringstream cout_buffer; + CoutRedirect cout_redirect(cout_buffer.rdbuf()); + + torch::manual_seed(666); + + const std::string config = get_config_path(system + ".yaml"); + const std::string src = system + "_src"; + const std::string full = system + "_restore_full"; + const std::string weights = system + "_restore_weights"; + const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system; + + std::filesystem::remove_all(ckpt_dir); + + // create and train the source system, filling the replay buffer past its min_size + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + train_and_fill(src, /* num_train_steps = */ 5); + + // reference deterministic prediction for a fixed evaluation state + const float eval_state = 0.5f; + const float action_ref = predict_fixed(src, eval_state); + + // sanity: the source system has a filled (ready) replay buffer + bool src_ready = false; + CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(src.c_str(), src_ready)); + EXPECT_TRUE(src_ready); + + // save the full checkpoint + CHECK_TORCHFORT(torchfort_rl_off_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); + + // ---------------- full checkpoint restore ---------------- + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + // a freshly created system has an empty buffer and is not ready + bool full_ready_before = true; + CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_before)); + EXPECT_FALSE(full_ready_before); + + CHECK_TORCHFORT(torchfort_rl_off_policy_load_checkpoint(full.c_str(), ckpt_dir.c_str())); + + // the policy weights are restored -> identical deterministic prediction + const float action_full = predict_fixed(full, eval_state); + EXPECT_NEAR(action_full, action_ref, 1e-5); + + // the replay buffer is restored as well -> the system is ready right away + bool full_ready_after = false; + CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_after)); + EXPECT_TRUE(full_ready_after); + + // ---------------- weights-only restore (load_model) ---------------- + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(weights.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + // prediction before loading stems from a fresh (independent) initialization and differs from the reference, + // which makes the post-load match below a meaningful check that loading actually happened + const float action_fresh = predict_fixed(weights, eval_state); + EXPECT_GT(std::abs(action_ref - action_fresh), 1e-6); + + CHECK_TORCHFORT(torchfort_rl_off_policy_load_model(weights.c_str(), ckpt_dir.c_str())); + + // the policy weights are restored -> identical deterministic prediction + const float action_weights = predict_fixed(weights, eval_state); + EXPECT_NEAR(action_weights, action_ref, 1e-5); + + // the replay buffer is NOT restored -> the freshly created buffer is still empty and the system is not ready + bool weights_ready = true; + CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(weights.c_str(), weights_ready)); + EXPECT_FALSE(weights_ready); + + // cleanup + std::filesystem::remove_all(ckpt_dir); +} + +// ===================================== on-policy (PPO) ===================================== + +// run a deterministic policy prediction for a fixed scalar state and return the scalar action +float predict_fixed_on_policy(const std::string& name, float state_value) { + std::vector state_shape{1}, action_shape{1}; + std::vector state_batch_shape{1, 1}, action_batch_shape{1, 1}; + torch::Tensor state = torch::full(state_shape, state_value, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor action = torch::zeros(action_shape, torch::TensorOptions().dtype(torch::kFloat32)); + CHECK_TORCHFORT(torchfort_rl_on_policy_predict(name.c_str(), state.data_ptr(), 2, state_batch_shape.data(), + action.data_ptr(), 2, action_batch_shape.data(), TORCHFORT_FLOAT, 0)); + return action.item(); +} + +// step a simple environment, pushing transitions into the rollout buffer until the system is ready +// (the rollout buffer becomes ready once it is full and finalized) +void rollout_until_ready(const std::string& name) { + std::vector state_shape{1}, action_shape{1}; + std::vector state_batch_shape{1, 1}, action_batch_shape{1, 1}; + auto env = std::make_shared(1u, state_shape, action_shape); + + torch::Tensor state, state_new, action; + action = torch::zeros(action_shape, torch::TensorOptions().dtype(torch::kFloat32)); + float reward; + bool done; + std::tie(state, reward) = env->initialize(); + + bool ready = false; + int guard = 0; + const int max_iter = 100000; // safety cap, well above any configured rollout buffer size + while (!ready && guard < max_iter) { + CHECK_TORCHFORT(torchfort_rl_on_policy_predict_explore(name.c_str(), state.data_ptr(), 2, state_batch_shape.data(), + action.data_ptr(), 2, action_batch_shape.data(), + TORCHFORT_FLOAT, 0)); + std::tie(state_new, reward, done) = env->step(action); + CHECK_TORCHFORT(torchfort_rl_on_policy_update_rollout_buffer(name.c_str(), state.data_ptr(), 1, state_shape.data(), + action.data_ptr(), 1, action_shape.data(), &reward, + done, TORCHFORT_FLOAT, 0)); + CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(name.c_str(), ready)); + state = state_new; + ++guard; + } + ASSERT_TRUE(ready); +} + +// Round-trip test for both checkpoint restore mechanisms of the on-policy (PPO) system: +// * torchfort_rl_on_policy_load_checkpoint -> full restore (weights + rollout buffer) +// * torchfort_rl_on_policy_load_model -> weights-only restore (e.g. for fine-tuning) +void checkpoint_roundtrip_on_policy(const std::string& system) { + // silence TorchFort's verbose/logging output for the duration of the test + std::stringstream cout_buffer; + CoutRedirect cout_redirect(cout_buffer.rdbuf()); + + torch::manual_seed(666); + + const std::string config = get_config_path(system + ".yaml"); + const std::string src = system + "_src"; + const std::string full = system + "_restore_full"; + const std::string weights = system + "_restore_weights"; + const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system; + + std::filesystem::remove_all(ckpt_dir); + + // create the source system and fill its rollout buffer until it is ready + CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + rollout_until_ready(src); + + // reference deterministic prediction for a fixed evaluation state + const float eval_state = 0.5f; + const float action_ref = predict_fixed_on_policy(src, eval_state); + + // sanity: the source system has a filled (ready) rollout buffer + bool src_ready = false; + CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(src.c_str(), src_ready)); + EXPECT_TRUE(src_ready); + + // save the full checkpoint + CHECK_TORCHFORT(torchfort_rl_on_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); + + // ---------------- full checkpoint restore ---------------- + CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + // a freshly created system has an empty rollout buffer and is not ready + bool full_ready_before = true; + CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(full.c_str(), full_ready_before)); + EXPECT_FALSE(full_ready_before); + + CHECK_TORCHFORT(torchfort_rl_on_policy_load_checkpoint(full.c_str(), ckpt_dir.c_str())); + + // the network weights are restored -> identical deterministic prediction + const float action_full = predict_fixed_on_policy(full, eval_state); + EXPECT_NEAR(action_full, action_ref, 1e-5); + + // the rollout buffer is restored as well -> the system is ready right away + bool full_ready_after = false; + CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(full.c_str(), full_ready_after)); + EXPECT_TRUE(full_ready_after); + + // ---------------- weights-only restore (load_model) ---------------- + CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(weights.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, + TORCHFORT_DEVICE_CPU)); + // prediction before loading stems from a fresh (independent) initialization and differs from the reference + const float action_fresh = predict_fixed_on_policy(weights, eval_state); + EXPECT_GT(std::abs(action_ref - action_fresh), 1e-6); + + CHECK_TORCHFORT(torchfort_rl_on_policy_load_model(weights.c_str(), ckpt_dir.c_str())); + + // the network weights are restored -> identical deterministic prediction + const float action_weights = predict_fixed_on_policy(weights, eval_state); + EXPECT_NEAR(action_weights, action_ref, 1e-5); + + // the rollout buffer is NOT restored -> the freshly created buffer is still empty and the system is not ready + bool weights_ready = true; + CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(weights.c_str(), weights_ready)); + EXPECT_FALSE(weights_ready); + + // cleanup + std::filesystem::remove_all(ckpt_dir); +} + +} // namespace + +TEST(Checkpoint, TD3) { checkpoint_roundtrip("td3"); } + +TEST(Checkpoint, DDPG) { checkpoint_roundtrip("ddpg"); } + +TEST(Checkpoint, SAC) { checkpoint_roundtrip("sac"); } + +TEST(Checkpoint, PPO) { checkpoint_roundtrip_on_policy("ppo"); } + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tests/rl/test_interface.cpp b/tests/rl/test_interface.cpp index 838f627f..1301646d 100644 --- a/tests/rl/test_interface.cpp +++ b/tests/rl/test_interface.cpp @@ -26,6 +26,8 @@ #include #include +#include "test_utils.h" + struct CoutRedirect { CoutRedirect(std::streambuf* new_buffer) : old(std::cout.rdbuf(new_buffer)) {} @@ -38,7 +40,7 @@ struct CoutRedirect { // Function to modify TD3 config and write to temporary file std::string createModifiedTD3Config(int state_size, int action_size) { // Read the original TD3 config - std::string config_path = "configs/td3.yaml"; + std::string config_path = get_config_path("td3.yaml"); YAML::Node config = YAML::LoadFile(config_path); // Modify the policy model layer sizes @@ -78,7 +80,7 @@ std::string createModifiedTD3Config(int state_size, int action_size) { // Function to modify PPO config and write to temporary file std::string createModifiedPPOConfig(int state_size, int action_size) { // Read the original TD3 config - std::string config_path = "configs/ppo.yaml"; + std::string config_path = get_config_path("ppo.yaml"); YAML::Node config = YAML::LoadFile(config_path); // Modify the policy model layer sizes diff --git a/tests/rl/test_off_policy.cpp b/tests/rl/test_off_policy.cpp index 9b1a9b02..ce807409 100644 --- a/tests/rl/test_off_policy.cpp +++ b/tests/rl/test_off_policy.cpp @@ -21,6 +21,7 @@ #include #include "internal/defines.h" +#include "test_utils.h" enum EnvMode { Constant, Predictable, Delayed, Action, ActionState }; @@ -88,7 +89,7 @@ std::tuple TestSystem(const EnvMode mode, const std::string num_episodes = (num_train_iters + num_eval_iters) / episode_length; // set up td3 learning systems - std::string filename = "configs/" + system + ".yaml"; + std::string filename = get_config_path(system + ".yaml"); CHECK_TORCHFORT( torchfort_rl_off_policy_create_system("test", filename.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); diff --git a/tests/rl/test_on_policy.cpp b/tests/rl/test_on_policy.cpp index 86dc4309..3d996f61 100644 --- a/tests/rl/test_on_policy.cpp +++ b/tests/rl/test_on_policy.cpp @@ -21,6 +21,7 @@ #include #include "internal/defines.h" +#include "test_utils.h" enum EnvMode { Constant, Predictable, Delayed, Action, ActionState }; @@ -89,7 +90,7 @@ std::tuple TestSystem(const EnvMode mode, const std::string num_episodes = (num_train_iters + num_eval_iters) / episode_length; // set up td3 learning systems - std::string filename = "configs/" + system + ".yaml"; + std::string filename = get_config_path(system + ".yaml"); CHECK_TORCHFORT( torchfort_rl_on_policy_create_system("test", filename.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); diff --git a/tests/test_utils.h b/tests/test_utils.h index 9222a97e..cb87843b 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -16,17 +16,52 @@ */ #pragma once +// Keep this header self-contained: it references TORCHFORT_DEVICE_CPU (from torchfort.h) and, in GPU +// builds, CHECK_CUDA (from internal/defines.h). Including them here lets test_utils.h be included in +// any order relative to those headers. +#include "torchfort.h" +#ifdef ENABLE_GPU +#include "internal/defines.h" +#endif + #include +#include +#include #include #include #include #include +#include #include #ifdef ENABLE_GPU #include #endif +// Resolve the path to a test config file independently of the current working directory. +// Resolution order: +// 1. $TORCHFORT_TEST_CONFIG_DIR/ if that environment variable is set, +// 2. /configs/, +// 3. configs/ (legacy behavior, relative to the current working directory). +// This keeps the tests runnable from any directory, in both the build tree and the install tree. +inline std::string get_config_path(const std::string& name) { + namespace fs = std::filesystem; + + if (const char* env_dir = std::getenv("TORCHFORT_TEST_CONFIG_DIR")) { + return (fs::path(env_dir) / name).string(); + } + +#if defined(__linux__) + std::error_code ec; + fs::path exe = fs::read_symlink("/proc/self/exe", ec); + if (!ec) { + return (exe.parent_path() / "configs" / name).string(); + } +#endif + + return (fs::path("configs") / name).string(); +} + // Generate random vector data for testing template std::vector generate_random(const std::vector& shape) { From 77cc5a30297ee339b849945e7d95b170e67f08a9 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 02:48:26 -0700 Subject: [PATCH 2/6] fixing formatting Signed-off-by: Thorsten Kurth --- tests/rl/test_checkpoint.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/rl/test_checkpoint.cpp b/tests/rl/test_checkpoint.cpp index 9d61e9a9..c456d55f 100644 --- a/tests/rl/test_checkpoint.cpp +++ b/tests/rl/test_checkpoint.cpp @@ -108,8 +108,8 @@ void checkpoint_roundtrip(const std::string& system) { std::filesystem::remove_all(ckpt_dir); // create and train the source system, filling the replay buffer past its min_size - CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT( + torchfort_rl_off_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); train_and_fill(src, /* num_train_steps = */ 5); // reference deterministic prediction for a fixed evaluation state @@ -125,8 +125,8 @@ void checkpoint_roundtrip(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_off_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); // ---------------- full checkpoint restore ---------------- - CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT( + torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); // a freshly created system has an empty buffer and is not ready bool full_ready_before = true; CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_before)); @@ -229,8 +229,8 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { std::filesystem::remove_all(ckpt_dir); // create the source system and fill its rollout buffer until it is ready - CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT( + torchfort_rl_on_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); rollout_until_ready(src); // reference deterministic prediction for a fixed evaluation state @@ -246,8 +246,8 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_on_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); // ---------------- full checkpoint restore ---------------- - CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT( + torchfort_rl_on_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); // a freshly created system has an empty rollout buffer and is not ready bool full_ready_before = true; CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(full.c_str(), full_ready_before)); From bab4c1bb04fef6c056662178c3ef64cd15288c14 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 02:55:14 -0700 Subject: [PATCH 3/6] adding clarification that target networks are reloaded as well when using load model Signed-off-by: Thorsten Kurth --- docs/api/f_api.rst | 9 +++++---- docs/usage.rst | 5 +++-- src/csrc/include/torchfort_rl.h | 9 +++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/docs/api/f_api.rst b/docs/api/f_api.rst index 01b04445..1c6a0135 100644 --- a/docs/api/f_api.rst +++ b/docs/api/f_api.rst @@ -554,10 +554,11 @@ torchfort_rl_off_policy_load_model Restores only the network weights of a reinforcement learning system from a checkpoint. In contrast to :code:`torchfort_rl_off_policy_load_checkpoint`, this method only restores the weights of the online policy and - critic networks. The optimizers, LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly - created state, and the target networks are re-initialized from the loaded online networks. This is intended for fine-tuning or - transfer-learning workflows, where a pretrained model is used as the starting point for a new training run (e.g. with a modified - reward function or new environment data). The checkpoint is the one produced by :code:`torchfort_rl_off_policy_save_checkpoint`. + critic networks. For algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored, + by re-initializing them from the loaded online networks so that they start consistent with the restored weights. The optimizers, + LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly created state. This is intended for + fine-tuning or transfer-learning workflows, where a pretrained model is used as the starting point for a new training run (e.g. with + a modified reward function or new environment data). The checkpoint is the one produced by :code:`torchfort_rl_off_policy_save_checkpoint`. :p character(:) name [in]: The name of system instance to use, as defined during system creation. :p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load. diff --git a/docs/usage.rst b/docs/usage.rst index 9c7562b3..8e7f2222 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -373,8 +373,9 @@ For this purpose, only the network weights can be restored from a checkpoint: istat = torchfort_rl_off_policy_load_model(system_name, directory_name); This loads only the weights of the online policy and critic networks from the checkpoint directory created by ``torchfort_rl_off_policy_save_checkpoint``. -The optimizers, learning-rate schedulers, replay buffer, normalizer statistics and step counters remain in their freshly created state, and the target -networks are re-initialized from the loaded online networks. Training then proceeds from the pretrained weights with a clean training history, so that newly +For algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored, by re-initializing them from the loaded +online networks so that they start consistent with the restored weights. The optimizers, learning-rate schedulers, replay buffer, normalizer statistics and +step counters remain in their freshly created state. Training then proceeds from the pretrained weights with a clean training history, so that newly collected transitions (e.g. generated under the modified reward function) are not mixed with stale experience. The system must be created beforehand with ``torchfort_rl_off_policy_create_system`` using a network architecture matching the saved checkpoint. diff --git a/src/csrc/include/torchfort_rl.h b/src/csrc/include/torchfort_rl.h index e49da90b..f50ca55c 100644 --- a/src/csrc/include/torchfort_rl.h +++ b/src/csrc/include/torchfort_rl.h @@ -342,11 +342,12 @@ torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char* name, con /** * @brief Restores only the network weights of a reinforcement learning system from a checkpoint. * @details In contrast to \p torchfort_rl_off_policy_load_checkpoint, this method only restores the weights of the - * online policy and critic networks from the checkpoint produced by \p torchfort_rl_off_policy_save_checkpoint. The + * online policy and critic networks from the checkpoint produced by \p torchfort_rl_off_policy_save_checkpoint. For + * algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored, by + * re-initializing them from the loaded online networks so that they start consistent with the restored weights. The * optimizers, LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly created - * state, and the target networks are re-initialized from the loaded online networks. This is intended for fine-tuning - * or transfer-learning workflows, where a pretrained model is used as the starting point for a new training run (e.g. - * with a modified reward function or new environment data). + * state. This is intended for fine-tuning or transfer-learning workflows, where a pretrained model is used as the + * starting point for a new training run (e.g. with a modified reward function or new environment data). * * @param[in] name The name of a system instance to restore the weights for, as defined during system creation. * @param[in] checkpoint_dir A filesystem path to a directory which contains the checkpoint data to load. From 0554d8304056570de78f4c2906e78c749b9f1f6a Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 23:24:25 -0700 Subject: [PATCH 4/6] adding device awareness in save restore tests, fixing some relative directory path error in test Signed-off-by: Thorsten Kurth --- src/csrc/rl/off_policy/sac.cpp | 4 +- src/csrc/rl/on_policy/ppo.cpp | 2 +- tests/rl/test_checkpoint.cpp | 109 +++++++++++++++++++++++++-------- tests/rl/test_interface.cpp | 14 ++--- 4 files changed, 94 insertions(+), 35 deletions(-) diff --git a/src/csrc/rl/off_policy/sac.cpp b/src/csrc/rl/off_policy/sac.cpp index fb8deaa0..a43ba5f4 100644 --- a/src/csrc/rl/off_policy/sac.cpp +++ b/src/csrc/rl/off_policy/sac.cpp @@ -490,7 +490,7 @@ void SACSystem::loadCheckpoint(const std::string& checkpoint_dir) { if (!std::filesystem::exists(optimizer_path)) { THROW_INVALID_USAGE("Could not find " + optimizer_path.native() + "."); } - torch::load(*(p_model_.optimizer), optimizer_path.native()); + torch::load(*(p_model_.optimizer), optimizer_path.native(), p_model_.model->device()); auto lr_path = root_dir / "policy" / "lr.pt"; if (!std::filesystem::exists(lr_path)) { @@ -536,7 +536,7 @@ void SACSystem::loadCheckpoint(const std::string& checkpoint_dir) { if (!std::filesystem::exists(optimizer_path)) { THROW_INVALID_USAGE("Could not find " + optimizer_path.native() + "."); } - torch::load(*(alpha_optimizer_), optimizer_path.native()); + torch::load(*(alpha_optimizer_), optimizer_path.native(), model_device_); if (alpha_lr_scheduler_) { auto lr_path = root_dir / "alpha" / "lr.pt"; diff --git a/src/csrc/rl/on_policy/ppo.cpp b/src/csrc/rl/on_policy/ppo.cpp index 75b6b2e9..0e56e4c5 100644 --- a/src/csrc/rl/on_policy/ppo.cpp +++ b/src/csrc/rl/on_policy/ppo.cpp @@ -300,7 +300,7 @@ void PPOSystem::loadCheckpoint(const std::string& checkpoint_dir) { if (!std::filesystem::exists(optimizer_path)) { THROW_INVALID_USAGE("Could not find " + optimizer_path.native() + "."); } - torch::load(*(pq_model_.optimizer), optimizer_path.native()); + torch::load(*(pq_model_.optimizer), optimizer_path.native(), pq_model_.model->device()); auto lr_path = root_dir / "actor_critic" / "lr.pt"; if (!std::filesystem::exists(lr_path)) { diff --git a/tests/rl/test_checkpoint.cpp b/tests/rl/test_checkpoint.cpp index c456d55f..7e36ee51 100644 --- a/tests/rl/test_checkpoint.cpp +++ b/tests/rl/test_checkpoint.cpp @@ -92,7 +92,12 @@ void train_and_fill(const std::string& name, int num_train_steps) { // Round-trip test for both checkpoint restore mechanisms of an off-policy system: // * torchfort_rl_off_policy_load_checkpoint -> full restore (weights + replay buffer) // * torchfort_rl_off_policy_load_model -> weights-only restore (e.g. for fine-tuning) -void checkpoint_roundtrip(const std::string& system) { +std::string device_suffix(int device) { + return device == TORCHFORT_DEVICE_CPU ? "cpu" : "gpu" + std::to_string(device); +} + +void checkpoint_roundtrip(const std::string& system, int src_device = TORCHFORT_DEVICE_CPU, + int restore_device = TORCHFORT_DEVICE_CPU) { // silence TorchFort's verbose/logging output for the duration of the test std::stringstream cout_buffer; CoutRedirect cout_redirect(cout_buffer.rdbuf()); @@ -100,16 +105,16 @@ void checkpoint_roundtrip(const std::string& system) { torch::manual_seed(666); const std::string config = get_config_path(system + ".yaml"); - const std::string src = system + "_src"; - const std::string full = system + "_restore_full"; - const std::string weights = system + "_restore_weights"; - const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system; + const std::string dev_tag = device_suffix(src_device) + "_to_" + device_suffix(restore_device); + const std::string src = system + "_" + dev_tag + "_src"; + const std::string full = system + "_" + dev_tag + "_restore_full"; + const std::string weights = system + "_" + dev_tag + "_restore_weights"; + const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system + "_" + dev_tag; std::filesystem::remove_all(ckpt_dir); // create and train the source system, filling the replay buffer past its min_size - CHECK_TORCHFORT( - torchfort_rl_off_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(src.c_str(), config.c_str(), src_device, src_device)); train_and_fill(src, /* num_train_steps = */ 5); // reference deterministic prediction for a fixed evaluation state @@ -126,7 +131,7 @@ void checkpoint_roundtrip(const std::string& system) { // ---------------- full checkpoint restore ---------------- CHECK_TORCHFORT( - torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); + torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), restore_device, restore_device)); // a freshly created system has an empty buffer and is not ready bool full_ready_before = true; CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_before)); @@ -143,9 +148,21 @@ void checkpoint_roundtrip(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_after)); EXPECT_TRUE(full_ready_after); + // training is functional after full checkpoint restore: the policy must change after a few steps. + // TD3 applies policy gradients every policy_lag (default: 2) critic steps; two steps guarantee + // at least one policy update regardless of the internal step phase. + { + float p_loss, q_loss; + for (int i = 0; i < 2; ++i) { + CHECK_TORCHFORT(torchfort_rl_off_policy_train_step(full.c_str(), &p_loss, &q_loss, 0)); + } + } + const float action_full_trained = predict_fixed(full, eval_state); + EXPECT_GT(std::abs(action_full_trained - action_full), 1e-7f); + // ---------------- weights-only restore (load_model) ---------------- - CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(weights.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(weights.c_str(), config.c_str(), restore_device, + restore_device)); // prediction before loading stems from a fresh (independent) initialization and differs from the reference, // which makes the post-load match below a meaningful check that loading actually happened const float action_fresh = predict_fixed(weights, eval_state); @@ -162,6 +179,18 @@ void checkpoint_roundtrip(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(weights.c_str(), weights_ready)); EXPECT_FALSE(weights_ready); + // fill the buffer so training is possible, then verify the policy changes after enough steps + // (two steps guarantee at least one policy update given TD3's policy_lag=2) + train_and_fill(weights, 0); + { + float p_loss, q_loss; + for (int i = 0; i < 2; ++i) { + CHECK_TORCHFORT(torchfort_rl_off_policy_train_step(weights.c_str(), &p_loss, &q_loss, 0)); + } + } + const float action_weights_trained = predict_fixed(weights, eval_state); + EXPECT_GT(std::abs(action_weights_trained - action_weights), 1e-7f); + // cleanup std::filesystem::remove_all(ckpt_dir); } @@ -213,7 +242,8 @@ void rollout_until_ready(const std::string& name) { // Round-trip test for both checkpoint restore mechanisms of the on-policy (PPO) system: // * torchfort_rl_on_policy_load_checkpoint -> full restore (weights + rollout buffer) // * torchfort_rl_on_policy_load_model -> weights-only restore (e.g. for fine-tuning) -void checkpoint_roundtrip_on_policy(const std::string& system) { +void checkpoint_roundtrip_on_policy(const std::string& system, int src_device = TORCHFORT_DEVICE_CPU, + int restore_device = TORCHFORT_DEVICE_CPU) { // silence TorchFort's verbose/logging output for the duration of the test std::stringstream cout_buffer; CoutRedirect cout_redirect(cout_buffer.rdbuf()); @@ -221,16 +251,16 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { torch::manual_seed(666); const std::string config = get_config_path(system + ".yaml"); - const std::string src = system + "_src"; - const std::string full = system + "_restore_full"; - const std::string weights = system + "_restore_weights"; - const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system; + const std::string dev_tag = device_suffix(src_device) + "_to_" + device_suffix(restore_device); + const std::string src = system + "_" + dev_tag + "_src"; + const std::string full = system + "_" + dev_tag + "_restore_full"; + const std::string weights = system + "_" + dev_tag + "_restore_weights"; + const std::string ckpt_dir = "/tmp/torchfort_rl_ckpt_" + system + "_" + dev_tag; std::filesystem::remove_all(ckpt_dir); // create the source system and fill its rollout buffer until it is ready - CHECK_TORCHFORT( - torchfort_rl_on_policy_create_system(src.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(src.c_str(), config.c_str(), src_device, src_device)); rollout_until_ready(src); // reference deterministic prediction for a fixed evaluation state @@ -246,8 +276,7 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_on_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); // ---------------- full checkpoint restore ---------------- - CHECK_TORCHFORT( - torchfort_rl_on_policy_create_system(full.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(full.c_str(), config.c_str(), restore_device, restore_device)); // a freshly created system has an empty rollout buffer and is not ready bool full_ready_before = true; CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(full.c_str(), full_ready_before)); @@ -264,9 +293,17 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(full.c_str(), full_ready_after)); EXPECT_TRUE(full_ready_after); + // training is functional after full checkpoint restore: one gradient step must change the policy + { + float p_loss, q_loss; + CHECK_TORCHFORT(torchfort_rl_on_policy_train_step(full.c_str(), &p_loss, &q_loss, 0)); + } + const float action_full_trained = predict_fixed_on_policy(full, eval_state); + EXPECT_GT(std::abs(action_full_trained - action_full), 1e-7f); + // ---------------- weights-only restore (load_model) ---------------- - CHECK_TORCHFORT(torchfort_rl_on_policy_create_system(weights.c_str(), config.c_str(), TORCHFORT_DEVICE_CPU, - TORCHFORT_DEVICE_CPU)); + CHECK_TORCHFORT( + torchfort_rl_on_policy_create_system(weights.c_str(), config.c_str(), restore_device, restore_device)); // prediction before loading stems from a fresh (independent) initialization and differs from the reference const float action_fresh = predict_fixed_on_policy(weights, eval_state); EXPECT_GT(std::abs(action_ref - action_fresh), 1e-6); @@ -282,6 +319,15 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { CHECK_TORCHFORT(torchfort_rl_on_policy_is_ready(weights.c_str(), weights_ready)); EXPECT_FALSE(weights_ready); + // fill the rollout buffer so training is possible, then verify one gradient step changes the policy + rollout_until_ready(weights); + { + float p_loss, q_loss; + CHECK_TORCHFORT(torchfort_rl_on_policy_train_step(weights.c_str(), &p_loss, &q_loss, 0)); + } + const float action_weights_trained = predict_fixed_on_policy(weights, eval_state); + EXPECT_GT(std::abs(action_weights_trained - action_weights), 1e-7f); + // cleanup std::filesystem::remove_all(ckpt_dir); } @@ -289,13 +335,28 @@ void checkpoint_roundtrip_on_policy(const std::string& system) { } // namespace TEST(Checkpoint, TD3) { checkpoint_roundtrip("td3"); } - TEST(Checkpoint, DDPG) { checkpoint_roundtrip("ddpg"); } - TEST(Checkpoint, SAC) { checkpoint_roundtrip("sac"); } - TEST(Checkpoint, PPO) { checkpoint_roundtrip_on_policy("ppo"); } +#ifdef ENABLE_GPU +TEST(Checkpoint, TD3GPUtoGPU) { checkpoint_roundtrip("td3", 0, 0); } +TEST(Checkpoint, TD3CPUtoGPU) { checkpoint_roundtrip("td3", TORCHFORT_DEVICE_CPU, 0); } +TEST(Checkpoint, TD3GPUtoCPU) { checkpoint_roundtrip("td3", 0, TORCHFORT_DEVICE_CPU); } + +TEST(Checkpoint, DDPGGPUtoGPU) { checkpoint_roundtrip("ddpg", 0, 0); } +TEST(Checkpoint, DDPGCPUtoGPU) { checkpoint_roundtrip("ddpg", TORCHFORT_DEVICE_CPU, 0); } +TEST(Checkpoint, DDPGGPUtoCPU) { checkpoint_roundtrip("ddpg", 0, TORCHFORT_DEVICE_CPU); } + +TEST(Checkpoint, SACGPUtoGPU) { checkpoint_roundtrip("sac", 0, 0); } +TEST(Checkpoint, SACCPUtoGPU) { checkpoint_roundtrip("sac", TORCHFORT_DEVICE_CPU, 0); } +TEST(Checkpoint, SACGPUtoCPU) { checkpoint_roundtrip("sac", 0, TORCHFORT_DEVICE_CPU); } + +TEST(Checkpoint, PPOGPUtoGPU) { checkpoint_roundtrip_on_policy("ppo", 0, 0); } +TEST(Checkpoint, PPOCPUtoGPU) { checkpoint_roundtrip_on_policy("ppo", TORCHFORT_DEVICE_CPU, 0); } +TEST(Checkpoint, PPOGPUtoCPU) { checkpoint_roundtrip_on_policy("ppo", 0, TORCHFORT_DEVICE_CPU); } +#endif + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); diff --git a/tests/rl/test_interface.cpp b/tests/rl/test_interface.cpp index 1301646d..93a31dc8 100644 --- a/tests/rl/test_interface.cpp +++ b/tests/rl/test_interface.cpp @@ -66,10 +66,9 @@ std::string createModifiedTD3Config(int state_size, int action_size) { critic_layer_sizes[0] = state_size + action_size; } - // Create temporary file - std::string temp_filename("./tmpconfig.yaml"); - - // Write modified config to temporary file + // Write to /tmp/ so the file is always writable (install dirs may be read-only) + std::string temp_filename("/tmp/torchfort_tmpconfig_td3_" + std::to_string(state_size) + "_" + + std::to_string(action_size) + ".yaml"); std::ofstream temp_file(temp_filename); temp_file << config; temp_file.close(); @@ -102,10 +101,9 @@ std::string createModifiedPPOConfig(int state_size, int action_size) { } } - // Create temporary file - std::string temp_filename("./tmpconfig.yaml"); - - // Write modified config to temporary file + // Write to /tmp/ so the file is always writable (install dirs may be read-only) + std::string temp_filename("/tmp/torchfort_tmpconfig_ppo_" + std::to_string(state_size) + "_" + + std::to_string(action_size) + ".yaml"); std::ofstream temp_file(temp_filename); temp_file << config; temp_file.close(); From 1c85ab48f137dd7dde1940b50e19b7de8bae52ca Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Jun 2026 23:26:19 -0700 Subject: [PATCH 5/6] fixing formatting Signed-off-by: Thorsten Kurth --- tests/rl/test_checkpoint.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/rl/test_checkpoint.cpp b/tests/rl/test_checkpoint.cpp index 7e36ee51..09080c82 100644 --- a/tests/rl/test_checkpoint.cpp +++ b/tests/rl/test_checkpoint.cpp @@ -130,8 +130,7 @@ void checkpoint_roundtrip(const std::string& system, int src_device = TORCHFORT_ CHECK_TORCHFORT(torchfort_rl_off_policy_save_checkpoint(src.c_str(), ckpt_dir.c_str())); // ---------------- full checkpoint restore ---------------- - CHECK_TORCHFORT( - torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), restore_device, restore_device)); + CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(full.c_str(), config.c_str(), restore_device, restore_device)); // a freshly created system has an empty buffer and is not ready bool full_ready_before = true; CHECK_TORCHFORT(torchfort_rl_off_policy_is_ready(full.c_str(), full_ready_before)); @@ -161,8 +160,8 @@ void checkpoint_roundtrip(const std::string& system, int src_device = TORCHFORT_ EXPECT_GT(std::abs(action_full_trained - action_full), 1e-7f); // ---------------- weights-only restore (load_model) ---------------- - CHECK_TORCHFORT(torchfort_rl_off_policy_create_system(weights.c_str(), config.c_str(), restore_device, - restore_device)); + CHECK_TORCHFORT( + torchfort_rl_off_policy_create_system(weights.c_str(), config.c_str(), restore_device, restore_device)); // prediction before loading stems from a fresh (independent) initialization and differs from the reference, // which makes the post-load match below a meaningful check that loading actually happened const float action_fresh = predict_fixed(weights, eval_state); From 8f6ea5d28ab2228ea567c52894741b15d8a87f3b Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 17 Jun 2026 06:54:06 -0700 Subject: [PATCH 6/6] adding checkpoint_rl test to workflow script Signed-off-by: Thorsten Kurth --- .github/scripts/run_ci_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/run_ci_tests.sh b/.github/scripts/run_ci_tests.sh index 7baa46da..4245d35e 100755 --- a/.github/scripts/run_ci_tests.sh +++ b/.github/scripts/run_ci_tests.sh @@ -15,5 +15,6 @@ cd /opt/torchfort/bin/tests/rl ./test_distributions ./test_replay_buffer ./test_rollout_buffer +./test_checkpoint_rl ./test_off_policy --gtest_filter=*L0* ./test_on_policy --gtest_filter=*L0*