Skip to content

Diffusion models#54

Open
mtlaiu wants to merge 12 commits into
ORNL:mainfrom
mtlaiu:diffusion
Open

Diffusion models#54
mtlaiu wants to merge 12 commits into
ORNL:mainfrom
mtlaiu:diffusion

Conversation

@mtlaiu
Copy link
Copy Markdown
Collaborator

@mtlaiu mtlaiu commented May 20, 2026

  • Added EDM-based diffusion model option to avit, svit, vit, and turbt.
  • Added config yaml files for training conditional diffusion models on the miniweather dataset
  • Added scripts for training diffusion models and generating samples from trained diffusion models.

To be tested:

  • Backward compatibility with the diffusion or cond_diffusion options turned off
  • Effect of the persample_normalize flag
  • Applicability to other relevant datasets

@mtlaiu mtlaiu requested a review from pzhanggit May 20, 2026 20:28
@mtlaiu mtlaiu self-assigned this May 20, 2026
@mtlaiu
Copy link
Copy Markdown
Collaborator Author

mtlaiu commented May 20, 2026

Attach some generated results. These results are from models that are trained very briefly.

AViT results:
grid_visualization_batch_6_stat

SVit results:
grid_visualization_batch_6_stat

Vit results:
grid_visualization_batch_6_stat

TurbT results:
grid_visualization_batch_6_stat

Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mtlaiu thanks! It looks good. I have a few general comments and clairification questions.

Comment thread examples/config/Demo_MW_diffusion_avit.yaml Outdated
Comment thread examples/config/Demo_MW_diffusion_svit.yaml Outdated
Comment thread examples/config/Demo_MW_diffusion_TT.yaml Outdated
Comment thread examples/config/Demo_MW_diffusion_vit.yaml Outdated
Comment thread examples/submit_batch_generate.sh
Comment thread matey/generate.py Outdated
Comment thread matey/train.py Outdated
Comment thread matey/train.py Outdated
Comment on lines +740 to +750
if getattr(self.params, "diffusion", False):
logs = {'valid_EDMloss': torch.zeros(1).to(self.device),
'valid_rmse': torch.zeros(1).to(self.device),
'valid_nrmse': torch.zeros(1).to(self.device),
'valid_l1': torch.zeros(1).to(self.device),
'valid_ssim': torch.zeros(1).to(self.device)}
else:
logs = {'valid_rmse': torch.zeros(1).to(self.device),
'valid_nrmse': torch.zeros(1).to(self.device),
'valid_l1': torch.zeros(1).to(self.device),
'valid_ssim': torch.zeros(1).to(self.device)}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread matey/train.py
Comment thread matey/generate.py Outdated

x_next = init_inp * t_steps[0]

for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we wrap this denoising loop into a standalone function and move it to diffusion_model.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a sampler class in generate.py so that it's easier to test other samplers if needed. I did not put it in diffusion_model.py because the sampler does not need to be tied to the diffusion model choice.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe put it somewhere as a utils function? I'm thinking this piece of code is reusable generally.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a file for the sampler class. There are other samplers that can be included. I don't immediately see where these samplers could be used other than generating samples from score-based diffusion models though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants