From 339f110515c73b27fd3227a2015741243f6701e2 Mon Sep 17 00:00:00 2001 From: JulesCollenne Date: Thu, 22 Jun 2023 10:41:59 +0200 Subject: [PATCH] Added support for GPUs that can't use bfloat16 and non-distributed training --- src/train.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/train.py b/src/train.py index 0f387d6b..d2fbb274 100644 --- a/src/train.py +++ b/src/train.py @@ -218,9 +218,12 @@ def main(args, resume_preempt=False): num_epochs=num_epochs, ipe_scale=ipe_scale, use_bfloat16=use_bfloat16) - encoder = DistributedDataParallel(encoder, static_graph=True) - predictor = DistributedDataParallel(predictor, static_graph=True) - target_encoder = DistributedDataParallel(target_encoder) + + if world_size != 1: + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): p.requires_grad = False @@ -312,18 +315,21 @@ def loss_fn(z, h): loss = AllReduce.apply(loss) return loss - # Step 1. Forward - with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16): - h = forward_target() - z = forward_context() - loss = loss_fn(z, h) - - # Step 2. Backward & step if use_bfloat16: + # Step 1. Forward + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16): + h = forward_target() + z = forward_context() + loss = loss_fn(z, h) + + # Step 2. Backward & step scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: + h = forward_target() + z = forward_context() + loss = loss_fn(z, h) loss.backward() optimizer.step() grad_stats = grad_logger(encoder.named_parameters())