diff --git a/src/scope/core/pipelines/longlive/modules/causal_model.py b/src/scope/core/pipelines/longlive/modules/causal_model.py index 8607855bc..96e731484 100644 --- a/src/scope/core/pipelines/longlive/modules/causal_model.py +++ b/src/scope/core/pipelines/longlive/modules/causal_model.py @@ -1,5 +1,6 @@ # Modified from https://github.com/NVlabs/LongLive # SPDX-License-Identifier: CC-BY-NC-SA-4.0 +import logging import math import torch @@ -13,6 +14,7 @@ ) from scope.core.pipelines.wan2_1.modules.attention import attention + from .model import ( WAN_CROSSATTENTION_CLASSES, MLPProj, @@ -30,6 +32,8 @@ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" ) +logger = logging.getLogger(__name__) + def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): n, c = x.size(2), x.size(3) // 2 @@ -331,6 +335,15 @@ def qkv_fn(x): temp_v[:, write_start_index:local_end_index] = v[ :, roped_offset : roped_offset + write_len ] + else: + logger.warning( + "KV cache write skipped (roll_and_insert): " + "write_start_index (%d) >= local_end_index (%d) — " + "cache indices may be stale after a rapid mode transition. " + "(See daydreamlive/scope#921)", + write_start_index, + local_end_index, + ) # Save cache update info for later use cache_update_info = { @@ -377,6 +390,15 @@ def qkv_fn(x): temp_v[:, write_start_index:local_end_index] = v[ :, roped_offset : roped_offset + write_len ] + else: + logger.warning( + "KV cache write skipped (direct_insert): " + "write_start_index (%d) >= local_end_index (%d) — " + "cache indices may be stale after a rapid mode transition. " + "(See daydreamlive/scope#921)", + write_start_index, + local_end_index, + ) # Save cache update info for later use cache_update_info = {