Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions lightx2v_train/configs/dopsd/flux2_klein_dopsd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
model:
name: flux2_klein
pretrained_model_name_or_path: /data/nvme7/HF/hub/models--black-forest-labs--FLUX.2-klein-9B/snapshots/92196c8e11f7b6cf2b7493e037d8c5345c559216
max_sequence_length: 512
text_encoder_out_layers: [9, 18, 27]
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 4
prompt_dropout_rate: 0.0
target_area: 1048576
shuffle: true
data_path:
- /data/nvme7/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 4
target_area: 1048576
shuffle: false
data_path:
- /data/nvme7/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
timestep_distribution: logitnormal
logitnormal_mean: 0.0
logitnormal_std: 1.0
min_t: 0.001
max_t: 1.0
time_shift_settings:
do_time_shift: true
shift_type: exponential
time_shift_power: 1.0
dynamic_shift: true
shift_mu_strategy: flux2_empirical
shift_mu_num_steps: 4
patch_size: [1, 1]

training:
method: dopsd
max_train_iters: 4000
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 200
save_total_limit: 20
dopsd:
num_training_steps: 4
ema_decay: 0.9999
ema_decay_warmup: 0.999
ema_decay_warmup_iters: 200
# early-step loss weighting: t0, t1, t2, t3
step_loss_weights: [8.0, 4.0, 2.0, 1.0]
# teacher edit branch only; student uses dataset prompt (T2I)
edit_sys_prompt: "The output must be exactly the same as the reference image."
# teacher prompt: false = only edit_sys_prompt (reference image carries content)
# true = "{dataset_prompt} {edit_sys_prompt}" (original D-OPSD default)
teacher_use_dataset_prompt: false
# save student/teacher x0 trajectory grid during training (not at infer)
trajectory_every_iters: ${training.save_every_iters}
lora:
rank: 64
alpha: 128
target_modules:
- to_q
- to_k
- to_v
- to_out.0
- add_q_proj
- add_k_proj
- add_v_proj
- to_add_out
- to_qkv_mlp_proj
optimizer:
learning_rate: 0.00005
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.0
adam_epsilon: 0.00000001
output_dir: ./output_train/dopsd_editcontext_ema0.9999_onpolicy_4steptrain_9b_corgi_bsz1_lora_lr5e-5_earlyw

inference:
method: image_infer
negative_prompt: ""
default_width: 1024
default_height: 1024
num_inference_steps: 4
enable_cfg: false
cfg_guidance_scale: 1.0
seed: 42
output_dir: ./output_infer/dopsd_editcontext_ema0.9999_onpolicy_4steptrain_9b_corgi_bsz1_lora_lr5e-5_earlyw
infer_every_iters: ${training.save_every_iters}

logging:
rank_zero_only: true
train_log_every_iters: 10
infer_log_every_steps: 10

resume:
auto_resume: true
18 changes: 9 additions & 9 deletions lightx2v_train/configs/lora/flux2_klein_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ data:
shuffle: true
# examples: https://github.com/ModelTC/LightX2V_train_data_examples
data_path:
- /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/train.jsonl
- /data/nvme7/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/val.jsonl
- /data/nvme7/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
Expand All @@ -35,23 +35,23 @@ scheduler:
time_shift_power: 1.0
dynamic_shift: true
shift_mu_strategy: flux2_empirical
shift_mu_num_steps: 50
shift_mu_num_steps: 4
# Flux2 latents are already 2x2-patchified before scheduler shift length calculation.
patch_size: [1, 1] # [H, W]

training:
method: lora
max_train_iters: 3000
gradient_accumulation_iters: 1
gradient_checkpointing: true
gradient_checkpointing: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_every_iters: 200
save_total_limit: 10
lora:
rank: 16
alpha: 16
alpha: 32
target_modules:
- to_q
- to_k
Expand All @@ -75,9 +75,9 @@ inference:
negative_prompt: ""
default_width: 1024
default_height: 1024
num_inference_steps: 50
enable_cfg: true
cfg_guidance_scale: 4.0
num_inference_steps: 4
enable_cfg: false
cfg_guidance_scale: 1.0
seed: 42
output_dir: ./output_infer/flux2_klein_lora
infer_every_iters: ${training.save_every_iters}
Expand Down
47 changes: 47 additions & 0 deletions lightx2v_train/lightx2v_train/infer/dopsd_trajectory_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from PIL import Image, ImageDraw, ImageFont


def _fit_height(image, height):
if image.height == height:
return image
width = max(1, int(image.width * height / image.height))
return image.resize((width, height), Image.Resampling.LANCZOS)


def save_student_teacher_trajectory_grid(student_step_images, teacher_step_images, save_path):
if len(student_step_images) != len(teacher_step_images):
raise ValueError("student and teacher trajectory lengths must match")

pad = 12
header_h = 32
row_label_w = 56
font = ImageFont.load_default()
num_steps = len(student_step_images)
if num_steps == 0:
return

row_h = max(img.height for img in student_step_images + teacher_step_images)
student_cols = [_fit_height(img.convert("RGB"), row_h) for img in student_step_images]
teacher_cols = [_fit_height(img.convert("RGB"), row_h) for img in teacher_step_images]

panel_w = student_cols[0].width + pad + teacher_cols[0].width
canvas_h = header_h + num_steps * (row_h + pad) + pad
canvas_w = row_label_w + pad + panel_w + pad
canvas = Image.new("RGB", (canvas_w, canvas_h), (255, 255, 255))
draw = ImageDraw.Draw(canvas)

header_y = 8
draw.text((row_label_w + pad + 8, header_y), "Student", fill=(0, 0, 0), font=font)
draw.text((row_label_w + pad + student_cols[0].width + pad + 8, header_y), "Teacher", fill=(0, 0, 0), font=font)

y = header_h
for step_idx, (student_img, teacher_img) in enumerate(zip(student_cols, teacher_cols)):
draw.text((8, y + (row_h - 10) // 2), f"t{step_idx}", fill=(0, 0, 0), font=font)
x_student = row_label_w + pad
canvas.paste(student_img, (x_student, y))
x_teacher = x_student + student_img.width + pad
canvas.paste(teacher_img, (x_teacher, y))
y += row_h + pad

save_path = str(save_path)
canvas.save(save_path)
85 changes: 74 additions & 11 deletions lightx2v_train/lightx2v_train/model_zoo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,64 @@ def add_lora(self, rank, alpha, target_modules):
)
self.denoiser_module().add_adapter(lora_config)

def add_dual_lora(
self,
rank,
alpha,
target_modules,
student_adapter="student",
teacher_adapter="teacher",
init_teacher_from_student=True,
):
lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights="gaussian",
target_modules=target_modules,
)
denoiser = self.denoiser_module()
denoiser.requires_grad_(False)
denoiser.add_adapter(lora_config, adapter_name=student_adapter)
denoiser.add_adapter(lora_config, adapter_name=teacher_adapter)
denoiser.set_adapter(student_adapter)
if init_teacher_from_student:
self.copy_lora_adapter_weights(student_adapter, teacher_adapter)

@torch.no_grad()
def copy_lora_adapter_weights(self, src_adapter, dst_adapter):
named_params = dict(self.denoiser_module().named_parameters())
for name, param in named_params.items():
if src_adapter not in name:
continue
dst_name = name.replace(src_adapter, dst_adapter)
if dst_name in named_params:
named_params[dst_name].data.copy_(param.data)

def set_active_adapter(self, adapter_name):
self.denoiser_module().set_adapter(adapter_name)

def set_dual_lora_trainable(self, student_adapter="student", teacher_adapter="teacher"):
denoiser = self.denoiser_module()
denoiser.requires_grad_(False)
denoiser.train()
for name, param in denoiser.named_parameters():
if student_adapter in name and "lora" in name:
param.requires_grad = True
else:
param.requires_grad = False

@torch.no_grad()
def ema_update_lora_adapter(self, src_adapter="student", dst_adapter="teacher", ema_decay=0.999):
named_params = dict(self.denoiser_module().named_parameters())
for name, src_param in named_params.items():
if src_adapter not in name:
continue
dst_name = name.replace(src_adapter, dst_adapter)
if dst_name not in named_params:
continue
dst_param = named_params[dst_name]
dst_param.data.mul_(ema_decay).add_(src_param.data, alpha=1.0 - ema_decay)

def set_lora_trainable(self):
denoiser = self.denoiser_module()
denoiser.requires_grad_(False)
Expand Down Expand Up @@ -130,43 +188,48 @@ def unload_lora_for_infer(self):
self.denoiser_module().delete_adapters(adapter_name)
self._infer_lora_adapter_name = None

def save_lora_weights(self, save_dir):
peft_state_dict = self._get_lora_state_dict_for_save()
def save_lora_weights(self, save_dir, adapter_name=None, weights_subdir=None):
peft_state_dict = self._get_lora_state_dict_for_save(adapter_name=adapter_name)
if not is_main_process():
return

output_dir = os.path.join(save_dir, weights_subdir) if weights_subdir else save_dir
os.makedirs(output_dir, exist_ok=True)
lora_state_dict = convert_state_dict_to_diffusers(peft_state_dict)
if hasattr(self.pipeline_cls, "save_lora_weights"):
self.pipeline_cls.save_lora_weights(save_dir, lora_state_dict, safe_serialization=True)
self.pipeline_cls.save_lora_weights(output_dir, lora_state_dict, safe_serialization=True)
else:
save_file(lora_state_dict, f"{save_dir}/pytorch_lora_weights.safetensors")
save_file(lora_state_dict, os.path.join(output_dir, "pytorch_lora_weights.safetensors"))

def _get_lora_state_dict_for_save(self):
def _get_lora_state_dict_for_save(self, adapter_name=None):
denoiser = self.denoiser_module()
peft_kwargs = {} if adapter_name is None else {"adapter_name": adapter_name}
if not is_fsdp2_module(denoiser):
return get_peft_model_state_dict(denoiser)
return get_peft_model_state_dict(denoiser, **peft_kwargs)

options = StateDictOptions(
full_state_dict=True,
cpu_offload=True,
ignore_frozen_params=True,
ignore_frozen_params=False,
strict=False,
)
state_dict, _ = get_state_dict(denoiser, (), options=options)
if not is_main_process():
return {}
return get_peft_model_state_dict(denoiser, state_dict=state_dict)
return get_peft_model_state_dict(denoiser, state_dict=state_dict, **peft_kwargs)

def load_lora_weights_for_resume(self, lora_path):
raw = load_file(os.path.join(lora_path, "pytorch_lora_weights.safetensors"))
def load_lora_weights_for_resume(self, lora_path, adapter_name=None, weights_subdir=None):
weights_dir = os.path.join(lora_path, weights_subdir) if weights_subdir else lora_path
raw = load_file(os.path.join(weights_dir, "pytorch_lora_weights.safetensors"))
peft_state_dict = {}
for key, value in raw.items():
new_key = key.removeprefix("transformer.")
new_key = new_key.replace(".lora.down.weight", ".lora_A.weight")
new_key = new_key.replace(".lora.up.weight", ".lora_B.weight")
peft_state_dict[new_key] = value

incompatible = set_peft_model_state_dict(self.denoiser_module(), peft_state_dict)
load_kwargs = {} if adapter_name is None else {"adapter_name": adapter_name}
incompatible = set_peft_model_state_dict(self.denoiser_module(), peft_state_dict, **load_kwargs)
if incompatible and incompatible.unexpected_keys:
logger.warning("Unexpected keys when resuming LoRA: {}", incompatible.unexpected_keys)

Expand Down
Loading
Loading