From 6373d0ba6146d3644b61f508d403ebf1be3b09bb Mon Sep 17 00:00:00 2001 From: Victor Guichard Date: Sat, 24 Jun 2023 10:47:48 +0000 Subject: [PATCH] fallback to float16 if bfloat16 is not supported --- src/train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 0f387d6b..1355f9ad 100644 --- a/src/train.py +++ b/src/train.py @@ -84,6 +84,19 @@ def main(args, resume_preempt=False): device = torch.device('cuda:0') torch.cuda.set_device(device) + # Check if bfloat16 is supported, if not fall back to float16 if bfloat16 was requested + autocast_dtype = torch.bfloat16 if use_bfloat16 else torch.float32 + + bfloat16_supported = False + try: + bfloat16_supported = torch.cuda.is_bf16_supported() + except RuntimeError: + bfloat16_supported = False + + if not bfloat16_supported and use_bfloat16: + logger.info(f'Device does not support bfloat16, falling back to float16') + autocast_dtype = torch.float16 + # -- DATA use_gaussian_blur = args['data']['use_gaussian_blur'] use_horizontal_flip = args['data']['use_horizontal_flip'] @@ -313,7 +326,7 @@ def loss_fn(z, h): return loss # Step 1. Forward - with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16): + with torch.cuda.amp.autocast(dtype=autocast_dtype, enabled=use_bfloat16): h = forward_target() z = forward_context() loss = loss_fn(z, h)