diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index a0f3833bc..41f95a95e 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -182,6 +182,13 @@ def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar): def _fused_layer_norm_affine_setup_context(ctx, inputs, output): input, weight, bias, normalized_shape, eps, memory_efficient = inputs output, mean, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() @@ -189,9 +196,6 @@ def _fused_layer_norm_affine_setup_context(ctx, inputs, output): ctx.save_for_backward(output, weight_, bias_, None, invvar) else: ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_layer_norm_affine_fwd.register_autograd( _fused_layer_norm_affine_backward, @@ -337,15 +341,21 @@ def _fused_rms_norm_affine_backward(ctx, grad_output, grad_invvar): def _fused_rms_norm_affine_setup_context(ctx, inputs, output): input_, weight_, normalized_shape, eps, memory_efficient = inputs output_, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # Under torch.no_grad() backward will never run, so don't retain the + # activation/invvar tensors. For custom ops these saves are held in autograd + # metadata that isn't released after the call returns, which otherwise leaks + # two CUDA tensors per forward in inference loops (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input_.contiguous() weight_ = weight_.contiguous() if memory_efficient: ctx.save_for_backward(output_, weight_, invvar) else: ctx.save_for_backward(input_, weight_, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_rms_norm_affine_fwd.register_autograd( _fused_rms_norm_affine_backward, @@ -515,14 +525,18 @@ def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar): def _fused_layer_norm_setup_context(ctx, inputs, output): input, normalized_shape, eps, memory_efficient = inputs output, mean, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input.contiguous() if memory_efficient: ctx.save_for_backward(output, None, invvar) else: ctx.save_for_backward(input_, mean, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_layer_norm_fwd.register_autograd( _fused_layer_norm_backward, @@ -653,14 +667,18 @@ def _fused_rms_norm_backward(ctx, grad_output, grad_invvar): def _fused_rms_norm_setup_context(ctx, inputs, output): input_, normalized_shape, eps, memory_efficient = inputs output_, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input_.contiguous() if memory_efficient: ctx.save_for_backward(output_, invvar) else: ctx.save_for_backward(input_, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_rms_norm_fwd.register_autograd( _fused_rms_norm_backward, setup_context=_fused_rms_norm_setup_context