Skip to content

[Train] Qwen-image-edit lora training#1151

Open
Musisoul wants to merge 3 commits into
mainfrom
edit_lora_train
Open

[Train] Qwen-image-edit lora training#1151
Musisoul wants to merge 3 commits into
mainfrom
edit_lora_train

Conversation

@Musisoul

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Qwen-Image-Edit-2511 LoRA training and inference, adding the QwenImageEditModel along with corresponding configurations and run scripts. It also updates the dataset and inference pipelines to support source-image-conditioned generation. The review feedback highlights critical issues in the inference pipeline: the dataset-wide has_source_condition flag incorrectly disables classifier-free guidance (CFG) for samples without source images in mixed datasets, which should be resolved by checking conditions per-sample and lazily initializing static_neg_cond. Additionally, caching should be implemented in _load_dummy_sample to prevent redundant disk I/O, and prepare_denoiser_input in QwenImageEditModel needs to handle the absence of source latents gracefully to avoid potential KeyError exceptions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 61 to +70
self.enable_cfg = self.infer_config.get("enable_cfg", True)
has_source_condition = any(_has_source_images(sample) for sample in samples)
if self.enable_cfg:
self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0)
negative_prompt = self.infer_config.get("negative_prompt", " ")
neg_cond = self.model.encode_condition({"prompt": negative_prompt})
static_neg_cond = None if has_source_condition else self.model.encode_condition({"prompt": negative_prompt})
else:
self.guidance_scale = None
neg_cond = None
negative_prompt = None
static_neg_cond = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In a mixed dataset where some samples have source images and others do not, has_source_condition will be True. This causes static_neg_cond to be initialized to None. Consequently, for any sample without source images, neg_cond will incorrectly resolve to None instead of the encoded negative prompt, breaking classifier-free guidance (CFG). Removing has_source_condition and initializing static_neg_cond to None allows us to lazily encode it only when needed.

Suggested change
self.enable_cfg = self.infer_config.get("enable_cfg", True)
has_source_condition = any(_has_source_images(sample) for sample in samples)
if self.enable_cfg:
self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0)
negative_prompt = self.infer_config.get("negative_prompt", " ")
neg_cond = self.model.encode_condition({"prompt": negative_prompt})
static_neg_cond = None if has_source_condition else self.model.encode_condition({"prompt": negative_prompt})
else:
self.guidance_scale = None
neg_cond = None
negative_prompt = None
static_neg_cond = None
self.enable_cfg = self.infer_config.get("enable_cfg", True)
if self.enable_cfg:
self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0)
negative_prompt = self.infer_config.get("negative_prompt", " ")
static_neg_cond = None
else:
self.guidance_scale = None
negative_prompt = None
static_neg_cond = None

Comment on lines +88 to +96
if self.enable_cfg:
if has_source_condition:
neg_sample = dict(infer_sample)
neg_sample["prompt"] = negative_prompt
neg_cond = self.model.encode_condition(neg_sample)
else:
neg_cond = static_neg_cond
else:
neg_cond = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Instead of using the dataset-wide has_source_condition flag, check if the current infer_sample has source images using _has_source_images(infer_sample). If it does not, lazily initialize and use static_neg_cond. This ensures correct CFG conditioning for mixed datasets and avoids redundant encoding.

Suggested change
if self.enable_cfg:
if has_source_condition:
neg_sample = dict(infer_sample)
neg_sample["prompt"] = negative_prompt
neg_cond = self.model.encode_condition(neg_sample)
else:
neg_cond = static_neg_cond
else:
neg_cond = None
if self.enable_cfg:
if _has_source_images(infer_sample):
neg_sample = dict(infer_sample)
neg_sample["prompt"] = negative_prompt
neg_cond = self.model.encode_condition(neg_sample)
else:
if static_neg_cond is None:
static_neg_cond = self.model.encode_condition({"prompt": negative_prompt})
neg_cond = static_neg_cond
else:
neg_cond = None

Comment on lines +34 to +38
def _load_dummy_sample(self, samples):
for index, sample in enumerate(samples):
if _has_source_images(sample):
return self._load_infer_sample(index, " ")
return {"prompt": " "}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _load_dummy_sample method is called repeatedly for every dummy slot when has_sample is False. Since it loads the actual images from disk and processes them every time, this causes redundant disk I/O and CPU/GPU overhead. Caching the dummy sample on the inferencer instance avoids this bottleneck.

Suggested change
def _load_dummy_sample(self, samples):
for index, sample in enumerate(samples):
if _has_source_images(sample):
return self._load_infer_sample(index, " ")
return {"prompt": " "}
def _load_dummy_sample(self, samples):
if hasattr(self, "_dummy_sample"):
return self._dummy_sample
for index, sample in enumerate(samples):
if _has_source_images(sample):
self._dummy_sample = self._load_infer_sample(index, " ")
return self._dummy_sample
self._dummy_sample = {"prompt": " "}
return self._dummy_sample

Comment on lines +161 to +176
def prepare_denoiser_input(self, noisy_latent, condition=None):
if condition is None:
raise ValueError("QwenImageEditModel.prepare_denoiser_input requires condition.")

n = noisy_latent.shape[0]
h, w = noisy_latent.shape[3], noisy_latent.shape[4]
packed = QwenImageEditPlusPipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w)
hidden_states = torch.cat([packed, condition["source_latents"]], dim=1)
img_shapes = [[(1, h // 2, w // 2), *condition["source_img_shapes"]]] * n
return QwenImageEditDenoiserInput(
hidden_states=hidden_states,
target_token_length=packed.shape[1],
img_shapes=img_shapes,
height=h,
width=w,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The prepare_denoiser_input method assumes condition always contains "source_latents" and "source_img_shapes". If a sample does not have source images (e.g., during dummy evaluation or mixed datasets), this will raise a KeyError. Handling the absence of source latents gracefully ensures robustness.

    def prepare_denoiser_input(self, noisy_latent, condition=None):
        if condition is None:
            raise ValueError("QwenImageEditModel.prepare_denoiser_input requires condition.")

        n = noisy_latent.shape[0]
        h, w = noisy_latent.shape[3], noisy_latent.shape[4]
        packed = QwenImageEditPlusPipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w)
        if "source_latents" in condition:
            hidden_states = torch.cat([packed, condition["source_latents"]], dim=1)
            img_shapes = [[(1, h // 2, w // 2), *condition["source_img_shapes"]]] * n
        else:
            hidden_states = packed
            img_shapes = [[(1, h // 2, w // 2)]] * n
        return QwenImageEditDenoiserInput(
            hidden_states=hidden_states,
            target_token_length=packed.shape[1],
            img_shapes=img_shapes,
            height=h,
            width=w,
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant