From c713cecb4b56c3c5dee2a681734102f6f00422f1 Mon Sep 17 00:00:00 2001 From: Shengyang Sun Date: Wed, 3 Jul 2024 15:16:25 -0700 Subject: [PATCH 01/17] fix log probs mismatch Signed-off-by: Shengyang Sun --- examples/nlp/gpt/conf/gpt_dpo.yaml | 1 - .../models/nlp/gpt/megatron_gpt_dpo_model.py | 90 ++++++++----------- 2 files changed, 39 insertions(+), 52 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 07c5fe690..8e4f8addc 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -59,7 +59,6 @@ model: megatron_amp_O2: True dpo: - log_prob_forward_micro_batch_size: 1 ref_policy_kl_penalty: 0.2 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index d5174fa4b..1f186b066 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -92,7 +92,7 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo return out_chosen.flatten(), out_rejected.flatten() - def get_forward_output_and_loss_func(self, validation_step=False): + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) @@ -128,7 +128,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - if batch["ref_policy_log_probs_chosen"] is not None and batch["ref_policy_log_probs_rejected"] is not None: + if ("ref_policy_log_probs_chosen" in batch + and "ref_policy_log_probs_rejected" in batch + and batch["ref_policy_log_probs_chosen"] is not None + and batch["ref_policy_log_probs_rejected"] is not None): ref_logprobs = torch.cat( (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 ) @@ -165,6 +168,13 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if not parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor.to(dtype=self.autocast_dtype) + def logprobs_func(output_tensor, non_loss_data=True): + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + ) + + return {"logprobs": logprobs} + def loss_func(output_tensor): if validation_step and not self.cfg.data.get("validation_drop_last", True): raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") @@ -211,7 +221,10 @@ def loss_func(output_tensor): }, ) - return output_tensor, loss_func + if logprobs_only: + return output_tensor, logprobs_func + else: + return output_tensor, loss_func return fwd_output_and_loss_func @@ -295,7 +308,7 @@ def get_loss_and_metrics(self, batch, forward_only): fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(forward_only), + forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), data_iterator=data_iter, model=self.model, num_microbatches=get_num_microbatches(), @@ -384,57 +397,37 @@ def prepare_for_validation_step(self): def finish_validation_step(self): finish_validation_step(self) - def get_logprob_output_only_func(self, inference_only=True): - fwd_output_only_func = self.get_forward_output_only_func() - - def log_prob_output_only_func(dataloader_iter, model): - batch = next(dataloader_iter) - logits, _ = fwd_output_only_func(iter([batch[0:3],]), model) - - def id_func(logits, non_loss_data=True): - logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=logits, - target=batch[-1].cuda() if len(batch) == 4 else batch[0].cuda(), - inference_only=inference_only, - higher_stability=True, - ) - return {"logprobs": logprobs} - - return logits, id_func - - return log_prob_output_only_func - @torch.no_grad() - def get_logprob_batch(self, global_batch): - set_sync_funcs(self, forward_only=True) - - # assumes we pad to seq length before going into the model - # response_tokens = sequences.cuda() - # labels = labels.cuda() if labels is not None else None - - dp_size = parallel_state.get_data_parallel_world_size() - local_batch_size, seq_len = global_batch[0].shape - global_batch_size = local_batch_size * dp_size - - forward_mbs = self.cfg.dpo.log_prob_forward_micro_batch_size - forward_mbs_times_dp = forward_mbs * dp_size + def get_logprob_batch(self, batch): + seq_length = batch["chosen"].shape[1] - data_iter = get_iterator_k_split(global_batch, global_batch_size // forward_mbs_times_dp) + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only=True) fwd_bwd_function = get_forward_backward_func() + logprobs_list = fwd_bwd_function( - forward_step_func=self.get_logprob_output_only_func(inference_only=True), + forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), data_iterator=data_iter, model=self.model, - num_microbatches=global_batch_size // forward_mbs_times_dp, + num_microbatches=get_num_microbatches(), forward_only=True, - seq_length=seq_len, - micro_batch_size=forward_mbs, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size + * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 collect_non_loss_data=True, ) if len(logprobs_list) > 0: - logprobs = torch.cat([item["logprobs"] for item in logprobs_list]) + chosen_logprobs_list = [] + rejected_logprobs_list = [] + for item in logprobs_list: + chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) + chosen_logprobs_list.append(chosen_logprobs) + rejected_logprobs_list.append(rejected_logprobs) + + logprobs = torch.cat([torch.cat(chosen_logprobs_list), + torch.cat(rejected_logprobs_list)], dim=0) else: logprobs = None @@ -449,20 +442,15 @@ def get_logprob_batch(self, global_batch): return logprobs def get_ref_policy_logprobs(self, batch): - tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) - masks = torch.cat((batch["attention_mask"], batch["attention_mask"]), dim=0) - pos_ids = torch.cat((batch["position_ids"], batch["position_ids"]), dim=0) - labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - global_batch = [tokens, masks, pos_ids, labels] - + if self.use_peft and self.ref_policy_state_dict is None: # when using adapters instead of full-tuning, the actor is reference model + adapters with adapter_control(self): # With adapters disabled (meaning using the reference model), calculate ref_log_probs - ref_log_probs = self.get_logprob_batch(global_batch) + ref_log_probs = self.get_logprob_batch(batch) else: with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): - ref_log_probs = self.get_logprob_batch(global_batch) + ref_log_probs = self.get_logprob_batch(batch) # return in GPU, trainer needs to move to cpu return ref_log_probs From 0b11982452864e1e026fd5e5428cddeb0f1ed11c Mon Sep 17 00:00:00 2001 From: Shengyang Sun Date: Wed, 3 Jul 2024 15:16:25 -0700 Subject: [PATCH 02/17] fix log probs mismatch Signed-off-by: Shengyang Sun --- CHANGELOG.md | 3 +- examples/nlp/gpt/conf/gpt_dpo.yaml | 1 - .../models/nlp/gpt/megatron_gpt_dpo_model.py | 90 ++++++++----------- 3 files changed, 41 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ee097ea0..306df1786 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Next Version] - Implement reward-aware preference optimization. - +- Fix log probs mismatch issue between policy and reference policy in DPO & variants. + ### New features and optimizations ### Breaking changes diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 07c5fe690..8e4f8addc 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -59,7 +59,6 @@ model: megatron_amp_O2: True dpo: - log_prob_forward_micro_batch_size: 1 ref_policy_kl_penalty: 0.2 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index d5174fa4b..1f186b066 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -92,7 +92,7 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo return out_chosen.flatten(), out_rejected.flatten() - def get_forward_output_and_loss_func(self, validation_step=False): + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) @@ -128,7 +128,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - if batch["ref_policy_log_probs_chosen"] is not None and batch["ref_policy_log_probs_rejected"] is not None: + if ("ref_policy_log_probs_chosen" in batch + and "ref_policy_log_probs_rejected" in batch + and batch["ref_policy_log_probs_chosen"] is not None + and batch["ref_policy_log_probs_rejected"] is not None): ref_logprobs = torch.cat( (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 ) @@ -165,6 +168,13 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if not parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor.to(dtype=self.autocast_dtype) + def logprobs_func(output_tensor, non_loss_data=True): + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + ) + + return {"logprobs": logprobs} + def loss_func(output_tensor): if validation_step and not self.cfg.data.get("validation_drop_last", True): raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") @@ -211,7 +221,10 @@ def loss_func(output_tensor): }, ) - return output_tensor, loss_func + if logprobs_only: + return output_tensor, logprobs_func + else: + return output_tensor, loss_func return fwd_output_and_loss_func @@ -295,7 +308,7 @@ def get_loss_and_metrics(self, batch, forward_only): fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(forward_only), + forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), data_iterator=data_iter, model=self.model, num_microbatches=get_num_microbatches(), @@ -384,57 +397,37 @@ def prepare_for_validation_step(self): def finish_validation_step(self): finish_validation_step(self) - def get_logprob_output_only_func(self, inference_only=True): - fwd_output_only_func = self.get_forward_output_only_func() - - def log_prob_output_only_func(dataloader_iter, model): - batch = next(dataloader_iter) - logits, _ = fwd_output_only_func(iter([batch[0:3],]), model) - - def id_func(logits, non_loss_data=True): - logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=logits, - target=batch[-1].cuda() if len(batch) == 4 else batch[0].cuda(), - inference_only=inference_only, - higher_stability=True, - ) - return {"logprobs": logprobs} - - return logits, id_func - - return log_prob_output_only_func - @torch.no_grad() - def get_logprob_batch(self, global_batch): - set_sync_funcs(self, forward_only=True) - - # assumes we pad to seq length before going into the model - # response_tokens = sequences.cuda() - # labels = labels.cuda() if labels is not None else None - - dp_size = parallel_state.get_data_parallel_world_size() - local_batch_size, seq_len = global_batch[0].shape - global_batch_size = local_batch_size * dp_size - - forward_mbs = self.cfg.dpo.log_prob_forward_micro_batch_size - forward_mbs_times_dp = forward_mbs * dp_size + def get_logprob_batch(self, batch): + seq_length = batch["chosen"].shape[1] - data_iter = get_iterator_k_split(global_batch, global_batch_size // forward_mbs_times_dp) + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only=True) fwd_bwd_function = get_forward_backward_func() + logprobs_list = fwd_bwd_function( - forward_step_func=self.get_logprob_output_only_func(inference_only=True), + forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), data_iterator=data_iter, model=self.model, - num_microbatches=global_batch_size // forward_mbs_times_dp, + num_microbatches=get_num_microbatches(), forward_only=True, - seq_length=seq_len, - micro_batch_size=forward_mbs, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size + * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 collect_non_loss_data=True, ) if len(logprobs_list) > 0: - logprobs = torch.cat([item["logprobs"] for item in logprobs_list]) + chosen_logprobs_list = [] + rejected_logprobs_list = [] + for item in logprobs_list: + chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) + chosen_logprobs_list.append(chosen_logprobs) + rejected_logprobs_list.append(rejected_logprobs) + + logprobs = torch.cat([torch.cat(chosen_logprobs_list), + torch.cat(rejected_logprobs_list)], dim=0) else: logprobs = None @@ -449,20 +442,15 @@ def get_logprob_batch(self, global_batch): return logprobs def get_ref_policy_logprobs(self, batch): - tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) - masks = torch.cat((batch["attention_mask"], batch["attention_mask"]), dim=0) - pos_ids = torch.cat((batch["position_ids"], batch["position_ids"]), dim=0) - labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - global_batch = [tokens, masks, pos_ids, labels] - + if self.use_peft and self.ref_policy_state_dict is None: # when using adapters instead of full-tuning, the actor is reference model + adapters with adapter_control(self): # With adapters disabled (meaning using the reference model), calculate ref_log_probs - ref_log_probs = self.get_logprob_batch(global_batch) + ref_log_probs = self.get_logprob_batch(batch) else: with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): - ref_log_probs = self.get_logprob_batch(global_batch) + ref_log_probs = self.get_logprob_batch(batch) # return in GPU, trainer needs to move to cpu return ref_log_probs From 78ef52911bf3a9e55f4c258d96c7f4b2652766fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:24:31 +0000 Subject: [PATCH 03/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 1f186b066..b36e5bc2d 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -128,10 +128,12 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - if ("ref_policy_log_probs_chosen" in batch + if ( + "ref_policy_log_probs_chosen" in batch and "ref_policy_log_probs_rejected" in batch - and batch["ref_policy_log_probs_chosen"] is not None - and batch["ref_policy_log_probs_rejected"] is not None): + and batch["ref_policy_log_probs_chosen"] is not None + and batch["ref_policy_log_probs_rejected"] is not None + ): ref_logprobs = torch.cat( (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 ) @@ -172,7 +174,7 @@ def logprobs_func(output_tensor, non_loss_data=True): logprobs = from_parallel_logits_to_logprobs( vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, ) - + return {"logprobs": logprobs} def loss_func(output_tensor): @@ -426,8 +428,7 @@ def get_logprob_batch(self, batch): chosen_logprobs_list.append(chosen_logprobs) rejected_logprobs_list.append(rejected_logprobs) - logprobs = torch.cat([torch.cat(chosen_logprobs_list), - torch.cat(rejected_logprobs_list)], dim=0) + logprobs = torch.cat([torch.cat(chosen_logprobs_list), torch.cat(rejected_logprobs_list)], dim=0) else: logprobs = None @@ -442,7 +443,7 @@ def get_logprob_batch(self, batch): return logprobs def get_ref_policy_logprobs(self, batch): - + if self.use_peft and self.ref_policy_state_dict is None: # when using adapters instead of full-tuning, the actor is reference model + adapters with adapter_control(self): From faee17ddb18e0355feae180ba367fe5416a52e8f Mon Sep 17 00:00:00 2001 From: Shengyang Sun Date: Fri, 5 Jul 2024 12:14:00 -0700 Subject: [PATCH 04/17] enable log_prob_forward_micro_batch_size Signed-off-by: Shengyang Sun --- examples/nlp/gpt/conf/gpt_dpo.yaml | 1 + .../models/nlp/gpt/megatron_gpt_dpo_model.py | 13 +++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 8e4f8addc..5958dcb49 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -59,6 +59,7 @@ model: megatron_amp_O2: True dpo: + log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, 2} ref_policy_kl_penalty: 0.2 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 1f186b066..3df3cb25b 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -128,10 +128,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - if ("ref_policy_log_probs_chosen" in batch - and "ref_policy_log_probs_rejected" in batch - and batch["ref_policy_log_probs_chosen"] is not None - and batch["ref_policy_log_probs_rejected"] is not None): + if (batch.get("ref_policy_log_probs_chosen") is not None + and batch.get("ref_policy_log_probs_rejected") is not None): ref_logprobs = torch.cat( (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 ) @@ -180,7 +178,7 @@ def loss_func(output_tensor): raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") per_token_logps = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=labels, higher_stability=True + vocab_parallel_logits=output_tensor, target=labels, inference_only=validation_step, higher_stability=True ) preference_loss, acc_chosen = self.loss_func( @@ -249,7 +247,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) rewards_delta = chosen_rewards - reject_rewards - + if self.preference_loss == "dpo": loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) elif self.preference_loss == "rpo_bwd_kl": @@ -413,8 +411,7 @@ def get_logprob_batch(self, batch): num_microbatches=get_num_microbatches(), forward_only=True, seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size - * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size, collect_non_loss_data=True, ) From dd6c271dfe7eb9831d3075417a4bd3f163c23a3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 19:18:51 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index aa3ce4916..c968e41d1 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -128,8 +128,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - if (batch.get("ref_policy_log_probs_chosen") is not None - and batch.get("ref_policy_log_probs_rejected") is not None): + if ( + batch.get("ref_policy_log_probs_chosen") is not None + and batch.get("ref_policy_log_probs_rejected") is not None + ): ref_logprobs = torch.cat( (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 ) @@ -177,7 +179,10 @@ def loss_func(output_tensor): raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") per_token_logps = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=labels, inference_only=validation_step, higher_stability=True + vocab_parallel_logits=output_tensor, + target=labels, + inference_only=validation_step, + higher_stability=True, ) preference_loss, acc_chosen = self.loss_func( @@ -246,7 +251,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) rewards_delta = chosen_rewards - reject_rewards - + if self.preference_loss == "dpo": loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) elif self.preference_loss == "rpo_bwd_kl": @@ -437,7 +442,7 @@ def get_logprob_batch(self, batch): return logprobs def get_ref_policy_logprobs(self, batch): - + if self.use_peft and self.ref_policy_state_dict is None: # when using adapters instead of full-tuning, the actor is reference model + adapters with adapter_control(self): From bfe31353f1e39242206b361faf4f87f2e5365cd2 Mon Sep 17 00:00:00 2001 From: Shengyang Sun Date: Thu, 11 Jul 2024 06:19:23 -0700 Subject: [PATCH 06/17] add comments Signed-off-by: Shengyang Sun --- examples/nlp/gpt/conf/gpt_dpo.yaml | 2 ++ nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 5958dcb49..dc745c726 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -59,6 +59,8 @@ model: megatron_amp_O2: True dpo: + # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. + # A higher value can be used to speed-up log probs computations, but may cause numeric differences. log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, 2} ref_policy_kl_penalty: 0.2 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index aa3ce4916..5b2c15721 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -167,6 +167,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ output_tensor = output_tensor.to(dtype=self.autocast_dtype) def logprobs_func(output_tensor, non_loss_data=True): + # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. + # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 + assert non_loss_data logprobs = from_parallel_logits_to_logprobs( vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, ) From 0a8128468e21e8a9c54112444498366aa591b0fa Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Fri, 6 Sep 2024 08:57:18 -0700 Subject: [PATCH 07/17] added RPO for multiple responses (4) Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_dpo.yaml | 22 +- examples/nlp/gpt/conf/gpt_rpo.yaml | 140 +++++++ examples/nlp/gpt/train_gpt_rpo.py | 162 +++++++ nemo_aligner/algorithms/rpo.py | 332 +++++++++++++++ nemo_aligner/data/nlp/builders.py | 6 +- nemo_aligner/data/nlp/datasets.py | 109 +++++ .../models/nlp/gpt/megatron_gpt_dpo_model.py | 10 +- .../models/nlp/gpt/megatron_gpt_rpo_model.py | 396 ++++++++++++++++++ nemo_aligner/utils/train_utils.py | 1 + nemo_aligner/utils/utils.py | 1 + 10 files changed, 1168 insertions(+), 11 deletions(-) create mode 100644 examples/nlp/gpt/conf/gpt_rpo.yaml create mode 100644 examples/nlp/gpt/train_gpt_rpo.py create mode 100644 nemo_aligner/algorithms/rpo.py create mode 100644 nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index dc745c726..88be4f3fe 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -2,8 +2,8 @@ defaults: - optional tp_overlap@model.ub_tp_comm_overlap_cfg: trainer: - num_nodes: 8 - devices: 8 + num_nodes: 1 + devices: 4 accelerator: gpu precision: bf16 @@ -12,7 +12,7 @@ trainer: max_epochs: 1 max_steps: -1 val_check_interval: 0.1 - save_interval: 100 + save_interval: 5 limit_train_batches: 1.0 # how many GBS we loop over @@ -50,12 +50,12 @@ exp_manager: model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} pretrained_checkpoint: - restore_from_path: null + restore_from_path: /data/GPT-2B-001_bf16_tp1.nemo model: - mcore_gpt: True - micro_batch_size: 1 - global_batch_size: 64 + mcore_gpt: False + micro_batch_size: 4 + global_batch_size: 32 megatron_amp_O2: True dpo: @@ -121,7 +121,13 @@ model: reset_attention_mask: False # Reset attention mask after end-of-document token eod_mask_loss: False # Mask loss for the end of document tokens index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix - data_prefix: null + data_prefix: + train: + - /data/sample_dpo_dataset.jsonl + test: + - /data/sample_dpo_dataset.jsonl + validation: + - /data/sample_dpo_dataset.jsonl default_chosen_reward: 1. # the default reward for the chosen response in RPO default_rejected_reward: 0. # the default reward for the rejected response in RPO diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml new file mode 100644 index 000000000..aa8fd535e --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -0,0 +1,140 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + num_nodes: 1 + devices: 8 + accelerator: gpu + precision: bf16 + + # rpo specific args + rpo: + max_epochs: 100 + max_steps: -1 + val_check_interval: 0.1 + save_interval: 100 + limit_train_batches: 1.0 + + # how many GBS we loop over + limit_val_batches: 1.0 + gradient_clip_val: 1.0 + + # do not change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.rpo.max_epochs} + max_steps: ${.rpo.max_steps} + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt + max_time_per_run: ${trainer.max_time} + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_rpo + name: rlhf_gpt3_rpo + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}' + model_parallel_size: 4 + +pretrained_checkpoint: + restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo + +model: + mcore_gpt: True + micro_batch_size: 1 + global_batch_size: 16 + megatron_amp_O2: True + + rpo: + # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. + # A higher value can be used to speed-up log probs computations, but may cause numeric differences. + log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, 2} + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + gt_reward_scale: 1. # the scale of the rewards in RPO + preference_loss: rpo_multi_response # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0.05 # the coefficient of the SFT loss + beta: 0.2 + eta: 0.2 + + #encoder_seq_length: 4096 + #max_position_embeddings: ${model.encoder_seq_length} + + # miscellaneous + seed: 1234 + + #peft + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-6 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-7 + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: + train: + - /data/responses_general.rm.formatted.jsonl + test: + - /data/small_test_formatted.jsonl + validation: + - /data/small_test_formatted.jsonl + default_chosen_reward: 1. # the default reward for the chosen response in RPO + default_rejected_reward: 0. # the default reward for the rejected response in RPO + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True + +precision: ${trainer.precision} diff --git a/examples/nlp/gpt/train_gpt_rpo.py b/examples/nlp/gpt/train_gpt_rpo.py new file mode 100644 index 000000000..0f850ff56 --- /dev/null +++ b/examples/nlp/gpt/train_gpt_rpo.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.rpo import RPOTrainer, rpo_custom_collate +from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_rpo_datasets +from nemo_aligner.models.nlp.gpt.megatron_gpt_rpo_model import MegatronGPTRPOModel +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_rpo") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "rpo") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTRPOModel, + cfg.model, + trainer, + strict=True, + load_base_model_only=False, + restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + if cfg.model.peft.peft_scheme == "none": + ref_policy_state_dict = retrieve_model_state_dict_in_cpu( + ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False) + ) + ptl_model.ref_policy_state_dict = ref_policy_state_dict + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3 + + train_ds, validation_ds, test_ds = build_train_valid_test_rpo_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=partial( + rpo_custom_collate, + eos_id=ptl_model.tokenizer.eos_id, + reset_position_ids=cfg.model.data.get("reset_position_ids", False), + reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), + eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), + ), + ) + + val_dataloader = build_dataloader( + cfg=cfg, + dataset=validation_ds, + consumed_samples=0, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=partial( + rpo_custom_collate, + eos_id=ptl_model.tokenizer.eos_id, + reset_position_ids=cfg.model.data.get("reset_position_ids", False), + reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), + eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), + ), + use_random_sampler=False, + ) + + init_using_ptl(trainer, ptl_model, train_dataloader, train_ds) + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + timer = Timer(cfg.exp_manager.get("max_time_per_run")) + dpo_trainer = RPOTrainer( + cfg=cfg.trainer.rpo, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + test_dataloader=None, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + dpo_trainer.load_state_dict(custom_trainer_state_dict) + + dpo_trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/rpo.py b/nemo_aligner/algorithms/rpo.py new file mode 100644 index 000000000..221e2b7ed --- /dev/null +++ b/nemo_aligner/algorithms/rpo.py @@ -0,0 +1,332 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from statistics import mean + +import torch +from omegaconf.dictconfig import DictConfig +from tqdm import tqdm + +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingRandomBatchSampler, +) +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.utils import logging +from nemo_aligner.utils.distributed import SyncTimer +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory + + +def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): + resp_outputs = {} + ## assume len 4 for responses + for k in batch[0].keys(): + if k.startswith('response_'): + # get response_i + resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( + [item[k] for item in batch], + batch_first=True, padding_value=eos_id) + elif k.startswith('labels_'): + # get labels_i + resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( + [item[k] for item in batch], + batch_first=True, padding_value=-100) + elif k.startswith('lengths_'): + # get lens_i + resp_outputs[k] = torch.LongTensor([item[k] for item in batch]) + elif k.startswith('rewards_'): + # get r_i + resp_outputs[k] = torch.FloatTensor([item[k] for item in batch]) + + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + resp_outputs['response_1'], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, + ) + assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate" + if attention_mask.shape[0] == 1: + # using .expand() here causes errors from pin_memory=True, so need to use .repeat() + # attention_mask = attention_mask.expand(len(batch), *((-1,) * (len(attention_mask.shape) - 1))) + attention_mask = attention_mask.repeat(4, *((1,) * (len(attention_mask.shape) - 1))) + + + resp_outputs["attention_mask"] = attention_mask + resp_outputs["position_ids"] = position_ids + + return resp_outputs + + +class RPOTrainer: + """Trainer to coordinate DPO training + """ + + def __init__( + self, + cfg: DictConfig, + model, + optimizer, + scheduler, + train_dataloader, + val_dataloader, + test_dataloader, + logger, + ckpt_callback, + run_timer, + ): + self.model = model + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.test_dataloader = test_dataloader + self.logger = logger + self.cfg = cfg + self.optimizer = optimizer + self.scheduler = scheduler + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.step = 0 + self.consumed_samples = 0 + + self.ckpt_callback = ckpt_callback + + # compute `max_steps` + self.num_steps_per_epoch = compute_num_steps_per_epoch( + self.train_dataloader.batch_sampler, self.cfg.get("limit_train_batches", 1.0) + ) + + self.limit_val_batches = compute_limit_batches(len(val_dataloader), self.cfg.limit_val_batches) + self.val_check_interval = ( + int(self.cfg.val_check_interval * self.num_steps_per_epoch) + if isinstance(self.cfg.val_check_interval, float) + else self.cfg.val_check_interval + ) + self.set_max_steps() + + self.timer = SyncTimer( + reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX + ) + + def validation_step(self, global_batch): + # these things should go into a GPTModel wrapper + self.model.prepare_for_validation_step() + + loss_mean, metrics = self.model.get_loss_and_metrics(batch=global_batch, forward_only=True) + + self.model.finish_validation_step() + return loss_mean, metrics + + @torch.no_grad() + def run_validation(self): + loss_means = [] + val_metrics = defaultdict(list) + + val_pbar = tqdm( + zip(range(self.limit_val_batches), self.augment_dataloader(self.val_dataloader)), + total=self.limit_val_batches, + leave=True, + desc="Validation steps", + ) + + for _, batch in val_pbar: + self.timer.start("validation_step_time") + loss_mean, metrics = self.validation_step(batch) + self.timer.stop("validation_step_time") + validation_step_time = self.timer.get("validation_step_time") + + metrics["validation_step_time"] = validation_step_time + + loss_means.append(loss_mean) + for k, v in metrics.items(): + val_metrics[k].append(v) + log_val_metrics = {f"val_{k}": v for k, v in metrics.items()} + val_pbar.set_postfix(log_val_metrics) + + val_metrics = {k: mean(v) for k, v in val_metrics.items()} + return mean(loss_means), val_metrics + + def train_single_step(self, global_batch): + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + + # NOTE: assume backward is called on the loss already + loss_mean, metrics = self.model.get_loss_and_metrics(batch=global_batch, forward_only=False) + + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + trainer_metrics = {} + if grad_norm is not None: + trainer_metrics["grad_norm"] = grad_norm + trainer_metrics.update({"lr": lr, "loss": loss_mean}) + + return loss_mean, {**metrics, **trainer_metrics} + + def fit(self): + if (not isinstance(self.train_dataloader.batch_sampler, MegatronPretrainingRandomBatchSampler)) and ( + self.cfg.max_epochs is not None and self.cfg.max_epochs > 1 + ): + # if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py) + # then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling + # to fix this, you should use MegatronPretrainingRandomBatchSampler instead, which alleviates this issue and allows + # random shuffling for each epoch. + raise ValueError( + "max_epochs > 1 is not supported unless using `MegatronPretrainingRandomBatchSampler` as the batch_sampler for your train dataloader" + ) + + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + self.run_timer.start_time() + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + global_pbar = tqdm( + self.augment_dataloader(self.train_dataloader), + initial=self.step, + total=self.max_steps, + leave=True, + desc="Training steps", + ) + + for _, global_batch in zip(loop_iter, global_pbar): + self.timer.start("train_step_time") + loss, metrics = self.train_single_step(global_batch) + self.timer.stop("train_step_time") + train_step_time = self.timer.get("train_step_time") + # to help avoid fragmentation + clear_memory() + + # TODO(geshen): maybe use the dataloader instead + # bump up the consumed samples but not the step + self.consumed_samples += self.model.cfg.global_batch_size + metrics["consumed_samples"] = self.consumed_samples + metrics["step_time"] = train_step_time + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train/", + ) + metrics = {f"train_{k}": v for k, v in metrics.items()} + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.val_check_interval, + self.cfg.save_interval, + self.limit_val_batches, + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + val_loss, val_metrics = self.run_validation() + # validation is done on the UPDATED weights + # so we use the incremented self.step + self.logger.log_metrics(val_metrics, step=self.step, prefix="val/") + val_metrics = {f"val_{k}": v for k, v in val_metrics.items()} + metrics.update(val_metrics) + + global_pbar.set_postfix(metrics) + + if save_model: + # PTL save wants tensors only + metrics = {k: torch.as_tensor(v) for k, v in metrics.items()} + self.save(metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + metrics.clear() + + self.logger.finalize() + + def save(self, extra_candidates=None, is_train_end=False): + """PTL based save""" + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + + loaded_values = [self.step, self.consumed_samples] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def augment_dataloader(self, dataloader): + """Augment dataloader with ref policy log prob""" + iter_dataloader = iter(dataloader) + while True: + try: + batch = next(iter_dataloader) + logprobs = self.model.get_ref_policy_logprobs(batch).cpu() + ind = 1 + + for logps in torch.split(logprobs, len(logprobs) // 4, dim=0): + batch["ref_policy_log_probs_response_" + str(ind)] = logps + ind += 1 + + yield batch + del logprobs, logps + except StopIteration: + break + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index 3a4a25823..a38f19bf5 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -44,6 +44,7 @@ from nemo.utils import logging from nemo_aligner.data.nlp.datasets import ( DPOModelDataset, + RPOModelDataset, RegressionRewardModelDataset, RewardModelDataset, RLHFDataset, @@ -261,6 +262,7 @@ def build_dataset(index, name): build_train_valid_test_rlhf_datasets = partial(build_train_valid_test_datasets, RLHFDataset) build_train_valid_test_rm_datasets = partial(build_train_valid_test_datasets, RewardModelDataset) build_train_valid_test_dpo_datasets = partial(build_train_valid_test_datasets, DPOModelDataset) +build_train_valid_test_rpo_datasets = partial(build_train_valid_test_datasets, RPOModelDataset) build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset) @@ -336,14 +338,14 @@ def build_dataloader( "data_parallel_size": parallel_state.get_data_parallel_world_size(), "drop_last": drop_last, "global_batch_size": gbs, - "pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, + # "pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, } # Megatron sampler if hasattr(cfg.model.data, "dataloader_type") and cfg.model.data.dataloader_type == "single": if use_random_sampler: cls = MegatronPretrainingRandomBatchSampler if load_gbs else MegatronPretrainingRandomSampler - common_params["seed"] = cfg.model.seed + # common_params["seed"] = cfg.model.seed else: cls = MegatronPretrainingBatchSampler if load_gbs else MegatronPretrainingSampler batch_sampler = cls(**common_params) diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index a452eec5e..3d7221788 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -19,6 +19,7 @@ import numpy as np import scipy import torch +import random from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset @@ -353,6 +354,114 @@ def __getitem__(self, idx): return output +class RPOModelDataset(Dataset): + """This class works only with jsonl files. It assumes each line of the json file is a dictionary + with the prompt, along with the chosen response (response only, no prompt), and the rejected response + (response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and + rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens + with -100 for the prompt part. + + WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! + Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix + strings such as , it would not know where it is safe to truncate for each model. Therefore, + the user must do all truncation logic in their preprocessing step when generating the jsonl + used by this class. Put all special truncation logic there specific to your model. + """ + + def __init__( + self, cfg, tokenizer, name, data_prefix, documents, data, seq_length, seed, drop_last=True, + ): + super().__init__() + self.cfg = cfg + self.name = name + self.data = data + self.drop_last = drop_last + self.seq_length = seq_length + self.tokenizer = tokenizer + + self.reset_position_ids = cfg.data.get("reset_position_ids", False) + self.reset_attention_mask = cfg.data.get("reset_attention_mask", False) + self.eod_mask_loss = cfg.data.get("eod_mask_loss", False) + self.eos_id = tokenizer.eos_id + + self.nograd_length = 32 + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < len(self.data) + + def __len__(self): + return len(self.data) + + def encode(self, text, append_eod=False): + if self.cfg.data.get("apply_ftfy", False): + import ftfy + + text = ftfy.fix_text(text) + + text_ids = self.tokenizer.text_to_ids(text) + + if len(text_ids) > 0 and append_eod: + text_ids.append(self.tokenizer.eos_id) + + return text_ids, len(text_ids) + + def __getitem__(self, idx): + """Returns a pair of chosen/rejected pairs, their respective lengths, and labels. + """ + payload = self.data[idx] + prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) + responses = [] + labels = [] + + # loop on responses of the given prompt to encode them + for resp in payload['responses']: + resp_tokens, resp_len = self.encode( + resp, append_eod=self.cfg.data.get("append_eod", False) + ) + + resp_tokens = prompt + resp_tokens + resp_len = len(resp_tokens) + + responses.append((resp_tokens, resp_len)) + labels.append(([-100] * prompt_len) + resp_tokens[prompt_len:]) + + assert resp_tokens[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response" + + max_curr_seq_len = max([i[1] for i in responses]) + if max_curr_seq_len > self.seq_length: + logging.warning( + f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})." + + f"The example will be ignored." + ) + + rewards = payload.get("rewards", [random.random() for _ in range(len(responses))]) + resp_dict = {} + + for ind, (resp, resp_len) in enumerate(responses): + resp_tokens = torch.nn.functional.pad( + torch.LongTensor(resp), (0, max_curr_seq_len - resp_len), mode="constant", value=self.eos_id + ) + label = labels[ind] + label_tokens = torch.nn.functional.pad( + torch.LongTensor(label), (0, max_curr_seq_len - len(label)), mode="constant", value=-100 + ) + + # slice if necessary + if max_curr_seq_len > self.seq_length: + resp_tokens = resp_tokens[: self.nograd_length] + label_tokens = torch.ones_like(resp_tokens) * (-100) + resp_len = self.nograd_length + + resp_dict['response_' + str(ind+1)] = resp_tokens + resp_dict['labels_' + str(ind+1)] = label_tokens + resp_dict['lengths_' + str(ind+1)] = resp_len + resp_dict['rewards_' + str(ind+1)] = rewards[ind] + + return resp_dict + + + class RegressionRewardModelDataset(RewardModelDataset): """This class assumes each line of the dataset file is a dictionary with "text" and "label" field, where "text" is a string representing the input prompt, and "label" is a list of float or int values. diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 152b31f6c..6752a00f2 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -63,7 +63,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.sft_avg_log_probs = self.cfg.dpo.get("sft_average_log_probs", self.preference_avg_log_probs) self.preference_loss_weight = self.cfg.dpo.get("preference_loss_weight", 1) - self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0) + self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0.7) assert ( self.preference_loss_weight != 0 or self.sft_loss_weight != 0 ), "sft loss weight and dpo loss weight cannot both be 0" @@ -144,6 +144,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # position_ids = batch["position_ids"][0:1] attention_mask = batch["attention_mask"][0:1] + print(attention_mask.shape, tokens.shape) # Model forward pass forward_args = { "input_ids": tokens, @@ -197,6 +198,7 @@ def loss_func(output_tensor): ) sft_loss = torch.zeros_like(preference_loss) + self.sft_loss_weight = 0.7 if self.sft_loss_weight != 0: sft_loss = self.sft_loss_func( per_token_logps, labels[:, 1:], average_log_probs=self.sft_avg_log_probs @@ -249,10 +251,15 @@ def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): return (logps * loss_mask).sum(-1) def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): + print('REF:', ref_logprobs.shape) rewards = self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) + print('PI:', pi_logprobs.shape, rewards, chosen_rewards) + print('======') + exit(0) + rewards_delta = chosen_rewards - reject_rewards if self.preference_loss == "dpo": @@ -302,6 +309,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) chosen_logprobs, _ = self.split_output_tensor(logprobs) + print(chosen_logprobs) return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py new file mode 100644 index 000000000..955446fe2 --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -0,0 +1,396 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import torch +from apex.transformer.pipeline_parallel.utils import get_num_microbatches +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo_aligner.models.alignable_interface import SupervisedInterface +from nemo_aligner.utils.distributed import broadcast_2d_tensor, from_parallel_logits_to_logprobs +from nemo_aligner.utils.train_utils import ( + finish_validation_step, + grad_reductions, + prepare_for_training_step, + prepare_for_validation_step, + set_sync_funcs, +) +from nemo_aligner.utils.utils import adapter_control, cpu_weight_swap + + +class MegatronGPTRPOModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface): + """ + Megatron GPT RPO Model Training. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + + if self.cfg.pipeline_model_parallel_size > 1 and not self.cfg.megatron_amp_O2: + warnings.warn( + "when using pipeline parallelism, it is recommended to set megatron_amp_O2 to be True to " + "avoid explicit casting for pipeline communication" + ) + self.automatic_optimization = False + self.ref_policy_state_dict = None + + self.preference_avg_log_probs = self.cfg.rpo.get("preference_average_log_probs", False) + self.sft_avg_log_probs = self.cfg.rpo.get("sft_average_log_probs", self.preference_avg_log_probs) + + self.preference_loss_weight = self.cfg.rpo.get("preference_loss_weight", 1) + self.sft_loss_weight = self.cfg.rpo.get("sft_loss_weight", 0) + assert ( + self.preference_loss_weight != 0 or self.sft_loss_weight != 0 + ), "sft loss weight and dpo loss weight cannot both be 0" + + # variants of preference losses, by default RPO. + self.preference_loss = self.cfg.rpo.get("preference_loss", "rpo") + self.gt_reward_scale = self.cfg.rpo.get("gt_reward_scale", 1.0) + + self.beta = self.cfg.rpo.get("beta", 0.01) + self.eta = self.cfg.rpo.get("eta", 0.01) + self.k_len = 4 + + @torch.no_grad() + def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): + pi_logprobs = pi_logprobs.detach() + + dp_group = parallel_state.get_data_parallel_group() + + batch_logs = self.get_reduced_masked_logps( + pi_logprobs - ref_logprobs, labels[:, 1:], average_log_probs=average_log_probs + ) + + output_list = [torch.zeros_like(batch_logs) for _ in range(dp_group.size())] + + torch.distributed.all_gather(output_list, batch_logs, group=dp_group) + + split_iter = map(self.split_output_tensor, output_list) + outputs = map(torch.cat, zip(*split_iter)) + + return outputs.flatten() + + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + batch = next(dataloader_iter) + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + # there is a problem with apex ignoring the mask on the older models + # so we will always give the attention mask + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("response_1", "response_2", "response_3", "response_4", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update( + ( + "ref_policy_log_probs_response_1" + "ref_policy_log_probs_response_2" + "ref_policy_log_probs_response_3" + "ref_policy_log_probs_response_4" + "response_1", "response_2", "response_3", "response_4", + "labels_1", "labels_2", "labels_3", "labels_4", + "rewards_1", "rewards_2", "rewards_3", "rewards_4", + ) + ) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + # creating tokens and labels tensor batches + tokens, labels, ref_logprobs, gt_rewards = None, None, None, None + if batch["response_1"] is not None: + tokens = torch.cat(tuple(batch['response_' + str(i+1)] for i in range(self.k_len)), dim=0) + + if batch["labels_1"] is not None: + labels = torch.cat(tuple(batch['labels_' + str(i+1)] for i in range(self.k_len)), dim=0) + + if batch["rewards_1"] is not None: + gt_rewards = torch.cat(tuple(batch['rewards_' + str(i+1)] for i in range(self.k_len)), dim=0) + + if batch.get("ref_policy_log_probs_response_1") is not None: + ref_logprobs = torch.cat(tuple(batch['ref_policy_log_probs_response_' + str(i+1)] for i in range(self.k_len)), dim=0) + + # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs + # these two lines ensure your position_ids and attn_mask are always B=1 + attention_mask = batch['attention_mask'][0:1] + + # Model forward pass + forward_args = { + "input_ids": tokens, + "position_ids": batch["position_ids"], + "attention_mask": attention_mask, + "labels": None, + "loss_mask": None, + } + + # TODO: we can remove this someday when we no longer support legacy models + if not self.mcore_gpt: + forward_args["checkpoint_activations_all_layers"] = checkpoint_activations_all_layers + if not self.use_loss_mask: + forward_args.pop("loss_mask") + else: + forward_args.pop("loss_mask") + + output_tensor = model(**forward_args) + + # in this nemo version the model and autocast dtypes are not synced + # so we need to explicitly cast it + if not parallel_state.is_pipeline_last_stage(): + output_tensor = output_tensor.to(dtype=self.autocast_dtype) + + def logprobs_func(output_tensor, non_loss_data=True): + # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. + # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 + assert non_loss_data + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + ) + return {"logprobs": logprobs} + + def loss_func(output_tensor): + if validation_step and not self.cfg.data.get("validation_drop_last", True): + raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") + + per_token_logps = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=labels, + inference_only=validation_step, + higher_stability=True, + ) + + # print('Before Loss:', per_token_logps.shape) + preference_loss = self.loss_func( + per_token_logps, + ref_logprobs, + labels[:, 1:], + gt_rewards, + average_log_probs=self.preference_avg_log_probs, + ) + + sft_loss = torch.zeros_like(preference_loss) + if self.sft_loss_weight != 0: + sft_loss = self.sft_loss_func( + per_token_logps, labels[:, 1:], gt_rewards, average_log_probs=self.sft_avg_log_probs + ) + loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss + + print(loss, preference_loss, sft_loss) + + ( + reduced_loss, + reduced_preference_loss, + reduced_sft_loss, + ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss]) + + return ( + loss, + { + "avg": reduced_loss, + "avg_sft_loss": reduced_sft_loss, + "avg_preference_loss": reduced_preference_loss, + }, + ) + + if logprobs_only: + return output_tensor, logprobs_func + else: + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def split_output_tensor(self, output_tensor): + responses_logps = torch.split(output_tensor.float(), len(output_tensor) // self.k_len, dim=0) + return responses_logps + + def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): + assert logps.shape == labels.shape, "logps and labels shape mismatch" + + loss_mask = (labels > -1).float() + + if average_log_probs: + # need to guard against divide by zero in case labels are all -100 + return (logps * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1) + else: + return (logps * loss_mask).sum(-1) + + def log_sum_exp(self, x): + max_x = torch.max(x) + return max_x + torch.log(torch.sum(torch.exp(x - max_x))) + + def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): + if self.preference_loss == 'rpo_multi_response': + # estimated rewards + rewards = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( + pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs + ))) + rewards = torch.nn.functional.softmax(self.beta * rewards, dim=0) + + + # based on GT rewards + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) + p_star = torch.nn.functional.softmax(self.eta * gt_rewards, dim=0) + + print(rewards.shape, p_star.shape) + else: + raise ValueError("Unknown RPO Loss") + + loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards + 1e-8 ))).sum(0).unsqueeze(0) + print('>>', loss) + return loss.mean(0) + + def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) + all_log_probs = torch.stack(self.split_output_tensor(logprobs)) + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) + chosen_best = torch.argmax(gt_rewards, dim=0) + + chosen_logprobs = all_log_probs[chosen_best] + return -chosen_logprobs.mean(0) + + def get_loss_and_metrics(self, batch, forward_only): + seq_length = batch["response_1"].shape[1] + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # NOTE: assume that the returned values are already gathered across the DP workers + loss_mean = torch.as_tensor( + [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + sft_loss_mean = torch.as_tensor( + [loss_reduced["avg_sft_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + preference_loss_mean = torch.as_tensor( + [loss_reduced["avg_preference_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + sft_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + preference_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(loss_mean, get_last_rank()) + torch.distributed.broadcast(sft_loss_mean, get_last_rank()) + torch.distributed.broadcast(preference_loss_mean, get_last_rank()) + + metrics = { + "loss": loss_mean, + "sft_loss": sft_loss_mean, + "preference_loss": preference_loss_mean, + } + + # move to CPU + metrics = {k: v.item() for k, v in metrics.items()} + + return loss_mean.item(), metrics + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def finish_training_step(self): + grad_reductions(self) + + def prepare_for_validation_step(self): + prepare_for_validation_step(self) + + def finish_validation_step(self): + finish_validation_step(self) + + @torch.no_grad() + def get_logprob_batch(self, batch): + seq_length = batch["response_1"].shape[1] + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + set_sync_funcs(self, forward_only=True) + + fwd_bwd_function = get_forward_backward_func() + + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=seq_length, + micro_batch_size=self.cfg.rpo.log_prob_forward_micro_batch_size, + collect_non_loss_data=True, + ) + + if len(logprobs_list) > 0: + logprobs_list_cat = [] + for item in logprobs_list: + all_log_probs = self.split_output_tensor(item["logprobs"]) + logprobs_list_cat.extend(all_log_probs) + logprobs = torch.cat(logprobs_list_cat) + else: + logprobs = None + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + # broadcast it from last PP stage to everything else + logprobs = broadcast_2d_tensor( + logprobs, + parallel_state.get_pipeline_model_parallel_last_rank(), + parallel_state.get_pipeline_model_parallel_group(), + ) + + return logprobs + + def get_ref_policy_logprobs(self, batch): + + if self.use_peft and self.ref_policy_state_dict is None: + # when using adapters instead of full-tuning, the actor is reference model + adapters + with adapter_control(self): + # With adapters disabled (meaning using the reference model), calculate ref_log_probs + ref_log_probs = self.get_logprob_batch(batch) + else: + with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): + ref_log_probs = self.get_logprob_batch(batch) + + # return in GPU, trainer needs to move to cpu + return ref_log_probs diff --git a/nemo_aligner/utils/train_utils.py b/nemo_aligner/utils/train_utils.py index 82b1dd660..5c8b32087 100644 --- a/nemo_aligner/utils/train_utils.py +++ b/nemo_aligner/utils/train_utils.py @@ -39,6 +39,7 @@ def set_sync_funcs(ptl_model, forward_only): # pipeline schedules will get these from ptl_model.model.config for module in ptl_model.get_model_module_list(): + # for module in ptl_model.get_gpt_module_list(): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func module.config.param_sync_func = param_sync_func diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index 78bc48b22..66968a4d0 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -128,6 +128,7 @@ def load_and_override_model_config(restore_path, model_cfg_to_overwrite, remove_ """load the config in the model checkpoint and then overwrite it with whatever is provided """ + print('>>>>', restore_path) checkpoint_cfg = load_checkpoint_model_config(restore_path) if remove_meta_info: From cb6cbb7023c79491ba9d00bc87bd2528a3592d1f Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Fri, 6 Sep 2024 11:41:00 -0700 Subject: [PATCH 08/17] updated RPO loss Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_dpo.yaml | 10 +++++----- examples/nlp/gpt/conf/gpt_rpo.yaml | 8 ++++---- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 14 ++++++++------ .../models/nlp/gpt/megatron_gpt_rpo_model.py | 17 +++++++---------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 88be4f3fe..2bf46317b 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -3,7 +3,7 @@ defaults: trainer: num_nodes: 1 - devices: 4 + devices: 8 accelerator: gpu precision: bf16 @@ -50,12 +50,12 @@ exp_manager: model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} pretrained_checkpoint: - restore_from_path: /data/GPT-2B-001_bf16_tp1.nemo + restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo model: - mcore_gpt: False - micro_batch_size: 4 - global_batch_size: 32 + mcore_gpt: True + micro_batch_size: 16 + global_batch_size: 64 megatron_amp_O2: True dpo: diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index aa8fd535e..6a636234e 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -9,9 +9,9 @@ trainer: # rpo specific args rpo: - max_epochs: 100 + max_epochs: 1000 max_steps: -1 - val_check_interval: 0.1 + val_check_interval: 100 save_interval: 100 limit_train_batches: 1.0 @@ -54,8 +54,8 @@ pretrained_checkpoint: model: mcore_gpt: True - micro_batch_size: 1 - global_batch_size: 16 + micro_batch_size: 4 + global_batch_size: 8 megatron_amp_O2: True rpo: diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 6752a00f2..a6fd9a3f7 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -144,7 +144,6 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # position_ids = batch["position_ids"][0:1] attention_mask = batch["attention_mask"][0:1] - print(attention_mask.shape, tokens.shape) # Model forward pass forward_args = { "input_ids": tokens, @@ -205,6 +204,8 @@ def loss_func(output_tensor): ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss + # print('IMP:', loss, preference_loss, sft_loss) + ( reduced_loss, reduced_preference_loss, @@ -251,17 +252,14 @@ def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): return (logps * loss_mask).sum(-1) def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): - print('REF:', ref_logprobs.shape) rewards = self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) - print('PI:', pi_logprobs.shape, rewards, chosen_rewards) - print('======') - exit(0) rewards_delta = chosen_rewards - reject_rewards + if self.preference_loss == "dpo": loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) elif self.preference_loss == "rpo_bwd_kl": @@ -307,9 +305,13 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p return loss, acc_chosen def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): + print(pi_logprobs.shape) logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) chosen_logprobs, _ = self.split_output_tensor(logprobs) - print(chosen_logprobs) + + print(chosen_logprobs, chosen_logprobs.shape) + exit(0) + return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index 955446fe2..1ca86366c 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -185,7 +185,6 @@ def loss_func(output_tensor): higher_stability=True, ) - # print('Before Loss:', per_token_logps.shape) preference_loss = self.loss_func( per_token_logps, ref_logprobs, @@ -201,7 +200,6 @@ def loss_func(output_tensor): ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss - print(loss, preference_loss, sft_loss) ( reduced_loss, @@ -257,21 +255,20 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) p_star = torch.nn.functional.softmax(self.eta * gt_rewards, dim=0) - print(rewards.shape, p_star.shape) else: raise ValueError("Unknown RPO Loss") - loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards + 1e-8 ))).sum(0).unsqueeze(0) - print('>>', loss) - return loss.mean(0) + loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards + 1e-8 ))).sum(0).mean(0) + + return loss def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): - logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) - all_log_probs = torch.stack(self.split_output_tensor(logprobs)) - gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] + all_log_probs = torch.stack(self.split_output_tensor(logprobs)) # [4, 4] -> each has several responses which we select the best? + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards chosen_best = torch.argmax(gt_rewards, dim=0) - chosen_logprobs = all_log_probs[chosen_best] + chosen_logprobs = all_log_probs[chosen_best, torch.arange(all_log_probs.size(1))] return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): From 2b34846aa0424c3ef10da67560dc711545773ad5 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Fri, 6 Sep 2024 16:36:26 -0700 Subject: [PATCH 09/17] added metrics Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_rpo.yaml | 34 +++++----- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 11 +--- .../models/nlp/gpt/megatron_gpt_rpo_model.py | 66 ++++++++++++++----- 3 files changed, 70 insertions(+), 41 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index 6a636234e..906a22b02 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -6,7 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 - + # rpo specific args rpo: max_epochs: 1000 @@ -16,7 +16,7 @@ trainer: limit_train_batches: 1.0 # how many GBS we loop over - limit_val_batches: 1.0 + limit_val_batches: 0.01 gradient_clip_val: 1.0 # do not change these @@ -46,13 +46,13 @@ exp_manager: mode: min always_save_nemo: False # saves nemo file during validation, not implemented for model parallel save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits - filename: 'megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}' + filename: "megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}" model_parallel_size: 4 pretrained_checkpoint: restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo -model: +model: mcore_gpt: True micro_batch_size: 4 global_batch_size: 8 @@ -70,28 +70,28 @@ model: sft_loss_weight: 0.05 # the coefficient of the SFT loss beta: 0.2 eta: 0.2 - + #encoder_seq_length: 4096 #max_position_embeddings: ${model.encoder_seq_length} # miscellaneous seed: 1234 - #peft + #peft peft: - peft_scheme: "none" # ["lora", "none"] + peft_scheme: "none" # ["lora", "none"] restore_from_path: null restore_from_ckpt: checkpoint_dir: null checkpoint_name: null lora_tuning: - target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + target_modules: ["attention_qkv"] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' adapter_dim: 32 adapter_dropout: 0.0 - column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal - row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal - layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + column_init_method: "xavier" # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: "zero" # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers weight_tying: False position_embedding_strategy: null # used only when weight_tying is True @@ -101,10 +101,10 @@ model: overlap_grad_sync: False contiguous_grad_buffer: True lr: 9e-6 - weight_decay: 0.1 - betas: - - 0.9 - - 0.98 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 sched: name: CosineAnnealing warmup_steps: 10 @@ -122,7 +122,7 @@ model: reset_attention_mask: False # Reset attention mask after end-of-document token eod_mask_loss: False # Mask loss for the end of document tokens index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix - data_prefix: + data_prefix: train: - /data/responses_general.rm.formatted.jsonl test: @@ -131,7 +131,7 @@ model: - /data/small_test_formatted.jsonl default_chosen_reward: 1. # the default reward for the chosen response in RPO default_rejected_reward: 0. # the default reward for the rejected response in RPO - + # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: data: diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index a6fd9a3f7..3f2f67334 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -87,7 +87,6 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo torch.distributed.all_gather(output_list, batch_logs, group=dp_group) split_iter = map(self.split_output_tensor, output_list) - out_chosen, out_rejected = map(torch.cat, zip(*split_iter)) return out_chosen.flatten(), out_rejected.flatten() @@ -204,8 +203,6 @@ def loss_func(output_tensor): ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss - # print('IMP:', loss, preference_loss, sft_loss) - ( reduced_loss, reduced_preference_loss, @@ -256,7 +253,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) - + rewards_delta = chosen_rewards - reject_rewards @@ -305,13 +302,9 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p return loss, acc_chosen def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): - print(pi_logprobs.shape) logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) chosen_logprobs, _ = self.split_output_tensor(logprobs) - - print(chosen_logprobs, chosen_logprobs.shape) - exit(0) - + return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index 1ca86366c..20d70a533 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -70,7 +70,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # variants of preference losses, by default RPO. self.preference_loss = self.cfg.rpo.get("preference_loss", "rpo") self.gt_reward_scale = self.cfg.rpo.get("gt_reward_scale", 1.0) - + self.beta = self.cfg.rpo.get("beta", 0.01) self.eta = self.cfg.rpo.get("eta", 0.01) self.k_len = 4 @@ -91,8 +91,9 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo split_iter = map(self.split_output_tensor, output_list) outputs = map(torch.cat, zip(*split_iter)) + flat_outputs = list(map(torch.flatten, outputs)) - return outputs.flatten() + return flat_outputs def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): @@ -130,7 +131,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if batch["labels_1"] is not None: labels = torch.cat(tuple(batch['labels_' + str(i+1)] for i in range(self.k_len)), dim=0) - + if batch["rewards_1"] is not None: gt_rewards = torch.cat(tuple(batch['rewards_' + str(i+1)] for i in range(self.k_len)), dim=0) @@ -140,7 +141,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs # these two lines ensure your position_ids and attn_mask are always B=1 attention_mask = batch['attention_mask'][0:1] - + # Model forward pass forward_args = { "input_ids": tokens, @@ -185,7 +186,7 @@ def loss_func(output_tensor): higher_stability=True, ) - preference_loss = self.loss_func( + preference_loss, acc_best_resp = self.loss_func( per_token_logps, ref_logprobs, labels[:, 1:], @@ -205,7 +206,12 @@ def loss_func(output_tensor): reduced_loss, reduced_preference_loss, reduced_sft_loss, - ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss]) + reduced_acc, + ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss, acc_best_resp]) + + out_responses = self.gather_and_split_rewards( + per_token_logps, ref_logprobs, labels, average_log_probs=self.preference_avg_log_probs + ) return ( loss, @@ -213,6 +219,8 @@ def loss_func(output_tensor): "avg": reduced_loss, "avg_sft_loss": reduced_sft_loss, "avg_preference_loss": reduced_preference_loss, + "acc": reduced_acc, + "out_responses": out_responses, }, ) @@ -248,19 +256,22 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p rewards = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ))) - rewards = torch.nn.functional.softmax(self.beta * rewards, dim=0) + rewards_pred = torch.nn.functional.softmax(self.beta * rewards, dim=0) + - # based on GT rewards gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) p_star = torch.nn.functional.softmax(self.eta * gt_rewards, dim=0) - + else: raise ValueError("Unknown RPO Loss") - loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards + 1e-8 ))).sum(0).mean(0) - - return loss + loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards_pred + 1e-8 ))).sum(0).mean(0) + + # adding accuracy for the best rewards -> MSE or best accuracy? + acc_best_resp = (torch.argmax(rewards, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() + + return loss, acc_best_resp def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] @@ -273,7 +284,7 @@ def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False def get_loss_and_metrics(self, batch, forward_only): seq_length = batch["response_1"].shape[1] - + data_iter = get_iterator_k_split(batch, get_num_microbatches()) set_sync_funcs(self, forward_only) @@ -292,6 +303,16 @@ def get_loss_and_metrics(self, batch, forward_only): # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: # NOTE: assume that the returned values are already gathered across the DP workers + collected_rewards_per_resp = [] + for i in range(self.k_len): + collected_rewards_per_resp.append( + torch.cat([item["out_responses"][i] for item in losses_reduced_per_micro_batch]) + ) + + rewards_all = torch.cat(tuple(collected_rewards_per_resp)) + rewards_all_mean = rewards_all.mean() + rewards_all_std = rewards_all.std() + loss_mean = torch.as_tensor( [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch], device=torch.cuda.current_device(), @@ -304,20 +325,35 @@ def get_loss_and_metrics(self, batch, forward_only): [loss_reduced["avg_preference_loss"] for loss_reduced in losses_reduced_per_micro_batch], device=torch.cuda.current_device(), ).mean() + acc_mean = torch.as_tensor( + [loss_reduced["acc"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() else: loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) sft_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) preference_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + acc_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + rewards_all_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + rewards_all_std = torch.tensor(0.0, device=torch.cuda.current_device()) # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(loss_mean, get_last_rank()) torch.distributed.broadcast(sft_loss_mean, get_last_rank()) torch.distributed.broadcast(preference_loss_mean, get_last_rank()) + torch.distributed.broadcast(acc_mean, get_last_rank()) + + torch.distributed.broadcast(rewards_all_mean, get_last_rank()) + torch.distributed.broadcast(rewards_all_std, get_last_rank()) metrics = { "loss": loss_mean, "sft_loss": sft_loss_mean, "preference_loss": preference_loss_mean, + "acc": acc_mean, + "rewards_all_mean": rewards_all_mean, + "rewards_all_std": rewards_all_std, } # move to CPU @@ -343,7 +379,7 @@ def get_logprob_batch(self, batch): seq_length = batch["response_1"].shape[1] data_iter = get_iterator_k_split(batch, get_num_microbatches()) - + set_sync_funcs(self, forward_only=True) fwd_bwd_function = get_forward_backward_func() @@ -358,7 +394,7 @@ def get_logprob_batch(self, batch): micro_batch_size=self.cfg.rpo.log_prob_forward_micro_batch_size, collect_non_loss_data=True, ) - + if len(logprobs_list) > 0: logprobs_list_cat = [] for item in logprobs_list: From d855f5f80fc5d43b5fe48a90edcae39138f2475f Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Mon, 9 Sep 2024 21:08:37 -0700 Subject: [PATCH 10/17] bug fixes Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_dpo.yaml | 8 ++++---- examples/nlp/gpt/conf/gpt_rpo.yaml | 6 +++--- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 3 +-- .../models/nlp/gpt/megatron_gpt_rpo_model.py | 12 ++++++------ 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 2bf46317b..b4a979f79 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -55,7 +55,7 @@ pretrained_checkpoint: model: mcore_gpt: True micro_batch_size: 16 - global_batch_size: 64 + global_batch_size: 128 megatron_amp_O2: True dpo: @@ -123,11 +123,11 @@ model: index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_prefix: train: - - /data/sample_dpo_dataset.jsonl + - /data/responses_general.rm.formatted.dpo.jsonl test: - - /data/sample_dpo_dataset.jsonl + - /data/dpo_test_set.jsonl validation: - - /data/sample_dpo_dataset.jsonl + - /data/dpo_test_set.jsonl default_chosen_reward: 1. # the default reward for the chosen response in RPO default_rejected_reward: 0. # the default reward for the rejected response in RPO diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index 906a22b02..1f286b5f9 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -9,14 +9,14 @@ trainer: # rpo specific args rpo: - max_epochs: 1000 + max_epochs: 1 max_steps: -1 val_check_interval: 100 save_interval: 100 limit_train_batches: 1.0 # how many GBS we loop over - limit_val_batches: 0.01 + limit_val_batches: 10 gradient_clip_val: 1.0 # do not change these @@ -42,7 +42,7 @@ exp_manager: create_checkpoint_callback: True checkpoint_callback_params: monitor: val_loss - save_top_k: 3 + save_top_k: 5 mode: min always_save_nemo: False # saves nemo file during validation, not implemented for model parallel save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 3f2f67334..37d19ceb1 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -63,7 +63,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.sft_avg_log_probs = self.cfg.dpo.get("sft_average_log_probs", self.preference_avg_log_probs) self.preference_loss_weight = self.cfg.dpo.get("preference_loss_weight", 1) - self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0.7) + self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0) assert ( self.preference_loss_weight != 0 or self.sft_loss_weight != 0 ), "sft loss weight and dpo loss weight cannot both be 0" @@ -196,7 +196,6 @@ def loss_func(output_tensor): ) sft_loss = torch.zeros_like(preference_loss) - self.sft_loss_weight = 0.7 if self.sft_loss_weight != 0: sft_loss = self.sft_loss_func( per_token_logps, labels[:, 1:], average_log_probs=self.sft_avg_log_probs diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index 20d70a533..bde197ae8 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -61,18 +61,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.preference_avg_log_probs = self.cfg.rpo.get("preference_average_log_probs", False) self.sft_avg_log_probs = self.cfg.rpo.get("sft_average_log_probs", self.preference_avg_log_probs) - self.preference_loss_weight = self.cfg.rpo.get("preference_loss_weight", 1) - self.sft_loss_weight = self.cfg.rpo.get("sft_loss_weight", 0) + self.preference_loss_weight = float(self.cfg.rpo.get("preference_loss_weight", 1)) + self.sft_loss_weight = float(self.cfg.rpo.get("sft_loss_weight", 0)) assert ( self.preference_loss_weight != 0 or self.sft_loss_weight != 0 ), "sft loss weight and dpo loss weight cannot both be 0" # variants of preference losses, by default RPO. self.preference_loss = self.cfg.rpo.get("preference_loss", "rpo") - self.gt_reward_scale = self.cfg.rpo.get("gt_reward_scale", 1.0) + self.gt_reward_scale = float(self.cfg.rpo.get("gt_reward_scale", 1.0)) - self.beta = self.cfg.rpo.get("beta", 0.01) - self.eta = self.cfg.rpo.get("eta", 0.01) + self.beta = float(self.cfg.rpo.get("beta", 0.01)) + self.eta = float(self.cfg.rpo.get("eta", 0.01)) self.k_len = 4 @torch.no_grad() @@ -297,7 +297,7 @@ def get_loss_and_metrics(self, batch, forward_only): num_microbatches=get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + micro_batch_size=self.cfg.micro_batch_size * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses ) # only the last stages of the pipeline return losses From 780b13433e4f9b6037c36defc0a399671199b518 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Wed, 11 Sep 2024 17:52:24 -0700 Subject: [PATCH 11/17] bug fix Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_rpo.yaml | 10 +++--- nemo_aligner/algorithms/rpo.py | 4 ++- .../models/nlp/gpt/megatron_gpt_rpo_model.py | 32 +++++++------------ 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index 1f286b5f9..7aec6ffe0 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -18,6 +18,7 @@ trainer: # how many GBS we loop over limit_val_batches: 10 gradient_clip_val: 1.0 + num_responses: 2 # do not change these logger: False # logger provided by exp_manager @@ -70,6 +71,7 @@ model: sft_loss_weight: 0.05 # the coefficient of the SFT loss beta: 0.2 eta: 0.2 + num_responses: 2 #encoder_seq_length: 4096 #max_position_embeddings: ${model.encoder_seq_length} @@ -124,13 +126,11 @@ model: index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_prefix: train: - - /data/responses_general.rm.formatted.jsonl + - /data/responses_general.rm.formatted.2resp.jsonl test: - - /data/small_test_formatted.jsonl + - /data/small_test_formatted_2resp.jsonl validation: - - /data/small_test_formatted.jsonl - default_chosen_reward: 1. # the default reward for the chosen response in RPO - default_rejected_reward: 0. # the default reward for the rejected response in RPO + - /data/small_test_formatted_2resp.jsonl # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: diff --git a/nemo_aligner/algorithms/rpo.py b/nemo_aligner/algorithms/rpo.py index 221e2b7ed..a7de3d2ac 100644 --- a/nemo_aligner/algorithms/rpo.py +++ b/nemo_aligner/algorithms/rpo.py @@ -32,6 +32,7 @@ def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): resp_outputs = {} + ## assume len 4 for responses for k in batch[0].keys(): if k.startswith('response_'): @@ -117,6 +118,7 @@ def __init__( self.timer = SyncTimer( reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX ) + self.k_len = int(self.cfg.num_responses) def validation_step(self, global_batch): # these things should go into a GPTModel wrapper @@ -318,7 +320,7 @@ def augment_dataloader(self, dataloader): logprobs = self.model.get_ref_policy_logprobs(batch).cpu() ind = 1 - for logps in torch.split(logprobs, len(logprobs) // 4, dim=0): + for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): batch["ref_policy_log_probs_response_" + str(ind)] = logps ind += 1 diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index bde197ae8..d2a50871a 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -73,7 +73,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.beta = float(self.cfg.rpo.get("beta", 0.01)) self.eta = float(self.cfg.rpo.get("eta", 0.01)) - self.k_len = 4 + self.k_len = int(self.cfg.rpo.get("num_responses", 2)) @torch.no_grad() def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): @@ -107,20 +107,14 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.add("attention_mask") if parallel_state.is_pipeline_first_stage(): - required_keys.update(("response_1", "response_2", "response_3", "response_4", "position_ids")) + required_keys.update(["response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(("position_ids")) if parallel_state.is_pipeline_last_stage(): - required_keys.update( - ( - "ref_policy_log_probs_response_1" - "ref_policy_log_probs_response_2" - "ref_policy_log_probs_response_3" - "ref_policy_log_probs_response_4" - "response_1", "response_2", "response_3", "response_4", - "labels_1", "labels_2", "labels_3", "labels_4", - "rewards_1", "rewards_2", "rewards_3", "rewards_4", - ) - ) + required_keys.update(["response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["ref_policy_log_probs_response_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["labels_" + str(i) for i in range(1, self.k_len + 1)]) + required_keys.update(["rewards_" + str(i) for i in range(1, self.k_len + 1)]) batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} @@ -177,7 +171,7 @@ def logprobs_func(output_tensor, non_loss_data=True): def loss_func(output_tensor): if validation_step and not self.cfg.data.get("validation_drop_last", True): - raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") + raise NotImplementedError("RPO does not support validation when cfg.data.drop_last=False") per_token_logps = from_parallel_logits_to_logprobs( vocab_parallel_logits=output_tensor, @@ -256,18 +250,16 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p rewards = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ))) - rewards_pred = torch.nn.functional.softmax(self.beta * rewards, dim=0) - + rewards_pred = self.beta * rewards # based on GT rewards gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) - p_star = torch.nn.functional.softmax(self.eta * gt_rewards, dim=0) - + p_star = self.eta * gt_rewards else: raise ValueError("Unknown RPO Loss") - loss = (p_star * (torch.log( p_star + 1e-8 ) - torch.log( rewards_pred + 1e-8 ))).sum(0).mean(0) - + loss = ( p_star * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) + # adding accuracy for the best rewards -> MSE or best accuracy? acc_best_resp = (torch.argmax(rewards, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() From 19d7882386544b2f7ddb78f22668ae3a7cf77ed9 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Sun, 15 Sep 2024 21:45:11 -0700 Subject: [PATCH 12/17] bugfix - numerical error Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_rpo.yaml | 14 +++++++------- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 2 ++ .../models/nlp/gpt/megatron_gpt_rpo_model.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index 7aec6ffe0..bda090bb3 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -18,7 +18,7 @@ trainer: # how many GBS we loop over limit_val_batches: 10 gradient_clip_val: 1.0 - num_responses: 2 + num_responses: 4 # do not change these logger: False # logger provided by exp_manager @@ -62,16 +62,16 @@ model: rpo: # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. # A higher value can be used to speed-up log probs computations, but may cause numeric differences. - log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, 2} + log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, trainer.rpo.num_responses} preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss gt_reward_scale: 1. # the scale of the rewards in RPO - preference_loss: rpo_multi_response # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + preference_loss: rpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl preference_loss_weight: 1 # the coefficient of the preference loss sft_loss_weight: 0.05 # the coefficient of the SFT loss beta: 0.2 eta: 0.2 - num_responses: 2 + num_responses: ${trainer.rpo.num_responses} #encoder_seq_length: 4096 #max_position_embeddings: ${model.encoder_seq_length} @@ -126,11 +126,11 @@ model: index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_prefix: train: - - /data/responses_general.rm.formatted.2resp.jsonl + - /data/train_rpo_4resps.jsonl test: - - /data/small_test_formatted_2resp.jsonl + - /data/val_rpo_4resps.jsonl validation: - - /data/small_test_formatted_2resp.jsonl + - /data/val_rpo_4resps.jsonl # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 37d19ceb1..282837ac1 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -253,6 +253,8 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) + print('Rewards:', rewards, gt_rewards) + rewards_delta = chosen_rewards - reject_rewards diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index d2a50871a..a452e55d4 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -245,7 +245,7 @@ def log_sum_exp(self, x): return max_x + torch.log(torch.sum(torch.exp(x - max_x))) def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): - if self.preference_loss == 'rpo_multi_response': + if self.preference_loss == 'rpo': # estimated rewards rewards = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs @@ -258,7 +258,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p else: raise ValueError("Unknown RPO Loss") - loss = ( p_star * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) + loss = ( torch.nn.functional.softmax(p_star, dim=0) * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) # adding accuracy for the best rewards -> MSE or best accuracy? acc_best_resp = (torch.argmax(rewards, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() From 6b03c8c8bd5192ed0a4426a21816b9f642ee6f53 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 24 Sep 2024 10:09:24 -0700 Subject: [PATCH 13/17] updated RPO -> bug fixes Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_dpo.yaml | 23 +++++++------------ examples/nlp/gpt/conf/gpt_rpo.yaml | 6 ++--- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 23 +++++++++---------- .../models/nlp/gpt/megatron_gpt_rpo_model.py | 18 ++++++++------- 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index b4a979f79..cf7215346 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -2,7 +2,7 @@ defaults: - optional tp_overlap@model.ub_tp_comm_overlap_cfg: trainer: - num_nodes: 1 + num_nodes: 8 devices: 8 accelerator: gpu precision: bf16 @@ -12,7 +12,7 @@ trainer: max_epochs: 1 max_steps: -1 val_check_interval: 0.1 - save_interval: 5 + save_interval: 100 limit_train_batches: 1.0 # how many GBS we loop over @@ -50,18 +50,18 @@ exp_manager: model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} pretrained_checkpoint: - restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo + restore_from_path: null model: mcore_gpt: True - micro_batch_size: 16 - global_batch_size: 128 + micro_batch_size: 1 + global_batch_size: 64 megatron_amp_O2: True dpo: # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. # A higher value can be used to speed-up log probs computations, but may cause numeric differences. - log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, 2} + log_prob_forward_micro_batch_size: ${model.micro_batch_size} ref_policy_kl_penalty: 0.2 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss @@ -116,18 +116,11 @@ model: seq_length: ${model.encoder_seq_length} skip_warmup: True num_workers: 0 - dataloader_type: single # cyclic reset_position_ids: False # Reset position ids after end-of-document token reset_attention_mask: False # Reset attention mask after end-of-document token eod_mask_loss: False # Mask loss for the end of document tokens index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix - data_prefix: - train: - - /data/responses_general.rm.formatted.dpo.jsonl - test: - - /data/dpo_test_set.jsonl - validation: - - /data/dpo_test_set.jsonl + data_prefix: null default_chosen_reward: 1. # the default reward for the chosen response in RPO default_rejected_reward: 0. # the default reward for the rejected response in RPO @@ -136,4 +129,4 @@ model: data: data_prefix: True -precision: ${trainer.precision} +precision: ${trainer.precision} \ No newline at end of file diff --git a/examples/nlp/gpt/conf/gpt_rpo.yaml b/examples/nlp/gpt/conf/gpt_rpo.yaml index bda090bb3..68407adcc 100644 --- a/examples/nlp/gpt/conf/gpt_rpo.yaml +++ b/examples/nlp/gpt/conf/gpt_rpo.yaml @@ -126,11 +126,11 @@ model: index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_prefix: train: - - /data/train_rpo_4resps.jsonl + - /data/responses_general.rm.formatted.4resp.jsonl test: - - /data/val_rpo_4resps.jsonl + - /data/rpo_test_set.jsonl validation: - - /data/val_rpo_4resps.jsonl + - /data/rpo_test_set.jsonl # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 282837ac1..90f14a322 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -16,9 +16,9 @@ from functools import partial import torch -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from megatron.core import parallel_state +from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.utils import divide from omegaconf.dictconfig import DictConfig from pytorch_lightning.trainer.trainer import Trainer @@ -31,6 +31,7 @@ from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo_aligner.models.alignable_interface import SupervisedInterface +from nemo_aligner.utils import parallel_state from nemo_aligner.utils.distributed import broadcast_2d_tensor, from_parallel_logits_to_logprobs from nemo_aligner.utils.train_utils import ( finish_validation_step, @@ -66,7 +67,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0) assert ( self.preference_loss_weight != 0 or self.sft_loss_weight != 0 - ), "sft loss weight and dpo loss weight cannot both be 0" + ), "sft loss weight and preference loss weight cannot both be 0" # variants of preference losses, by default DPO. self.preference_loss = self.cfg.dpo.get("preference_loss", "dpo") @@ -87,6 +88,7 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo torch.distributed.all_gather(output_list, batch_logs, group=dp_group) split_iter = map(self.split_output_tensor, output_list) + out_chosen, out_rejected = map(torch.cat, zip(*split_iter)) return out_chosen.flatten(), out_rejected.flatten() @@ -252,12 +254,8 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) - - print('Rewards:', rewards, gt_rewards) - rewards_delta = chosen_rewards - reject_rewards - if self.preference_loss == "dpo": loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) elif self.preference_loss == "rpo_bwd_kl": @@ -305,7 +303,6 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) chosen_logprobs, _ = self.split_output_tensor(logprobs) - return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): @@ -409,8 +406,10 @@ def finish_validation_step(self): @torch.no_grad() def get_logprob_batch(self, batch): seq_length = batch["chosen"].shape[1] + batch_size = batch["chosen"].shape[0] - data_iter = get_iterator_k_split(batch, get_num_microbatches()) + num_microbatches = divide(batch_size, self.cfg.dpo.log_prob_forward_micro_batch_size) + data_iter = get_iterator_k_split(batch, num_microbatches) set_sync_funcs(self, forward_only=True) fwd_bwd_function = get_forward_backward_func() @@ -419,10 +418,10 @@ def get_logprob_batch(self, batch): forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), data_iterator=data_iter, model=self.model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, forward_only=True, seq_length=seq_length, - micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size, + micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size * 2, collect_non_loss_data=True, ) @@ -460,4 +459,4 @@ def get_ref_policy_logprobs(self, batch): ref_log_probs = self.get_logprob_batch(batch) # return in GPU, trainer needs to move to cpu - return ref_log_probs + return ref_log_probs \ No newline at end of file diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index a452e55d4..d4ab4a87b 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -247,10 +247,9 @@ def log_sum_exp(self, x): def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): if self.preference_loss == 'rpo': # estimated rewards - rewards = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs + rewards_pred = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( + self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs ))) - rewards_pred = self.beta * rewards # based on GT rewards gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) @@ -261,7 +260,7 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p loss = ( torch.nn.functional.softmax(p_star, dim=0) * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) # adding accuracy for the best rewards -> MSE or best accuracy? - acc_best_resp = (torch.argmax(rewards, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() + acc_best_resp = (torch.argmax(rewards_pred, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() return loss, acc_best_resp @@ -369,7 +368,6 @@ def finish_validation_step(self): @torch.no_grad() def get_logprob_batch(self, batch): seq_length = batch["response_1"].shape[1] - data_iter = get_iterator_k_split(batch, get_num_microbatches()) set_sync_funcs(self, forward_only=True) @@ -387,12 +385,16 @@ def get_logprob_batch(self, batch): collect_non_loss_data=True, ) + each_response_list = [ [] for _ in range(self.k_len) ] + if len(logprobs_list) > 0: - logprobs_list_cat = [] for item in logprobs_list: all_log_probs = self.split_output_tensor(item["logprobs"]) - logprobs_list_cat.extend(all_log_probs) - logprobs = torch.cat(logprobs_list_cat) + for ind in range(self.k_len): + each_response_list[ind].extend(all_log_probs[ind]) + each_response_list = [ torch.stack(b, dim=0) for b in each_response_list ] + logprobs = torch.cat(each_response_list, dim=0) + else: logprobs = None From a0764b225abb8390161d9b541df36ddd77a8d806 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 24 Sep 2024 10:16:21 -0700 Subject: [PATCH 14/17] updated Signed-off-by: David Mosallanezhad --- examples/nlp/gpt/conf/gpt_rpo_top4.yaml | 140 ++++++++++++++++++++++++ examples/nlp/gpt/conf/gpt_sft.yaml | 10 +- nemo_aligner/utils/train_utils.py | 1 - nemo_aligner/utils/utils.py | 1 - 4 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 examples/nlp/gpt/conf/gpt_rpo_top4.yaml diff --git a/examples/nlp/gpt/conf/gpt_rpo_top4.yaml b/examples/nlp/gpt/conf/gpt_rpo_top4.yaml new file mode 100644 index 000000000..82fae29a0 --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_rpo_top4.yaml @@ -0,0 +1,140 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + num_nodes: 1 + devices: 8 + accelerator: gpu + precision: bf16 + + # rpo specific args + rpo: + max_epochs: 1 + max_steps: -1 + val_check_interval: 100 + save_interval: 100 + limit_train_batches: 1.0 + + # how many GBS we loop over + limit_val_batches: 10 + gradient_clip_val: 1.0 + num_responses: 4 + + # do not change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.rpo.max_epochs} + max_steps: ${.rpo.max_steps} + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt + max_time_per_run: ${trainer.max_time} + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_rpo + name: rlhf_gpt3_rpo + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: "megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}" + model_parallel_size: 4 + +pretrained_checkpoint: + restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo + +model: + mcore_gpt: True + micro_batch_size: 4 + global_batch_size: 8 + megatron_amp_O2: True + + rpo: + # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. + # A higher value can be used to speed-up log probs computations, but may cause numeric differences. + log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, trainer.rpo.num_responses} + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + gt_reward_scale: 1. # the scale of the rewards in RPO + preference_loss: rpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0.05 # the coefficient of the SFT loss + beta: 0.2 + eta: 0.2 + num_responses: ${trainer.rpo.num_responses} + + #encoder_seq_length: 4096 + #max_position_embeddings: ${model.encoder_seq_length} + + # miscellaneous + seed: 1234 + + #peft + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ["attention_qkv"] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: "xavier" # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: "zero" # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-6 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-7 + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: + train: + - /data/train_rpo_4resps_top4.jsonl + test: + - /data/val_rpo_4resps_top4.jsonl + validation: + - /data/val_rpo_4resps_top4.jsonl + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True + +precision: ${trainer.precision} diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index 9e880faa2..b6b88dcb0 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -14,7 +14,7 @@ trainer: save_interval: ${.val_check_interval} limit_train_batches: 1.0 - limit_val_batches: 1.0 + limit_val_batches: 10 gradient_clip_val: 1.0 # can be used to register any custom metrics that require token-by-token generation @@ -161,13 +161,13 @@ model: # - 0.5 # - 0.25 # - 0.25 - label_key: 'output' - add_eos: True + label_key: 'response' + add_eos: False add_sep: False add_bos: False - truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + truncation_field: "prompt" # # Can be multiple keys separated with ',' Options: keys in prompt_template index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + prompt_template: "{prompt} {response}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] diff --git a/nemo_aligner/utils/train_utils.py b/nemo_aligner/utils/train_utils.py index 5c8b32087..82b1dd660 100644 --- a/nemo_aligner/utils/train_utils.py +++ b/nemo_aligner/utils/train_utils.py @@ -39,7 +39,6 @@ def set_sync_funcs(ptl_model, forward_only): # pipeline schedules will get these from ptl_model.model.config for module in ptl_model.get_model_module_list(): - # for module in ptl_model.get_gpt_module_list(): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func module.config.param_sync_func = param_sync_func diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index 66968a4d0..78bc48b22 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -128,7 +128,6 @@ def load_and_override_model_config(restore_path, model_cfg_to_overwrite, remove_ """load the config in the model checkpoint and then overwrite it with whatever is provided """ - print('>>>>', restore_path) checkpoint_cfg = load_checkpoint_model_config(restore_path) if remove_meta_info: From 539abdb603f7717f8f93e0ae7528ee0006e211bb Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 24 Sep 2024 10:24:42 -0700 Subject: [PATCH 15/17] style fixes --- examples/nlp/gpt/conf/gpt_dpo.yaml | 2 +- examples/nlp/gpt/conf/gpt_rpo_top4.yaml | 140 ------------------ .../models/nlp/gpt/megatron_gpt_dpo_model.py | 2 +- 3 files changed, 2 insertions(+), 142 deletions(-) delete mode 100644 examples/nlp/gpt/conf/gpt_rpo_top4.yaml diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index cf7215346..2a165bf9d 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -129,4 +129,4 @@ model: data: data_prefix: True -precision: ${trainer.precision} \ No newline at end of file +precision: ${trainer.precision} diff --git a/examples/nlp/gpt/conf/gpt_rpo_top4.yaml b/examples/nlp/gpt/conf/gpt_rpo_top4.yaml deleted file mode 100644 index 82fae29a0..000000000 --- a/examples/nlp/gpt/conf/gpt_rpo_top4.yaml +++ /dev/null @@ -1,140 +0,0 @@ -defaults: - - optional tp_overlap@model.ub_tp_comm_overlap_cfg: - -trainer: - num_nodes: 1 - devices: 8 - accelerator: gpu - precision: bf16 - - # rpo specific args - rpo: - max_epochs: 1 - max_steps: -1 - val_check_interval: 100 - save_interval: 100 - limit_train_batches: 1.0 - - # how many GBS we loop over - limit_val_batches: 10 - gradient_clip_val: 1.0 - num_responses: 4 - - # do not change these - logger: False # logger provided by exp_manager - enable_checkpointing: False - use_distributed_sampler: False - max_time: null - max_epochs: ${.rpo.max_epochs} - max_steps: ${.rpo.max_steps} - -exp_manager: - explicit_log_dir: /results - exp_dir: null - name: megatron_gpt - max_time_per_run: ${trainer.max_time} - create_wandb_logger: False - wandb_logger_kwargs: - project: nemo_aligner_rpo - name: rlhf_gpt3_rpo - resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. - resume_if_exists: True - resume_ignore_no_checkpoint: True - create_checkpoint_callback: True - checkpoint_callback_params: - monitor: val_loss - save_top_k: 5 - mode: min - always_save_nemo: False # saves nemo file during validation, not implemented for model parallel - save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits - filename: "megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}" - model_parallel_size: 4 - -pretrained_checkpoint: - restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo - -model: - mcore_gpt: True - micro_batch_size: 4 - global_batch_size: 8 - megatron_amp_O2: True - - rpo: - # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. - # A higher value can be used to speed-up log probs computations, but may cause numeric differences. - log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, trainer.rpo.num_responses} - preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss - sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss - gt_reward_scale: 1. # the scale of the rewards in RPO - preference_loss: rpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl - preference_loss_weight: 1 # the coefficient of the preference loss - sft_loss_weight: 0.05 # the coefficient of the SFT loss - beta: 0.2 - eta: 0.2 - num_responses: ${trainer.rpo.num_responses} - - #encoder_seq_length: 4096 - #max_position_embeddings: ${model.encoder_seq_length} - - # miscellaneous - seed: 1234 - - #peft - peft: - peft_scheme: "none" # ["lora", "none"] - restore_from_path: null - restore_from_ckpt: - checkpoint_dir: null - checkpoint_name: null - - lora_tuning: - target_modules: ["attention_qkv"] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' - adapter_dim: 32 - adapter_dropout: 0.0 - column_init_method: "xavier" # IGNORED if linear_adapter is used, options: xavier, zero or normal - row_init_method: "zero" # IGNORED if linear_adapter is used, options: xavier, zero or normal - layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers - weight_tying: False - position_embedding_strategy: null # used only when weight_tying is True - - optim: - name: distributed_fused_adam - bucket_cap_mb: 200 - overlap_grad_sync: False - contiguous_grad_buffer: True - lr: 9e-6 - weight_decay: 0.1 - betas: - - 0.9 - - 0.98 - sched: - name: CosineAnnealing - warmup_steps: 10 - constant_steps: 1000 - min_lr: 9e-7 - - data: - data_impl: jsonl - splits_string: null - seq_length: ${model.encoder_seq_length} - skip_warmup: True - num_workers: 0 - dataloader_type: single # cyclic - reset_position_ids: False # Reset position ids after end-of-document token - reset_attention_mask: False # Reset attention mask after end-of-document token - eod_mask_loss: False # Mask loss for the end of document tokens - index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix - data_prefix: - train: - - /data/train_rpo_4resps_top4.jsonl - test: - - /data/val_rpo_4resps_top4.jsonl - validation: - - /data/val_rpo_4resps_top4.jsonl - - # define fields from the base model's config that should be ignored when merging with this config. - overwrite_base_config: - data: - data_prefix: True - -precision: ${trainer.precision} diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 90f14a322..952b4e897 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -459,4 +459,4 @@ def get_ref_policy_logprobs(self, batch): ref_log_probs = self.get_logprob_batch(batch) # return in GPU, trainer needs to move to cpu - return ref_log_probs \ No newline at end of file + return ref_log_probs From b2145fd5c8223f20e97663fcb97362f1eed4bab5 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 24 Sep 2024 10:28:12 -0700 Subject: [PATCH 16/17] style fixes --- examples/nlp/gpt/conf/gpt_sft.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index 231c814c8..bdd757f31 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -14,7 +14,7 @@ trainer: save_interval: ${.val_check_interval} limit_train_batches: 1.0 - limit_val_batches: 10 + limit_val_batches: 1.0 gradient_clip_val: 1.0 # can be used to register any custom metrics that require token-by-token generation @@ -160,13 +160,13 @@ model: # - 0.5 # - 0.25 # - 0.25 - label_key: 'response' - add_eos: False + label_key: 'output' + add_eos: True add_sep: False add_bos: False - truncation_field: "prompt" # # Can be multiple keys separated with ',' Options: keys in prompt_template + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "{prompt} {response}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] From bf90fdb0bb4b4ee33ea1f93a9ddceba3d2c73945 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:39:23 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo_aligner/algorithms/rpo.py | 29 +++++----- nemo_aligner/data/nlp/builders.py | 2 +- nemo_aligner/data/nlp/datasets.py | 35 ++++++------ .../models/nlp/gpt/megatron_gpt_rpo_model.py | 56 ++++++++++++------- 4 files changed, 69 insertions(+), 53 deletions(-) diff --git a/nemo_aligner/algorithms/rpo.py b/nemo_aligner/algorithms/rpo.py index a7de3d2ac..991f0b8e2 100644 --- a/nemo_aligner/algorithms/rpo.py +++ b/nemo_aligner/algorithms/rpo.py @@ -32,28 +32,28 @@ def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): resp_outputs = {} - + ## assume len 4 for responses for k in batch[0].keys(): - if k.startswith('response_'): + if k.startswith("response_"): # get response_i resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( - [item[k] for item in batch], - batch_first=True, padding_value=eos_id) - elif k.startswith('labels_'): + [item[k] for item in batch], batch_first=True, padding_value=eos_id + ) + elif k.startswith("labels_"): # get labels_i resp_outputs[k] = torch.nn.utils.rnn.pad_sequence( - [item[k] for item in batch], - batch_first=True, padding_value=-100) - elif k.startswith('lengths_'): + [item[k] for item in batch], batch_first=True, padding_value=-100 + ) + elif k.startswith("lengths_"): # get lens_i resp_outputs[k] = torch.LongTensor([item[k] for item in batch]) - elif k.startswith('rewards_'): + elif k.startswith("rewards_"): # get r_i resp_outputs[k] = torch.FloatTensor([item[k] for item in batch]) - + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - resp_outputs['response_1'], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, + resp_outputs["response_1"], eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, ) assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate" if attention_mask.shape[0] == 1: @@ -61,10 +61,9 @@ def rpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_ # attention_mask = attention_mask.expand(len(batch), *((-1,) * (len(attention_mask.shape) - 1))) attention_mask = attention_mask.repeat(4, *((1,) * (len(attention_mask.shape) - 1))) - resp_outputs["attention_mask"] = attention_mask resp_outputs["position_ids"] = position_ids - + return resp_outputs @@ -319,8 +318,8 @@ def augment_dataloader(self, dataloader): batch = next(iter_dataloader) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() ind = 1 - - for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): + + for logps in torch.split(logprobs, len(logprobs) // self.k_len, dim=0): batch["ref_policy_log_probs_response_" + str(ind)] = logps ind += 1 diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index ef2bcb54f..e90fb3aeb 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -43,11 +43,11 @@ from nemo.utils import logging from nemo_aligner.data.nlp.datasets import ( DPOModelDataset, - RPOModelDataset, KTOModelDataset, RegressionRewardModelDataset, RewardModelDataset, RLHFDataset, + RPOModelDataset, ) from nemo_aligner.utils import parallel_state from nemo_aligner.utils.utils import collate_with_batch_max_sequence_length diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index 73679af0a..2a2a4e0fa 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -15,11 +15,11 @@ """Custom datasets for RLHF training""" import os +import random import numpy as np import scipy import torch -import random from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset @@ -413,20 +413,20 @@ def __getitem__(self, idx): prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) responses = [] labels = [] - + # loop on responses of the given prompt to encode them - for resp in payload['responses']: - resp_tokens, resp_len = self.encode( - resp, append_eod=self.cfg.data.get("append_eod", False) - ) - + for resp in payload["responses"]: + resp_tokens, resp_len = self.encode(resp, append_eod=self.cfg.data.get("append_eod", False)) + resp_tokens = prompt + resp_tokens resp_len = len(resp_tokens) - + responses.append((resp_tokens, resp_len)) labels.append(([-100] * prompt_len) + resp_tokens[prompt_len:]) - assert resp_tokens[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response" + assert ( + resp_tokens[0:prompt_len] == prompt + ), "the tokenizer for DPO has merged tokens between prompt and response" max_curr_seq_len = max([i[1] for i in responses]) if max_curr_seq_len > self.seq_length: @@ -437,7 +437,7 @@ def __getitem__(self, idx): rewards = payload.get("rewards", [random.random() for _ in range(len(responses))]) resp_dict = {} - + for ind, (resp, resp_len) in enumerate(responses): resp_tokens = torch.nn.functional.pad( torch.LongTensor(resp), (0, max_curr_seq_len - resp_len), mode="constant", value=self.eos_id @@ -446,20 +446,19 @@ def __getitem__(self, idx): label_tokens = torch.nn.functional.pad( torch.LongTensor(label), (0, max_curr_seq_len - len(label)), mode="constant", value=-100 ) - + # slice if necessary if max_curr_seq_len > self.seq_length: resp_tokens = resp_tokens[: self.nograd_length] label_tokens = torch.ones_like(resp_tokens) * (-100) resp_len = self.nograd_length - - resp_dict['response_' + str(ind+1)] = resp_tokens - resp_dict['labels_' + str(ind+1)] = label_tokens - resp_dict['lengths_' + str(ind+1)] = resp_len - resp_dict['rewards_' + str(ind+1)] = rewards[ind] - - return resp_dict + resp_dict["response_" + str(ind + 1)] = resp_tokens + resp_dict["labels_" + str(ind + 1)] = label_tokens + resp_dict["lengths_" + str(ind + 1)] = resp_len + resp_dict["rewards_" + str(ind + 1)] = rewards[ind] + + return resp_dict class KTOModelDataset(Dataset): diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py index d4ab4a87b..430e48ca6 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rpo_model.py @@ -121,20 +121,22 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # creating tokens and labels tensor batches tokens, labels, ref_logprobs, gt_rewards = None, None, None, None if batch["response_1"] is not None: - tokens = torch.cat(tuple(batch['response_' + str(i+1)] for i in range(self.k_len)), dim=0) + tokens = torch.cat(tuple(batch["response_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch["labels_1"] is not None: - labels = torch.cat(tuple(batch['labels_' + str(i+1)] for i in range(self.k_len)), dim=0) + labels = torch.cat(tuple(batch["labels_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch["rewards_1"] is not None: - gt_rewards = torch.cat(tuple(batch['rewards_' + str(i+1)] for i in range(self.k_len)), dim=0) + gt_rewards = torch.cat(tuple(batch["rewards_" + str(i + 1)] for i in range(self.k_len)), dim=0) if batch.get("ref_policy_log_probs_response_1") is not None: - ref_logprobs = torch.cat(tuple(batch['ref_policy_log_probs_response_' + str(i+1)] for i in range(self.k_len)), dim=0) + ref_logprobs = torch.cat( + tuple(batch["ref_policy_log_probs_response_" + str(i + 1)] for i in range(self.k_len)), dim=0 + ) # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs # these two lines ensure your position_ids and attn_mask are always B=1 - attention_mask = batch['attention_mask'][0:1] + attention_mask = batch["attention_mask"][0:1] # Model forward pass forward_args = { @@ -195,7 +197,6 @@ def loss_func(output_tensor): ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss - ( reduced_loss, reduced_preference_loss, @@ -245,11 +246,15 @@ def log_sum_exp(self, x): return max_x + torch.log(torch.sum(torch.exp(x - max_x))) def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): - if self.preference_loss == 'rpo': + if self.preference_loss == "rpo": # estimated rewards - rewards_pred = torch.stack(self.split_output_tensor(self.get_reduced_masked_logps( - self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs - ))) + rewards_pred = torch.stack( + self.split_output_tensor( + self.get_reduced_masked_logps( + self.beta * (pi_logprobs - ref_logprobs), labels, average_log_probs=average_log_probs + ) + ) + ) # based on GT rewards gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) @@ -257,17 +262,29 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p else: raise ValueError("Unknown RPO Loss") - loss = ( torch.nn.functional.softmax(p_star, dim=0) * (torch.nn.functional.log_softmax( p_star, dim=0 ) - torch.nn.functional.log_softmax( rewards_pred, dim=0 )) ).sum(0).mean(0) - + loss = ( + ( + torch.nn.functional.softmax(p_star, dim=0) + * ( + torch.nn.functional.log_softmax(p_star, dim=0) + - torch.nn.functional.log_softmax(rewards_pred, dim=0) + ) + ) + .sum(0) + .mean(0) + ) + # adding accuracy for the best rewards -> MSE or best accuracy? acc_best_resp = (torch.argmax(rewards_pred, dim=0) == torch.argmax(gt_rewards, dim=0)).float().mean() return loss, acc_best_resp def sft_loss_func(self, pi_logprobs, labels, gt_rewards, average_log_probs=False): - logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] - all_log_probs = torch.stack(self.split_output_tensor(logprobs)) # [4, 4] -> each has several responses which we select the best? - gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) # [16] + all_log_probs = torch.stack( + self.split_output_tensor(logprobs) + ) # [4, 4] -> each has several responses which we select the best? + gt_rewards = torch.stack(self.split_output_tensor(gt_rewards)) # same, we split the rewards chosen_best = torch.argmax(gt_rewards, dim=0) chosen_logprobs = all_log_probs[chosen_best, torch.arange(all_log_probs.size(1))] @@ -288,7 +305,8 @@ def get_loss_and_metrics(self, batch, forward_only): num_microbatches=get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses + micro_batch_size=self.cfg.micro_batch_size + * self.k_len, # each minibatch has K comparisons so tensor shape will be mbs * num_responses ) # only the last stages of the pipeline return losses @@ -385,16 +403,16 @@ def get_logprob_batch(self, batch): collect_non_loss_data=True, ) - each_response_list = [ [] for _ in range(self.k_len) ] + each_response_list = [[] for _ in range(self.k_len)] if len(logprobs_list) > 0: for item in logprobs_list: all_log_probs = self.split_output_tensor(item["logprobs"]) for ind in range(self.k_len): each_response_list[ind].extend(all_log_probs[ind]) - each_response_list = [ torch.stack(b, dim=0) for b in each_response_list ] + each_response_list = [torch.stack(b, dim=0) for b in each_response_list] logprobs = torch.cat(each_response_list, dim=0) - + else: logprobs = None