|
|
This work introduces trajectory-level priority sampling methods for stable offline-to-online reinforcement learning:
- Offline Cluster-balanced Trajectory Sampling: Ensures diverse trajectory coverage during offline training by clustering trajectories in feature space and sampling uniformly across clusters
- Online Trajectory-level Prioritized Experience Replay (T-PER): Prioritizes high-TD-error trajectories during online fine-tuning to focus learning on difficult transitions
These methods improve sample efficiency and training stability in sparse reward robotic manipulation tasks.
This implementation is built on top of Action Chunking with Flow Q-Learning (ACFQL), extending it with trajectory-level sampling strategies.
# Create conda environment
conda create -n tper python=3.10
conda activate tper
# Install dependencies
pip install -r requirements.txtFor robomimic environments, datasets should be located at:
~/.robomimic/{task}/mh/low_dim_v15.hdf5
Download from: https://robomimic.github.io/docs/datasets/robomimic_v0.1.html (Method 2: Direct Download Links - Multi-Human (MH))
There a few candidate for selecting K when using elbow method.
Initially, we used K=7 for square task and K=5 for transport task
# Baseline: ACFQL
MUJOCO_GL=egl python main_acfql.py \
--run_group=baseline \
--env_name=square-mh-low \
--sparse=False \
--horizon_length=5 \
--offline_steps=1000000 \
--online_steps=1000000
--entity=P{YOUR_WANDB_ENTITY} \
# With Cluster-balanced Sampling (Offline)
MUJOCO_GL=egl python main.py \
--env_name=square-mh-low \
--run_group=cluster_balanced \
--use_ptr_backward=True \
--use_ptr_online_priority=True \
--sparse=False \
--agent.alpha=100 \
--horizon_length=5 \
--metric=uniform \
--cluster_sampler=True \
--entity=P{YOUR_WANDB_ENTITY} \
# With T-PER (Online Priority Sampling)
MUJOCO_GL=egl python main.py \
--env_name=square-mh-low \
--run_group=tper \
--use_ptr_backward=True \
--use_ptr_online_priority=True \
--sparse=False \
--agent.alpha=100 \
--horizon_length=5 \
--metric=td_error_rank \
--cluster_sampler=False \
--entity=P{YOUR_WANDB_ENTITY} \
# Full Method: Cluster-balanced + T-PER
MUJOCO_GL=egl python main.py \
--env_name=square-mh-low \
--run_group=tper \
--use_ptr_backward=True \
--use_ptr_online_priority=True \
--sparse=False \
--agent.alpha=100 \
--horizon_length=5 \
--metric=td_error_rank \
--cluster_sampler=True \
--entity=P{YOUR_WANDB_ENTITY} \
Cluster-balanced Sampling:
--cluster_sampler: Enable cluster-balanced trajectory sampling during offline training
T-PER (Trajectory-level Priority Sampling):
--use_ptr_backward: Enable backward sampling from trajectory endpoints--use_ptr_online_priority: Enable online priority updates based on TD-error--metric: Priority metric (td_error_rank,success_binary,avg_reward)--ptr_warmup_steps: Steps before enabling priority sampling (default: 20000)--backward: Sample from end of trajectories (not recommended on ACFQL baseline)
SARSA-style Weighted Target (not recommended):
--use_weighted_target: Enable weighted combination of policy and trajectory targets--beta: Weight for policy target (default: 0.5)
.
├── main.py # Main training script
├── agents/
│ └── acfql.py # ACFQL agent with priority sampling
├── utils/
│ └── datasets.py # Dataset and PriorityTrajectorySampler
├── cluster_vis.py # Trajectory clustering utilities
└── assets/
└── algorithm_visualizations_final.png
Located in utils/datasets.py, handles:
- Trajectory boundary tracking
- Priority computation (reward-based, success-based, TD-error-based)
- Rank-based sampling for stability
- Online priority updates
Located in train.py, provides:
- K-means clustering in trajectory feature space
- Uniform sampling across clusters
- Automatic K selection based on return homogeneity
The code logs to Weights & Biases with the following key metrics:
Offline Phase:
offline_agent/critic_loss: Critic training lossoffline_agent/actor_loss: Actor training losscluster/offline/*: Cluster sampling statistics
Online Phase:
online_agent/q_mean: Average Q-valuesonline_agent/td_error_mean: TD-error statisticsptr/sample_online/*: Priority sampling statisticseval/success_rate: Task success rate
If you use this code or find it helpful, please consider citing:
@article{yourname2025tper,
author = {Your Name},
title = {Trajectory-level Priority Sampling for Stable Offline-to-Online RL},
year = {2025},
}
This codebase builds upon:



