diff --git a/apex/contrib/test/clip_grad/test_clip_grad.py b/apex/contrib/test/clip_grad/test_clip_grad.py index fb5b3706d..272380bce 100644 --- a/apex/contrib/test/clip_grad/test_clip_grad.py +++ b/apex/contrib/test/clip_grad/test_clip_grad.py @@ -104,7 +104,7 @@ def test_matches_pytorch( ) def test_matches_pytorch_fp16(self): - self.test_matches_pytorch(num_params=11, dtypes=[torch.float16]) + self.test_matches_pytorch(num_params=11, dtypes=[torch.float16], rtol=5e-3) def test_matches_pytorch_fp32(self): self.test_matches_pytorch(dtypes=[torch.float32], rtol=1e-6)