From 39a5862ed5d6a8d8b4ff61cd9e614cb374e42732 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:35 -0500 Subject: [PATCH 01/95] added basic GRIT code Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 57 ++++++ gridfm_graphkit/models/grit_transformer.py | 195 +++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 examples/config/grit_pretraining.yaml create mode 100644 gridfm_graphkit/models/grit_transformer.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml new file mode 100644 index 00000000..a6566e09 --- /dev/null +++ b/examples/config/grit_pretraining.yaml @@ -0,0 +1,57 @@ +callbacks: + patience: 100 + tol: 0 +data: + baseMVA: 100 + learn_mask: false + mask_dim: 6 + mask_ratio: 0.5 + mask_type: rnd + mask_value: -1.0 + networks: + # - Texas2k_case1_2016summerpeak + - case24_ieee_rts + # - case118_ieee + # - case300_ieee + - case89_pegase + # - case240_pserc + normalization: baseMVAnorm + scenarios: + # - 5000 + - 5000 + - 5000 + # - 30000 + # - 50000 + # - 50000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 123 + input_dim: 9 + num_layers: 14 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer # +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.99 + losses: + - MaskedMSE + - PBE + accelerator: auto + devices: auto + strategy: auto diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 00000000..3ee5e8ec --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,195 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +from torch import nn +import torch +import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.models.gnn import GNNPreMP +from torch_geometric.graphgym.models.layer import (new_layer_config, + BatchNorm1dNode) +from torch_geometric.graphgym.register import register_network +from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + """ + def __init__(self, dim_in): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if cfg.dataset.node_encoder: + # Encode integer node features via nn.Embeddings + NodeEncoder = register.node_encoder_dict[ + cfg.dataset.node_encoder_name] + self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) + if cfg.dataset.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = cfg.gnn.dim_inner + if cfg.dataset.edge_encoder: + # Hard-limit max edge dim for PNA. + if 'PNA' in cfg.gt.layer_type: + cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + else: + cfg.gnn.dim_edge = cfg.gnn.dim_inner + # Encode integer edge features via nn.Embeddings + EdgeEncoder = register.edge_encoder_dict[ + cfg.dataset.edge_encoder_name] + self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + if cfg.dataset.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + + +@register_head('decoder_head') +class GNNDecoderHead(nn.Module): + """ + Predictoin head for encoder-decoder networks. + + Args: + dim_in (int): Input dimension # TODO update arg comments as needed + dim_out (int): Output dimension. For binary prediction, dim_out=1. + """ + + def __init__(self, dim_in, dim_out): + super(GNNDecoderHead, self).__init__() + + + + # note that the input and output dimensions are from the config file + # if we want this to be variable that will have to change with + # each layer + + # TODO consider use of a bottleneck + + # note the config is imported as in other modules + + # the number of config layers should apriori be different than the encoder + + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gnn.layers_decode): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + + + + self.layer_post_mp = MLP( + new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, + has_act=False, has_bias=True, cfg=cfg)) + + + + def _apply_index(self, batch): + return batch.x, batch.y + + def forward(self, batch): + batch = self.layers(batch) + + # follow GMAE here and make a final linear projection from the + # hiden dimension to the output dimension + batch = self.layer_post_mp(batch) + + pred, label = self._apply_index(batch) + #print('>>>>>>', pred.size(),label.size()) + return pred, label + + + +@MODELS_REGISTRY.register("GRIT") +class GritTransformer(torch.nn.Module): + ''' + The proposed GritTransformer (Graph Inductive Bias Transformer) + ''' + + def __init__(self, dim_in, dim_out): + super().__init__() + self.encoder = FeatureEncoder(dim_in) + dim_in = self.encoder.dim_in + + self.ablation = True + self.ablation = False + + if cfg.posenc_RRWP.enable: + self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ + (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) + rel_pe_dim = cfg.posenc_RRWP.ksteps + self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ + (rel_pe_dim, cfg.gnn.dim_edge, + pad_to_full_graph=cfg.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) + + + if cfg.gnn.layers_pre_mp > 0: + self.pre_mp = GNNPreMP( + dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) + dim_in = cfg.gnn.dim_inner + + assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + "The inner and hidden dims must match." + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + # global_model_type = "GritTransformer" + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gt.layers): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + GNNHead = register.head_dict[cfg.gnn.head] + self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + + return batch \ No newline at end of file From 922d6cefe51a19d31cff2b3cda09603e2600a349 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 02/95] initial connection of model to config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 35 +++- gridfm_graphkit/models/grit_transformer.py | 202 ++++++++------------- 2 files changed, 113 insertions(+), 124 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index a6566e09..904d6dc4 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -32,10 +32,41 @@ model: edge_dim: 2 hidden_size: 123 input_dim: 9 - num_layers: 14 + num_layers: 10 output_dim: 6 pe_dim: 20 - type: GPSTransformer # + type: GRIT #GPSTransformer # + layers_pre_mp: 0 + act: relu + encoder: + node_encoder: True + edge_encoder: True + node_encoder_name: TODO + node_encoder_bn: True + gt: + layer_type: GritTransformer + # layers: 10 + # n_heads: 8 + dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + # dropout: 0.0 + layer_norm: False + batch_norm: True + update_e: True + attn_dropout: 0.2 + attn: + clamp: 5. + act: 'relu' + full_attn: True + edge_enhance: True + O_e: True + norm_e: True + signed_sqrt: True + posenc_RRWP: + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3ee5e8ec..b09f5272 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,12 +1,10 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY from torch import nn import torch -import torch_geometric.graphgym.register as register -from torch_geometric.graphgym.config import cfg + from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) -from torch_geometric.graphgym.register import register_network from torch_geometric.graphgym.models.layer import new_layer_config, MLP @@ -17,114 +15,49 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension + + + TODO replace 'register' with local version of it + """ - def __init__(self, dim_in): + def __init__( + self, + dim_in, + dim_inner, + args + ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if cfg.dataset.node_encoder: + if args.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = register.node_encoder_dict[ - cfg.dataset.node_encoder_name] - self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) - if cfg.dataset.node_encoder_bn: + args.node_encoder_name] + self.node_encoder = NodeEncoder(dim_inner) + if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + new_layer_config(dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) # Update dim_in to reflect the new dimension fo the node features - self.dim_in = cfg.gnn.dim_inner - if cfg.dataset.edge_encoder: + self.dim_in = dim_inner + if args.edge_encoder: # Hard-limit max edge dim for PNA. - if 'PNA' in cfg.gt.layer_type: - cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed + dim_edge = min(128, dim_inner) else: - cfg.gnn.dim_edge = cfg.gnn.dim_inner + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + self.edge_encoder = EdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + new_layer_config(dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)) def forward(self, batch): for module in self.children(): batch = module(batch) return batch - - -@register_head('decoder_head') -class GNNDecoderHead(nn.Module): - """ - Predictoin head for encoder-decoder networks. - - Args: - dim_in (int): Input dimension # TODO update arg comments as needed - dim_out (int): Output dimension. For binary prediction, dim_out=1. - """ - - def __init__(self, dim_in, dim_out): - super(GNNDecoderHead, self).__init__() - - - - # note that the input and output dimensions are from the config file - # if we want this to be variable that will have to change with - # each layer - - # TODO consider use of a bottleneck - - # note the config is imported as in other modules - - # the number of config layers should apriori be different than the encoder - - - global_model_type = cfg.gt.get('layer_type', "GritTransformer") - - TransformerLayer = register.layer_dict.get(global_model_type) - - layers = [] - for l in range(cfg.gnn.layers_decode): - layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, - residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, - )) - # layers = [] - - self.layers = torch.nn.Sequential(*layers) - - - - self.layer_post_mp = MLP( - new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, - has_act=False, has_bias=True, cfg=cfg)) - - - - def _apply_index(self, batch): - return batch.x, batch.y - - def forward(self, batch): - batch = self.layers(batch) - - # follow GMAE here and make a final linear projection from the - # hiden dimension to the output dimension - batch = self.layer_post_mp(batch) - - pred, label = self._apply_index(batch) - #print('>>>>>>', pred.size(),label.size()) - return pred, label - @MODELS_REGISTRY.register("GRIT") @@ -133,60 +66,85 @@ class GritTransformer(torch.nn.Module): The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, dim_in, dim_out): + def __init__(self, args): super().__init__() - self.encoder = FeatureEncoder(dim_in) - dim_in = self.encoder.dim_in - self.ablation = True - self.ablation = False + # ### TODO remove default args not needed #### + # self.input_dim = + # self.hidden_dim = + # self.output_dim = + # self.edge_dim = + # self.num_layers = args.model.num_layers + # self.heads = getattr(args.model, "attention_head", 1) + # self.dropout = getattr(args.model, "dropout", 0.0) + # ### ### + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + dim_edge = args.model.edge_dim + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + + self.encoder = FeatureEncoder( + dim_in, + dim_inner, + args.model.encoder + ) # TODO add args + dim_in = self.encoder.dim_in + - if cfg.posenc_RRWP.enable: + if args.model.posenc_RRWP.enable: + # TODO connect 'register' to local version self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) - rel_pe_dim = cfg.posenc_RRWP.ksteps + (args.model.posenc_RRWP.ksteps, dim_inner) + rel_pe_dim = args.model.posenc_RRWP.ksteps self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, cfg.gnn.dim_edge, - pad_to_full_graph=cfg.gt.attn.full_attn, + (rel_pe_dim, dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. ) - if cfg.gnn.layers_pre_mp > 0: + if args.model.layers_pre_mp > 0: self.pre_mp = GNNPreMP( - dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) - dim_in = cfg.gnn.dim_inner + dim_in, dim_inner, args.model.layers_pre_mp) + dim_in = dim_inner - assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = cfg.gt.get('layer_type', "GritTransformer") + global_model_type = args.model.gt.layer_type # global_model_type = "GritTransformer" - + # TODO replace this with local register logic TransformerLayer = register.layer_dict.get(global_model_type) layers = [] - for l in range(cfg.gt.layers): + for ll in range(num_layers): layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, + norm_e=args.model.gt.attn.norm_e, + O_e=args.model.gt.attn.O_e, + cfg=args.model.gt, )) - # layers = [] - self.layers = torch.nn.Sequential(*layers) - GNNHead = register.head_dict[cfg.gnn.head] - self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + self.layers = nn.Sequential(*layers) + + self.decoder = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, dim_out), + ) def forward(self, batch): for module in self.children(): From e8281ac03c1f97fed0b57419b96665fc4fb90020 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 03/95] collect model components and replace old register method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 346 +++++++++++++++++++++ gridfm_graphkit/models/grit_transformer.py | 79 +++-- gridfm_graphkit/models/rrwp_encoder.py | 192 ++++++++++++ 3 files changed, 583 insertions(+), 34 deletions(-) create mode 100644 gridfm_graphkit/models/grit_layer.py create mode 100644 gridfm_graphkit/models/rrwp_encoder.py diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 00000000..dc6cf976 --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,346 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter, scatter_max, scatter_add + +from grit.utils import negate_edge_index +from torch_geometric.graphgym.register import * +import opt_einsum as oe + +from yacs.config import CfgNode as CN + +import warnings + +def pyg_softmax(src, index, num_nodes=None): + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / ( + scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Proposed Attention Computation for GRIT + """ + + def __init__(self, in_dim, out_dim, num_heads, use_bias, + clamp=5., dropout=0., act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg=CN(), + **kwargs): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = act_dict[act]() + + if self.edge_enhance: + self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + raw_attn = score + score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get('wE', None) + + return h_out, e_out + + +@register_layer("GritTransformer") +class GritTransformerLayer(nn.Module): + """ + Proposed Transformer Layer for GRIT + """ + def __init__(self, in_dim, out_dim, num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, batch_norm=True, + residual=True, + act='relu', + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = cfg.get("update_e", True) + self.bn_momentum = cfg.bn_momentum + self.bn_no_runner = cfg.bn_no_runner + self.rezero = cfg.get("rezero", False) + + self.act = act_dict[act]() if act is not None else nn.Identity() + if cfg.get("attn", None) is None: + cfg.attn = dict() + self.use_attn = cfg.attn.get("use", True) + # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) + self.deg_scaler = cfg.attn.get("deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=cfg.attn.get("edge_enhance", True), + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + if cfg.attn.get('graphormer_attn', False): + self.attention = MultiHeadAttentionLayerGraphormerSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=True, + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + + + self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1,1)) + self.alpha2_h = nn.Parameter(torch.zeros(1,1)) + self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( + self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual, + super().__repr__(), + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") + deg = pyg.utils.degree(batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg + + diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index b09f5272..e3c60471 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -2,12 +2,42 @@ from torch import nn import torch +from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from grit_layer import GritTransformerLayer + +# TODO verify use from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) from torch_geometric.graphgym.models.layer import new_layer_config, MLP +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + if cfg.dataset.name in ['MNIST', 'CIFAR10']: + self.in_dim = 1 + elif cfg.dataset.name.startswith('attributed_triangle-'): + self.in_dim = 2 + else: + raise ValueError("Input edge feature dim is required to be hardset " + "or refactored to use a cfg option.") + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + class FeatureEncoder(torch.nn.Module): """ @@ -16,9 +46,6 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension - - TODO replace 'register' with local version of it - """ def __init__( self, @@ -30,9 +57,7 @@ def __init__( self.dim_in = dim_in if args.node_encoder: # Encode integer node features via nn.Embeddings - NodeEncoder = register.node_encoder_dict[ - args.node_encoder_name] - self.node_encoder = NodeEncoder(dim_inner) + self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config(dim_inner, -1, -1, has_act=False, @@ -46,9 +71,7 @@ def __init__( else: dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - EdgeEncoder = register.edge_encoder_dict[ - cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(dim_edge) + self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config(dim_edge, -1, -1, has_act=False, @@ -65,19 +88,9 @@ class GritTransformer(torch.nn.Module): ''' The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, args): super().__init__() - # ### TODO remove default args not needed #### - # self.input_dim = - # self.hidden_dim = - # self.output_dim = - # self.edge_dim = - # self.num_layers = args.model.num_layers - # self.heads = getattr(args.model, "attention_head", 1) - # self.dropout = getattr(args.model, "dropout", 0.0) - # ### ### dim_in = args.model.input_dim dim_out = args.model.output_dim @@ -96,16 +109,19 @@ def __init__(self, args): if args.model.posenc_RRWP.enable: - # TODO connect 'register' to local version - self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (args.model.posenc_RRWP.ksteps, dim_inner) + + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.model.posenc_RRWP.ksteps, + dim_inner + ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, dim_edge, - pad_to_full_graph=args.model.gt.attn.full_attn, - add_node_attr_as_self_loop=False, - fill_value=0. - ) + self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + rel_pe_dim, + dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) if args.model.layers_pre_mp > 0: @@ -116,14 +132,9 @@ def __init__(self, args): assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = args.model.gt.layer_type - # global_model_type = "GritTransformer" - # TODO replace this with local register logic - TransformerLayer = register.layer_dict.get(global_model_type) - layers = [] for ll in range(num_layers): - layers.append(TransformerLayer( + layers.append(GritTransformerLayer( in_dim=args.model.gt.dim_hidden, out_dim=args.model.gt.dim_hidden, num_heads=num_heads, diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 00000000..f98118e9 --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,192 @@ +''' + The RRWP encoder for GRIT (ours) +''' +import torch +from torch import nn +from torch.nn import functional as F +from ogb.utils.features import get_bond_feature_dims +import torch_sparse + +import torch_geometric as pyg +from torch_geometric.graphgym.register import ( + register_edge_encoder, + register_node_encoder, +) + +from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_scatter import scatter +import warnings + +def full_edge_index(edge_index, batch=None): + """ + Retunr the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, + dim=0, dim_size=batch_size, reduce='add') + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, + device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + ''' + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + ''' + def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, + pad_to_full_graph=True, fill_value=0., + add_node_attr_as_self_loop=False, + overwrite_old_attr=False): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0. + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + + #print('-->>>>', edge_attr.size(), rrwp_val.size()) + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, batch.num_nodes, + op="add" + ) + + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = torch_sparse.coalesce( + out_idx, out_val, batch.num_nodes, batch.num_nodes, + op="add" + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return f"{self.__class__.__name__}" \ + f"(pad_to_full_graph={self.pad_to_full_graph}," \ + f"fill_value={self.fill_value}," \ + f"{self.fc.__repr__()})" + + + From a67e52285b69d5fae828b5c6231a586daa6d629e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 04/95] clean up imported layers and encoders Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 6 +-- gridfm_graphkit/models/grit_layer.py | 3 -- gridfm_graphkit/models/grit_transformer.py | 50 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 14 ++---- 4 files changed, 29 insertions(+), 44 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 904d6dc4..dc6f3a11 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -35,8 +35,7 @@ model: num_layers: 10 output_dim: 6 pe_dim: 20 - type: GRIT #GPSTransformer # - layers_pre_mp: 0 + type: GRIT act: relu encoder: node_encoder: True @@ -45,10 +44,7 @@ model: node_encoder_bn: True gt: layer_type: GritTransformer - # layers: 10 - # n_heads: 8 dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` - # dropout: 0.0 layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index dc6cf976..b4779806 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -6,8 +6,6 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter, scatter_max, scatter_add -from grit.utils import negate_edge_index -from torch_geometric.graphgym.register import * import opt_einsum as oe from yacs.config import CfgNode as CN @@ -141,7 +139,6 @@ def forward(self, batch): return h_out, e_out -@register_layer("GritTransformer") class GritTransformerLayer(nn.Module): """ Proposed Transformer Layer for GRIT diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e3c60471..715c25f1 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,15 +1,29 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY -from torch import nn import torch - +from torch import nn from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from grit_layer import GritTransformerLayer -# TODO verify use -from torch_geometric.graphgym.models.gnn import GNNPreMP -from torch_geometric.graphgym.models.layer import (new_layer_config, - BatchNorm1dNode) -from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + TODO fill in comments + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch class LinearNodeEncoder(torch.nn.Module): @@ -59,23 +73,16 @@ def __init__( # Encode integer node features via nn.Embeddings self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: - self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_inner, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.edge_encoder: - # Hard-limit max edge dim for PNA. - if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed - dim_edge = min(128, dim_inner) - else: - dim_edge = dim_inner + + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_edge, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -107,7 +114,6 @@ def __init__(self, args): ) # TODO add args dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( @@ -123,12 +129,6 @@ def __init__(self, args): fill_value=0. ) - - if args.model.layers_pre_mp > 0: - self.pre_mp = GNNPreMP( - dim_in, dim_inner, args.model.layers_pre_mp) - dim_in = dim_inner - assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index f98118e9..b73e463d 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -1,25 +1,18 @@ -''' +""" The RRWP encoder for GRIT (ours) -''' +""" import torch from torch import nn from torch.nn import functional as F -from ogb.utils.features import get_bond_feature_dims import torch_sparse -import torch_geometric as pyg -from torch_geometric.graphgym.register import ( - register_edge_encoder, - register_node_encoder, -) - from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops from torch_scatter import scatter import warnings def full_edge_index(edge_index, batch=None): """ - Retunr the Full batched sparse adjacency matrices given by edge indices. + Return the Full batched sparse adjacency matrices given by edge indices. Returns batched sparse adjacency matrices with exactly those edges that are not in the input `edge_index` while ignoring self-loops. Implementation inspired by `torch_geometric.utils.to_dense_adj` @@ -152,7 +145,6 @@ def forward(self, batch): # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - #print('-->>>>', edge_attr.size(), rrwp_val.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 6966f5ffc56a2384a2cb730324aa1bbf8269efb7 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:37 -0500 Subject: [PATCH 05/95] flow in basic structure for RRWP calculation Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 13 +- gridfm_graphkit/datasets/posenc_stats.py | 423 ++++++++++++++++++ .../datasets/powergrid_datamodule.py | 14 + gridfm_graphkit/datasets/rrwp.py | 103 +++++ 4 files changed, 547 insertions(+), 6 deletions(-) create mode 100644 gridfm_graphkit/datasets/posenc_stats.py create mode 100644 gridfm_graphkit/datasets/rrwp.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index dc6f3a11..05f31be6 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -26,6 +26,12 @@ data: test_ratio: 0.1 val_ratio: 0.1 workers: 4 + posenc_RRWP: # TODO maybe better with data section... + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False model: attention_head: 8 dropout: 0.1 @@ -57,12 +63,7 @@ model: O_e: True norm_e: True signed_sqrt: True - posenc_RRWP: - enable: True - ksteps: 21 - add_identity: True - add_node_attr: False - add_inverse: False + optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 00000000..492a0a69 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,423 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, + to_undirected, to_dense_adj) +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter_add +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp + + +def compute_posenc_stats(data, pe_types, is_undirected, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute, selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', + 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + # Basic preprocessing of the input graph. + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() + if laplacian_norm_type == 'none': + laplacian_norm_type = None + if is_undirected: + undir_edge_index = data.edge_index + else: + undir_edge_index = to_undirected(data.edge_index) + + # Eigen values and vectors. + evals, evects = None, None + if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: + # Eigen-decomposition with numpy, can be reused for Heat kernels. + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, + num_nodes=N) + ) + evals, evects = np.linalg.eigh(L.toarray()) + + if 'LapPE' in pe_types: + max_freqs=cfg.posenc_LapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm + elif 'EquivStableLapPE' in pe_types: + max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm + + data.EigVals, data.EigVecs = get_lap_decomp_stats( + evals=evals, evects=evects, + max_freqs=max_freqs, + eigvec_norm=eigvec_norm) + + if 'SignNet' in pe_types: + # Eigen-decomposition with numpy for SignNet. + norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() + if norm_type == 'none': + norm_type = None + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=norm_type, + num_nodes=N) + ) + evals_sn, evects_sn = np.linalg.eigh(L.toarray()) + data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( + evals=evals_sn, evects=evects_sn, + max_freqs=cfg.posenc_SignNet.eigen.max_freqs, + eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) + + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, + edge_index=data.edge_index, + num_nodes=N) + data.pestat_RWSE = rw_landing + + # Heat Kernels. + if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: + # Get the eigenvalues and eigenvectors of the regular Laplacian, + # if they have not yet been computed for 'eigen'. + if laplacian_norm_type is not None or evals is None or evects is None: + L_heat = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) + ) + evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) + else: + evals_heat, evects_heat = evals, evects + evals_heat = torch.from_numpy(evals_heat) + evects_heat = torch.from_numpy(evects_heat) + + # Get the full heat kernels. + if 'HKfullPE' in pe_types: + # The heat kernels can't be stored in the Data object without + # additional padding because in PyG's collation of the graphs the + # sizes of tensors must match except in dimension 0. Do this when + # the full heat kernels are actually used downstream by an Encoder. + raise NotImplementedError() + # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, + # kernel_times=kernel_param.times) + # data.pestat_HKdiagSE = hk_diag + # Get heat kernel diagonals in more efficient way. + if 'HKdiagSE' in pe_types: + kernel_param = cfg.posenc_HKdiagSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("Diffusion times are required for heat kernel") + hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, + kernel_times=kernel_param.times, + space_dim=0) + data.pestat_HKdiagSE = hk_diag + + # Electrostatic interaction inspired kernel. + if 'ElstaticSE' in pe_types: + elstatic = get_electrostatic_function_encoding(undir_edge_index, N) + data.pestat_ElstaticSE = elstatic + + if 'RRWP' in pe_types: + param = cfg.posenc_RRWP + transform = partial(add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + spd=param.spd, # by default False + ) + data = transform(data) + + return data + + +def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): + """Compute Laplacian eigen-decomposition-based PE stats of the given graph. + + Args: + evals, evects: Precomputed eigen-decomposition + max_freqs: Maximum number of top smallest frequencies / eigenvecs to use + eigvec_norm: Normalization for the eigen vectors of the Laplacian + Returns: + Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node + Tensor (num_nodes, max_freqs) of eigenvector values per node + """ + N = len(evals) # Number of nodes, including disconnected nodes. + + # Keep up to the maximum desired number of frequencies. + idx = evals.argsort()[:max_freqs] + evals, evects = evals[idx], np.real(evects[:, idx]) + evals = torch.from_numpy(np.real(evals)).clamp_min(0) + + # Normalize and pad eigen vectors. + evects = torch.from_numpy(evects).float() + evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) + if N < max_freqs: + EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) + else: + EigVecs = evects + + # Pad and save eigenvalues. + if N < max_freqs: + EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) + else: + EigVals = evals.unsqueeze(0) + EigVals = EigVals.repeat(N, 1).unsqueeze(2) + + return EigVals, EigVecs + + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + + +def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): + """Compute Heat kernel diagonal. + + This is a continuous function that represents a Gaussian in the Euclidean + space, and is the solution to the diffusion equation. + The random-walk diagonal should converge to this. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the diffusion diagonal by a factor `t^(space_dim/2)`. In + euclidean space, this correction means that the height of the + gaussian stays constant across time, if `space_dim` is the dimension + of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + heat_kernels_diag = [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels diagonal only for each time + eigvec_mul = evects ** 2 + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) + this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + + # Multiply by `t` to stabilize the values, since the gaussian height + # is proportional to `1/t` + heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) + heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) + + return heat_kernels_diag + + +def get_heat_kernels(evects, evals, kernel_times=[]): + """Compute full Heat diffusion kernels. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + """ + heat_kernels, rw_landing = [], [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels for each time + eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) + heat_kernels.append( + torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + ) + + heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) + + # Take the diagonal of each heat kernel, + # i.e. the landing probability of each of the random walks + rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) + + return heat_kernels, rw_landing + + +def get_electrostatic_function_encoding(edge_index, num_nodes): + """Kernel based on the electrostatic interaction between nodes. + """ + L = to_scipy_sparse_matrix( + *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) + ).todense() + L = torch.as_tensor(L) + Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) + A = deepcopy(L).abs() + A.fill_diagonal_(0) + DinvA = Dinv.matmul(A) + + electrostatic = torch.pinverse(L) + electrostatic = electrostatic - electrostatic.diag() + green_encoding = torch.stack([ + electrostatic.min(dim=0)[0], # Min of Vi -> j + electrostatic.max(dim=0)[0], # Max of Vi -> j + electrostatic.mean(dim=0), # Mean of Vi -> j + electrostatic.std(dim=0), # Std of Vi -> j + electrostatic.min(dim=1)[0], # Min of Vj -> i + electrostatic.max(dim=0)[0], # Max of Vj -> i + electrostatic.mean(dim=1), # Mean of Vj -> i + electrostatic.std(dim=1), # Std of Vj -> i + (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour + (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour + ], dim=1) + + return green_encoding + + +def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): + """ + Implement different eigenvector normalizations. + """ + + EigVals = EigVals.unsqueeze(0) + + if normalization == "L1": + # L1 normalization: eigvec / sum(abs(eigvec)) + denom = EigVecs.norm(p=1, dim=0, keepdim=True) + + elif normalization == "L2": + # L2 normalization: eigvec / sqrt(sum(eigvec^2)) + denom = EigVecs.norm(p=2, dim=0, keepdim=True) + + elif normalization == "abs-max": + # AbsMax normalization: eigvec / max|eigvec| + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + + elif normalization == "wavelength": + # AbsMax normalization, followed by wavelength multiplication: + # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom * 2 / np.pi + + elif normalization == "wavelength-asin": + # AbsMax normalization, followed by arcsin and wavelength multiplication: + # arcsin(eigvec / max|eigvec|) / sqrt(eigval) + denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) + EigVecs = torch.asin(EigVecs / denom_temp) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = eigval_denom + + elif normalization == "wavelength-soft": + # AbsSoftmax normalization, followed by wavelength multiplication: + # eigvec / (softmax|eigvec| * sqrt(eigval)) + denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom + + else: + raise ValueError(f"Unsupported normalization `{normalization}`") + + denom = denom.clamp_min(eps).expand_as(EigVecs) + EigVecs = EigVecs / denom + + return EigVecs + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data, HeteroData + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, is_undirected, cfg): + self.pe_types = pe_types + self.is_undirected = is_undirected + self.cfg = cfg + + def __call__(self, data: Data) -> Data: + data = compute_posenc_stats(data, pe_types=self.pe_types, + is_undirected=self.is_undirected, + cfg=self.cfg + ) + return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index c18c3604..ad68f4f6 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -10,6 +10,9 @@ ) from gridfm_graphkit.datasets.utils import split_dataset from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat + import numpy as np import random import warnings @@ -129,6 +132,17 @@ def setup(self, stage: str): mask_dim=self.args.data.mask_dim, transform=get_transform(args=self.args), ) + + if self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments + is_undirected=is_undirected, + cfg=cfg + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 00000000..d88e3e7b --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,103 @@ +# ------------------------ : new rwpse ---------------- +from typing import Union, Any, Optional +import numpy as np +import torch +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.data import Data, HeteroData +from torch_geometric.transforms import BaseTransform +from torch_scatter import scatter, scatter_add, scatter_max + +from torch_geometric.graphgym.config import cfg + +from torch_geometric.utils import ( + get_laplacian, + get_self_loop_attr, + to_scipy_sparse_matrix, +) +import torch_sparse +from torch_sparse import SparseTensor + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + + +@torch.no_grad() +def add_full_rrwp(data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs + ): + device=data.edge_index.device + ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + From a7bd51d1747ae3dd6f12a712b2962484abc59d6f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 06/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 372 +----------------- .../datasets/powergrid_datamodule.py | 5 +- 2 files changed, 9 insertions(+), 368 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 492a0a69..049633cf 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -11,8 +11,10 @@ from functools import partial from gridfm_graphkit.datasets.rrwp import add_full_rrwp +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data -def compute_posenc_stats(data, pe_types, is_undirected, cfg): +def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute, selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. @@ -37,387 +39,27 @@ def compute_posenc_stats(data, pe_types, is_undirected, cfg): 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") - # Basic preprocessing of the input graph. - if hasattr(data, 'num_nodes'): - N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa - else: - N = data.x.shape[0] # Number of nodes, including disconnected nodes. - laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() - if laplacian_norm_type == 'none': - laplacian_norm_type = None - if is_undirected: - undir_edge_index = data.edge_index - else: - undir_edge_index = to_undirected(data.edge_index) - - # Eigen values and vectors. - evals, evects = None, None - if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: - # Eigen-decomposition with numpy, can be reused for Heat kernels. - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, - num_nodes=N) - ) - evals, evects = np.linalg.eigh(L.toarray()) - - if 'LapPE' in pe_types: - max_freqs=cfg.posenc_LapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm - elif 'EquivStableLapPE' in pe_types: - max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm - - data.EigVals, data.EigVecs = get_lap_decomp_stats( - evals=evals, evects=evects, - max_freqs=max_freqs, - eigvec_norm=eigvec_norm) - - if 'SignNet' in pe_types: - # Eigen-decomposition with numpy for SignNet. - norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() - if norm_type == 'none': - norm_type = None - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=norm_type, - num_nodes=N) - ) - evals_sn, evects_sn = np.linalg.eigh(L.toarray()) - data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( - evals=evals_sn, evects=evects_sn, - max_freqs=cfg.posenc_SignNet.eigen.max_freqs, - eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) - - # Random Walks. - if 'RWSE' in pe_types: - kernel_param = cfg.posenc_RWSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("List of kernel times required for RWSE") - rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, - edge_index=data.edge_index, - num_nodes=N) - data.pestat_RWSE = rw_landing - - # Heat Kernels. - if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: - # Get the eigenvalues and eigenvectors of the regular Laplacian, - # if they have not yet been computed for 'eigen'. - if laplacian_norm_type is not None or evals is None or evects is None: - L_heat = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) - ) - evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) - else: - evals_heat, evects_heat = evals, evects - evals_heat = torch.from_numpy(evals_heat) - evects_heat = torch.from_numpy(evects_heat) - - # Get the full heat kernels. - if 'HKfullPE' in pe_types: - # The heat kernels can't be stored in the Data object without - # additional padding because in PyG's collation of the graphs the - # sizes of tensors must match except in dimension 0. Do this when - # the full heat kernels are actually used downstream by an Encoder. - raise NotImplementedError() - # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, - # kernel_times=kernel_param.times) - # data.pestat_HKdiagSE = hk_diag - # Get heat kernel diagonals in more efficient way. - if 'HKdiagSE' in pe_types: - kernel_param = cfg.posenc_HKdiagSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("Diffusion times are required for heat kernel") - hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, - kernel_times=kernel_param.times, - space_dim=0) - data.pestat_HKdiagSE = hk_diag - - # Electrostatic interaction inspired kernel. - if 'ElstaticSE' in pe_types: - elstatic = get_electrostatic_function_encoding(undir_edge_index, N) - data.pestat_ElstaticSE = elstatic - if 'RRWP' in pe_types: param = cfg.posenc_RRWP transform = partial(add_full_rrwp, walk_length=param.ksteps, attr_name_abs="rrwp", attr_name_rel="rrwp", - add_identity=True, - spd=param.spd, # by default False + add_identity=True ) data = transform(data) return data -def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): - """Compute Laplacian eigen-decomposition-based PE stats of the given graph. - - Args: - evals, evects: Precomputed eigen-decomposition - max_freqs: Maximum number of top smallest frequencies / eigenvecs to use - eigvec_norm: Normalization for the eigen vectors of the Laplacian - Returns: - Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node - Tensor (num_nodes, max_freqs) of eigenvector values per node - """ - N = len(evals) # Number of nodes, including disconnected nodes. - - # Keep up to the maximum desired number of frequencies. - idx = evals.argsort()[:max_freqs] - evals, evects = evals[idx], np.real(evects[:, idx]) - evals = torch.from_numpy(np.real(evals)).clamp_min(0) - - # Normalize and pad eigen vectors. - evects = torch.from_numpy(evects).float() - evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) - if N < max_freqs: - EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) - else: - EigVecs = evects - - # Pad and save eigenvalues. - if N < max_freqs: - EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) - else: - EigVals = evals.unsqueeze(0) - EigVals = EigVals.repeat(N, 1).unsqueeze(2) - - return EigVals, EigVecs - - -def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, - num_nodes=None, space_dim=0): - """Compute Random Walk landing probabilities for given list of K steps. - - Args: - ksteps: List of k-steps for which to compute the RW landings - edge_index: PyG sparse representation of the graph - edge_weight: (optional) Edge weights - num_nodes: (optional) Number of nodes in the graph - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the random-walk diagonal by a factor `k^(space_dim/2)`. - In euclidean space, this correction means that the height of - the gaussian distribution stays almost constant across the number of - steps, if `space_dim` is the dimension of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - if edge_weight is None: - edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) - num_nodes = maybe_num_nodes(edge_index, num_nodes) - source, dest = edge_index[0], edge_index[1] - deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.) - deg_inv.masked_fill_(deg_inv == float('inf'), 0) - - if edge_index.numel() == 0: - P = edge_index.new_zeros((1, num_nodes, num_nodes)) - else: - # P = D^-1 * A - P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) - rws = [] - if ksteps == list(range(min(ksteps), max(ksteps) + 1)): - # Efficient way if ksteps are a consecutive sequence (most of the time the case) - Pk = P.clone().detach().matrix_power(min(ksteps)) - for k in range(min(ksteps), max(ksteps) + 1): - rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - Pk = Pk @ P - else: - # Explicitly raising P to power k for each k \in ksteps. - for k in ksteps: - rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) - - return rw_landing - - -def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): - """Compute Heat kernel diagonal. - - This is a continuous function that represents a Gaussian in the Euclidean - space, and is the solution to the diffusion equation. - The random-walk diagonal should converge to this. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the diffusion diagonal by a factor `t^(space_dim/2)`. In - euclidean space, this correction means that the height of the - gaussian stays constant across time, if `space_dim` is the dimension - of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - heat_kernels_diag = [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels diagonal only for each time - eigvec_mul = evects ** 2 - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) - this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - - # Multiply by `t` to stabilize the values, since the gaussian height - # is proportional to `1/t` - heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) - heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) - - return heat_kernels_diag - - -def get_heat_kernels(evects, evals, kernel_times=[]): - """Compute full Heat diffusion kernels. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - """ - heat_kernels, rw_landing = [], [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels for each time - eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) - heat_kernels.append( - torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - ) - - heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) - - # Take the diagonal of each heat kernel, - # i.e. the landing probability of each of the random walks - rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) - - return heat_kernels, rw_landing - - -def get_electrostatic_function_encoding(edge_index, num_nodes): - """Kernel based on the electrostatic interaction between nodes. - """ - L = to_scipy_sparse_matrix( - *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) - ).todense() - L = torch.as_tensor(L) - Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) - A = deepcopy(L).abs() - A.fill_diagonal_(0) - DinvA = Dinv.matmul(A) - - electrostatic = torch.pinverse(L) - electrostatic = electrostatic - electrostatic.diag() - green_encoding = torch.stack([ - electrostatic.min(dim=0)[0], # Min of Vi -> j - electrostatic.max(dim=0)[0], # Max of Vi -> j - electrostatic.mean(dim=0), # Mean of Vi -> j - electrostatic.std(dim=0), # Std of Vi -> j - electrostatic.min(dim=1)[0], # Min of Vj -> i - electrostatic.max(dim=0)[0], # Max of Vj -> i - electrostatic.mean(dim=1), # Mean of Vj -> i - electrostatic.std(dim=1), # Std of Vj -> i - (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour - (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour - ], dim=1) - - return green_encoding - - -def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): - """ - Implement different eigenvector normalizations. - """ - - EigVals = EigVals.unsqueeze(0) - - if normalization == "L1": - # L1 normalization: eigvec / sum(abs(eigvec)) - denom = EigVecs.norm(p=1, dim=0, keepdim=True) - - elif normalization == "L2": - # L2 normalization: eigvec / sqrt(sum(eigvec^2)) - denom = EigVecs.norm(p=2, dim=0, keepdim=True) - - elif normalization == "abs-max": - # AbsMax normalization: eigvec / max|eigvec| - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - - elif normalization == "wavelength": - # AbsMax normalization, followed by wavelength multiplication: - # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom * 2 / np.pi - - elif normalization == "wavelength-asin": - # AbsMax normalization, followed by arcsin and wavelength multiplication: - # arcsin(eigvec / max|eigvec|) / sqrt(eigval) - denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) - EigVecs = torch.asin(EigVecs / denom_temp) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = eigval_denom - - elif normalization == "wavelength-soft": - # AbsSoftmax normalization, followed by wavelength multiplication: - # eigvec / (softmax|eigvec| * sqrt(eigval)) - denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom - - else: - raise ValueError(f"Unsupported normalization `{normalization}`") - - denom = denom.clamp_min(eps).expand_as(EigVecs) - EigVecs = EigVecs / denom - - return EigVecs - -from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data, HeteroData - class ComputePosencStat(BaseTransform): - def __init__(self, pe_types, is_undirected, cfg): + def __init__(self, pe_types, cfg): self.pe_types = pe_types - self.is_undirected = is_undirected self.cfg = cfg def __call__(self, data: Data) -> Data: - data = compute_posenc_stats(data, pe_types=self.pe_types, - is_undirected=self.is_undirected, + data = compute_posenc_stats(data, + pe_types=self.pe_types, cfg=self.cfg ) return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index ad68f4f6..d73f6faa 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -134,9 +134,8 @@ def setup(self, stage: str): ) if self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments - is_undirected=is_undirected, - cfg=cfg + pe_transform = ComputePosencStat(pe_types=['RRWP'], + cfg=self.args.data ) if dataset.transform is None: dataset.transform = pe_transform From 226f2a3098c4ed4499ac982d53f1e49338cce9fd Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 07/95] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 1 + gridfm_graphkit/datasets/posenc_stats.py | 3 +- .../datasets/powergrid_datamodule.py | 2 + gridfm_graphkit/datasets/rrwp.py | 1 - gridfm_graphkit/models/__init__.py | 3 +- gridfm_graphkit/models/grit_layer.py | 3 +- gridfm_graphkit/models/grit_transformer.py | 42 +++++++++---------- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 05f31be6..98960571 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,6 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True + .edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 049633cf..8bb2b9dc 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -16,7 +16,8 @@ def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. - Supported PE statistics to precompute, selected by `pe_types`: + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. 'RWSE': Random walk landing probabilities (diagonals of RW matrices). 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index d73f6faa..9960e084 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -13,6 +13,8 @@ from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat +import torch_geometric.transforms as T + import numpy as np import random import warnings diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index d88e3e7b..26218f0a 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -8,7 +8,6 @@ from torch_geometric.transforms import BaseTransform from torch_scatter import scatter, scatter_add, scatter_max -from torch_geometric.graphgym.config import cfg from torch_geometric.utils import ( get_laplacian, diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index de355d31..cc669363 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gps_transformer import GPSTransformer from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv +from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv"] +__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index b4779806..53e72172 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,7 +8,6 @@ import opt_einsum as oe -from yacs.config import CfgNode as CN import warnings @@ -48,7 +47,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_enhance=True, sqrt_relu=False, signed_sqrt=True, - cfg=CN(), + cfg={}, **kwargs): super().__init__() diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 715c25f1..49bfdf2e 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,8 +1,9 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY import torch from torch import nn -from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder -from grit_layer import GritTransformerLayer + +from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.grit_layer import GritTransformerLayer @@ -27,25 +28,21 @@ def forward(self, batch): class LinearNodeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, dim_in, emb_dim): super().__init__() - self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + self.encoder = torch.nn.Linear(dim_in, emb_dim) def forward(self, batch): batch.x = self.encoder(batch.x) return batch class LinearEdgeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, edge_dim, emb_dim): super().__init__() - if cfg.dataset.name in ['MNIST', 'CIFAR10']: - self.in_dim = 1 - elif cfg.dataset.name.startswith('attributed_triangle-'): - self.in_dim = 2 - else: - raise ValueError("Input edge feature dim is required to be hardset " - "or refactored to use a cfg option.") + + self.in_dim = edge_dim + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) def forward(self, batch): @@ -69,20 +66,20 @@ def __init__( ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if args.node_encoder: + if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - self.node_encoder = LinearNodeEncoder(dim_inner) - if args.node_encoder_bn: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner - if args.edge_encoder: - - dim_edge = dim_inner + if args.encoder.edge_encoder: + args.edge_dim + enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - self.edge_encoder = LinearEdgeEncoder(dim_edge) - if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -121,7 +118,7 @@ def __init__(self, args): dim_inner ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, pad_to_full_graph=args.model.gt.attn.full_attn, @@ -158,6 +155,7 @@ def __init__(self, args): ) def forward(self, batch): + print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) From 88d9ca6a7d2cea3f6f935b123b236a203452809f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 08/95] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 98960571..bd76278b 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,7 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True - .edge_encoder_bn: True + edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 49bfdf2e..e8746b84 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -74,7 +74,7 @@ def __init__( # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.encoder.edge_encoder: - args.edge_dim + edge_dim = args.edge_dim enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) @@ -107,8 +107,8 @@ def __init__(self, args): self.encoder = FeatureEncoder( dim_in, dim_inner, - args.model.encoder - ) # TODO add args + args.model + ) dim_in = self.encoder.dim_in if args.model.posenc_RRWP.enable: From b7d9dcf035026c797c1d68577d8b04ed9d3247ee Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 09/95] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 ++ gridfm_graphkit/models/grit_layer.py | 4 ++-- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index bd76278b..73e5817e 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -64,6 +64,8 @@ model: O_e: True norm_e: True signed_sqrt: True + bn_momentum: 0.1 + bn_no_runner: False optimizer: beta1: 0.9 diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 53e72172..ffcf584b 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = cfg.get("update_e", True) + self.update_e = getattr(cfg, "update_e", True) self.bn_momentum = cfg.bn_momentum self.bn_no_runner = cfg.bn_no_runner - self.rezero = cfg.get("rezero", False) + self.rezero = getattr(cfg, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e8746b84..8d4f6962 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -111,13 +111,13 @@ def __init__(self, args): ) dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: + if args.data.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( - args.model.posenc_RRWP.ksteps, + args.data.posenc_RRWP.ksteps, dim_inner ) - rel_pe_dim = args.model.posenc_RRWP.ksteps + rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, From 38cc44a31d909cba09a8296784fc5f02e9637dba Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 10/95] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index ffcf584b..98d0b6c0 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = getattr(cfg, "update_e", True) - self.bn_momentum = cfg.bn_momentum - self.bn_no_runner = cfg.bn_no_runner - self.rezero = getattr(cfg, "rezero", False) + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: From 7fded9556996acec2d55919e95059fe4faba93fa Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 11/95] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 98d0b6c0..9723304c 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -72,7 +72,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, if act is None: self.act = nn.Identity() else: - self.act = act_dict[act]() + self.act = nn.ReLU() if self.edge_enhance: self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) @@ -171,12 +171,15 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - self.act = act_dict[act]() if act is not None else nn.Identity() - if cfg.get("attn", None) is None: + if act is not None + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: cfg.attn = dict() - self.use_attn = cfg.attn.get("use", True) - # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) - self.deg_scaler = cfg.attn.get("deg_scaler", True) + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) self.attention = MultiHeadAttentionLayerGritSparse( in_dim=in_dim, From 0f3b803a2a2ccfa69b11bac643b1ad14e2e24d38 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 12/95] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 9723304c..0bcdf730 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -171,7 +171,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - if act is not None + if act is not None: self.act = nn.ReLU() else: self.act = nn.Identity() From af8ad03ad589f781a35d5deaa557a616c5463e80 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 13/95] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 0bcdf730..f95ffc78 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -185,31 +185,31 @@ def __init__(self, in_dim, out_dim, num_heads, in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), - edge_enhance=cfg.attn.get("edge_enhance", True), - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn,"scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) - if cfg.attn.get('graphormer_attn', False): + if getattr(cfg.attn, 'graphormer_attn', False): self.attention = MultiHeadAttentionLayerGraphormerSparse( in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), edge_enhance=True, - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) @@ -232,8 +232,8 @@ def __init__(self, in_dim, out_dim, num_heads, if self.batch_norm: # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) - self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) - self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() # FFN for h self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) @@ -243,7 +243,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.layer_norm2_h = nn.LayerNorm(out_dim) if self.batch_norm: - self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if self.rezero: self.alpha1_h = nn.Parameter(torch.zeros(1,1)) From f430f2a7562c3db8b570b673422e7406446de593 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 14/95] matching up parameters in data module Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index 9960e084..e67956e5 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -142,7 +142,7 @@ def setup(self, stage: str): if dataset.transform is None: dataset.transform = pe_transform else: - dataset.transform = T.compose([pe_transform, dataset.transform]) + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) From e1c489050bd9995e10efe69d6f356515dc6b965d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 15/95] flow over parameters from base model Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 8d4f6962..10af25a7 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -103,6 +103,19 @@ def __init__(self, args): num_heads = args.model.attention_head dropout = args.model.dropout num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) self.encoder = FeatureEncoder( dim_in, From 36dca0095cecea391a586b3ec205a229b3511e7c Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 16/95] verified encodings and data flow to model forward method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- .../tasks/feature_reconstruction_task.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 10af25a7..cf07a8c6 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -167,7 +167,7 @@ def __init__(self, args): nn.Linear(dim_inner, dim_out), ) - def forward(self, batch): + def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b6..da2f478f 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -74,11 +74,11 @@ def __init__(self, args, node_normalizers, edge_normalizers): self.edge_normalizers = edge_normalizers self.save_hyperparameters() - def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): - if mask is not None: - mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) - x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] - return self.model(x, pe, edge_index, edge_attr, batch) + def forward(self, batch): + if batch.mask is not None: + mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) + batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] + return self.model(batch) @rank_zero_only def on_fit_start(self): @@ -111,12 +111,14 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - x=batch.x, - pe=batch.pe, - edge_index=batch.edge_index, - edge_attr=batch.edge_attr, - batch=batch.batch, - mask=batch.mask, + # TODO update args list in the GPS Transf. for consistency + # x=batch.x, + # pe=batch.pe, + # edge_index=batch.edge_index, + # edge_attr=batch.edge_attr, + # batch=batch.batch, + # mask=batch.mask, + batch ) loss_dict = self.loss_fn( From a8ec56efdbef89eef493f3e6f1b897f3e17140b6 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 17/95] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- gridfm_graphkit/models/rrwp_encoder.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index cf07a8c6..50d0fec2 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -133,7 +133,7 @@ def __init__(self, args): rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, - dim_edge, + dim_inner, pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index b73e463d..33c52157 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,6 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph @@ -144,7 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - + print('xxxx', edge_attr.size(), rrwp_val.size()) + print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 0868b96e7a5aed4e2648f6288d85073a260afd44 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 18/95] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 73e5817e..52c44c8f 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 123 + hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 From 3cc21a38eea36b3dddca0e3a2d1817e15ac66b86 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 19/95] reformat decoder to handle batch format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/cli.py | 5 ++ gridfm_graphkit/models/grit_transformer.py | 58 ++++++++++++++++++++-- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c11..79cb772e 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -77,6 +77,11 @@ def main_cli(args): max_epochs=config_args.training.epochs, callbacks=get_training_callbacks(config_args), ) + + # print('******model*****') + # print(model) + # print('******model*****') + if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 50d0fec2..2e85d4e0 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -85,6 +85,49 @@ def forward(self, batch): for module in self.children(): batch = module(batch) return batch + +class GraphHead(nn.Module): + """ + Prediction head for graph prediction tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + # self.deg_scaler = False + # self.fwl = False + + # list_FC_layers = [ + # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) + # for l in range(L)] + # list_FC_layers.append( + # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) #nn.ModuleList(list_FC_layers) + # self.L = L + # self.activation = register.act_dict[cfg.gnn.act]() + # note: modified to add () in the end from original code of 'GPS' + # potentially due to the change of PyG/GraphGym version + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + # graph_emb = self.pooling_fun(batch.x, batch.batch) + graph_emb = self.FC_layers(batch.x) + # for l in range(self.L): + # graph_emb = self.FC_layers[l](graph_emb) + # graph_emb = self.activation(graph_emb) + # graph_emb = self.FC_layers[self.L](graph_emb) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred @MODELS_REGISTRY.register("GRIT") @@ -161,15 +204,20 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - self.decoder = nn.Sequential( - nn.Linear(dim_inner, dim_inner), - nn.LeakyReLU(), - nn.Linear(dim_inner, dim_out), - ) + # self.decoder = nn.Sequential( + # nn.Linear(dim_inner, dim_inner), + # nn.LeakyReLU(), + # nn.Linear(dim_inner, dim_out), + # ) + + self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): + print('----------') + print(module) batch = module(batch) + print('--passed--') return batch \ No newline at end of file From 17830516f74d297db439b659a14edf8db99506f3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 20/95] confirmed training loop functions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 8 ++++---- gridfm_graphkit/models/rrwp_encoder.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 2e85d4e0..4e09de91 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - print('process--->>', batch) # TODO remove print + # print('process--->>', batch) # TODO remove print for module in self.children(): - print('----------') - print(module) + # print('----------') + # print(module) batch = module(batch) - print('--passed--') + # print('--passed--') return batch \ No newline at end of file diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 33c52157..270ca866 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -145,8 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - print('xxxx', edge_attr.size(), rrwp_val.size()) - print('yyyy', edge_index.size(), rrwp_idx.size()) + # print('xxxx', edge_attr.size(), rrwp_val.size()) + # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From c75012fa845f0775f47df55eda7f2fe2c85d25d9 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 21/95] update toml Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 22 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 2 +- pyproject.toml | 1 + 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 52c44c8f..30ee2139 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -11,21 +11,21 @@ data: networks: # - Texas2k_case1_2016summerpeak - case24_ieee_rts - # - case118_ieee - # - case300_ieee + - case118_ieee + - case300_ieee - case89_pegase - # - case240_pserc + - case240_pserc normalization: baseMVAnorm scenarios: # - 5000 - - 5000 - - 5000 - # - 30000 - # - 50000 - # - 50000 + - 50000 + - 50000 + - 30000 + - 50000 + - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 4 + workers: 8 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 @@ -51,7 +51,7 @@ model: edge_encoder_bn: True gt: layer_type: GritTransformer - dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 270ca866..2dadd35c 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,7 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") - print('--------fc in and out:', emb_dim, out_dim) + # print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph diff --git a/pyproject.toml b/pyproject.toml index 51c86652..2250ae90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "opt-einsum", ] [project.optional-dependencies] From 3d3f98b3123defb965144e8d50868c46bfdab455 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 22/95] added forward method to transform class Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 6 +++++- gridfm_graphkit/models/grit_transformer.py | 4 ++-- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 8bb2b9dc..21e7841b 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -13,6 +13,7 @@ from torch_geometric.transforms import BaseTransform from torch_geometric.data import Data +from typing import Any def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. @@ -58,9 +59,12 @@ def __init__(self, pe_types, cfg): self.pe_types = pe_types self.cfg = cfg + def forward(self, data: Any) -> Any: + pass + def __call__(self, data: Data) -> Data: data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg ) - return data \ No newline at end of file + return data diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 4e09de91..7caed0fb 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - # print('process--->>', batch) # TODO remove print + #print('process--->>', batch) # TODO remove print for module in self.children(): # print('----------') # print(module) batch = module(batch) # print('--passed--') - return batch \ No newline at end of file + return batch diff --git a/pyproject.toml b/pyproject.toml index 2250ae90..4a17ed50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ - "torch>2.0", + "torch==2.6", "torch-geometric", "mlflow", "nbformat", From d238e7591d4f358ac2877f838debd4ab64aa7afe Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 23/95] update readme with install instructions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 78e86614..5a096dc4 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ cd gridfm-graphkit python -m venv venv source venv/bin/activate pip install -e . +pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` For documentation generation and unit testing, install with the optional `dev` and `test` extras: From 17b0889e339f0378d761afe047f5b08db0802d03 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 24/95] verifed compat with GPS and GNN Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- gridfm_graphkit/models/gnn_transformer.py | 8 +++++++- gridfm_graphkit/models/gps_transformer.py | 8 +++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 30ee2139..8f11c933 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -25,7 +25,7 @@ data: - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 8 + workers: 0 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index e67956e5..4b0320f3 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -135,7 +135,7 @@ def setup(self, stage: str): transform=get_transform(args=self.args), ) - if self.args.data.posenc_RRWP.enable: + if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: pe_transform = ComputePosencStat(pe_types=['RRWP'], cfg=self.args.data ) diff --git a/gridfm_graphkit/models/gnn_transformer.py b/gridfm_graphkit/models/gnn_transformer.py index 9e1ab231..37d36323 100644 --- a/gridfm_graphkit/models/gnn_transformer.py +++ b/gridfm_graphkit/models/gnn_transformer.py @@ -74,7 +74,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -88,6 +88,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + for conv in self.layers: x = conv(x, edge_index, edge_attr) x = nn.LeakyReLU()(x) diff --git a/gridfm_graphkit/models/gps_transformer.py b/gridfm_graphkit/models/gps_transformer.py index cc8b6489..ca45c5a4 100644 --- a/gridfm_graphkit/models/gps_transformer.py +++ b/gridfm_graphkit/models/gps_transformer.py @@ -105,7 +105,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -119,6 +119,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + x_pe = self.pe_norm(pe) x = self.encoder(x) From 091f0849ce997ff6d820b2102fb4d631faa7f789 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 25/95] work on comments and clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/rrwp.py | 19 +------ gridfm_graphkit/models/grit_layer.py | 12 ++-- gridfm_graphkit/models/grit_transformer.py | 55 +++++++------------ gridfm_graphkit/models/rrwp_encoder.py | 27 ++++----- .../tasks/feature_reconstruction_task.py | 7 --- 5 files changed, 42 insertions(+), 78 deletions(-) diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index 26218f0a..acbe1120 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -1,20 +1,7 @@ -# ------------------------ : new rwpse ---------------- -from typing import Union, Any, Optional -import numpy as np +from typing import Any, Optional import torch import torch.nn.functional as F -import torch_geometric as pyg -from torch_geometric.data import Data, HeteroData -from torch_geometric.transforms import BaseTransform -from torch_scatter import scatter, scatter_add, scatter_max - - -from torch_geometric.utils import ( - get_laplacian, - get_self_loop_attr, - to_scipy_sparse_matrix, -) -import torch_sparse +from torch_geometric.data import Data from torch_sparse import SparseTensor @@ -42,8 +29,6 @@ def add_full_rrwp(data, spd=False, **kwargs ): - device=data.edge_index.device - ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) num_nodes = data.num_nodes edge_index, edge_weight = data.edge_index, data.edge_weight diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index f95ffc78..a1ffc4a3 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,11 +8,12 @@ import opt_einsum as oe - import warnings + def pyg_softmax(src, index, num_nodes=None): - r"""Computes a sparsely evaluated softmax. + """ + Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. @@ -23,7 +24,8 @@ def pyg_softmax(src, index, num_nodes=None): num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) - :rtype: :class:`Tensor` + Returns: + out (Tensor) """ num_nodes = maybe_num_nodes(index, num_nodes) @@ -39,7 +41,7 @@ def pyg_softmax(src, index, num_nodes=None): class MultiHeadAttentionLayerGritSparse(nn.Module): """ - Proposed Attention Computation for GRIT + Attention Computation for GRIT """ def __init__(self, in_dim, out_dim, num_heads, use_bias, @@ -140,7 +142,7 @@ def forward(self, batch): class GritTransformerLayer(nn.Module): """ - Proposed Transformer Layer for GRIT + Transformer Layer for GRIT """ def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 7caed0fb..a1717d18 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -12,7 +12,8 @@ class BatchNorm1dNode(torch.nn.Module): Args: dim_in (int): BatchNorm input dimension. - TODO fill in comments + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. """ def __init__(self, dim_in, eps, momentum): super().__init__() @@ -88,7 +89,7 @@ def forward(self, batch): class GraphHead(nn.Module): """ - Prediction head for graph prediction tasks. + Prediction head for decoding tasks. Args: dim_in (int): Input dimension. dim_out (int): Output dimension. For binary prediction, dim_out=1. @@ -97,34 +98,18 @@ class GraphHead(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - # self.deg_scaler = False - # self.fwl = False - - # list_FC_layers = [ - # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) - # for l in range(L)] - # list_FC_layers.append( - # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( nn.Linear(dim_in, dim_in), nn.LeakyReLU(), nn.Linear(dim_in, dim_out), - ) #nn.ModuleList(list_FC_layers) - # self.L = L - # self.activation = register.act_dict[cfg.gnn.act]() - # note: modified to add () in the end from original code of 'GPS' - # potentially due to the change of PyG/GraphGym version + ) def _apply_index(self, batch): return batch.graph_feature, batch.y def forward(self, batch): - # graph_emb = self.pooling_fun(batch.x, batch.batch) graph_emb = self.FC_layers(batch.x) - # for l in range(self.L): - # graph_emb = self.FC_layers[l](graph_emb) - # graph_emb = self.activation(graph_emb) - # graph_emb = self.FC_layers[self.L](graph_emb) batch.graph_feature = graph_emb pred, label = self._apply_index(batch) return pred @@ -132,9 +117,12 @@ def forward(self, batch): @MODELS_REGISTRY.register("GRIT") class GritTransformer(torch.nn.Module): - ''' - The proposed GritTransformer (Graph Inductive Bias Transformer) - ''' + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ def __init__(self, args): super().__init__() @@ -204,20 +192,19 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - # self.decoder = nn.Sequential( - # nn.Linear(dim_inner, dim_inner), - # nn.LeakyReLU(), - # nn.Linear(dim_inner, dim_out), - # ) - self.decoder = GraphHead(dim_inner, dim_out) - def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - #print('process--->>', batch) # TODO remove print + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ for module in self.children(): - # print('----------') - # print(module) batch = module(batch) - # print('--passed--') return batch diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 2dadd35c..1f7fd10b 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -51,10 +51,10 @@ def full_edge_index(edge_index, batch=None): class RRWPLinearNodeEncoder(torch.nn.Module): """ - FC_1(RRWP) + FC_2 (Node-attr) - note: FC_2 is given by the Typedict encoder of node-attr in some cases - Parameters: - num_classes - the number of classes for the embedding mapping to learn + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn """ def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): super().__init__() @@ -90,14 +90,14 @@ def forward(self, batch): class RRWPLinearEdgeEncoder(torch.nn.Module): - ''' - Merge RRWP with given edge-attr and Zero-padding to all pairs of node - FC_1(RRWP) + FC_2(edge-attr) - - FC_2 given by the TypedictEncoder in same cases - - Zero-padding for non-existing edges in fully-connected graph - - (optional) add node-attr as the E_{i,i}'s attr - note: assuming node-attr and edge-attr is with the same dimension after Encoders - ''' + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, pad_to_full_graph=True, fill_value=0., add_node_attr_as_self_loop=False, @@ -143,10 +143,7 @@ def forward(self, batch): if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: - # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - # print('xxxx', edge_attr.size(), rrwp_val.size()) - # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index da2f478f..0d1743b6 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -111,13 +111,6 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - # TODO update args list in the GPS Transf. for consistency - # x=batch.x, - # pe=batch.pe, - # edge_index=batch.edge_index, - # edge_attr=batch.edge_attr, - # batch=batch.batch, - # mask=batch.mask, batch ) From 53d564415e4a8c99f10c7882c343bf38fdadf766 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:44 -0500 Subject: [PATCH 26/95] deep copy in test method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/feature_reconstruction_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index 0d1743b6..e6a47494 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -5,6 +5,7 @@ import numpy as np import os import pandas as pd +import copy from lightning.pytorch.loggers import MLFlowLogger from gridfm_graphkit.io.param_handler import load_model, get_loss_function @@ -162,7 +163,7 @@ def validation_step(self, batch, batch_idx): return loss_dict["loss"] def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(batch) + output, loss_dict = self.shared_step(copy.deepcopy(batch)) dataset_name = self.args.data.networks[dataloader_idx] From e23c9c62fe6698a397e8bc160d41e9fcbe332eef Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:13:19 -0500 Subject: [PATCH 27/95] basic RWSE flown over Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 18 +++-- gridfm_graphkit/datasets/posenc_stats.py | 67 +++++++++++++++++++ .../datasets/powergrid_datamodule.py | 8 +++ gridfm_graphkit/models/grit_transformer.py | 2 + .../tasks/feature_reconstruction_task.py | 2 + 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 8f11c933..e0cdf3e2 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -18,20 +18,24 @@ data: normalization: baseMVAnorm scenarios: # - 5000 - - 50000 - - 50000 - - 30000 - - 50000 - - 50000 + - 5000 + - 5000 + - 3000 + - 5000 + - 5000 test_ratio: 0.1 val_ratio: 0.1 workers: 0 - posenc_RRWP: # TODO maybe better with data section... - enable: True + posenc_RRWP: + enable: False ksteps: 21 add_identity: True add_node_attr: False add_inverse: False + posenc_RWSE: # TODO verify functionality + enable: True + kernel: + times: 21 model: attention_head: 8 dropout: 0.1 diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 21e7841b..0588d878 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -15,6 +15,8 @@ from torch_geometric.data import Data from typing import Any +from torch_geometric.utils.num_nodes import maybe_num_nodes + def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute in original implementation, @@ -51,9 +53,74 @@ def compute_posenc_stats(data, pe_types, cfg): ) data = transform(data) + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + if len(kernel_param.times) == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs( + ksteps=[xx + 1 for xx in range(kernel_param.times)], + edge_index=data.edge_index, + num_nodes=N + ) + data.pestat_RWSE = rw_landing + return data + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + class ComputePosencStat(BaseTransform): def __init__(self, pe_types, cfg): self.pe_types = pe_types diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index 4b0320f3..740ff513 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -143,6 +143,14 @@ def setup(self, stage: str): dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=['RWSE'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index a1717d18..e264d1b5 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -204,6 +204,8 @@ def forward(self, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index e6a47494..aa3370c1 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -77,6 +77,8 @@ def __init__(self, args, node_normalizers, edge_normalizers): def forward(self, batch): if batch.mask is not None: + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] return self.model(batch) From bfe2af0f2a447af11fc244d48c49ea2c2ebf4aa0 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:24:31 -0500 Subject: [PATCH 28/95] tested addition of RWSE Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 2 +- gridfm_graphkit/models/grit_transformer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 0588d878..e12d9f10 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -60,7 +60,7 @@ def compute_posenc_stats(data, pe_types, cfg): N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa else: N = data.x.shape[0] # Number of nodes, including disconnected nodes. - if len(kernel_param.times) == 0: + if kernel_param.times == 0: raise ValueError("List of kernel times required for RWSE") rw_landing = get_rw_landing_probs( ksteps=[xx + 1 for xx in range(kernel_param.times)], diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e264d1b5..cab5ef9f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -206,6 +206,7 @@ def forward(self, batch): """ # print('xxxx',batch.x.min(), batch.x.max()) # print('yyyyy',batch.y.min(), batch.y.max()) + print('>>>>', batch) for module in self.children(): batch = module(batch) From c1e572181a704a86dab158a3f81a6696de80b78f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:26:42 -0500 Subject: [PATCH 29/95] flow over kernel encoders Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 7 +- gridfm_graphkit/models/kernel_pos_encoder.py | 123 +++++++++++++++++++ 3 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 gridfm_graphkit/models/kernel_pos_encoder.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index e0cdf3e2..f8654b99 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -50,7 +50,7 @@ model: encoder: node_encoder: True edge_encoder: True - node_encoder_name: TODO + node_encoder_name: RWSE node_encoder_bn: True edge_encoder_bn: True gt: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index cab5ef9f..19351d9b 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -69,7 +69,10 @@ def __init__( self.dim_in = dim_in if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if 'RWSE' in self.node_encoder_name: + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner) + else: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features @@ -206,7 +209,7 @@ def forward(self, batch): """ # print('xxxx',batch.x.min(), batch.x.max()) # print('yyyyy',batch.y.min(), batch.y.max()) - print('>>>>', batch) + # print('>>>>', batch) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py new file mode 100644 index 00000000..36c4c51a --- /dev/null +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.register import register_node_encoder + + +class KernelPENodeEncoder(torch.nn.Module): + """Configurable kernel-based Positional Encoding node encoder. + + The choice of which kernel-based statistics to use is configurable through + setting of `kernel_type`. Based on this, the appropriate config is selected, + and also the appropriate variable with precomputed kernel stats is then + selected from PyG Data graphs in `forward` function. + E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. + + PE of size `dim_pe` will get appended to each node feature vector. + If `expand_x` set True, original node features will be first linearly + projected to (dim_emb - dim_pe) size and the concatenated with PE. + + Args: + dim_emb: Size of final node embedding + expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) + """ + + kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE + + def __init__(self, dim_emb, expand_x=True): + super().__init__() + if self.kernel_type is None: + raise ValueError(f"{self.__class__.__name__} has to be " + f"preconfigured by setting 'kernel_type' class" + f"variable before calling the constructor.") + + dim_in = cfg.share.dim_in # Expected original input node features dim + + pecfg = getattr(cfg, f"posenc_{self.kernel_type}") + dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding + num_rw_steps = len(pecfg.kernel.times) + model_type = pecfg.model.lower() # Encoder NN model type for PEs + n_layers = pecfg.layers # Num. layers in PE encoder model + norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type + self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + + if dim_emb - dim_pe < 1: + raise ValueError(f"PE dim size {dim_pe} is too large for " + f"desired embedding size of {dim_emb}.") + + if expand_x: + self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) + self.expand_x = expand_x + + if norm_type == 'batchnorm': + self.raw_norm = nn.BatchNorm1d(num_rw_steps) + else: + self.raw_norm = None + + activation = nn.ReLU() # register.act_dict[cfg.gnn.act] + if model_type == 'mlp': + layers = [] + if n_layers == 1: + layers.append(nn.Linear(num_rw_steps, dim_pe)) + layers.append(activation) + else: + layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) + layers.append(activation) + for _ in range(n_layers - 2): + layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) + layers.append(activation) + layers.append(nn.Linear(2 * dim_pe, dim_pe)) + layers.append(activation) + self.pe_encoder = nn.Sequential(*layers) + elif model_type == 'linear': + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + else: + raise ValueError(f"{self.__class__.__name__}: Does not support " + f"'{model_type}' encoder model.") + + def forward(self, batch): + pestat_var = f"pestat_{self.kernel_type}" + if not hasattr(batch, pestat_var): + raise ValueError(f"Precomputed '{pestat_var}' variable is " + f"required for {self.__class__.__name__}; set " + f"config 'posenc_{self.kernel_type}.enable' to " + f"True, and also set 'posenc.kernel.times' values") + + pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) + # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) + if self.raw_norm: + pos_enc = self.raw_norm(pos_enc) + pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe + + # Expand node features if needed + if self.expand_x: + h = self.linear_x(batch.x) + else: + h = batch.x + # Concatenate final PEs to input embedding + batch.x = torch.cat((h, pos_enc), 1) + # Keep PE also separate in a variable (e.g. for skip connections to input) + if self.pass_as_var: + setattr(batch, f'pe_{self.kernel_type}', pos_enc) + return batch + + +@register_node_encoder('RWSE') +class RWSENodeEncoder(KernelPENodeEncoder): + """Random Walk Structural Encoding node encoder. + """ + kernel_type = 'RWSE' + + +@register_node_encoder('HKdiagSE') +class HKdiagSENodeEncoder(KernelPENodeEncoder): + """Heat kernel (diagonal) Structural Encoding node encoder. + """ + kernel_type = 'HKdiagSE' + + +@register_node_encoder('ElstaticSE') +class ElstaticSENodeEncoder(KernelPENodeEncoder): + """Electrostatic interactions Structural Encoding node encoder. + """ + kernel_type = 'ElstaticSE' From 5e096832be119ffb9092af7d0f6dcd22fcf8f54e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:39:47 -0500 Subject: [PATCH 30/95] basic match of parameters for new encoder Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 4 +- gridfm_graphkit/models/grit_transformer.py | 4 +- gridfm_graphkit/models/kernel_pos_encoder.py | 55 +++----------------- 3 files changed, 13 insertions(+), 50 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index f8654b99..24349db2 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -32,10 +32,12 @@ data: add_identity: True add_node_attr: False add_inverse: False - posenc_RWSE: # TODO verify functionality + posenc_RWSE: enable: True kernel: times: 21 + pe_dim: 20 # TODO unify with model.pe_dim + raw_norm_type: batchnorm model: attention_head: 8 dropout: 0.1 diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 19351d9b..5a01c396 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -4,7 +4,7 @@ from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer - +from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder class BatchNorm1dNode(torch.nn.Module): @@ -70,7 +70,7 @@ def __init__( if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings if 'RWSE' in self.node_encoder_name: - self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner) + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.posenc_RWSE) else: self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py index 36c4c51a..4b6a654f 100644 --- a/gridfm_graphkit/models/kernel_pos_encoder.py +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.register import register_node_encoder class KernelPENodeEncoder(torch.nn.Module): @@ -24,22 +22,17 @@ class KernelPENodeEncoder(torch.nn.Module): kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE - def __init__(self, dim_emb, expand_x=True): + def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): super().__init__() if self.kernel_type is None: raise ValueError(f"{self.__class__.__name__} has to be " f"preconfigured by setting 'kernel_type' class" f"variable before calling the constructor.") - dim_in = cfg.share.dim_in # Expected original input node features dim - - pecfg = getattr(cfg, f"posenc_{self.kernel_type}") dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding num_rw_steps = len(pecfg.kernel.times) - model_type = pecfg.model.lower() # Encoder NN model type for PEs - n_layers = pecfg.layers # Num. layers in PE encoder model norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type - self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable if dim_emb - dim_pe < 1: raise ValueError(f"PE dim size {dim_pe} is too large for " @@ -54,26 +47,8 @@ def __init__(self, dim_emb, expand_x=True): else: self.raw_norm = None - activation = nn.ReLU() # register.act_dict[cfg.gnn.act] - if model_type == 'mlp': - layers = [] - if n_layers == 1: - layers.append(nn.Linear(num_rw_steps, dim_pe)) - layers.append(activation) - else: - layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) - layers.append(activation) - for _ in range(n_layers - 2): - layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) - layers.append(activation) - layers.append(nn.Linear(2 * dim_pe, dim_pe)) - layers.append(activation) - self.pe_encoder = nn.Sequential(*layers) - elif model_type == 'linear': - self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) - else: - raise ValueError(f"{self.__class__.__name__}: Does not support " - f"'{model_type}' encoder model.") + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + def forward(self, batch): pestat_var = f"pestat_{self.kernel_type}" @@ -97,27 +72,13 @@ def forward(self, batch): # Concatenate final PEs to input embedding batch.x = torch.cat((h, pos_enc), 1) # Keep PE also separate in a variable (e.g. for skip connections to input) - if self.pass_as_var: - setattr(batch, f'pe_{self.kernel_type}', pos_enc) + # if self.pass_as_var: + # setattr(batch, f'pe_{self.kernel_type}', pos_enc) + return batch -@register_node_encoder('RWSE') class RWSENodeEncoder(KernelPENodeEncoder): """Random Walk Structural Encoding node encoder. """ - kernel_type = 'RWSE' - - -@register_node_encoder('HKdiagSE') -class HKdiagSENodeEncoder(KernelPENodeEncoder): - """Heat kernel (diagonal) Structural Encoding node encoder. - """ - kernel_type = 'HKdiagSE' - - -@register_node_encoder('ElstaticSE') -class ElstaticSENodeEncoder(KernelPENodeEncoder): - """Electrostatic interactions Structural Encoding node encoder. - """ - kernel_type = 'ElstaticSE' + kernel_type = 'RWSE' \ No newline at end of file From 1378770b6d31ceb7644f0cea64801c0dde25f90c Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:48:54 -0500 Subject: [PATCH 31/95] tested functionality of new encoding Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 15 +++++++++------ gridfm_graphkit/models/grit_transformer.py | 4 ++-- gridfm_graphkit/models/kernel_pos_encoder.py | 4 ++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 24349db2..bf56dffa 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -25,7 +25,7 @@ data: - 5000 test_ratio: 0.1 val_ratio: 0.1 - workers: 0 + workers: 4 posenc_RRWP: enable: False ksteps: 21 @@ -33,11 +33,9 @@ data: add_node_attr: False add_inverse: False posenc_RWSE: - enable: True - kernel: - times: 21 - pe_dim: 20 # TODO unify with model.pe_dim - raw_norm_type: batchnorm + enable: True + kernel: + times: 21 # TODO unify with model model: attention_head: 8 dropout: 0.1 @@ -55,6 +53,11 @@ model: node_encoder_name: RWSE node_encoder_bn: True edge_encoder_bn: True + posenc_RWSE: + kernel: + times: 21 + pe_dim: 20 # TODO unify with model.pe_dim + raw_norm_type: batchnorm gt: layer_type: GritTransformer dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 5a01c396..ab0d51bd 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -69,8 +69,8 @@ def __init__( self.dim_in = dim_in if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - if 'RWSE' in self.node_encoder_name: - self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.posenc_RWSE) + if 'RWSE' in args.encoder.node_encoder_name: + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.encoder.posenc_RWSE) else: self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py index 4b6a654f..b24078de 100644 --- a/gridfm_graphkit/models/kernel_pos_encoder.py +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -29,8 +29,8 @@ def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): f"preconfigured by setting 'kernel_type' class" f"variable before calling the constructor.") - dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding - num_rw_steps = len(pecfg.kernel.times) + dim_pe = pecfg.pe_dim # Size of the kernel-based PE embedding + num_rw_steps = pecfg.kernel.times norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable From 2eb3a10b8ea8731f58af2b4d496de03e733dbf2b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:26:09 -0400 Subject: [PATCH 32/95] settle final merge conflicts Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- .../datasets/hetero_powergrid_datamodule.py | 8 + .../datasets/powergrid_datamodule.py | 230 ----------- .../tasks/feature_reconstruction_task.py | 356 ------------------ 3 files changed, 8 insertions(+), 586 deletions(-) delete mode 100644 gridfm_graphkit/datasets/powergrid_datamodule.py delete mode 100644 gridfm_graphkit/tasks/feature_reconstruction_task.py diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 4474c2e4..e6a4cfd1 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -163,6 +163,14 @@ def setup(self, stage: str): dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=['RWSE'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py deleted file mode 100644 index 740ff513..00000000 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ /dev/null @@ -1,230 +0,0 @@ -import torch -from torch_geometric.loader import DataLoader -from torch.utils.data import ConcatDataset -from torch.utils.data import Subset -import torch.distributed as dist -from gridfm_graphkit.io.param_handler import ( - NestedNamespace, - load_normalizer, - get_transform, -) -from gridfm_graphkit.datasets.utils import split_dataset -from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk - -from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat - -import torch_geometric.transforms as T - -import numpy as np -import random -import warnings -import os -import lightning as L - - -class LitGridDataModule(L.LightningDataModule): - """ - PyTorch Lightning DataModule for power grid datasets. - - This datamodule handles loading, preprocessing, splitting, and batching - of power grid graph datasets (`GridDatasetDisk`) for training, validation, - testing, and prediction. It ensures reproducibility through fixed seeds. - - Args: - args (NestedNamespace): Experiment configuration. - data_dir (str, optional): Root directory for datasets. Defaults to "./data". - - Attributes: - batch_size (int): Batch size for all dataloaders. From ``args.training.batch_size`` - node_normalizers (list): List of node feature normalizers, one per dataset. - edge_normalizers (list): List of edge feature normalizers, one per dataset. - datasets (list): Original datasets for each network. - train_datasets (list): Train splits for each network. - val_datasets (list): Validation splits for each network. - test_datasets (list): Test splits for each network. - train_dataset_multi (ConcatDataset): Concatenated train datasets for multi-network training. - val_dataset_multi (ConcatDataset): Concatenated validation datasets for multi-network validation. - _is_setup_done (bool): Tracks whether `setup` has been executed to avoid repeated processing. - - Methods: - setup(stage): - Load and preprocess datasets, split into train/val/test, and store normalizers. - Handles distributed preprocessing safely. - train_dataloader(): - Returns a DataLoader for concatenated training datasets. - val_dataloader(): - Returns a DataLoader for concatenated validation datasets. - test_dataloader(): - Returns a list of DataLoaders, one per test dataset. - predict_dataloader(): - Returns a list of DataLoaders, one per test dataset for prediction. - - Notes: - - Preprocessing is only performed on rank 0 in distributed settings. - - Subsets and splits are deterministic based on the provided random seed. - - Normalizers are loaded for each network independently. - - Test and predict dataloaders are returned as lists, one per dataset. - - Example: - ```python - from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule - from gridfm_graphkit.io.param_handler import NestedNamespace - import yaml - - with open("config/config.yaml") as f: - base_config = yaml.safe_load(f) - args = NestedNamespace(**base_config) - - datamodule = LitGridDataModule(args, data_dir="./data") - - datamodule.setup("fit") - train_loader = datamodule.train_dataloader() - ``` - """ - - def __init__(self, args: NestedNamespace, data_dir: str = "./data"): - super().__init__() - self.data_dir = data_dir - self.batch_size = int(args.training.batch_size) - self.args = args - self.node_normalizers = [] - self.edge_normalizers = [] - self.datasets = [] - self.train_datasets = [] - self.val_datasets = [] - self.test_datasets = [] - self._is_setup_done = False - - def setup(self, stage: str): - if self._is_setup_done: - print(f"Setup already done for stage={stage}, skipping...") - return - - for i, network in enumerate(self.args.data.networks): - node_normalizer, edge_normalizer = load_normalizer(args=self.args) - self.node_normalizers.append(node_normalizer) - self.edge_normalizers.append(edge_normalizer) - - # Create torch dataset and split - data_path_network = os.path.join(self.data_dir, network) - - # Run preprocessing only on rank 0 - if dist.is_available() and dist.is_initialized() and dist.get_rank() == 0: - print(f"Pre-processing of {network} dataset on rank 0") - _ = GridDatasetDisk( # just to trigger processing - root=data_path_network, - norm_method=self.args.data.normalization, - node_normalizer=node_normalizer, - edge_normalizer=edge_normalizer, - pe_dim=self.args.model.pe_dim, - mask_dim=self.args.data.mask_dim, - transform=get_transform(args=self.args), - ) - - # All ranks wait here until processing is done - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.barrier() - - dataset = GridDatasetDisk( - root=data_path_network, - norm_method=self.args.data.normalization, - node_normalizer=node_normalizer, - edge_normalizer=edge_normalizer, - pe_dim=self.args.model.pe_dim, - mask_dim=self.args.data.mask_dim, - transform=get_transform(args=self.args), - ) - - if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=['RRWP'], - cfg=self.args.data - ) - if dataset.transform is None: - dataset.transform = pe_transform - else: - dataset.transform = T.Compose([pe_transform, dataset.transform]) - if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: - pe_transform = ComputePosencStat(pe_types=['RWSE'], - cfg=self.args.data - ) - if dataset.transform is None: - dataset.transform = pe_transform - else: - dataset.transform = T.Compose([pe_transform, dataset.transform]) - - self.datasets.append(dataset) - - num_scenarios = self.args.data.scenarios[i] - if num_scenarios > len(dataset): - warnings.warn( - f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). " - "Using the full dataset instead.", - ) - num_scenarios = len(dataset) - - # Create a subset - all_indices = list(range(len(dataset))) - # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order - random.seed(self.args.seed) - random.shuffle(all_indices) - subset_indices = all_indices[:num_scenarios] - dataset = Subset(dataset, subset_indices) - - # Random seed set before every split, same as above - np.random.seed(self.args.seed) - train_dataset, val_dataset, test_dataset = split_dataset( - dataset, - self.data_dir, - self.args.data.val_ratio, - self.args.data.test_ratio, - ) - - self.train_datasets.append(train_dataset) - self.val_datasets.append(val_dataset) - self.test_datasets.append(test_dataset) - - self.train_dataset_multi = ConcatDataset(self.train_datasets) - self.val_dataset_multi = ConcatDataset(self.val_datasets) - self._is_setup_done = True - - def train_dataloader(self): - return DataLoader( - self.train_dataset_multi, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.args.data.workers, - pin_memory=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset_multi, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - - def test_dataloader(self): - return [ - DataLoader( - i, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - for i in self.test_datasets - ] - - def predict_dataloader(self): - return [ - DataLoader( - i, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - for i in self.test_datasets - ] diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py deleted file mode 100644 index aa3370c1..00000000 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ /dev/null @@ -1,356 +0,0 @@ -import torch -from torch.optim.lr_scheduler import ReduceLROnPlateau -import lightning as L -from pytorch_lightning.utilities import rank_zero_only -import numpy as np -import os -import pandas as pd -import copy - -from lightning.pytorch.loggers import MLFlowLogger -from gridfm_graphkit.io.param_handler import load_model, get_loss_function -import torch.nn.functional as F -from gridfm_graphkit.datasets.globals import PQ, PV, REF, PD, QD, PG, QG, VM, VA - - -class FeatureReconstructionTask(L.LightningModule): - """ - PyTorch Lightning task for node feature reconstruction on power grid graphs. - - This task wraps a GridFM model inside a LightningModule and defines the full - training, validation, testing, and prediction logic. It is designed to - reconstruct masked node features from graph-structured input data, using - datasets and normalizers provided by `gridfm-graphkit`. - - Args: - args (NestedNamespace): Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. - node_normalizers (list): One normalizer per dataset to (de)normalize node features. - edge_normalizers (list): One normalizer per dataset to (de)normalize edge features. - - Attributes: - model (torch.nn.Module): model loaded via `load_model`. - loss_fn (callable): Loss function resolved from configuration. - batch_size (int): Training batch size. From ``args.training.batch_size`` - node_normalizers (list): Dataset-wise node feature normalizers. - edge_normalizers (list): Dataset-wise edge feature normalizers. - - Methods: - forward(x, pe, edge_index, edge_attr, batch, mask=None): - Forward pass with optional feature masking. - training_step(batch): - One training step: computes loss, logs metrics, returns loss. - validation_step(batch, batch_idx): - One validation step: computes losses and logs metrics. - test_step(batch, batch_idx, dataloader_idx=0): - Evaluate on test data, compute per-node-type MSEs, and log per-dataset metrics. - predict_step(batch, batch_idx, dataloader_idx=0): - Run inference and return denormalized outputs + node masks. - configure_optimizers(): - Setup Adam optimizer and ReduceLROnPlateau scheduler. - on_fit_start(): - Save normalization statistics at the beginning of training. - on_test_end(): - Collect test metrics across datasets and export summary CSV reports. - - Notes: - - Node types are distinguished using the global constants (`PQ`, `PV`, `REF`). - - The datamodule must provide `batch.mask` for masking node features. - - Test metrics include per-node-type RMSE for [Pd, Qd, Pg, Qg, Vm, Va]. - - Reports are saved under `/test/.csv`. - - Example: - ```python - model = FeatureReconstructionTask(args, node_normalizers, edge_normalizers) - output = model(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch) - ``` - """ - - def __init__(self, args, node_normalizers, edge_normalizers): - super().__init__() - self.model = load_model(args=args) - self.args = args - self.loss_fn = get_loss_function(args) - self.batch_size = int(args.training.batch_size) - self.node_normalizers = node_normalizers - self.edge_normalizers = edge_normalizers - self.save_hyperparameters() - - def forward(self, batch): - if batch.mask is not None: - # print('xxxx',batch.x.min(), batch.x.max()) - # print('yyyyy',batch.y.min(), batch.y.max()) - mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) - batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] - return self.model(batch) - - @rank_zero_only - def on_fit_start(self): - # Determine save path - if isinstance(self.logger, MLFlowLogger): - log_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - "stats", - ) - else: - log_dir = os.path.join(self.logger.save_dir, "stats") - - os.makedirs(log_dir, exist_ok=True) - log_stats_path = os.path.join(log_dir, "normalization_stats.txt") - - # Collect normalization stats - with open(log_stats_path, "w") as log_file: - for i, normalizer in enumerate(self.node_normalizers): - log_file.write( - f"Node Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", - ) - - for i, normalizer in enumerate(self.edge_normalizers): - log_file.write( - f"Edge Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", - ) - - def shared_step(self, batch): - output = self.forward( - batch - ) - - loss_dict = self.loss_fn( - output, - batch.y, - batch.edge_index, - batch.edge_attr, - batch.mask, - ) - return output, loss_dict - - def training_step(self, batch): - _, loss_dict = self.shared_step(batch) - current_lr = self.optimizer.param_groups[0]["lr"] - metrics = {} - metrics["Training Loss"] = loss_dict["loss"].detach() - metrics["Learning Rate"] = current_lr - for metric, value in metrics.items(): - self.log( - metric, - value, - batch_size=batch.num_graphs, - sync_dist=True, - on_epoch=True, - prog_bar=True, - logger=True, - on_step=False, - ) - - return loss_dict["loss"] - - def validation_step(self, batch, batch_idx): - _, loss_dict = self.shared_step(batch) - loss_dict["loss"] = loss_dict["loss"].detach() - for metric, value in loss_dict.items(): - metric_name = f"Validation {metric}" - self.log( - metric_name, - value, - batch_size=batch.num_graphs, - sync_dist=True, - on_epoch=True, - prog_bar=True, - logger=True, - on_step=False, - ) - - return loss_dict["loss"] - - def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(copy.deepcopy(batch)) - - dataset_name = self.args.data.networks[dataloader_idx] - - output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) - target_denorm = self.node_normalizers[dataloader_idx].inverse_transform(batch.y) - - mask_PQ = batch.x[:, PQ] == 1 - mask_PV = batch.x[:, PV] == 1 - mask_REF = batch.x[:, REF] == 1 - - mse_PQ = F.mse_loss( - output_denorm[mask_PQ], - target_denorm[mask_PQ], - reduction="none", - ) - mse_PV = F.mse_loss( - output_denorm[mask_PV], - target_denorm[mask_PV], - reduction="none", - ) - mse_REF = F.mse_loss( - output_denorm[mask_REF], - target_denorm[mask_REF], - reduction="none", - ) - - mse_PQ = mse_PQ.mean(dim=0) - mse_PV = mse_PV.mean(dim=0) - mse_REF = mse_REF.mean(dim=0) - - loss_dict["MSE PQ nodes - PD"] = mse_PQ[PD] - loss_dict["MSE PV nodes - PD"] = mse_PV[PD] - loss_dict["MSE REF nodes - PD"] = mse_REF[PD] - - loss_dict["MSE PQ nodes - QD"] = mse_PQ[QD] - loss_dict["MSE PV nodes - QD"] = mse_PV[QD] - loss_dict["MSE REF nodes - QD"] = mse_REF[QD] - - loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG] - loss_dict["MSE PV nodes - PG"] = mse_PV[PG] - loss_dict["MSE REF nodes - PG"] = mse_REF[PG] - - loss_dict["MSE PQ nodes - QG"] = mse_PQ[QG] - loss_dict["MSE PV nodes - QG"] = mse_PV[QG] - loss_dict["MSE REF nodes - QG"] = mse_REF[QG] - - loss_dict["MSE PQ nodes - VM"] = mse_PQ[VM] - loss_dict["MSE PV nodes - VM"] = mse_PV[VM] - loss_dict["MSE REF nodes - VM"] = mse_REF[VM] - - loss_dict["MSE PQ nodes - VA"] = mse_PQ[VA] - loss_dict["MSE PV nodes - VA"] = mse_PV[VA] - loss_dict["MSE REF nodes - VA"] = mse_REF[VA] - - loss_dict["Test loss"] = loss_dict.pop("loss").detach() - for metric, value in loss_dict.items(): - metric_name = f"{dataset_name}/{metric}" - if "p.u." in metric: - # Denormalize metrics expressed in p.u. - value *= self.node_normalizers[dataloader_idx].baseMVA - metric_name = metric_name.replace("in p.u.", "").strip() - self.log( - metric_name, - value, - batch_size=batch.num_graphs, - add_dataloader_idx=False, - sync_dist=True, - logger=False, - ) - return - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) - output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) - - # Count buses and generate per-node scenario_id - bus_counts = batch.batch.unique(return_counts=True)[1] - scenario_ids = batch.scenario_id # shape: [num_graphs] - scenario_per_node = torch.cat( - [ - torch.full((count,), sid, dtype=torch.int32) - for count, sid in zip(bus_counts, scenario_ids) - ], - ) - - bus_numbers = np.concatenate([np.arange(count.item()) for count in bus_counts]) - - return { - "output": output_denorm.cpu().numpy(), - "scenario_id": scenario_per_node, - "bus_number": bus_numbers, - } - - @rank_zero_only - def on_test_end(self): - if isinstance(self.logger, MLFlowLogger): - artifact_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - ) - else: - artifact_dir = self.logger.save_dir - - final_metrics = self.trainer.callback_metrics - grouped_metrics = {} - - for full_key, value in final_metrics.items(): - try: - value = value.item() - except AttributeError: - pass - - if "/" in full_key: - dataset_name, metric = full_key.split("/", 1) - if dataset_name not in grouped_metrics: - grouped_metrics[dataset_name] = {} - grouped_metrics[dataset_name][metric] = value - - for dataset, metrics in grouped_metrics.items(): - rmse_PQ = [ - metrics.get(f"MSE PQ nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - rmse_PV = [ - metrics.get(f"MSE PV nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - rmse_REF = [ - metrics.get(f"MSE REF nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - - avg_active_res = metrics.get("Active Power Loss", " ") - avg_reactive_res = metrics.get("Reactive Power Loss", " ") - - data = { - "Metric": [ - "RMSE-PQ", - "RMSE-PV", - "RMSE-REF", - "Avg. active res. (MW)", - "Avg. reactive res. (MVar)", - ], - "Pd (MW)": [ - rmse_PQ[0], - rmse_PV[0], - rmse_REF[0], - avg_active_res, - avg_reactive_res, - ], - "Qd (MVar)": [rmse_PQ[1], rmse_PV[1], rmse_REF[1], " ", " "], - "Pg (MW)": [rmse_PQ[2], rmse_PV[2], rmse_REF[2], " ", " "], - "Qg (MVar)": [rmse_PQ[3], rmse_PV[3], rmse_REF[3], " ", " "], - "Vm (p.u.)": [rmse_PQ[4], rmse_PV[4], rmse_REF[4], " ", " "], - "Va (degree)": [rmse_PQ[5], rmse_PV[5], rmse_REF[5], " ", " "], - } - - df = pd.DataFrame(data) - - test_dir = os.path.join(artifact_dir, "test") - os.makedirs(test_dir, exist_ok=True) - csv_path = os.path.join(test_dir, f"{dataset}.csv") - df.to_csv(csv_path, index=False) - - def configure_optimizers(self): - self.optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.args.optimizer.learning_rate, - betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), - ) - - self.scheduler = ReduceLROnPlateau( - self.optimizer, - mode="min", - factor=self.args.optimizer.lr_decay, - patience=self.args.optimizer.lr_patience, - ) - config_optim = { - "optimizer": self.optimizer, - "lr_scheduler": { - "scheduler": self.scheduler, - "monitor": "Validation loss", - "reduce_on_plateau": True, - }, - } - return config_optim From 6d89bba750c488b7e5e61561f4f491ed95b60da3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:03:39 -0400 Subject: [PATCH 33/95] connect grit and encoders with hetero-adapter Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 33 ++++++- .../models/gnn_heterogeneous_gns.py | 18 ++-- gridfm_graphkit/models/grit_transformer.py | 89 ++++++++++++++++++- gridfm_graphkit/tasks/reconstruction_tasks.py | 11 +-- 4 files changed, 134 insertions(+), 17 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index e12d9f10..5263b488 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -12,7 +12,7 @@ from gridfm_graphkit.datasets.rrwp import add_full_rrwp from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data +from torch_geometric.data import Data, HeteroData from typing import Any from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -129,9 +129,38 @@ def __init__(self, pe_types, cfg): def forward(self, data: Any) -> Any: pass - def __call__(self, data: Data) -> Data: + def __call__(self, data) -> Data: + if isinstance(data, HeteroData): + return self._call_hetero(data) + data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg ) return data + + def _call_hetero(self, data: HeteroData) -> HeteroData: + """Compute PE on the bus-only subgraph and store results on data['bus'].""" + bus_data = Data( + x=data["bus"].x, + edge_index=data["bus", "connects", "bus"].edge_index, + num_nodes=data["bus"].num_nodes, + ) + if hasattr(data["bus", "connects", "bus"], "edge_weight"): + bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight + + bus_data = compute_posenc_stats( + bus_data, pe_types=self.pe_types, cfg=self.cfg, + ) + + # Copy computed PE attributes back onto the HeteroData bus store + pe_attrs = [ + "pestat_RWSE", # RWSE + "rrwp", "rrwp_index", "rrwp_val", # RRWP + "log_deg", "deg", # degree info from RRWP + ] + for attr in pe_attrs: + if hasattr(bus_data, attr): + data["bus"][attr] = getattr(bus_data, attr) + + return data diff --git a/gridfm_graphkit/models/gnn_heterogeneous_gns.py b/gridfm_graphkit/models/gnn_heterogeneous_gns.py index 10735603..e274a1c0 100644 --- a/gridfm_graphkit/models/gnn_heterogeneous_gns.py +++ b/gridfm_graphkit/models/gnn_heterogeneous_gns.py @@ -146,14 +146,20 @@ def __init__(self, args) -> None: # container for monitoring residual norms per layer and type self.layer_residuals = {} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + def forward(self, batch): """ - x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} - edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") - edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) - batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. - mask: optional mask per node (applies when computing residuals) + Accepts a PyG HeteroData batch and extracts the required tensors. + + batch: HeteroData/Batch containing: + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + mask_dict: dict mapping node/bus types to mask tensors """ + x_dict = batch.x_dict + edge_index_dict = batch.edge_index_dict + edge_attr_dict = batch.edge_attr_dict + mask_dict = batch.mask_dict self.layer_residuals = {} diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index ab0d51bd..17b15398 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,6 +1,7 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY import torch from torch import nn +from torch_geometric.data import Data from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer @@ -118,7 +119,6 @@ def forward(self, batch): return pred -@MODELS_REGISTRY.register("GRIT") class GritTransformer(torch.nn.Module): """ The GritTransformer (Graph Inductive Bias Transformer) from @@ -214,3 +214,90 @@ def forward(self, batch): batch = module(batch) return batch + + +@MODELS_REGISTRY.register("GRIT") +class GritHeteroAdapter(torch.nn.Module): + """Adapter that enables the homogeneous GRIT transformer to operate on + heterogeneous power-grid graphs. + + Extracts the bus-only homogeneous subgraph using PyG's native HeteroData + accessors, runs it through the GRIT encoder and transformer layers, and + produces per-node-type predictions. Generator output comes from a + lightweight standalone MLP (generators are not seen by the transformer). + + Returns: + dict: ``{"bus": Tensor[num_bus, output_bus_dim], + "gen": Tensor[num_gen, output_gen_dim]}`` + """ + + def __init__(self, args): + super().__init__() + + dim_inner = args.model.hidden_size + output_bus_dim = args.model.output_bus_dim + output_gen_dim = args.model.output_gen_dim + input_gen_dim = args.model.input_gen_dim + + # Ensure config keys expected by GritTransformer are present. + # input_dim = bus feature dimension (used by FeatureEncoder) + # output_dim = bus output dimension (used by the unused GraphHead) + if not hasattr(args.model, "input_dim"): + args.model.input_dim = args.model.input_bus_dim + if not hasattr(args.model, "output_dim"): + args.model.output_dim = output_bus_dim + + # The original homogeneous GRIT + # (encoder + optional PE encoders + transformer layers + GraphHead) + self.grit = GritTransformer(args) + + # Per-node-type output heads (replace GraphHead for hetero output) + self.bus_head = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_bus_dim), + ) + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + + def forward(self, batch): + """Forward pass on a heterogeneous power-grid batch. + + Args: + batch: A batched ``HeteroData`` with node types ``"bus"`` and + ``"gen"``, and edge type ``("bus", "connects", "bus")``. + + Returns: + dict with keys ``"bus"`` and ``"gen"``, each mapping to the + predicted output features. + """ + # --- Extract bus-only homogeneous subgraph --- + homo = Data( + x=batch["bus"].x, + y=batch["bus"].y, + edge_index=batch["bus", "connects", "bus"].edge_index, + edge_attr=batch["bus", "connects", "bus"].edge_attr, + batch=batch["bus"].batch, + ) + + # Forward positional-encoding attributes if present + for attr in ("pestat_RWSE", "rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"): + if hasattr(batch["bus"], attr): + setattr(homo, attr, getattr(batch["bus"], attr)) + + # --- Run GRIT encoder + PE encoders + transformer layers --- + homo = self.grit.encoder(homo) + if hasattr(self.grit, "rrwp_abs_encoder"): + homo = self.grit.rrwp_abs_encoder(homo) + if hasattr(self.grit, "rrwp_rel_encoder"): + homo = self.grit.rrwp_rel_encoder(homo) + homo = self.grit.layers(homo) + + # --- Per-type decoding --- + bus_out = self.bus_head(homo.x) + gen_out = self.gen_head(batch["gen"].x) + + return {"bus": bus_out, "gen": gen_out} diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 8742646b..43e243c0 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -39,16 +39,11 @@ def __init__(self, args, data_normalizers): self.batch_size = int(args.training.batch_size) self.test_outputs = {i: [] for i in range(len(args.data.networks))} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): - return self.model(x_dict, edge_index_dict, edge_attr_dict, mask_dict) + def forward(self, batch): + return self.model(batch) def shared_step(self, batch): - output = self.forward( - x_dict=batch.x_dict, - edge_index_dict=batch.edge_index_dict, - edge_attr_dict=batch.edge_attr_dict, - mask_dict=batch.mask_dict, - ) + output = self.forward(batch) loss_dict = self.loss_fn( output, From 926dff5f89e5de06ef6c0bd180dfe75ad03f119b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:47:03 -0400 Subject: [PATCH 34/95] flow over and update PBE loss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 116 +++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d253d2b3..02df3bc7 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from gridfm_graphkit.io.registries import LOSS_REGISTRY from torch_scatter import scatter_add +from torch_geometric.utils import to_torch_coo_tensor from gridfm_graphkit.datasets.globals import ( # Bus feature indices @@ -19,6 +20,11 @@ PG_OUT, # Generator feature indices PG_H, + # Edge feature indices + YFF_TT_R, + YFF_TT_I, + YFT_TF_R, + YFT_TF_I, ) @@ -322,3 +328,113 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } + +@LOSS_REGISTRY.register("PBE") +class PBELoss(BaseLoss): + """ + Loss based on the Power Balance Equations. + + Adapted for the heterogeneous graph convention: predictions and targets + are passed as dicts (``{"bus": …, "gen": …}``). Generator active power + is aggregated onto bus nodes via the ``(gen, connected_to, bus)`` edge + index before computing the power balance. + """ + + def __init__(self, loss_args, args): + super(PBELoss, self).__init__() + self.visualization = getattr(loss_args, "visualization", False) + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] + target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] + num_bus = target_bus.size(0) + + bus_edge_index = edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] + mask_bus = mask_dict["bus"] + + # --- Voltage: use prediction where masked, target where known --- + Vm_pred = pred_bus[:, VM_OUT] + Va_pred = pred_bus[:, VA_OUT] + Vm_target = target_bus[:, VM_H] + Va_target = target_bus[:, VA_H] + + mask_Vm = mask_bus[:, VM_H] + mask_Va = mask_bus[:, VA_H] + + V_m = torch.where(mask_Vm, Vm_pred, Vm_target) + V_a = torch.where(mask_Va, Va_pred, Va_target) + + # Complex voltage + V = V_m * torch.exp(1j * V_a) + V_conj = torch.conj(V) + + # --- Admittance matrix from bus-bus edge attrs --- + # Use Yff (diagonal-block) real/imag as the admittance entries + edge_complex = bus_edge_attr[:, YFF_TT_R] + 1j * bus_edge_attr[:, YFF_TT_I] + + Y_bus_sparse = to_torch_coo_tensor( + bus_edge_index, + edge_complex, + size=(num_bus, num_bus), + ) + Y_bus_conj = torch.conj(Y_bus_sparse) + + # Complex power injection: S_inj = diag(V) * conj(Y) * conj(V) + S_injection = torch.diag(V) @ Y_bus_conj @ V_conj + + # --- Net power from predictions/targets --- + # Pg: aggregate generator predictions onto buses + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + Pg_per_bus = scatter_add( + pred_dict["gen"].squeeze(-1), + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + + Pd = target_bus[:, PD_H] + Qd = target_bus[:, QD_H] + + # Qg: use prediction if the model predicts it, else use target + if pred_bus.size(1) > QG_OUT: + Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) + else: + Qg = target_bus[:, QG_H] + + net_P = Pg_per_bus - Pd + net_Q = Qg - Qd + S_net = net_P + 1j * net_Q + + # --- Loss --- + loss = torch.mean(torch.abs(S_net - S_injection)) + + real_loss = torch.mean( + torch.abs(torch.real(S_net - S_injection)), + ) + imag_loss = torch.mean( + torch.abs(torch.imag(S_net - S_injection)), + ) + + result = { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss.detach(), + "Reactive Power Loss in p.u.": imag_loss.detach(), + } + if self.visualization: + result["Nodal Active Power Loss in p.u."] = torch.abs( + torch.real(S_net - S_injection), + ) + result["Nodal Reactive Power Loss in p.u."] = torch.abs( + torch.imag(S_net - S_injection), + ) + return result From 9bcb0d1f73dede61aeeead8f0b22967a2647aaa4 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:47:47 -0400 Subject: [PATCH 35/95] added sample config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 94 +++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/config/GRIT_PF_datakit_case14.yaml diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml new file mode 100644 index 00000000..d119e13d --- /dev/null +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -0,0 +1,94 @@ +callbacks: + patience: 100 + tol: 0 +task: + task_name: PowerFlow +data: + baseMVA: 100 + mask_value: 0.0 + normalization: HeteroDataMVANormalizer + networks: + - case14_ieee + scenarios: + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + posenc_RRWP: + enable: false + ksteps: 21 + posenc_RWSE: + enable: true + kernel: + times: 21 +model: + attention_head: 8 + dropout: 0.1 + # edge_dim must match the bus-bus edge feature count after transforms + # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) + edge_dim: 10 + hidden_size: 116 + # input_dim = bus feature count (used by GRIT core FeatureEncoder) + input_dim: 15 + # Hetero adapter head dimensions + input_bus_dim: 15 + input_gen_dim: 6 + output_bus_dim: 2 + output_gen_dim: 1 + num_layers: 10 + type: GRIT + act: relu + encoder: + node_encoder: true + edge_encoder: true + node_encoder_name: RWSE + node_encoder_bn: true + edge_encoder_bn: true + posenc_RWSE: + kernel: + times: 21 + pe_dim: 20 + raw_norm_type: batchnorm + gt: + layer_type: GritTransformer + dim_hidden: 116 # must match hidden_size + layer_norm: false + batch_norm: true + update_e: true + attn_dropout: 0.2 + attn: + clamp: 5. + act: relu + full_attn: true + edge_enhance: true + O_e: true + norm_e: true + signed_sqrt: true + bn_momentum: 0.1 + bn_no_runner: false +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.09 + - 0.9 + losses: + - PBE + - MaskedGenMSE + - MaskedBusMSE + loss_args: + - {} + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +verbose: true From b68ae5e12d9d92a8495efb960983040d47efce85 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:52:37 -0400 Subject: [PATCH 36/95] update project toml Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 2b6c523c..100b8753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,14 @@ dependencies = [ "nbformat>=5.10.4", "networkx>=3.4.2", "numpy>=2.2.6", + "opt-einsum>=3.3.0", "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", "torch>=2.7.1,<2.9", "torch-geometric>=2.6.1", + "torch-scatter>=2.1.2", + "torch-sparse>=0.6.18", "torchaudio>=2.7.1", "torchvision>=0.22.1", "lightning", From 52942fa6aef4129ab9030b13fb685df3a5273537 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:28:02 -0400 Subject: [PATCH 37/95] simplify configuration file Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 5 ++--- gridfm_graphkit/models/grit_transformer.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index d119e13d..d1ecdcaf 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -45,13 +45,12 @@ model: node_encoder_bn: true edge_encoder_bn: true posenc_RWSE: - kernel: - times: 21 + # kernel.times is synced automatically from data.posenc_RWSE.kernel.times pe_dim: 20 raw_norm_type: batchnorm gt: layer_type: GritTransformer - dim_hidden: 116 # must match hidden_size + # dim_hidden is synced automatically from model.hidden_size layer_norm: false batch_norm: true update_e: true diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 17b15398..aab8b939 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -247,6 +247,24 @@ def __init__(self, args): if not hasattr(args.model, "output_dim"): args.model.output_dim = output_bus_dim + # Sync PE kernel.times from data config into model encoder config so + # users only need to specify it once (under data.posenc_RWSE). + if ( + hasattr(args.data, "posenc_RWSE") + and args.data.posenc_RWSE.enable + and hasattr(args.model, "encoder") + and hasattr(args.model.encoder, "posenc_RWSE") + ): + from gridfm_graphkit.io.param_handler import NestedNamespace + enc_rwse = args.model.encoder.posenc_RWSE + if not hasattr(enc_rwse, "kernel"): + enc_rwse.kernel = NestedNamespace() + enc_rwse.kernel.times = args.data.posenc_RWSE.kernel.times + + # Sync gt.dim_hidden from model.hidden_size so it is specified once. + if hasattr(args.model, "gt"): + args.model.gt.dim_hidden = args.model.hidden_size + # The original homogeneous GRIT # (encoder + optional PE encoders + transformer layers + GraphHead) self.grit = GritTransformer(args) From eba33e52aba4204c07d8dbc3b412e6d500f5e528 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:32:41 -0400 Subject: [PATCH 38/95] flow over time benchmarking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 447 +++++++++++++++++++++++++++ scripts/run_benchmark.sh | 39 +++ 2 files changed, 486 insertions(+) create mode 100644 scripts/benchmark_model_inference.py create mode 100755 scripts/run_benchmark.sh diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py new file mode 100644 index 00000000..9a80d14b --- /dev/null +++ b/scripts/benchmark_model_inference.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. + +Example usagef (edge count is 2*E (branch count)): + +###################################### + +CONF_PATH=../examples/config +OUT_DIR=../scripts +mkdir $OUT_DIR + +python benchmark_model_inference.py --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case57_ieee_base.yaml --num_nodes 57 --num_edges 160 --num_gens 7 --iterations 20 --output_csv $OUT_DIR/case57.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case500_ieee_base.yaml --num_nodes 500 --num_edges 1466 --num_gens 224 --iterations 20 --output_csv $OUT_DIR/case500.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case2000_ieee_base.yaml --num_nodes 2000 --num_edges 7278 --num_gens 384 --iterations 20 --output_csv $OUT_DIR/case2000.csv || true + +###################################### + +Author(s): Mangaliso M. - mngomezulum@ibm.com + Matteo M. - Not Available +""" + +import os +import time +import csv +import yaml +import torch +import argparse +import platform +from datetime import datetime +from torch_geometric.loader import DataLoader +from torch_geometric.data import HeteroData +from gridfm_graphkit.io.param_handler import NestedNamespace, load_model + +# Optional: tqdm (imported but not required for core flow) +try: + from tqdm import tqdm # noqa: F401 +except Exception: + pass + +# Compilation (kept from original) +import torch._dynamo as dynamo +dynamo.config.suppress_errors = False + +# ---------------------------- +# Argument Parsing +# ---------------------------- +parser = argparse.ArgumentParser(description="Benchmark GNS_final Heterogeneous Model with profiling CSV") +parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") +parser.add_argument("--num_nodes", type=int, required=True) +parser.add_argument("--num_gens", type=int, required=True) +parser.add_argument("--num_edges", type=int, required=True) +parser.add_argument("--output_csv", type=str, required=True) +parser.add_argument("--iterations", type=int, default=20) +parser.add_argument("--num_workers", type=int, default=0, help="DataLoader num_workers") +parser.add_argument("--pin_memory", action="store_true", help="Enable pin_memory in DataLoader when CUDA is available") +args = parser.parse_args() + +# --- Custom logging (ensure directory exists) +import logging +os.makedirs('logs', exist_ok=True) +logger = logging.getLogger('ibm_benchmark_logger') +logger.setLevel(logging.DEBUG) +logger.propagate = False +file_handler = logging.FileHandler('logs/ibm_bench_logs.log', mode='a') # 'a' for append, 'w' to overwrite +file_handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +file_handler.setFormatter(formatter) +if not logger.handlers: + logger.addHandler(file_handler) + +# ---------------------------- +# Load Model +# ---------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +with open(args.config, "r") as f: + base_config = yaml.safe_load(f) + +config_args = NestedNamespace(**base_config) +model = load_model(config_args).to(device).eval() + +# ---------------------------- +# Parameters +# ---------------------------- +N_BUS = args.num_nodes +N_GEN = args.num_gens +E = args.num_edges +BUS_FEATS = config_args.model.input_bus_dim +GEN_FEATS = config_args.model.input_gen_dim +EDGE_FEATS = config_args.model.edge_dim + +# Keep original batch sizes list +batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] +iterations = args.iterations + +# ---------------------------- +# Helpers +# ---------------------------- +def now_ms() -> float: + return time.perf_counter() * 1000.0 + +def maybe_cuda_sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def get_env_info(): + # CPU name detection + cpu_name = None + try: + cpu_name = platform.processor() or None + if not cpu_name and os.path.exists("/proc/cpuinfo"): + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_name = line.strip().split(":", 1)[1].strip() + break + if not cpu_name: + cpu_name = platform.uname().machine + except Exception: + cpu_name = "unknown" + + # GPU names and device info + if torch.cuda.is_available(): + try: + gpu_names_list = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + gpu_names = "; ".join(gpu_names_list) + except Exception: + gpu_names = "cuda_available_but_name_unreadable" + device_type = "cuda" + device_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + cuda_version_in_torch = torch.version.cuda + cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None + else: + # Apple Metal backend? + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + device_type = "mps" + device_name = "Apple MPS" + gpu_names = "mps" + cuda_version_in_torch = None + cudnn_version = None + else: + device_type = "cpu" + device_name = "cpu" + gpu_names = "none" + cuda_version_in_torch = None + cudnn_version = None + + info = { + "device_type": device_type, + "device_name": device_name, + "gpu_names": gpu_names, + "cpu_name": cpu_name, + "torch_version": torch.__version__, + "cuda_version_in_torch": cuda_version_in_torch, + "cudnn_version": cudnn_version, + "python_version": platform.python_version(), + } + return info + +# ---------------------------- +# Generate Synthetic Hetero Graph +# ---------------------------- +def generate_hetero_graph(): + """ + Generates a dummy heterogeneous power network graph for benchmarking. + + Returns: + data (HeteroData): single self-contained heterogeneous graph with: + - data["bus"].x, data["gen"].x + - edge_index & edge_attr for all relations + - mask_dict inside data.mask_dict + """ + data = HeteroData() + + # Node features + data["bus"].x = torch.randn(N_BUS, BUS_FEATS) + data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + + # Edges: Bus–Bus + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data["bus", "connects", "bus"].edge_index = torch.stack([src, dst], dim=0) + data["bus", "connects", "bus"].edge_attr = torch.randn(E, EDGE_FEATS) + + # Edges: Gen–Bus & Bus–Gen + gen_to_bus = torch.randint(0, N_BUS, (N_GEN,)) + + # Gen → Bus + data["gen", "connected_to", "bus"].edge_index = torch.stack( + [torch.arange(N_GEN), gen_to_bus], dim=0 + ) + + # Bus → Gen + data["bus", "connected_to", "gen"].edge_index = torch.stack( + [gen_to_bus, torch.arange(N_GEN)], dim=0 + ) + + # No edge features for these + data["gen", "connected_to", "bus"].edge_attr = None + data["bus", "connected_to", "gen"].edge_attr = None + + # Dummy masks (all True) + mask_bus = torch.ones_like(data["bus"].x, dtype=torch.bool) + mask_gen = torch.ones_like(data["gen"].x, dtype=torch.bool) + bus_types = torch.randint(0, 3, (N_BUS,)) + mask_branch = torch.ones_like(data["bus", "connects", "bus"].edge_attr, dtype=torch.bool) + + mask_PQ = bus_types == 0 + mask_PV = bus_types == 1 + mask_REF = bus_types == 2 + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + "branch": mask_branch + } + return data + +# ---------------------------- +# Benchmark Function +# ---------------------------- +def benchmark(): + # Environment/context info (constant per run) + env = get_env_info() + timestamp = datetime.now().isoformat(timespec='seconds') + + # Measure synthetic graph creation + t0 = now_ms() + data = generate_hetero_graph() + t1 = now_ms() + data_gen_time_ms = t1 - t0 + + # Move the base graph to device (preserve original behavior) + maybe_cuda_sync() + t2 = now_ms() + data = data.to(device) + maybe_cuda_sync() + t3 = now_ms() + graph_to_device_time_ms = t3 - t2 + + batch_sizes_used = [] + times = [] + + header = [ + # Keep original first two columns + "batch_size", + "avg_time_per_sample_ms", + + # Execution config + "num_iters", + "total_samples", + + # Data/IO timing + "data_gen_time_ms", + "graph_to_device_time_ms", + "clone_list_time_ms", + "dataloader_create_time_ms", + "dataloader_first_iter_time_ms", + "batch_to_device_time_ms", + + # Model timing + "warmup_time_ms", + "iter_total_wall_time_ms", + "iter_gpu_time_ms", + "gpu_idle_time_ms", + "gpu_busy_ratio", + "samples_per_sec_wall", + "samples_per_sec_gpu", + "timing_source", # "cuda_event" or "wall_clock" + + # Memory + "max_cuda_mem_alloc_bytes", + "max_cuda_mem_reserved_bytes", + + # Graph & model context + "n_bus", "n_gen", "n_edges", + "bus_feats", "gen_feats", "edge_feats", + + # Runtime context + "device_type", "device_name", + "torch_version", "cuda_version_in_torch", "cudnn_version", + "python_version", + "cpu_name", # NEW + "gpu_names", # NEW + "timestamp_iso", + "num_workers", + "pin_memory", + ] + + with open(args.output_csv, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(header) + + for batch_size in batch_sizes: + # Build list of graphs (on device, preserving original flow) + maybe_cuda_sync() + t_clone_start = now_ms() + data_list = [data.clone() for _ in range(batch_size)] + maybe_cuda_sync() + t_clone_end = now_ms() + clone_list_time_ms = t_clone_end - t_clone_start + + # Create DataLoader + pin_mem = args.pin_memory and torch.cuda.is_available() + persistent = args.num_workers > 0 + t_dl_create_start = now_ms() + loader = DataLoader( + data_list, + batch_size=batch_size, + num_workers=args.num_workers, + pin_memory=pin_mem, + persistent_workers=persistent, + ) + t_dl_create_end = now_ms() + dataloader_create_time_ms = t_dl_create_end - t_dl_create_start + + # Fetch first batch (collate) + t_iter_start = now_ms() + batch = next(iter(loader)) + t_iter_end = now_ms() + dataloader_first_iter_time_ms = t_iter_end - t_iter_start + + # Ensure batch on device (likely ~0 if items already on device) + maybe_cuda_sync() + t_b2d_start = now_ms() + batch = batch.to(device, non_blocking=True) if torch.cuda.is_available() else batch.to(device) + maybe_cuda_sync() + t_b2d_end = now_ms() + batch_to_device_time_ms = t_b2d_end - t_b2d_start + + test_model = model + + # Warmup (excluded from main timing) + maybe_cuda_sync() + t_warmup_start = now_ms() + with torch.no_grad(): + for _ in range(5): + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + maybe_cuda_sync() + t_warmup_end = now_ms() + warmup_time_ms = t_warmup_end - t_warmup_start + + num_iters = iterations + total_samples = batch_size * num_iters + + # Reset CUDA memory stats and set up timing + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(device) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Iteration timing + maybe_cuda_sync() + wall_start = now_ms() + with torch.no_grad(): + if torch.cuda.is_available(): + start_event.record() + for _ in range(num_iters): + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if torch.cuda.is_available(): + end_event.record() + maybe_cuda_sync() + wall_end = now_ms() + + iter_total_wall_time_ms = wall_end - wall_start + + if torch.cuda.is_available(): + iter_gpu_time_ms = float(start_event.elapsed_time(end_event)) # ms + timing_source = "cuda_event" + avg_time_per_sample_ms = iter_gpu_time_ms / total_samples + gpu_idle_time_ms = max(iter_total_wall_time_ms - iter_gpu_time_ms, 0.0) + gpu_busy_ratio = (iter_gpu_time_ms / iter_total_wall_time_ms) if iter_total_wall_time_ms > 0 else None + max_cuda_mem_alloc_bytes = int(torch.cuda.max_memory_allocated(device)) + max_cuda_mem_reserved_bytes = int(torch.cuda.max_memory_reserved(device)) + samples_per_sec_gpu = (total_samples / (iter_gpu_time_ms / 1000.0)) if iter_gpu_time_ms > 0 else None + else: + iter_gpu_time_ms = None + timing_source = "wall_clock" + avg_time_per_sample_ms = iter_total_wall_time_ms / total_samples + gpu_idle_time_ms = None + gpu_busy_ratio = None + max_cuda_mem_alloc_bytes = None + max_cuda_mem_reserved_bytes = None + samples_per_sec_gpu = None + + samples_per_sec_wall = (total_samples / (iter_total_wall_time_ms / 1000.0)) if iter_total_wall_time_ms > 0 else None + + # Prepare row + row = [ + batch_size, + avg_time_per_sample_ms, + + num_iters, + total_samples, + + data_gen_time_ms, + graph_to_device_time_ms, + clone_list_time_ms, + dataloader_create_time_ms, + dataloader_first_iter_time_ms, + batch_to_device_time_ms, + + warmup_time_ms, + iter_total_wall_time_ms, + iter_gpu_time_ms, + gpu_idle_time_ms, + gpu_busy_ratio, + samples_per_sec_wall, + samples_per_sec_gpu, + timing_source, + + max_cuda_mem_alloc_bytes, + max_cuda_mem_reserved_bytes, + + N_BUS, N_GEN, E, + BUS_FEATS, GEN_FEATS, EDGE_FEATS, + + env["device_type"], env["device_name"], + env["torch_version"], env["cuda_version_in_torch"], env["cudnn_version"], + env["python_version"], + env["cpu_name"], + env["gpu_names"], + timestamp, + args.num_workers, + bool(pin_mem), + ] + + writer.writerow(row) + csvfile.flush() + batch_sizes_used.append(batch_size) + times.append(avg_time_per_sample_ms) + + return batch_sizes_used, times + + +if __name__ == "__main__": + print(f"Starting benchmark for {os.path.basename(args.output_csv)} ..") + benchmark() + print(f"Finished benchmarking for {os.path.basename(args.output_csv)}\n ...") diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 00000000..52ae9818 --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +set +e # Do NOT exit on error + +CONFIGS=( + "gridfm01" + "gridfm02" +) + +CONFIG_PATHS=( + "../examples/config/gridFMv0.1_pretraining.yaml" + "../examples/config/gridFMv0.2_pretraining.yaml" +) + +GRAPH_SIZES=( + "30 110" + "300 1120" + "2000 9276" + "3022 11390" + "9241 41337" + "30000 100784" +) + +OUTPUT_DIR="benchmark_results" +mkdir -p $OUTPUT_DIR +for i in "${!CONFIGS[@]}"; do + config_name="${CONFIGS[$i]}" + config_path="${CONFIG_PATHS[$i]}" + for size in "${GRAPH_SIZES[@]}"; do + read -r nodes edges <<< "$size" + output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" + echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." + python benchmark_model_inference.py \ + --config "$config_path" \ + --output_csv "$output_file" \ + --num_nodes "$nodes" \ + --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" + done +done \ No newline at end of file From 9dd35d6fcab24fa3b9b567a4b09971d181521730 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:20:17 -0400 Subject: [PATCH 39/95] add baseline grit support Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 112 +++++++++++++++++++++++---- 1 file changed, 98 insertions(+), 14 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index 9a80d14b..f98e686c 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -4,7 +4,11 @@ """ A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. -Example usagef (edge count is 2*E (branch count)): +Supports two model types via --model flag: + - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) + - "grit": GritTransformer with homogeneous Data (single node type) + +Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): ###################################### @@ -12,11 +16,17 @@ OUT_DIR=../scripts mkdir $OUT_DIR -python benchmark_model_inference.py --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case57_ieee_base.yaml --num_nodes 57 --num_edges 160 --num_gens 7 --iterations 20 --output_csv $OUT_DIR/case57.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case500_ieee_base.yaml --num_nodes 500 --num_edges 1466 --num_gens 224 --iterations 20 --output_csv $OUT_DIR/case500.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case2000_ieee_base.yaml --num_nodes 2000 --num_edges 7278 --num_gens 384 --iterations 20 --output_csv $OUT_DIR/case2000.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true + +###################################### + +Example usage — GRIT (homogeneous, --num_gens is ignored): + +###################################### + +python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 30 --num_edges 82 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 118 --num_edges 372 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true ###################################### @@ -33,7 +43,7 @@ import platform from datetime import datetime from torch_geometric.loader import DataLoader -from torch_geometric.data import HeteroData +from torch_geometric.data import Data, HeteroData from gridfm_graphkit.io.param_handler import NestedNamespace, load_model # Optional: tqdm (imported but not required for core flow) @@ -49,10 +59,13 @@ # ---------------------------- # Argument Parsing # ---------------------------- -parser = argparse.ArgumentParser(description="Benchmark GNS_final Heterogeneous Model with profiling CSV") +parser = argparse.ArgumentParser(description="Benchmark GNN Model inference with profiling CSV") +parser.add_argument("--model", type=str, choices=["hetero", "grit"], default="hetero", + help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer") parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") parser.add_argument("--num_nodes", type=int, required=True) -parser.add_argument("--num_gens", type=int, required=True) +parser.add_argument("--num_gens", type=int, default=0, + help="Number of generator nodes (required for hetero, ignored for grit)") parser.add_argument("--num_edges", type=int, required=True) parser.add_argument("--output_csv", type=str, required=True) parser.add_argument("--iterations", type=int, default=20) @@ -87,13 +100,30 @@ # ---------------------------- # Parameters # ---------------------------- +MODEL_TYPE = args.model N_BUS = args.num_nodes N_GEN = args.num_gens E = args.num_edges -BUS_FEATS = config_args.model.input_bus_dim -GEN_FEATS = config_args.model.input_gen_dim EDGE_FEATS = config_args.model.edge_dim +if MODEL_TYPE == "hetero": + BUS_FEATS = config_args.model.input_bus_dim + GEN_FEATS = config_args.model.input_gen_dim + NODE_FEATS = None # not used for hetero +else: + # GRIT homogeneous model + NODE_FEATS = config_args.model.input_dim + OUTPUT_DIM = config_args.model.output_dim + MASK_DIM = getattr(config_args.data, "mask_dim", 6) + # Positional encoding config + RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False + RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 + RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ + and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") + RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + BUS_FEATS = NODE_FEATS # alias for CSV output compatibility + GEN_FEATS = 0 + # Keep original batch sizes list batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] iterations = args.iterations @@ -224,6 +254,51 @@ def generate_hetero_graph(): } return data + +# ---------------------------- +# Generate Synthetic Homogeneous Graph (GRIT) +# ---------------------------- +def generate_homo_graph(): + """ + Generates a dummy homogeneous power network graph for GRIT benchmarking. + + Returns: + data (Data): single self-contained homogeneous graph with: + - data.x: node features [N_BUS, NODE_FEATS] + - data.y: target labels [N_BUS, OUTPUT_DIM] + - data.edge_index: [2, E] + - data.edge_attr: [E, EDGE_FEATS] + - data.pestat_RWSE (if RWSE enabled): [N_BUS, RWSE_TIMES] + - data.rrwp, rrwp_index, rrwp_val (if RRWP enabled) + """ + data = Data() + + # Node features: same layout as powergrid_dataset (Pd, Qd, Pg, Qg, Vm, Va, PQ, PV, REF) + data.x = torch.randn(N_BUS, NODE_FEATS) + data.y = data.x[:, :OUTPUT_DIM].clone() + + # Edges + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data.edge_index = torch.stack([src, dst], dim=0) + data.edge_attr = torch.randn(E, EDGE_FEATS) + + # RWSE positional encoding (diagonal of random-walk matrix powers) + if RWSE_ENABLED: + data.pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + + # RRWP positional / structural encoding + if RRWP_ENABLED: + data.rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing edges + self-loops + self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) + rrwp_idx = torch.cat([data.edge_index, self_loops], dim=1) + rrwp_nnz = rrwp_idx.size(1) + data.rrwp_index = rrwp_idx + data.rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + + return data + # ---------------------------- # Benchmark Function # ---------------------------- @@ -234,7 +309,10 @@ def benchmark(): # Measure synthetic graph creation t0 = now_ms() - data = generate_hetero_graph() + if MODEL_TYPE == "hetero": + data = generate_hetero_graph() + else: + data = generate_homo_graph() t1 = now_ms() data_gen_time_ms = t1 - t0 @@ -343,7 +421,10 @@ def benchmark(): t_warmup_start = now_ms() with torch.no_grad(): for _ in range(5): - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if MODEL_TYPE == "hetero": + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + else: + _ = test_model(batch.clone()) maybe_cuda_sync() t_warmup_end = now_ms() warmup_time_ms = t_warmup_end - t_warmup_start @@ -364,7 +445,10 @@ def benchmark(): if torch.cuda.is_available(): start_event.record() for _ in range(num_iters): - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if MODEL_TYPE == "hetero": + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + else: + _ = test_model(batch.clone()) if torch.cuda.is_available(): end_event.record() maybe_cuda_sync() From 91047ccac0b8c8b86c2b84b44e0f2b51810440ca Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:25:31 -0400 Subject: [PATCH 40/95] update benchmarking for new grit format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 97 ------------------------ scripts/benchmark_model_inference.py | 104 ++++++++++++-------------- scripts/run_benchmark.sh | 1 + 3 files changed, 49 insertions(+), 153 deletions(-) delete mode 100644 examples/config/grit_pretraining.yaml diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml deleted file mode 100644 index bf56dffa..00000000 --- a/examples/config/grit_pretraining.yaml +++ /dev/null @@ -1,97 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -data: - baseMVA: 100 - learn_mask: false - mask_dim: 6 - mask_ratio: 0.5 - mask_type: rnd - mask_value: -1.0 - networks: - # - Texas2k_case1_2016summerpeak - - case24_ieee_rts - - case118_ieee - - case300_ieee - - case89_pegase - - case240_pserc - normalization: baseMVAnorm - scenarios: - # - 5000 - - 5000 - - 5000 - - 3000 - - 5000 - - 5000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 4 - posenc_RRWP: - enable: False - ksteps: 21 - add_identity: True - add_node_attr: False - add_inverse: False - posenc_RWSE: - enable: True - kernel: - times: 21 # TODO unify with model -model: - attention_head: 8 - dropout: 0.1 - edge_dim: 2 - hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` - input_dim: 9 - num_layers: 10 - output_dim: 6 - pe_dim: 20 - type: GRIT - act: relu - encoder: - node_encoder: True - edge_encoder: True - node_encoder_name: RWSE - node_encoder_bn: True - edge_encoder_bn: True - posenc_RWSE: - kernel: - times: 21 - pe_dim: 20 # TODO unify with model.pe_dim - raw_norm_type: batchnorm - gt: - layer_type: GritTransformer - dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` - layer_norm: False - batch_norm: True - update_e: True - attn_dropout: 0.2 - attn: - clamp: 5. - act: 'relu' - full_attn: True - edge_enhance: True - O_e: True - norm_e: True - signed_sqrt: True - bn_momentum: 0.1 - bn_no_runner: False - -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0001 - lr_decay: 0.7 - lr_patience: 10 -seed: 0 -training: - batch_size: 8 - epochs: 500 - loss_weights: - - 0.01 - - 0.99 - losses: - - MaskedMSE - - PBE - accelerator: auto - devices: auto - strategy: auto diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f98e686c..f8c7b76e 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -6,7 +6,7 @@ Supports two model types via --model flag: - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) - - "grit": GritTransformer with homogeneous Data (single node type) + - "grit": GritHeteroAdapter with HeteroData (bus + gen nodes, optional PE attrs) Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): @@ -21,12 +21,12 @@ ###################################### -Example usage — GRIT (homogeneous, --num_gens is ignored): +Example usage — GRIT (HeteroData with PE, --num_gens required): ###################################### -python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 30 --num_edges 82 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true -python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 118 --num_edges 372 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true ###################################### @@ -43,7 +43,7 @@ import platform from datetime import datetime from torch_geometric.loader import DataLoader -from torch_geometric.data import Data, HeteroData +from torch_geometric.data import HeteroData from gridfm_graphkit.io.param_handler import NestedNamespace, load_model # Optional: tqdm (imported but not required for core flow) @@ -102,27 +102,31 @@ # ---------------------------- MODEL_TYPE = args.model N_BUS = args.num_nodes -N_GEN = args.num_gens E = args.num_edges -EDGE_FEATS = config_args.model.edge_dim -if MODEL_TYPE == "hetero": - BUS_FEATS = config_args.model.input_bus_dim - GEN_FEATS = config_args.model.input_gen_dim - NODE_FEATS = None # not used for hetero -else: - # GRIT homogeneous model - NODE_FEATS = config_args.model.input_dim - OUTPUT_DIM = config_args.model.output_dim - MASK_DIM = getattr(config_args.data, "mask_dim", 6) - # Positional encoding config +# Default num_gens when not provided (shell script omits --num_gens) +N_GEN = args.num_gens if args.num_gens > 0 else max(1, N_BUS // 5) + +EDGE_FEATS = getattr(config_args.model, "edge_dim", 10) + +# Both model types use HeteroData with bus + gen nodes. +# Fall back to input_dim / defaults for configs that lack the hetero keys. +BUS_FEATS = getattr(config_args.model, "input_bus_dim", + getattr(config_args.model, "input_dim", 15)) +GEN_FEATS = getattr(config_args.model, "input_gen_dim", 6) + +if MODEL_TYPE == "grit": + # Positional encoding config (only GRIT uses these) RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 - BUS_FEATS = NODE_FEATS # alias for CSV output compatibility - GEN_FEATS = 0 +else: + RRWP_ENABLED = False + RRWP_KSTEPS = 0 + RWSE_ENABLED = False + RWSE_TIMES = 0 # Keep original batch sizes list batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] @@ -211,6 +215,10 @@ def generate_hetero_graph(): data["bus"].x = torch.randn(N_BUS, BUS_FEATS) data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + # Dummy targets + data["bus"].y = torch.randn(N_BUS, BUS_FEATS) + data["gen"].y = torch.randn(N_GEN, GEN_FEATS) + # Edges: Bus–Bus src = torch.randint(0, N_BUS, (E,)) dst = torch.randint(0, N_BUS, (E,)) @@ -258,44 +266,34 @@ def generate_hetero_graph(): # ---------------------------- # Generate Synthetic Homogeneous Graph (GRIT) # ---------------------------- -def generate_homo_graph(): +def generate_grit_graph(): """ - Generates a dummy homogeneous power network graph for GRIT benchmarking. + Generates a dummy heterogeneous graph for GRIT benchmarking. + + GritHeteroAdapter expects a HeteroData batch, so we generate the same + structure as generate_hetero_graph() but also attach PE attributes on + the bus node store when RWSE / RRWP are enabled. Returns: - data (Data): single self-contained homogeneous graph with: - - data.x: node features [N_BUS, NODE_FEATS] - - data.y: target labels [N_BUS, OUTPUT_DIM] - - data.edge_index: [2, E] - - data.edge_attr: [E, EDGE_FEATS] - - data.pestat_RWSE (if RWSE enabled): [N_BUS, RWSE_TIMES] - - data.rrwp, rrwp_index, rrwp_val (if RRWP enabled) + data (HeteroData): heterogeneous graph with bus & gen nodes, + plus optional PE attributes on data["bus"]. """ - data = Data() - - # Node features: same layout as powergrid_dataset (Pd, Qd, Pg, Qg, Vm, Va, PQ, PV, REF) - data.x = torch.randn(N_BUS, NODE_FEATS) - data.y = data.x[:, :OUTPUT_DIM].clone() - - # Edges - src = torch.randint(0, N_BUS, (E,)) - dst = torch.randint(0, N_BUS, (E,)) - data.edge_index = torch.stack([src, dst], dim=0) - data.edge_attr = torch.randn(E, EDGE_FEATS) + data = generate_hetero_graph() - # RWSE positional encoding (diagonal of random-walk matrix powers) + # RWSE positional encoding on bus nodes if RWSE_ENABLED: - data.pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + data["bus"].pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() - # RRWP positional / structural encoding + # RRWP positional / structural encoding on bus nodes if RRWP_ENABLED: - data.rrwp = torch.randn(N_BUS, RRWP_KSTEPS) - # Sparse RRWP for edges: include existing edges + self-loops + data["bus"].rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing bus-bus edges + self-loops + bb_ei = data["bus", "connects", "bus"].edge_index self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) - rrwp_idx = torch.cat([data.edge_index, self_loops], dim=1) + rrwp_idx = torch.cat([bb_ei, self_loops], dim=1) rrwp_nnz = rrwp_idx.size(1) - data.rrwp_index = rrwp_idx - data.rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + data["bus"].rrwp_index = rrwp_idx + data["bus"].rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) return data @@ -312,7 +310,7 @@ def benchmark(): if MODEL_TYPE == "hetero": data = generate_hetero_graph() else: - data = generate_homo_graph() + data = generate_grit_graph() t1 = now_ms() data_gen_time_ms = t1 - t0 @@ -421,10 +419,7 @@ def benchmark(): t_warmup_start = now_ms() with torch.no_grad(): for _ in range(5): - if MODEL_TYPE == "hetero": - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) - else: - _ = test_model(batch.clone()) + _ = test_model(batch.clone()) maybe_cuda_sync() t_warmup_end = now_ms() warmup_time_ms = t_warmup_end - t_warmup_start @@ -445,10 +440,7 @@ def benchmark(): if torch.cuda.is_available(): start_event.record() for _ in range(num_iters): - if MODEL_TYPE == "hetero": - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) - else: - _ = test_model(batch.clone()) + _ = test_model(batch.clone()) if torch.cuda.is_available(): end_event.record() maybe_cuda_sync() diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 52ae9818..6483b815 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -31,6 +31,7 @@ for i in "${!CONFIGS[@]}"; do output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." python benchmark_model_inference.py \ + --model "grit" \ --config "$config_path" \ --output_csv "$output_file" \ --num_nodes "$nodes" \ From deaf640249b74e4b927cd068b7fcfcf7c81f2ce0 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:40:01 -0400 Subject: [PATCH 41/95] cleanup Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index b9222749..64d97279 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,10 +1,4 @@ -<<<<<<< HEAD -from gridfm_graphkit.models.gps_transformer import GPSTransformer -from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv -from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] -======= from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, @@ -18,4 +12,4 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] ->>>>>>> opensource/main + From b620847fae307c3308d0d96f9261c4db5b146473 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:33:31 -0400 Subject: [PATCH 42/95] finalize connections connection of model Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index 64d97279..b30680ec 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,5 +1,6 @@ from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, PhysicsDecoderPF, @@ -8,6 +9,7 @@ __all__ = [ "GNS_heterogeneous", + "GritHeteroAdapter", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", From 964cd5a8577598bfb6c01752572ce664612c3f5b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 10:21:24 -0400 Subject: [PATCH 43/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/compute_ac_dc_metrics.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py index 8dcfc8c0..3d8118cd 100644 --- a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py +++ b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py @@ -4,10 +4,6 @@ import os import numpy as np import pandas as pd -from gridfm_datakit.utils.power_balance import ( - compute_branch_powers_vectorized, - compute_bus_balance, -) N_SCENARIO_PER_PARTITION = 200 NUM_PROCESSES = 64 @@ -132,6 +128,11 @@ def compute_ac_dc_metrics( bus_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, + ) + # ========================= # AC residuals # ========================= From 028d7c56d0253a059d0e4bfc674faf4f1ef280cb Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:41:06 -0400 Subject: [PATCH 44/95] flow over random masking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 3 +- gridfm_graphkit/datasets/masking.py | 54 +++++++++++++++++++++ gridfm_graphkit/datasets/task_transforms.py | 17 ++++++- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index d1ecdcaf..7f89c565 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -5,7 +5,8 @@ task: task_name: PowerFlow data: baseMVA: 100 - mask_value: 0.0 + mask_type: rnd # or determinstic + mask_ratio: 0.5 # for random masking only normalization: HeteroDataMVANormalizer networks: - case14_ieee diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index b615e0cb..6b9d3e4e 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -33,6 +33,60 @@ from torch_geometric.nn import MessagePassing +class AddRandomHeteroMask(BaseTransform): + """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. + + Each selected feature dimension is independently masked per node/edge with + probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen + features: PG. Masked branch features: P_E, Q_E. + + The output ``data.mask_dict`` has the same structure as the deterministic + PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``, + ``PBELoss``, etc.) work without modification. + """ + + def __init__(self, mask_ratio=0.5): + super().__init__() + self.mask_ratio = mask_ratio + + def forward(self, data): + bus_x = data.x_dict["bus"] + gen_x = data.x_dict["gen"] + + # Bus type indicators (needed by losses and test metrics) + mask_PQ = bus_x[:, PQ_H] == 1 + mask_PV = bus_x[:, PV_H] == 1 + mask_REF = bus_x[:, REF_H] == 1 + + # Random bus mask on variable features the model reconstructs + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) + n_bus = bus_x.size(0) + for feat_idx in (VM_H, VA_H, QG_H): + mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio + + # Random gen mask on PG + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) + mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio + + # Random branch mask on flow features + branch_attr = data.edge_attr_dict[("bus", "connects", "bus")] + mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool) + n_edge = branch_attr.size(0) + for feat_idx in (P_E, Q_E): + mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "branch": mask_branch, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + } + + return data + + class AddPFHeteroMask(BaseTransform): """Creates masks for a heterogeneous power flow graph.""" diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index eaaca66c..d6f1b8cb 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -8,6 +8,7 @@ from gridfm_graphkit.datasets.masking import ( AddOPFHeteroMask, AddPFHeteroMask, + AddRandomHeteroMask, SimulateMeasurements, ) from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY @@ -20,7 +21,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose @@ -34,7 +41,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddOPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddOPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose From 78253bae6d0884426a6e65c0e2cd441b3f186e45 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:59:50 -0400 Subject: [PATCH 45/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f8c7b76e..0199c6ef 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -16,8 +16,8 @@ OUT_DIR=../scripts mkdir $OUT_DIR -python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true -python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true ###################################### @@ -117,11 +117,11 @@ if MODEL_TYPE == "grit": # Positional encoding config (only GRIT uses these) + # Read enablement and dimensions from data config (canonical source). RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 - RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ - and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") - RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False) + RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 else: RRWP_ENABLED = False RRWP_KSTEPS = 0 From 72a744930f3d91a3e11e4cd27e3ed3e9a7983cff Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:12:41 -0400 Subject: [PATCH 46/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 2 +- scripts/run_benchmark.sh | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f8c7b76e..878cbbd4 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -129,7 +129,7 @@ RWSE_TIMES = 0 # Keep original batch sizes list -batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] +batch_sizes = [1, 2, 4, 8, 16, 32] iterations = args.iterations # ---------------------------- diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 6483b815..744cfa21 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -3,13 +3,11 @@ set +e # Do NOT exit on error CONFIGS=( - "gridfm01" - "gridfm02" + "grit01" ) CONFIG_PATHS=( - "../examples/config/gridFMv0.1_pretraining.yaml" - "../examples/config/gridFMv0.2_pretraining.yaml" + "../examples/config/r2-1_grit_pretraining_RWSE_multi.yaml" ) GRAPH_SIZES=( @@ -37,4 +35,4 @@ for i in "${!CONFIGS[@]}"; do --num_nodes "$nodes" \ --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" done -done \ No newline at end of file +done From 549a525199502f52c73e0a85234aae76b0a80688 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:32:33 -0400 Subject: [PATCH 47/95] adjust example parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 4 ++-- scripts/benchmark_model_inference.py | 2 ++ scripts/run_benchmark.sh | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 7f89c565..70813ceb 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -28,7 +28,7 @@ model: # edge_dim must match the bus-bus edge feature count after transforms # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) edge_dim: 10 - hidden_size: 116 + hidden_size: 496 # input_dim = bus feature count (used by GRIT core FeatureEncoder) input_dim: 15 # Hetero adapter head dimensions @@ -36,7 +36,7 @@ model: input_gen_dim: 6 output_bus_dim: 2 output_gen_dim: 1 - num_layers: 10 + num_layers: 7 type: GRIT act: relu encoder: diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index e9af4c6e..fe3010fc 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -96,6 +96,8 @@ config_args = NestedNamespace(**base_config) model = load_model(config_args).to(device).eval() +tot_params = sum(p.numel() for p in model.parameters() if p.requires_grad) +print("**Total model trainable params: {}".format(tot_params)) # ---------------------------- # Parameters diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 744cfa21..2054f30b 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -7,7 +7,7 @@ CONFIGS=( ) CONFIG_PATHS=( - "../examples/config/r2-1_grit_pretraining_RWSE_multi.yaml" + "../examples/config/GRIT_PF_datakit_case14.yaml" ) GRAPH_SIZES=( From c181d1206b9c1364a557b5d2fa548f48e02fb52b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 08:33:44 -0400 Subject: [PATCH 48/95] update GRIT wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index aab8b939..82fef9d3 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -126,7 +126,7 @@ class GritTransformer(torch.nn.Module): 2023. """ - def __init__(self, args): + def __init__(self, args, include_decoder=True): super().__init__() @@ -195,7 +195,8 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - self.decoder = GraphHead(dim_inner, dim_out) + if include_decoder: + self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): """ @@ -266,8 +267,9 @@ def __init__(self, args): args.model.gt.dim_hidden = args.model.hidden_size # The original homogeneous GRIT - # (encoder + optional PE encoders + transformer layers + GraphHead) - self.grit = GritTransformer(args) + # (encoder + optional PE encoders + transformer layers) + # Decoder is excluded — this adapter provides its own per-type heads. + self.grit = GritTransformer(args, include_decoder=False) # Per-node-type output heads (replace GraphHead for hetero output) self.bus_head = nn.Sequential( From d4405fd63d50069b6ecb3c3e90c697b582867c57 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 09:25:33 -0400 Subject: [PATCH 49/95] update GRIT wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 82fef9d3..99519eb8 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -178,6 +178,11 @@ def __init__(self, args, include_decoder=True): layers = [] for ll in range(num_layers): + # The last layer's edge output is never consumed downstream + # (only node features feed into the output heads), so skip + # creating O_e / norm_e parameters to avoid DDP unused-parameter + # errors. + is_last = (ll == num_layers - 1) layers.append(GritTransformerLayer( in_dim=args.model.gt.dim_hidden, out_dim=args.model.gt.dim_hidden, @@ -188,8 +193,8 @@ def __init__(self, args, include_decoder=True): layer_norm=args.model.gt.layer_norm, batch_norm=args.model.gt.batch_norm, residual=True, - norm_e=args.model.gt.attn.norm_e, - O_e=args.model.gt.attn.O_e, + norm_e=False if is_last else args.model.gt.attn.norm_e, + O_e=False if is_last else args.model.gt.attn.O_e, cfg=args.model.gt, )) From 2452df3e6798a9e39ac52f229b7eea328ec45273 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:54:21 -0400 Subject: [PATCH 50/95] add edge norm Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 1 + gridfm_graphkit/models/grit_transformer.py | 23 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 70813ceb..8ce3a97f 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -7,6 +7,7 @@ data: baseMVA: 100 mask_type: rnd # or determinstic mask_ratio: 0.5 # for random masking only + mask_value: -1 normalization: HeteroDataMVANormalizer networks: - case14_ieee diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 99519eb8..eaaf11fa 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -29,6 +29,27 @@ def forward(self, batch): return batch +class BatchNorm1dEdge(torch.nn.Module): + r"""A batch normalization layer for edge-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.edge_attr = self.bn(batch.edge_attr) + return batch + + class LinearNodeEncoder(torch.nn.Module): def __init__(self, dim_in, emb_dim): super().__init__() @@ -84,7 +105,7 @@ def __init__( # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) if args.encoder.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) + self.edge_encoder_bn = BatchNorm1dEdge(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): From ef2ce5807dd90da87d3c82aa61bdae28509c7d2d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:18:41 -0400 Subject: [PATCH 51/95] extend random masking to PD QD PG Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 6 +- gridfm_graphkit/datasets/globals.py | 2 + gridfm_graphkit/datasets/masking.py | 8 +-- gridfm_graphkit/datasets/normalizers.py | 2 + gridfm_graphkit/models/grit_transformer.py | 50 ++++++++++++- gridfm_graphkit/training/loss.py | 77 +++++++++++++++++++++ 6 files changed, 137 insertions(+), 8 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 8ce3a97f..909c616c 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -78,9 +78,9 @@ training: batch_size: 8 epochs: 500 loss_weights: - - 0.01 - - 0.09 - - 0.9 + - 0.99 + - 0.001 + - 0.009 losses: - PBE - MaskedGenMSE diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3d..9627cd8e 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -24,6 +24,8 @@ VA_OUT = 1 PG_OUT = 2 QG_OUT = 3 +PD_OUT = 4 # for random masking +QD_OUT = 5 # for random masking PG_OUT_GEN = 0 diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index 6b9d3e4e..01924a2d 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -37,11 +37,11 @@ class AddRandomHeteroMask(BaseTransform): """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. Each selected feature dimension is independently masked per node/edge with - probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen - features: PG. Masked branch features: P_E, Q_E. + probability ``mask_ratio``. Masked bus features: PD, QD, VM, VA, QG. + Masked gen features: PG. Masked branch features: P_E, Q_E. The output ``data.mask_dict`` has the same structure as the deterministic - PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``, + PF / OPF masks so that downstream losses (``MaskedReconstructionMSE``, ``PBELoss``, etc.) work without modification. """ @@ -61,7 +61,7 @@ def forward(self, data): # Random bus mask on variable features the model reconstructs mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) n_bus = bus_x.size(0) - for feat_idx in (VM_H, VA_H, QG_H): + for feat_idx in (PD_H, QD_H, VM_H, VA_H, QG_H): mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio # Random gen mask on PG diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a66..61894045 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -306,6 +306,8 @@ def inverse_output(self, output, batch): gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA + bus_output[:, PD_OUT] *= self.baseMVA # for random masking + bus_output[:, QD_OUT] *= self.baseMVA # for random masking gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index eaaf11fa..5e20f18f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -6,6 +6,8 @@ from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder +from torch_scatter import scatter_add +from gridfm_graphkit.datasets.globals import PG_H class BatchNorm1dNode(torch.nn.Module): @@ -242,6 +244,48 @@ def forward(self, batch): return batch +def aggregate_pg(batch, mask_value=-1.0): + """Aggregate per-generator active power (PG) onto bus nodes. + + In the homogeneous reference, PG is a direct bus feature visible to the + transformer alongside Pd, Qd, Vm, Va, etc. In the heterogeneous + representation PG lives on separate generator nodes, so it must be + aggregated onto buses before the transformer can learn voltage-power + coupling. + + Masked generators (where PG has been replaced by the mask value) are + excluded from the sum to avoid corrupting the aggregated signal. Buses + where *all* connected generators are masked receive the mask value + instead, preserving a consistent "unknown" indicator. + """ + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index + gen_pg = batch["gen"].x[:, PG_H] + gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked + + # Zero out masked generators so they don't contribute to the sum + pg_clean = torch.where(gen_masked, torch.zeros_like(gen_pg), gen_pg) + + pg_per_bus = scatter_add( + pg_clean, + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + + # Check which buses have ALL generators masked (or no generators at all) + unmasked_count = scatter_add( + (~gen_masked).float(), + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + all_masked = unmasked_count == 0 + + # Set mask_value for fully-masked buses + pg_per_bus[all_masked] = mask_value + + return pg_per_bus + @MODELS_REGISTRY.register("GRIT") class GritHeteroAdapter(torch.nn.Module): @@ -321,8 +365,12 @@ def forward(self, batch): predicted output features. """ # --- Extract bus-only homogeneous subgraph --- + # Aggregate generator PG onto buses + pg_per_bus = aggregate_pg(batch, mask_value=self.grit.mask_value[0].item()) + bus_x = torch.cat([batch["bus"].x, pg_per_bus.unsqueeze(-1)], dim=-1) # 15 → 16D + homo = Data( - x=batch["bus"].x, + x=bus_x, y=batch["bus"].y, edge_index=batch["bus", "connects", "bus"].edge_index, edge_attr=batch["bus", "connects", "bus"].edge_attr, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 02df3bc7..f28370d0 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -18,6 +18,8 @@ VA_OUT, QG_OUT, PG_OUT, + PD_OUT, + QD_OUT, # Generator feature indices PG_H, # Edge feature indices @@ -142,6 +144,81 @@ def forward( return {"loss": loss, "Masked bus MSE loss": loss.detach()} +@LOSS_REGISTRY.register("MaskedReconstructionMSE") +class MaskedReconstructionMSE(BaseLoss): + """Unified masked MSE over bus-level quantities [VM, VA, PG, QG, PD, QD]. + + Mirrors the homogeneous reference MaskedMSE by combining bus predictions + and aggregated generator PG into a single prediction/target/mask tensor. + PG targets are aggregated from generator ground truth onto buses via + scatter_add; the bus-level PG mask is True when any generator at the bus + is masked, indicating that the model must reconstruct that quantity. + + Replaces the separate MaskedBusMSE + MaskedGenMSE pair. + Requires output_bus_dim >= 6 so the bus head predicts + [VM, VA, PG, QG, PD, QD]. + """ + + def __init__(self, loss_args, args): + super().__init__() + self.reduction = "mean" + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] + target_bus = target_dict["bus"] + num_bus = target_bus.size(0) + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + + # --- Build target: [VM, VA, PG_agg, QG, PD, QD] --- + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + target = torch.stack([ + target_bus[:, VM_H], + target_bus[:, VA_H], + target_pg_agg, + target_bus[:, QG_H], + target_bus[:, PD_H], + target_bus[:, QD_H], + ], dim=1) + + # --- Build mask: [N_bus, 6] --- + # PG bus-level mask: True if any generator at the bus has PG masked + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + + mask = torch.stack([ + mask_dict["bus"][:, VM_H], + mask_dict["bus"][:, VA_H], + any_gen_masked, + mask_dict["bus"][:, QG_H], + mask_dict["bus"][:, PD_H], + mask_dict["bus"][:, QD_H], + ], dim=1) + + # --- Prediction: [VM, VA, PG, QG, PD, QD] from bus head --- + pred = pred_bus[:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT, PD_OUT, QD_OUT]] + + loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) + return {"loss": loss, "Masked reconstruction MSE loss": loss.detach()} + + @LOSS_REGISTRY.register("MSE") class MSELoss(BaseLoss): """Standard Mean Squared Error loss.""" From 55182a65f1dc6040e628945fa8418613711944eb Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:25:16 -0400 Subject: [PATCH 52/95] extend random masking to PD QD PG Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 909c616c..88857f1a 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -30,12 +30,12 @@ model: # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) edge_dim: 10 hidden_size: 496 - # input_dim = bus feature count (used by GRIT core FeatureEncoder) - input_dim: 15 + # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) + input_dim: 16 # Hetero adapter head dimensions - input_bus_dim: 15 + input_bus_dim: 16 input_gen_dim: 6 - output_bus_dim: 2 + output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] output_gen_dim: 1 num_layers: 7 type: GRIT @@ -79,16 +79,13 @@ training: epochs: 500 loss_weights: - 0.99 - - 0.001 - - 0.009 + - 0.01 losses: - PBE - - MaskedGenMSE - - MaskedBusMSE + - MaskedReconstructionMSE loss_args: - {} - {} - - {} accelerator: auto devices: auto strategy: auto From 85b4ddfea754f18f662a041c5153e18e066a2033 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:40:18 -0400 Subject: [PATCH 53/95] update PBLoss to support transformer wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index f28370d0..fe4106b8 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -469,14 +469,23 @@ def forward( S_injection = torch.diag(V) @ Y_bus_conj @ V_conj # --- Net power from predictions/targets --- - # Pg: aggregate generator predictions onto buses + # Pg: use bus head prediction where masked, ground truth where known. + # Ground truth is aggregated from generator targets onto buses. gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] - Pg_per_bus = scatter_add( - pred_dict["gen"].squeeze(-1), + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], gen_to_bus_ei[1], dim=0, dim_size=num_bus, ) + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) Pd = target_bus[:, PD_H] Qd = target_bus[:, QD_H] From 9e2a3133661f21e1efb1bfdfd8bd67a3542f7816 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:55:23 -0400 Subject: [PATCH 54/95] update PBLoss to support transformer wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 88857f1a..cef4b503 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -36,7 +36,7 @@ model: input_bus_dim: 16 input_gen_dim: 6 output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] - output_gen_dim: 1 + output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed num_layers: 7 type: GRIT act: relu diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 5e20f18f..3422a21e 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -347,11 +347,17 @@ def __init__(self, args): nn.LeakyReLU(), nn.Linear(dim_inner, output_bus_dim), ) - self.gen_head = nn.Sequential( - nn.Linear(input_gen_dim, dim_inner), - nn.LeakyReLU(), - nn.Linear(dim_inner, output_gen_dim), - ) + # gen_head is only needed for tasks that require per-generator + # predictions (e.g. OPF cost computation). When output_gen_dim is 0 + # or not set, skip it to avoid DDP unused-parameter errors. + if output_gen_dim and output_gen_dim > 0: + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + else: + self.gen_head = None def forward(self, batch): """Forward pass on a heterogeneous power-grid batch. @@ -392,6 +398,6 @@ def forward(self, batch): # --- Per-type decoding --- bus_out = self.bus_head(homo.x) - gen_out = self.gen_head(batch["gen"].x) + gen_out = self.gen_head(batch["gen"].x) if self.gen_head is not None else batch["gen"].x return {"bus": bus_out, "gen": gen_out} From 6f112d89e4a99945ddd4a678afc9c63a86558b81 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:14:28 -0400 Subject: [PATCH 55/95] patch for admittance matrix indicies in PBLoss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index fe4106b8..819b2e66 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -455,8 +455,13 @@ def forward( V_conj = torch.conj(V) # --- Admittance matrix from bus-bus edge attrs --- - # Use Yff (diagonal-block) real/imag as the admittance entries - edge_complex = bus_edge_attr[:, YFF_TT_R] + 1j * bus_edge_attr[:, YFF_TT_I] + # Off-diagonal entries of Y-bus: Y[from][to] = Yft, Y[to][from] = Ytf. + # The dataset stores forward edges with Yft at YFT_TF columns and + # reverse edges with Ytf at the same columns, so indexing YFT_TF_R/I + # gives the correct off-diagonal admittance for both directions. + # (YFF_TT columns hold diagonal-block entries Yff/Ytt which belong on + # the Y-bus diagonal, not at off-diagonal edge positions.) + edge_complex = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] Y_bus_sparse = to_torch_coo_tensor( bus_edge_index, From b5f218ee353270ace3b80d90ba2632deb872310b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:46:53 -0400 Subject: [PATCH 56/95] patch for admittance matrix indicies in PBLoss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 50 ++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 819b2e66..bb8cf8d2 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -455,17 +455,49 @@ def forward( V_conj = torch.conj(V) # --- Admittance matrix from bus-bus edge attrs --- - # Off-diagonal entries of Y-bus: Y[from][to] = Yft, Y[to][from] = Ytf. - # The dataset stores forward edges with Yft at YFT_TF columns and - # reverse edges with Ytf at the same columns, so indexing YFT_TF_R/I - # gives the correct off-diagonal admittance for both directions. - # (YFF_TT columns hold diagonal-block entries Yff/Ytt which belong on - # the Y-bus diagonal, not at off-diagonal edge positions.) - edge_complex = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + # The Y-bus matrix has off-diagonal AND diagonal entries. + # + # Off-diagonal: Y[from][to] = Yft, Y[to][from] = Ytf, stored in the + # YFT_TF columns of the edge attributes. + # + # Diagonal: Y[k][k] = sum of Yff/Ytt for all branches at bus k. + # The dataset stores Yff (forward edges) and Ytt (reverse edges) in + # the YFF_TT columns. For every edge, YFF_TT at the *source* bus + # gives that branch's diagonal contribution at the source. Summing + # over all edges with source == k yields the full branch-diagonal. + # + # The reference project loads a pre-built Y-bus (y_bus_data.parquet) + # that includes self-loops for diagonal entries. Here we reconstruct + # the same structure from per-branch pi-model parameters. + + # Off-diagonal admittance values + edge_offdiag = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + + # Diagonal: aggregate Yff/Ytt to source bus of each edge + Y_diag_r = scatter_add( + bus_edge_attr[:, YFF_TT_R], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag_i = scatter_add( + bus_edge_attr[:, YFF_TT_I], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag = Y_diag_r + 1j * Y_diag_i + + # Build complete Y-bus: off-diagonal edges + self-loops for diagonal + diag_idx = torch.arange(num_bus, device=bus_edge_index.device) + full_edge_index = torch.cat( + [bus_edge_index, torch.stack([diag_idx, diag_idx])], dim=1, + ) + full_edge_values = torch.cat([edge_offdiag, Y_diag]) Y_bus_sparse = to_torch_coo_tensor( - bus_edge_index, - edge_complex, + full_edge_index, + full_edge_values, size=(num_bus, num_bus), ) Y_bus_conj = torch.conj(Y_bus_sparse) From 823438f13ef8cc87fe601f45627883fbb814f7cf Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 14:21:47 -0400 Subject: [PATCH 57/95] PBLoss support for random masking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index bb8cf8d2..1966a39d 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -524,8 +524,18 @@ def forward( ) > 0 Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) - Pd = target_bus[:, PD_H] - Qd = target_bus[:, QD_H] + # Pd, Qd: use prediction where masked, target where known. + # For deterministic PF/OPF masks PD/QD are never masked, so this + # is equivalent to always using target. For random masking this + # lets PBE provide gradient signal for PD/QD reconstruction. + if pred_bus.size(1) > PD_OUT: + Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) + else: + Pd = target_bus[:, PD_H] + if pred_bus.size(1) > QD_OUT: + Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) + else: + Qd = target_bus[:, QD_H] # Qg: use prediction if the model predicts it, else use target if pred_bus.size(1) > QG_OUT: From 821a40ca2e83051fe8eff49ab79d5f49887b1aec Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:02:40 -0400 Subject: [PATCH 58/95] slice mse features Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/opf_task.py | 11 ++++++++--- gridfm_graphkit/tasks/pf_task.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index b28c5a0c..e2f0bcd1 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -233,18 +233,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Active Power Loss"] = final_residual_real_bus.detach() loss_dict["Reactive Power Loss"] = final_residual_imag_bus.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index cdc9d646..c17696e8 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -166,18 +166,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["PBE Mean"] = pbe_mean.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) From 9dabcbf8f11cd75c4e1e338e3806d1c4f0ba11da Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:41:37 -0400 Subject: [PATCH 59/95] adjust denorm for GRIT Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/normalizers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 61894045..01936a21 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -20,6 +20,8 @@ # Output feature indices PG_OUT, QG_OUT, + PD_OUT, + QD_OUT, PG_OUT_GEN, # Generator feature indices PG_H, @@ -306,8 +308,10 @@ def inverse_output(self, output, batch): gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA - bus_output[:, PD_OUT] *= self.baseMVA # for random masking - bus_output[:, QD_OUT] *= self.baseMVA # for random masking + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= self.baseMVA + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= self.baseMVA gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: @@ -608,6 +612,10 @@ def inverse_output(self, output, batch): # Scale per-unit power back to MW/Mvar bus_output[:, PG_OUT] *= b_bus bus_output[:, QG_OUT] *= b_bus + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= b_bus + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= b_bus gen_output[:, PG_OUT_GEN] *= b_gen def get_stats(self) -> dict: From f9611037cdba89f61ec332b9859c4e165fd2de6e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:46:56 -0400 Subject: [PATCH 60/95] update mse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 1966a39d..960e3ca1 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -100,9 +100,12 @@ def forward( mask_dict, model=None, ): + gen_pred = pred_dict["gen"][:, : (PG_H + 1)] + gen_target = target_dict["gen"][:, : (PG_H + 1)] + mask = mask_dict["gen"][:, : (PG_H + 1)] loss = F.mse_loss( - pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], - target_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], + gen_pred[mask], + gen_target[mask], reduction=self.reduction, ) return {"loss": loss, "Masked generator MSE loss": loss.detach()} From 5738c9533f54f78f8cf086dbf82518f956e4ad69 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 1 Apr 2026 09:51:47 -0400 Subject: [PATCH 61/95] clamp unsupervised values in evaluate Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/pf_task.py | 80 +++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index c17696e8..74d6bee9 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -5,6 +5,8 @@ QG_H, VM_H, VA_H, + # Generator feature indices + PG_H, # Output feature indices VM_OUT, VA_OUT, @@ -80,12 +82,35 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH # output["bus"] = target - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + # Clamp known (unmasked) values to ground truth, matching the + # PBELoss convention used during training. The model is only + # responsible for predicting masked unknowns; using raw predictions + # for known quantities would inflate the residual. + mask_bus = batch.mask_dict["bus"] + eval_bus = output["bus"].clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) @@ -368,13 +393,52 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + + # Build target for clamping known (unmasked) values + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + # Clamp known values to ground truth (same as PBELoss during training) + mask_bus = batch.mask_dict["bus"] + eval_bus = output["bus"].clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) residual_P = torch.abs(residual_P) @@ -396,14 +460,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): mask_PV = batch.mask_dict["PV"] mask_REF = batch.mask_dict["REF"] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), From b661b85acc756badb29b3b15dd7093eedac11b56 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:42:17 -0400 Subject: [PATCH 62/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 3 - gridfm_graphkit/tasks/pf_task.py | 145 +++++++++------------ gridfm_graphkit/training/loss.py | 19 ++- 3 files changed, 76 insertions(+), 91 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3422a21e..51767e24 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -236,9 +236,6 @@ def forward(self, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ - # print('xxxx',batch.x.min(), batch.x.max()) - # print('yyyyy',batch.y.min(), batch.y.max()) - # print('>>>>', batch) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 74d6bee9..46be6998 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -36,11 +36,67 @@ import pandas as pd +def _build_bus_target(batch, num_bus): + """Build a 4-column bus-level target tensor [VM, VA, PG_agg, QG]. + + Generator PG is aggregated onto buses via scatter_add so that the + target layout matches the bus head output columns. + """ + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) + return target, gen_to_bus_index, agg_gen_on_bus + + +def _clamp_known_to_ground_truth(output_bus, target, batch, gen_to_bus_index, num_bus): + """Replace predicted values with ground truth for known (unmasked) quantities. + + During both training (PBELoss) and evaluation, the model is only + responsible for predicting masked unknowns. Known quantities (e.g. + VM at PV buses, VA at REF, PG at non-slack generators) are clamped to + ground truth so that prediction errors on non-target outputs do not + pollute the power-balance residual. + """ + mask_bus = batch.mask_dict["bus"] + eval_bus = output_bus.clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output_bus[:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output_bus[:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output_bus[:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output_bus[:, QG_OUT], target[:, QG_OUT], + ) + return eval_bus + + @TASK_REGISTRY.register("PowerFlow") class PowerFlowTask(ReconstructionTask): """ - Concrete Optimal Power Flow task. - Extends ReconstructionTask and adds OPF-specific metrics. + Concrete Power Flow task. + Extends ReconstructionTask and adds PF-specific evaluation metrics + (power balance residuals, per-bus-type RMSE). """ def __init__(self, args, data_normalizers): @@ -60,49 +116,10 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - # output_agg = torch.cat([batch.y_dict["bus"], agg_gen_on_bus], dim=1) - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, - ) - # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH - # output["bus"] = target - - # Clamp known (unmasked) values to ground truth, matching the - # PBELoss convention used during training. The model is only - # responsible for predicting masked unknowns; using raw predictions - # for known quantities would inflate the residual. - mask_bus = batch.mask_dict["bus"] - eval_bus = output["bus"].clone() - eval_bus[:, VM_OUT] = torch.where( - mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], - ) - eval_bus[:, VA_OUT] = torch.where( - mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], - ) - gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() - any_gen_masked = ( - scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 - ) - eval_bus[:, PG_OUT] = torch.where( - any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], - ) - eval_bus[:, QG_OUT] = torch.where( - mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) @@ -393,44 +410,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - - # Build target for clamping known (unmasked) values - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, - ) - # Clamp known values to ground truth (same as PBELoss during training) - mask_bus = batch.mask_dict["bus"] - eval_bus = output["bus"].clone() - eval_bus[:, VM_OUT] = torch.where( - mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], - ) - eval_bus[:, VA_OUT] = torch.where( - mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], - ) - gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() - any_gen_masked = ( - scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 - ) - eval_bus[:, PG_OUT] = torch.where( - any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], - ) - eval_bus[:, QG_OUT] = torch.where( - mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 960e3ca1..d700b89b 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -441,7 +441,14 @@ def forward( bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] mask_bus = mask_dict["bus"] - # --- Voltage: use prediction where masked, target where known --- + # --- Clamp known values to ground truth --- + # In power flow, certain variables are "known" (unmasked) at each + # bus type (e.g. VM at PV buses, VA at REF). The model only needs + # to predict *masked* unknowns; for everything else we substitute + # the ground truth so that errors in non-target outputs do not + # pollute the physics loss. This matches the reference's + # ``temp_pred[unmasked] = target[unmasked]`` convention. + Vm_pred = pred_bus[:, VM_OUT] Va_pred = pred_bus[:, VA_OUT] Vm_target = target_bus[:, VM_H] @@ -527,10 +534,10 @@ def forward( ) > 0 Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) - # Pd, Qd: use prediction where masked, target where known. - # For deterministic PF/OPF masks PD/QD are never masked, so this - # is equivalent to always using target. For random masking this - # lets PBE provide gradient signal for PD/QD reconstruction. + # Pd, Qd, Qg: same clamp-to-ground-truth logic. The size guard + # (``pred_bus.size(1) > *_OUT``) handles models with a narrow bus + # head (e.g. output_bus_dim=4) that don't predict PD/QD/QG; in that + # case the target is always used. if pred_bus.size(1) > PD_OUT: Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) else: @@ -539,8 +546,6 @@ def forward( Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) else: Qd = target_bus[:, QD_H] - - # Qg: use prediction if the model predicts it, else use target if pred_bus.size(1) > QG_OUT: Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) else: From 297857b1f76b2f14633db1ea36a76883d5820c00 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:49:02 -0400 Subject: [PATCH 63/95] include shunt admittance in pb loss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/reconstruction_tasks.py | 1 + gridfm_graphkit/training/loss.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 43e243c0..f6c4501c 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -52,6 +52,7 @@ def shared_step(self, batch): batch.edge_attr_dict, batch.mask_dict, model=self.model, + x_dict=batch.x_dict, ) return output, loss_dict diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d700b89b..a9b6e85f 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -13,6 +13,8 @@ VA_H, QD_H, PD_H, + GS, + BS, # Output feature indices VM_OUT, VA_OUT, @@ -271,6 +273,7 @@ def forward( edge_attr=None, mask=None, model=None, + **kwargs, ): """ Compute the weighted sum of all specified losses. @@ -297,6 +300,7 @@ def forward( edge_attr, mask, model, + **kwargs, ) # Assume each loss function returns a dictionary with a "loss" key @@ -432,6 +436,7 @@ def forward( edge_attr_dict, mask_dict, model=None, + x_dict=None, ): pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] @@ -498,6 +503,11 @@ def forward( ) Y_diag = Y_diag_r + 1j * Y_diag_i + # Add bus shunt admittance (Gs + jBs) to the diagonal + if x_dict is not None: + bus_orig = x_dict["bus"] + Y_diag = Y_diag + bus_orig[:, GS] + 1j * bus_orig[:, BS] + # Build complete Y-bus: off-diagonal edges + self-loops for diagonal diag_idx = torch.arange(num_bus, device=bus_edge_index.device) full_edge_index = torch.cat( From 03c59a02ff6ffe1e23de058ba8e9cbeb8291ec7f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:49:03 -0400 Subject: [PATCH 64/95] include shunt admittance in pb loss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index a9b6e85f..69bc10b6 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -46,6 +46,7 @@ def forward( edge_attr=None, mask=None, model=None, + **kwargs, ): """ Compute the loss. @@ -82,6 +83,7 @@ def forward( edge_attr=None, mask=None, model=None, + **kwargs, ): loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) return {"loss": loss, "Masked MSE loss": loss.detach()} @@ -101,6 +103,7 @@ def forward( edge_attr, mask_dict, model=None, + **kwargs, ): gen_pred = pred_dict["gen"][:, : (PG_H + 1)] gen_target = target_dict["gen"][:, : (PG_H + 1)] @@ -128,6 +131,7 @@ def forward( edge_attr, mask_dict, model=None, + **kwargs, ): if self.args.task == "OptimalPowerFlow": pred_cols = [VM_OUT, VA_OUT, QG_OUT] @@ -176,6 +180,7 @@ def forward( edge_attr_dict, mask_dict, model=None, + **kwargs, ): pred_bus = pred_dict["bus"] target_bus = target_dict["bus"] @@ -240,6 +245,7 @@ def forward( edge_attr=None, mask=None, model=None, + **kwargs, ): loss = F.mse_loss(pred, target, reduction=self.reduction) return {"loss": loss, "MSE loss": loss.detach()} @@ -331,6 +337,7 @@ def forward( edge_attr=None, mask=None, model=None, + **kwargs, ): total_loss = 0.0 loss_details = {} @@ -381,6 +388,7 @@ def forward( edge_attr, mask_dict, model=None, + **kwargs, ): if self.dim == "VM": temp_pred = pred_dict["bus"][:, VM_OUT] From a7a63cebce5ff045a683bb16f427a80c6f34efbc Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:54:25 -0400 Subject: [PATCH 65/95] pre-merge checks Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- .github/workflows/ci-build.yaml | 4 +- gridfm_graphkit/datasets/globals.py | 4 +- .../datasets/hetero_powergrid_datamodule.py | 14 +- gridfm_graphkit/datasets/posenc_stats.py | 97 +++++---- gridfm_graphkit/datasets/rrwp.py | 41 ++-- gridfm_graphkit/models/__init__.py | 2 - gridfm_graphkit/models/grit_layer.py | 191 +++++++++++------- gridfm_graphkit/models/grit_transformer.py | 118 ++++++----- gridfm_graphkit/models/kernel_pos_encoder.py | 33 +-- gridfm_graphkit/models/rrwp_encoder.py | 87 +++++--- gridfm_graphkit/tasks/pf_task.py | 28 ++- gridfm_graphkit/training/loss.py | 96 +++++---- pyproject.toml | 4 +- scripts/benchmark_model_inference.py | 182 ++++++++++++----- 14 files changed, 571 insertions(+), 330 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 01b8b422..453d25e0 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -67,11 +67,11 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install --upgrade "git+https://github.com/ibm/detect-secrets.git@master#egg=detect-secrets" python3 -m pip install boxsdk - + - name: Scan repository & write snapshot run: | mkdir -p security-outputs - + # Run detect-secrets while skipping binary files detect-secrets scan \ --exclude-files '.*\.ipynb$|.*\.(png|jpg|jpeg|gif|pdf|onnx|pt|pth|bin|zip)$' \ diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index 9627cd8e..c62cfae7 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -24,8 +24,8 @@ VA_OUT = 1 PG_OUT = 2 QG_OUT = 3 -PD_OUT = 4 # for random masking -QD_OUT = 5 # for random masking +PD_OUT = 4 # for random masking +QD_OUT = 5 # for random masking PG_OUT_GEN = 0 diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index e6a4cfd1..d14537fa 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -154,19 +154,15 @@ def setup(self, stage: str): data_normalizer=data_normalizer, transform=get_task_transforms(args=self.args), ) - - if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=['RRWP'], - cfg=self.args.data - ) + + if ("posenc_RRWP" in self.args.data) and self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=["RRWP"], cfg=self.args.data) if dataset.transform is None: dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) - if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: - pe_transform = ComputePosencStat(pe_types=['RWSE'], - cfg=self.args.data - ) + if ("posenc_RWSE" in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=["RWSE"], cfg=self.args.data) if dataset.transform is None: dataset.transform = pe_transform else: diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 5263b488..cc780ee5 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -4,8 +4,12 @@ import torch import torch.nn.functional as F -from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, - to_undirected, to_dense_adj) +from torch_geometric.utils import ( + get_laplacian, + to_scipy_sparse_matrix, + to_undirected, + to_dense_adj, +) from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter_add from functools import partial @@ -17,9 +21,10 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes + def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. - Supported PE statistics to precompute in original implementation, + Supported PE statistics to precompute in original implementation, selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. 'RWSE': Random walk landing probabilities (diagonals of RW matrices). @@ -39,42 +44,55 @@ def compute_posenc_stats(data, pe_types, cfg): """ # Verify PE types. for t in pe_types: - if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', - 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + if t not in [ + "LapPE", + "EquivStableLapPE", + "SignNet", + "RWSE", + "HKdiagSE", + "HKfullPE", + "ElstaticSE", + "RRWP", + ]: raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") - if 'RRWP' in pe_types: + if "RRWP" in pe_types: param = cfg.posenc_RRWP - transform = partial(add_full_rrwp, - walk_length=param.ksteps, - attr_name_abs="rrwp", - attr_name_rel="rrwp", - add_identity=True - ) + transform = partial( + add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + ) data = transform(data) # Random Walks. - if 'RWSE' in pe_types: + if "RWSE" in pe_types: kernel_param = cfg.posenc_RWSE.kernel - if hasattr(data, 'num_nodes'): + if hasattr(data, "num_nodes"): N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa else: N = data.x.shape[0] # Number of nodes, including disconnected nodes. if kernel_param.times == 0: raise ValueError("List of kernel times required for RWSE") rw_landing = get_rw_landing_probs( - ksteps=[xx + 1 for xx in range(kernel_param.times)], - edge_index=data.edge_index, - num_nodes=N - ) + ksteps=[xx + 1 for xx in range(kernel_param.times)], + edge_index=data.edge_index, + num_nodes=N, + ) data.pestat_RWSE = rw_landing return data - -def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, - num_nodes=None, space_dim=0): +def get_rw_landing_probs( + ksteps, + edge_index, + edge_weight=None, + num_nodes=None, + space_dim=0, +): """Compute Random Walk landing probabilities for given list of K steps. Args: @@ -96,31 +114,36 @@ def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, num_nodes = maybe_num_nodes(edge_index, num_nodes) source, dest = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.) - deg_inv.masked_fill_(deg_inv == float('inf'), 0) + deg_inv = deg.pow(-1.0) + deg_inv.masked_fill_(deg_inv == float("inf"), 0) if edge_index.numel() == 0: P = edge_index.new_zeros((1, num_nodes, num_nodes)) else: # P = D^-1 * A - P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + P = torch.diag(deg_inv) @ to_dense_adj( + edge_index, + max_num_nodes=num_nodes, + ) # 1 x (Num nodes) x (Num nodes) rws = [] if ksteps == list(range(min(ksteps), max(ksteps) + 1)): # Efficient way if ksteps are a consecutive sequence (most of the time the case) Pk = P.clone().detach().matrix_power(min(ksteps)) for k in range(min(ksteps), max(ksteps) + 1): - rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * (k ** (space_dim / 2))) Pk = Pk @ P else: # Explicitly raising P to power k for each k \in ksteps. for k in ksteps: - rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) + rws.append( + torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) + * (k ** (space_dim / 2)), + ) rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) return rw_landing + class ComputePosencStat(BaseTransform): def __init__(self, pe_types, cfg): self.pe_types = pe_types @@ -133,10 +156,7 @@ def __call__(self, data) -> Data: if isinstance(data, HeteroData): return self._call_hetero(data) - data = compute_posenc_stats(data, - pe_types=self.pe_types, - cfg=self.cfg - ) + data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg) return data def _call_hetero(self, data: HeteroData) -> HeteroData: @@ -150,14 +170,19 @@ def _call_hetero(self, data: HeteroData) -> HeteroData: bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight bus_data = compute_posenc_stats( - bus_data, pe_types=self.pe_types, cfg=self.cfg, + bus_data, + pe_types=self.pe_types, + cfg=self.cfg, ) # Copy computed PE attributes back onto the HeteroData bus store pe_attrs = [ - "pestat_RWSE", # RWSE - "rrwp", "rrwp_index", "rrwp_val", # RRWP - "log_deg", "deg", # degree info from RRWP + "pestat_RWSE", # RWSE + "rrwp", + "rrwp_index", + "rrwp_val", # RRWP + "log_deg", + "deg", # degree info from RRWP ] for attr in pe_attrs: if hasattr(bus_data, attr): diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index acbe1120..317515e2 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -5,10 +5,9 @@ from torch_sparse import SparseTensor -def add_node_attr(data: Data, value: Any, - attr_name: Optional[str] = None) -> Data: +def add_node_attr(data: Data, value: Any, attr_name: Optional[str] = None) -> Data: if attr_name is None: - if 'x' in data: + if "x" in data: x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) else: @@ -19,27 +18,29 @@ def add_node_attr(data: Data, value: Any, return data - @torch.no_grad() -def add_full_rrwp(data, - walk_length=8, - attr_name_abs="rrwp", # name: 'rrwp' - attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') - add_identity=True, - spd=False, - **kwargs - ): +def add_full_rrwp( + data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs, +): num_nodes = data.num_nodes edge_index, edge_weight = data.edge_index, data.edge_weight - adj = SparseTensor.from_edge_index(edge_index, edge_weight, - sparse_sizes=(num_nodes, num_nodes), - ) + adj = SparseTensor.from_edge_index( + edge_index, + edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) # Compute D^{-1} A: deg = adj.sum(dim=1) deg_inv = 1.0 / adj.sum(dim=1) - deg_inv[deg_inv == float('inf')] = 0 + deg_inv[deg_inv == float("inf")] = 0 adj = adj * deg_inv.view(-1, 1) adj = adj.to_dense() @@ -57,19 +58,18 @@ def add_full_rrwp(data, out = out @ adj pe_list.append(out) - pe = torch.stack(pe_list, dim=-1) # n x n x k + pe = torch.stack(pe_list, dim=-1) # n x n x k - abs_pe = pe.diagonal().transpose(0, 1) # n x k + abs_pe = pe.diagonal().transpose(0, 1) # n x k rel_pe = SparseTensor.from_dense(pe, has_value=True) rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) - # the framework of GRIT performing right-mul while adj is row-normalized, + # the framework of GRIT performing right-mul while adj is row-normalized, # need to switch the order or row and col. # note: both can work but the current version is more reasonable. - if spd: spd_idx = walk_length - torch.arange(walk_length) val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) @@ -84,4 +84,3 @@ def add_full_rrwp(data, data.deg = deg.type(torch.long) return data - diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index b30680ec..f185c6a2 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,3 @@ - from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter from gridfm_graphkit.models.utils import ( @@ -14,4 +13,3 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] - diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index a1ffc4a3..2dc4abe5 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -24,7 +24,7 @@ def pyg_softmax(src, index, num_nodes=None): num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) - Returns: + Returns: out (Tensor) """ @@ -32,25 +32,31 @@ def pyg_softmax(src, index, num_nodes=None): out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] out = out.exp() - out = out / ( - scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) return out - class MultiHeadAttentionLayerGritSparse(nn.Module): """ Attention Computation for GRIT """ - def __init__(self, in_dim, out_dim, num_heads, use_bias, - clamp=5., dropout=0., act=None, - edge_enhance=True, - sqrt_relu=False, - signed_sqrt=True, - cfg={}, - **kwargs): + def __init__( + self, + in_dim, + out_dim, + num_heads, + use_bias, + clamp=5.0, + dropout=0.0, + act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg={}, + **kwargs, + ): super().__init__() self.out_dim = out_dim @@ -68,7 +74,10 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, nn.init.xavier_normal_(self.E.weight) nn.init.xavier_normal_(self.V.weight) - self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + self.Aw = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, 1), + requires_grad=True, + ) nn.init.xavier_normal_(self.Aw) if act is None: @@ -77,17 +86,20 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, self.act = nn.ReLU() if self.edge_enhance: - self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + self.VeRow = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, self.out_dim), + requires_grad=True, + ) nn.init.xavier_normal_(self.VeRow) def propagate_attention(self, batch): - src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim - dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim - score = src + dest # element-wise multiplication + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication if batch.get("E", None) is not None: batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) - E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + E_w, E_b = batch.E[:, :, : self.out_dim], batch.E[:, :, self.out_dim :] # (num relative) x num_heads x out_dim score = score * E_w score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) @@ -106,14 +118,21 @@ def propagate_attention(self, batch): score = torch.clamp(score, min=-self.clamp, max=self.clamp) raw_attn = score - score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = pyg_softmax( + score, + batch.edge_index[1], + ) # (num relative) x num_heads x 1 score = self.dropout(score) batch.attn = score # Aggregate with Attn-Score - msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim - batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim - scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + msg = ( + batch.V_h[batch.edge_index[0]] * score + ) # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like( + batch.V_h, + ) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce="add") if self.edge_enhance and batch.E is not None: rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") @@ -135,7 +154,7 @@ def forward(self, batch): batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) self.propagate_attention(batch) h_out = batch.wV - e_out = batch.get('wE', None) + e_out = batch.get("wE", None) return h_out, e_out @@ -144,16 +163,23 @@ class GritTransformerLayer(nn.Module): """ Transformer Layer for GRIT """ - def __init__(self, in_dim, out_dim, num_heads, - dropout=0.0, - attn_dropout=0.0, - layer_norm=False, batch_norm=True, - residual=True, - act='relu', - norm_e=True, - O_e=True, - cfg=dict(), - **kwargs): + + def __init__( + self, + in_dim, + out_dim, + num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, + batch_norm=True, + residual=True, + act="relu", + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs, + ): super().__init__() self.debug = False @@ -171,12 +197,12 @@ def __init__(self, in_dim, out_dim, num_heads, self.update_e = getattr(cfg.attn, "update_e", True) self.bn_momentum = cfg.attn.bn_momentum self.bn_no_runner = cfg.attn.bn_no_runner - self.rezero = getattr(cfg.attn, "rezero", False) + self.rezero = getattr(cfg.attn, "rezero", False) if act is not None: - self.act = nn.ReLU() + self.act = nn.ReLU() else: - self.act = nn.Identity() + self.act = nn.Identity() if getattr(cfg, "attn", None) is None: cfg.attn = dict() @@ -189,43 +215,43 @@ def __init__(self, in_dim, out_dim, num_heads, num_heads=num_heads, use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=getattr(cfg.attn, "clamp", 5.), + clamp=getattr(cfg.attn, "clamp", 5.0), act=getattr(cfg.attn, "act", "relu"), edge_enhance=getattr(cfg.attn, "edge_enhance", True), sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), - scaled_attn =getattr(cfg.attn,"scaled_attn", False), + scaled_attn=getattr(cfg.attn, "scaled_attn", False), no_qk=getattr(cfg.attn, "no_qk", False), ) - if getattr(cfg.attn, 'graphormer_attn', False): + if getattr(cfg.attn, "graphormer_attn", False): self.attention = MultiHeadAttentionLayerGraphormerSparse( in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=getattr(cfg.attn, "clamp", 5.), + clamp=getattr(cfg.attn, "clamp", 5.0), act=getattr(cfg.attn, "act", "relu"), edge_enhance=True, sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), - scaled_attn =getattr(cfg.attn, "scaled_attn", False), + scaled_attn=getattr(cfg.attn, "scaled_attn", False), no_qk=getattr(cfg.attn, "no_qk", False), ) - - - self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + self.O_h = nn.Linear(out_dim // num_heads * num_heads, out_dim) if O_e: - self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + self.O_e = nn.Linear(out_dim // num_heads * num_heads, out_dim) else: self.O_e = nn.Identity() # -------- Deg Scaler Option ------ if self.deg_scaler: - self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + self.deg_coef = nn.Parameter( + torch.zeros(1, out_dim // num_heads * num_heads, 2), + ) nn.init.xavier_normal_(self.deg_coef) if self.layer_norm: @@ -234,8 +260,22 @@ def __init__(self, in_dim, out_dim, num_heads, if self.batch_norm: # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) - self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) - self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() + self.batch_norm1_h = nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) + self.batch_norm1_e = ( + nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) + if norm_e + else nn.Identity() + ) # FFN for h self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) @@ -245,12 +285,17 @@ def __init__(self, in_dim, out_dim, num_heads, self.layer_norm2_h = nn.LayerNorm(out_dim) if self.batch_norm: - self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm2_h = nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) if self.rezero: - self.alpha1_h = nn.Parameter(torch.zeros(1,1)) - self.alpha2_h = nn.Parameter(torch.zeros(1,1)) - self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + self.alpha1_h = nn.Parameter(torch.zeros(1, 1)) + self.alpha2_h = nn.Parameter(torch.zeros(1, 1)) + self.alpha1_e = nn.Parameter(torch.zeros(1, 1)) def forward(self, batch): h = batch.x @@ -279,19 +324,23 @@ def forward(self, batch): e = self.O_e(e) if self.residual: - if self.rezero: h = h * self.alpha1_h + if self.rezero: + h = h * self.alpha1_h h = h_in1 + h # residual connection if e is not None: - if self.rezero: e = e * self.alpha1_e + if self.rezero: + e = e * self.alpha1_e e = e + e_in1 if self.layer_norm: h = self.layer_norm1_h(h) - if e is not None: e = self.layer_norm1_e(e) + if e is not None: + e = self.layer_norm1_e(e) if self.batch_norm: h = self.batch_norm1_h(h) - if e is not None: e = self.batch_norm1_e(e) + if e is not None: + e = self.batch_norm1_e(e) # FFN for h h_in2 = h # for second residual connection @@ -301,7 +350,8 @@ def forward(self, batch): h = self.FFN_h_layer2(h) if self.residual: - if self.rezero: h = h * self.alpha2_h + if self.rezero: + h = h * self.alpha2_h h = h_in2 + h # residual connection if self.layer_norm: @@ -319,11 +369,15 @@ def forward(self, batch): return batch def __repr__(self): - return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( - self.__class__.__name__, - self.in_channels, - self.out_channels, self.num_heads, self.residual, - super().__repr__(), + return ( + "{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.num_heads, + self.residual, + super().__repr__(), + ) ) @@ -335,13 +389,14 @@ def get_log_deg(batch): deg = batch.deg log_deg = torch.log(deg + 1).unsqueeze(-1) else: - warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") - deg = pyg.utils.degree(batch.edge_index[1], - num_nodes=batch.num_nodes, - dtype=torch.float - ) + warnings.warn( + "Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs", + ) + deg = pyg.utils.degree( + batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float, + ) log_deg = torch.log(deg + 1) log_deg = log_deg.view(batch.num_nodes, 1) return log_deg - - diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3422a21e..285b830b 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -3,7 +3,10 @@ from torch import nn from torch_geometric.data import Data -from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.rrwp_encoder import ( + RRWPLinearNodeEncoder, + RRWPLinearEdgeEncoder, +) from gridfm_graphkit.models.grit_layer import GritTransformerLayer from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder from torch_scatter import scatter_add @@ -18,6 +21,7 @@ class BatchNorm1dNode(torch.nn.Module): eps (float): BatchNorm eps. momentum (float): BatchNorm momentum. """ + def __init__(self, dim_in, eps, momentum): super().__init__() self.bn = torch.nn.BatchNorm1d( @@ -39,6 +43,7 @@ class BatchNorm1dEdge(torch.nn.Module): eps (float): BatchNorm eps. momentum (float): BatchNorm momentum. """ + def __init__(self, dim_in, eps, momentum): super().__init__() self.bn = torch.nn.BatchNorm1d( @@ -61,7 +66,8 @@ def __init__(self, dim_in, emb_dim): def forward(self, batch): batch.x = self.encoder(batch.x) return batch - + + class LinearEdgeEncoder(torch.nn.Module): def __init__(self, edge_dim, emb_dim): super().__init__() @@ -83,18 +89,18 @@ class FeatureEncoder(torch.nn.Module): dim_in (int): Input feature dimension """ - def __init__( - self, - dim_in, - dim_inner, - args - ): + + def __init__(self, dim_in, dim_inner, args): super(FeatureEncoder, self).__init__() self.dim_in = dim_in if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - if 'RWSE' in args.encoder.node_encoder_name: - self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.encoder.posenc_RWSE) + if "RWSE" in args.encoder.node_encoder_name: + self.node_encoder = RWSENodeEncoder( + self.dim_in, + dim_inner, + args.encoder.posenc_RWSE, + ) else: self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: @@ -113,7 +119,8 @@ def forward(self, batch): for module in self.children(): batch = module(batch) return batch - + + class GraphHead(nn.Module): """ Prediction head for decoding tasks. @@ -126,11 +133,11 @@ class GraphHead(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.FC_layers = nn.Sequential( + self.FC_layers = nn.Sequential( nn.Linear(dim_in, dim_in), nn.LeakyReLU(), nn.Linear(dim_in, dim_out), - ) + ) def _apply_index(self, batch): return batch.graph_feature, batch.y @@ -149,10 +156,10 @@ class GritTransformer(torch.nn.Module): 2023. """ + def __init__(self, args, include_decoder=True): super().__init__() - dim_in = args.model.input_dim dim_out = args.model.output_dim dim_inner = args.model.hidden_size @@ -166,38 +173,34 @@ def __init__(self, args, include_decoder=True): if self.learn_mask: self.mask_value = nn.Parameter( torch.randn(self.mask_dim) + self.mask_value, - requires_grad=True, - ) - else: + requires_grad=True, + ) + else: self.mask_value = nn.Parameter( - torch.zeros(self.mask_dim) + self.mask_value, + torch.zeros(self.mask_dim) + self.mask_value, requires_grad=False, ) - - self.encoder = FeatureEncoder( - dim_in, - dim_inner, - args.model - ) - dim_in = self.encoder.dim_in - if args.data.posenc_RRWP.enable: + self.encoder = FeatureEncoder(dim_in, dim_inner, args.model) + dim_in = self.encoder.dim_in + if args.data.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( - args.data.posenc_RRWP.ksteps, - dim_inner - ) + args.data.posenc_RRWP.ksteps, + dim_inner, + ) rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( - rel_pe_dim, + rel_pe_dim, dim_inner, pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, - fill_value=0. - ) + fill_value=0.0, + ) - assert args.model.hidden_size == dim_inner == dim_in, \ + assert args.model.hidden_size == dim_inner == dim_in, ( "The inner and hidden dims must match." + ) layers = [] for ll in range(num_layers): @@ -205,28 +208,30 @@ def __init__(self, args, include_decoder=True): # (only node features feed into the output heads), so skip # creating O_e / norm_e parameters to avoid DDP unused-parameter # errors. - is_last = (ll == num_layers - 1) - layers.append(GritTransformerLayer( - in_dim=args.model.gt.dim_hidden, - out_dim=args.model.gt.dim_hidden, - num_heads=num_heads, - dropout=dropout, - act=args.model.act, - attn_dropout=args.model.gt.attn_dropout, - layer_norm=args.model.gt.layer_norm, - batch_norm=args.model.gt.batch_norm, - residual=True, - norm_e=False if is_last else args.model.gt.attn.norm_e, - O_e=False if is_last else args.model.gt.attn.O_e, - cfg=args.model.gt, - )) + is_last = ll == num_layers - 1 + layers.append( + GritTransformerLayer( + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, + residual=True, + norm_e=False if is_last else args.model.gt.attn.norm_e, + O_e=False if is_last else args.model.gt.attn.O_e, + cfg=args.model.gt, + ), + ) self.layers = nn.Sequential(*layers) if include_decoder: self.decoder = GraphHead(dim_inner, dim_out) - def forward(self, batch): + def forward(self, batch): """ Forward pass for GRIT. @@ -244,6 +249,7 @@ def forward(self, batch): return batch + def aggregate_pg(batch, mask_value=-1.0): """Aggregate per-generator active power (PG) onto bus nodes. @@ -327,6 +333,7 @@ def __init__(self, args): and hasattr(args.model.encoder, "posenc_RWSE") ): from gridfm_graphkit.io.param_handler import NestedNamespace + enc_rwse = args.model.encoder.posenc_RWSE if not hasattr(enc_rwse, "kernel"): enc_rwse.kernel = NestedNamespace() @@ -373,8 +380,11 @@ def forward(self, batch): # --- Extract bus-only homogeneous subgraph --- # Aggregate generator PG onto buses pg_per_bus = aggregate_pg(batch, mask_value=self.grit.mask_value[0].item()) - bus_x = torch.cat([batch["bus"].x, pg_per_bus.unsqueeze(-1)], dim=-1) # 15 → 16D - + bus_x = torch.cat( + [batch["bus"].x, pg_per_bus.unsqueeze(-1)], + dim=-1, + ) # 15 → 16D + homo = Data( x=bus_x, y=batch["bus"].y, @@ -398,6 +408,10 @@ def forward(self, batch): # --- Per-type decoding --- bus_out = self.bus_head(homo.x) - gen_out = self.gen_head(batch["gen"].x) if self.gen_head is not None else batch["gen"].x + gen_out = ( + self.gen_head(batch["gen"].x) + if self.gen_head is not None + else batch["gen"].x + ) return {"bus": bus_out, "gen": gen_out} diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py index b24078de..c75a7585 100644 --- a/gridfm_graphkit/models/kernel_pos_encoder.py +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -25,9 +25,11 @@ class KernelPENodeEncoder(torch.nn.Module): def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): super().__init__() if self.kernel_type is None: - raise ValueError(f"{self.__class__.__name__} has to be " - f"preconfigured by setting 'kernel_type' class" - f"variable before calling the constructor.") + raise ValueError( + f"{self.__class__.__name__} has to be " + f"preconfigured by setting 'kernel_type' class" + f"variable before calling the constructor.", + ) dim_pe = pecfg.pe_dim # Size of the kernel-based PE embedding num_rw_steps = pecfg.kernel.times @@ -35,28 +37,31 @@ def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable if dim_emb - dim_pe < 1: - raise ValueError(f"PE dim size {dim_pe} is too large for " - f"desired embedding size of {dim_emb}.") + raise ValueError( + f"PE dim size {dim_pe} is too large for " + f"desired embedding size of {dim_emb}.", + ) if expand_x: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) self.expand_x = expand_x - if norm_type == 'batchnorm': + if norm_type == "batchnorm": self.raw_norm = nn.BatchNorm1d(num_rw_steps) else: self.raw_norm = None self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) - def forward(self, batch): pestat_var = f"pestat_{self.kernel_type}" if not hasattr(batch, pestat_var): - raise ValueError(f"Precomputed '{pestat_var}' variable is " - f"required for {self.__class__.__name__}; set " - f"config 'posenc_{self.kernel_type}.enable' to " - f"True, and also set 'posenc.kernel.times' values") + raise ValueError( + f"Precomputed '{pestat_var}' variable is " + f"required for {self.__class__.__name__}; set " + f"config 'posenc_{self.kernel_type}.enable' to " + f"True, and also set 'posenc.kernel.times' values", + ) pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) @@ -79,6 +84,6 @@ def forward(self, batch): class RWSENodeEncoder(KernelPENodeEncoder): - """Random Walk Structural Encoding node encoder. - """ - kernel_type = 'RWSE' \ No newline at end of file + """Random Walk Structural Encoding node encoder.""" + + kernel_type = "RWSE" diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 1f7fd10b..19746279 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -1,15 +1,21 @@ """ - The RRWP encoder for GRIT (ours) +The RRWP encoder for GRIT (ours) """ + import torch from torch import nn from torch.nn import functional as F import torch_sparse -from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_geometric.utils import ( + remove_self_loops, + add_remaining_self_loops, + add_self_loops, +) from torch_scatter import scatter import warnings + def full_edge_index(edge_index, batch=None): """ Return the Full batched sparse adjacency matrices given by edge indices. @@ -28,16 +34,14 @@ def full_edge_index(edge_index, batch=None): batch_size = batch.max().item() + 1 one = batch.new_ones(batch.size(0)) - num_nodes = scatter(one, batch, - dim=0, dim_size=batch_size, reduce='add') + num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce="add") cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) negative_index_list = [] for i in range(batch_size): n = num_nodes[i].item() size = [n, n] - adj = torch.ones(size, dtype=torch.short, - device=edge_index.device) + adj = torch.ones(size, dtype=torch.short, device=edge_index.device) adj = adj.view(size) _edge_index = adj.nonzero(as_tuple=False).t().contiguous() @@ -48,7 +52,6 @@ def full_edge_index(edge_index, batch=None): return edge_index_full - class RRWPLinearNodeEncoder(torch.nn.Module): """ FC_1(RRWP) + FC_2 (Node-attr) @@ -56,7 +59,16 @@ class RRWPLinearNodeEncoder(torch.nn.Module): Parameters: num_classes - the number of classes for the embedding mapping to learn """ - def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + + def __init__( + self, + emb_dim, + out_dim, + use_bias=False, + batchnorm=False, + layernorm=False, + pe_name="rrwp", + ): super().__init__() self.batchnorm = batchnorm self.layernorm = layernorm @@ -98,27 +110,38 @@ class RRWPLinearEdgeEncoder(torch.nn.Module): - (optional) add node-attr as the E_{i,i}'s attr note: assuming node-attr and edge-attr is with the same dimension after Encoders """ - def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, - pad_to_full_graph=True, fill_value=0., - add_node_attr_as_self_loop=False, - overwrite_old_attr=False): + + def __init__( + self, + emb_dim, + out_dim, + batchnorm=False, + layernorm=False, + use_bias=False, + pad_to_full_graph=True, + fill_value=0.0, + add_node_attr_as_self_loop=False, + overwrite_old_attr=False, + ): super().__init__() # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info self.emb_dim = emb_dim self.out_dim = out_dim self.add_node_attr_as_self_loop = add_node_attr_as_self_loop - self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + self.overwrite_old_attr = overwrite_old_attr # remove the old edge-attr self.batchnorm = batchnorm self.layernorm = layernorm if self.batchnorm or self.layernorm: - warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + warnings.warn( + "batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ", + ) # print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph - self.fill_value = 0. + self.fill_value = 0.0 padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value self.register_buffer("padding", padding) @@ -143,15 +166,20 @@ def forward(self, batch): if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: - edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + edge_index, edge_attr = add_self_loops( + edge_index, + edge_attr, + num_nodes=batch.num_nodes, + fill_value=0.0, + ) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), - batch.num_nodes, batch.num_nodes, - op="add" + batch.num_nodes, + batch.num_nodes, + op="add", ) - if self.pad_to_full_graph: edge_index_full = full_edge_index(out_idx, batch=batch.batch) edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) @@ -159,8 +187,11 @@ def forward(self, batch): out_idx = torch.cat([out_idx, edge_index_full], dim=1) out_val = torch.cat([out_val, edge_attr_pad], dim=0) out_idx, out_val = torch_sparse.coalesce( - out_idx, out_val, batch.num_nodes, batch.num_nodes, - op="add" + out_idx, + out_val, + batch.num_nodes, + batch.num_nodes, + op="add", ) if self.batchnorm: @@ -169,15 +200,13 @@ def forward(self, batch): if self.layernorm: out_val = self.ln(out_val) - batch.edge_index, batch.edge_attr = out_idx, out_val return batch def __repr__(self): - return f"{self.__class__.__name__}" \ - f"(pad_to_full_graph={self.pad_to_full_graph}," \ - f"fill_value={self.fill_value}," \ - f"{self.fc.__repr__()})" - - - + return ( + f"{self.__class__.__name__}" + f"(pad_to_full_graph={self.pad_to_full_graph}," + f"fill_value={self.fill_value}," + f"{self.fc.__repr__()})" + ) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 46be6998..d943f397 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -73,20 +73,28 @@ def _clamp_known_to_ground_truth(output_bus, target, batch, gen_to_bus_index, nu mask_bus = batch.mask_dict["bus"] eval_bus = output_bus.clone() eval_bus[:, VM_OUT] = torch.where( - mask_bus[:, VM_H], output_bus[:, VM_OUT], target[:, VM_OUT], + mask_bus[:, VM_H], + output_bus[:, VM_OUT], + target[:, VM_OUT], ) eval_bus[:, VA_OUT] = torch.where( - mask_bus[:, VA_H], output_bus[:, VA_OUT], target[:, VA_OUT], + mask_bus[:, VA_H], + output_bus[:, VA_OUT], + target[:, VA_OUT], ) gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() any_gen_masked = ( scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 ) eval_bus[:, PG_OUT] = torch.where( - any_gen_masked, output_bus[:, PG_OUT], target[:, PG_OUT], + any_gen_masked, + output_bus[:, PG_OUT], + target[:, PG_OUT], ) eval_bus[:, QG_OUT] = torch.where( - mask_bus[:, QG_H], output_bus[:, QG_OUT], target[:, QG_OUT], + mask_bus[:, QG_H], + output_bus[:, QG_OUT], + target[:, QG_OUT], ) return eval_bus @@ -119,7 +127,11 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) eval_bus = _clamp_known_to_ground_truth( - output["bus"], target, batch, gen_to_bus_index, num_bus, + output["bus"], + target, + batch, + gen_to_bus_index, + num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) @@ -413,7 +425,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) eval_bus = _clamp_known_to_ground_truth( - output["bus"], target, batch, gen_to_bus_index, num_bus, + output["bus"], + target, + batch, + gen_to_bus_index, + num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 69bc10b6..b94ffed1 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -194,33 +194,42 @@ def forward( dim=0, dim_size=num_bus, ) - target = torch.stack([ - target_bus[:, VM_H], - target_bus[:, VA_H], - target_pg_agg, - target_bus[:, QG_H], - target_bus[:, PD_H], - target_bus[:, QD_H], - ], dim=1) + target = torch.stack( + [ + target_bus[:, VM_H], + target_bus[:, VA_H], + target_pg_agg, + target_bus[:, QG_H], + target_bus[:, PD_H], + target_bus[:, QD_H], + ], + dim=1, + ) # --- Build mask: [N_bus, 6] --- # PG bus-level mask: True if any generator at the bus has PG masked gen_pg_masked = mask_dict["gen"][:, PG_H].float() - any_gen_masked = scatter_add( - gen_pg_masked, - gen_to_bus_ei[1], - dim=0, - dim_size=num_bus, - ) > 0 + any_gen_masked = ( + scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + > 0 + ) - mask = torch.stack([ - mask_dict["bus"][:, VM_H], - mask_dict["bus"][:, VA_H], - any_gen_masked, - mask_dict["bus"][:, QG_H], - mask_dict["bus"][:, PD_H], - mask_dict["bus"][:, QD_H], - ], dim=1) + mask = torch.stack( + [ + mask_dict["bus"][:, VM_H], + mask_dict["bus"][:, VA_H], + any_gen_masked, + mask_dict["bus"][:, QG_H], + mask_dict["bus"][:, PD_H], + mask_dict["bus"][:, QD_H], + ], + dim=1, + ) # --- Prediction: [VM, VA, PG, QG, PD, QD] from bus head --- pred = pred_bus[:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT, PD_OUT, QD_OUT]] @@ -420,7 +429,8 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } - + + @LOSS_REGISTRY.register("PBE") class PBELoss(BaseLoss): """ @@ -446,8 +456,8 @@ def forward( model=None, x_dict=None, ): - pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] - target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] + pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] + target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] num_bus = target_bus.size(0) bus_edge_index = edge_index_dict[("bus", "connects", "bus")] @@ -519,7 +529,8 @@ def forward( # Build complete Y-bus: off-diagonal edges + self-loops for diagonal diag_idx = torch.arange(num_bus, device=bus_edge_index.device) full_edge_index = torch.cat( - [bus_edge_index, torch.stack([diag_idx, diag_idx])], dim=1, + [bus_edge_index, torch.stack([diag_idx, diag_idx])], + dim=1, ) full_edge_values = torch.cat([edge_offdiag, Y_diag]) @@ -544,12 +555,15 @@ def forward( dim_size=num_bus, ) gen_pg_masked = mask_dict["gen"][:, PG_H].float() - any_gen_masked = scatter_add( - gen_pg_masked, - gen_to_bus_ei[1], - dim=0, - dim_size=num_bus, - ) > 0 + any_gen_masked = ( + scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + > 0 + ) Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) # Pd, Qd, Qg: same clamp-to-ground-truth logic. The size guard @@ -557,15 +571,27 @@ def forward( # head (e.g. output_bus_dim=4) that don't predict PD/QD/QG; in that # case the target is always used. if pred_bus.size(1) > PD_OUT: - Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) + Pd = torch.where( + mask_bus[:, PD_H], + pred_bus[:, PD_OUT], + target_bus[:, PD_H], + ) else: Pd = target_bus[:, PD_H] if pred_bus.size(1) > QD_OUT: - Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) + Qd = torch.where( + mask_bus[:, QD_H], + pred_bus[:, QD_OUT], + target_bus[:, QD_H], + ) else: Qd = target_bus[:, QD_H] if pred_bus.size(1) > QG_OUT: - Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) + Qg = torch.where( + mask_bus[:, QG_H], + pred_bus[:, QG_OUT], + target_bus[:, QG_H], + ) else: Qg = target_bus[:, QG_H] diff --git a/pyproject.toml b/pyproject.toml index 100b8753..0e5081e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,8 @@ dependencies = [ "pyyaml>=6.0.2", "torch>=2.7.1,<2.9", "torch-geometric>=2.6.1", - "torch-scatter>=2.1.2", - "torch-sparse>=0.6.18", + "torch-scatter", + "torch-sparse", "torchaudio>=2.7.1", "torchvision>=0.22.1", "lightning", diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index fe3010fc..2027bdde 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -54,34 +54,59 @@ # Compilation (kept from original) import torch._dynamo as dynamo + dynamo.config.suppress_errors = False # ---------------------------- # Argument Parsing # ---------------------------- -parser = argparse.ArgumentParser(description="Benchmark GNN Model inference with profiling CSV") -parser.add_argument("--model", type=str, choices=["hetero", "grit"], default="hetero", - help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer") -parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") +parser = argparse.ArgumentParser( + description="Benchmark GNN Model inference with profiling CSV", +) +parser.add_argument( + "--model", + type=str, + choices=["hetero", "grit"], + default="hetero", + help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer", +) +parser.add_argument( + "--config", + type=str, + required=True, + help="Path to config YAML for model", +) parser.add_argument("--num_nodes", type=int, required=True) -parser.add_argument("--num_gens", type=int, default=0, - help="Number of generator nodes (required for hetero, ignored for grit)") +parser.add_argument( + "--num_gens", + type=int, + default=0, + help="Number of generator nodes (required for hetero, ignored for grit)", +) parser.add_argument("--num_edges", type=int, required=True) parser.add_argument("--output_csv", type=str, required=True) parser.add_argument("--iterations", type=int, default=20) parser.add_argument("--num_workers", type=int, default=0, help="DataLoader num_workers") -parser.add_argument("--pin_memory", action="store_true", help="Enable pin_memory in DataLoader when CUDA is available") +parser.add_argument( + "--pin_memory", + action="store_true", + help="Enable pin_memory in DataLoader when CUDA is available", +) args = parser.parse_args() # --- Custom logging (ensure directory exists) import logging -os.makedirs('logs', exist_ok=True) -logger = logging.getLogger('ibm_benchmark_logger') + +os.makedirs("logs", exist_ok=True) +logger = logging.getLogger("ibm_benchmark_logger") logger.setLevel(logging.DEBUG) logger.propagate = False -file_handler = logging.FileHandler('logs/ibm_bench_logs.log', mode='a') # 'a' for append, 'w' to overwrite +file_handler = logging.FileHandler( + "logs/ibm_bench_logs.log", + mode="a", +) # 'a' for append, 'w' to overwrite file_handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") file_handler.setFormatter(formatter) if not logger.handlers: logger.addHandler(file_handler) @@ -113,17 +138,32 @@ # Both model types use HeteroData with bus + gen nodes. # Fall back to input_dim / defaults for configs that lack the hetero keys. -BUS_FEATS = getattr(config_args.model, "input_bus_dim", - getattr(config_args.model, "input_dim", 15)) +BUS_FEATS = getattr( + config_args.model, + "input_bus_dim", + getattr(config_args.model, "input_dim", 15), +) GEN_FEATS = getattr(config_args.model, "input_gen_dim", 6) if MODEL_TYPE == "grit": # Positional encoding config (only GRIT uses these) # Read enablement and dimensions from data config (canonical source). - RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False - RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 - RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False) - RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + RRWP_ENABLED = ( + getattr(config_args.data.posenc_RRWP, "enable", False) + if hasattr(config_args.data, "posenc_RRWP") + else False + ) + RRWP_KSTEPS = ( + getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 + ) + RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr( + config_args.data.posenc_RWSE, + "enable", + False, + ) + RWSE_TIMES = ( + getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + ) else: RRWP_ENABLED = False RRWP_KSTEPS = 0 @@ -134,16 +174,19 @@ batch_sizes = [1, 2, 4, 8, 16, 32] iterations = args.iterations + # ---------------------------- # Helpers # ---------------------------- def now_ms() -> float: return time.perf_counter() * 1000.0 + def maybe_cuda_sync(): if torch.cuda.is_available(): torch.cuda.synchronize() + def get_env_info(): # CPU name detection cpu_name = None @@ -163,14 +206,22 @@ def get_env_info(): # GPU names and device info if torch.cuda.is_available(): try: - gpu_names_list = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + gpu_names_list = [ + torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) + ] gpu_names = "; ".join(gpu_names_list) except Exception: gpu_names = "cuda_available_but_name_unreadable" device_type = "cuda" - device_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + device_name = ( + torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + ) cuda_version_in_torch = torch.version.cuda - cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None + cudnn_version = ( + torch.backends.cudnn.version() + if torch.backends.cudnn.is_available() + else None + ) else: # Apple Metal backend? if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): @@ -198,6 +249,7 @@ def get_env_info(): } return info + # ---------------------------- # Generate Synthetic Hetero Graph # ---------------------------- @@ -232,12 +284,14 @@ def generate_hetero_graph(): # Gen → Bus data["gen", "connected_to", "bus"].edge_index = torch.stack( - [torch.arange(N_GEN), gen_to_bus], dim=0 + [torch.arange(N_GEN), gen_to_bus], + dim=0, ) # Bus → Gen data["bus", "connected_to", "gen"].edge_index = torch.stack( - [gen_to_bus, torch.arange(N_GEN)], dim=0 + [gen_to_bus, torch.arange(N_GEN)], + dim=0, ) # No edge features for these @@ -248,7 +302,10 @@ def generate_hetero_graph(): mask_bus = torch.ones_like(data["bus"].x, dtype=torch.bool) mask_gen = torch.ones_like(data["gen"].x, dtype=torch.bool) bus_types = torch.randint(0, 3, (N_BUS,)) - mask_branch = torch.ones_like(data["bus", "connects", "bus"].edge_attr, dtype=torch.bool) + mask_branch = torch.ones_like( + data["bus", "connects", "bus"].edge_attr, + dtype=torch.bool, + ) mask_PQ = bus_types == 0 mask_PV = bus_types == 1 @@ -260,7 +317,7 @@ def generate_hetero_graph(): "PQ": mask_PQ, "PV": mask_PV, "REF": mask_REF, - "branch": mask_branch + "branch": mask_branch, } return data @@ -299,13 +356,14 @@ def generate_grit_graph(): return data + # ---------------------------- # Benchmark Function # ---------------------------- def benchmark(): # Environment/context info (constant per run) env = get_env_info() - timestamp = datetime.now().isoformat(timespec='seconds') + timestamp = datetime.now().isoformat(timespec="seconds") # Measure synthetic graph creation t0 = now_ms() @@ -331,11 +389,9 @@ def benchmark(): # Keep original first two columns "batch_size", "avg_time_per_sample_ms", - # Execution config "num_iters", "total_samples", - # Data/IO timing "data_gen_time_ms", "graph_to_device_time_ms", @@ -343,7 +399,6 @@ def benchmark(): "dataloader_create_time_ms", "dataloader_first_iter_time_ms", "batch_to_device_time_ms", - # Model timing "warmup_time_ms", "iter_total_wall_time_ms", @@ -353,21 +408,25 @@ def benchmark(): "samples_per_sec_wall", "samples_per_sec_gpu", "timing_source", # "cuda_event" or "wall_clock" - # Memory "max_cuda_mem_alloc_bytes", "max_cuda_mem_reserved_bytes", - # Graph & model context - "n_bus", "n_gen", "n_edges", - "bus_feats", "gen_feats", "edge_feats", - + "n_bus", + "n_gen", + "n_edges", + "bus_feats", + "gen_feats", + "edge_feats", # Runtime context - "device_type", "device_name", - "torch_version", "cuda_version_in_torch", "cudnn_version", + "device_type", + "device_name", + "torch_version", + "cuda_version_in_torch", + "cudnn_version", "python_version", - "cpu_name", # NEW - "gpu_names", # NEW + "cpu_name", # NEW + "gpu_names", # NEW "timestamp_iso", "num_workers", "pin_memory", @@ -409,7 +468,11 @@ def benchmark(): # Ensure batch on device (likely ~0 if items already on device) maybe_cuda_sync() t_b2d_start = now_ms() - batch = batch.to(device, non_blocking=True) if torch.cuda.is_available() else batch.to(device) + batch = ( + batch.to(device, non_blocking=True) + if torch.cuda.is_available() + else batch.to(device) + ) maybe_cuda_sync() t_b2d_end = now_ms() batch_to_device_time_ms = t_b2d_end - t_b2d_start @@ -455,10 +518,20 @@ def benchmark(): timing_source = "cuda_event" avg_time_per_sample_ms = iter_gpu_time_ms / total_samples gpu_idle_time_ms = max(iter_total_wall_time_ms - iter_gpu_time_ms, 0.0) - gpu_busy_ratio = (iter_gpu_time_ms / iter_total_wall_time_ms) if iter_total_wall_time_ms > 0 else None + gpu_busy_ratio = ( + (iter_gpu_time_ms / iter_total_wall_time_ms) + if iter_total_wall_time_ms > 0 + else None + ) max_cuda_mem_alloc_bytes = int(torch.cuda.max_memory_allocated(device)) - max_cuda_mem_reserved_bytes = int(torch.cuda.max_memory_reserved(device)) - samples_per_sec_gpu = (total_samples / (iter_gpu_time_ms / 1000.0)) if iter_gpu_time_ms > 0 else None + max_cuda_mem_reserved_bytes = int( + torch.cuda.max_memory_reserved(device), + ) + samples_per_sec_gpu = ( + (total_samples / (iter_gpu_time_ms / 1000.0)) + if iter_gpu_time_ms > 0 + else None + ) else: iter_gpu_time_ms = None timing_source = "wall_clock" @@ -469,23 +542,24 @@ def benchmark(): max_cuda_mem_reserved_bytes = None samples_per_sec_gpu = None - samples_per_sec_wall = (total_samples / (iter_total_wall_time_ms / 1000.0)) if iter_total_wall_time_ms > 0 else None + samples_per_sec_wall = ( + (total_samples / (iter_total_wall_time_ms / 1000.0)) + if iter_total_wall_time_ms > 0 + else None + ) # Prepare row row = [ batch_size, avg_time_per_sample_ms, - num_iters, total_samples, - data_gen_time_ms, graph_to_device_time_ms, clone_list_time_ms, dataloader_create_time_ms, dataloader_first_iter_time_ms, batch_to_device_time_ms, - warmup_time_ms, iter_total_wall_time_ms, iter_gpu_time_ms, @@ -494,15 +568,19 @@ def benchmark(): samples_per_sec_wall, samples_per_sec_gpu, timing_source, - max_cuda_mem_alloc_bytes, max_cuda_mem_reserved_bytes, - - N_BUS, N_GEN, E, - BUS_FEATS, GEN_FEATS, EDGE_FEATS, - - env["device_type"], env["device_name"], - env["torch_version"], env["cuda_version_in_torch"], env["cudnn_version"], + N_BUS, + N_GEN, + E, + BUS_FEATS, + GEN_FEATS, + EDGE_FEATS, + env["device_type"], + env["device_name"], + env["torch_version"], + env["cuda_version_in_torch"], + env["cudnn_version"], env["python_version"], env["cpu_name"], env["gpu_names"], From 4f22dd1351efe80a900febca053f81bb647b19e9 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:54:25 -0400 Subject: [PATCH 66/95] pre-merge checks Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 11 +---------- gridfm_graphkit/models/grit_layer.py | 17 ----------------- gridfm_graphkit/models/grit_transformer.py | 1 - gridfm_graphkit/models/rrwp_encoder.py | 3 --- scripts/benchmark_model_inference.py | 2 +- 5 files changed, 2 insertions(+), 32 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index cc780ee5..cae18f05 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -1,13 +1,6 @@ -from copy import deepcopy - -import numpy as np import torch -import torch.nn.functional as F from torch_geometric.utils import ( - get_laplacian, - to_scipy_sparse_matrix, - to_undirected, to_dense_adj, ) from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -19,8 +12,6 @@ from torch_geometric.data import Data, HeteroData from typing import Any -from torch_geometric.utils.num_nodes import maybe_num_nodes - def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. @@ -112,7 +103,7 @@ def get_rw_landing_probs( if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) - source, dest = edge_index[0], edge_index[1] + source, _ = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. deg_inv = deg.pow(-1.0) deg_inv.masked_fill_(deg_inv == float("inf"), 0) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 2dc4abe5..52d2cdb3 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -117,7 +117,6 @@ def propagate_attention(self, batch): if self.clamp is not None: score = torch.clamp(score, min=-self.clamp, max=self.clamp) - raw_attn = score score = pyg_softmax( score, batch.edge_index[1], @@ -224,22 +223,6 @@ def __init__( no_qk=getattr(cfg.attn, "no_qk", False), ) - if getattr(cfg.attn, "graphormer_attn", False): - self.attention = MultiHeadAttentionLayerGraphormerSparse( - in_dim=in_dim, - out_dim=out_dim // num_heads, - num_heads=num_heads, - use_bias=getattr(cfg.attn, "use_bias", False), - dropout=attn_dropout, - clamp=getattr(cfg.attn, "clamp", 5.0), - act=getattr(cfg.attn, "act", "relu"), - edge_enhance=True, - sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), - signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), - scaled_attn=getattr(cfg.attn, "scaled_attn", False), - no_qk=getattr(cfg.attn, "no_qk", False), - ) - self.O_h = nn.Linear(out_dim // num_heads * num_heads, out_dim) if O_e: self.O_e = nn.Linear(out_dim // num_heads * num_heads, out_dim) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 285b830b..8440b334 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -163,7 +163,6 @@ def __init__(self, args, include_decoder=True): dim_in = args.model.input_dim dim_out = args.model.output_dim dim_inner = args.model.hidden_size - dim_edge = args.model.edge_dim num_heads = args.model.attention_head dropout = args.model.dropout num_layers = args.model.num_layers diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 19746279..960974ec 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -4,12 +4,9 @@ import torch from torch import nn -from torch.nn import functional as F import torch_sparse from torch_geometric.utils import ( - remove_self_loops, - add_remaining_self_loops, add_self_loops, ) from torch_scatter import scatter diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index 2027bdde..01b1e6cd 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -45,6 +45,7 @@ from torch_geometric.loader import DataLoader from torch_geometric.data import HeteroData from gridfm_graphkit.io.param_handler import NestedNamespace, load_model +import logging # Optional: tqdm (imported but not required for core flow) try: @@ -95,7 +96,6 @@ args = parser.parse_args() # --- Custom logging (ensure directory exists) -import logging os.makedirs("logs", exist_ok=True) logger = logging.getLogger("ibm_benchmark_logger") From 9a8ba0576f46adc4aeb24f7cabd91d068c1ba5a2 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 15 Apr 2026 12:56:28 -0400 Subject: [PATCH 67/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- README.md | 1 - gridfm_graphkit/cli.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/README.md b/README.md index aa57d943..ad85f056 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,6 @@ source venv/bin/activate Install gridfm-graphkit in editable mode ```bash pip install -e . -pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` Get PyTorch + CUDA version for torch-scatter diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index b21a42ab..d045bc88 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -210,9 +210,6 @@ def main_cli(args): profiler=profiler, ) - # print('******model*****') - # print(model) - # print('******model*****') if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) From 7d294d5680cfd2fd376a2d72850d7736f714d8cf Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 15 Apr 2026 14:23:15 -0400 Subject: [PATCH 68/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- README.md | 1 + pyproject.toml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index ad85f056..935409a3 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' Install the correct torch-scatter wheel ```bash pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html +pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html ``` diff --git a/pyproject.toml b/pyproject.toml index 0e5081e5..cd8d39db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,6 @@ dependencies = [ "pyyaml>=6.0.2", "torch>=2.7.1,<2.9", "torch-geometric>=2.6.1", - "torch-scatter", - "torch-sparse", "torchaudio>=2.7.1", "torchvision>=0.22.1", "lightning", From d8e7529e54149dc39742568b4b497a3451ed23a1 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:45:37 -0400 Subject: [PATCH 69/95] support for optional scatter and sparse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/__init__.py | 21 ++++++++++---- gridfm_graphkit/models/__init__.py | 24 +++++++++++----- gridfm_graphkit/models/grit_transformer.py | 13 ++++++++- gridfm_graphkit/models/rrwp_encoder.py | 32 ++++++++++++++++++++-- gridfm_graphkit/tasks/__init__.py | 18 ++++++++++-- 5 files changed, 90 insertions(+), 18 deletions(-) diff --git a/gridfm_graphkit/__init__.py b/gridfm_graphkit/__init__.py index 9378901f..d91c0dc7 100644 --- a/gridfm_graphkit/__init__.py +++ b/gridfm_graphkit/__init__.py @@ -1,8 +1,19 @@ -import gridfm_graphkit.datasets -import gridfm_graphkit.tasks.base_task -import gridfm_graphkit.models.gnn_heterogeneous_gns -import gridfm_graphkit.tasks.reconstruction_tasks +import importlib as _importlib __all__ = [ - "gridfm_graphkit", + "datasets", + "tasks", + "models", ] + +_LAZY_SUBMODULES = { + "datasets": "gridfm_graphkit.datasets", + "tasks": "gridfm_graphkit.tasks", + "models": "gridfm_graphkit.models", +} + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return _importlib.import_module(_LAZY_SUBMODULES[name]) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f185c6a2..e9a58399 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,10 +1,4 @@ -from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous -from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter -from gridfm_graphkit.models.utils import ( - PhysicsDecoderOPF, - PhysicsDecoderPF, - PhysicsDecoderSE, -) +import importlib as _importlib __all__ = [ "GNS_heterogeneous", @@ -13,3 +7,19 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] + +_LAZY_IMPORTS = { + "GNS_heterogeneous": ("gridfm_graphkit.models.gnn_heterogeneous_gns", "GNS_heterogeneous"), + "GritHeteroAdapter": ("gridfm_graphkit.models.grit_transformer", "GritHeteroAdapter"), + "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), + "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), + "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + mod = _importlib.import_module(module_path) + return getattr(mod, attr) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 8440b334..1f2a4711 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -9,7 +9,12 @@ ) from gridfm_graphkit.models.grit_layer import GritTransformerLayer from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from gridfm_graphkit.datasets.globals import PG_H @@ -263,6 +268,12 @@ def aggregate_pg(batch, mask_value=-1.0): where *all* connected generators are masked receive the mask value instead, preserving a consistent "unknown" indicator. """ + if scatter_add is None: + raise ImportError( + "torch-scatter is required for the GRIT modules but is not installed. " + "Install it with: pip install torch-scatter" + ) + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index gen_pg = batch["gen"].x[:, PG_H] gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 960974ec..116711b5 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -4,14 +4,38 @@ import torch from torch import nn -import torch_sparse + +try: + import torch_sparse +except ImportError: + torch_sparse = None from torch_geometric.utils import ( add_self_loops, ) -from torch_scatter import scatter + +try: + from torch_scatter import scatter +except ImportError: + scatter = None + import warnings +_MISSING_MSG = ( + "{pkg} is required for the RRWP / GRIT modules but is not installed. " + "Install it with: pip install {pkg}" +) + + +def _check_sparse(): + if torch_sparse is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-sparse")) + + +def _check_scatter(): + if scatter is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-scatter")) + def full_edge_index(edge_index, batch=None): """ @@ -26,6 +50,8 @@ def full_edge_index(edge_index, batch=None): Complementary edge index. """ + _check_scatter() + if batch is None: batch = edge_index.new_zeros(edge_index.max().item() + 1) @@ -169,6 +195,7 @@ def forward(self, batch): num_nodes=batch.num_nodes, fill_value=0.0, ) + _check_sparse() out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), @@ -183,6 +210,7 @@ def forward(self, batch): # zero padding to fully-connected graphs out_idx = torch.cat([out_idx, edge_index_full], dim=1) out_val = torch.cat([out_val, edge_attr_pad], dim=0) + _check_sparse() out_idx, out_val = torch_sparse.coalesce( out_idx, out_val, diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index 8ed9b137..b213a392 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -1,5 +1,17 @@ -from gridfm_graphkit.tasks.pf_task import PowerFlowTask -from gridfm_graphkit.tasks.opf_task import OptimalPowerFlowTask -from gridfm_graphkit.tasks.se_task import StateEstimationTask +import importlib as _importlib __all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask"] + +_LAZY_IMPORTS = { + "PowerFlowTask": ("gridfm_graphkit.tasks.pf_task", "PowerFlowTask"), + "OptimalPowerFlowTask": ("gridfm_graphkit.tasks.opf_task", "OptimalPowerFlowTask"), + "StateEstimationTask": ("gridfm_graphkit.tasks.se_task", "StateEstimationTask"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + mod = _importlib.import_module(module_path) + return getattr(mod, attr) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From e111e5f712d6638fb9ad1fe89e250d93adbe0c4e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:45:37 -0400 Subject: [PATCH 70/95] ruff format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- {examples => examples_wdir}/__init__.py | 0 .../config/GRIT_PF_datakit_case14.yaml | 0 .../config/HGNS_OPFData_case118.yaml | 0 .../config/HGNS_OPFData_case14.yaml | 0 .../config/HGNS_OPFData_case2000.yaml | 0 .../config/HGNS_OPFData_case30.yaml | 0 .../config/HGNS_OPFData_case500.yaml | 0 .../config/HGNS_OPFData_case57.yaml | 0 .../config/HGNS_OPF_datakit_case118.yaml | 0 .../config/HGNS_OPF_datakit_case14.yaml | 0 .../config/HGNS_OPF_datakit_case2000.yaml | 0 .../config/HGNS_OPF_datakit_case30.yaml | 0 .../config/HGNS_OPF_datakit_case500.yaml | 0 .../config/HGNS_OPF_datakit_case57.yaml | 0 .../config/HGNS_PF_datakit_case118.yaml | 0 .../config/HGNS_PF_datakit_case14.yaml | 0 .../config/HGNS_PF_datakit_case2000.yaml | 0 .../config/HGNS_PF_datakit_case30.yaml | 0 .../config/HGNS_PF_datakit_case500.yaml | 0 .../config/HGNS_PF_datakit_case57.yaml | 0 .../config/HGNS_PF_datakit_caseTexas.yaml | 0 .../config/HGNS_PF_pfdelta_case118.yaml | 0 .../config/HGNS_SE_datakit_case118.yaml | 0 .../config/HGNS_SE_datakit_case14.yaml | 0 {examples => examples_wdir}/config/__init__.py | 0 .../data/contingency_texas/branch_idx_removed.csv | 0 .../data/contingency_texas/bus_params.csv | 0 .../data/contingency_texas/edge_params.csv | 0 .../data/contingency_texas/pf_node_10_examples.csv | 0 .../data/contingency_texas/predictions_10_examples.csv | 0 .../notebooks/Tutorial_contingency_analisys.ipynb | 0 .../Tutorial_reconstruction_visualization.ipynb | 0 gridfm_graphkit/cli.py | 1 - gridfm_graphkit/models/__init__.py | 10 ++++++++-- 34 files changed, 8 insertions(+), 3 deletions(-) rename {examples => examples_wdir}/__init__.py (100%) rename {examples => examples_wdir}/config/GRIT_PF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_caseTexas.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_pfdelta_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_SE_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_SE_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/__init__.py (100%) rename {examples => examples_wdir}/data/contingency_texas/branch_idx_removed.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/bus_params.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/edge_params.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/pf_node_10_examples.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/predictions_10_examples.csv (100%) rename {examples => examples_wdir}/notebooks/Tutorial_contingency_analisys.ipynb (100%) rename {examples => examples_wdir}/notebooks/Tutorial_reconstruction_visualization.ipynb (100%) diff --git a/examples/__init__.py b/examples_wdir/__init__.py similarity index 100% rename from examples/__init__.py rename to examples_wdir/__init__.py diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples_wdir/config/GRIT_PF_datakit_case14.yaml similarity index 100% rename from examples/config/GRIT_PF_datakit_case14.yaml rename to examples_wdir/config/GRIT_PF_datakit_case14.yaml diff --git a/examples/config/HGNS_OPFData_case118.yaml b/examples_wdir/config/HGNS_OPFData_case118.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case118.yaml rename to examples_wdir/config/HGNS_OPFData_case118.yaml diff --git a/examples/config/HGNS_OPFData_case14.yaml b/examples_wdir/config/HGNS_OPFData_case14.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case14.yaml rename to examples_wdir/config/HGNS_OPFData_case14.yaml diff --git a/examples/config/HGNS_OPFData_case2000.yaml b/examples_wdir/config/HGNS_OPFData_case2000.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case2000.yaml rename to examples_wdir/config/HGNS_OPFData_case2000.yaml diff --git a/examples/config/HGNS_OPFData_case30.yaml b/examples_wdir/config/HGNS_OPFData_case30.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case30.yaml rename to examples_wdir/config/HGNS_OPFData_case30.yaml diff --git a/examples/config/HGNS_OPFData_case500.yaml b/examples_wdir/config/HGNS_OPFData_case500.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case500.yaml rename to examples_wdir/config/HGNS_OPFData_case500.yaml diff --git a/examples/config/HGNS_OPFData_case57.yaml b/examples_wdir/config/HGNS_OPFData_case57.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case57.yaml rename to examples_wdir/config/HGNS_OPFData_case57.yaml diff --git a/examples/config/HGNS_OPF_datakit_case118.yaml b/examples_wdir/config/HGNS_OPF_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case118.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case118.yaml diff --git a/examples/config/HGNS_OPF_datakit_case14.yaml b/examples_wdir/config/HGNS_OPF_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case14.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case14.yaml diff --git a/examples/config/HGNS_OPF_datakit_case2000.yaml b/examples_wdir/config/HGNS_OPF_datakit_case2000.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case2000.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case2000.yaml diff --git a/examples/config/HGNS_OPF_datakit_case30.yaml b/examples_wdir/config/HGNS_OPF_datakit_case30.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case30.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case30.yaml diff --git a/examples/config/HGNS_OPF_datakit_case500.yaml b/examples_wdir/config/HGNS_OPF_datakit_case500.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case500.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case500.yaml diff --git a/examples/config/HGNS_OPF_datakit_case57.yaml b/examples_wdir/config/HGNS_OPF_datakit_case57.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case57.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case57.yaml diff --git a/examples/config/HGNS_PF_datakit_case118.yaml b/examples_wdir/config/HGNS_PF_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case118.yaml rename to examples_wdir/config/HGNS_PF_datakit_case118.yaml diff --git a/examples/config/HGNS_PF_datakit_case14.yaml b/examples_wdir/config/HGNS_PF_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case14.yaml rename to examples_wdir/config/HGNS_PF_datakit_case14.yaml diff --git a/examples/config/HGNS_PF_datakit_case2000.yaml b/examples_wdir/config/HGNS_PF_datakit_case2000.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case2000.yaml rename to examples_wdir/config/HGNS_PF_datakit_case2000.yaml diff --git a/examples/config/HGNS_PF_datakit_case30.yaml b/examples_wdir/config/HGNS_PF_datakit_case30.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case30.yaml rename to examples_wdir/config/HGNS_PF_datakit_case30.yaml diff --git a/examples/config/HGNS_PF_datakit_case500.yaml b/examples_wdir/config/HGNS_PF_datakit_case500.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case500.yaml rename to examples_wdir/config/HGNS_PF_datakit_case500.yaml diff --git a/examples/config/HGNS_PF_datakit_case57.yaml b/examples_wdir/config/HGNS_PF_datakit_case57.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case57.yaml rename to examples_wdir/config/HGNS_PF_datakit_case57.yaml diff --git a/examples/config/HGNS_PF_datakit_caseTexas.yaml b/examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_caseTexas.yaml rename to examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml diff --git a/examples/config/HGNS_PF_pfdelta_case118.yaml b/examples_wdir/config/HGNS_PF_pfdelta_case118.yaml similarity index 100% rename from examples/config/HGNS_PF_pfdelta_case118.yaml rename to examples_wdir/config/HGNS_PF_pfdelta_case118.yaml diff --git a/examples/config/HGNS_SE_datakit_case118.yaml b/examples_wdir/config/HGNS_SE_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_SE_datakit_case118.yaml rename to examples_wdir/config/HGNS_SE_datakit_case118.yaml diff --git a/examples/config/HGNS_SE_datakit_case14.yaml b/examples_wdir/config/HGNS_SE_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_SE_datakit_case14.yaml rename to examples_wdir/config/HGNS_SE_datakit_case14.yaml diff --git a/examples/config/__init__.py b/examples_wdir/config/__init__.py similarity index 100% rename from examples/config/__init__.py rename to examples_wdir/config/__init__.py diff --git a/examples/data/contingency_texas/branch_idx_removed.csv b/examples_wdir/data/contingency_texas/branch_idx_removed.csv similarity index 100% rename from examples/data/contingency_texas/branch_idx_removed.csv rename to examples_wdir/data/contingency_texas/branch_idx_removed.csv diff --git a/examples/data/contingency_texas/bus_params.csv b/examples_wdir/data/contingency_texas/bus_params.csv similarity index 100% rename from examples/data/contingency_texas/bus_params.csv rename to examples_wdir/data/contingency_texas/bus_params.csv diff --git a/examples/data/contingency_texas/edge_params.csv b/examples_wdir/data/contingency_texas/edge_params.csv similarity index 100% rename from examples/data/contingency_texas/edge_params.csv rename to examples_wdir/data/contingency_texas/edge_params.csv diff --git a/examples/data/contingency_texas/pf_node_10_examples.csv b/examples_wdir/data/contingency_texas/pf_node_10_examples.csv similarity index 100% rename from examples/data/contingency_texas/pf_node_10_examples.csv rename to examples_wdir/data/contingency_texas/pf_node_10_examples.csv diff --git a/examples/data/contingency_texas/predictions_10_examples.csv b/examples_wdir/data/contingency_texas/predictions_10_examples.csv similarity index 100% rename from examples/data/contingency_texas/predictions_10_examples.csv rename to examples_wdir/data/contingency_texas/predictions_10_examples.csv diff --git a/examples/notebooks/Tutorial_contingency_analisys.ipynb b/examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb similarity index 100% rename from examples/notebooks/Tutorial_contingency_analisys.ipynb rename to examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb diff --git a/examples/notebooks/Tutorial_reconstruction_visualization.ipynb b/examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb similarity index 100% rename from examples/notebooks/Tutorial_reconstruction_visualization.ipynb rename to examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index d045bc88..0e9d4398 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -210,7 +210,6 @@ def main_cli(args): profiler=profiler, ) - if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) if ( diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index e9a58399..e31310d7 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -9,8 +9,14 @@ ] _LAZY_IMPORTS = { - "GNS_heterogeneous": ("gridfm_graphkit.models.gnn_heterogeneous_gns", "GNS_heterogeneous"), - "GritHeteroAdapter": ("gridfm_graphkit.models.grit_transformer", "GritHeteroAdapter"), + "GNS_heterogeneous": ( + "gridfm_graphkit.models.gnn_heterogeneous_gns", + "GNS_heterogeneous", + ), + "GritHeteroAdapter": ( + "gridfm_graphkit.models.grit_transformer", + "GritHeteroAdapter", + ), "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), From 40818c65268cf5779896628089c4a58ef167ad70 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:45:38 -0400 Subject: [PATCH 71/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- {examples_wdir => examples}/__init__.py | 0 {examples_wdir => examples}/config/GRIT_PF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_caseTexas.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_pfdelta_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_SE_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_SE_datakit_case14.yaml | 0 {examples_wdir => examples}/config/__init__.py | 0 .../data/contingency_texas/branch_idx_removed.csv | 0 {examples_wdir => examples}/data/contingency_texas/bus_params.csv | 0 .../data/contingency_texas/edge_params.csv | 0 .../data/contingency_texas/pf_node_10_examples.csv | 0 .../data/contingency_texas/predictions_10_examples.csv | 0 .../notebooks/Tutorial_contingency_analisys.ipynb | 0 .../notebooks/Tutorial_reconstruction_visualization.ipynb | 0 32 files changed, 0 insertions(+), 0 deletions(-) rename {examples_wdir => examples}/__init__.py (100%) rename {examples_wdir => examples}/config/GRIT_PF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_caseTexas.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_pfdelta_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_SE_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_SE_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/__init__.py (100%) rename {examples_wdir => examples}/data/contingency_texas/branch_idx_removed.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/bus_params.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/edge_params.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/pf_node_10_examples.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/predictions_10_examples.csv (100%) rename {examples_wdir => examples}/notebooks/Tutorial_contingency_analisys.ipynb (100%) rename {examples_wdir => examples}/notebooks/Tutorial_reconstruction_visualization.ipynb (100%) diff --git a/examples_wdir/__init__.py b/examples/__init__.py similarity index 100% rename from examples_wdir/__init__.py rename to examples/__init__.py diff --git a/examples_wdir/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/GRIT_PF_datakit_case14.yaml rename to examples/config/GRIT_PF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case118.yaml b/examples/config/HGNS_OPFData_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case118.yaml rename to examples/config/HGNS_OPFData_case118.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case14.yaml b/examples/config/HGNS_OPFData_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case14.yaml rename to examples/config/HGNS_OPFData_case14.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case2000.yaml b/examples/config/HGNS_OPFData_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case2000.yaml rename to examples/config/HGNS_OPFData_case2000.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case30.yaml b/examples/config/HGNS_OPFData_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case30.yaml rename to examples/config/HGNS_OPFData_case30.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case500.yaml b/examples/config/HGNS_OPFData_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case500.yaml rename to examples/config/HGNS_OPFData_case500.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case57.yaml b/examples/config/HGNS_OPFData_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case57.yaml rename to examples/config/HGNS_OPFData_case57.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case118.yaml b/examples/config/HGNS_OPF_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case118.yaml rename to examples/config/HGNS_OPF_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case14.yaml b/examples/config/HGNS_OPF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case14.yaml rename to examples/config/HGNS_OPF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case2000.yaml b/examples/config/HGNS_OPF_datakit_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case2000.yaml rename to examples/config/HGNS_OPF_datakit_case2000.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case30.yaml b/examples/config/HGNS_OPF_datakit_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case30.yaml rename to examples/config/HGNS_OPF_datakit_case30.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case500.yaml b/examples/config/HGNS_OPF_datakit_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case500.yaml rename to examples/config/HGNS_OPF_datakit_case500.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case57.yaml b/examples/config/HGNS_OPF_datakit_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case57.yaml rename to examples/config/HGNS_OPF_datakit_case57.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case118.yaml b/examples/config/HGNS_PF_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case118.yaml rename to examples/config/HGNS_PF_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case14.yaml b/examples/config/HGNS_PF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case14.yaml rename to examples/config/HGNS_PF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case2000.yaml b/examples/config/HGNS_PF_datakit_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case2000.yaml rename to examples/config/HGNS_PF_datakit_case2000.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case30.yaml b/examples/config/HGNS_PF_datakit_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case30.yaml rename to examples/config/HGNS_PF_datakit_case30.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case500.yaml b/examples/config/HGNS_PF_datakit_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case500.yaml rename to examples/config/HGNS_PF_datakit_case500.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case57.yaml b/examples/config/HGNS_PF_datakit_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case57.yaml rename to examples/config/HGNS_PF_datakit_case57.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml b/examples/config/HGNS_PF_datakit_caseTexas.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml rename to examples/config/HGNS_PF_datakit_caseTexas.yaml diff --git a/examples_wdir/config/HGNS_PF_pfdelta_case118.yaml b/examples/config/HGNS_PF_pfdelta_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_pfdelta_case118.yaml rename to examples/config/HGNS_PF_pfdelta_case118.yaml diff --git a/examples_wdir/config/HGNS_SE_datakit_case118.yaml b/examples/config/HGNS_SE_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_SE_datakit_case118.yaml rename to examples/config/HGNS_SE_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_SE_datakit_case14.yaml b/examples/config/HGNS_SE_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_SE_datakit_case14.yaml rename to examples/config/HGNS_SE_datakit_case14.yaml diff --git a/examples_wdir/config/__init__.py b/examples/config/__init__.py similarity index 100% rename from examples_wdir/config/__init__.py rename to examples/config/__init__.py diff --git a/examples_wdir/data/contingency_texas/branch_idx_removed.csv b/examples/data/contingency_texas/branch_idx_removed.csv similarity index 100% rename from examples_wdir/data/contingency_texas/branch_idx_removed.csv rename to examples/data/contingency_texas/branch_idx_removed.csv diff --git a/examples_wdir/data/contingency_texas/bus_params.csv b/examples/data/contingency_texas/bus_params.csv similarity index 100% rename from examples_wdir/data/contingency_texas/bus_params.csv rename to examples/data/contingency_texas/bus_params.csv diff --git a/examples_wdir/data/contingency_texas/edge_params.csv b/examples/data/contingency_texas/edge_params.csv similarity index 100% rename from examples_wdir/data/contingency_texas/edge_params.csv rename to examples/data/contingency_texas/edge_params.csv diff --git a/examples_wdir/data/contingency_texas/pf_node_10_examples.csv b/examples/data/contingency_texas/pf_node_10_examples.csv similarity index 100% rename from examples_wdir/data/contingency_texas/pf_node_10_examples.csv rename to examples/data/contingency_texas/pf_node_10_examples.csv diff --git a/examples_wdir/data/contingency_texas/predictions_10_examples.csv b/examples/data/contingency_texas/predictions_10_examples.csv similarity index 100% rename from examples_wdir/data/contingency_texas/predictions_10_examples.csv rename to examples/data/contingency_texas/predictions_10_examples.csv diff --git a/examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb b/examples/notebooks/Tutorial_contingency_analisys.ipynb similarity index 100% rename from examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb rename to examples/notebooks/Tutorial_contingency_analisys.ipynb diff --git a/examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb b/examples/notebooks/Tutorial_reconstruction_visualization.ipynb similarity index 100% rename from examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb rename to examples/notebooks/Tutorial_reconstruction_visualization.ipynb From dbb5d476c5808746ef4036ca40ad4bbafaf0c305 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:45:38 -0400 Subject: [PATCH 72/95] support for scatter and sparse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/__init__.py | 21 +++++-------------- gridfm_graphkit/models/__init__.py | 30 +++++++--------------------- gridfm_graphkit/models/grit_layer.py | 8 +++++++- gridfm_graphkit/tasks/__init__.py | 18 +++-------------- gridfm_graphkit/tasks/pf_task.py | 7 ++++++- 5 files changed, 28 insertions(+), 56 deletions(-) diff --git a/gridfm_graphkit/__init__.py b/gridfm_graphkit/__init__.py index d91c0dc7..9378901f 100644 --- a/gridfm_graphkit/__init__.py +++ b/gridfm_graphkit/__init__.py @@ -1,19 +1,8 @@ -import importlib as _importlib +import gridfm_graphkit.datasets +import gridfm_graphkit.tasks.base_task +import gridfm_graphkit.models.gnn_heterogeneous_gns +import gridfm_graphkit.tasks.reconstruction_tasks __all__ = [ - "datasets", - "tasks", - "models", + "gridfm_graphkit", ] - -_LAZY_SUBMODULES = { - "datasets": "gridfm_graphkit.datasets", - "tasks": "gridfm_graphkit.tasks", - "models": "gridfm_graphkit.models", -} - - -def __getattr__(name: str): - if name in _LAZY_SUBMODULES: - return _importlib.import_module(_LAZY_SUBMODULES[name]) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index e31310d7..f185c6a2 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,10 @@ -import importlib as _importlib +from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter +from gridfm_graphkit.models.utils import ( + PhysicsDecoderOPF, + PhysicsDecoderPF, + PhysicsDecoderSE, +) __all__ = [ "GNS_heterogeneous", @@ -7,25 +13,3 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] - -_LAZY_IMPORTS = { - "GNS_heterogeneous": ( - "gridfm_graphkit.models.gnn_heterogeneous_gns", - "GNS_heterogeneous", - ), - "GritHeteroAdapter": ( - "gridfm_graphkit.models.grit_transformer", - "GritHeteroAdapter", - ), - "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), - "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), - "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module_path, attr = _LAZY_IMPORTS[name] - mod = _importlib.import_module(module_path) - return getattr(mod, attr) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 52d2cdb3..c0136981 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -4,7 +4,13 @@ import torch.nn.functional as F import torch_geometric as pyg from torch_geometric.utils.num_nodes import maybe_num_nodes -from torch_scatter import scatter, scatter_max, scatter_add + +try: + from torch_scatter import scatter, scatter_max, scatter_add +except ImportError: + scatter = None + scatter_max = None + scatter_add = None import opt_einsum as oe diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index b213a392..8ed9b137 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -1,17 +1,5 @@ -import importlib as _importlib +from gridfm_graphkit.tasks.pf_task import PowerFlowTask +from gridfm_graphkit.tasks.opf_task import OptimalPowerFlowTask +from gridfm_graphkit.tasks.se_task import StateEstimationTask __all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask"] - -_LAZY_IMPORTS = { - "PowerFlowTask": ("gridfm_graphkit.tasks.pf_task", "PowerFlowTask"), - "OptimalPowerFlowTask": ("gridfm_graphkit.tasks.opf_task", "OptimalPowerFlowTask"), - "StateEstimationTask": ("gridfm_graphkit.tasks.se_task", "StateEstimationTask"), -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module_path, attr = _LAZY_IMPORTS[name] - mod = _importlib.import_module(module_path) - return getattr(mod, attr) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index d943f397..56b4b4a0 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -24,7 +24,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from torch_geometric.nn import global_mean_pool from gridfm_graphkit.models.utils import ( ComputeBranchFlow, From 101a9f5f6c0abad68401bf32a4b21c95db41a6d3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:45:38 -0400 Subject: [PATCH 73/95] formatting Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 1f2a4711..75b0d1c6 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -271,7 +271,7 @@ def aggregate_pg(batch, mask_value=-1.0): if scatter_add is None: raise ImportError( "torch-scatter is required for the GRIT modules but is not installed. " - "Install it with: pip install torch-scatter" + "Install it with: pip install torch-scatter", ) gen_to_bus = batch["gen", "connected_to", "bus"].edge_index From e900295821b9986503401e521758414e8cc8d6b2 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:46 -0400 Subject: [PATCH 74/95] support for optional scatter and sparse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/__init__.py | 21 ++++++++++---- gridfm_graphkit/models/__init__.py | 24 +++++++++++----- gridfm_graphkit/models/grit_transformer.py | 13 ++++++++- gridfm_graphkit/models/rrwp_encoder.py | 32 ++++++++++++++++++++-- gridfm_graphkit/tasks/__init__.py | 18 ++++++++++-- 5 files changed, 90 insertions(+), 18 deletions(-) diff --git a/gridfm_graphkit/__init__.py b/gridfm_graphkit/__init__.py index 9378901f..d91c0dc7 100644 --- a/gridfm_graphkit/__init__.py +++ b/gridfm_graphkit/__init__.py @@ -1,8 +1,19 @@ -import gridfm_graphkit.datasets -import gridfm_graphkit.tasks.base_task -import gridfm_graphkit.models.gnn_heterogeneous_gns -import gridfm_graphkit.tasks.reconstruction_tasks +import importlib as _importlib __all__ = [ - "gridfm_graphkit", + "datasets", + "tasks", + "models", ] + +_LAZY_SUBMODULES = { + "datasets": "gridfm_graphkit.datasets", + "tasks": "gridfm_graphkit.tasks", + "models": "gridfm_graphkit.models", +} + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return _importlib.import_module(_LAZY_SUBMODULES[name]) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f185c6a2..e9a58399 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,10 +1,4 @@ -from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous -from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter -from gridfm_graphkit.models.utils import ( - PhysicsDecoderOPF, - PhysicsDecoderPF, - PhysicsDecoderSE, -) +import importlib as _importlib __all__ = [ "GNS_heterogeneous", @@ -13,3 +7,19 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] + +_LAZY_IMPORTS = { + "GNS_heterogeneous": ("gridfm_graphkit.models.gnn_heterogeneous_gns", "GNS_heterogeneous"), + "GritHeteroAdapter": ("gridfm_graphkit.models.grit_transformer", "GritHeteroAdapter"), + "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), + "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), + "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + mod = _importlib.import_module(module_path) + return getattr(mod, attr) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 8440b334..1f2a4711 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -9,7 +9,12 @@ ) from gridfm_graphkit.models.grit_layer import GritTransformerLayer from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from gridfm_graphkit.datasets.globals import PG_H @@ -263,6 +268,12 @@ def aggregate_pg(batch, mask_value=-1.0): where *all* connected generators are masked receive the mask value instead, preserving a consistent "unknown" indicator. """ + if scatter_add is None: + raise ImportError( + "torch-scatter is required for the GRIT modules but is not installed. " + "Install it with: pip install torch-scatter" + ) + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index gen_pg = batch["gen"].x[:, PG_H] gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 960974ec..116711b5 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -4,14 +4,38 @@ import torch from torch import nn -import torch_sparse + +try: + import torch_sparse +except ImportError: + torch_sparse = None from torch_geometric.utils import ( add_self_loops, ) -from torch_scatter import scatter + +try: + from torch_scatter import scatter +except ImportError: + scatter = None + import warnings +_MISSING_MSG = ( + "{pkg} is required for the RRWP / GRIT modules but is not installed. " + "Install it with: pip install {pkg}" +) + + +def _check_sparse(): + if torch_sparse is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-sparse")) + + +def _check_scatter(): + if scatter is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-scatter")) + def full_edge_index(edge_index, batch=None): """ @@ -26,6 +50,8 @@ def full_edge_index(edge_index, batch=None): Complementary edge index. """ + _check_scatter() + if batch is None: batch = edge_index.new_zeros(edge_index.max().item() + 1) @@ -169,6 +195,7 @@ def forward(self, batch): num_nodes=batch.num_nodes, fill_value=0.0, ) + _check_sparse() out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), @@ -183,6 +210,7 @@ def forward(self, batch): # zero padding to fully-connected graphs out_idx = torch.cat([out_idx, edge_index_full], dim=1) out_val = torch.cat([out_val, edge_attr_pad], dim=0) + _check_sparse() out_idx, out_val = torch_sparse.coalesce( out_idx, out_val, diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index 8ed9b137..b213a392 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -1,5 +1,17 @@ -from gridfm_graphkit.tasks.pf_task import PowerFlowTask -from gridfm_graphkit.tasks.opf_task import OptimalPowerFlowTask -from gridfm_graphkit.tasks.se_task import StateEstimationTask +import importlib as _importlib __all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask"] + +_LAZY_IMPORTS = { + "PowerFlowTask": ("gridfm_graphkit.tasks.pf_task", "PowerFlowTask"), + "OptimalPowerFlowTask": ("gridfm_graphkit.tasks.opf_task", "OptimalPowerFlowTask"), + "StateEstimationTask": ("gridfm_graphkit.tasks.se_task", "StateEstimationTask"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + mod = _importlib.import_module(module_path) + return getattr(mod, attr) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From 8730d0657e5bb578c1ffacdf2d08366648f713a4 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:46 -0400 Subject: [PATCH 75/95] ruff format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- {examples => examples_wdir}/__init__.py | 0 .../config/GRIT_PF_datakit_case14.yaml | 0 .../config/HGNS_OPFData_case118.yaml | 0 .../config/HGNS_OPFData_case14.yaml | 0 .../config/HGNS_OPFData_case2000.yaml | 0 .../config/HGNS_OPFData_case30.yaml | 0 .../config/HGNS_OPFData_case500.yaml | 0 .../config/HGNS_OPFData_case57.yaml | 0 .../config/HGNS_OPF_datakit_case118.yaml | 0 .../config/HGNS_OPF_datakit_case14.yaml | 0 .../config/HGNS_OPF_datakit_case2000.yaml | 0 .../config/HGNS_OPF_datakit_case30.yaml | 0 .../config/HGNS_OPF_datakit_case500.yaml | 0 .../config/HGNS_OPF_datakit_case57.yaml | 0 .../config/HGNS_PF_datakit_case118.yaml | 0 .../config/HGNS_PF_datakit_case14.yaml | 0 .../config/HGNS_PF_datakit_case2000.yaml | 0 .../config/HGNS_PF_datakit_case30.yaml | 0 .../config/HGNS_PF_datakit_case500.yaml | 0 .../config/HGNS_PF_datakit_case57.yaml | 0 .../config/HGNS_PF_datakit_caseTexas.yaml | 0 .../config/HGNS_PF_pfdelta_case118.yaml | 0 .../config/HGNS_SE_datakit_case118.yaml | 0 .../config/HGNS_SE_datakit_case14.yaml | 0 {examples => examples_wdir}/config/__init__.py | 0 .../data/contingency_texas/branch_idx_removed.csv | 0 .../data/contingency_texas/bus_params.csv | 0 .../data/contingency_texas/edge_params.csv | 0 .../data/contingency_texas/pf_node_10_examples.csv | 0 .../data/contingency_texas/predictions_10_examples.csv | 0 .../notebooks/Tutorial_contingency_analisys.ipynb | 0 .../Tutorial_reconstruction_visualization.ipynb | 0 gridfm_graphkit/cli.py | 1 - gridfm_graphkit/models/__init__.py | 10 ++++++++-- 34 files changed, 8 insertions(+), 3 deletions(-) rename {examples => examples_wdir}/__init__.py (100%) rename {examples => examples_wdir}/config/GRIT_PF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPFData_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_OPF_datakit_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case2000.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case30.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case500.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_case57.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_datakit_caseTexas.yaml (100%) rename {examples => examples_wdir}/config/HGNS_PF_pfdelta_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_SE_datakit_case118.yaml (100%) rename {examples => examples_wdir}/config/HGNS_SE_datakit_case14.yaml (100%) rename {examples => examples_wdir}/config/__init__.py (100%) rename {examples => examples_wdir}/data/contingency_texas/branch_idx_removed.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/bus_params.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/edge_params.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/pf_node_10_examples.csv (100%) rename {examples => examples_wdir}/data/contingency_texas/predictions_10_examples.csv (100%) rename {examples => examples_wdir}/notebooks/Tutorial_contingency_analisys.ipynb (100%) rename {examples => examples_wdir}/notebooks/Tutorial_reconstruction_visualization.ipynb (100%) diff --git a/examples/__init__.py b/examples_wdir/__init__.py similarity index 100% rename from examples/__init__.py rename to examples_wdir/__init__.py diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples_wdir/config/GRIT_PF_datakit_case14.yaml similarity index 100% rename from examples/config/GRIT_PF_datakit_case14.yaml rename to examples_wdir/config/GRIT_PF_datakit_case14.yaml diff --git a/examples/config/HGNS_OPFData_case118.yaml b/examples_wdir/config/HGNS_OPFData_case118.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case118.yaml rename to examples_wdir/config/HGNS_OPFData_case118.yaml diff --git a/examples/config/HGNS_OPFData_case14.yaml b/examples_wdir/config/HGNS_OPFData_case14.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case14.yaml rename to examples_wdir/config/HGNS_OPFData_case14.yaml diff --git a/examples/config/HGNS_OPFData_case2000.yaml b/examples_wdir/config/HGNS_OPFData_case2000.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case2000.yaml rename to examples_wdir/config/HGNS_OPFData_case2000.yaml diff --git a/examples/config/HGNS_OPFData_case30.yaml b/examples_wdir/config/HGNS_OPFData_case30.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case30.yaml rename to examples_wdir/config/HGNS_OPFData_case30.yaml diff --git a/examples/config/HGNS_OPFData_case500.yaml b/examples_wdir/config/HGNS_OPFData_case500.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case500.yaml rename to examples_wdir/config/HGNS_OPFData_case500.yaml diff --git a/examples/config/HGNS_OPFData_case57.yaml b/examples_wdir/config/HGNS_OPFData_case57.yaml similarity index 100% rename from examples/config/HGNS_OPFData_case57.yaml rename to examples_wdir/config/HGNS_OPFData_case57.yaml diff --git a/examples/config/HGNS_OPF_datakit_case118.yaml b/examples_wdir/config/HGNS_OPF_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case118.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case118.yaml diff --git a/examples/config/HGNS_OPF_datakit_case14.yaml b/examples_wdir/config/HGNS_OPF_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case14.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case14.yaml diff --git a/examples/config/HGNS_OPF_datakit_case2000.yaml b/examples_wdir/config/HGNS_OPF_datakit_case2000.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case2000.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case2000.yaml diff --git a/examples/config/HGNS_OPF_datakit_case30.yaml b/examples_wdir/config/HGNS_OPF_datakit_case30.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case30.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case30.yaml diff --git a/examples/config/HGNS_OPF_datakit_case500.yaml b/examples_wdir/config/HGNS_OPF_datakit_case500.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case500.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case500.yaml diff --git a/examples/config/HGNS_OPF_datakit_case57.yaml b/examples_wdir/config/HGNS_OPF_datakit_case57.yaml similarity index 100% rename from examples/config/HGNS_OPF_datakit_case57.yaml rename to examples_wdir/config/HGNS_OPF_datakit_case57.yaml diff --git a/examples/config/HGNS_PF_datakit_case118.yaml b/examples_wdir/config/HGNS_PF_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case118.yaml rename to examples_wdir/config/HGNS_PF_datakit_case118.yaml diff --git a/examples/config/HGNS_PF_datakit_case14.yaml b/examples_wdir/config/HGNS_PF_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case14.yaml rename to examples_wdir/config/HGNS_PF_datakit_case14.yaml diff --git a/examples/config/HGNS_PF_datakit_case2000.yaml b/examples_wdir/config/HGNS_PF_datakit_case2000.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case2000.yaml rename to examples_wdir/config/HGNS_PF_datakit_case2000.yaml diff --git a/examples/config/HGNS_PF_datakit_case30.yaml b/examples_wdir/config/HGNS_PF_datakit_case30.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case30.yaml rename to examples_wdir/config/HGNS_PF_datakit_case30.yaml diff --git a/examples/config/HGNS_PF_datakit_case500.yaml b/examples_wdir/config/HGNS_PF_datakit_case500.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case500.yaml rename to examples_wdir/config/HGNS_PF_datakit_case500.yaml diff --git a/examples/config/HGNS_PF_datakit_case57.yaml b/examples_wdir/config/HGNS_PF_datakit_case57.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_case57.yaml rename to examples_wdir/config/HGNS_PF_datakit_case57.yaml diff --git a/examples/config/HGNS_PF_datakit_caseTexas.yaml b/examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml similarity index 100% rename from examples/config/HGNS_PF_datakit_caseTexas.yaml rename to examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml diff --git a/examples/config/HGNS_PF_pfdelta_case118.yaml b/examples_wdir/config/HGNS_PF_pfdelta_case118.yaml similarity index 100% rename from examples/config/HGNS_PF_pfdelta_case118.yaml rename to examples_wdir/config/HGNS_PF_pfdelta_case118.yaml diff --git a/examples/config/HGNS_SE_datakit_case118.yaml b/examples_wdir/config/HGNS_SE_datakit_case118.yaml similarity index 100% rename from examples/config/HGNS_SE_datakit_case118.yaml rename to examples_wdir/config/HGNS_SE_datakit_case118.yaml diff --git a/examples/config/HGNS_SE_datakit_case14.yaml b/examples_wdir/config/HGNS_SE_datakit_case14.yaml similarity index 100% rename from examples/config/HGNS_SE_datakit_case14.yaml rename to examples_wdir/config/HGNS_SE_datakit_case14.yaml diff --git a/examples/config/__init__.py b/examples_wdir/config/__init__.py similarity index 100% rename from examples/config/__init__.py rename to examples_wdir/config/__init__.py diff --git a/examples/data/contingency_texas/branch_idx_removed.csv b/examples_wdir/data/contingency_texas/branch_idx_removed.csv similarity index 100% rename from examples/data/contingency_texas/branch_idx_removed.csv rename to examples_wdir/data/contingency_texas/branch_idx_removed.csv diff --git a/examples/data/contingency_texas/bus_params.csv b/examples_wdir/data/contingency_texas/bus_params.csv similarity index 100% rename from examples/data/contingency_texas/bus_params.csv rename to examples_wdir/data/contingency_texas/bus_params.csv diff --git a/examples/data/contingency_texas/edge_params.csv b/examples_wdir/data/contingency_texas/edge_params.csv similarity index 100% rename from examples/data/contingency_texas/edge_params.csv rename to examples_wdir/data/contingency_texas/edge_params.csv diff --git a/examples/data/contingency_texas/pf_node_10_examples.csv b/examples_wdir/data/contingency_texas/pf_node_10_examples.csv similarity index 100% rename from examples/data/contingency_texas/pf_node_10_examples.csv rename to examples_wdir/data/contingency_texas/pf_node_10_examples.csv diff --git a/examples/data/contingency_texas/predictions_10_examples.csv b/examples_wdir/data/contingency_texas/predictions_10_examples.csv similarity index 100% rename from examples/data/contingency_texas/predictions_10_examples.csv rename to examples_wdir/data/contingency_texas/predictions_10_examples.csv diff --git a/examples/notebooks/Tutorial_contingency_analisys.ipynb b/examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb similarity index 100% rename from examples/notebooks/Tutorial_contingency_analisys.ipynb rename to examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb diff --git a/examples/notebooks/Tutorial_reconstruction_visualization.ipynb b/examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb similarity index 100% rename from examples/notebooks/Tutorial_reconstruction_visualization.ipynb rename to examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index d045bc88..0e9d4398 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -210,7 +210,6 @@ def main_cli(args): profiler=profiler, ) - if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) if ( diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index e9a58399..e31310d7 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -9,8 +9,14 @@ ] _LAZY_IMPORTS = { - "GNS_heterogeneous": ("gridfm_graphkit.models.gnn_heterogeneous_gns", "GNS_heterogeneous"), - "GritHeteroAdapter": ("gridfm_graphkit.models.grit_transformer", "GritHeteroAdapter"), + "GNS_heterogeneous": ( + "gridfm_graphkit.models.gnn_heterogeneous_gns", + "GNS_heterogeneous", + ), + "GritHeteroAdapter": ( + "gridfm_graphkit.models.grit_transformer", + "GritHeteroAdapter", + ), "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), From dc3f4947937cd5d4a951b5af403c604ee297d944 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:47 -0400 Subject: [PATCH 76/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- {examples_wdir => examples}/__init__.py | 0 {examples_wdir => examples}/config/GRIT_PF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_OPFData_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_OPF_datakit_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case14.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case2000.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case30.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case500.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_case57.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_datakit_caseTexas.yaml | 0 {examples_wdir => examples}/config/HGNS_PF_pfdelta_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_SE_datakit_case118.yaml | 0 {examples_wdir => examples}/config/HGNS_SE_datakit_case14.yaml | 0 {examples_wdir => examples}/config/__init__.py | 0 .../data/contingency_texas/branch_idx_removed.csv | 0 {examples_wdir => examples}/data/contingency_texas/bus_params.csv | 0 .../data/contingency_texas/edge_params.csv | 0 .../data/contingency_texas/pf_node_10_examples.csv | 0 .../data/contingency_texas/predictions_10_examples.csv | 0 .../notebooks/Tutorial_contingency_analisys.ipynb | 0 .../notebooks/Tutorial_reconstruction_visualization.ipynb | 0 32 files changed, 0 insertions(+), 0 deletions(-) rename {examples_wdir => examples}/__init__.py (100%) rename {examples_wdir => examples}/config/GRIT_PF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPFData_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_OPF_datakit_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case2000.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case30.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case500.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_case57.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_datakit_caseTexas.yaml (100%) rename {examples_wdir => examples}/config/HGNS_PF_pfdelta_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_SE_datakit_case118.yaml (100%) rename {examples_wdir => examples}/config/HGNS_SE_datakit_case14.yaml (100%) rename {examples_wdir => examples}/config/__init__.py (100%) rename {examples_wdir => examples}/data/contingency_texas/branch_idx_removed.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/bus_params.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/edge_params.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/pf_node_10_examples.csv (100%) rename {examples_wdir => examples}/data/contingency_texas/predictions_10_examples.csv (100%) rename {examples_wdir => examples}/notebooks/Tutorial_contingency_analisys.ipynb (100%) rename {examples_wdir => examples}/notebooks/Tutorial_reconstruction_visualization.ipynb (100%) diff --git a/examples_wdir/__init__.py b/examples/__init__.py similarity index 100% rename from examples_wdir/__init__.py rename to examples/__init__.py diff --git a/examples_wdir/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/GRIT_PF_datakit_case14.yaml rename to examples/config/GRIT_PF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case118.yaml b/examples/config/HGNS_OPFData_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case118.yaml rename to examples/config/HGNS_OPFData_case118.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case14.yaml b/examples/config/HGNS_OPFData_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case14.yaml rename to examples/config/HGNS_OPFData_case14.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case2000.yaml b/examples/config/HGNS_OPFData_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case2000.yaml rename to examples/config/HGNS_OPFData_case2000.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case30.yaml b/examples/config/HGNS_OPFData_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case30.yaml rename to examples/config/HGNS_OPFData_case30.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case500.yaml b/examples/config/HGNS_OPFData_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case500.yaml rename to examples/config/HGNS_OPFData_case500.yaml diff --git a/examples_wdir/config/HGNS_OPFData_case57.yaml b/examples/config/HGNS_OPFData_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPFData_case57.yaml rename to examples/config/HGNS_OPFData_case57.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case118.yaml b/examples/config/HGNS_OPF_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case118.yaml rename to examples/config/HGNS_OPF_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case14.yaml b/examples/config/HGNS_OPF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case14.yaml rename to examples/config/HGNS_OPF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case2000.yaml b/examples/config/HGNS_OPF_datakit_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case2000.yaml rename to examples/config/HGNS_OPF_datakit_case2000.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case30.yaml b/examples/config/HGNS_OPF_datakit_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case30.yaml rename to examples/config/HGNS_OPF_datakit_case30.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case500.yaml b/examples/config/HGNS_OPF_datakit_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case500.yaml rename to examples/config/HGNS_OPF_datakit_case500.yaml diff --git a/examples_wdir/config/HGNS_OPF_datakit_case57.yaml b/examples/config/HGNS_OPF_datakit_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_OPF_datakit_case57.yaml rename to examples/config/HGNS_OPF_datakit_case57.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case118.yaml b/examples/config/HGNS_PF_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case118.yaml rename to examples/config/HGNS_PF_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case14.yaml b/examples/config/HGNS_PF_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case14.yaml rename to examples/config/HGNS_PF_datakit_case14.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case2000.yaml b/examples/config/HGNS_PF_datakit_case2000.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case2000.yaml rename to examples/config/HGNS_PF_datakit_case2000.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case30.yaml b/examples/config/HGNS_PF_datakit_case30.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case30.yaml rename to examples/config/HGNS_PF_datakit_case30.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case500.yaml b/examples/config/HGNS_PF_datakit_case500.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case500.yaml rename to examples/config/HGNS_PF_datakit_case500.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_case57.yaml b/examples/config/HGNS_PF_datakit_case57.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_case57.yaml rename to examples/config/HGNS_PF_datakit_case57.yaml diff --git a/examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml b/examples/config/HGNS_PF_datakit_caseTexas.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_datakit_caseTexas.yaml rename to examples/config/HGNS_PF_datakit_caseTexas.yaml diff --git a/examples_wdir/config/HGNS_PF_pfdelta_case118.yaml b/examples/config/HGNS_PF_pfdelta_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_PF_pfdelta_case118.yaml rename to examples/config/HGNS_PF_pfdelta_case118.yaml diff --git a/examples_wdir/config/HGNS_SE_datakit_case118.yaml b/examples/config/HGNS_SE_datakit_case118.yaml similarity index 100% rename from examples_wdir/config/HGNS_SE_datakit_case118.yaml rename to examples/config/HGNS_SE_datakit_case118.yaml diff --git a/examples_wdir/config/HGNS_SE_datakit_case14.yaml b/examples/config/HGNS_SE_datakit_case14.yaml similarity index 100% rename from examples_wdir/config/HGNS_SE_datakit_case14.yaml rename to examples/config/HGNS_SE_datakit_case14.yaml diff --git a/examples_wdir/config/__init__.py b/examples/config/__init__.py similarity index 100% rename from examples_wdir/config/__init__.py rename to examples/config/__init__.py diff --git a/examples_wdir/data/contingency_texas/branch_idx_removed.csv b/examples/data/contingency_texas/branch_idx_removed.csv similarity index 100% rename from examples_wdir/data/contingency_texas/branch_idx_removed.csv rename to examples/data/contingency_texas/branch_idx_removed.csv diff --git a/examples_wdir/data/contingency_texas/bus_params.csv b/examples/data/contingency_texas/bus_params.csv similarity index 100% rename from examples_wdir/data/contingency_texas/bus_params.csv rename to examples/data/contingency_texas/bus_params.csv diff --git a/examples_wdir/data/contingency_texas/edge_params.csv b/examples/data/contingency_texas/edge_params.csv similarity index 100% rename from examples_wdir/data/contingency_texas/edge_params.csv rename to examples/data/contingency_texas/edge_params.csv diff --git a/examples_wdir/data/contingency_texas/pf_node_10_examples.csv b/examples/data/contingency_texas/pf_node_10_examples.csv similarity index 100% rename from examples_wdir/data/contingency_texas/pf_node_10_examples.csv rename to examples/data/contingency_texas/pf_node_10_examples.csv diff --git a/examples_wdir/data/contingency_texas/predictions_10_examples.csv b/examples/data/contingency_texas/predictions_10_examples.csv similarity index 100% rename from examples_wdir/data/contingency_texas/predictions_10_examples.csv rename to examples/data/contingency_texas/predictions_10_examples.csv diff --git a/examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb b/examples/notebooks/Tutorial_contingency_analisys.ipynb similarity index 100% rename from examples_wdir/notebooks/Tutorial_contingency_analisys.ipynb rename to examples/notebooks/Tutorial_contingency_analisys.ipynb diff --git a/examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb b/examples/notebooks/Tutorial_reconstruction_visualization.ipynb similarity index 100% rename from examples_wdir/notebooks/Tutorial_reconstruction_visualization.ipynb rename to examples/notebooks/Tutorial_reconstruction_visualization.ipynb From 526efa9e1a8fb6428c6a15288c300c99ef9234ed Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:47 -0400 Subject: [PATCH 77/95] support for scatter and sparse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/__init__.py | 21 +++++-------------- gridfm_graphkit/models/__init__.py | 30 +++++++--------------------- gridfm_graphkit/models/grit_layer.py | 8 +++++++- gridfm_graphkit/tasks/__init__.py | 18 +++-------------- gridfm_graphkit/tasks/pf_task.py | 7 ++++++- 5 files changed, 28 insertions(+), 56 deletions(-) diff --git a/gridfm_graphkit/__init__.py b/gridfm_graphkit/__init__.py index d91c0dc7..9378901f 100644 --- a/gridfm_graphkit/__init__.py +++ b/gridfm_graphkit/__init__.py @@ -1,19 +1,8 @@ -import importlib as _importlib +import gridfm_graphkit.datasets +import gridfm_graphkit.tasks.base_task +import gridfm_graphkit.models.gnn_heterogeneous_gns +import gridfm_graphkit.tasks.reconstruction_tasks __all__ = [ - "datasets", - "tasks", - "models", + "gridfm_graphkit", ] - -_LAZY_SUBMODULES = { - "datasets": "gridfm_graphkit.datasets", - "tasks": "gridfm_graphkit.tasks", - "models": "gridfm_graphkit.models", -} - - -def __getattr__(name: str): - if name in _LAZY_SUBMODULES: - return _importlib.import_module(_LAZY_SUBMODULES[name]) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index e31310d7..f185c6a2 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,10 @@ -import importlib as _importlib +from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter +from gridfm_graphkit.models.utils import ( + PhysicsDecoderOPF, + PhysicsDecoderPF, + PhysicsDecoderSE, +) __all__ = [ "GNS_heterogeneous", @@ -7,25 +13,3 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] - -_LAZY_IMPORTS = { - "GNS_heterogeneous": ( - "gridfm_graphkit.models.gnn_heterogeneous_gns", - "GNS_heterogeneous", - ), - "GritHeteroAdapter": ( - "gridfm_graphkit.models.grit_transformer", - "GritHeteroAdapter", - ), - "PhysicsDecoderOPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderOPF"), - "PhysicsDecoderPF": ("gridfm_graphkit.models.utils", "PhysicsDecoderPF"), - "PhysicsDecoderSE": ("gridfm_graphkit.models.utils", "PhysicsDecoderSE"), -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module_path, attr = _LAZY_IMPORTS[name] - mod = _importlib.import_module(module_path) - return getattr(mod, attr) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 52d2cdb3..c0136981 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -4,7 +4,13 @@ import torch.nn.functional as F import torch_geometric as pyg from torch_geometric.utils.num_nodes import maybe_num_nodes -from torch_scatter import scatter, scatter_max, scatter_add + +try: + from torch_scatter import scatter, scatter_max, scatter_add +except ImportError: + scatter = None + scatter_max = None + scatter_add = None import opt_einsum as oe diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index b213a392..8ed9b137 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -1,17 +1,5 @@ -import importlib as _importlib +from gridfm_graphkit.tasks.pf_task import PowerFlowTask +from gridfm_graphkit.tasks.opf_task import OptimalPowerFlowTask +from gridfm_graphkit.tasks.se_task import StateEstimationTask __all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask"] - -_LAZY_IMPORTS = { - "PowerFlowTask": ("gridfm_graphkit.tasks.pf_task", "PowerFlowTask"), - "OptimalPowerFlowTask": ("gridfm_graphkit.tasks.opf_task", "OptimalPowerFlowTask"), - "StateEstimationTask": ("gridfm_graphkit.tasks.se_task", "StateEstimationTask"), -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module_path, attr = _LAZY_IMPORTS[name] - mod = _importlib.import_module(module_path) - return getattr(mod, attr) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index d943f397..56b4b4a0 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -24,7 +24,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from torch_geometric.nn import global_mean_pool from gridfm_graphkit.models.utils import ( ComputeBranchFlow, From 32e3ac21f7b4e308846c70a88b94a7fb08245e8d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:47 -0400 Subject: [PATCH 78/95] formatting Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 1f2a4711..75b0d1c6 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -271,7 +271,7 @@ def aggregate_pg(batch, mask_value=-1.0): if scatter_add is None: raise ImportError( "torch-scatter is required for the GRIT modules but is not installed. " - "Install it with: pip install torch-scatter" + "Install it with: pip install torch-scatter", ) gen_to_bus = batch["gen", "connected_to", "bus"].edge_index From f9cb7b2e9d8bd8c12760c3a08948509541f64fc8 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:47:47 -0400 Subject: [PATCH 79/95] support for scatter and sparse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 12 +++++++++++- gridfm_graphkit/datasets/rrwp.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index cae18f05..db38e930 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -4,7 +4,12 @@ to_dense_adj, ) from torch_geometric.utils.num_nodes import maybe_num_nodes -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from functools import partial from gridfm_graphkit.datasets.rrwp import add_full_rrwp @@ -104,6 +109,11 @@ def get_rw_landing_probs( edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) source, _ = edge_index[0], edge_index[1] + if scatter_add is None: + raise ImportError( + "torch-scatter is required for RWSE positional encodings. " + "Install it with: pip install torch-scatter", + ) deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. deg_inv = deg.pow(-1.0) deg_inv.masked_fill_(deg_inv == float("inf"), 0) diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index 317515e2..9faa01af 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -2,7 +2,11 @@ import torch import torch.nn.functional as F from torch_geometric.data import Data -from torch_sparse import SparseTensor + +try: + from torch_sparse import SparseTensor +except ImportError: + SparseTensor = None def add_node_attr(data: Data, value: Any, attr_name: Optional[str] = None) -> Data: @@ -31,6 +35,12 @@ def add_full_rrwp( num_nodes = data.num_nodes edge_index, edge_weight = data.edge_index, data.edge_weight + if SparseTensor is None: + raise ImportError( + "torch-sparse is required for RRWP positional encodings. " + "Install it with: pip install torch-sparse", + ) + adj = SparseTensor.from_edge_index( edge_index, edge_weight, From 969129eb95f047f239212ed3a7dae591129b7db7 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:43 -0400 Subject: [PATCH 80/95] update callback Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/callbacks.py | 49 +++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index df8ee247..18af3420 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -82,3 +82,52 @@ def on_validation_end(self, trainer, pl_module): # Save the model's state_dict model_path = os.path.join(model_dir, self.filename) torch.save(pl_module.state_dict(), model_path) + + +class SaveLastModelStateDict(Callback): + def __init__(self, filename: str = "last_model_state_dict.pt"): + self.filename = filename + + @rank_zero_only + def on_train_epoch_end(self, trainer, pl_module): + logger = trainer.logger + if isinstance(logger, MLFlowLogger): + model_dir = os.path.join( + logger.save_dir, + logger.experiment_id, + logger.run_id, + "artifacts", + "model", + ) + else: + model_dir = os.path.join(logger.save_dir, "model") + + os.makedirs(model_dir, exist_ok=True) + model_path = os.path.join(model_dir, self.filename) + torch.save(pl_module.state_dict(), model_path) + +class FreezeMaskTokens(Callback): + """Inject pre-trained mask tokens and freeze them. + + Replaces nn.Parameter with a registered buffer so DDP does not expect + gradients for these tensors. + """ + + def __init__(self, mask_state_path: str): + super().__init__() + self.mask_state_path = mask_state_path + + def setup(self, trainer, pl_module, stage=None): + if stage != "fit": + return + saved = torch.load(self.mask_state_path, map_location="cpu") + model = pl_module.model + + for name in ("bus_mask_token", "edge_mask_token", "gen_mask_token"): + key = f"model.{name}" + if key in saved and hasattr(model, name): + tensor = saved[key] + # Remove the nn.Parameter and re-register as a buffer so DDP + # won't include it in gradient reduction. + delattr(model, name) + model.register_buffer(name, tensor) From 4cbd59b5ec09b16ffe6f547ac7e009aa5e617466 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:43 -0400 Subject: [PATCH 81/95] allow posenc caching Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 1 + gridfm_graphkit/datasets/cached_transform.py | 140 ++++++++++++++++++ .../datasets/hetero_powergrid_datamodule.py | 22 +++ 3 files changed, 163 insertions(+) create mode 100644 gridfm_graphkit/datasets/cached_transform.py diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index cef4b503..dbb5eadf 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -21,6 +21,7 @@ data: ksteps: 21 posenc_RWSE: enable: true + cache: true # cache the computed positional encoding at first pass (if false computes every access) kernel: times: 21 model: diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py new file mode 100644 index 00000000..750369ae --- /dev/null +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -0,0 +1,140 @@ +import os +import tempfile +import hashlib + +import torch +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import HeteroData + + +class CachedPosencTransform(BaseTransform): + """Disk-caching wrapper for positional encoding transforms. + + Computes the PE on the first access for each sample and caches the + result to disk. Subsequent accesses load from cache, avoiding + redundant computation across epochs and across jobs that share the + same processed directory. + + Thread/process safety: + - Uses atomic write (write to temp file, then os.replace) so + concurrent DataLoader workers or separate jobs cannot produce + corrupt cache files. + - Cache is keyed by scenario_id, which is unique per graph. + Different train/val/test splits across jobs safely share the + cache since RWSE depends only on topology, not on split + membership. + + Args: + transform: The inner PE transform (e.g. ComputePosencStat). + cache_dir: Directory to store cached PE tensors. + cached_attrs: List of attribute names to cache (e.g. ["pestat_RWSE"]). + key_attr: Attribute on the data object used as the cache key. + Must be a scalar tensor (e.g. scenario_id). + """ + + def __init__( + self, + transform: BaseTransform, + cache_dir: str, + cached_attrs: list[str], + key_attr: str = "scenario_id", + ): + self.transform = transform + self.cache_dir = cache_dir + self.cached_attrs = cached_attrs + self.key_attr = key_attr + os.makedirs(cache_dir, exist_ok=True) + + def _cache_path(self, data) -> str: + key = data[self.key_attr].item() + return os.path.join(self.cache_dir, f"pe_cache_{key}.pt") + + def _load_cache(self, cache_path, data): + """Load cached PE attributes and attach them to data.""" + cached = torch.load(cache_path, weights_only=True) + if isinstance(data, HeteroData): + for attr, val in cached.items(): + data["bus"][attr] = val + else: + for attr, val in cached.items(): + setattr(data, attr, val) + + def _save_cache(self, cache_path, data): + """Atomically save PE attributes to the cache file. + + Uses a temporary file in the same directory followed by + os.replace, which is atomic on both Linux and Windows. + This ensures concurrent workers/jobs never see a partially + written file. + """ + if isinstance(data, HeteroData): + target = data["bus"] + else: + target = data + + to_cache = {} + for attr in self.cached_attrs: + if hasattr(target, attr): + to_cache[attr] = getattr(target, attr) + + if not to_cache: + return + + # Write to a temporary file in the same directory (same filesystem) + # to guarantee os.replace is atomic. + fd, tmp_path = tempfile.mkstemp( + dir=self.cache_dir, + prefix=f".pe_cache_tmp_{os.getpid()}_", + suffix=".pt", + ) + try: + os.close(fd) + torch.save(to_cache, tmp_path) + os.replace(tmp_path, cache_path) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + def __call__(self, data): + cache_path = self._cache_path(data) + + # Fast path: load from cache if available + if os.path.exists(cache_path): + self._load_cache(cache_path, data) + return data + + # Slow path: compute, then cache + data = self.transform(data) + self._save_cache(cache_path, data) + return data + + +def make_pe_cache_dir(processed_dir: str, pe_type: str, cfg) -> str: + """Build a cache directory path that includes a config fingerprint. + + The fingerprint ensures that changing PE parameters (e.g. kernel.times) + invalidates the cache automatically by using a different directory. + + Args: + processed_dir: The dataset's processed directory. + pe_type: "RWSE" or "RRWP". + cfg: The data config namespace containing PE parameters. + + Returns: + Path to the cache directory. + """ + if pe_type == "RWSE": + kernel_times = cfg.posenc_RWSE.kernel.times + fingerprint = f"k{kernel_times}" + elif pe_type == "RRWP": + ksteps = cfg.posenc_RRWP.ksteps + fingerprint = f"k{ksteps}" + else: + fingerprint = "default" + + cache_dir_name = f"pe_cache_{pe_type.lower()}_{fingerprint}" + return os.path.join(processed_dir, cache_dir_name) diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 09534b42..51d2796d 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -18,6 +18,10 @@ from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat +from gridfm_graphkit.datasets.cached_transform import ( + CachedPosencTransform, + make_pe_cache_dir, +) import torch_geometric.transforms as T @@ -162,12 +166,30 @@ def setup(self, stage: str): if ("posenc_RRWP" in self.args.data) and self.args.data.posenc_RRWP.enable: pe_transform = ComputePosencStat(pe_types=["RRWP"], cfg=self.args.data) + if getattr(self.args.data.posenc_RRWP, "cache", False): + cache_dir = make_pe_cache_dir( + dataset.processed_dir, "RRWP", self.args.data, + ) + pe_transform = CachedPosencTransform( + pe_transform, + cache_dir, + cached_attrs=["rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"], + ) if dataset.transform is None: dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) if ("posenc_RWSE" in self.args.data) and self.args.data.posenc_RWSE.enable: pe_transform = ComputePosencStat(pe_types=["RWSE"], cfg=self.args.data) + if getattr(self.args.data.posenc_RWSE, "cache", False): + cache_dir = make_pe_cache_dir( + dataset.processed_dir, "RWSE", self.args.data, + ) + pe_transform = CachedPosencTransform( + pe_transform, + cache_dir, + cached_attrs=["pestat_RWSE"], + ) if dataset.transform is None: dataset.transform = pe_transform else: From d4c47259b88dffa614b94883b1f6bf28c5a6e803 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:44 -0400 Subject: [PATCH 82/95] allow posenc caching Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/cached_transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index 750369ae..61695cfe 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -3,11 +3,10 @@ import hashlib import torch -from torch_geometric.transforms import BaseTransform from torch_geometric.data import HeteroData -class CachedPosencTransform(BaseTransform): +class CachedPosencTransform: """Disk-caching wrapper for positional encoding transforms. Computes the PE on the first access for each sample and caches the From 6e041e2c15be9f876a19783e84420335e6f75969 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:44 -0400 Subject: [PATCH 83/95] allow posenc caching Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/cached_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index 61695cfe..a4eaf85d 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -33,7 +33,7 @@ class CachedPosencTransform: def __init__( self, - transform: BaseTransform, + transform, cache_dir: str, cached_attrs: list[str], key_attr: str = "scenario_id", From baa2132089651477b898f7b33044ff97203ed950 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:44 -0400 Subject: [PATCH 84/95] update caching for RRWP Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/cached_transform.py | 26 +++++++++++++++++-- .../datasets/hetero_powergrid_datamodule.py | 3 ++- gridfm_graphkit/datasets/posenc_stats.py | 18 ++++++++----- gridfm_graphkit/models/grit_transformer.py | 7 ++++- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index a4eaf85d..e31be1fe 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -26,7 +26,10 @@ class CachedPosencTransform: Args: transform: The inner PE transform (e.g. ComputePosencStat). cache_dir: Directory to store cached PE tensors. - cached_attrs: List of attribute names to cache (e.g. ["pestat_RWSE"]). + cached_attrs: List of attribute names to cache on the bus node store + (e.g. ["pestat_RWSE"]). + cached_edge_type: Optional edge type tuple (e.g. ("bus", "rrwp", "bus")) + whose edge_index and edge_attr should also be cached. key_attr: Attribute on the data object used as the cache key. Must be a scalar tensor (e.g. scenario_id). """ @@ -36,11 +39,13 @@ def __init__( transform, cache_dir: str, cached_attrs: list[str], + cached_edge_type: tuple[str, str, str] | None = None, key_attr: str = "scenario_id", ): self.transform = transform self.cache_dir = cache_dir self.cached_attrs = cached_attrs + self.cached_edge_type = cached_edge_type self.key_attr = key_attr os.makedirs(cache_dir, exist_ok=True) @@ -53,7 +58,12 @@ def _load_cache(self, cache_path, data): cached = torch.load(cache_path, weights_only=True) if isinstance(data, HeteroData): for attr, val in cached.items(): - data["bus"][attr] = val + if attr == "_edge_type_index": + data[self.cached_edge_type].edge_index = val + elif attr == "_edge_type_attr": + data[self.cached_edge_type].edge_attr = val + else: + data["bus"][attr] = val else: for attr, val in cached.items(): setattr(data, attr, val) @@ -76,6 +86,18 @@ def _save_cache(self, cache_path, data): if hasattr(target, attr): to_cache[attr] = getattr(target, attr) + # Cache edge-type data (RRWP sparse index + values) + if ( + self.cached_edge_type is not None + and isinstance(data, HeteroData) + and self.cached_edge_type in data.edge_types + ): + edge_store = data[self.cached_edge_type] + if hasattr(edge_store, "edge_index"): + to_cache["_edge_type_index"] = edge_store.edge_index + if hasattr(edge_store, "edge_attr"): + to_cache["_edge_type_attr"] = edge_store.edge_attr + if not to_cache: return diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 51d2796d..0124550c 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -173,7 +173,8 @@ def setup(self, stage: str): pe_transform = CachedPosencTransform( pe_transform, cache_dir, - cached_attrs=["rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"], + cached_attrs=["rrwp", "log_deg", "deg"], + cached_edge_type=("bus", "rrwp", "bus"), ) if dataset.transform is None: dataset.transform = pe_transform diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index db38e930..7c02f7c0 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -176,17 +176,23 @@ def _call_hetero(self, data: HeteroData) -> HeteroData: cfg=self.cfg, ) - # Copy computed PE attributes back onto the HeteroData bus store - pe_attrs = [ + # Copy computed PE attributes back onto the HeteroData. + # Node-level attrs go on the bus store; the sparse RRWP index/val + # are stored as a dedicated edge type so PyG batches them correctly + # (edge_index needs cat_dim=1 + node-count incrementing). + node_pe_attrs = [ "pestat_RWSE", # RWSE - "rrwp", - "rrwp_index", - "rrwp_val", # RRWP + "rrwp", # RRWP absolute (node-level diagonal) "log_deg", "deg", # degree info from RRWP ] - for attr in pe_attrs: + for attr in node_pe_attrs: if hasattr(bus_data, attr): data["bus"][attr] = getattr(bus_data, attr) + # RRWP relative PE: sparse [2, E] index + [E, K] values → edge type + if hasattr(bus_data, "rrwp_index") and hasattr(bus_data, "rrwp_val"): + data["bus", "rrwp", "bus"].edge_index = bus_data.rrwp_index + data["bus", "rrwp", "bus"].edge_attr = bus_data.rrwp_val + return data diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 75b0d1c6..3ad7545a 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -404,10 +404,15 @@ def forward(self, batch): ) # Forward positional-encoding attributes if present - for attr in ("pestat_RWSE", "rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"): + for attr in ("pestat_RWSE", "rrwp", "log_deg", "deg"): if hasattr(batch["bus"], attr): setattr(homo, attr, getattr(batch["bus"], attr)) + # RRWP relative PE is stored as a dedicated edge type for correct batching + if ("bus", "rrwp", "bus") in batch.edge_types: + homo.rrwp_index = batch["bus", "rrwp", "bus"].edge_index + homo.rrwp_val = batch["bus", "rrwp", "bus"].edge_attr + # --- Run GRIT encoder + PE encoders + transformer layers --- homo = self.grit.encoder(homo) if hasattr(self.grit, "rrwp_abs_encoder"): From aeedb7d801e80ee570a1c51087fade3189f9f691 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:44 -0400 Subject: [PATCH 85/95] topk sparsity Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- .../GRIT_PF_datakit_case14_topk_rrwp.yaml | 159 ++++++++++++++++++ gridfm_graphkit/datasets/cached_transform.py | 6 +- gridfm_graphkit/datasets/posenc_stats.py | 27 ++- gridfm_graphkit/datasets/rrwp.py | 131 +++++++++++++++ 4 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml diff --git a/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml b/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml new file mode 100644 index 00000000..e20a3796 --- /dev/null +++ b/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml @@ -0,0 +1,159 @@ +callbacks: + patience: 100 + tol: 0 + freeze_mask: None # path to .pt weight file if wanting to freeze to learned values +task: + task_name: PowerFlow +data: + baseMVA: 100 + mask_type: rnd # or determinstic + mask_ratio: 0.5 # for random masking only + mask_value: -1 + # use_learnable_mask: when true, replaces the fixed mask_value scalar with + # a per-feature learnable parameter (nn.Parameter) for each node/edge type. + # The transform still zeroes out masked positions, and the model injects the + # learned token additively in the forward pass. mask_value is ignored. + use_learnable_mask: false + # use_max_basemva (legacy): when true, baseMVA is computed as the absolute + # maximum of all non-zero Pd, Qd, Pg, Qg values. Superseded by + # basemva_mode; kept for backward compatibility. If basemva_mode is set + # explicitly, this flag is ignored. + use_max_basemva: true + # basemva_mode: controls how the per-unit base power (baseMVA) is computed + # during normalizer fitting. Choices: + # "percentile" - 95th percentile of non-zero |Pd|, |Qd|, |Pg|, |Qg| + # (per-generator Pg). Default when neither flag is set. + # "abs_max" - max of absolute non-zero values (equivalent to legacy + # use_max_basemva: true). + # "reference" - aggregates generator Pg onto buses (sum per bus), then + # takes the plain .max() over [Pd, Qd, Pg_agg, Qg]. + # This matches the homogeneous reference normalizer and + # produces identical feature scales when the underlying + # power system data is the same. + basemva_mode: reference + normalization: HeteroDataMVANormalizer + mask_ref_vm: false # match reference PF masking at REF buses + networks: + - case14_ieee + scenarios: + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + posenc_RRWP: + enable: true + ksteps: 21 + cache: true + # topk: Number of highest-ranked neighbors to retain per node based on + # the L2 norm of their multi-step random walk probability vector. + # Original graph edges and self-loops are always retained regardless of + # rank. This provides sparse attention that interpolates between: + # topk=0 (or omitted): full RRWP — all N^2 pairs (dense attention) + # topk=K: each node attends to its K structurally most important + # neighbors plus the original graph edges + # For case14 (14 buses), topk=5 retains ~36% of all pairs + graph edges, + # providing global reach while being significantly sparser than full. + topk: 5 + posenc_RWSE: + enable: false + cache: true + kernel: + times: 21 +model: + # homo_compat: when true, the GritHeteroAdapter remaps the heterogeneous + # bus features (15D) and edge features (10D) to the 9D / 2D layout used by + # the homogeneous reference model: bus x = [Pd, Qd, Pg_agg, Qg, Vm, Va, + # PQ, PV, REF], edge = [Yft_r, Yft_i]. This forces input_dim=9 and + # edge_dim=2 regardless of what is set below. + homo_compat: true + # homo_compat_topology: when true, the ConvertToYBusTopology transform is + # applied before masking. It merges parallel branches into a single Y-bus + # admittance per bus pair, adds self-loop diagonal entries (Yff + GS + jBS), + # and outputs 2D [G, B] edge features. Implies homo_compat: true. + # Positional encodings (RWSE/RRWP) are computed after this transform so + # they reflect the final graph topology. + homo_compat_topology: true + # mask_ref_vm: when false, voltage magnitude (Vm) at the reference/slack bus + # is never masked, matching the convention of the homogeneous reference + # where the slack bus Vm is always known. When true, Vm at REF buses can + # be masked like any other bus. + mask_ref_vm: false + attention_head: 8 + dropout: 0.1 + # edge_dim must match the bus-bus edge feature count after transforms + # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) + edge_dim: 10 + hidden_size: 496 + # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) + input_dim: 16 + # Hetero adapter head dimensions + input_bus_dim: 16 + input_gen_dim: 6 + output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] + output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed + num_layers: 7 + type: GRIT + act: relu + encoder: + node_encoder: true + edge_encoder: true + node_encoder_name: Linear + node_encoder_bn: true + edge_encoder_bn: true + posenc_RWSE: + # kernel.times is synced automatically from data.posenc_RWSE.kernel.times + pe_dim: 20 + raw_norm_type: batchnorm + gt: + layer_type: GritTransformer + # dim_hidden is synced automatically from model.hidden_size + layer_norm: false + batch_norm: true + update_e: true + attn_dropout: 0.2 + attn: + clamp: 5. + act: relu + # full_attn: false because the Top-K RRWP sparsification already + # determines the attention pattern. The RRWPLinearEdgeEncoder will + # merge the sparse RRWP edges with original graph edges without + # padding to a fully-connected graph. + full_attn: false + edge_enhance: true + O_e: true + norm_e: true + # norm_last_layer: controls whether the final transformer layer + # creates O_e (linear projection) and batch_norm parameters for edge + # features. When false, the last layer skips these to avoid DDP + # unused-parameter errors, since only node features feed into the + # output heads and the last layer's edge output is discarded. When + # true, all layers are identical (matching the reference architecture), + # but requires strategy: ddp_find_unused_parameters_true in DDP. + norm_last_layer: true + signed_sqrt: true + bn_momentum: 0.1 + bn_no_runner: false +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + num_nodes: 1 + loss_weights: + - 0.99 + - 0.01 + losses: + - PBE + - MaskedReconstructionMSE + loss_args: + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +verbose: true diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index e31be1fe..65fb2551 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -153,7 +153,11 @@ def make_pe_cache_dir(processed_dir: str, pe_type: str, cfg) -> str: fingerprint = f"k{kernel_times}" elif pe_type == "RRWP": ksteps = cfg.posenc_RRWP.ksteps - fingerprint = f"k{ksteps}" + topk = getattr(cfg.posenc_RRWP, "topk", 0) + if topk and topk > 0: + fingerprint = f"k{ksteps}_topk{topk}" + else: + fingerprint = f"k{ksteps}" else: fingerprint = "default" diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 7c02f7c0..8fca4f19 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -11,7 +11,7 @@ scatter_add = None from functools import partial -from gridfm_graphkit.datasets.rrwp import add_full_rrwp +from gridfm_graphkit.datasets.rrwp import add_full_rrwp, add_topk_rrwp from torch_geometric.transforms import BaseTransform from torch_geometric.data import Data, HeteroData @@ -54,13 +54,24 @@ def compute_posenc_stats(data, pe_types, cfg): if "RRWP" in pe_types: param = cfg.posenc_RRWP - transform = partial( - add_full_rrwp, - walk_length=param.ksteps, - attr_name_abs="rrwp", - attr_name_rel="rrwp", - add_identity=True, - ) + topk = getattr(param, "topk", 0) + if topk and topk > 0: + transform = partial( + add_topk_rrwp, + walk_length=param.ksteps, + topk=topk, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + ) + else: + transform = partial( + add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + ) data = transform(data) # Random Walks. diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index 9faa01af..d360117a 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -94,3 +94,134 @@ def add_full_rrwp( data.deg = deg.type(torch.long) return data + + +@torch.no_grad() +def add_topk_rrwp( + data, + walk_length=8, + topk=10, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + spd=False, + **kwargs, +): + """Compute RRWP positional encodings with Top-K sparsification. + + Instead of retaining the full N×N relative PE matrix, this function + keeps only the `topk` highest-magnitude neighbors per node (based on + the L2 norm of the multi-step random walk probability vector). The + original graph edges are always retained regardless of their rank. + + This provides a smooth interpolation between: + - topk=0 (or topk >= N): equivalent to full RRWP (all pairs) + - topk=1: nearly equivalent to RWSE (mostly self-loops / local) + + Args: + data: PyG Data object with edge_index. + walk_length: Number of random walk steps (k). + topk: Number of highest-ranked neighbors to retain per node. + If 0 or >= num_nodes, all edges are kept (full RRWP). + attr_name_abs: Attribute name for the absolute (diagonal) PE. + attr_name_rel: Prefix for relative PE index/val attributes. + add_identity: Whether to include the identity (step 0) in PE. + spd: If True, encode shortest-path distance instead of probabilities. + + Returns: + Data object with rrwp, rrwp_index, rrwp_val, log_deg, deg attributes. + """ + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + if SparseTensor is None: + raise ImportError( + "torch-sparse is required for RRWP positional encodings. " + "Install it with: pip install torch-sparse", + ) + + adj = SparseTensor.from_edge_index( + edge_index, + edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float("inf")] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (pe > 0).type(torch.float) * spd_idx.unsqueeze(0).unsqueeze(0) + val = torch.argmax(val, dim=-1) + pe = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + # --- Top-K sparsification --- + # If topk <= 0 or topk >= num_nodes, keep everything (full RRWP) + if topk <= 0 or topk >= num_nodes: + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + else: + # Score each (i,j) pair by L2 norm of the k-step probability vector + # pe shape: [n, n, k] — pe[i, j, :] is the walk vector from i to j + scores = pe.norm(dim=-1) # [n, n] + + # Always include original graph edges (set their scores to infinity) + edge_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + edge_mask[edge_index[0], edge_index[1]] = True + # Also always include self-loops + diag_idx = torch.arange(num_nodes) + edge_mask[diag_idx, diag_idx] = True + + # For Top-K selection: get top-k scores per row (per source node) + # Clamp topk to at most num_nodes + k = min(topk, num_nodes) + _, topk_indices = scores.topk(k, dim=1) # [n, k] + + # Build a mask of selected entries + topk_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + row_idx = torch.arange(num_nodes).unsqueeze(1).expand_as(topk_indices) + topk_mask[row_idx, topk_indices] = True + + # Union of top-k and original edges + keep_mask = topk_mask | edge_mask + + # Extract sparse entries from the masked pe tensor + # Zero out entries we don't want, then sparsify + pe_sparse = pe.clone() + pe_sparse[~keep_mask] = 0 + + rel_pe = SparseTensor.from_dense(pe_sparse, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data From db88a7c00df0b8c30d45cdd56a9434c416fa0222 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:45 -0400 Subject: [PATCH 86/95] reduce cache memory footprint Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/cached_transform.py | 100 ++++++++++++++++-- .../datasets/hetero_powergrid_datamodule.py | 12 +++ 2 files changed, 102 insertions(+), 10 deletions(-) diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index 65fb2551..87e7bef6 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -4,23 +4,84 @@ import torch from torch_geometric.data import HeteroData +from torch_geometric.utils import remove_self_loops + + +def _topology_fingerprint(data, cached_edge_type=None, use_admittance=False, + admittance_remove_self_loops=True): + """Compute a short hash that uniquely identifies the graph topology. + + The fingerprint captures everything that determines the positional + encoding output: edge connectivity, number of nodes, and (when + admittance-weighted) the edge weights derived from admittance values. + + This allows samples sharing the same topology (and admittance values) + to reuse a single cached PE file instead of storing redundant copies. + + Args: + data: HeteroData or Data object. + cached_edge_type: Edge type tuple for hetero data, or None. + use_admittance: Whether admittance weights affect the PE. + admittance_remove_self_loops: Whether self-loops are removed + when computing admittance weights. + + Returns: + A 16-character hex string uniquely identifying the topology. + """ + if isinstance(data, HeteroData): + edge_index = data["bus", "connects", "bus"].edge_index + num_nodes = data["bus"].num_nodes + edge_attr = data["bus", "connects", "bus"].edge_attr + else: + edge_index = data.edge_index + num_nodes = data.num_nodes + edge_attr = getattr(data, "edge_attr", None) + + h = hashlib.sha256() + h.update(num_nodes.to_bytes(4, "little") if isinstance(num_nodes, int) + else int(num_nodes).to_bytes(4, "little")) + h.update(edge_index.cpu().numpy().tobytes()) + + # Include admittance values in the fingerprint when they affect the PE + if use_admittance and edge_attr is not None: + if edge_attr.size(1) == 2: + g, b = edge_attr[:, 0], edge_attr[:, 1] + else: + # Yft at indices 4, 5 + g, b = edge_attr[:, 4], edge_attr[:, 5] + edge_weight = torch.sqrt(g ** 2 + b ** 2) + + if admittance_remove_self_loops: + _, edge_weight = remove_self_loops(edge_index, edge_weight) + + # Quantize to 6 decimal places to avoid floating-point noise + # causing spurious cache misses + quantized = (edge_weight * 1e6).round().to(torch.int64) + h.update(quantized.cpu().numpy().tobytes()) + + return h.hexdigest()[:16] class CachedPosencTransform: """Disk-caching wrapper for positional encoding transforms. - Computes the PE on the first access for each sample and caches the - result to disk. Subsequent accesses load from cache, avoiding - redundant computation across epochs and across jobs that share the - same processed directory. + Computes the PE on the first access and caches the result to disk. + Subsequent accesses with the same graph topology load from cache, + avoiding redundant computation across epochs, splits, and jobs. + + Cache keying strategy: + PE depends only on graph topology (edge_index, num_nodes) and + optionally on admittance values — NOT on node features like + loads/generation. The cache is keyed by a hash of these + topology-determining inputs so that all samples sharing the same + network structure reuse a single cache file. Thread/process safety: - Uses atomic write (write to temp file, then os.replace) so concurrent DataLoader workers or separate jobs cannot produce corrupt cache files. - - Cache is keyed by scenario_id, which is unique per graph. - Different train/val/test splits across jobs safely share the - cache since RWSE depends only on topology, not on split + - Different train/val/test splits across jobs safely share the + cache since PE depends only on topology, not on split membership. Args: @@ -31,7 +92,14 @@ class CachedPosencTransform: cached_edge_type: Optional edge type tuple (e.g. ("bus", "rrwp", "bus")) whose edge_index and edge_attr should also be cached. key_attr: Attribute on the data object used as the cache key. - Must be a scalar tensor (e.g. scenario_id). + If "topology" (default), uses a hash of the graph structure + so samples with identical topology share one cache file. + Otherwise, uses the named scalar attribute (e.g. "scenario_id") + for per-sample caching. + use_admittance: Whether admittance weights are used in the PE + computation (affects the topology fingerprint). + admittance_remove_self_loops: Whether self-loops are removed + for admittance weighting (affects the topology fingerprint). """ def __init__( @@ -40,17 +108,29 @@ def __init__( cache_dir: str, cached_attrs: list[str], cached_edge_type: tuple[str, str, str] | None = None, - key_attr: str = "scenario_id", + key_attr: str = "topology", + use_admittance: bool = False, + admittance_remove_self_loops: bool = True, ): self.transform = transform self.cache_dir = cache_dir self.cached_attrs = cached_attrs self.cached_edge_type = cached_edge_type self.key_attr = key_attr + self.use_admittance = use_admittance + self.admittance_remove_self_loops = admittance_remove_self_loops os.makedirs(cache_dir, exist_ok=True) def _cache_path(self, data) -> str: - key = data[self.key_attr].item() + if self.key_attr == "topology": + key = _topology_fingerprint( + data, + cached_edge_type=self.cached_edge_type, + use_admittance=self.use_admittance, + admittance_remove_self_loops=self.admittance_remove_self_loops, + ) + else: + key = data[self.key_attr].item() return os.path.join(self.cache_dir, f"pe_cache_{key}.pt") def _load_cache(self, cache_path, data): diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 0124550c..a19917ed 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -170,11 +170,22 @@ def setup(self, stage: str): cache_dir = make_pe_cache_dir( dataset.processed_dir, "RRWP", self.args.data, ) + _use_admw = getattr( + self.args.data.posenc_RRWP, + "use_admittance_weights", False, + ) + _rm_sl = getattr( + self.args.data.posenc_RRWP, + "admittance_remove_self_loops", True, + ) pe_transform = CachedPosencTransform( pe_transform, cache_dir, cached_attrs=["rrwp", "log_deg", "deg"], cached_edge_type=("bus", "rrwp", "bus"), + key_attr="topology", + use_admittance=_use_admw, + admittance_remove_self_loops=_rm_sl, ) if dataset.transform is None: dataset.transform = pe_transform @@ -190,6 +201,7 @@ def setup(self, stage: str): pe_transform, cache_dir, cached_attrs=["pestat_RWSE"], + key_attr="topology", ) if dataset.transform is None: dataset.transform = pe_transform From fb0dbcfc14e89a82bb8ec6b41be5d233f04e6312 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:45 -0400 Subject: [PATCH 87/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 15 ++ .../GRIT_PF_datakit_case14_topk_rrwp.yaml | 159 ------------------ .../datasets/hetero_powergrid_datamodule.py | 10 -- 3 files changed, 15 insertions(+), 169 deletions(-) delete mode 100644 examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index dbb5eadf..522bf158 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -19,6 +19,17 @@ data: posenc_RRWP: enable: false ksteps: 21 + cache: true + # topk: Number of highest-ranked neighbors to retain per node based on + # the L2 norm of their multi-step random walk probability vector. + # Original graph edges and self-loops are always retained regardless of + # rank. This provides sparse attention that interpolates between: + # topk=0 (or omitted): full RRWP — all N^2 pairs (dense attention) + # topk=K: each node attends to its K structurally most important + # neighbors plus the original graph edges + # For case14 (14 buses), topk=5 retains ~36% of all pairs + graph edges, + # providing global reach while being significantly sparser than full. + topk: 5 posenc_RWSE: enable: true cache: true # cache the computed positional encoding at first pass (if false computes every access) @@ -61,6 +72,10 @@ model: attn: clamp: 5. act: relu + # full_attn: false because the Top-K RRWP sparsification already + # determines the attention pattern. The RRWPLinearEdgeEncoder will + # merge the sparse RRWP edges with original graph edges without + # padding to a fully-connected graph. full_attn: true edge_enhance: true O_e: true diff --git a/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml b/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml deleted file mode 100644 index e20a3796..00000000 --- a/examples/config/GRIT_PF_datakit_case14_topk_rrwp.yaml +++ /dev/null @@ -1,159 +0,0 @@ -callbacks: - patience: 100 - tol: 0 - freeze_mask: None # path to .pt weight file if wanting to freeze to learned values -task: - task_name: PowerFlow -data: - baseMVA: 100 - mask_type: rnd # or determinstic - mask_ratio: 0.5 # for random masking only - mask_value: -1 - # use_learnable_mask: when true, replaces the fixed mask_value scalar with - # a per-feature learnable parameter (nn.Parameter) for each node/edge type. - # The transform still zeroes out masked positions, and the model injects the - # learned token additively in the forward pass. mask_value is ignored. - use_learnable_mask: false - # use_max_basemva (legacy): when true, baseMVA is computed as the absolute - # maximum of all non-zero Pd, Qd, Pg, Qg values. Superseded by - # basemva_mode; kept for backward compatibility. If basemva_mode is set - # explicitly, this flag is ignored. - use_max_basemva: true - # basemva_mode: controls how the per-unit base power (baseMVA) is computed - # during normalizer fitting. Choices: - # "percentile" - 95th percentile of non-zero |Pd|, |Qd|, |Pg|, |Qg| - # (per-generator Pg). Default when neither flag is set. - # "abs_max" - max of absolute non-zero values (equivalent to legacy - # use_max_basemva: true). - # "reference" - aggregates generator Pg onto buses (sum per bus), then - # takes the plain .max() over [Pd, Qd, Pg_agg, Qg]. - # This matches the homogeneous reference normalizer and - # produces identical feature scales when the underlying - # power system data is the same. - basemva_mode: reference - normalization: HeteroDataMVANormalizer - mask_ref_vm: false # match reference PF masking at REF buses - networks: - - case14_ieee - scenarios: - - 5000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 4 - posenc_RRWP: - enable: true - ksteps: 21 - cache: true - # topk: Number of highest-ranked neighbors to retain per node based on - # the L2 norm of their multi-step random walk probability vector. - # Original graph edges and self-loops are always retained regardless of - # rank. This provides sparse attention that interpolates between: - # topk=0 (or omitted): full RRWP — all N^2 pairs (dense attention) - # topk=K: each node attends to its K structurally most important - # neighbors plus the original graph edges - # For case14 (14 buses), topk=5 retains ~36% of all pairs + graph edges, - # providing global reach while being significantly sparser than full. - topk: 5 - posenc_RWSE: - enable: false - cache: true - kernel: - times: 21 -model: - # homo_compat: when true, the GritHeteroAdapter remaps the heterogeneous - # bus features (15D) and edge features (10D) to the 9D / 2D layout used by - # the homogeneous reference model: bus x = [Pd, Qd, Pg_agg, Qg, Vm, Va, - # PQ, PV, REF], edge = [Yft_r, Yft_i]. This forces input_dim=9 and - # edge_dim=2 regardless of what is set below. - homo_compat: true - # homo_compat_topology: when true, the ConvertToYBusTopology transform is - # applied before masking. It merges parallel branches into a single Y-bus - # admittance per bus pair, adds self-loop diagonal entries (Yff + GS + jBS), - # and outputs 2D [G, B] edge features. Implies homo_compat: true. - # Positional encodings (RWSE/RRWP) are computed after this transform so - # they reflect the final graph topology. - homo_compat_topology: true - # mask_ref_vm: when false, voltage magnitude (Vm) at the reference/slack bus - # is never masked, matching the convention of the homogeneous reference - # where the slack bus Vm is always known. When true, Vm at REF buses can - # be masked like any other bus. - mask_ref_vm: false - attention_head: 8 - dropout: 0.1 - # edge_dim must match the bus-bus edge feature count after transforms - # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) - edge_dim: 10 - hidden_size: 496 - # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) - input_dim: 16 - # Hetero adapter head dimensions - input_bus_dim: 16 - input_gen_dim: 6 - output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] - output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed - num_layers: 7 - type: GRIT - act: relu - encoder: - node_encoder: true - edge_encoder: true - node_encoder_name: Linear - node_encoder_bn: true - edge_encoder_bn: true - posenc_RWSE: - # kernel.times is synced automatically from data.posenc_RWSE.kernel.times - pe_dim: 20 - raw_norm_type: batchnorm - gt: - layer_type: GritTransformer - # dim_hidden is synced automatically from model.hidden_size - layer_norm: false - batch_norm: true - update_e: true - attn_dropout: 0.2 - attn: - clamp: 5. - act: relu - # full_attn: false because the Top-K RRWP sparsification already - # determines the attention pattern. The RRWPLinearEdgeEncoder will - # merge the sparse RRWP edges with original graph edges without - # padding to a fully-connected graph. - full_attn: false - edge_enhance: true - O_e: true - norm_e: true - # norm_last_layer: controls whether the final transformer layer - # creates O_e (linear projection) and batch_norm parameters for edge - # features. When false, the last layer skips these to avoid DDP - # unused-parameter errors, since only node features feed into the - # output heads and the last layer's edge output is discarded. When - # true, all layers are identical (matching the reference architecture), - # but requires strategy: ddp_find_unused_parameters_true in DDP. - norm_last_layer: true - signed_sqrt: true - bn_momentum: 0.1 - bn_no_runner: false -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0001 - lr_decay: 0.7 - lr_patience: 10 -seed: 0 -training: - batch_size: 8 - epochs: 500 - num_nodes: 1 - loss_weights: - - 0.99 - - 0.01 - losses: - - PBE - - MaskedReconstructionMSE - loss_args: - - {} - - {} - accelerator: auto - devices: auto - strategy: auto -verbose: true diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index a19917ed..f406549c 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -170,22 +170,12 @@ def setup(self, stage: str): cache_dir = make_pe_cache_dir( dataset.processed_dir, "RRWP", self.args.data, ) - _use_admw = getattr( - self.args.data.posenc_RRWP, - "use_admittance_weights", False, - ) - _rm_sl = getattr( - self.args.data.posenc_RRWP, - "admittance_remove_self_loops", True, - ) pe_transform = CachedPosencTransform( pe_transform, cache_dir, cached_attrs=["rrwp", "log_deg", "deg"], cached_edge_type=("bus", "rrwp", "bus"), key_attr="topology", - use_admittance=_use_admw, - admittance_remove_self_loops=_rm_sl, ) if dataset.transform is None: dataset.transform = pe_transform From d46680f9629319bcd33bc0f06113cf9a66aa74b2 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:45 -0400 Subject: [PATCH 88/95] change encoder Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3ad7545a..8e3a41e6 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -118,7 +118,7 @@ def __init__(self, dim_in, dim_inner, args): # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) if args.encoder.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dEdge(enc_dim_edge, 1e-5, 0.1) + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): From f7af5bedbebd4c7b645706d144d9d0bac7d2b7de Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:45 -0400 Subject: [PATCH 89/95] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/callbacks.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index 18af3420..c2a92e15 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -106,28 +106,3 @@ def on_train_epoch_end(self, trainer, pl_module): model_path = os.path.join(model_dir, self.filename) torch.save(pl_module.state_dict(), model_path) -class FreezeMaskTokens(Callback): - """Inject pre-trained mask tokens and freeze them. - - Replaces nn.Parameter with a registered buffer so DDP does not expect - gradients for these tensors. - """ - - def __init__(self, mask_state_path: str): - super().__init__() - self.mask_state_path = mask_state_path - - def setup(self, trainer, pl_module, stage=None): - if stage != "fit": - return - saved = torch.load(self.mask_state_path, map_location="cpu") - model = pl_module.model - - for name in ("bus_mask_token", "edge_mask_token", "gen_mask_token"): - key = f"model.{name}" - if key in saved and hasattr(model, name): - tensor = saved[key] - # Remove the nn.Parameter and re-register as a buffer so DDP - # won't include it in gradient reduction. - delattr(model, name) - model.register_buffer(name, tensor) From 4c09970fc9f82efba3713c486637c405430d2974 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:16:45 -0400 Subject: [PATCH 90/95] linting Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 2 +- gridfm_graphkit/datasets/cached_transform.py | 17 ++++++++++++----- .../datasets/hetero_powergrid_datamodule.py | 8 ++++++-- gridfm_graphkit/training/callbacks.py | 1 - 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 522bf158..5ce232c3 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -24,7 +24,7 @@ data: # the L2 norm of their multi-step random walk probability vector. # Original graph edges and self-loops are always retained regardless of # rank. This provides sparse attention that interpolates between: - # topk=0 (or omitted): full RRWP — all N^2 pairs (dense attention) + # topk=0 (or omitted): full RRWP - all N^2 pairs (dense attention) # topk=K: each node attends to its K structurally most important # neighbors plus the original graph edges # For case14 (14 buses), topk=5 retains ~36% of all pairs + graph edges, diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py index 87e7bef6..17f37878 100644 --- a/gridfm_graphkit/datasets/cached_transform.py +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -7,8 +7,12 @@ from torch_geometric.utils import remove_self_loops -def _topology_fingerprint(data, cached_edge_type=None, use_admittance=False, - admittance_remove_self_loops=True): +def _topology_fingerprint( + data, + cached_edge_type=None, + use_admittance=False, + admittance_remove_self_loops=True, +): """Compute a short hash that uniquely identifies the graph topology. The fingerprint captures everything that determines the positional @@ -38,8 +42,11 @@ def _topology_fingerprint(data, cached_edge_type=None, use_admittance=False, edge_attr = getattr(data, "edge_attr", None) h = hashlib.sha256() - h.update(num_nodes.to_bytes(4, "little") if isinstance(num_nodes, int) - else int(num_nodes).to_bytes(4, "little")) + h.update( + num_nodes.to_bytes(4, "little") + if isinstance(num_nodes, int) + else int(num_nodes).to_bytes(4, "little"), + ) h.update(edge_index.cpu().numpy().tobytes()) # Include admittance values in the fingerprint when they affect the PE @@ -49,7 +56,7 @@ def _topology_fingerprint(data, cached_edge_type=None, use_admittance=False, else: # Yft at indices 4, 5 g, b = edge_attr[:, 4], edge_attr[:, 5] - edge_weight = torch.sqrt(g ** 2 + b ** 2) + edge_weight = torch.sqrt(g**2 + b**2) if admittance_remove_self_loops: _, edge_weight = remove_self_loops(edge_index, edge_weight) diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index f406549c..2bf91a5d 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -168,7 +168,9 @@ def setup(self, stage: str): pe_transform = ComputePosencStat(pe_types=["RRWP"], cfg=self.args.data) if getattr(self.args.data.posenc_RRWP, "cache", False): cache_dir = make_pe_cache_dir( - dataset.processed_dir, "RRWP", self.args.data, + dataset.processed_dir, + "RRWP", + self.args.data, ) pe_transform = CachedPosencTransform( pe_transform, @@ -185,7 +187,9 @@ def setup(self, stage: str): pe_transform = ComputePosencStat(pe_types=["RWSE"], cfg=self.args.data) if getattr(self.args.data.posenc_RWSE, "cache", False): cache_dir = make_pe_cache_dir( - dataset.processed_dir, "RWSE", self.args.data, + dataset.processed_dir, + "RWSE", + self.args.data, ) pe_transform = CachedPosencTransform( pe_transform, diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index c2a92e15..333c9976 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -105,4 +105,3 @@ def on_train_epoch_end(self, trainer, pl_module): os.makedirs(model_dir, exist_ok=True) model_path = os.path.join(model_dir, self.filename) torch.save(pl_module.state_dict(), model_path) - From 91ce9f099b7817b693b612714c14a0a1e6022915 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 17 Jun 2026 14:11:47 -0400 Subject: [PATCH 91/95] settle merge conflicts Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 572cd54f..cb0d9f07 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -185,7 +185,7 @@ def forward( edge_attr_dict, mask_dict, model=None, - **kwargs, + x_dict=None, ): pred_bus = pred_dict["bus"] target_bus = target_dict["bus"] From c1706e75789a9c7881b5b38b3f28e70323d4d040 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 17 Jun 2026 14:11:48 -0400 Subject: [PATCH 92/95] adjust grit config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 5ce232c3..fd53133b 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -76,7 +76,7 @@ model: # determines the attention pattern. The RRWPLinearEdgeEncoder will # merge the sparse RRWP edges with original graph edges without # padding to a fully-connected graph. - full_attn: true + full_attn: false # true only for full RRWP edge_enhance: true O_e: true norm_e: true From af76a9feeb60d4a013162611905db732fbd0dd0a Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 17 Jun 2026 14:11:48 -0400 Subject: [PATCH 93/95] adjust grit config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index fd53133b..e7845147 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -7,7 +7,7 @@ data: baseMVA: 100 mask_type: rnd # or determinstic mask_ratio: 0.5 # for random masking only - mask_value: -1 + mask_value: 0 normalization: HeteroDataMVANormalizer networks: - case14_ieee From 34bafe9df491021eb81acf812d14d95055acf3ef Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 17 Jun 2026 14:11:48 -0400 Subject: [PATCH 94/95] clamp un-predicted values in predict Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/pf_task.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 831db0e7..f41a0b55 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -433,7 +433,16 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) # compute the branch flows + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], + target, + batch, + gen_to_bus_index, + num_bus, + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) # compute the branch flows P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections residual_P, residual_Q = node_residuals_layer( # compute the node residuals P_in, From 41d256713e30741a9159d1ddf8b974aa3da2f87a Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 17 Jun 2026 14:11:48 -0400 Subject: [PATCH 95/95] precommit checks Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- SECURITY.md | 12 +-- docs/quick_start/yaml_config.md | 4 +- gridfm_graphkit/__main__.py | 3 + gridfm_graphkit/cli.py | 21 ++-- .../datasets/hetero_powergrid_datamodule.py | 30 ++++-- gridfm_graphkit/datasets/masking.py | 2 + gridfm_graphkit/datasets/normalizers.py | 25 ++++- .../datasets/powergrid_hetero_dataset.py | 19 +++- gridfm_graphkit/datasets/task_transforms.py | 3 + gridfm_graphkit/datasets/transforms.py | 1 + gridfm_graphkit/datasets/utils.py | 6 +- gridfm_graphkit/io/registries.py | 1 + gridfm_graphkit/models/utils.py | 3 + gridfm_graphkit/tasks/opf_ac_dc_baseline.py | 96 +++++++++++++------ gridfm_graphkit/tasks/opf_task.py | 37 ++++--- gridfm_graphkit/tasks/pf_task.py | 54 +++++++---- gridfm_graphkit/tasks/se_task.py | 1 + gridfm_graphkit/training/callbacks.py | 1 + gridfm_graphkit/training/loss.py | 27 +++--- integrationtests/conftest.py | 1 - integrationtests/generate_test_data.py | 18 +++- integrationtests/test_base_set.py | 51 +++++++--- 22 files changed, 290 insertions(+), 126 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index edc0d73a..aedac7f0 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -58,8 +58,8 @@ Users are strongly encouraged to upgrade to the latest release to receive securi We aim to follow these response targets: -- **Initial acknowledgment**: within 72 hours -- **Status update**: within 7 days +- **Initial acknowledgment**: within 72 hours +- **Status update**: within 7 days - **Resolution target**: within 90 days (depending on severity) These are targets, not guarantees. @@ -79,10 +79,10 @@ These are targets, not guarantees. We follow a **coordinated vulnerability disclosure (CVD)** process: -- We work with reporters to agree on a disclosure timeline -- Public disclosure occurs after a fix is available or mitigation exists -- Contributors are credited unless anonymity is requested -- CVE identifiers will be requested when appropriate +- We work with reporters to agree on a disclosure timeline +- Public disclosure occurs after a fix is available or mitigation exists +- Contributors are credited unless anonymity is requested +- CVE identifiers will be requested when appropriate --- diff --git a/docs/quick_start/yaml_config.md b/docs/quick_start/yaml_config.md index 96c2fa62..568d6f67 100644 --- a/docs/quick_start/yaml_config.md +++ b/docs/quick_start/yaml_config.md @@ -85,12 +85,12 @@ Task name registered in the framework: ### `data.networks` -List of dataset folders under your data root. +List of dataset folders under your data root. Examples: `case14_ieee`, `case118_ieee`, `case2000_goc`, `Texas2k_case1_2016summerpeak`. ### `data.scenarios` -List of scenario counts, one value per network in `data.networks`. +List of scenario counts, one value per network in `data.networks`. Example: with two networks, use two scenario entries in matching order. ### `data.normalization` diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 695ba531..03166542 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -22,6 +22,7 @@ def _warn_mp_context_on_linux(mp_context): stacklevel=2, ) + def is_lsf(): return ( os.environ.get("LSB_JOBID") is not None @@ -29,6 +30,7 @@ def is_lsf(): and "LSF_ENVDIR" in os.environ # strong LSF indicator ) + def fix_infiniband(): """Configure NCCL to skip Ethernet-only IB ports on this host.""" ibv = subprocess.run("ibv_devinfo", stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -71,6 +73,7 @@ def set_env(): ) os.environ["NCCL_IB_CUDA_SUPPORT"] = "1" # Force use of infiniband + def main(): """Parse CLI arguments and dispatch to the selected GridFM subcommand.""" if is_lsf(): diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 53a970c8..d3256357 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -25,7 +25,9 @@ import lightning as L -def _normalize_loaded_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: +def _normalize_loaded_state_dict_keys( + state_dict: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: """Map legacy torch.compile checkpoint keys to the canonical model namespace.""" has_compiled_prefix = any(key.startswith("model._orig_mod.") for key in state_dict) if not has_compiled_prefix: @@ -163,7 +165,7 @@ def main_cli(args): run_name=args.run_name, ) - # When using torch.compile with Triton, dynamic graph support can cause + # When using torch.compile with Triton, dynamic graph support can cause # out-of-memory errors during autotuning on some kernels. # Disabling dynamic graph support allows those kernels # to be skipped gracefully instead of causing errors. @@ -238,12 +240,17 @@ def main_cli(args): _accelerator = config_args.training.accelerator _strategy = config_args.training.strategy # if mps is available and accelerator is auto, explicitely set accelerator to mps to select the right strategy in the next block - if _accelerator == "auto" and torch.backends.mps.is_available(): + if _accelerator == "auto" and torch.backends.mps.is_available(): _accelerator = "mps" - if _accelerator not in ("mps", "cpu") and isinstance(_strategy, str) and _strategy in ( - "auto", - "ddp", - ): # when using mps, we don't want to use ddp. + if ( + _accelerator not in ("mps", "cpu") + and isinstance(_strategy, str) + and _strategy + in ( + "auto", + "ddp", + ) + ): # when using mps, we don't want to use ddp. _strategy = DDPStrategy(find_unused_parameters=False) trainer = L.Trainer( diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 854e7391..6287b6c6 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -132,13 +132,18 @@ def __init__( self._is_setup_done = False if self.split_by_load_scenario_idx: - assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + assert self.split_from_existing_files is None, ( + " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + ) if self.split_from_existing_files is not None: - assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" + assert isinstance(self.split_from_existing_files, str), ( + "`split_from_existing_files` must be an existing folder in string format" + ) self.split_from_existing_files = Path(self.split_from_existing_files) - assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" - + assert self.split_from_existing_files.is_dir(), ( + "`split_from_existing_files` must be an existing folder in string format" + ) def setup(self, stage: str): if self._is_setup_done: @@ -233,7 +238,6 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) - if self.split_from_existing_files is not None: warnings.warn( "`data.scenarios` is ignored when `split_from_existing_files` is set; " @@ -278,13 +282,14 @@ def setup(self, stage: str): # load_scenario for each scenario in the subset load_scenarios = dataset.load_scenarios[subset_indices] - dataset = Subset(dataset, subset_indices) - + if self.dataset_wrapper is not None: wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - + dataset = wrapper_cls( + dataset, + cache_dir=self.dataset_wrapper_cache_dir, + ) # Random seed set before every split, same as above np.random.seed(self.args.seed) @@ -479,7 +484,12 @@ def _dataloader_kwargs(self): return kwargs def train_dataloader(self): - print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed") + print( + "creating train dataloader for rank ", + dist.get_rank() + if dist.is_available() and dist.is_initialized() + else "not distributed", + ) return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index da78e6c4..3b8f78fe 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -212,6 +212,7 @@ def forward(self, data): class BusToGenBroadcaster(MessagePassing): """Broadcast per-bus values to connected generators via graph propagation.""" + def __init__(self, aggr="add"): super().__init__(aggr=aggr) @@ -230,6 +231,7 @@ def message(self, x_j): class SimulateMeasurements(BaseTransform): """Add configurable noise/outliers and masks to simulate measured quantities.""" + def __init__(self, args): super().__init__() self.measurements = args.task.measurements diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 557e5cab..d53e24cd 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -230,8 +230,14 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = torch.tensor(self.baseMVA, dtype=data.x_dict["bus"].dtype) # # needs to be float32 for MPS - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.baseMVA = torch.tensor( + self.baseMVA, + dtype=data.x_dict["bus"].dtype, + ) # # needs to be float32 for MPS + data.is_normalized = torch.tensor( + True, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): if self.baseMVA is None or self.baseMVA == 0: @@ -301,7 +307,10 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= self.baseMVA - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): bus_output = output["bus"] @@ -516,7 +525,10 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= e_b - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + True, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): """Undo per-unit normalization (multiply by baseMVA, inverse log1p for cost coeffs).""" @@ -579,7 +591,10 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= e_b - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): """ diff --git a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py index 82f57a57..a0510ddb 100644 --- a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py @@ -73,9 +73,14 @@ def process(self): ) if "load_scenario_idx" in bus_data.columns: load_scenarios = torch.tensor( - bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, + bus_data.groupby("scenario", sort=True)["load_scenario_idx"] + .first() + .values, + ) + torch.save( + load_scenarios, + osp.join(self.processed_dir, "load_scenarios.pt"), ) - torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) agg_gen = ( gen_data.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] @@ -136,7 +141,9 @@ def process(self): ] + common_branch_features # Group by scenario - bus_groups = bus_data.groupby("scenario") # Groupby preserves the order of rows within each group. + bus_groups = bus_data.groupby( + "scenario", + ) # Groupby preserves the order of rows within each group. # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html gen_groups = gen_data.groupby("scenario") branch_groups = branch_data.groupby("scenario") @@ -159,8 +166,10 @@ def process(self): # Bus nodes bus_df = bus_groups.get_group(scenario) # assert that the buses are in increasing order - assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), "Buses are not in increasing order" - #todo: we should remove this assert and store the bus idx in the tensors + assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), ( + "Buses are not in increasing order" + ) + # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order for e.g. the predict step that uses torch.arange(n_nodes) to index the buses. data["bus"].x = torch.tensor(bus_df[bus_features].values, dtype=torch.float) diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index 6e75e92a..64b7e223 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -17,6 +17,7 @@ @TRANSFORM_REGISTRY.register("PowerFlow") class PowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for PowerFlow datasets.""" + def __init__(self, args): transforms = [] @@ -38,6 +39,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("OptimalPowerFlow") class OptimalPowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for OptimalPowerFlow datasets.""" + def __init__(self, args): transforms = [] @@ -59,6 +61,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("StateEstimation") class StateEstimationTransforms(Compose): """Compose preprocessing and measurement transforms for StateEstimation datasets.""" + def __init__(self, args): transforms = [] diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index c6891dc2..f58a1fec 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -97,6 +97,7 @@ def forward(self, data): class LoadGridParamsFromPath(BaseTransform): """Inject static grid parameters from a saved grid template into each sample.""" + def __init__(self, args): super().__init__() self.grid_path = args.task.grid_path diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index 65b34f4e..8d16de92 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -114,8 +114,8 @@ def split_from_existing_files( split_dataset = Subset(dataset, split_indices) output.append(split_dataset) split_indices = list(split_indices) - print(f'{split=} {len(split_indices)=}') - indices[split]=[int(t.item()) for t in split_indices] + print(f"{split=} {len(split_indices)=}") + indices[split] = [int(t.item()) for t in split_indices] output = tuple(output) - return output, indices \ No newline at end of file + return output, indices diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 65d596a9..c26f07b9 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -1,5 +1,6 @@ class Registry: """Simple name-to-object registry with decorator-based registration.""" + def __init__(self, name: str): self._name = name self._registry = {} diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index bc4b9bfa..41b75f20 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -82,6 +82,7 @@ def compute_shunt_power(bus_data_pred, bus_data_orig): @PHYSICS_DECODER_REGISTRY.register("OptimalPowerFlow") class PhysicsDecoderOPF(nn.Module): """Map network outputs to OPF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): mask_pv = mask_dict["PV"] mask_ref = mask_dict["REF"] @@ -117,6 +118,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("PowerFlow") class PhysicsDecoderPF(nn.Module): """Map network outputs to PF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): """ PF decoder: @@ -165,6 +167,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("StateEstimation") class PhysicsDecoderSE(nn.Module): """Map network outputs to SE targets via bus power-balance relations.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) Vm_out = bus_data_pred[:, VM_OUT] diff --git a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py index a13a5312..601a4e71 100644 --- a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py +++ b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py @@ -28,7 +28,6 @@ ) from gridfm_graphkit.tasks.pf_ac_dc_baseline import ( N_SCENARIO_PER_PARTITION, - NUM_PROCESSES, _compute_residual_stats, _compute_runtime_stats, ) @@ -65,7 +64,9 @@ def _load_test_data(data_dir: str, test_scenario_ids: list[int]): bus_df = bus_df[bus_df["scenario"].isin(test_set)].reset_index(drop=True) gen_df = gen_df[gen_df["scenario"].isin(test_set)].reset_index(drop=True) branch_df = branch_df[branch_df["scenario"].isin(test_set)].reset_index(drop=True) - runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index(drop=True) + runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index( + drop=True, + ) print( f" Loaded {len(bus_df)} bus rows, {len(gen_df)} gen rows, " @@ -85,8 +86,12 @@ def _compute_optimality_gap(gen_df: pd.DataFrame) -> dict: pg_ac = gen_df["p_mw"].to_numpy(dtype=float) pg_dc = gen_df["p_mw_dc"].to_numpy(dtype=float) g = gen_df.copy() - g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g["in_service"] # all is already in MW - g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g["in_service"] # all is already in MW + g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g[ + "in_service" + ] # all is already in MW + g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g[ + "in_service" + ] # all is already in MW per_scenario = g.groupby("scenario")[["cost_ac", "cost_dc"]].sum() cost_ac = per_scenario["cost_ac"].to_numpy(dtype=float) cost_dc = per_scenario["cost_dc"].to_numpy(dtype=float) @@ -113,26 +118,44 @@ def _compute_pg_violations(gen_df: pd.DataFrame) -> dict: def _compute_qg_violations_ac(bus_df: pd.DataFrame, gen_df: pd.DataFrame) -> dict: """Compute AC reactive-power limit violations for PV/REF buses.""" - # opf_task style on bus Qg; AC only + # opf_task style on bus Qg; AC only bus = bus_df.copy() qg = bus["Qg"].to_numpy(dtype=float) agg_gen = ( - gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] - .sum() - .reset_index()) + gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] + .sum() + .reset_index() + ) bus = bus.merge(agg_gen, on=["scenario", "bus"], how="left") - assert bus[bus["PV"]==1]["min_q_mvar"].isna().sum() == 0, "PV buses have no min_q_mvar" - assert bus[bus["PV"]==1]["max_q_mvar"].isna().sum() == 0, "PV buses have no max_q_mvar" - assert bus[bus["REF"]==1]["min_q_mvar"].isna().sum() == 0, "REF buses have no min_q_mvar" - assert bus[bus["REF"]==1]["max_q_mvar"].isna().sum() == 0, "REF buses have no max_q_mvar" - bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum(bus["min_q_mvar"] - qg, 0.0) + assert bus[bus["PV"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "PV buses have no min_q_mvar" + ) + assert bus[bus["PV"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "PV buses have no max_q_mvar" + ) + assert bus[bus["REF"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "REF buses have no min_q_mvar" + ) + assert bus[bus["REF"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "REF buses have no max_q_mvar" + ) + bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum( + bus["min_q_mvar"] - qg, + 0.0, + ) pv = bus[bus["PV"] == 1] ref = bus[bus["REF"] == 1] pv_ref = bus[(bus["PV"] == 1) | (bus["REF"] == 1)] return { - "AC Mean Qg violation PV buses": float(np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation REF buses": float(np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation": float(np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float))), + "AC Mean Qg violation PV buses": float( + np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation REF buses": float( + np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation": float( + np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float)), + ), } @@ -140,13 +163,21 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> """Compute AC/DC branch thermal and angle-limit violation statistics.""" rate = branch_df["rate_a"].to_numpy(dtype=float) ac_from = np.sqrt( - branch_df["pf"].to_numpy(dtype=float) ** 2 + branch_df["qf"].to_numpy(dtype=float) ** 2, + branch_df["pf"].to_numpy(dtype=float) ** 2 + + branch_df["qf"].to_numpy(dtype=float) ** 2, ) ac_to = np.sqrt( - branch_df["pt"].to_numpy(dtype=float) ** 2 + branch_df["qt"].to_numpy(dtype=float) ** 2, + branch_df["pt"].to_numpy(dtype=float) ** 2 + + branch_df["qt"].to_numpy(dtype=float) ** 2, + ) + dc_from = np.sqrt( + branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2, + ) # reactive part is needed here + dc_to = np.sqrt( + branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2, ) - dc_from = np.sqrt(branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2) # reactive part is needed here - dc_to = np.sqrt(branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2) ac_thermal_from = np.maximum(ac_from - rate, 0.0) ac_thermal_to = np.maximum(ac_to - rate, 0.0) @@ -155,7 +186,7 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> bus_angles = bus_df[["scenario", "bus", "Va", "Va_dc"]] # convert to radians - bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 + bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 bus_angles.loc[:, "Va_dc"] = bus_angles["Va_dc"] * np.pi / 180.0 from_angles = bus_angles.rename( columns={"bus": "from_bus", "Va": "Va_from", "Va_dc": "Va_dc_from"}, @@ -165,10 +196,12 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> ) br = branch_df.merge(from_angles, on=["scenario", "from_bus"], how="left") br = br.merge(to_angles, on=["scenario", "to_bus"], how="left") - + # AC angle ac_angle_diff = br["Va_from"] - br["Va_to"] - ac_angle_diff = (ac_angle_diff + np.pi) % (2 * np.pi) - np.pi # wrap to [-pi, pi] + ac_angle_diff = (ac_angle_diff + np.pi) % ( + 2 * np.pi + ) - np.pi # wrap to [-pi, pi] ac_angle_excess_low = np.maximum(br["ang_min"] - ac_angle_diff, 0.0) ac_angle_excess_high = np.maximum(ac_angle_diff - br["ang_max"], 0.0) mean_ac_angle_violation = np.mean(ac_angle_excess_low + ac_angle_excess_high) @@ -180,12 +213,20 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> mean_dc_angle_violation = np.mean(dc_angle_excess_low + dc_angle_excess_high) return { - "AC Mean branch thermal violation from (MVA)": float(np.nanmean(ac_thermal_from)), + "AC Mean branch thermal violation from (MVA)": float( + np.nanmean(ac_thermal_from), + ), "AC Mean branch thermal violation to (MVA)": float(np.nanmean(ac_thermal_to)), - "AC Mean branch angle difference violation (radians)": float(mean_ac_angle_violation), - "DC Mean branch thermal violation from (MVA)": float(np.nanmean(dc_thermal_from)), + "AC Mean branch angle difference violation (radians)": float( + mean_ac_angle_violation, + ), + "DC Mean branch thermal violation from (MVA)": float( + np.nanmean(dc_thermal_from), + ), "DC Mean branch thermal violation to (MVA)": float(np.nanmean(dc_thermal_to)), - "DC Mean branch angle difference violation (radians)": float(mean_dc_angle_violation), + "DC Mean branch angle difference violation (radians)": float( + mean_dc_angle_violation, + ), } @@ -256,7 +297,6 @@ def compute_opf_ac_dc_metrics( branch_df["pt_dc_computed"] = pt_dc branch_df["qf_dc_computed"] = qf_dc branch_df["qt_dc_computed"] = qt_dc - opf_extra = {} opf_extra.update(_compute_optimality_gap(gen_df)) diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index cc245ebd..fe517deb 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -87,14 +87,18 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): c2 = batch.x_dict["gen"][:, C2_H] target_pg = batch.y_dict["gen"].squeeze() pred_pg = output["gen"].squeeze() - gen_cost_gt = (c0 + c1 * target_pg + c2 * target_pg**2) # assumes all branches are on! - gen_cost_pred = (c0 + c1 * pred_pg + c2 * pred_pg**2) # assumes all branches are on! + gen_cost_gt = ( + c0 + c1 * target_pg + c2 * target_pg**2 + ) # assumes all branches are on! + gen_cost_pred = ( + c0 + c1 * pred_pg + c2 * pred_pg**2 + ) # assumes all branches are on! gen_batch = batch.batch_dict["gen"] # shape: [N_gen_total] cost_gt = scatter_add(gen_cost_gt, gen_batch, dim=0) cost_pred = scatter_add(gen_cost_pred, gen_batch, dim=0) - + optimality_gap = torch.mean(torch.abs((cost_pred - cost_gt) / cost_gt * 100)) agg_gen_on_bus = scatter_add( @@ -138,14 +142,16 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): bus_angles = output["bus"][:, VA_OUT] # in degrees from_bus = bus_edge_index[0] to_bus = bus_edge_index[1] - angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign - angle_diff = (angle_diff + torch.pi) % (2 * torch.pi) - torch.pi # wrap to [-pi, pi] + angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign + angle_diff = (angle_diff + torch.pi) % ( + 2 * torch.pi + ) - torch.pi # wrap to [-pi, pi] angle_excess_low = F.relu(angle_min - angle_diff) angle_excess_high = F.relu(angle_diff - angle_max) branch_angle_violation_mean = torch.mean( - angle_excess_low + angle_excess_high - ) # mean of the abs violation + angle_excess_low + angle_excess_high, + ) # mean of the abs violation P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( @@ -174,8 +180,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): mean_Qg_violation_PV = Qg_violation_amount[mask_PV].mean() mean_Qg_violation_REF = Qg_violation_amount[mask_REF].mean() - mask_PV_REF = mask_PV | mask_REF # PV or REF buses - mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # + mask_PV_REF = mask_PV | mask_REF # PV or REF buses + mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # if self.args.verbose: mean_res_P_PQ, max_res_P_PQ = residual_stats_by_type( @@ -275,7 +281,9 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) - loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + loss_dict["Mean Qg violation PV buses"] = ( + mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + ) # this is then overaged over all the batches and gives same weight to all batches despite them possibly having varying number of branches loss_dict["Mean Qg violation REF buses"] = mean_Qg_violation_REF loss_dict["Mean Qg violation"] = mean_Qg_violation @@ -377,7 +385,10 @@ def on_test_end(self): "Branch thermal violation from", " ", ) - branch_thermal_violation_to = metrics.get("Branch thermal violation to", " ") + branch_thermal_violation_to = metrics.get( + "Branch thermal violation to", + " ", + ) branch_angle_violation = metrics.get( "Branch voltage angle difference violations", " ", @@ -548,9 +559,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): local_bus_idx = torch.cat( [ torch.arange(c, device=bus_batch.device) - for c in torch.bincount(bus_batch) + for c in torch.bincount(bus_batch) ], - ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. + ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. bus_x = batch.x_dict["bus"] bus_y = batch.y_dict["bus"] diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index f41a0b55..a60e61d9 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -309,7 +309,7 @@ def on_test_end(self): # Only rank 0 proceeds with logging, CSV writing, and plotting if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: - self.test_outputs.clear() # clear the test outputs for other ranks + self.test_outputs.clear() # clear the test outputs for other ranks return if isinstance(self.logger, MLFlowLogger): @@ -420,18 +420,31 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) # get the predicted output from the model + output, _ = self.shared_step(batch) # get the predicted output from the model - self.data_normalizers[dataloader_idx].inverse_transform(batch) # normalize the batch data back to the original scale - self.data_normalizers[dataloader_idx].inverse_output(output, batch) # inverse transform the predicted output back to the original scale - - branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows - node_injection_layer = ComputeNodeInjection() # layer to compute the node injections - node_residuals_layer = ComputeNodeResiduals() # layer to compute the node residuals - - num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections + self.data_normalizers[dataloader_idx].inverse_transform( + batch, + ) # normalize the batch data back to the original scale + self.data_normalizers[dataloader_idx].inverse_output( + output, + batch, + ) # inverse transform the predicted output back to the original scale + + branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows + node_injection_layer = ( + ComputeNodeInjection() + ) # layer to compute the node injections + node_residuals_layer = ( + ComputeNodeResiduals() + ) # layer to compute the node residuals + + num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch + bus_edge_index = batch.edge_index_dict[ + ("bus", "connects", "bus") + ] # from and to buses + bus_edge_attr = batch.edge_attr_dict[ + ("bus", "connects", "bus") + ] # edge attributes (admittance) of the bus connections target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) eval_bus = _clamp_known_to_ground_truth( @@ -442,9 +455,18 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): num_bus, ) - Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) # compute the branch flows - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections - residual_P, residual_Q = node_residuals_layer( # compute the node residuals + Pft, Qft = branch_flow_layer( + eval_bus, + bus_edge_index, + bus_edge_attr, + ) # compute the branch flows + P_in, Q_in = node_injection_layer( + Pft, + Qft, + bus_edge_index, + num_bus, + ) # compute the node injections + residual_P, residual_Q = node_residuals_layer( # compute the node residuals P_in, Q_in, eval_bus, @@ -461,7 +483,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch) ], - ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. + ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order and we have an assert in the dataset to check it. bus_x = batch.x_dict["bus"] diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 36667ad2..78aa0fa5 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -27,6 +27,7 @@ @TASK_REGISTRY.register("StateEstimation") class StateEstimationTask(ReconstructionTask): """State-estimation task with evaluation plots for masked and noisy measurements.""" + def __init__(self, args, data_normalizers): super().__init__(args, data_normalizers) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index ba7a4049..116aa4f0 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -42,6 +42,7 @@ def last_epoch_iters_per_sec(self) -> float | None: class SaveBestModelStateDict(Callback): """Persist the best model state_dict according to a monitored validation metric.""" + def __init__( self, monitor: str, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index cb0d9f07..acbceb78 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -30,7 +30,7 @@ YFT_TF_R, YFT_TF_I, # Qg Limits - MIN_QG_H, + MIN_QG_H, MAX_QG_H, ) @@ -95,6 +95,7 @@ def forward( @LOSS_REGISTRY.register("MaskedGenMSE") class MaskedGenMSE(torch.nn.Module): """Compute MSE on generator targets restricted to generator mask entries.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -123,6 +124,7 @@ def forward( @LOSS_REGISTRY.register("MaskedBusMSE") class MaskedBusMSE(torch.nn.Module): """Compute MSE on selected bus targets, respecting task-specific output columns.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -340,6 +342,7 @@ def forward( @LOSS_REGISTRY.register("LayeredWeightedPhysics") class LayeredWeightedPhysicsLoss(BaseLoss): """Combine intermediate physics residuals using normalized geometric weights.""" + def __init__(self, loss_args, args) -> None: super().__init__() self.base_weight = loss_args.base_weight @@ -381,6 +384,7 @@ def forward( @LOSS_REGISTRY.register("LossPerDim") class LossPerDim(BaseLoss): """Compute MAE/MSE for one named physical dimension of bus outputs.""" + def __init__(self, loss_args, args): super(LossPerDim, self).__init__() self.reduction = "mean" @@ -630,6 +634,8 @@ def forward( torch.imag(S_net - S_injection), ) return result + + @LOSS_REGISTRY.register("QgViolationPenalty") class QgViolationPenaltyLoss(BaseLoss): """Standard Mean Squared Error loss.""" @@ -652,12 +658,8 @@ def forward( Qg_max = x_dict["bus"][:, MAX_QG_H] Qg_min = x_dict["bus"][:, MIN_QG_H] - max_penalty_mask = (Qg_pred > Qg_max) - min_penalty_mask = (Qg_pred < Qg_min) - - mask_PQ = mask["PQ"] # PQ buses - mask_PV = mask["PV"] # PV buses - mask_REF = mask["REF"] # Reference buses + max_penalty_mask = Qg_pred > Qg_max + min_penalty_mask = Qg_pred < Qg_min loss = 0.0 # where there are violations, compute penalty loss @@ -666,19 +668,18 @@ def forward( Qg_over = Qg_over[max_penalty_mask].mean() Qg_under = Qg_under[min_penalty_mask].mean() - - if Qg_over!=Qg_over: # replacing nan with 0 + + if Qg_over != Qg_over: # replacing nan with 0 Qg_over = 0.0 - if Qg_under!=Qg_under: # replacing nan with 0 + if Qg_under != Qg_under: # replacing nan with 0 Qg_under = 0.0 - penalty_loss = Qg_over + Qg_under + penalty_loss = Qg_over + Qg_under loss += penalty_loss try: output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} - except: + except Exception: output = {"loss": loss, "Qg Violation Penalty loss": loss} return output - diff --git a/integrationtests/conftest.py b/integrationtests/conftest.py index 01737f5e..75e3debb 100644 --- a/integrationtests/conftest.py +++ b/integrationtests/conftest.py @@ -28,4 +28,3 @@ def calibrate_runs(request): def ci_level(request): """Confidence interval level requested via --ci (default 0.995).""" return request.config.getoption("--ci") - diff --git a/integrationtests/generate_test_data.py b/integrationtests/generate_test_data.py index fba2624b..c205d03d 100644 --- a/integrationtests/generate_test_data.py +++ b/integrationtests/generate_test_data.py @@ -29,7 +29,9 @@ def _base_config() -> dict: return config -def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") -> None: +def generate_pf_test_data( + config_path: str = "integrationtests/default_pf.yaml", +) -> None: """ Generate power-flow (PF) test data for case14_ieee with 10 000 scenarios and 2 topology variants. @@ -42,12 +44,16 @@ def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") print(f"PF config written to {config_path}") print(f" network.name : {config['network']['name']}") print(f" load.scenarios : {config['load']['scenarios']}") - print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}") + print( + f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}", + ) execute_and_live_output(f"gridfm_datakit generate {config_path}") -def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml") -> None: +def generate_opf_test_data( + config_path: str = "integrationtests/default_opf.yaml", +) -> None: """ Generate optimal power-flow (OPF) test data for case14_ieee with 10 000 scenarios and 2 topology variants. @@ -61,12 +67,14 @@ def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml print(f"OPF config written to {config_path}") print(f" network.name : {config['network']['name']}") print(f" load.scenarios : {config['load']['scenarios']}") - print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}") + print( + f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}", + ) print(f" settings.mode : {config['settings']['mode']}") execute_and_live_output(f"gridfm_datakit generate {config_path}") if __name__ == "__main__": - #generate_pf_test_data() + # generate_pf_test_data() generate_opf_test_data() diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 2dfe4d8f..e4ec862b 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -24,13 +24,22 @@ def collect_metrics_from_log(log_base: str, metric_keys: list) -> dict: run_dirs = glob.glob(os.path.join(latest_exp_dir, "*")) assert len(run_dirs) > 0, f"No run directories found in {latest_exp_dir}" latest_run_dir = max(run_dirs, key=os.path.getmtime) - metrics_file = os.path.join(latest_run_dir, "artifacts", "test", "case14_ieee_metrics.csv") + metrics_file = os.path.join( + latest_run_dir, + "artifacts", + "test", + "case14_ieee_metrics.csv", + ) assert os.path.exists(metrics_file), f"Metrics file not found: {metrics_file}" df = pd.read_csv(metrics_file) return dict(zip(df["Metric"], df["Value"].astype(float))) -def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interval: float = 0.995) -> None: +def print_calibration_stats( + all_runs: list, + metric_keys: list, + confidence_interval: float = 0.995, +) -> None: """ Print per-metric stats across calibration runs: - std with Bessel's correction (ddof=1) @@ -49,7 +58,9 @@ def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interv ci_pct = f"{confidence_interval * 100:g}" col_w = max(len(k) for k in metric_keys) + 2 header = f" {'Metric':<{col_w}} {'Mean':>10} {'Std(ddof=1)':>12} {f'CI {ci_pct}% lo':>10} {f'CI {ci_pct}% hi':>10}" - print(f"\n===== Calibration Results (n={n}, CI={confidence_interval}, t_crit={t_crit:.4f}) =====") + print( + f"\n===== Calibration Results (n={n}, CI={confidence_interval}, t_crit={t_crit:.4f}) =====", + ) print(header) print(" " + "-" * (len(header) - 2)) for key in metric_keys: @@ -63,7 +74,7 @@ def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interv me = t_crit * std / np.sqrt(len(arr)) # margin of error lo, hi = mean - me, mean + me print( - f" {key:<{col_w}} {mean:>10.4f} {std:>12.4f} {lo:>10.4f} {hi:>10.4f}" + f" {key:<{col_w}} {mean:>10.4f} {std:>12.4f} {lo:>10.4f} {hi:>10.4f}", ) print("=" * (len(header)) + "\n") @@ -88,7 +99,9 @@ def prepare_training_config(): with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - print(f"Training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}") + print( + f"Training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}", + ) return config_path @@ -113,7 +126,9 @@ def prepare_opf_training_config(): with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - print(f"OPF training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}") + print( + f"OPF training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}", + ) return config_path @@ -207,7 +222,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): last_error = None for attempt in range(1, MAX_RETRIES + 1): if attempt > 1: - print(f"\n--- PF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---") + print( + f"\n--- PF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---", + ) execute_and_live_output( f"gridfm_graphkit train " f"--config {training_config_path} " @@ -225,7 +242,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): assert 0.2042 <= pbe_mean_value <= 0.6397, ( f"PBE Mean value {pbe_mean_value} is outside 95% CI [0.2042, 0.6397]" ) - print(f"PBE Mean value {pbe_mean_value} is within 95% CI [0.2042, 0.6397] (attempt {attempt})") + print( + f"PBE Mean value {pbe_mean_value} is within 95% CI [0.2042, 0.6397] (attempt {attempt})", + ) last_error = None break except AssertionError as e: @@ -278,7 +297,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): opf_data_dir = "data_out_opf" if not os.path.exists(opf_data_dir) or not os.listdir(opf_data_dir): - print("OPF data directory not found or empty, downloading pre-generated data...") + print( + "OPF data directory not found or empty, downloading pre-generated data...", + ) gdrive_file_id = "1p5f5mRvmBQh8lZpIyWWbTbU42aHAIsdT" # pragma: allowlist secret zip_filename = "case14_ieee.10000_scenarios_2_variants_opf.zip" @@ -335,7 +356,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): last_error = None for attempt in range(1, MAX_RETRIES + 1): if attempt > 1: - print(f"\n--- OPF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---") + print( + f"\n--- OPF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---", + ) execute_and_live_output( f"gridfm_graphkit train " f"--config {training_config_path} " @@ -350,12 +373,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): try: for metric_name, (lo, hi) in checks.items(): - assert metric_name in metrics, f"Metric '{metric_name}' not found in CSV" + assert metric_name in metrics, ( + f"Metric '{metric_name}' not found in CSV" + ) value = metrics[metric_name] assert lo <= value <= hi, ( f"Metric '{metric_name}' value {value} is outside 99.5% CI [{lo}, {hi}]" ) - print(f"{metric_name}: {value} is within 99.5% CI [{lo}, {hi}] (attempt {attempt})") + print( + f"{metric_name}: {value} is within 99.5% CI [{lo}, {hi}] (attempt {attempt})", + ) last_error = None break except AssertionError as e: