Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions acestep/streaming/ace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@
Capabilities,
TickContext,
)
from acestep.steering import SteeringController
from acestep.streaming.knobs import (
CHANNEL_GROUPS,
KEYSTONE_CHANNELS,
knob_specs as registry_knob_specs,
manual_slot_specs,
steering_axis_spec,
)

# Audio sample rate the ACE-Step v1.5 family is trained on, and the
Expand Down Expand Up @@ -121,6 +124,41 @@ def _curve_from_spec(spec, T):
return None


def steering_knob_specs(steering: "SteeringController") -> list:
"""Project a SteeringController's live surface into registry specs.

Empty when no vector bundle is reachable (the knobs would be dead).
The spec SHAPES come from the registry factories in
``acestep.streaming.knobs``; only the axis/catalog metadata is
filled in here, where the steering policy lives.
"""
if not steering.is_loaded:
return []
specs: list = []
for ax in steering.auto_axes:
inject_layer = max(
0, min(steering.MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset),
)
specs.append(steering_axis_spec(
ax.name,
axis=ax.axis,
inject_layer=inject_layer,
probe_step=ax.probe_step,
probe_n=steering._probe_n,
blurb=ax.blurb,
))
src_max = max(0, len(steering.catalog) - 1)
for slot in steering.active_slots():
specs.extend(manual_slot_specs(
slot,
src_max=src_max,
catalog_len=len(steering.catalog),
layer_max=steering.MANUAL_MAX_LAYER,
step_max=steering.MANUAL_MAX_STEP,
))
return specs


class ACEStepBackend(DiffusionBackend):
"""ACE-Step v1.5 diffusion generation behind the GeneratorBackend seam.

Expand All @@ -141,6 +179,7 @@ def __init__(
walk_window=False,
walk_window_s=60.0,
neg_conditioning=None,
steering: SteeringController | None = None,
):
# The family codec is the engine Session: its windowed VAE
# decode is what render_window()/render_full() drive. The
Expand Down Expand Up @@ -205,6 +244,19 @@ def __init__(
# ``None`` on the first tick just seeds the baseline.
self._last_rebuild_keys = None

# Activation steering. The controller is the source of truth for
# the slot count and vector catalog; the session mirrors its
# slot ops into KnobState / the knob manifest. ``None`` (e.g. a
# bare-construction test fixture) degrades to an unloaded
# controller so every consumer can read it unconditionally.
self.steering = (
steering if steering is not None else SteeringController(None)
)
# (pipeline, snapshot) change-detection key for _sync_steering;
# None forces a push on the first tick and after a
# steps_override-driven pipeline rebuild.
self._last_steering = None

# ----- per-tick translation state (the old run() locals) -----
self._last_latent = None
# Previous fresh latent for the full-buffer MSE skip. Tracked
Expand Down Expand Up @@ -275,6 +327,7 @@ def capabilities(self) -> Capabilities:
depth=True,
curves=True,
notes_conditioning=False,
steering=self.steering.is_loaded,
)

def geometry(self) -> AudioGeometry:
Expand All @@ -291,11 +344,13 @@ def geometry(self) -> AudioGeometry:
def knob_specs(self, lora_ids=()) -> list:
"""The ACE-family manifest: the shared registry's spec list,
parameterized by this session's SDE mode and the enabled-LoRA
set the session passes in (see the protocol docstring)."""
set the session passes in (see the protocol docstring), plus
the activation-steering surface (auto axes + the live manual
slots) when this session's checkpoint has a vector bundle."""
return registry_knob_specs(
self.use_sde,
loras=list(lora_ids) if self.use_lora else [],
)
) + steering_knob_specs(self.steering)

# ---- public hooks reachable from session ops ---------------------------

Expand Down Expand Up @@ -377,6 +432,28 @@ def _sync_channel_guidance(self, raw: dict, last: list) -> list:
self.stream.model.handler._channel_guidance = configs
return ch_gains[:]

def _sync_steering(self, raw: dict, last):
"""Push activation-steering configs when the snapshot changes.

``last`` is ``(pipeline, snapshot_tuple)`` or ``None``. Pipeline
identity is part of the key because ``steps_override`` rebuilds
the StreamPipeline (fresh, empty steering state) without
changing ``raw`` — without the identity check the new pipeline
would never receive ``set_steering``.
"""
if not self.steering.is_loaded:
return last
pipe = self.stream.pipeline
if pipe is None:
return last
n = max(1, int(raw.get("steps_override", 8)))
snapshot = self.steering.snapshot_key(raw, n)
last_pipe, last_snapshot = last if last is not None else (None, None)
if pipe is last_pipe and snapshot == last_snapshot:
return last
pipe.set_steering(self.steering.build_configs(raw, n))
return (pipe, snapshot)

# ---- GeneratorBackend hot loop -----------------------------------------

def sync_source(self, ctx: TickContext) -> None:
Expand Down Expand Up @@ -712,6 +789,9 @@ def _prepare_tick(self, knobs: dict, ctx: TickContext) -> dict:
self._last_channel_gains = self._sync_channel_guidance(
raw, self._last_channel_gains,
)
self._last_steering = self._sync_steering(
raw, self._last_steering,
)

# Route every curve-capable parameter through the shared
# mutable curve system so knob changes take effect on ALL
Expand Down
9 changes: 9 additions & 0 deletions acestep/streaming/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ class DepthApplied:
value: int


@dataclass(frozen=True)
class ManualSlotCount:
"""Manual steering slot count after a manual_slot_add / manual_slot_pop
(published on success AND refusal so the client's +/- UI resyncs
either way). ``count`` is the controller's live slot count."""

count: int


@dataclass(frozen=True)
class SwapReady:
"""Source swap completed. Carries enough state for the transport to
Expand Down
47 changes: 45 additions & 2 deletions acestep/streaming/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@


def _make_acestep(ss):
from acestep.steering import SteeringController, ensure_steering_vectors
from acestep.streaming.ace_backend import ACEStepBackend

# SteeringController is the source of truth for slot_count and the
# vector catalog; ensure_steering_vectors fetches/caches the
# checkpoint's probe bundle (None for checkpoints without one — XL,
# fetch failures — which degrades the controller to is_loaded=False
# and drops the steering capability/knobs for the session).
steering = SteeringController(ensure_steering_vectors(ss.checkpoint))

return ACEStepBackend(
ss.session, ss.stream,
state=ss.state,
Expand All @@ -34,6 +42,7 @@ def _make_acestep(ss):
walk_window=ss.walk_window,
walk_window_s=ss.walk_window_s,
neg_conditioning=ss.cond_negative,
steering=steering,
)


Expand All @@ -43,14 +52,48 @@ def _make_acestep(ss):


def _acestep_knob_universe():
from acestep.streaming.knobs import knob_specs
from acestep.steering.policy import (
AUTO_AXES,
MANUAL_MAX_LAYER,
MANUAL_MAX_STEP,
PROBE_N,
)
from acestep.streaming.knobs import (
knob_specs,
manual_slot_specs,
steering_axis_spec,
)

# Every spec the family can ever expose: both SDE-mode variants plus
# a representative LoRA-strength knob (the per-id specs all come from
# lora_strength_spec, so one placeholder id covers the pattern).
# lora_strength_spec, so one placeholder id covers the pattern), plus
# the steering surface — the four auto axes and one representative
# manual slot (per-slot specs all come from manual_slot_specs).
# Catalog geometry uses the canonical v15-turbo bundle's 144 cells;
# no network fetch happens here (policy tables only).
steering = [
steering_axis_spec(
ax.name,
axis=ax.axis,
inject_layer=max(
0, min(MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset),
),
probe_step=ax.probe_step,
probe_n=PROBE_N,
blurb=ax.blurb,
)
for ax in AUTO_AXES
] + manual_slot_specs(
1,
src_max=143,
catalog_len=144,
layer_max=MANUAL_MAX_LAYER,
step_max=MANUAL_MAX_STEP,
)
return (
knob_specs(False, loras=["<lora_id>"])
+ knob_specs(True, loras=["<lora_id>"])
+ steering
)


Expand Down
5 changes: 5 additions & 0 deletions acestep/streaming/generator_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ class Capabilities:
depth: bool = False
curves: bool = False
notes_conditioning: bool = False
# Activation steering (per-layer residual shifts driven by the
# steer_* / man_*_<N> knobs and the manual_slot_add/pop commands).
# True only when the backend has a steering controller with a
# reachable vector bundle for its checkpoint.
steering: bool = False


@dataclass(frozen=True)
Expand Down
95 changes: 95 additions & 0 deletions acestep/streaming/knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,101 @@ def lora_strength_spec(lora_id: str) -> KnobSpec:
)


# Activation-steering alpha range. Bipolar so the operator can invert an
# axis without leaving the surface; useful magnitude is roughly 2..15 by
# ear, breakage above that.
STEERING_ALPHA_MAX = 30.0


def steering_axis_spec(
name: str,
*,
axis: str = "",
inject_layer: int = 0,
probe_step: int = 0,
probe_n: int = 8,
blurb: str = "",
) -> KnobSpec:
"""The registry spec for one auto-path activation-steering knob.

Shaped here (range, group, bank) so every transport projects the
same contract; the axis metadata (where the vector injects, what it
does) arrives as plain values from the backend that owns the
steering policy — this module stays torch-free / acestep-free.
"""
return KnobSpec(
name, default=0.0,
min_val=-STEERING_ALPHA_MAX, max_val=STEERING_ALPHA_MAX,
group="steering",
description=(
f"Activation-steering ({axis}) injected at DiT layer "
f"{inject_layer}, step round({probe_step}/{probe_n} * inject_n) "
f"of the current schedule. 0 = off, negative inverts the axis "
f"direction. {blurb}."
" Useful magnitude roughly 2..15 by ear; breakage above that."
),
)


def manual_slot_specs(
slot_id: int,
*,
src_max: int,
catalog_len: int,
layer_max: int,
step_max: int,
) -> list:
"""The four registry specs for one manual steering slot.

Like :func:`lora_strength_spec`, factored so the runtime slot
add path and the session-start manifest both shape the knobs from
the registry. Manual slots bypass the auto path's fractional step
mapping, layer offset, and sign correction — the vector lands at
the operator's chosen cell with the operator's chosen sign.
"""
return [
KnobSpec(
f"man_src_{slot_id}", default=0.0, min_val=0.0,
max_val=float(src_max), type="int", group="manual",
description=(
f"Manual slot {slot_id}: vector catalog index. Resolves to "
f"a (axis, build_layer, build_step) cell on disk; call "
f"list_manual_steering_vectors for the table. Index "
f"0..{src_max} ({catalog_len} cells)."
),
),
KnobSpec(
f"man_layer_{slot_id}", default=9.0, min_val=0.0,
max_val=float(layer_max), type="int", group="manual",
description=(
f"Manual slot {slot_id}: DiT inject layer (0..{layer_max}). "
"Passed verbatim to the engine; no automatic offset."
),
),
KnobSpec(
f"man_step_{slot_id}", default=0.0, min_val=0.0,
max_val=float(step_max), type="int", group="manual",
description=(
f"Manual slot {slot_id}: diffusion inject step "
f"(0..{step_max}). No fractional mapping. Values past the "
"current steps_override - 1 silently no-op (the engine only "
"fires when step equals the active diffusion step)."
),
),
KnobSpec(
f"man_alpha_{slot_id}", default=0.0,
min_val=-STEERING_ALPHA_MAX, max_val=STEERING_ALPHA_MAX,
group="manual",
description=(
f"Manual slot {slot_id}: injection strength. 0 = slot off. "
"Bipolar: negative alpha inverts the chosen vector's "
"direction at injection (no sign correction is applied). "
"Useful magnitude roughly 2..15 by ear; breakage above that."
),
),
]


def knob_catalog(sde: bool, loras=None) -> dict:
"""Project the full registry into a transport-agnostic catalog:
``name -> {type, default, min?, max, group, options?, description?,
Expand Down
Loading