Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ target_sources(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/common_models.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/models/rl/sac_model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/policy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/running_normalizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/interface.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/csrc/rl/off_policy/ddpg.cpp
Expand Down
282 changes: 174 additions & 108 deletions docs/api/config.rst

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/csrc/include/internal/rl/off_policy/ddpg.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "internal/rl/noise_actor.h"
#include "internal/rl/off_policy.h"
#include "internal/rl/replay_buffer.h"
#include "internal/rl/running_normalizer.h"
#include "internal/rl/utils.h"

namespace torchfort {
Expand Down Expand Up @@ -305,6 +306,12 @@ class DDPGSystem : public RLOffPolicySystem, public std::enable_shared_from_this
std::shared_ptr<NoiseActor> noise_actor_train_;
std::shared_ptr<NoiseActor> noise_actor_exploration_;

// state normalizer (optional, null if disabled)
std::unique_ptr<RunningNormalizer> state_normalizer_;

// reward normalizer (optional, null if disabled); scale_only=true so mean is preserved
std::unique_ptr<RunningNormalizer> reward_normalizer_;

// some parameters
int batch_size_;
int num_critics_;
Expand Down
7 changes: 7 additions & 0 deletions src/csrc/include/internal/rl/off_policy/sac.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "internal/rl/off_policy.h"
#include "internal/rl/policy.h"
#include "internal/rl/replay_buffer.h"
#include "internal/rl/running_normalizer.h"
#include "internal/rl/utils.h"

namespace torchfort {
Expand Down Expand Up @@ -428,6 +429,12 @@ class SACSystem : public RLOffPolicySystem, public std::enable_shared_from_this<
// system comm
std::shared_ptr<Comm> system_comm_;

// state normalizer (optional, null if disabled)
std::unique_ptr<RunningNormalizer> state_normalizer_;

// reward normalizer (optional, null if disabled); scale_only=true so mean is preserved
std::unique_ptr<RunningNormalizer> reward_normalizer_;

// some parameters
int batch_size_;
int num_critics_;
Expand Down
7 changes: 7 additions & 0 deletions src/csrc/include/internal/rl/off_policy/td3.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "internal/rl/noise_actor.h"
#include "internal/rl/off_policy.h"
#include "internal/rl/replay_buffer.h"
#include "internal/rl/running_normalizer.h"
#include "internal/rl/utils.h"

namespace torchfort {
Expand Down Expand Up @@ -338,6 +339,12 @@ class TD3System : public RLOffPolicySystem, public std::enable_shared_from_this<
std::shared_ptr<NoiseActor> noise_actor_train_;
std::shared_ptr<NoiseActor> noise_actor_exploration_;

// state normalizer (optional, null if disabled)
std::unique_ptr<RunningNormalizer> state_normalizer_;

// reward normalizer (optional, null if disabled); scale_only=true so mean is preserved
std::unique_ptr<RunningNormalizer> reward_normalizer_;

// some parameters
int batch_size_;
int num_critics_;
Expand Down
48 changes: 12 additions & 36 deletions src/csrc/include/internal/rl/on_policy/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "internal/rl/on_policy.h"
#include "internal/rl/policy.h"
#include "internal/rl/rollout_buffer.h"
#include "internal/rl/running_normalizer.h"
#include "internal/rl/utils.h"

namespace torchfort {
Expand All @@ -45,8 +46,8 @@ template <typename T>
void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::Tensor action_tensor,
torch::Tensor q_tensor, torch::Tensor log_p_tensor, torch::Tensor adv_tensor, torch::Tensor ret_tensor,
const T& epsilon, const T& clip_q, const T& entropy_loss_coeff, const T& q_loss_coeff,
const T& target_kl_divergence, bool normalize_advantage, T& p_loss_val, T& q_loss_val, T& kl_divergence,
T& clip_fraction, T& explained_var) {
const T& target_kl_divergence, T& p_loss_val, T& q_loss_val, T& kl_divergence, T& clip_fraction,
T& explained_var) {

// nvtx marker
torchfort::nvtx::rangePush("torchfort_train_ppo");
Expand All @@ -65,40 +66,6 @@ void train_ppo(const ACPolicyPack& pq_model, torch::Tensor state_tensor, torch::
assert(adv_tensor.dim() == 1);
assert(ret_tensor.dim() == 1);

// normalize advantages if requested
if (normalize_advantage && (batch_size > 1)) {
// make sure we are not going to compute gradients
torch::NoGradGuard no_grad;

// compute mean
torch::Tensor adv_mean = torch::sum(adv_tensor);
auto options = torch::TensorOptions().dtype(torch::kLong).device(adv_mean.device());
torch::Tensor adv_count = torch::tensor({torch::numel(adv_tensor)}, options);

// average mean across all nodes
if (pq_model.comm) {
std::vector<torch::Tensor> means = {adv_mean, adv_count};
pq_model.comm->allreduce(means, false);
adv_mean = means[0];
adv_count = means[1];
}
adv_mean = adv_mean / adv_count;

// compute std
torch::Tensor adv_std = torch::sum(torch::square(adv_tensor - adv_mean));

// average std across all nodes
if (pq_model.comm) {
std::vector<torch::Tensor> stds = {adv_std};
pq_model.comm->allreduce(stds, false);
adv_std = stds[0];
}
adv_std = torch::sqrt(adv_std / (adv_count - 1));

// update advantage tensor
adv_tensor = (adv_tensor - adv_mean) / (adv_std + 1.e-8);
}

// set models to train
pq_model.model->train();

Expand Down Expand Up @@ -317,6 +284,12 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this<R
// system comm
std::shared_ptr<Comm> system_comm_;

// state normalizer (optional, null if disabled)
std::unique_ptr<RunningNormalizer> state_normalizer_;

// return normalizer (optional, null if disabled); scale_only=true so mean is preserved
std::unique_ptr<RunningNormalizer> return_normalizer_;

// some parameters
int batch_size_;
float epsilon_, clip_q_;
Expand All @@ -326,6 +299,9 @@ class PPOSystem : public RLOnPolicySystem, public std::enable_shared_from_this<R
float clip_fraction_;
float a_low_, a_high_;
bool normalize_advantage_;
bool normalize_returns_;
bool advantage_normalized_; // tracks whether advantages have been normalized for the current rollout
bool returns_normalized_; // tracks whether returns have been normalized for the current rollout
ActorNormalizationMode actor_normalization_mode_;
};

Expand Down
88 changes: 88 additions & 0 deletions src/csrc/include/internal/rl/rollout_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
#include <torch/torch.h>

#include "internal/defines.h"
#include "internal/distributed.h"
#include "internal/rl/rl.h"
#include "internal/rl/running_normalizer.h"

namespace torchfort {

Expand Down Expand Up @@ -69,6 +71,8 @@ class RolloutBuffer {
virtual ExtendedBufferEntry getFull(int) = 0;
virtual bool isReady() const = 0;
virtual void reset() = 0;
virtual void normalizeReturns(std::shared_ptr<Comm> comm, RunningNormalizer& return_normalizer) = 0;
virtual void normalizeAdvantages(std::shared_ptr<Comm> comm) = 0;
virtual void setSeed(unsigned int seed) = 0;
virtual void printInfo() const = 0;
virtual void save(const std::string& fname) const = 0;
Expand Down Expand Up @@ -179,6 +183,90 @@ class GAELambdaRolloutBuffer : public RolloutBuffer, public std::enable_shared_f
return;
}

// Normalize all stored advantages to zero mean and unit variance over the full rollout.
// In distributed mode, statistics are combined across ranks via allreduce so that all
// ranks use the same normalization. Call this once after finalize() and before sampling.
void normalizeAdvantages(std::shared_ptr<Comm> comm) {
if (!finalized_) {
throw std::runtime_error(
"GAELambdaRolloutBuffer::normalizeAdvantages: buffer must be finalized before normalizing advantages.");
}

torch::NoGradGuard no_grad;

// stack all per-step advantages into [size_, n_envs_] and flatten to 1D
auto all_adv = torch::stack(advantages_, 0).flatten().to(torch::kFloat32);

// compute global sum and count for the mean
auto adv_sum = torch::sum(all_adv);
auto count_tensor = torch::tensor({static_cast<float>(all_adv.numel())}).to(all_adv.device());

if (comm) {
std::vector<torch::Tensor> stats = {adv_sum, count_tensor};
comm->allreduce(stats, false);
adv_sum = stats[0];
count_tensor = stats[1];
}
auto adv_mean = adv_sum / count_tensor;

// compute global sum of squared deviations for the std
auto adv_sq = torch::sum(torch::square(all_adv - adv_mean));
if (comm) {
std::vector<torch::Tensor> sq_stats = {adv_sq};
comm->allreduce(sq_stats, false);
adv_sq = sq_stats[0];
}
auto adv_std = torch::sqrt(adv_sq / (count_tensor - 1.) + 1e-8);

// normalize all stored advantages in-place
for (auto& adv : advantages_) {
adv = (adv - adv_mean) / adv_std;
}
}

// Scale returns and advantages by the running std of returns (no mean subtraction).
// Updates the provided return_normalizer with this rollout's returns, syncs statistics
// across MPI ranks, then divides both returns_ and advantages_ by the same return std.
// This ensures the value function regression target and the policy gradient use a
// consistent scale. Call this before normalizeAdvantages() if both are enabled.
void normalizeReturns(std::shared_ptr<Comm> comm, RunningNormalizer& return_normalizer) {
if (!finalized_) {
throw std::runtime_error(
"GAELambdaRolloutBuffer::normalizeReturns: buffer must be finalized before normalizing returns.");
}

torch::NoGradGuard no_grad;

// flatten all returns to [size_ * n_envs_, 1]: single scalar feature per sample
auto all_ret = torch::stack(returns_, 0).reshape({-1, 1}).to(torch::kFloat32);

// update running variance with this rollout's returns, then sync across ranks
return_normalizer.update(all_ret);
return_normalizer.sync(comm);

// apply scale-only normalization: R_norm = R / std(R)
// the same std is applied to advantages: A_scaled = A / std(R),
// preserving the relationship A = R - V when both are on the same scale
auto all_ret_norm = return_normalizer.normalize(all_ret);
auto all_adv = torch::stack(advantages_, 0).reshape({-1, 1}).to(torch::kFloat32);
auto all_adv_scaled = return_normalizer.normalize(all_adv);

// write normalized values back to per-step tensors
auto ret_reshaped = all_ret_norm.reshape({static_cast<int64_t>(size_), static_cast<int64_t>(n_envs_)});
auto adv_reshaped = all_adv_scaled.reshape({static_cast<int64_t>(size_), static_cast<int64_t>(n_envs_)});
for (size_t step = 0; step < size_; ++step) {
returns_[step] = ret_reshaped[step];
advantages_[step] = adv_reshaped[step];
}

// also scale the stored value estimates (q) by the same std so that A = R - V
// holds in normalized space: A_norm = R_norm - V_norm = (R - V) / std
for (auto& entry : buffer_) {
auto& q = std::get<3>(entry);
q = return_normalizer.normalize(q.reshape({-1, 1})).reshape(q.sizes());
}
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(int batch_size) {

Expand Down
87 changes: 87 additions & 0 deletions src/csrc/include/internal/rl/running_normalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <memory>
#include <string>

#include <torch/torch.h>

#include "internal/distributed.h"

namespace torchfort {

namespace rl {

// Online per-feature normalizer using Welford's parallel algorithm.
//
// Running statistics (mean, M2, count) are stored on CPU. normalize() moves them
// to the input tensor's device on-the-fly so the normalization arithmetic runs on
// GPU when called with device tensors.
//
// Two normalization modes are supported via the scale_only constructor flag:
//
// scale_only = false (default): x_norm = (x - mean) / sqrt(var + eps)
// Use for observations/states where zero-centering is desirable.
//
// scale_only = true: x_norm = x / sqrt(var + eps)
// Use for returns, where the mean must be preserved so the value function
// can learn the correct absolute level. The mean is still tracked internally
// (for distributed sync via Chan's algorithm) but not subtracted during normalization.
//
// Distributed sync: call sync() once per training step to combine per-rank running
// statistics across MPI ranks using Chan's parallel algorithm via two allreduce calls:
// 1. allreduce(count, weighted_mean) -> global count and mean
// 2. allreduce(local M2 contribution) -> global M2
class RunningNormalizer {
public:
explicit RunningNormalizer(float eps = 1e-8f, bool scale_only = false)
: count_(0), eps_(eps), scale_only_(scale_only) {}

// Update running statistics with a batch of samples.
// x shape: [batch, feature...]. Statistics are tracked per feature element.
// x may be on any device; statistics are always kept on CPU.
void update(torch::Tensor x);

// Normalize x using current running statistics.
// Returns x unchanged if fewer than 2 samples have been seen.
// Statistics are moved to x.device() for the computation.
// In scale_only mode, only divides by std without subtracting the mean.
torch::Tensor normalize(torch::Tensor x) const;

// Combine running statistics across MPI ranks using Chan's parallel algorithm.
// No-op if comm is null or count_ == 0.
void sync(std::shared_ptr<Comm> comm);

// Checkpoint support.
void save(const std::string& path) const;
void load(const std::string& path);

bool isInitialized() const { return count_ > 0; }

private:
torch::Tensor mean_; // per-feature mean, CPU float32
torch::Tensor M2_; // per-feature sum of squared deviations, CPU float32
int64_t count_;
float eps_;
bool scale_only_;
};

} // namespace rl

} // namespace torchfort
Loading
Loading