From 87584037722e3f3e8753b29c5159bbe8bbec5dbc Mon Sep 17 00:00:00 2001 From: LeSingh1 Date: Sun, 31 May 2026 14:45:16 -0700 Subject: [PATCH] Don't retain saved tensors in fused norm custom ops under no_grad The custom-op forward path for FusedRMSNorm/FusedLayerNorm registers an autograd setup_context that unconditionally calls save_for_backward. For torch.library custom ops these saved tensors are retained in autograd metadata that is not released after the call returns, so each forward under torch.no_grad() leaks the saved activation and the invvar tensor (two CUDA tensors per call), accumulating linearly in long-running inference (issue #1999). Skip the save_for_backward calls when grad is disabled, since backward can never run in that case. The grad-enabled training path is unchanged. Signed-off-by: LeSingh1 --- apex/normalization/fused_layer_norm.py | 42 ++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index a0f3833bc..41f95a95e 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -182,6 +182,13 @@ def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar): def _fused_layer_norm_affine_setup_context(ctx, inputs, output): input, weight, bias, normalized_shape, eps, memory_efficient = inputs output, mean, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() @@ -189,9 +196,6 @@ def _fused_layer_norm_affine_setup_context(ctx, inputs, output): ctx.save_for_backward(output, weight_, bias_, None, invvar) else: ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_layer_norm_affine_fwd.register_autograd( _fused_layer_norm_affine_backward, @@ -337,15 +341,21 @@ def _fused_rms_norm_affine_backward(ctx, grad_output, grad_invvar): def _fused_rms_norm_affine_setup_context(ctx, inputs, output): input_, weight_, normalized_shape, eps, memory_efficient = inputs output_, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # Under torch.no_grad() backward will never run, so don't retain the + # activation/invvar tensors. For custom ops these saves are held in autograd + # metadata that isn't released after the call returns, which otherwise leaks + # two CUDA tensors per forward in inference loops (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input_.contiguous() weight_ = weight_.contiguous() if memory_efficient: ctx.save_for_backward(output_, weight_, invvar) else: ctx.save_for_backward(input_, weight_, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_rms_norm_affine_fwd.register_autograd( _fused_rms_norm_affine_backward, @@ -515,14 +525,18 @@ def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar): def _fused_layer_norm_setup_context(ctx, inputs, output): input, normalized_shape, eps, memory_efficient = inputs output, mean, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input.contiguous() if memory_efficient: ctx.save_for_backward(output, None, invvar) else: ctx.save_for_backward(input_, mean, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_layer_norm_fwd.register_autograd( _fused_layer_norm_backward, @@ -653,14 +667,18 @@ def _fused_rms_norm_backward(ctx, grad_output, grad_invvar): def _fused_rms_norm_setup_context(ctx, inputs, output): input_, normalized_shape, eps, memory_efficient = inputs output_, invvar = output + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + # See _fused_rms_norm_affine_setup_context: skip the saves under no_grad to + # avoid leaking the retained tensors (issue #1999). + if not torch.is_grad_enabled(): + return input_ = input_.contiguous() if memory_efficient: ctx.save_for_backward(output_, invvar) else: ctx.save_for_backward(input_, invvar) - ctx.normalized_shape = normalized_shape - ctx.eps = eps - ctx.memory_efficient = memory_efficient fused_rms_norm_fwd.register_autograd( _fused_rms_norm_backward, setup_context=_fused_rms_norm_setup_context