Skip to content

Add distributed intra-graph parallelism for GNN from unstructured mesh (graph splitting)#55

Open
pzhanggit wants to merge 2 commits into
ORNL:mainfrom
pzhanggit:dev-graphpartition
Open

Add distributed intra-graph parallelism for GNN from unstructured mesh (graph splitting)#55
pzhanggit wants to merge 2 commits into
ORNL:mainfrom
pzhanggit:dev-graphpartition

Conversation

@pzhanggit
Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit commented May 20, 2026

Splits a single large graph across multiple GPUs so that graphs too large to fit on one device can be trained at scale. Each rank owns a contiguous subgraph (METIS or random partition) and communicates boundary node embeddings with neighbouring ranks via halo exchange (set to be optional) before each GNN message-passing layer.

  • Graph partitioning: METIS and random; partition computed once and saved as per-rank .pt shards
  • GhostInfo: static communication descriptor (send/recv indices, counts); stored alongside each shard and reused every forward pass
  • HaloExchange_sync: non-blocking all-to-all boundary exchange using torch.distributed.nn.functional.all_to_all_single
  • GraphhMLP_stem / GraphhMLP_output: updated to accept (x, batch, edge_index, ghost_info, comm) tuple and call HaloExchange_sync inside (set to be optional) the conv loop before each aggregation
  • @jqyin extended the ringx capabilities to handle the varying sequence lengths at different ranks inside each sequence parallelism group,

Tested on MeshGraphNets airfoil dataset (about 5000 nodes) with 4-split on Frontier. Checked the ghost sync function works as expected. The slightly slower convergence is probably due to the per-sample based operations. In the small test (toy test that does not really require graph partition and random split with 60~70% ghost node ratio) the runtime also goes up about 6x when 4-split and another 4x when 4-split + ghostsync.

Baseline

Train loss: 0.07630394399166107. Valid loss: 0.052530430257320404
Train loss: 0.03503976762294769. Valid loss: 0.016517285257577896
Train loss: 0.018825851380825043. Valid loss: 0.015362713485956192
Train loss: 0.01457530353218317. Valid loss: 0.012161660008132458
Train loss: 0.01393199898302555. Valid loss: 0.01502196118235588
Train loss: 0.010915150865912437. Valid loss: 0.010785265825688839
Train loss: 0.01096295565366745. Valid loss: 0.012610268779098988
Train loss: 0.009878791868686676. Valid loss: 0.012079717591404915
Train loss: 0.010431858710944653. Valid loss: 0.010026732459664345
Train loss: 0.009767167270183563. Valid loss: 0.009851831011474133
Train loss: 0.009984089061617851. Valid loss: 0.012373477220535278
Train loss: 0.008688466623425484. Valid loss: 0.007659319322556257
Train loss: 0.009710071608424187. Valid loss: 0.010486176237463951
Train loss: 0.009276134893298149. Valid loss: 0.007898691110312939
Train loss: 0.008403142914175987. Valid loss: 0.0070739202201366425
Train loss: 0.00783892534673214. Valid loss: 0.009797191247344017
Train loss: 0.008047484792768955. Valid loss: 0.0077331215143203735
Train loss: 0.00839265063405037. Valid loss: 0.00820506177842617
Train loss: 0.008258790709078312. Valid loss: 0.006615082733333111
Train loss: 0.007981297560036182. Valid loss: 0.007329833693802357

4-split without ghostsync

Train loss: 0.07002869993448257. Valid loss: 0.05199467018246651
Train loss: 0.03797893226146698. Valid loss: 0.021329935640096664
Train loss: 0.027292316779494286. Valid loss: 0.024242661893367767
Train loss: 0.021580709144473076. Valid loss: 0.01791365258395672
Train loss: 0.019156362861394882. Valid loss: 0.01984536275267601
Train loss: 0.014413462951779366. Valid loss: 0.01518427487462759
Train loss: 0.014195697382092476. Valid loss: 0.015402238816022873
Train loss: 0.012909915298223495. Valid loss: 0.016513165086507797
Train loss: 0.013953088782727718. Valid loss: 0.013102341443300247
Train loss: 0.012600645422935486. Valid loss: 0.012878330424427986
Train loss: 0.013121137395501137. Valid loss: 0.015406733378767967
Train loss: 0.011871391907334328. Valid loss: 0.01048492081463337
Train loss: 0.012379081919789314. Valid loss: 0.013197503983974457
Train loss: 0.011682969518005848. Valid loss: 0.010015368461608887
Train loss: 0.011048511601984501. Valid loss: 0.009631349705159664
Train loss: 0.010559621267020702. Valid loss: 0.013247327879071236
Train loss: 0.010678818449378014. Valid loss: 0.010132926516234875
Train loss: 0.011013306677341461. Valid loss: 0.01070511806756258
Train loss: 0.011161084286868572. Valid loss: 0.008446738123893738
Train loss: 0.010009515099227428. Valid loss: 0.00967256911098957

4-split with ghossync

Train loss: 0.06637820601463318. Valid loss: 0.042067885398864746
Train loss: 0.03029005602002144. Valid loss: 0.01729542389512062
Train loss: 0.02310142107307911. Valid loss: 0.02151530049741268
Train loss: 0.020736943930387497. Valid loss: 0.018721411004662514
Train loss: 0.02033010683953762. Valid loss: 0.02171349711716175
Train loss: 0.01594850979745388. Valid loss: 0.017006253823637962
Train loss: 0.01592416688799858. Valid loss: 0.018114691600203514
Train loss: 0.014570271596312523. Valid loss: 0.017547519877552986
Train loss: 0.015512395650148392. Valid loss: 0.014922119677066803
Train loss: 0.014379918575286865. Valid loss: 0.014739586971700191
Train loss: 0.014310665428638458. Valid loss: 0.016797099262475967
Train loss: 0.012336082756519318. Valid loss: 0.011441249400377274
Train loss: 0.01386372372508049. Valid loss: 0.014565189369022846
Train loss: 0.012717802077531815. Valid loss: 0.011043708771467209
Train loss: 0.011677439324557781. Valid loss: 0.010698249563574791
Train loss: 0.011000201106071472. Valid loss: 0.013033013790845871
Train loss: 0.011258008889853954. Valid loss: 0.010900972411036491
Train loss: 0.01149249542504549. Valid loss: 0.011500910855829716
Train loss: 0.011140567250549793. Valid loss: 0.008320937864482403
Train loss: 0.01032090000808239. Valid loss: 0.010040022432804108

@pzhanggit pzhanggit requested a review from TsChala May 20, 2026 20:58
@pzhanggit pzhanggit self-assigned this May 20, 2026
Copy link
Copy Markdown
Collaborator

@TsChala TsChala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting! Overall it looks good, I found a few small typos/bugs. I also got a question regarding the norms, see the comments.

I think we should update the inference script accordingly as well, to have ghost info in the opts etc.

if self.use_MPI:
if self.group_rank == 0:
self._run_partitioning()
_MPI.COMM_WORLD.Barrier()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a bug, remove _ before MPI?

Comment thread matey/data_utils/utils.py
Comment on lines +18 to +23
from torch.utils.data import Dataset
from torch_geometric.utils import coalesce
import torch.nn.functional as F
from collections import defaultdict
from typing import Optional, List, Dict, Tuple, NamedTuple
import warnings
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of these are not used: Dataset, coalesce, Optional, Tuple, warnings

shard_path = rank_shard
else:
# fallback: full graph (single-rank mode)
shard_path = full_path
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe raise a warning/print out something when this happens for the user to be aware that it's falling back to single-rank mode.

Comment thread matey/data_utils/utils.py
maxdiff_ghost = 0.0

assert maxdiff_all<1e-6 and maxdiff_own<1e-6 and maxdiff_ghost<1e-6, (
f"Pei debugging ghost0 [rank {dist.get_rank(sequence_parallel_group)}] first halo maxdiff_all={maxdiff_all:.6e}, "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequence_parallel_group is not defined here. sequence_parallel_group --> comm

Comment on lines +407 to 410
h = HaloExchange_sync(h, ghost_info, comm)
h_in = h
h = conv(h, edge_index)
h = norm(h, batch)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand things correctly, if the HaloExchange_sync happens, then we will normalize based on the owned nodes and the ghost nodes as well. Is this the intended behavior? Wouldn't it make more sense to ignore the ghost nodes for normalization?

This also happens in the GraphhMLP_stem.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's intended. Strictly speaking, we probably should apply all the local operations on the owned nodes ONLY. But the implementation did not separate owned nodes and ghost nodes across all the operations, except for the sync here, for simplicity. I was expecting this would not cause much performance difference. What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I think it's fine as well. I just wanted to make sure it's intended and things are not falling through cracks.

Comment on lines +574 to +576
leadtime = self.leadtime_max #max(self.leadtime_max//2, 1)])
#else:
# raise ValueError(f"Fix leadtime for now but got {self.leadtime_fixed}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The autoregressive model forward is still not in main branch for graphs, so I'm not sure if this part make sense to be added now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants