Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimized/mlx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ Apple Silicon only (MLX is Metal-backed). Python 3.10+. `./install.sh
./sa3 --prompt "ambient drone" --cfg 3.0 --negative-prompt "drums, vocals" \
--dit sm-music --decoder same-s --out drone.wav

# Apply a LoRA finetune (merged into the DiT at load; base must match --dit)
./sa3 --prompt "arabic maqam oud taqsim" --dit medium --decoder same-l \
--lora ./my_lora.safetensors --lora-strength 1.0 --out maqam.wav

# Generate + play immediately (afplay; Ctrl-C stops both)
./sa3 --prompt "rainforest" --dit sm-sfx --decoder same-s --play

Expand Down Expand Up @@ -178,6 +182,8 @@ Sample run on **M4 Pro / 48 GB**:
| `--init-noise-level` | 1.0 | σmax; 0.4–0.8 typical for variation, 1.0 = full regen, >1 = overshoot |
| `--inpaint-range` | — | `START,END` seconds; regenerate that span, keep the rest |
| `--dit-dtype` | fp16 | DiT compute dtype (decoder always FP32; T5Gemma always fp16) |
| `--lora` | — | One or more `.safetensors` LoRA adapters merged into the DiT at load (SA3-native or PEFT). Pickle `.ckpt/.pt` is refused. Base must match `--dit` |
| `--lora-strength` | 1.0 | Application weight per `--lora` delta; 0 = bit-exact bypass, >1 amplifies |
| `--free-models` | on | Progressive model freeing; `--no-free-models` keeps them resident |
| `--out` | out.wav | Relative → `output/<file>`; absolute → as-is. 16-bit PCM stereo @ 44.1 kHz, trimmed to exactly `--seconds` |
| `--play` | off | After writing, play via `afplay`; Ctrl-C stops both processes |
Expand Down
12 changes: 11 additions & 1 deletion optimized/mlx/models/defs/dit_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,29 @@ def convert_weights_from_torch_ckpt(ckpt_path):
return out


def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False):
def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False,
lora_paths=None, lora_strength=1.0, lora_log=print):
"""Build MLX DiT and load weights.

weights_path can be either:
- the sa3-sm-music torch ckpt (slow; converts at load time), OR
- a pre-converted MLX file (.npz or .safetensors — fast path).

lora_paths: optional list of LoRA adapters (.safetensors / PEFT dir) to merge
into the weights at load time. lora_strength scales every adapter's delta.
See models/defs/lora_merge.py.
"""
p = str(weights_path)
if p.endswith(".npz") or p.endswith(".safetensors"):
wd = dict(mx.load(p))
else:
wd = convert_weights_from_torch_ckpt(p)

if lora_paths:
from .lora_merge import merge_loras_into_weights
stats = merge_loras_into_weights(wd, lora_paths, strength=lora_strength, log=lora_log)
lora_log(f"lora: merged {stats['merged']} layer(s) from {stats['adapters']} adapter(s)")

model = DiT(T_lat=T_lat)
wd_list = [(k, v.astype(dtype)) for k, v in wd.items()]
model.load_weights(wd_list, strict=False)
Expand Down
12 changes: 11 additions & 1 deletion optimized/mlx/models/defs/dit_mlx_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,16 @@ def convert_weights(safetensors_path, out_path=None):
return out


def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False):
def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False,
lora_paths=None, lora_strength=1.0, lora_log=print):
"""Build MLX DiT and load weights.

weights_path: either the .safetensors (we'll convert in-memory) or a
pre-converted .safetensors-mlx file.

lora_paths: optional list of LoRA adapters (.safetensors / PEFT dir) to merge
into the weights at load time. lora_strength scales every adapter's delta.
See models/defs/lora_merge.py.
"""
weights_path = str(weights_path)
if weights_path.endswith(".safetensors") and ("medium-ARC" in weights_path):
Expand All @@ -418,6 +423,11 @@ def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False):
else:
wd = dict(mx.load(weights_path))

if lora_paths:
from .lora_merge import merge_loras_into_weights
stats = merge_loras_into_weights(wd, lora_paths, strength=lora_strength, log=lora_log)
lora_log(f"lora: merged {stats['merged']} layer(s) from {stats['adapters']} adapter(s)")

model = DiT(T_lat=T_lat)

# Cast to target dtype (no-op when already at `dtype`).
Expand Down
Loading
Loading