Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/run_ci_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ cd /opt/torchfort/bin/tests/rl
./test_distributions
./test_replay_buffer
./test_rollout_buffer
./test_checkpoint_rl
./test_off_policy --gtest_filter=*L0*
./test_on_policy --gtest_filter=*L0*
16 changes: 16 additions & 0 deletions docs/api/c_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ torchfort_rl_off_policy_load_checkpoint

------

.. _torchfort_rl_off_policy_load_model-ref:

torchfort_rl_off_policy_load_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. doxygenfunction:: torchfort_rl_off_policy_load_model

------


Weights and Biases Logging
--------------------------
Expand Down Expand Up @@ -483,6 +491,14 @@ torchfort_rl_on_policy_load_checkpoint

------

.. _torchfort_rl_on_policy_load_model-ref:

torchfort_rl_on_policy_load_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. doxygenfunction:: torchfort_rl_on_policy_load_model

------


Weights and Biases Logging
--------------------------
Expand Down
49 changes: 45 additions & 4 deletions docs/api/f_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,32 @@ torchfort_rl_off_policy_load_checkpoint

.. f:function:: torchfort_rl_off_policy_load_checkpoint(name, checkpoint_dir)
Restores a reinforcement learning system from a checkpoint.
Restores a reinforcement learning system from a checkpoint.
This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler
states. It also fully restores the state of the replay buffer, but not the current RNG seed.
This function should be used in conjunction with :code:`torchfort_rl_off_policy_save_checkpoint`.


:p character(:) name [in]: The name of system instance to use, as defined during system creation.
:p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load.
:r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure.

------

.. _torchfort_rl_off_policy_load_model-f-ref:

torchfort_rl_off_policy_load_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. f:function:: torchfort_rl_off_policy_load_model(name, checkpoint_dir)
Restores only the network weights of a reinforcement learning system from a checkpoint.
In contrast to :code:`torchfort_rl_off_policy_load_checkpoint`, this method only restores the weights of the online policy and
critic networks. For algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored,
by re-initializing them from the loaded online networks so that they start consistent with the restored weights. The optimizers,
LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly created state. This is intended for
fine-tuning or transfer-learning workflows, where a pretrained model is used as the starting point for a new training run (e.g. with
a modified reward function or new environment data). The checkpoint is the one produced by :code:`torchfort_rl_off_policy_save_checkpoint`.

:p character(:) name [in]: The name of system instance to use, as defined during system creation.
:p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load.
:r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure.
Expand Down Expand Up @@ -800,11 +821,31 @@ torchfort_rl_on_policy_load_checkpoint

.. f:function:: torchfort_rl_on_policy_load_checkpoint(name, checkpoint_dir)
Restores a reinforcement learning system from a checkpoint.
Restores a reinforcement learning system from a checkpoint.
This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler
states. It also fully restores the state of the rollout buffer, but not the current RNG seed.
This function should be used in conjunction with :code:`torchfort_rl_on_policy_save_checkpoint`.


:p character(:) name [in]: The name of system instance to use, as defined during system creation.
:p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load.
:r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure.

------

.. _torchfort_rl_on_policy_load_model-f-ref:

torchfort_rl_on_policy_load_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. f:function:: torchfort_rl_on_policy_load_model(name, checkpoint_dir)
Restores only the network weights of a reinforcement learning system from a checkpoint.
In contrast to :code:`torchfort_rl_on_policy_load_checkpoint`, this method only restores the weights of the actor-critic network.
The optimizer, LR scheduler, rollout buffer, normalizer statistics and step counters are left in their freshly created state. This is
intended for fine-tuning or transfer-learning workflows, where a pretrained model is used as the starting point for a new training run
(e.g. with a modified reward function or new environment data). The checkpoint is the one produced by
:code:`torchfort_rl_on_policy_save_checkpoint`.

:p character(:) name [in]: The name of system instance to use, as defined during system creation.
:p character(:) checkpoint_dir [in]: A filesystem path to a directory which contains the checkpoint data to load.
:r torchfort_result res: :code:`TORCHFORT_RESULT_SUCCESS` on success or error code on failure.
Expand Down
31 changes: 30 additions & 1 deletion docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,34 @@ This function is only required if RL training from the checkpoint should be resu

istat = torchfort_load_model(model_name, policy_model_file);

can be used instead. The model instance should be created beforehand using the methods described in the :ref:`supervised_learning-ref` section.
can be used instead. The model instance should be created beforehand using the methods described in the :ref:`supervised_learning-ref` section.

Fine-Tuning / Transfer Learning
-------------------------------

Sometimes a previously trained system should be used as the *starting point* for a new training run, rather than fully resumed. A typical example is
refining a policy with new environment data or a modified reward function. In this case restoring the full checkpoint is undesirable, because it would
also restore the stale replay buffer, the optimizer momenta, the learning-rate schedule position and the training step counters of the previous run.

For this purpose, only the network weights can be restored from a checkpoint:

.. tabs::

.. code-tab:: fortran

istat = torchfort_rl_off_policy_load_model(system_name, directory_name)

.. code-tab:: c++

istat = torchfort_rl_off_policy_load_model(system_name, directory_name);

This loads only the weights of the online policy and critic networks from the checkpoint directory created by ``torchfort_rl_off_policy_save_checkpoint``.
For algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored, by re-initializing them from the loaded
online networks so that they start consistent with the restored weights. The optimizers, learning-rate schedulers, replay buffer, normalizer statistics and
step counters remain in their freshly created state. Training then proceeds from the pretrained weights with a clean training history, so that newly
collected transitions (e.g. generated under the modified reward function) are not mixed with stale experience. The system must be created beforehand with
``torchfort_rl_off_policy_create_system`` using a network architecture matching the saved checkpoint.

On-policy systems provide the equivalent function ``torchfort_rl_on_policy_load_model``, which restores only the actor-critic network weights and leaves the
optimizer, learning-rate scheduler, rollout buffer, normalizer statistics and step counters in their freshly created state.

13 changes: 13 additions & 0 deletions src/csrc/include/internal/model_pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "internal/base_loss.h"
#include "internal/base_lr_scheduler.h"
#include "internal/distributed.h"
#include "internal/exceptions.h"
#include "internal/model_state.h"
#include "internal/model_wrapper.h"
#ifdef ENABLE_GPU
Expand Down Expand Up @@ -51,4 +52,16 @@ struct ModelPack {
void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true);
void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true);

// Re-point an optimizer at the given parameter tensors, e.g. after (re)loading model weights (for JIT
// models a load replaces the underlying tensors, so the optimizer's references have to be refreshed).
// TorchFort constructs optimizers with a single parameter group, so we update group 0 directly; this
// avoids the deprecated torch::optim::Optimizer::parameters() accessor.
inline void reset_optimizer_parameters(const std::shared_ptr<torch::optim::Optimizer>& optimizer,
const std::vector<torch::Tensor>& parameters) {
if (optimizer->param_groups().size() != 1) {
THROW_NOT_SUPPORTED("reset_optimizer_parameters expects an optimizer with a single parameter group.");
}
optimizer->param_groups()[0].params() = parameters;
}

} // namespace torchfort
3 changes: 3 additions & 0 deletions src/csrc/include/internal/rl/off_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class RLOffPolicySystem {
virtual void initSystemComm(MPI_Comm mpi_comm) = 0;
virtual void saveCheckpoint(const std::string& checkpoint_dir) const = 0;
virtual void loadCheckpoint(const std::string& checkpoint_dir) = 0;
// load only the network weights (e.g. for fine-tuning), leaving optimizers,
// LR schedulers, replay buffer and step counters in their freshly created state
virtual void loadModel(const std::string& checkpoint_dir) = 0;
virtual torch::Device modelDevice() const = 0;
virtual torch::Device rbDevice() const = 0;
virtual int getRank() const = 0;
Expand Down
1 change: 1 addition & 0 deletions src/csrc/include/internal/rl/off_policy/ddpg.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class DDPGSystem : public RLOffPolicySystem, public std::enable_shared_from_this
// saving and loading
void saveCheckpoint(const std::string& checkpoint_dir) const;
void loadCheckpoint(const std::string& checkpoint_dir);
void loadModel(const std::string& checkpoint_dir);

// info printing
void printInfo() const;
Expand Down
1 change: 1 addition & 0 deletions src/csrc/include/internal/rl/off_policy/sac.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class SACSystem : public RLOffPolicySystem, public std::enable_shared_from_this<
// saving and loading
void saveCheckpoint(const std::string& checkpoint_dir) const;
void loadCheckpoint(const std::string& checkpoint_dir);
void loadModel(const std::string& checkpoint_dir);

// info printing
void printInfo() const;
Expand Down
1 change: 1 addition & 0 deletions src/csrc/include/internal/rl/off_policy/td3.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class TD3System : public RLOffPolicySystem, public std::enable_shared_from_this<
// saving and loading
void saveCheckpoint(const std::string& checkpoint_dir) const;
void loadCheckpoint(const std::string& checkpoint_dir);
void loadModel(const std::string& checkpoint_dir);

// info printing
void printInfo() const;
Expand Down
3 changes: 3 additions & 0 deletions src/csrc/include/internal/rl/on_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class RLOnPolicySystem {
virtual void initSystemComm(MPI_Comm mpi_comm) = 0;
virtual void saveCheckpoint(const std::string& checkpoint_dir) const = 0;
virtual void loadCheckpoint(const std::string& checkpoint_dir) = 0;
// load only the network weights (e.g. for fine-tuning), leaving optimizers,
// LR schedulers, rollout buffer and step counters in their freshly created state
virtual void loadModel(const std::string& checkpoint_dir) = 0;
virtual torch::Device modelDevice() const = 0;
virtual torch::Device rbDevice() const = 0;

Expand Down
1 change: 1 addition & 0 deletions src/csrc/include/internal/rl/on_policy/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this<R
// saving and loading
void saveCheckpoint(const std::string& checkpoint_dir) const;
void loadCheckpoint(const std::string& checkpoint_dir);
void loadModel(const std::string& checkpoint_dir);

// info printing
void printInfo() const;
Expand Down
32 changes: 32 additions & 0 deletions src/csrc/include/torchfort_rl.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,23 @@ torchfort_result_t torchfort_rl_off_policy_save_checkpoint(const char* name, con
*/
torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char* name, const char* checkpoint_dir);

/**
* @brief Restores only the network weights of a reinforcement learning system from a checkpoint.
* @details In contrast to \p torchfort_rl_off_policy_load_checkpoint, this method only restores the weights of the
* online policy and critic networks from the checkpoint produced by \p torchfort_rl_off_policy_save_checkpoint. For
* algorithms that use target networks (e.g. DDPG and TD3), the corresponding target networks are also restored, by
* re-initializing them from the loaded online networks so that they start consistent with the restored weights. The
* optimizers, LR schedulers, replay buffer, normalizer statistics and step counters are left in their freshly created
* state. This is intended for fine-tuning or transfer-learning workflows, where a pretrained model is used as the
* starting point for a new training run (e.g. with a modified reward function or new environment data).
*
* @param[in] name The name of a system instance to restore the weights for, as defined during system creation.
* @param[in] checkpoint_dir A filesystem path to a directory which contains the checkpoint data to load.
*
* @return \p TORCHFORT_RESULT_SUCCESS on success or error code on failure.
*/
torchfort_result_t torchfort_rl_off_policy_load_model(const char* name, const char* checkpoint_dir);

// RL off-policy miscellaneous utility functions
/**
* @brief Queries a reinforcement learning system for rediness to start training
Expand Down Expand Up @@ -666,6 +683,21 @@ torchfort_result_t torchfort_rl_on_policy_save_checkpoint(const char* name, cons
*/
torchfort_result_t torchfort_rl_on_policy_load_checkpoint(const char* name, const char* checkpoint_dir);

/**
* @brief Restores only the network weights of a reinforcement learning system from a checkpoint.
* @details In contrast to \p torchfort_rl_on_policy_load_checkpoint, this method only restores the weights of the
* actor-critic network from the checkpoint produced by \p torchfort_rl_on_policy_save_checkpoint. The optimizer, LR
* scheduler, rollout buffer, normalizer statistics and step counters are left in their freshly created state. This is
* intended for fine-tuning or transfer-learning workflows, where a pretrained model is used as the starting point for a
* new training run (e.g. with a modified reward function or new environment data).
*
* @param[in] name The name of a system instance to restore the weights for, as defined during system creation.
* @param[in] checkpoint_dir A filesystem path to a directory which contains the checkpoint data to load.
*
* @return \p TORCHFORT_RESULT_SUCCESS on success or error code on failure.
*/
torchfort_result_t torchfort_rl_on_policy_load_model(const char* name, const char* checkpoint_dir);

// RL on-policy miscellaneous utility functions
/**
* @brief Queries a reinforcement learning system for rediness to start training
Expand Down
2 changes: 1 addition & 1 deletion src/csrc/model_pack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void load_model_pack(ModelPack& model_pack, const std::string& dir, bool load_op
// we need to check if the optimizer is initialized before doing so
// (some RL models do not have an optimizer attached to them):
if (model_pack.optimizer) {
model_pack.optimizer->parameters() = model_pack.model->parameters();
reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters());
}

if (load_optimizer) {
Expand Down
30 changes: 30 additions & 0 deletions src/csrc/rl/off_policy/ddpg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,36 @@ void DDPGSystem::loadCheckpoint(const std::string& checkpoint_dir) {
}
}

// loading only the network weights (e.g. for fine-tuning / transfer learning):
// this restores the online policy and critic weights from a checkpoint directory, but
// leaves optimizers, LR schedulers, replay buffer, normalizers and step counters untouched.
void DDPGSystem::loadModel(const std::string& checkpoint_dir) {
using namespace torchfort;
std::filesystem::path root_dir(checkpoint_dir);

// load only the model weights (model.pt) of a model pack, reconnecting the optimizer
// to the (possibly newly allocated) model parameters afterwards.
auto load_weights = [](auto& model_pack, const std::filesystem::path& dir) {
auto model_path = dir / "model.pt";
if (!std::filesystem::exists(model_path)) {
THROW_INVALID_USAGE("Could not find " + model_path.native() + ".");
}
model_pack.model->load(model_path.native());
if (model_pack.optimizer) {
reset_optimizer_parameters(model_pack.optimizer, model_pack.model->parameters());
}
};

// online policy and critic
load_weights(p_model_, root_dir / "policy");
load_weights(q_model_, root_dir / "critic");

// initialize the target networks from the freshly loaded online networks; the saved
// target weights are an artifact of the previous run and are intentionally ignored.
copy_parameters(p_model_target_.model, p_model_.model);
copy_parameters(q_model_target_.model, q_model_.model);
}

// we should pass a tuple (s, a, s', r, d)
void DDPGSystem::updateReplayBuffer(torch::Tensor s, torch::Tensor a, torch::Tensor sp, torch::Tensor r,
torch::Tensor d) {
Expand Down
13 changes: 13 additions & 0 deletions src/csrc/rl/off_policy/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char* name, con
return TORCHFORT_RESULT_SUCCESS;
}

// load network weights only (e.g. for fine-tuning)
torchfort_result_t torchfort_rl_off_policy_load_model(const char* name, const char* checkpoint_dir) {
using namespace torchfort;

try {
rl::off_policy::registry[name]->loadModel(checkpoint_dir);
} catch (const BaseException& e) {
std::cerr << e.what();
return e.getResult();
}
return TORCHFORT_RESULT_SUCCESS;
}

// ready check
torchfort_result_t torchfort_rl_off_policy_is_ready(const char* name, bool& ready) {
using namespace torchfort;
Expand Down
Loading
Loading