diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml new file mode 100644 index 000000000..68407adcc --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -0,0 +1,140 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + num_nodes: 1 + devices: 8 + accelerator: gpu + precision: bf16 + + # rpo specific args + rpo: + max_epochs: 1 + max_steps: -1 + val_check_interval: 100 + save_interval: 100 + limit_train_batches: 1.0 + + # how many GBS we loop over + limit_val_batches: 10 + gradient_clip_val: 1.0 + num_responses: 4 + + # do not change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.rpo.max_epochs} + max_steps: ${.rpo.max_steps} + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt + max_time_per_run: ${trainer.max_time} + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_rpo + name: rlhf_gpt3_rpo + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: "megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}" + model_parallel_size: 4 + +pretrained_checkpoint: + restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo + +model: + mcore_gpt: True + micro_batch_size: 4 + global_batch_size: 8 + megatron_amp_O2: True + + rpo: + # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. + # A higher value can be used to speed-up log probs computations, but may cause numeric differences. + log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, trainer.rpo.num_responses} + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + gt_reward_scale: 1. # the scale of the rewards in RPO + preference_loss: rpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0.05 # the coefficient of the SFT loss + beta: 0.2 + eta: 0.2 + num_responses: ${trainer.rpo.num_responses} + + #encoder_seq_length: 4096 + #max_position_embeddings: ${model.encoder_seq_length} + + # miscellaneous + seed: 1234 + + #peft + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ["attention_qkv"] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: "xavier" # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: "zero" # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-6 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-7 + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: + train: + - /data/responses_general.rm.formatted.4resp.jsonl + test: + - /data/rpo_test_set.jsonl + validation: + - /data/rpo_test_set.jsonl + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True + +precision: ${trainer.precision} diff --git a/examples/nlp/gpt/train_gpt_rpo.py b/examples/nlp/gpt/train_gpt_rpo.py new file mode 100644 index 000000000..0f850ff56 --- /dev/null +++ b/examples/nlp/gpt/train_gpt_rpo.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.rpo import RPOTrainer, rpo_custom_collate +from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_rpo_datasets +from nemo_aligner.models.nlp.gpt.megatron_gpt_rpo_model import MegatronGPTRPOModel +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_rpo") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "rpo") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTRPOModel, + cfg.model, + trainer, + strict=True, + load_base_model_only=False, + restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + if cfg.model.peft.peft_scheme == "none": + ref_policy_state_dict = retrieve_model_state_dict_in_cpu( + ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False) + ) + ptl_model.ref_policy_state_dict = ref_policy_state_dict + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3 + + train_ds, validation_ds, test_ds = build_train_valid_test_rpo_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=partial( + rpo_custom_collate, + eos_id=ptl_model.tokenizer.eos_id, + reset_position_ids=cfg.model.data.get("reset_position_ids", False), + reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), + eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), + ), + ) + + val_dataloader = build_dataloader( + cfg=cfg, + dataset=validation_ds, + consumed_samples=0, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=partial( + rpo_custom_collate, + eos_id=ptl_model.tokenizer.eos_id, + reset_position_ids=cfg.model.data.get("reset_position_ids", False), + reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), + eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), + ), + use_random_sampler=False, + ) + + init_using_ptl(trainer, ptl_model, train_dataloader, train_ds) + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + timer = Timer(cfg.exp_manager.get("max_time_per_run")) + dpo_trainer = RPOTrainer( + cfg=cfg.trainer.rpo, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + test_dataloader=None, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + dpo_trainer.load_state_dict(custom_trainer_state_dict) + + dpo_trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/rpo.py b/nemo_aligner/algorithms/rpo.py new file mode 100644 index 000000000..991f0b8e2 --- /dev/null +++ b/nemo_aligner/algorithms/rpo.py @@ -0,0 +1,333 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from statistics import mean + +import torch +from omegaconf.dictconfig import DictConfig +from tqdm import tqdm + +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingRandomBatchSampler, +) +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.utils import logging +from nemo_aligner.utils.distributed import SyncTimer +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory + + +def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): + resp_outputs = {} + + ## assume len 4 for responses + for k in batch[0].keys(): + if k.startswith("response_"): + # get response_i + resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( + [item[k] for item in batch], batch_first=True, padding_value=eos_id + ) + elif k.startswith("labels_"): + # get labels_i + resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( + [item[k] for item in batch], batch_first=True, padding_value=-100 + ) + elif k.startswith("lengths_"): + # get lens_i + resp_outputs[k] = torch.LongTensor([item[k] for item in batch]) + elif k.startswith("rewards_"): + # get r_i + resp_outputs[k] = torch.FloatTensor([item[k] for item in batch]) + + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + resp_outputs["response_1"], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, + ) + assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate" + if attention_mask.shape[0] == 1: + # using .expand() here causes errors from pin_memory=True, so need to use .repeat() + # attention_mask = attention_mask.expand(len(batch), *((-1,) * (len(attention_mask.shape) - 1))) + attention_mask = attention_mask.repeat(4, *((1,) * (len(attention_mask.shape) - 1))) + + resp_outputs["attention_mask"] = attention_mask + resp_outputs["position_ids"] = position_ids + + return resp_outputs + + +class RPOTrainer: + """Trainer to coordinate DPO training + """ + + def __init__( + self, + cfg: DictConfig, + model, + optimizer, + scheduler, + train_dataloader, + val_dataloader, + test_dataloader, + logger, + ckpt_callback, + run_timer, + ): + self.model = model + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.test_dataloader = test_dataloader + self.logger = logger + self.cfg = cfg + self.optimizer = optimizer + self.scheduler = scheduler + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.step = 0 + self.consumed_samples = 0 + + self.ckpt_callback = ckpt_callback + + # compute `max_steps` + self.num_steps_per_epoch = compute_num_steps_per_epoch( + self.train_dataloader.batch_sampler, self.cfg.get("limit_train_batches", 1.0) + ) + + self.limit_val_batches = compute_limit_batches(len(val_dataloader), self.cfg.limit_val_batches) + self.val_check_interval = ( + int(self.cfg.val_check_interval * self.num_steps_per_epoch) + if isinstance(self.cfg.val_check_interval, float) + else self.cfg.val_check_interval + ) + self.set_max_steps() + + self.timer = SyncTimer( + reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX + ) + self.k_len = int(self.cfg.num_responses) + + def validation_step(self, global_batch): + # these things should go into a GPTModel wrapper + self.model.prepare_for_validation_step() + + loss_mean, metrics = self.model.get_loss_and_metrics(batch=global_batch, forward_only=True) + + self.model.finish_validation_step() + return loss_mean, metrics + + @torch.no_grad() + def run_validation(self): + loss_means = [] + val_metrics = defaultdict(list) + + val_pbar = tqdm( + zip(range(self.limit_val_batches), self.augment_dataloader(self.val_dataloader)), + total=self.limit_val_batches, + leave=True, + desc="Validation steps", + ) + + for _, batch in val_pbar: + self.timer.start("validation_step_time") + loss_mean, metrics = self.validation_step(batch) + self.timer.stop("validation_step_time") + validation_step_time = self.timer.get("validation_step_time") + + metrics["validation_step_time"] = validation_step_time + + loss_means.append(loss_mean) + for k, v in metrics.items(): + val_metrics[k].append(v) + log_val_metrics = {f"val_{k}": v for k, v in metrics.items()} + val_pbar.set_postfix(log_val_metrics) + + val_metrics = {k: mean(v) for k, v in val_metrics.items()} + return mean(loss_means), val_metrics + + def train_single_step(self, global_batch): + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + + # NOTE: assume backward is called on the loss already + loss_mean, metrics = self.model.get_loss_and_metrics(batch=global_batch, forward_only=False) + + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + trainer_metrics = {} + if grad_norm is not None: + trainer_metrics["grad_norm"] = grad_norm + trainer_metrics.update({"lr": lr, "loss": loss_mean}) + + return loss_mean, {**metrics, **trainer_metrics} + + def fit(self): + if (not isinstance(self.train_dataloader.batch_sampler, MegatronPretrainingRandomBatchSampler)) and ( + self.cfg.max_epochs is not None and self.cfg.max_epochs > 1 + ): + # if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py) + # then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling + # to fix this, you should use MegatronPretrainingRandomBatchSampler instead, which alleviates this issue and allows + # random shuffling for each epoch. + raise ValueError( + "max_epochs > 1 is not supported unless using `MegatronPretrainingRandomBatchSampler` as the batch_sampler for your train dataloader" + ) + + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + self.run_timer.start_time() + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + global_pbar = tqdm( + self.augment_dataloader(self.train_dataloader), + initial=self.step, + total=self.max_steps, + leave=True, + desc="Training steps", + ) + + for _, global_batch in zip(loop_iter, global_pbar): + self.timer.start("train_step_time") + loss, metrics = self.train_single_step(global_batch) + self.timer.stop("train_step_time") + train_step_time = self.timer.get("train_step_time") + # to help avoid fragmentation + clear_memory() + + # TODO(geshen): maybe use the dataloader instead + # bump up the consumed samples but not the step + self.consumed_samples += self.model.cfg.global_batch_size + metrics["consumed_samples"] = self.consumed_samples + metrics["step_time"] = train_step_time + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train/", + ) + metrics = {f"train_{k}": v for k, v in metrics.items()} + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.val_check_interval, + self.cfg.save_interval, + self.limit_val_batches, + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + val_loss, val_metrics = self.run_validation() + # validation is done on the UPDATED weights + # so we use the incremented self.step + self.logger.log_metrics(val_metrics, step=self.step, prefix="val/") + val_metrics = {f"val_{k}": v for k, v in val_metrics.items()} + metrics.update(val_metrics) + + global_pbar.set_postfix(metrics) + + if save_model: + # PTL save wants tensors only + metrics = {k: torch.as_tensor(v) for k, v in metrics.items()} + self.save(metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + metrics.clear() + + self.logger.finalize() + + def save(self, extra_candidates=None, is_train_end=False): + """PTL based save""" + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + + loaded_values = [self.step, self.consumed_samples] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def augment_dataloader(self, dataloader): + """Augment dataloader with ref policy log prob""" + iter_dataloader = iter(dataloader) + while True: + try: + batch = next(iter_dataloader) + logprobs = self.model.get_ref_policy_logprobs(batch).cpu() + ind = 1 + + for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): + batch["ref_policy_log_probs_response_" + str(ind)] = logps + ind += 1 + + yield batch + del logprobs, logps + except StopIteration: + break + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index a61fb46f9..e90fb3aeb 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -47,6 +47,7 @@ RegressionRewardModelDataset, RewardModelDataset, RLHFDataset, + RPOModelDataset, ) from nemo_aligner.utils import parallel_state from nemo_aligner.utils.utils import collate_with_batch_max_sequence_length @@ -262,6 +263,7 @@ def build_dataset(index, name): build_train_valid_test_rlhf_datasets = partial(build_train_valid_test_datasets, RLHFDataset) build_train_valid_test_rm_datasets = partial(build_train_valid_test_datasets, RewardModelDataset) build_train_valid_test_dpo_datasets = partial(build_train_valid_test_datasets, DPOModelDataset) +build_train_valid_test_rpo_datasets = partial(build_train_valid_test_datasets, RPOModelDataset) build_train_valid_test_kto_datasets = partial(build_train_valid_test_datasets, KTOModelDataset) build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset) @@ -355,7 +357,7 @@ def build_dataloader( "data_parallel_size": parallel_state.get_data_parallel_world_size(), "drop_last": drop_last, "global_batch_size": gbs, - "pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, + # "pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, } if use_random_sampler: diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index 5b8acaeb6..2a2a4e0fa 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -15,6 +15,7 @@ """Custom datasets for RLHF training""" import os +import random import numpy as np import scipy @@ -353,6 +354,113 @@ def __getitem__(self, idx): return output +class RPOModelDataset(Dataset): + """This class works only with jsonl files. It assumes each line of the json file is a dictionary + with the prompt, along with the chosen response (response only, no prompt), and the rejected response + (response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and + rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens + with -100 for the prompt part. + + WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! + Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix + strings such as , it would not know where it is safe to truncate for each model. Therefore, + the user must do all truncation logic in their preprocessing step when generating the jsonl + used by this class. Put all special truncation logic there specific to your model. + """ + + def __init__( + self, cfg, tokenizer, name, data_prefix, documents, data, seq_length, seed, drop_last=True, + ): + super().__init__() + self.cfg = cfg + self.name = name + self.data = data + self.drop_last = drop_last + self.seq_length = seq_length + self.tokenizer = tokenizer + + self.reset_position_ids = cfg.data.get("reset_position_ids", False) + self.reset_attention_mask = cfg.data.get("reset_attention_mask", False) + self.eod_mask_loss = cfg.data.get("eod_mask_loss", False) + self.eos_id = tokenizer.eos_id + + self.nograd_length = 32 + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < len(self.data) + + def __len__(self): + return len(self.data) + + def encode(self, text, append_eod=False): + if self.cfg.data.get("apply_ftfy", False): + import ftfy + + text = ftfy.fix_text(text) + + text_ids = self.tokenizer.text_to_ids(text) + + if len(text_ids) > 0 and append_eod: + text_ids.append(self.tokenizer.eos_id) + + return text_ids, len(text_ids) + + def __getitem__(self, idx): + """Returns a pair of chosen/rejected pairs, their respective lengths, and labels. + """ + payload = self.data[idx] + prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) + responses = [] + labels = [] + + # loop on responses of the given prompt to encode them + for resp in payload["responses"]: + resp_tokens, resp_len = self.encode(resp, append_eod=self.cfg.data.get("append_eod", False)) + + resp_tokens = prompt + resp_tokens + resp_len = len(resp_tokens) + + responses.append((resp_tokens, resp_len)) + labels.append(([-100] * prompt_len) + resp_tokens[prompt_len:]) + + assert ( + resp_tokens[0:prompt_len] == prompt + ), "the tokenizer for DPO has merged tokens between prompt and response" + + max_curr_seq_len = max([i[1] for i in responses]) + if max_curr_seq_len > self.seq_length: + logging.warning( + f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})." + + f"The example will be ignored." + ) + + rewards = payload.get("rewards", [random.random() for _ in range(len(responses))]) + resp_dict = {} + + for ind, (resp, resp_len) in enumerate(responses): + resp_tokens = torch.nn.functional.pad( + torch.LongTensor(resp), (0, max_curr_seq_len - resp_len), mode="constant", value=self.eos_id + ) + label = labels[ind] + label_tokens = torch.nn.functional.pad( + torch.LongTensor(label), (0, max_curr_seq_len - len(label)), mode="constant", value=-100 + ) + + # slice if necessary + if max_curr_seq_len > self.seq_length: + resp_tokens = resp_tokens[: self.nograd_length] + label_tokens = torch.ones_like(resp_tokens) * (-100) + resp_len = self.nograd_length + + resp_dict["response_" + str(ind + 1)] = resp_tokens + resp_dict["labels_" + str(ind + 1)] = label_tokens + resp_dict["lengths_" + str(ind + 1)] = resp_len + resp_dict["rewards_" + str(ind + 1)] = rewards[ind] + + return resp_dict + + class KTOModelDataset(Dataset): """This class works only with jsonl files. It assumes each line of the json file is a dictionary with the prompt, along with the response (response only, no prompt), and the status denoting whether the response is diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py new file mode 100644 index 000000000..430e48ca6 --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -0,0 +1,441 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import torch +from apex.transformer.pipeline_parallel.utils import get_num_microbatches +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo_aligner.models.alignable_interface import SupervisedInterface +from nemo_aligner.utils.distributed import broadcast_2d_tensor, from_parallel_logits_to_logprobs +from nemo_aligner.utils.train_utils import ( + finish_validation_step, + grad_reductions, + prepare_for_training_step, + prepare_for_validation_step, + set_sync_funcs, +) +from nemo_aligner.utils.utils import adapter_control, cpu_weight_swap + + +class MegatronGPTRPOModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface): + """ + Megatron GPT RPO Model Training. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + + if self.cfg.pipeline_model_parallel_size > 1 and not self.cfg.megatron_amp_O2: + warnings.warn( + "when using pipeline parallelism, it is recommended to set megatron_amp_O2 to be True to " + "avoid explicit casting for pipeline communication" + ) + self.automatic_optimization = False + self.ref_policy_state_dict = None + + self.preference_avg_log_probs = self.cfg.rpo.get("preference_average_log_probs", False) + self.sft_avg_log_probs = self.cfg.rpo.get("sft_average_log_probs", self.preference_avg_log_probs) + + self.preference_loss_weight = float(self.cfg.rpo.get("preference_loss_weight", 1)) + self.sft_loss_weight = float(self.cfg.rpo.get("sft_loss_weight", 0)) + assert ( + self.preference_loss_weight != 0 or self.sft_loss_weight != 0 + ), "sft loss weight and dpo loss weight cannot both be 0" + + # variants of preference losses, by default RPO. + self.preference_loss = self.cfg.rpo.get("preference_loss", "rpo") + self.gt_reward_scale = float(self.cfg.rpo.get("gt_reward_scale", 1.0)) + + self.beta = float(self.cfg.rpo.get("beta", 0.01)) + self.eta = float(self.cfg.rpo.get("eta", 0.01)) + self.k_len = int(self.cfg.rpo.get("num_responses", 2)) + + @torch.no_grad() + def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): + pi_logprobs = pi_logprobs.detach() + + dp_group = parallel_state.get_data_parallel_group() + + batch_logs = self.get_reduced_masked_logps( + pi_logprobs - ref_logprobs, labels[:, 1:], average_log_probs=average_log_probs + ) + + output_list = [torch.zeros_like(batch_logs) for _ in range(dp_group.size())] + + torch.distributed.all_gather(output_list, batch_logs, group=dp_group) + + split_iter = map(self.split_output_tensor, output_list) + outputs = map(torch.cat, zip(*split_iter)) + flat_outputs = list(map(torch.flatten, outputs)) + + return flat_outputs + + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + batch = next(dataloader_iter) + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + # there is a problem with apex ignoring the mask on the older models + # so we will always give the attention mask + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(["response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(("position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update(["response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["ref_policy_log_probs_response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["labels_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["rewards_" + str(i) for i in range(1, self.k_len + 1)]) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + # creating tokens and labels tensor batches + tokens, labels, ref_logprobs, gt_rewards = None, None, None, None + if batch["response_1"] is not None: + tokens = torch.cat(tuple(batch["response_" + str(i + 1)] for i in range(self.k_len)), dim=0) + + if batch["labels_1"] is not None: + labels = torch.cat(tuple(batch["labels_" + str(i + 1)] for i in range(self.k_len)), dim=0) + + if batch["rewards_1"] is not None: + gt_rewards = torch.cat(tuple(batch["rewards_" + str(i + 1)] for i in range(self.k_len)), dim=0) + + if batch.get("ref_policy_log_probs_response_1") is not None: + ref_logprobs = torch.cat( + tuple(batch["ref_policy_log_probs_response_" + str(i + 1)] for i in range(self.k_len)), dim=0 + ) + + # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs + # these two lines ensure your position_ids and attn_mask are always B=1 + attention_mask = batch["attention_mask"][0:1] + + # Model forward pass + forward_args = { + "input_ids": tokens, + "position_ids": batch["position_ids"], + "attention_mask": attention_mask, + "labels": None, + "loss_mask": None, + } + + # TODO: we can remove this someday when we no longer support legacy models + if not self.mcore_gpt: + forward_args["checkpoint_activations_all_layers"] = checkpoint_activations_all_layers + if not self.use_loss_mask: + forward_args.pop("loss_mask") + else: + forward_args.pop("loss_mask") + + output_tensor = model(**forward_args) + + # in this nemo version the model and autocast dtypes are not synced + # so we need to explicitly cast it + if not parallel_state.is_pipeline_last_stage(): + output_tensor = output_tensor.to(dtype=self.autocast_dtype) + + def logprobs_func(output_tensor, non_loss_data=True): + # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. + # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 + assert non_loss_data + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + ) + return {"logprobs": logprobs} + + def loss_func(output_tensor): + if validation_step and not self.cfg.data.get("validation_drop_last", True): + raise NotImplementedError("RPO does not support validation when cfg.data.drop_last=False") + + per_token_logps = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=labels, + inference_only=validation_step, + higher_stability=True, + ) + + preference_loss, acc_best_resp = self.loss_func( + per_token_logps, + ref_logprobs, + labels[:, 1:], + gt_rewards, + average_log_probs=self.preference_avg_log_probs, + ) + + sft_loss = torch.zeros_like(preference_loss) + if self.sft_loss_weight != 0: + sft_loss = self.sft_loss_func( + per_token_logps, labels[:, 1:], gt_rewards, average_log_probs=self.sft_avg_log_probs + ) + loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss + + ( + reduced_loss, + reduced_preference_loss, + reduced_sft_loss, + reduced_acc, + ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss, acc_best_resp]) + + out_responses = self.gather_and_split_rewards( + per_token_logps, ref_logprobs, labels, average_log_probs=self.preference_avg_log_probs + ) + + return ( + loss, + { + "avg": reduced_loss, + "avg_sft_loss": reduced_sft_loss, + "avg_preference_loss": reduced_preference_loss, + "acc": reduced_acc, + "out_responses": out_responses, + }, + ) + + if logprobs_only: + return output_tensor, logprobs_func + else: + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def split_output_tensor(self, output_tensor): + responses_logps = torch.split(output_tensor.float(), len(output_tensor) // self.k_len, dim=0) + return responses_logps + + def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): + assert logps.shape == labels.shape, "logps and labels shape mismatch" + + loss_mask = (labels > -1).float() + + if average_log_probs: + # need to guard against divide by zero in case labels are all -100 + return (logps * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1) + else: + return (logps * loss_mask).sum(-1) + + def log_sum_exp(self, x): + max_x = torch.max(x) + return max_x + torch.log(torch.sum(torch.exp(x - max_x))) + + def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): + if self.preference_loss == "rpo": + # estimated rewards + rewards_pred = torch.stack( + self.split_output_tensor( + self.get_reduced_masked_logps( + self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs + ) + ) + ) + + # based on GT rewards + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) + p_star = self.eta * gt_rewards + else: + raise ValueError("Unknown RPO Loss") + + loss = ( + ( + torch.nn.functional.softmax(p_star, dim=0) + * ( + torch.nn.functional.log_softmax(p_star, dim=0) + - torch.nn.functional.log_softmax(rewards_pred, dim=0) + ) + ) + .sum(0) + .mean(0) + ) + + # adding accuracy for the best rewards -> MSE or best accuracy? + acc_best_resp = (torch.argmax(rewards_pred, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() + + return loss, acc_best_resp + + def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] + all_log_probs = torch.stack( + self.split_output_tensor(logprobs) + ) # [4, 4] -> each has several responses which we select the best? + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards + chosen_best = torch.argmax(gt_rewards, dim=0) + + chosen_logprobs = all_log_probs[chosen_best, torch.arange(all_log_probs.size(1))] + return -chosen_logprobs.mean(0) + + def get_loss_and_metrics(self, batch, forward_only): + seq_length = batch["response_1"].shape[1] + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size + * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # NOTE: assume that the returned values are already gathered across the DP workers + collected_rewards_per_resp = [] + for i in range(self.k_len): + collected_rewards_per_resp.append( + torch.cat([item["out_responses"][i] for item in losses_reduced_per_micro_batch]) + ) + + rewards_all = torch.cat(tuple(collected_rewards_per_resp)) + rewards_all_mean = rewards_all.mean() + rewards_all_std = rewards_all.std() + + loss_mean = torch.as_tensor( + [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + sft_loss_mean = torch.as_tensor( + [loss_reduced["avg_sft_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + preference_loss_mean = torch.as_tensor( + [loss_reduced["avg_preference_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + acc_mean = torch.as_tensor( + [loss_reduced["acc"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + sft_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + preference_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + acc_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + rewards_all_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + rewards_all_std = torch.tensor(0.0, device=torch.cuda.current_device()) + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(loss_mean, get_last_rank()) + torch.distributed.broadcast(sft_loss_mean, get_last_rank()) + torch.distributed.broadcast(preference_loss_mean, get_last_rank()) + torch.distributed.broadcast(acc_mean, get_last_rank()) + + torch.distributed.broadcast(rewards_all_mean, get_last_rank()) + torch.distributed.broadcast(rewards_all_std, get_last_rank()) + + metrics = { + "loss": loss_mean, + "sft_loss": sft_loss_mean, + "preference_loss": preference_loss_mean, + "acc": acc_mean, + "rewards_all_mean": rewards_all_mean, + "rewards_all_std": rewards_all_std, + } + + # move to CPU + metrics = {k: v.item() for k, v in metrics.items()} + + return loss_mean.item(), metrics + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def finish_training_step(self): + grad_reductions(self) + + def prepare_for_validation_step(self): + prepare_for_validation_step(self) + + def finish_validation_step(self): + finish_validation_step(self) + + @torch.no_grad() + def get_logprob_batch(self, batch): + seq_length = batch["response_1"].shape[1] + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + set_sync_funcs(self, forward_only=True) + + fwd_bwd_function = get_forward_backward_func() + + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=seq_length, + micro_batch_size=self.cfg.rpo.log_prob_forward_micro_batch_size, + collect_non_loss_data=True, + ) + + each_response_list = [[] for _ in range(self.k_len)] + + if len(logprobs_list) > 0: + for item in logprobs_list: + all_log_probs = self.split_output_tensor(item["logprobs"]) + for ind in range(self.k_len): + each_response_list[ind].extend(all_log_probs[ind]) + each_response_list = [torch.stack(b, dim=0) for b in each_response_list] + logprobs = torch.cat(each_response_list, dim=0) + + else: + logprobs = None + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + # broadcast it from last PP stage to everything else + logprobs = broadcast_2d_tensor( + logprobs, + parallel_state.get_pipeline_model_parallel_last_rank(), + parallel_state.get_pipeline_model_parallel_group(), + ) + + return logprobs + + def get_ref_policy_logprobs(self, batch): + + if self.use_peft and self.ref_policy_state_dict is None: + # when using adapters instead of full-tuning, the actor is reference model + adapters + with adapter_control(self): + # With adapters disabled (meaning using the reference model), calculate ref_log_probs + ref_log_probs = self.get_logprob_batch(batch) + else: + with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): + ref_log_probs = self.get_logprob_batch(batch) + + # return in GPU, trainer needs to move to cpu + return ref_log_probs