Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/scope/core/pipelines/memflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def __init__(
):
from .modules.causal_model import CausalWanModel

# Validate resolution requirements
# VAE downsample (8) * patch embedding downsample (2) = 16
validate_resolution(
# Snap resolution to the nearest multiple of 16.
# VAE downsample (8) × patch embedding downsample (2) = 16.
# Instead of hard-failing, round down and log a warning so that
# non-standard input resolutions (e.g. 674×389) still work.
config.height, config.width = validate_resolution(
height=config.height,
width=config.width,
scale_factor=16,
snap=True,
)

model_dir = getattr(config, "model_dir", None)
Expand Down
56 changes: 56 additions & 0 deletions src/scope/core/pipelines/test_utils_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Tests for validate_resolution / snap_to_multiple in pipelines/utils.py."""
import pytest

from scope.core.pipelines.utils import snap_to_multiple, validate_resolution


class TestSnapToMultiple:
def test_already_aligned(self):
assert snap_to_multiple(672, 16) == 672

def test_rounds_down(self):
assert snap_to_multiple(674, 16) == 672
assert snap_to_multiple(389, 16) == 384

def test_smaller_than_multiple(self):
assert snap_to_multiple(10, 16) == 0


class TestValidateResolution:
# --- default snap=False behaviour (raises on invalid) ---

def test_valid_resolution_returns_unchanged(self):
h, w = validate_resolution(height=384, width=672, scale_factor=16)
assert (h, w) == (384, 672)

def test_invalid_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid resolution"):
validate_resolution(height=389, width=674, scale_factor=16)

def test_error_message_contains_suggestion(self):
with pytest.raises(ValueError, match="672×384"):
validate_resolution(height=389, width=674, scale_factor=16)

# --- snap=True behaviour (rounds down, no exception) ---

def test_snap_invalid_resolution(self):
h, w = validate_resolution(height=389, width=674, scale_factor=16, snap=True)
assert (h, w) == (384, 672)

def test_snap_valid_resolution_unchanged(self):
h, w = validate_resolution(height=320, width=576, scale_factor=16, snap=True)
assert (h, w) == (320, 576)

def test_snap_only_height_unaligned(self):
h, w = validate_resolution(height=385, width=576, scale_factor=16, snap=True)
assert (h, w) == (384, 576)

def test_snap_only_width_unaligned(self):
h, w = validate_resolution(height=384, width=577, scale_factor=16, snap=True)
assert (h, w) == (384, 576)

def test_snap_logs_warning(self, caplog):
import logging
with caplog.at_level(logging.WARNING, logger="scope.core.pipelines.utils"):
validate_resolution(height=389, width=674, scale_factor=16, snap=True)
assert any("Snapping resolution" in r.message for r in caplog.records)
34 changes: 29 additions & 5 deletions src/scope/core/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,54 @@ def load_model_config(config, pipeline_file_path: str | Path) -> OmegaConf:
return model_config


def snap_to_multiple(val: int, multiple: int) -> int:
"""Round *val* down to the nearest multiple of *multiple*."""
return (val // multiple) * multiple


def validate_resolution(
height: int,
width: int,
scale_factor: int,
) -> None:
snap: bool = False,
) -> tuple[int, int]:
"""
Validate that resolution dimensions are divisible by the required scale factor.
Validate (and optionally snap) resolution dimensions to a required scale factor.

Args:
height: Height of the resolution
width: Width of the resolution
scale_factor: The factor that both dimensions must be divisible by
snap: If True, silently round down to the nearest valid multiple instead
of raising an error. A warning is logged when snapping occurs.

Returns:
A ``(height, width)`` tuple. When *snap* is False and the dimensions
are already valid the input values are returned unchanged. When *snap*
is True the (possibly adjusted) values are returned.

Raises:
ValueError: If height or width is not divisible by scale_factor
ValueError: If *snap* is False and height or width is not divisible by
*scale_factor*.
"""
if height % scale_factor != 0 or width % scale_factor != 0:
adjusted_width = (width // scale_factor) * scale_factor
adjusted_height = (height // scale_factor) * scale_factor
adjusted_width = snap_to_multiple(width, scale_factor)
adjusted_height = snap_to_multiple(height, scale_factor)
if snap:
import logging
logging.getLogger(__name__).warning(
"Snapping resolution from %d×%d to %d×%d "
"(both dimensions must be divisible by %d)",
width, height, adjusted_width, adjusted_height, scale_factor,
)
return adjusted_height, adjusted_width
raise ValueError(
f"Invalid resolution {width}×{height}. "
f"Both width and height must be divisible by {scale_factor} "
f"Please adjust to a valid resolution, e.g., {adjusted_width}×{adjusted_height}."
f"\nIf this error persists, consider removing the models directory and re-downloading models."
)
return height, width


def parse_jsonl_prompts(file_path: str) -> list[list[str]]:
Expand Down
Loading