[Train] Qwen-image-edit lora training#1151
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Qwen-Image-Edit-2511 LoRA training and inference, adding the QwenImageEditModel along with corresponding configurations and run scripts. It also updates the dataset and inference pipelines to support source-image-conditioned generation. The review feedback highlights critical issues in the inference pipeline: the dataset-wide has_source_condition flag incorrectly disables classifier-free guidance (CFG) for samples without source images in mixed datasets, which should be resolved by checking conditions per-sample and lazily initializing static_neg_cond. Additionally, caching should be implemented in _load_dummy_sample to prevent redundant disk I/O, and prepare_denoiser_input in QwenImageEditModel needs to handle the absence of source latents gracefully to avoid potential KeyError exceptions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self.enable_cfg = self.infer_config.get("enable_cfg", True) | ||
| has_source_condition = any(_has_source_images(sample) for sample in samples) | ||
| if self.enable_cfg: | ||
| self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0) | ||
| negative_prompt = self.infer_config.get("negative_prompt", " ") | ||
| neg_cond = self.model.encode_condition({"prompt": negative_prompt}) | ||
| static_neg_cond = None if has_source_condition else self.model.encode_condition({"prompt": negative_prompt}) | ||
| else: | ||
| self.guidance_scale = None | ||
| neg_cond = None | ||
| negative_prompt = None | ||
| static_neg_cond = None |
There was a problem hiding this comment.
In a mixed dataset where some samples have source images and others do not, has_source_condition will be True. This causes static_neg_cond to be initialized to None. Consequently, for any sample without source images, neg_cond will incorrectly resolve to None instead of the encoded negative prompt, breaking classifier-free guidance (CFG). Removing has_source_condition and initializing static_neg_cond to None allows us to lazily encode it only when needed.
| self.enable_cfg = self.infer_config.get("enable_cfg", True) | |
| has_source_condition = any(_has_source_images(sample) for sample in samples) | |
| if self.enable_cfg: | |
| self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0) | |
| negative_prompt = self.infer_config.get("negative_prompt", " ") | |
| neg_cond = self.model.encode_condition({"prompt": negative_prompt}) | |
| static_neg_cond = None if has_source_condition else self.model.encode_condition({"prompt": negative_prompt}) | |
| else: | |
| self.guidance_scale = None | |
| neg_cond = None | |
| negative_prompt = None | |
| static_neg_cond = None | |
| self.enable_cfg = self.infer_config.get("enable_cfg", True) | |
| if self.enable_cfg: | |
| self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0) | |
| negative_prompt = self.infer_config.get("negative_prompt", " ") | |
| static_neg_cond = None | |
| else: | |
| self.guidance_scale = None | |
| negative_prompt = None | |
| static_neg_cond = None |
| if self.enable_cfg: | ||
| if has_source_condition: | ||
| neg_sample = dict(infer_sample) | ||
| neg_sample["prompt"] = negative_prompt | ||
| neg_cond = self.model.encode_condition(neg_sample) | ||
| else: | ||
| neg_cond = static_neg_cond | ||
| else: | ||
| neg_cond = None |
There was a problem hiding this comment.
Instead of using the dataset-wide has_source_condition flag, check if the current infer_sample has source images using _has_source_images(infer_sample). If it does not, lazily initialize and use static_neg_cond. This ensures correct CFG conditioning for mixed datasets and avoids redundant encoding.
| if self.enable_cfg: | |
| if has_source_condition: | |
| neg_sample = dict(infer_sample) | |
| neg_sample["prompt"] = negative_prompt | |
| neg_cond = self.model.encode_condition(neg_sample) | |
| else: | |
| neg_cond = static_neg_cond | |
| else: | |
| neg_cond = None | |
| if self.enable_cfg: | |
| if _has_source_images(infer_sample): | |
| neg_sample = dict(infer_sample) | |
| neg_sample["prompt"] = negative_prompt | |
| neg_cond = self.model.encode_condition(neg_sample) | |
| else: | |
| if static_neg_cond is None: | |
| static_neg_cond = self.model.encode_condition({"prompt": negative_prompt}) | |
| neg_cond = static_neg_cond | |
| else: | |
| neg_cond = None |
| def _load_dummy_sample(self, samples): | ||
| for index, sample in enumerate(samples): | ||
| if _has_source_images(sample): | ||
| return self._load_infer_sample(index, " ") | ||
| return {"prompt": " "} |
There was a problem hiding this comment.
The _load_dummy_sample method is called repeatedly for every dummy slot when has_sample is False. Since it loads the actual images from disk and processes them every time, this causes redundant disk I/O and CPU/GPU overhead. Caching the dummy sample on the inferencer instance avoids this bottleneck.
| def _load_dummy_sample(self, samples): | |
| for index, sample in enumerate(samples): | |
| if _has_source_images(sample): | |
| return self._load_infer_sample(index, " ") | |
| return {"prompt": " "} | |
| def _load_dummy_sample(self, samples): | |
| if hasattr(self, "_dummy_sample"): | |
| return self._dummy_sample | |
| for index, sample in enumerate(samples): | |
| if _has_source_images(sample): | |
| self._dummy_sample = self._load_infer_sample(index, " ") | |
| return self._dummy_sample | |
| self._dummy_sample = {"prompt": " "} | |
| return self._dummy_sample |
| def prepare_denoiser_input(self, noisy_latent, condition=None): | ||
| if condition is None: | ||
| raise ValueError("QwenImageEditModel.prepare_denoiser_input requires condition.") | ||
|
|
||
| n = noisy_latent.shape[0] | ||
| h, w = noisy_latent.shape[3], noisy_latent.shape[4] | ||
| packed = QwenImageEditPlusPipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w) | ||
| hidden_states = torch.cat([packed, condition["source_latents"]], dim=1) | ||
| img_shapes = [[(1, h // 2, w // 2), *condition["source_img_shapes"]]] * n | ||
| return QwenImageEditDenoiserInput( | ||
| hidden_states=hidden_states, | ||
| target_token_length=packed.shape[1], | ||
| img_shapes=img_shapes, | ||
| height=h, | ||
| width=w, | ||
| ) |
There was a problem hiding this comment.
The prepare_denoiser_input method assumes condition always contains "source_latents" and "source_img_shapes". If a sample does not have source images (e.g., during dummy evaluation or mixed datasets), this will raise a KeyError. Handling the absence of source latents gracefully ensures robustness.
def prepare_denoiser_input(self, noisy_latent, condition=None):
if condition is None:
raise ValueError("QwenImageEditModel.prepare_denoiser_input requires condition.")
n = noisy_latent.shape[0]
h, w = noisy_latent.shape[3], noisy_latent.shape[4]
packed = QwenImageEditPlusPipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w)
if "source_latents" in condition:
hidden_states = torch.cat([packed, condition["source_latents"]], dim=1)
img_shapes = [[(1, h // 2, w // 2), *condition["source_img_shapes"]]] * n
else:
hidden_states = packed
img_shapes = [[(1, h // 2, w // 2)]] * n
return QwenImageEditDenoiserInput(
hidden_states=hidden_states,
target_token_length=packed.shape[1],
img_shapes=img_shapes,
height=h,
width=w,
)
No description provided.