diff --git a/README.md b/README.md index 7c9fdd02..a4dfc9fb 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' Install the correct torch-scatter wheel ```bash pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html +pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html ``` diff --git a/SECURITY.md b/SECURITY.md index edc0d73a..aedac7f0 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -58,8 +58,8 @@ Users are strongly encouraged to upgrade to the latest release to receive securi We aim to follow these response targets: -- **Initial acknowledgment**: within 72 hours -- **Status update**: within 7 days +- **Initial acknowledgment**: within 72 hours +- **Status update**: within 7 days - **Resolution target**: within 90 days (depending on severity) These are targets, not guarantees. @@ -79,10 +79,10 @@ These are targets, not guarantees. We follow a **coordinated vulnerability disclosure (CVD)** process: -- We work with reporters to agree on a disclosure timeline -- Public disclosure occurs after a fix is available or mitigation exists -- Contributors are credited unless anonymity is requested -- CVE identifiers will be requested when appropriate +- We work with reporters to agree on a disclosure timeline +- Public disclosure occurs after a fix is available or mitigation exists +- Contributors are credited unless anonymity is requested +- CVE identifiers will be requested when appropriate --- diff --git a/docs/quick_start/yaml_config.md b/docs/quick_start/yaml_config.md index 96c2fa62..568d6f67 100644 --- a/docs/quick_start/yaml_config.md +++ b/docs/quick_start/yaml_config.md @@ -85,12 +85,12 @@ Task name registered in the framework: ### `data.networks` -List of dataset folders under your data root. +List of dataset folders under your data root. Examples: `case14_ieee`, `case118_ieee`, `case2000_goc`, `Texas2k_case1_2016summerpeak`. ### `data.scenarios` -List of scenario counts, one value per network in `data.networks`. +List of scenario counts, one value per network in `data.networks`. Example: with two networks, use two scenario entries in matching order. ### `data.normalization` diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml new file mode 100644 index 00000000..e7845147 --- /dev/null +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -0,0 +1,108 @@ +callbacks: + patience: 100 + tol: 0 +task: + task_name: PowerFlow +data: + baseMVA: 100 + mask_type: rnd # or determinstic + mask_ratio: 0.5 # for random masking only + mask_value: 0 + normalization: HeteroDataMVANormalizer + networks: + - case14_ieee + scenarios: + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + posenc_RRWP: + enable: false + ksteps: 21 + cache: true + # topk: Number of highest-ranked neighbors to retain per node based on + # the L2 norm of their multi-step random walk probability vector. + # Original graph edges and self-loops are always retained regardless of + # rank. This provides sparse attention that interpolates between: + # topk=0 (or omitted): full RRWP - all N^2 pairs (dense attention) + # topk=K: each node attends to its K structurally most important + # neighbors plus the original graph edges + # For case14 (14 buses), topk=5 retains ~36% of all pairs + graph edges, + # providing global reach while being significantly sparser than full. + topk: 5 + posenc_RWSE: + enable: true + cache: true # cache the computed positional encoding at first pass (if false computes every access) + kernel: + times: 21 +model: + attention_head: 8 + dropout: 0.1 + # edge_dim must match the bus-bus edge feature count after transforms + # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) + edge_dim: 10 + hidden_size: 496 + # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) + input_dim: 16 + # Hetero adapter head dimensions + input_bus_dim: 16 + input_gen_dim: 6 + output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] + output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed + num_layers: 7 + type: GRIT + act: relu + encoder: + node_encoder: true + edge_encoder: true + node_encoder_name: RWSE + node_encoder_bn: true + edge_encoder_bn: true + posenc_RWSE: + # kernel.times is synced automatically from data.posenc_RWSE.kernel.times + pe_dim: 20 + raw_norm_type: batchnorm + gt: + layer_type: GritTransformer + # dim_hidden is synced automatically from model.hidden_size + layer_norm: false + batch_norm: true + update_e: true + attn_dropout: 0.2 + attn: + clamp: 5. + act: relu + # full_attn: false because the Top-K RRWP sparsification already + # determines the attention pattern. The RRWPLinearEdgeEncoder will + # merge the sparse RRWP edges with original graph edges without + # padding to a fully-connected graph. + full_attn: false # true only for full RRWP + edge_enhance: true + O_e: true + norm_e: true + signed_sqrt: true + bn_momentum: 0.1 + bn_no_runner: false +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.99 + - 0.01 + losses: + - PBE + - MaskedReconstructionMSE + loss_args: + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +verbose: true diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 695ba531..03166542 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -22,6 +22,7 @@ def _warn_mp_context_on_linux(mp_context): stacklevel=2, ) + def is_lsf(): return ( os.environ.get("LSB_JOBID") is not None @@ -29,6 +30,7 @@ def is_lsf(): and "LSF_ENVDIR" in os.environ # strong LSF indicator ) + def fix_infiniband(): """Configure NCCL to skip Ethernet-only IB ports on this host.""" ibv = subprocess.run("ibv_devinfo", stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -71,6 +73,7 @@ def set_env(): ) os.environ["NCCL_IB_CUDA_SUPPORT"] = "1" # Force use of infiniband + def main(): """Parse CLI arguments and dispatch to the selected GridFM subcommand.""" if is_lsf(): diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 53a970c8..d3256357 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -25,7 +25,9 @@ import lightning as L -def _normalize_loaded_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: +def _normalize_loaded_state_dict_keys( + state_dict: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: """Map legacy torch.compile checkpoint keys to the canonical model namespace.""" has_compiled_prefix = any(key.startswith("model._orig_mod.") for key in state_dict) if not has_compiled_prefix: @@ -163,7 +165,7 @@ def main_cli(args): run_name=args.run_name, ) - # When using torch.compile with Triton, dynamic graph support can cause + # When using torch.compile with Triton, dynamic graph support can cause # out-of-memory errors during autotuning on some kernels. # Disabling dynamic graph support allows those kernels # to be skipped gracefully instead of causing errors. @@ -238,12 +240,17 @@ def main_cli(args): _accelerator = config_args.training.accelerator _strategy = config_args.training.strategy # if mps is available and accelerator is auto, explicitely set accelerator to mps to select the right strategy in the next block - if _accelerator == "auto" and torch.backends.mps.is_available(): + if _accelerator == "auto" and torch.backends.mps.is_available(): _accelerator = "mps" - if _accelerator not in ("mps", "cpu") and isinstance(_strategy, str) and _strategy in ( - "auto", - "ddp", - ): # when using mps, we don't want to use ddp. + if ( + _accelerator not in ("mps", "cpu") + and isinstance(_strategy, str) + and _strategy + in ( + "auto", + "ddp", + ) + ): # when using mps, we don't want to use ddp. _strategy = DDPStrategy(find_unused_parameters=False) trainer = L.Trainer( diff --git a/gridfm_graphkit/datasets/cached_transform.py b/gridfm_graphkit/datasets/cached_transform.py new file mode 100644 index 00000000..17f37878 --- /dev/null +++ b/gridfm_graphkit/datasets/cached_transform.py @@ -0,0 +1,252 @@ +import os +import tempfile +import hashlib + +import torch +from torch_geometric.data import HeteroData +from torch_geometric.utils import remove_self_loops + + +def _topology_fingerprint( + data, + cached_edge_type=None, + use_admittance=False, + admittance_remove_self_loops=True, +): + """Compute a short hash that uniquely identifies the graph topology. + + The fingerprint captures everything that determines the positional + encoding output: edge connectivity, number of nodes, and (when + admittance-weighted) the edge weights derived from admittance values. + + This allows samples sharing the same topology (and admittance values) + to reuse a single cached PE file instead of storing redundant copies. + + Args: + data: HeteroData or Data object. + cached_edge_type: Edge type tuple for hetero data, or None. + use_admittance: Whether admittance weights affect the PE. + admittance_remove_self_loops: Whether self-loops are removed + when computing admittance weights. + + Returns: + A 16-character hex string uniquely identifying the topology. + """ + if isinstance(data, HeteroData): + edge_index = data["bus", "connects", "bus"].edge_index + num_nodes = data["bus"].num_nodes + edge_attr = data["bus", "connects", "bus"].edge_attr + else: + edge_index = data.edge_index + num_nodes = data.num_nodes + edge_attr = getattr(data, "edge_attr", None) + + h = hashlib.sha256() + h.update( + num_nodes.to_bytes(4, "little") + if isinstance(num_nodes, int) + else int(num_nodes).to_bytes(4, "little"), + ) + h.update(edge_index.cpu().numpy().tobytes()) + + # Include admittance values in the fingerprint when they affect the PE + if use_admittance and edge_attr is not None: + if edge_attr.size(1) == 2: + g, b = edge_attr[:, 0], edge_attr[:, 1] + else: + # Yft at indices 4, 5 + g, b = edge_attr[:, 4], edge_attr[:, 5] + edge_weight = torch.sqrt(g**2 + b**2) + + if admittance_remove_self_loops: + _, edge_weight = remove_self_loops(edge_index, edge_weight) + + # Quantize to 6 decimal places to avoid floating-point noise + # causing spurious cache misses + quantized = (edge_weight * 1e6).round().to(torch.int64) + h.update(quantized.cpu().numpy().tobytes()) + + return h.hexdigest()[:16] + + +class CachedPosencTransform: + """Disk-caching wrapper for positional encoding transforms. + + Computes the PE on the first access and caches the result to disk. + Subsequent accesses with the same graph topology load from cache, + avoiding redundant computation across epochs, splits, and jobs. + + Cache keying strategy: + PE depends only on graph topology (edge_index, num_nodes) and + optionally on admittance values — NOT on node features like + loads/generation. The cache is keyed by a hash of these + topology-determining inputs so that all samples sharing the same + network structure reuse a single cache file. + + Thread/process safety: + - Uses atomic write (write to temp file, then os.replace) so + concurrent DataLoader workers or separate jobs cannot produce + corrupt cache files. + - Different train/val/test splits across jobs safely share the + cache since PE depends only on topology, not on split + membership. + + Args: + transform: The inner PE transform (e.g. ComputePosencStat). + cache_dir: Directory to store cached PE tensors. + cached_attrs: List of attribute names to cache on the bus node store + (e.g. ["pestat_RWSE"]). + cached_edge_type: Optional edge type tuple (e.g. ("bus", "rrwp", "bus")) + whose edge_index and edge_attr should also be cached. + key_attr: Attribute on the data object used as the cache key. + If "topology" (default), uses a hash of the graph structure + so samples with identical topology share one cache file. + Otherwise, uses the named scalar attribute (e.g. "scenario_id") + for per-sample caching. + use_admittance: Whether admittance weights are used in the PE + computation (affects the topology fingerprint). + admittance_remove_self_loops: Whether self-loops are removed + for admittance weighting (affects the topology fingerprint). + """ + + def __init__( + self, + transform, + cache_dir: str, + cached_attrs: list[str], + cached_edge_type: tuple[str, str, str] | None = None, + key_attr: str = "topology", + use_admittance: bool = False, + admittance_remove_self_loops: bool = True, + ): + self.transform = transform + self.cache_dir = cache_dir + self.cached_attrs = cached_attrs + self.cached_edge_type = cached_edge_type + self.key_attr = key_attr + self.use_admittance = use_admittance + self.admittance_remove_self_loops = admittance_remove_self_loops + os.makedirs(cache_dir, exist_ok=True) + + def _cache_path(self, data) -> str: + if self.key_attr == "topology": + key = _topology_fingerprint( + data, + cached_edge_type=self.cached_edge_type, + use_admittance=self.use_admittance, + admittance_remove_self_loops=self.admittance_remove_self_loops, + ) + else: + key = data[self.key_attr].item() + return os.path.join(self.cache_dir, f"pe_cache_{key}.pt") + + def _load_cache(self, cache_path, data): + """Load cached PE attributes and attach them to data.""" + cached = torch.load(cache_path, weights_only=True) + if isinstance(data, HeteroData): + for attr, val in cached.items(): + if attr == "_edge_type_index": + data[self.cached_edge_type].edge_index = val + elif attr == "_edge_type_attr": + data[self.cached_edge_type].edge_attr = val + else: + data["bus"][attr] = val + else: + for attr, val in cached.items(): + setattr(data, attr, val) + + def _save_cache(self, cache_path, data): + """Atomically save PE attributes to the cache file. + + Uses a temporary file in the same directory followed by + os.replace, which is atomic on both Linux and Windows. + This ensures concurrent workers/jobs never see a partially + written file. + """ + if isinstance(data, HeteroData): + target = data["bus"] + else: + target = data + + to_cache = {} + for attr in self.cached_attrs: + if hasattr(target, attr): + to_cache[attr] = getattr(target, attr) + + # Cache edge-type data (RRWP sparse index + values) + if ( + self.cached_edge_type is not None + and isinstance(data, HeteroData) + and self.cached_edge_type in data.edge_types + ): + edge_store = data[self.cached_edge_type] + if hasattr(edge_store, "edge_index"): + to_cache["_edge_type_index"] = edge_store.edge_index + if hasattr(edge_store, "edge_attr"): + to_cache["_edge_type_attr"] = edge_store.edge_attr + + if not to_cache: + return + + # Write to a temporary file in the same directory (same filesystem) + # to guarantee os.replace is atomic. + fd, tmp_path = tempfile.mkstemp( + dir=self.cache_dir, + prefix=f".pe_cache_tmp_{os.getpid()}_", + suffix=".pt", + ) + try: + os.close(fd) + torch.save(to_cache, tmp_path) + os.replace(tmp_path, cache_path) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + def __call__(self, data): + cache_path = self._cache_path(data) + + # Fast path: load from cache if available + if os.path.exists(cache_path): + self._load_cache(cache_path, data) + return data + + # Slow path: compute, then cache + data = self.transform(data) + self._save_cache(cache_path, data) + return data + + +def make_pe_cache_dir(processed_dir: str, pe_type: str, cfg) -> str: + """Build a cache directory path that includes a config fingerprint. + + The fingerprint ensures that changing PE parameters (e.g. kernel.times) + invalidates the cache automatically by using a different directory. + + Args: + processed_dir: The dataset's processed directory. + pe_type: "RWSE" or "RRWP". + cfg: The data config namespace containing PE parameters. + + Returns: + Path to the cache directory. + """ + if pe_type == "RWSE": + kernel_times = cfg.posenc_RWSE.kernel.times + fingerprint = f"k{kernel_times}" + elif pe_type == "RRWP": + ksteps = cfg.posenc_RRWP.ksteps + topk = getattr(cfg.posenc_RRWP, "topk", 0) + if topk and topk > 0: + fingerprint = f"k{ksteps}_topk{topk}" + else: + fingerprint = f"k{ksteps}" + else: + fingerprint = "default" + + cache_dir_name = f"pe_cache_{pe_type.lower()}_{fingerprint}" + return os.path.join(processed_dir, cache_dir_name) diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3d..c62cfae7 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -24,6 +24,8 @@ VA_OUT = 1 PG_OUT = 2 QG_OUT = 3 +PD_OUT = 4 # for random masking +QD_OUT = 5 # for random masking PG_OUT_GEN = 0 diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 83527bdc..6287b6c6 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -17,6 +17,15 @@ split_from_existing_files, ) from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat +from gridfm_graphkit.datasets.cached_transform import ( + CachedPosencTransform, + make_pe_cache_dir, +) + +import torch_geometric.transforms as T + import numpy as np import random import warnings @@ -123,13 +132,18 @@ def __init__( self._is_setup_done = False if self.split_by_load_scenario_idx: - assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + assert self.split_from_existing_files is None, ( + " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + ) if self.split_from_existing_files is not None: - assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" + assert isinstance(self.split_from_existing_files, str), ( + "`split_from_existing_files` must be an existing folder in string format" + ) self.split_from_existing_files = Path(self.split_from_existing_files) - assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" - + assert self.split_from_existing_files.is_dir(), ( + "`split_from_existing_files` must be an existing folder in string format" + ) def setup(self, stage: str): if self._is_setup_done: @@ -173,6 +187,44 @@ def setup(self, stage: str): transform=get_task_transforms(args=self.args), ) + if ("posenc_RRWP" in self.args.data) and self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=["RRWP"], cfg=self.args.data) + if getattr(self.args.data.posenc_RRWP, "cache", False): + cache_dir = make_pe_cache_dir( + dataset.processed_dir, + "RRWP", + self.args.data, + ) + pe_transform = CachedPosencTransform( + pe_transform, + cache_dir, + cached_attrs=["rrwp", "log_deg", "deg"], + cached_edge_type=("bus", "rrwp", "bus"), + key_attr="topology", + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ("posenc_RWSE" in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=["RWSE"], cfg=self.args.data) + if getattr(self.args.data.posenc_RWSE, "cache", False): + cache_dir = make_pe_cache_dir( + dataset.processed_dir, + "RWSE", + self.args.data, + ) + pe_transform = CachedPosencTransform( + pe_transform, + cache_dir, + cached_attrs=["pestat_RWSE"], + key_attr="topology", + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] @@ -186,7 +238,6 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) - if self.split_from_existing_files is not None: warnings.warn( "`data.scenarios` is ignored when `split_from_existing_files` is set; " @@ -231,13 +282,14 @@ def setup(self, stage: str): # load_scenario for each scenario in the subset load_scenarios = dataset.load_scenarios[subset_indices] - dataset = Subset(dataset, subset_indices) - + if self.dataset_wrapper is not None: wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - + dataset = wrapper_cls( + dataset, + cache_dir=self.dataset_wrapper_cache_dir, + ) # Random seed set before every split, same as above np.random.seed(self.args.seed) @@ -432,7 +484,12 @@ def _dataloader_kwargs(self): return kwargs def train_dataloader(self): - print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed") + print( + "creating train dataloader for rank ", + dist.get_rank() + if dist.is_available() and dist.is_initialized() + else "not distributed", + ) return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index df2f657c..3b8f78fe 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -33,6 +33,60 @@ from torch_geometric.nn import MessagePassing +class AddRandomHeteroMask(BaseTransform): + """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. + + Each selected feature dimension is independently masked per node/edge with + probability ``mask_ratio``. Masked bus features: PD, QD, VM, VA, QG. + Masked gen features: PG. Masked branch features: P_E, Q_E. + + The output ``data.mask_dict`` has the same structure as the deterministic + PF / OPF masks so that downstream losses (``MaskedReconstructionMSE``, + ``PBELoss``, etc.) work without modification. + """ + + def __init__(self, mask_ratio=0.5): + super().__init__() + self.mask_ratio = mask_ratio + + def forward(self, data): + bus_x = data.x_dict["bus"] + gen_x = data.x_dict["gen"] + + # Bus type indicators (needed by losses and test metrics) + mask_PQ = bus_x[:, PQ_H] == 1 + mask_PV = bus_x[:, PV_H] == 1 + mask_REF = bus_x[:, REF_H] == 1 + + # Random bus mask on variable features the model reconstructs + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) + n_bus = bus_x.size(0) + for feat_idx in (PD_H, QD_H, VM_H, VA_H, QG_H): + mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio + + # Random gen mask on PG + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) + mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio + + # Random branch mask on flow features + branch_attr = data.edge_attr_dict[("bus", "connects", "bus")] + mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool) + n_edge = branch_attr.size(0) + for feat_idx in (P_E, Q_E): + mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "branch": mask_branch, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + } + + return data + + class AddPFHeteroMask(BaseTransform): """Creates masks for a heterogeneous power flow graph.""" @@ -158,6 +212,7 @@ def forward(self, data): class BusToGenBroadcaster(MessagePassing): """Broadcast per-bus values to connected generators via graph propagation.""" + def __init__(self, aggr="add"): super().__init__(aggr=aggr) @@ -176,6 +231,7 @@ def message(self, x_j): class SimulateMeasurements(BaseTransform): """Add configurable noise/outliers and masks to simulate measured quantities.""" + def __init__(self, args): super().__init__() self.measurements = args.task.measurements diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index eb5652d7..d53e24cd 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -20,6 +20,8 @@ # Output feature indices PG_OUT, QG_OUT, + PD_OUT, + QD_OUT, PG_OUT_GEN, # Generator feature indices PG_H, @@ -228,8 +230,14 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = torch.tensor(self.baseMVA, dtype=data.x_dict["bus"].dtype) # # needs to be float32 for MPS - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.baseMVA = torch.tensor( + self.baseMVA, + dtype=data.x_dict["bus"].dtype, + ) # # needs to be float32 for MPS + data.is_normalized = torch.tensor( + True, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): if self.baseMVA is None or self.baseMVA == 0: @@ -299,13 +307,20 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= self.baseMVA - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): bus_output = output["bus"] gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= self.baseMVA + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= self.baseMVA gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: @@ -510,7 +525,10 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= e_b - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + True, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): """Undo per-unit normalization (multiply by baseMVA, inverse log1p for cost coeffs).""" @@ -573,7 +591,10 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= e_b - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, + dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): """ @@ -606,6 +627,10 @@ def inverse_output(self, output, batch): # Scale per-unit power back to MW/Mvar bus_output[:, PG_OUT] *= b_bus bus_output[:, QG_OUT] *= b_bus + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= b_bus + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= b_bus gen_output[:, PG_OUT_GEN] *= b_gen def get_stats(self) -> dict: diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 00000000..8fca4f19 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,209 @@ +import torch + +from torch_geometric.utils import ( + to_dense_adj, +) +from torch_geometric.utils.num_nodes import maybe_num_nodes + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp, add_topk_rrwp + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data, HeteroData +from typing import Any + + +def compute_posenc_stats(data, pe_types, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in [ + "LapPE", + "EquivStableLapPE", + "SignNet", + "RWSE", + "HKdiagSE", + "HKfullPE", + "ElstaticSE", + "RRWP", + ]: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + if "RRWP" in pe_types: + param = cfg.posenc_RRWP + topk = getattr(param, "topk", 0) + if topk and topk > 0: + transform = partial( + add_topk_rrwp, + walk_length=param.ksteps, + topk=topk, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + ) + else: + transform = partial( + add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + ) + data = transform(data) + + # Random Walks. + if "RWSE" in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if hasattr(data, "num_nodes"): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + if kernel_param.times == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs( + ksteps=[xx + 1 for xx in range(kernel_param.times)], + edge_index=data.edge_index, + num_nodes=N, + ) + data.pestat_RWSE = rw_landing + + return data + + +def get_rw_landing_probs( + ksteps, + edge_index, + edge_weight=None, + num_nodes=None, + space_dim=0, +): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, _ = edge_index[0], edge_index[1] + if scatter_add is None: + raise ImportError( + "torch-scatter is required for RWSE positional encodings. " + "Install it with: pip install torch-scatter", + ) + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.0) + deg_inv.masked_fill_(deg_inv == float("inf"), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj( + edge_index, + max_num_nodes=num_nodes, + ) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append( + torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) + * (k ** (space_dim / 2)), + ) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, cfg): + self.pe_types = pe_types + self.cfg = cfg + + def forward(self, data: Any) -> Any: + pass + + def __call__(self, data) -> Data: + if isinstance(data, HeteroData): + return self._call_hetero(data) + + data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg) + return data + + def _call_hetero(self, data: HeteroData) -> HeteroData: + """Compute PE on the bus-only subgraph and store results on data['bus'].""" + bus_data = Data( + x=data["bus"].x, + edge_index=data["bus", "connects", "bus"].edge_index, + num_nodes=data["bus"].num_nodes, + ) + if hasattr(data["bus", "connects", "bus"], "edge_weight"): + bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight + + bus_data = compute_posenc_stats( + bus_data, + pe_types=self.pe_types, + cfg=self.cfg, + ) + + # Copy computed PE attributes back onto the HeteroData. + # Node-level attrs go on the bus store; the sparse RRWP index/val + # are stored as a dedicated edge type so PyG batches them correctly + # (edge_index needs cat_dim=1 + node-count incrementing). + node_pe_attrs = [ + "pestat_RWSE", # RWSE + "rrwp", # RRWP absolute (node-level diagonal) + "log_deg", + "deg", # degree info from RRWP + ] + for attr in node_pe_attrs: + if hasattr(bus_data, attr): + data["bus"][attr] = getattr(bus_data, attr) + + # RRWP relative PE: sparse [2, E] index + [E, K] values → edge type + if hasattr(bus_data, "rrwp_index") and hasattr(bus_data, "rrwp_val"): + data["bus", "rrwp", "bus"].edge_index = bus_data.rrwp_index + data["bus", "rrwp", "bus"].edge_attr = bus_data.rrwp_val + + return data diff --git a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py index 82f57a57..a0510ddb 100644 --- a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py @@ -73,9 +73,14 @@ def process(self): ) if "load_scenario_idx" in bus_data.columns: load_scenarios = torch.tensor( - bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, + bus_data.groupby("scenario", sort=True)["load_scenario_idx"] + .first() + .values, + ) + torch.save( + load_scenarios, + osp.join(self.processed_dir, "load_scenarios.pt"), ) - torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) agg_gen = ( gen_data.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] @@ -136,7 +141,9 @@ def process(self): ] + common_branch_features # Group by scenario - bus_groups = bus_data.groupby("scenario") # Groupby preserves the order of rows within each group. + bus_groups = bus_data.groupby( + "scenario", + ) # Groupby preserves the order of rows within each group. # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html gen_groups = gen_data.groupby("scenario") branch_groups = branch_data.groupby("scenario") @@ -159,8 +166,10 @@ def process(self): # Bus nodes bus_df = bus_groups.get_group(scenario) # assert that the buses are in increasing order - assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), "Buses are not in increasing order" - #todo: we should remove this assert and store the bus idx in the tensors + assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), ( + "Buses are not in increasing order" + ) + # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order for e.g. the predict step that uses torch.arange(n_nodes) to index the buses. data["bus"].x = torch.tensor(bus_df[bus_features].values, dtype=torch.float) diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 00000000..d360117a --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,227 @@ +from typing import Any, Optional +import torch +import torch.nn.functional as F +from torch_geometric.data import Data + +try: + from torch_sparse import SparseTensor +except ImportError: + SparseTensor = None + + +def add_node_attr(data: Data, value: Any, attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if "x" in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + +@torch.no_grad() +def add_full_rrwp( + data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs, +): + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + if SparseTensor is None: + raise ImportError( + "torch-sparse is required for RRWP positional encodings. " + "Install it with: pip install torch-sparse", + ) + + adj = SparseTensor.from_edge_index( + edge_index, + edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float("inf")] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + + +@torch.no_grad() +def add_topk_rrwp( + data, + walk_length=8, + topk=10, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + spd=False, + **kwargs, +): + """Compute RRWP positional encodings with Top-K sparsification. + + Instead of retaining the full N×N relative PE matrix, this function + keeps only the `topk` highest-magnitude neighbors per node (based on + the L2 norm of the multi-step random walk probability vector). The + original graph edges are always retained regardless of their rank. + + This provides a smooth interpolation between: + - topk=0 (or topk >= N): equivalent to full RRWP (all pairs) + - topk=1: nearly equivalent to RWSE (mostly self-loops / local) + + Args: + data: PyG Data object with edge_index. + walk_length: Number of random walk steps (k). + topk: Number of highest-ranked neighbors to retain per node. + If 0 or >= num_nodes, all edges are kept (full RRWP). + attr_name_abs: Attribute name for the absolute (diagonal) PE. + attr_name_rel: Prefix for relative PE index/val attributes. + add_identity: Whether to include the identity (step 0) in PE. + spd: If True, encode shortest-path distance instead of probabilities. + + Returns: + Data object with rrwp, rrwp_index, rrwp_val, log_deg, deg attributes. + """ + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + if SparseTensor is None: + raise ImportError( + "torch-sparse is required for RRWP positional encodings. " + "Install it with: pip install torch-sparse", + ) + + adj = SparseTensor.from_edge_index( + edge_index, + edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float("inf")] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (pe > 0).type(torch.float) * spd_idx.unsqueeze(0).unsqueeze(0) + val = torch.argmax(val, dim=-1) + pe = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + # --- Top-K sparsification --- + # If topk <= 0 or topk >= num_nodes, keep everything (full RRWP) + if topk <= 0 or topk >= num_nodes: + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + else: + # Score each (i,j) pair by L2 norm of the k-step probability vector + # pe shape: [n, n, k] — pe[i, j, :] is the walk vector from i to j + scores = pe.norm(dim=-1) # [n, n] + + # Always include original graph edges (set their scores to infinity) + edge_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + edge_mask[edge_index[0], edge_index[1]] = True + # Also always include self-loops + diag_idx = torch.arange(num_nodes) + edge_mask[diag_idx, diag_idx] = True + + # For Top-K selection: get top-k scores per row (per source node) + # Clamp topk to at most num_nodes + k = min(topk, num_nodes) + _, topk_indices = scores.topk(k, dim=1) # [n, k] + + # Build a mask of selected entries + topk_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + row_idx = torch.arange(num_nodes).unsqueeze(1).expand_as(topk_indices) + topk_mask[row_idx, topk_indices] = True + + # Union of top-k and original edges + keep_mask = topk_mask | edge_mask + + # Extract sparse entries from the masked pe tensor + # Zero out entries we don't want, then sparsify + pe_sparse = pe.clone() + pe_sparse[~keep_mask] = 0 + + rel_pe = SparseTensor.from_dense(pe_sparse, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index dffb66cb..64b7e223 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -8,6 +8,7 @@ from gridfm_graphkit.datasets.masking import ( AddOPFHeteroMask, AddPFHeteroMask, + AddRandomHeteroMask, SimulateMeasurements, ) from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY @@ -16,12 +17,19 @@ @TRANSFORM_REGISTRY.register("PowerFlow") class PowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for PowerFlow datasets.""" + def __init__(self, args): transforms = [] transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose @@ -31,12 +39,19 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("OptimalPowerFlow") class OptimalPowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for OptimalPowerFlow datasets.""" + def __init__(self, args): transforms = [] transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddOPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddOPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose @@ -46,6 +61,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("StateEstimation") class StateEstimationTransforms(Compose): """Compose preprocessing and measurement transforms for StateEstimation datasets.""" + def __init__(self, args): transforms = [] diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index c6891dc2..f58a1fec 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -97,6 +97,7 @@ def forward(self, data): class LoadGridParamsFromPath(BaseTransform): """Inject static grid parameters from a saved grid template into each sample.""" + def __init__(self, args): super().__init__() self.grid_path = args.task.grid_path diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index 65b34f4e..8d16de92 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -114,8 +114,8 @@ def split_from_existing_files( split_dataset = Subset(dataset, split_indices) output.append(split_dataset) split_indices = list(split_indices) - print(f'{split=} {len(split_indices)=}') - indices[split]=[int(t.item()) for t in split_indices] + print(f"{split=} {len(split_indices)=}") + indices[split] = [int(t.item()) for t in split_indices] output = tuple(output) - return output, indices \ No newline at end of file + return output, indices diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 65d596a9..c26f07b9 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -1,5 +1,6 @@ class Registry: """Simple name-to-object registry with decorator-based registration.""" + def __init__(self, name: str): self._name = name self._registry = {} diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f8245352..f185c6a2 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, PhysicsDecoderPF, @@ -7,6 +8,7 @@ __all__ = [ "GNS_heterogeneous", + "GritHeteroAdapter", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", diff --git a/gridfm_graphkit/models/gnn_heterogeneous_gns.py b/gridfm_graphkit/models/gnn_heterogeneous_gns.py index 8f2ba3e4..6a339c18 100644 --- a/gridfm_graphkit/models/gnn_heterogeneous_gns.py +++ b/gridfm_graphkit/models/gnn_heterogeneous_gns.py @@ -162,14 +162,20 @@ def __init__(self, args) -> None: # container for monitoring residual norms per layer and type self.layer_residuals = {} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + def forward(self, batch): """ - x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} - edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") - edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) - batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. - mask: optional mask per node (applies when computing residuals) + Accepts a PyG HeteroData batch and extracts the required tensors. + + batch: HeteroData/Batch containing: + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + mask_dict: dict mapping node/bus types to mask tensors """ + x_dict = batch.x_dict + edge_index_dict = batch.edge_index_dict + edge_attr_dict = batch.edge_attr_dict + mask_dict = batch.mask_dict self.layer_residuals = {} diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 00000000..c0136981 --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,391 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes + +try: + from torch_scatter import scatter, scatter_max, scatter_add +except ImportError: + scatter = None + scatter_max = None + scatter_add = None + +import opt_einsum as oe + +import warnings + + +def pyg_softmax(src, index, num_nodes=None): + """ + Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + Returns: + out (Tensor) + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Attention Computation for GRIT + """ + + def __init__( + self, + in_dim, + out_dim, + num_heads, + use_bias, + clamp=5.0, + dropout=0.0, + act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg={}, + **kwargs, + ): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, 1), + requires_grad=True, + ) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = nn.ReLU() + + if self.edge_enhance: + self.VeRow = nn.Parameter( + torch.zeros(self.out_dim, self.num_heads, self.out_dim), + requires_grad=True, + ) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, : self.out_dim], batch.E[:, :, self.out_dim :] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + score = pyg_softmax( + score, + batch.edge_index[1], + ) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = ( + batch.V_h[batch.edge_index[0]] * score + ) # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like( + batch.V_h, + ) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce="add") + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get("wE", None) + + return h_out, e_out + + +class GritTransformerLayer(nn.Module): + """ + Transformer Layer for GRIT + """ + + def __init__( + self, + in_dim, + out_dim, + num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, + batch_norm=True, + residual=True, + act="relu", + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs, + ): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) + + if act is not None: + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: + cfg.attn = dict() + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=getattr(cfg.attn, "use_bias", False), + dropout=attn_dropout, + clamp=getattr(cfg.attn, "clamp", 5.0), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn=getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), + ) + + self.O_h = nn.Linear(out_dim // num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim // num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter( + torch.zeros(1, out_dim // num_heads * num_heads, 2), + ) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) + self.batch_norm1_e = ( + nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) + if norm_e + else nn.Identity() + ) + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d( + out_dim, + track_running_stats=not self.bn_no_runner, + eps=1e-5, + momentum=cfg.attn.bn_momentum, + ) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1, 1)) + self.alpha2_h = nn.Parameter(torch.zeros(1, 1)) + self.alpha1_e = nn.Parameter(torch.zeros(1, 1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: + h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: + e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: + e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: + e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: + h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return ( + "{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.num_heads, + self.residual, + super().__repr__(), + ) + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn( + "Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs", + ) + deg = pyg.utils.degree( + batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float, + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 00000000..8e3a41e6 --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,432 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +import torch +from torch import nn +from torch_geometric.data import Data + +from gridfm_graphkit.models.rrwp_encoder import ( + RRWPLinearNodeEncoder, + RRWPLinearEdgeEncoder, +) +from gridfm_graphkit.models.grit_layer import GritTransformerLayer +from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + +from gridfm_graphkit.datasets.globals import PG_H + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch + + +class BatchNorm1dEdge(torch.nn.Module): + r"""A batch normalization layer for edge-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.edge_attr = self.bn(batch.edge_attr) + return batch + + +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, dim_in, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, edge_dim, emb_dim): + super().__init__() + + self.in_dim = edge_dim + + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + + """ + + def __init__(self, dim_in, dim_inner, args): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if args.encoder.node_encoder: + # Encode integer node features via nn.Embeddings + if "RWSE" in args.encoder.node_encoder_name: + self.node_encoder = RWSENodeEncoder( + self.dim_in, + dim_inner, + args.encoder.posenc_RWSE, + ) + else: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = dim_inner + if args.encoder.edge_encoder: + edge_dim = args.edge_dim + enc_dim_edge = dim_inner + # Encode integer edge features via nn.Embeddings + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + + +class GraphHead(nn.Module): + """ + Prediction head for decoding tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + graph_emb = self.FC_layers(batch.x) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred + + +class GritTransformer(torch.nn.Module): + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ + + def __init__(self, args, include_decoder=True): + super().__init__() + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) + + self.encoder = FeatureEncoder(dim_in, dim_inner, args.model) + dim_in = self.encoder.dim_in + + if args.data.posenc_RRWP.enable: + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.data.posenc_RRWP.ksteps, + dim_inner, + ) + rel_pe_dim = args.data.posenc_RRWP.ksteps + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( + rel_pe_dim, + dim_inner, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0.0, + ) + + assert args.model.hidden_size == dim_inner == dim_in, ( + "The inner and hidden dims must match." + ) + + layers = [] + for ll in range(num_layers): + # The last layer's edge output is never consumed downstream + # (only node features feed into the output heads), so skip + # creating O_e / norm_e parameters to avoid DDP unused-parameter + # errors. + is_last = ll == num_layers - 1 + layers.append( + GritTransformerLayer( + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, + residual=True, + norm_e=False if is_last else args.model.gt.attn.norm_e, + O_e=False if is_last else args.model.gt.attn.O_e, + cfg=args.model.gt, + ), + ) + + self.layers = nn.Sequential(*layers) + + if include_decoder: + self.decoder = GraphHead(dim_inner, dim_out) + + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) + # print('>>>>', batch) + for module in self.children(): + batch = module(batch) + + return batch + + +def aggregate_pg(batch, mask_value=-1.0): + """Aggregate per-generator active power (PG) onto bus nodes. + + In the homogeneous reference, PG is a direct bus feature visible to the + transformer alongside Pd, Qd, Vm, Va, etc. In the heterogeneous + representation PG lives on separate generator nodes, so it must be + aggregated onto buses before the transformer can learn voltage-power + coupling. + + Masked generators (where PG has been replaced by the mask value) are + excluded from the sum to avoid corrupting the aggregated signal. Buses + where *all* connected generators are masked receive the mask value + instead, preserving a consistent "unknown" indicator. + """ + if scatter_add is None: + raise ImportError( + "torch-scatter is required for the GRIT modules but is not installed. " + "Install it with: pip install torch-scatter", + ) + + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index + gen_pg = batch["gen"].x[:, PG_H] + gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked + + # Zero out masked generators so they don't contribute to the sum + pg_clean = torch.where(gen_masked, torch.zeros_like(gen_pg), gen_pg) + + pg_per_bus = scatter_add( + pg_clean, + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + + # Check which buses have ALL generators masked (or no generators at all) + unmasked_count = scatter_add( + (~gen_masked).float(), + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + all_masked = unmasked_count == 0 + + # Set mask_value for fully-masked buses + pg_per_bus[all_masked] = mask_value + + return pg_per_bus + + +@MODELS_REGISTRY.register("GRIT") +class GritHeteroAdapter(torch.nn.Module): + """Adapter that enables the homogeneous GRIT transformer to operate on + heterogeneous power-grid graphs. + + Extracts the bus-only homogeneous subgraph using PyG's native HeteroData + accessors, runs it through the GRIT encoder and transformer layers, and + produces per-node-type predictions. Generator output comes from a + lightweight standalone MLP (generators are not seen by the transformer). + + Returns: + dict: ``{"bus": Tensor[num_bus, output_bus_dim], + "gen": Tensor[num_gen, output_gen_dim]}`` + """ + + def __init__(self, args): + super().__init__() + + dim_inner = args.model.hidden_size + output_bus_dim = args.model.output_bus_dim + output_gen_dim = args.model.output_gen_dim + input_gen_dim = args.model.input_gen_dim + + # Ensure config keys expected by GritTransformer are present. + # input_dim = bus feature dimension (used by FeatureEncoder) + # output_dim = bus output dimension (used by the unused GraphHead) + if not hasattr(args.model, "input_dim"): + args.model.input_dim = args.model.input_bus_dim + if not hasattr(args.model, "output_dim"): + args.model.output_dim = output_bus_dim + + # Sync PE kernel.times from data config into model encoder config so + # users only need to specify it once (under data.posenc_RWSE). + if ( + hasattr(args.data, "posenc_RWSE") + and args.data.posenc_RWSE.enable + and hasattr(args.model, "encoder") + and hasattr(args.model.encoder, "posenc_RWSE") + ): + from gridfm_graphkit.io.param_handler import NestedNamespace + + enc_rwse = args.model.encoder.posenc_RWSE + if not hasattr(enc_rwse, "kernel"): + enc_rwse.kernel = NestedNamespace() + enc_rwse.kernel.times = args.data.posenc_RWSE.kernel.times + + # Sync gt.dim_hidden from model.hidden_size so it is specified once. + if hasattr(args.model, "gt"): + args.model.gt.dim_hidden = args.model.hidden_size + + # The original homogeneous GRIT + # (encoder + optional PE encoders + transformer layers) + # Decoder is excluded — this adapter provides its own per-type heads. + self.grit = GritTransformer(args, include_decoder=False) + + # Per-node-type output heads (replace GraphHead for hetero output) + self.bus_head = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_bus_dim), + ) + # gen_head is only needed for tasks that require per-generator + # predictions (e.g. OPF cost computation). When output_gen_dim is 0 + # or not set, skip it to avoid DDP unused-parameter errors. + if output_gen_dim and output_gen_dim > 0: + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + else: + self.gen_head = None + + def forward(self, batch): + """Forward pass on a heterogeneous power-grid batch. + + Args: + batch: A batched ``HeteroData`` with node types ``"bus"`` and + ``"gen"``, and edge type ``("bus", "connects", "bus")``. + + Returns: + dict with keys ``"bus"`` and ``"gen"``, each mapping to the + predicted output features. + """ + # --- Extract bus-only homogeneous subgraph --- + # Aggregate generator PG onto buses + pg_per_bus = aggregate_pg(batch, mask_value=self.grit.mask_value[0].item()) + bus_x = torch.cat( + [batch["bus"].x, pg_per_bus.unsqueeze(-1)], + dim=-1, + ) # 15 → 16D + + homo = Data( + x=bus_x, + y=batch["bus"].y, + edge_index=batch["bus", "connects", "bus"].edge_index, + edge_attr=batch["bus", "connects", "bus"].edge_attr, + batch=batch["bus"].batch, + ) + + # Forward positional-encoding attributes if present + for attr in ("pestat_RWSE", "rrwp", "log_deg", "deg"): + if hasattr(batch["bus"], attr): + setattr(homo, attr, getattr(batch["bus"], attr)) + + # RRWP relative PE is stored as a dedicated edge type for correct batching + if ("bus", "rrwp", "bus") in batch.edge_types: + homo.rrwp_index = batch["bus", "rrwp", "bus"].edge_index + homo.rrwp_val = batch["bus", "rrwp", "bus"].edge_attr + + # --- Run GRIT encoder + PE encoders + transformer layers --- + homo = self.grit.encoder(homo) + if hasattr(self.grit, "rrwp_abs_encoder"): + homo = self.grit.rrwp_abs_encoder(homo) + if hasattr(self.grit, "rrwp_rel_encoder"): + homo = self.grit.rrwp_rel_encoder(homo) + homo = self.grit.layers(homo) + + # --- Per-type decoding --- + bus_out = self.bus_head(homo.x) + gen_out = ( + self.gen_head(batch["gen"].x) + if self.gen_head is not None + else batch["gen"].x + ) + + return {"bus": bus_out, "gen": gen_out} diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py new file mode 100644 index 00000000..c75a7585 --- /dev/null +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn + + +class KernelPENodeEncoder(torch.nn.Module): + """Configurable kernel-based Positional Encoding node encoder. + + The choice of which kernel-based statistics to use is configurable through + setting of `kernel_type`. Based on this, the appropriate config is selected, + and also the appropriate variable with precomputed kernel stats is then + selected from PyG Data graphs in `forward` function. + E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. + + PE of size `dim_pe` will get appended to each node feature vector. + If `expand_x` set True, original node features will be first linearly + projected to (dim_emb - dim_pe) size and the concatenated with PE. + + Args: + dim_emb: Size of final node embedding + expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) + """ + + kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE + + def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): + super().__init__() + if self.kernel_type is None: + raise ValueError( + f"{self.__class__.__name__} has to be " + f"preconfigured by setting 'kernel_type' class" + f"variable before calling the constructor.", + ) + + dim_pe = pecfg.pe_dim # Size of the kernel-based PE embedding + num_rw_steps = pecfg.kernel.times + norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type + # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + + if dim_emb - dim_pe < 1: + raise ValueError( + f"PE dim size {dim_pe} is too large for " + f"desired embedding size of {dim_emb}.", + ) + + if expand_x: + self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) + self.expand_x = expand_x + + if norm_type == "batchnorm": + self.raw_norm = nn.BatchNorm1d(num_rw_steps) + else: + self.raw_norm = None + + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + + def forward(self, batch): + pestat_var = f"pestat_{self.kernel_type}" + if not hasattr(batch, pestat_var): + raise ValueError( + f"Precomputed '{pestat_var}' variable is " + f"required for {self.__class__.__name__}; set " + f"config 'posenc_{self.kernel_type}.enable' to " + f"True, and also set 'posenc.kernel.times' values", + ) + + pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) + # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) + if self.raw_norm: + pos_enc = self.raw_norm(pos_enc) + pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe + + # Expand node features if needed + if self.expand_x: + h = self.linear_x(batch.x) + else: + h = batch.x + # Concatenate final PEs to input embedding + batch.x = torch.cat((h, pos_enc), 1) + # Keep PE also separate in a variable (e.g. for skip connections to input) + # if self.pass_as_var: + # setattr(batch, f'pe_{self.kernel_type}', pos_enc) + + return batch + + +class RWSENodeEncoder(KernelPENodeEncoder): + """Random Walk Structural Encoding node encoder.""" + + kernel_type = "RWSE" diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 00000000..116711b5 --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,237 @@ +""" +The RRWP encoder for GRIT (ours) +""" + +import torch +from torch import nn + +try: + import torch_sparse +except ImportError: + torch_sparse = None + +from torch_geometric.utils import ( + add_self_loops, +) + +try: + from torch_scatter import scatter +except ImportError: + scatter = None + +import warnings + +_MISSING_MSG = ( + "{pkg} is required for the RRWP / GRIT modules but is not installed. " + "Install it with: pip install {pkg}" +) + + +def _check_sparse(): + if torch_sparse is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-sparse")) + + +def _check_scatter(): + if scatter is None: + raise ImportError(_MISSING_MSG.format(pkg="torch-scatter")) + + +def full_edge_index(edge_index, batch=None): + """ + Return the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + _check_scatter() + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce="add") + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + + def __init__( + self, + emb_dim, + out_dim, + use_bias=False, + batchnorm=False, + layernorm=False, + pe_name="rrwp", + ): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ + + def __init__( + self, + emb_dim, + out_dim, + batchnorm=False, + layernorm=False, + use_bias=False, + pad_to_full_graph=True, + fill_value=0.0, + add_node_attr_as_self_loop=False, + overwrite_old_attr=False, + ): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr = overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn( + "batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ", + ) + + # print('--------fc in and out:', emb_dim, out_dim) + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0.0 + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + edge_index, edge_attr = add_self_loops( + edge_index, + edge_attr, + num_nodes=batch.num_nodes, + fill_value=0.0, + ) + _check_sparse() + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, + batch.num_nodes, + op="add", + ) + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + _check_sparse() + out_idx, out_val = torch_sparse.coalesce( + out_idx, + out_val, + batch.num_nodes, + batch.num_nodes, + op="add", + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"(pad_to_full_graph={self.pad_to_full_graph}," + f"fill_value={self.fill_value}," + f"{self.fc.__repr__()})" + ) diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index bc4b9bfa..41b75f20 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -82,6 +82,7 @@ def compute_shunt_power(bus_data_pred, bus_data_orig): @PHYSICS_DECODER_REGISTRY.register("OptimalPowerFlow") class PhysicsDecoderOPF(nn.Module): """Map network outputs to OPF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): mask_pv = mask_dict["PV"] mask_ref = mask_dict["REF"] @@ -117,6 +118,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("PowerFlow") class PhysicsDecoderPF(nn.Module): """Map network outputs to PF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): """ PF decoder: @@ -165,6 +167,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("StateEstimation") class PhysicsDecoderSE(nn.Module): """Map network outputs to SE targets via bus power-balance relations.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) Vm_out = bus_data_pred[:, VM_OUT] diff --git a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py index a13a5312..601a4e71 100644 --- a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py +++ b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py @@ -28,7 +28,6 @@ ) from gridfm_graphkit.tasks.pf_ac_dc_baseline import ( N_SCENARIO_PER_PARTITION, - NUM_PROCESSES, _compute_residual_stats, _compute_runtime_stats, ) @@ -65,7 +64,9 @@ def _load_test_data(data_dir: str, test_scenario_ids: list[int]): bus_df = bus_df[bus_df["scenario"].isin(test_set)].reset_index(drop=True) gen_df = gen_df[gen_df["scenario"].isin(test_set)].reset_index(drop=True) branch_df = branch_df[branch_df["scenario"].isin(test_set)].reset_index(drop=True) - runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index(drop=True) + runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index( + drop=True, + ) print( f" Loaded {len(bus_df)} bus rows, {len(gen_df)} gen rows, " @@ -85,8 +86,12 @@ def _compute_optimality_gap(gen_df: pd.DataFrame) -> dict: pg_ac = gen_df["p_mw"].to_numpy(dtype=float) pg_dc = gen_df["p_mw_dc"].to_numpy(dtype=float) g = gen_df.copy() - g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g["in_service"] # all is already in MW - g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g["in_service"] # all is already in MW + g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g[ + "in_service" + ] # all is already in MW + g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g[ + "in_service" + ] # all is already in MW per_scenario = g.groupby("scenario")[["cost_ac", "cost_dc"]].sum() cost_ac = per_scenario["cost_ac"].to_numpy(dtype=float) cost_dc = per_scenario["cost_dc"].to_numpy(dtype=float) @@ -113,26 +118,44 @@ def _compute_pg_violations(gen_df: pd.DataFrame) -> dict: def _compute_qg_violations_ac(bus_df: pd.DataFrame, gen_df: pd.DataFrame) -> dict: """Compute AC reactive-power limit violations for PV/REF buses.""" - # opf_task style on bus Qg; AC only + # opf_task style on bus Qg; AC only bus = bus_df.copy() qg = bus["Qg"].to_numpy(dtype=float) agg_gen = ( - gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] - .sum() - .reset_index()) + gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] + .sum() + .reset_index() + ) bus = bus.merge(agg_gen, on=["scenario", "bus"], how="left") - assert bus[bus["PV"]==1]["min_q_mvar"].isna().sum() == 0, "PV buses have no min_q_mvar" - assert bus[bus["PV"]==1]["max_q_mvar"].isna().sum() == 0, "PV buses have no max_q_mvar" - assert bus[bus["REF"]==1]["min_q_mvar"].isna().sum() == 0, "REF buses have no min_q_mvar" - assert bus[bus["REF"]==1]["max_q_mvar"].isna().sum() == 0, "REF buses have no max_q_mvar" - bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum(bus["min_q_mvar"] - qg, 0.0) + assert bus[bus["PV"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "PV buses have no min_q_mvar" + ) + assert bus[bus["PV"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "PV buses have no max_q_mvar" + ) + assert bus[bus["REF"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "REF buses have no min_q_mvar" + ) + assert bus[bus["REF"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "REF buses have no max_q_mvar" + ) + bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum( + bus["min_q_mvar"] - qg, + 0.0, + ) pv = bus[bus["PV"] == 1] ref = bus[bus["REF"] == 1] pv_ref = bus[(bus["PV"] == 1) | (bus["REF"] == 1)] return { - "AC Mean Qg violation PV buses": float(np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation REF buses": float(np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation": float(np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float))), + "AC Mean Qg violation PV buses": float( + np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation REF buses": float( + np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation": float( + np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float)), + ), } @@ -140,13 +163,21 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> """Compute AC/DC branch thermal and angle-limit violation statistics.""" rate = branch_df["rate_a"].to_numpy(dtype=float) ac_from = np.sqrt( - branch_df["pf"].to_numpy(dtype=float) ** 2 + branch_df["qf"].to_numpy(dtype=float) ** 2, + branch_df["pf"].to_numpy(dtype=float) ** 2 + + branch_df["qf"].to_numpy(dtype=float) ** 2, ) ac_to = np.sqrt( - branch_df["pt"].to_numpy(dtype=float) ** 2 + branch_df["qt"].to_numpy(dtype=float) ** 2, + branch_df["pt"].to_numpy(dtype=float) ** 2 + + branch_df["qt"].to_numpy(dtype=float) ** 2, + ) + dc_from = np.sqrt( + branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2, + ) # reactive part is needed here + dc_to = np.sqrt( + branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2, ) - dc_from = np.sqrt(branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2) # reactive part is needed here - dc_to = np.sqrt(branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2) ac_thermal_from = np.maximum(ac_from - rate, 0.0) ac_thermal_to = np.maximum(ac_to - rate, 0.0) @@ -155,7 +186,7 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> bus_angles = bus_df[["scenario", "bus", "Va", "Va_dc"]] # convert to radians - bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 + bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 bus_angles.loc[:, "Va_dc"] = bus_angles["Va_dc"] * np.pi / 180.0 from_angles = bus_angles.rename( columns={"bus": "from_bus", "Va": "Va_from", "Va_dc": "Va_dc_from"}, @@ -165,10 +196,12 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> ) br = branch_df.merge(from_angles, on=["scenario", "from_bus"], how="left") br = br.merge(to_angles, on=["scenario", "to_bus"], how="left") - + # AC angle ac_angle_diff = br["Va_from"] - br["Va_to"] - ac_angle_diff = (ac_angle_diff + np.pi) % (2 * np.pi) - np.pi # wrap to [-pi, pi] + ac_angle_diff = (ac_angle_diff + np.pi) % ( + 2 * np.pi + ) - np.pi # wrap to [-pi, pi] ac_angle_excess_low = np.maximum(br["ang_min"] - ac_angle_diff, 0.0) ac_angle_excess_high = np.maximum(ac_angle_diff - br["ang_max"], 0.0) mean_ac_angle_violation = np.mean(ac_angle_excess_low + ac_angle_excess_high) @@ -180,12 +213,20 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> mean_dc_angle_violation = np.mean(dc_angle_excess_low + dc_angle_excess_high) return { - "AC Mean branch thermal violation from (MVA)": float(np.nanmean(ac_thermal_from)), + "AC Mean branch thermal violation from (MVA)": float( + np.nanmean(ac_thermal_from), + ), "AC Mean branch thermal violation to (MVA)": float(np.nanmean(ac_thermal_to)), - "AC Mean branch angle difference violation (radians)": float(mean_ac_angle_violation), - "DC Mean branch thermal violation from (MVA)": float(np.nanmean(dc_thermal_from)), + "AC Mean branch angle difference violation (radians)": float( + mean_ac_angle_violation, + ), + "DC Mean branch thermal violation from (MVA)": float( + np.nanmean(dc_thermal_from), + ), "DC Mean branch thermal violation to (MVA)": float(np.nanmean(dc_thermal_to)), - "DC Mean branch angle difference violation (radians)": float(mean_dc_angle_violation), + "DC Mean branch angle difference violation (radians)": float( + mean_dc_angle_violation, + ), } @@ -256,7 +297,6 @@ def compute_opf_ac_dc_metrics( branch_df["pt_dc_computed"] = pt_dc branch_df["qf_dc_computed"] = qf_dc branch_df["qt_dc_computed"] = qt_dc - opf_extra = {} opf_extra.update(_compute_optimality_gap(gen_df)) diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index dbb1baab..fe517deb 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -87,14 +87,18 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): c2 = batch.x_dict["gen"][:, C2_H] target_pg = batch.y_dict["gen"].squeeze() pred_pg = output["gen"].squeeze() - gen_cost_gt = (c0 + c1 * target_pg + c2 * target_pg**2) # assumes all branches are on! - gen_cost_pred = (c0 + c1 * pred_pg + c2 * pred_pg**2) # assumes all branches are on! + gen_cost_gt = ( + c0 + c1 * target_pg + c2 * target_pg**2 + ) # assumes all branches are on! + gen_cost_pred = ( + c0 + c1 * pred_pg + c2 * pred_pg**2 + ) # assumes all branches are on! gen_batch = batch.batch_dict["gen"] # shape: [N_gen_total] cost_gt = scatter_add(gen_cost_gt, gen_batch, dim=0) cost_pred = scatter_add(gen_cost_pred, gen_batch, dim=0) - + optimality_gap = torch.mean(torch.abs((cost_pred - cost_gt) / cost_gt * 100)) agg_gen_on_bus = scatter_add( @@ -138,14 +142,16 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): bus_angles = output["bus"][:, VA_OUT] # in degrees from_bus = bus_edge_index[0] to_bus = bus_edge_index[1] - angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign - angle_diff = (angle_diff + torch.pi) % (2 * torch.pi) - torch.pi # wrap to [-pi, pi] + angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign + angle_diff = (angle_diff + torch.pi) % ( + 2 * torch.pi + ) - torch.pi # wrap to [-pi, pi] angle_excess_low = F.relu(angle_min - angle_diff) angle_excess_high = F.relu(angle_diff - angle_max) branch_angle_violation_mean = torch.mean( - angle_excess_low + angle_excess_high - ) # mean of the abs violation + angle_excess_low + angle_excess_high, + ) # mean of the abs violation P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( @@ -174,8 +180,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): mean_Qg_violation_PV = Qg_violation_amount[mask_PV].mean() mean_Qg_violation_REF = Qg_violation_amount[mask_REF].mean() - mask_PV_REF = mask_PV | mask_REF # PV or REF buses - mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # + mask_PV_REF = mask_PV | mask_REF # PV or REF buses + mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # if self.args.verbose: mean_res_P_PQ, max_res_P_PQ = residual_stats_by_type( @@ -242,18 +248,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Active Power Loss"] = final_residual_real_bus.detach() loss_dict["Reactive Power Loss"] = final_residual_imag_bus.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) @@ -270,7 +281,9 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) - loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + loss_dict["Mean Qg violation PV buses"] = ( + mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + ) # this is then overaged over all the batches and gives same weight to all batches despite them possibly having varying number of branches loss_dict["Mean Qg violation REF buses"] = mean_Qg_violation_REF loss_dict["Mean Qg violation"] = mean_Qg_violation @@ -372,7 +385,10 @@ def on_test_end(self): "Branch thermal violation from", " ", ) - branch_thermal_violation_to = metrics.get("Branch thermal violation to", " ") + branch_thermal_violation_to = metrics.get( + "Branch thermal violation to", + " ", + ) branch_angle_violation = metrics.get( "Branch voltage angle difference violations", " ", @@ -543,9 +559,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): local_bus_idx = torch.cat( [ torch.arange(c, device=bus_batch.device) - for c in torch.bincount(bus_batch) + for c in torch.bincount(bus_batch) ], - ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. + ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. bus_x = batch.x_dict["bus"] bus_y = batch.y_dict["bus"] diff --git a/gridfm_graphkit/tasks/pf_ac_dc_baseline.py b/gridfm_graphkit/tasks/pf_ac_dc_baseline.py index 00da4512..75572f1c 100644 --- a/gridfm_graphkit/tasks/pf_ac_dc_baseline.py +++ b/gridfm_graphkit/tasks/pf_ac_dc_baseline.py @@ -4,10 +4,6 @@ import os import numpy as np import pandas as pd -from gridfm_datakit.utils.power_balance import ( - compute_branch_powers_vectorized, - compute_bus_balance, -) N_SCENARIO_PER_PARTITION = 200 NUM_PROCESSES = 64 @@ -135,6 +131,11 @@ def compute_pf_ac_dc_metrics( bus_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, + ) + # ========================= # AC residuals # ========================= diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 948a25e0..a60e61d9 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -5,6 +5,8 @@ QG_H, VM_H, VA_H, + # Generator feature indices + PG_H, MIN_VM_H, MAX_VM_H, MIN_QG_H, @@ -26,7 +28,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch_scatter import scatter_add + +try: + from torch_scatter import scatter_add +except ImportError: + scatter_add = None + from torch_geometric.nn import global_mean_pool from gridfm_graphkit.models.utils import ( ComputeBranchFlow, @@ -38,11 +45,75 @@ import pandas as pd +def _build_bus_target(batch, num_bus): + """Build a 4-column bus-level target tensor [VM, VA, PG_agg, QG]. + + Generator PG is aggregated onto buses via scatter_add so that the + target layout matches the bus head output columns. + """ + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) + return target, gen_to_bus_index, agg_gen_on_bus + + +def _clamp_known_to_ground_truth(output_bus, target, batch, gen_to_bus_index, num_bus): + """Replace predicted values with ground truth for known (unmasked) quantities. + + During both training (PBELoss) and evaluation, the model is only + responsible for predicting masked unknowns. Known quantities (e.g. + VM at PV buses, VA at REF, PG at non-slack generators) are clamped to + ground truth so that prediction errors on non-target outputs do not + pollute the power-balance residual. + """ + mask_bus = batch.mask_dict["bus"] + eval_bus = output_bus.clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], + output_bus[:, VM_OUT], + target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], + output_bus[:, VA_OUT], + target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, + output_bus[:, PG_OUT], + target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], + output_bus[:, QG_OUT], + target[:, QG_OUT], + ) + return eval_bus + + @TASK_REGISTRY.register("PowerFlow") class PowerFlowTask(ReconstructionTask): """ - Concrete Optimal Power Flow task. - Extends ReconstructionTask and adds OPF-specific metrics. + Concrete Power Flow task. + Extends ReconstructionTask and adds PF-specific evaluation metrics + (power balance residuals, per-bus-type RMSE). """ def __init__(self, args, data_normalizers): @@ -62,34 +133,22 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], + target, + batch, gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - # output_agg = torch.cat([batch.y_dict["bus"], agg_gen_on_bus], dim=1) - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, + num_bus, ) - # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH - # output["bus"] = target - - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) @@ -170,18 +229,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["PBE Mean"] = pbe_mean.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) @@ -245,7 +309,7 @@ def on_test_end(self): # Only rank 0 proceeds with logging, CSV writing, and plotting if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: - self.test_outputs.clear() # clear the test outputs for other ranks + self.test_outputs.clear() # clear the test outputs for other ranks return if isinstance(self.logger, MLFlowLogger): @@ -356,25 +420,56 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) # get the predicted output from the model - - self.data_normalizers[dataloader_idx].inverse_transform(batch) # normalize the batch data back to the original scale - self.data_normalizers[dataloader_idx].inverse_output(output, batch) # inverse transform the predicted output back to the original scale - - branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows - node_injection_layer = ComputeNodeInjection() # layer to compute the node injections - node_residuals_layer = ComputeNodeResiduals() # layer to compute the node residuals - - num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections + output, _ = self.shared_step(batch) # get the predicted output from the model + + self.data_normalizers[dataloader_idx].inverse_transform( + batch, + ) # normalize the batch data back to the original scale + self.data_normalizers[dataloader_idx].inverse_output( + output, + batch, + ) # inverse transform the predicted output back to the original scale + + branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows + node_injection_layer = ( + ComputeNodeInjection() + ) # layer to compute the node injections + node_residuals_layer = ( + ComputeNodeResiduals() + ) # layer to compute the node residuals + + num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch + bus_edge_index = batch.edge_index_dict[ + ("bus", "connects", "bus") + ] # from and to buses + bus_edge_attr = batch.edge_attr_dict[ + ("bus", "connects", "bus") + ] # edge attributes (admittance) of the bus connections + + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], + target, + batch, + gen_to_bus_index, + num_bus, + ) - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) # compute the branch flows - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections - residual_P, residual_Q = node_residuals_layer( # compute the node residuals + Pft, Qft = branch_flow_layer( + eval_bus, + bus_edge_index, + bus_edge_attr, + ) # compute the branch flows + P_in, Q_in = node_injection_layer( + Pft, + Qft, + bus_edge_index, + num_bus, + ) # compute the node injections + residual_P, residual_Q = node_residuals_layer( # compute the node residuals P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) residual_P = torch.abs(residual_P) @@ -388,7 +483,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch) ], - ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. + ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order and we have an assert in the dataset to check it. bus_x = batch.x_dict["bus"] @@ -397,14 +492,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): mask_PV = batch.mask_dict["PV"] mask_REF = batch.mask_dict["REF"] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 45975aee..f6c4501c 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -39,16 +39,11 @@ def __init__(self, args, data_normalizers): self.batch_size = int(args.training.batch_size) self.test_outputs = {i: [] for i in range(len(args.data.networks))} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): - return self.model(x_dict, edge_index_dict, edge_attr_dict, mask_dict) + def forward(self, batch): + return self.model(batch) def shared_step(self, batch): - output = self.forward( - x_dict=batch.x_dict, - edge_index_dict=batch.edge_index_dict, - edge_attr_dict=batch.edge_attr_dict, - mask_dict=batch.mask_dict, - ) + output = self.forward(batch) loss_dict = self.loss_fn( output, diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 36667ad2..78aa0fa5 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -27,6 +27,7 @@ @TASK_REGISTRY.register("StateEstimation") class StateEstimationTask(ReconstructionTask): """State-estimation task with evaluation plots for masked and noisy measurements.""" + def __init__(self, args, data_normalizers): super().__init__(args, data_normalizers) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index ba7a4049..116aa4f0 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -42,6 +42,7 @@ def last_epoch_iters_per_sec(self) -> float | None: class SaveBestModelStateDict(Callback): """Persist the best model state_dict according to a monitored validation metric.""" + def __init__( self, monitor: str, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index a0521fc2..acbceb78 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from gridfm_graphkit.io.registries import LOSS_REGISTRY from torch_scatter import scatter_add +from torch_geometric.utils import to_torch_coo_tensor from gridfm_graphkit.datasets.globals import ( # Bus feature indices @@ -12,15 +13,24 @@ VA_H, QD_H, PD_H, + GS, + BS, # Output feature indices VM_OUT, VA_OUT, QG_OUT, PG_OUT, + PD_OUT, + QD_OUT, # Generator feature indices PG_H, + # Edge feature indices + YFF_TT_R, + YFF_TT_I, + YFT_TF_R, + YFT_TF_I, # Qg Limits - MIN_QG_H, + MIN_QG_H, MAX_QG_H, ) @@ -85,6 +95,7 @@ def forward( @LOSS_REGISTRY.register("MaskedGenMSE") class MaskedGenMSE(torch.nn.Module): """Compute MSE on generator targets restricted to generator mask entries.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -99,9 +110,12 @@ def forward( model=None, x_dict=None, ): + gen_pred = pred_dict["gen"][:, : (PG_H + 1)] + gen_target = target_dict["gen"][:, : (PG_H + 1)] + mask = mask_dict["gen"][:, : (PG_H + 1)] loss = F.mse_loss( - pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], - target_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], + gen_pred[mask], + gen_target[mask], reduction=self.reduction, ) return {"loss": loss, "Masked generator MSE loss": loss.detach()} @@ -110,6 +124,7 @@ def forward( @LOSS_REGISTRY.register("MaskedBusMSE") class MaskedBusMSE(torch.nn.Module): """Compute MSE on selected bus targets, respecting task-specific output columns.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -145,6 +160,91 @@ def forward( return {"loss": loss, "Masked bus MSE loss": loss.detach()} +@LOSS_REGISTRY.register("MaskedReconstructionMSE") +class MaskedReconstructionMSE(BaseLoss): + """Unified masked MSE over bus-level quantities [VM, VA, PG, QG, PD, QD]. + + Mirrors the homogeneous reference MaskedMSE by combining bus predictions + and aggregated generator PG into a single prediction/target/mask tensor. + PG targets are aggregated from generator ground truth onto buses via + scatter_add; the bus-level PG mask is True when any generator at the bus + is masked, indicating that the model must reconstruct that quantity. + + Replaces the separate MaskedBusMSE + MaskedGenMSE pair. + Requires output_bus_dim >= 6 so the bus head predicts + [VM, VA, PG, QG, PD, QD]. + """ + + def __init__(self, loss_args, args): + super().__init__() + self.reduction = "mean" + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + x_dict=None, + ): + pred_bus = pred_dict["bus"] + target_bus = target_dict["bus"] + num_bus = target_bus.size(0) + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + + # --- Build target: [VM, VA, PG_agg, QG, PD, QD] --- + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + target = torch.stack( + [ + target_bus[:, VM_H], + target_bus[:, VA_H], + target_pg_agg, + target_bus[:, QG_H], + target_bus[:, PD_H], + target_bus[:, QD_H], + ], + dim=1, + ) + + # --- Build mask: [N_bus, 6] --- + # PG bus-level mask: True if any generator at the bus has PG masked + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + > 0 + ) + + mask = torch.stack( + [ + mask_dict["bus"][:, VM_H], + mask_dict["bus"][:, VA_H], + any_gen_masked, + mask_dict["bus"][:, QG_H], + mask_dict["bus"][:, PD_H], + mask_dict["bus"][:, QD_H], + ], + dim=1, + ) + + # --- Prediction: [VM, VA, PG, QG, PD, QD] from bus head --- + pred = pred_bus[:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT, PD_OUT, QD_OUT]] + + loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) + return {"loss": loss, "Masked reconstruction MSE loss": loss.detach()} + + @LOSS_REGISTRY.register("MSE") class MSELoss(BaseLoss): """Standard Mean Squared Error loss.""" @@ -242,6 +342,7 @@ def forward( @LOSS_REGISTRY.register("LayeredWeightedPhysics") class LayeredWeightedPhysicsLoss(BaseLoss): """Combine intermediate physics residuals using normalized geometric weights.""" + def __init__(self, loss_args, args) -> None: super().__init__() self.base_weight = loss_args.base_weight @@ -283,6 +384,7 @@ def forward( @LOSS_REGISTRY.register("LossPerDim") class LossPerDim(BaseLoss): """Compute MAE/MSE for one named physical dimension of bus outputs.""" + def __init__(self, loss_args, args): super(LossPerDim, self).__init__() self.reduction = "mean" @@ -340,6 +442,200 @@ def forward( } +@LOSS_REGISTRY.register("PBE") +class PBELoss(BaseLoss): + """ + Loss based on the Power Balance Equations. + + Adapted for the heterogeneous graph convention: predictions and targets + are passed as dicts (``{"bus": …, "gen": …}``). Generator active power + is aggregated onto bus nodes via the ``(gen, connected_to, bus)`` edge + index before computing the power balance. + """ + + def __init__(self, loss_args, args): + super(PBELoss, self).__init__() + self.visualization = getattr(loss_args, "visualization", False) + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + x_dict=None, + ): + pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] + target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] + num_bus = target_bus.size(0) + + bus_edge_index = edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] + mask_bus = mask_dict["bus"] + + # --- Clamp known values to ground truth --- + # In power flow, certain variables are "known" (unmasked) at each + # bus type (e.g. VM at PV buses, VA at REF). The model only needs + # to predict *masked* unknowns; for everything else we substitute + # the ground truth so that errors in non-target outputs do not + # pollute the physics loss. This matches the reference's + # ``temp_pred[unmasked] = target[unmasked]`` convention. + + Vm_pred = pred_bus[:, VM_OUT] + Va_pred = pred_bus[:, VA_OUT] + Vm_target = target_bus[:, VM_H] + Va_target = target_bus[:, VA_H] + + mask_Vm = mask_bus[:, VM_H] + mask_Va = mask_bus[:, VA_H] + + V_m = torch.where(mask_Vm, Vm_pred, Vm_target) + V_a = torch.where(mask_Va, Va_pred, Va_target) + + # Complex voltage + V = V_m * torch.exp(1j * V_a) + V_conj = torch.conj(V) + + # --- Admittance matrix from bus-bus edge attrs --- + # The Y-bus matrix has off-diagonal AND diagonal entries. + # + # Off-diagonal: Y[from][to] = Yft, Y[to][from] = Ytf, stored in the + # YFT_TF columns of the edge attributes. + # + # Diagonal: Y[k][k] = sum of Yff/Ytt for all branches at bus k. + # The dataset stores Yff (forward edges) and Ytt (reverse edges) in + # the YFF_TT columns. For every edge, YFF_TT at the *source* bus + # gives that branch's diagonal contribution at the source. Summing + # over all edges with source == k yields the full branch-diagonal. + # + # The reference project loads a pre-built Y-bus (y_bus_data.parquet) + # that includes self-loops for diagonal entries. Here we reconstruct + # the same structure from per-branch pi-model parameters. + + # Off-diagonal admittance values + edge_offdiag = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + + # Diagonal: aggregate Yff/Ytt to source bus of each edge + Y_diag_r = scatter_add( + bus_edge_attr[:, YFF_TT_R], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag_i = scatter_add( + bus_edge_attr[:, YFF_TT_I], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag = Y_diag_r + 1j * Y_diag_i + + # Add bus shunt admittance (Gs + jBs) to the diagonal + if x_dict is not None: + bus_orig = x_dict["bus"] + Y_diag = Y_diag + bus_orig[:, GS] + 1j * bus_orig[:, BS] + + # Build complete Y-bus: off-diagonal edges + self-loops for diagonal + diag_idx = torch.arange(num_bus, device=bus_edge_index.device) + full_edge_index = torch.cat( + [bus_edge_index, torch.stack([diag_idx, diag_idx])], + dim=1, + ) + full_edge_values = torch.cat([edge_offdiag, Y_diag]) + + Y_bus_sparse = to_torch_coo_tensor( + full_edge_index, + full_edge_values, + size=(num_bus, num_bus), + ) + Y_bus_conj = torch.conj(Y_bus_sparse) + + # Complex power injection: S_inj = diag(V) * conj(Y) * conj(V) + S_injection = torch.diag(V) @ Y_bus_conj @ V_conj + + # --- Net power from predictions/targets --- + # Pg: use bus head prediction where masked, ground truth where known. + # Ground truth is aggregated from generator targets onto buses. + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + > 0 + ) + Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) + + # Pd, Qd, Qg: same clamp-to-ground-truth logic. The size guard + # (``pred_bus.size(1) > *_OUT``) handles models with a narrow bus + # head (e.g. output_bus_dim=4) that don't predict PD/QD/QG; in that + # case the target is always used. + if pred_bus.size(1) > PD_OUT: + Pd = torch.where( + mask_bus[:, PD_H], + pred_bus[:, PD_OUT], + target_bus[:, PD_H], + ) + else: + Pd = target_bus[:, PD_H] + if pred_bus.size(1) > QD_OUT: + Qd = torch.where( + mask_bus[:, QD_H], + pred_bus[:, QD_OUT], + target_bus[:, QD_H], + ) + else: + Qd = target_bus[:, QD_H] + if pred_bus.size(1) > QG_OUT: + Qg = torch.where( + mask_bus[:, QG_H], + pred_bus[:, QG_OUT], + target_bus[:, QG_H], + ) + else: + Qg = target_bus[:, QG_H] + + net_P = Pg_per_bus - Pd + net_Q = Qg - Qd + S_net = net_P + 1j * net_Q + + # --- Loss --- + loss = torch.mean(torch.abs(S_net - S_injection)) + + real_loss = torch.mean( + torch.abs(torch.real(S_net - S_injection)), + ) + imag_loss = torch.mean( + torch.abs(torch.imag(S_net - S_injection)), + ) + + result = { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss.detach(), + "Reactive Power Loss in p.u.": imag_loss.detach(), + } + if self.visualization: + result["Nodal Active Power Loss in p.u."] = torch.abs( + torch.real(S_net - S_injection), + ) + result["Nodal Reactive Power Loss in p.u."] = torch.abs( + torch.imag(S_net - S_injection), + ) + return result + + @LOSS_REGISTRY.register("QgViolationPenalty") class QgViolationPenaltyLoss(BaseLoss): """Standard Mean Squared Error loss.""" @@ -362,12 +658,8 @@ def forward( Qg_max = x_dict["bus"][:, MAX_QG_H] Qg_min = x_dict["bus"][:, MIN_QG_H] - max_penalty_mask = (Qg_pred > Qg_max) - min_penalty_mask = (Qg_pred < Qg_min) - - mask_PQ = mask["PQ"] # PQ buses - mask_PV = mask["PV"] # PV buses - mask_REF = mask["REF"] # Reference buses + max_penalty_mask = Qg_pred > Qg_max + min_penalty_mask = Qg_pred < Qg_min loss = 0.0 # where there are violations, compute penalty loss @@ -376,19 +668,18 @@ def forward( Qg_over = Qg_over[max_penalty_mask].mean() Qg_under = Qg_under[min_penalty_mask].mean() - - if Qg_over!=Qg_over: # replacing nan with 0 + + if Qg_over != Qg_over: # replacing nan with 0 Qg_over = 0.0 - if Qg_under!=Qg_under: # replacing nan with 0 + if Qg_under != Qg_under: # replacing nan with 0 Qg_under = 0.0 - penalty_loss = Qg_over + Qg_under + penalty_loss = Qg_over + Qg_under loss += penalty_loss try: output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} - except: + except Exception: output = {"loss": loss, "Qg Violation Penalty loss": loss} return output - diff --git a/integrationtests/conftest.py b/integrationtests/conftest.py index 01737f5e..75e3debb 100644 --- a/integrationtests/conftest.py +++ b/integrationtests/conftest.py @@ -28,4 +28,3 @@ def calibrate_runs(request): def ci_level(request): """Confidence interval level requested via --ci (default 0.995).""" return request.config.getoption("--ci") - diff --git a/integrationtests/generate_test_data.py b/integrationtests/generate_test_data.py index fba2624b..c205d03d 100644 --- a/integrationtests/generate_test_data.py +++ b/integrationtests/generate_test_data.py @@ -29,7 +29,9 @@ def _base_config() -> dict: return config -def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") -> None: +def generate_pf_test_data( + config_path: str = "integrationtests/default_pf.yaml", +) -> None: """ Generate power-flow (PF) test data for case14_ieee with 10 000 scenarios and 2 topology variants. @@ -42,12 +44,16 @@ def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") print(f"PF config written to {config_path}") print(f" network.name : {config['network']['name']}") print(f" load.scenarios : {config['load']['scenarios']}") - print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}") + print( + f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}", + ) execute_and_live_output(f"gridfm_datakit generate {config_path}") -def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml") -> None: +def generate_opf_test_data( + config_path: str = "integrationtests/default_opf.yaml", +) -> None: """ Generate optimal power-flow (OPF) test data for case14_ieee with 10 000 scenarios and 2 topology variants. @@ -61,12 +67,14 @@ def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml print(f"OPF config written to {config_path}") print(f" network.name : {config['network']['name']}") print(f" load.scenarios : {config['load']['scenarios']}") - print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}") + print( + f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}", + ) print(f" settings.mode : {config['settings']['mode']}") execute_and_live_output(f"gridfm_datakit generate {config_path}") if __name__ == "__main__": - #generate_pf_test_data() + # generate_pf_test_data() generate_opf_test_data() diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 2dfe4d8f..e4ec862b 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -24,13 +24,22 @@ def collect_metrics_from_log(log_base: str, metric_keys: list) -> dict: run_dirs = glob.glob(os.path.join(latest_exp_dir, "*")) assert len(run_dirs) > 0, f"No run directories found in {latest_exp_dir}" latest_run_dir = max(run_dirs, key=os.path.getmtime) - metrics_file = os.path.join(latest_run_dir, "artifacts", "test", "case14_ieee_metrics.csv") + metrics_file = os.path.join( + latest_run_dir, + "artifacts", + "test", + "case14_ieee_metrics.csv", + ) assert os.path.exists(metrics_file), f"Metrics file not found: {metrics_file}" df = pd.read_csv(metrics_file) return dict(zip(df["Metric"], df["Value"].astype(float))) -def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interval: float = 0.995) -> None: +def print_calibration_stats( + all_runs: list, + metric_keys: list, + confidence_interval: float = 0.995, +) -> None: """ Print per-metric stats across calibration runs: - std with Bessel's correction (ddof=1) @@ -49,7 +58,9 @@ def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interv ci_pct = f"{confidence_interval * 100:g}" col_w = max(len(k) for k in metric_keys) + 2 header = f" {'Metric':<{col_w}} {'Mean':>10} {'Std(ddof=1)':>12} {f'CI {ci_pct}% lo':>10} {f'CI {ci_pct}% hi':>10}" - print(f"\n===== Calibration Results (n={n}, CI={confidence_interval}, t_crit={t_crit:.4f}) =====") + print( + f"\n===== Calibration Results (n={n}, CI={confidence_interval}, t_crit={t_crit:.4f}) =====", + ) print(header) print(" " + "-" * (len(header) - 2)) for key in metric_keys: @@ -63,7 +74,7 @@ def print_calibration_stats(all_runs: list, metric_keys: list, confidence_interv me = t_crit * std / np.sqrt(len(arr)) # margin of error lo, hi = mean - me, mean + me print( - f" {key:<{col_w}} {mean:>10.4f} {std:>12.4f} {lo:>10.4f} {hi:>10.4f}" + f" {key:<{col_w}} {mean:>10.4f} {std:>12.4f} {lo:>10.4f} {hi:>10.4f}", ) print("=" * (len(header)) + "\n") @@ -88,7 +99,9 @@ def prepare_training_config(): with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - print(f"Training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}") + print( + f"Training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}", + ) return config_path @@ -113,7 +126,9 @@ def prepare_opf_training_config(): with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - print(f"OPF training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}") + print( + f"OPF training config updated: epochs set to {config['training']['epochs']}, hidden_size set to {config['model']['hidden_size']}", + ) return config_path @@ -207,7 +222,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): last_error = None for attempt in range(1, MAX_RETRIES + 1): if attempt > 1: - print(f"\n--- PF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---") + print( + f"\n--- PF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---", + ) execute_and_live_output( f"gridfm_graphkit train " f"--config {training_config_path} " @@ -225,7 +242,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): assert 0.2042 <= pbe_mean_value <= 0.6397, ( f"PBE Mean value {pbe_mean_value} is outside 95% CI [0.2042, 0.6397]" ) - print(f"PBE Mean value {pbe_mean_value} is within 95% CI [0.2042, 0.6397] (attempt {attempt})") + print( + f"PBE Mean value {pbe_mean_value} is within 95% CI [0.2042, 0.6397] (attempt {attempt})", + ) last_error = None break except AssertionError as e: @@ -278,7 +297,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): opf_data_dir = "data_out_opf" if not os.path.exists(opf_data_dir) or not os.listdir(opf_data_dir): - print("OPF data directory not found or empty, downloading pre-generated data...") + print( + "OPF data directory not found or empty, downloading pre-generated data...", + ) gdrive_file_id = "1p5f5mRvmBQh8lZpIyWWbTbU42aHAIsdT" # pragma: allowlist secret zip_filename = "case14_ieee.10000_scenarios_2_variants_opf.zip" @@ -335,7 +356,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): last_error = None for attempt in range(1, MAX_RETRIES + 1): if attempt > 1: - print(f"\n--- OPF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---") + print( + f"\n--- OPF Retry attempt {attempt}/{MAX_RETRIES} after metric interval failure ---", + ) execute_and_live_output( f"gridfm_graphkit train " f"--config {training_config_path} " @@ -350,12 +373,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): try: for metric_name, (lo, hi) in checks.items(): - assert metric_name in metrics, f"Metric '{metric_name}' not found in CSV" + assert metric_name in metrics, ( + f"Metric '{metric_name}' not found in CSV" + ) value = metrics[metric_name] assert lo <= value <= hi, ( f"Metric '{metric_name}' value {value} is outside 99.5% CI [{lo}, {hi}]" ) - print(f"{metric_name}: {value} is within 99.5% CI [{lo}, {hi}] (attempt {attempt})") + print( + f"{metric_name}: {value} is within 99.5% CI [{lo}, {hi}] (attempt {attempt})", + ) last_error = None break except AssertionError as e: diff --git a/pyproject.toml b/pyproject.toml index b4e8307b..9e7bb04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "nbformat>=5.10.4", "networkx>=3.4.2", "numpy>=2.2.6", + "opt-einsum>=3.3.0", "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py new file mode 100644 index 00000000..01b1e6cd --- /dev/null +++ b/scripts/benchmark_model_inference.py @@ -0,0 +1,603 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. + +Supports two model types via --model flag: + - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) + - "grit": GritHeteroAdapter with HeteroData (bus + gen nodes, optional PE attrs) + +Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): + +###################################### + +CONF_PATH=../examples/config +OUT_DIR=../scripts +mkdir $OUT_DIR + +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true + +###################################### + +Example usage — GRIT (HeteroData with PE, --num_gens required): + +###################################### + +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true + +###################################### + +Author(s): Mangaliso M. - mngomezulum@ibm.com + Matteo M. - Not Available +""" + +import os +import time +import csv +import yaml +import torch +import argparse +import platform +from datetime import datetime +from torch_geometric.loader import DataLoader +from torch_geometric.data import HeteroData +from gridfm_graphkit.io.param_handler import NestedNamespace, load_model +import logging + +# Optional: tqdm (imported but not required for core flow) +try: + from tqdm import tqdm # noqa: F401 +except Exception: + pass + +# Compilation (kept from original) +import torch._dynamo as dynamo + +dynamo.config.suppress_errors = False + +# ---------------------------- +# Argument Parsing +# ---------------------------- +parser = argparse.ArgumentParser( + description="Benchmark GNN Model inference with profiling CSV", +) +parser.add_argument( + "--model", + type=str, + choices=["hetero", "grit"], + default="hetero", + help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer", +) +parser.add_argument( + "--config", + type=str, + required=True, + help="Path to config YAML for model", +) +parser.add_argument("--num_nodes", type=int, required=True) +parser.add_argument( + "--num_gens", + type=int, + default=0, + help="Number of generator nodes (required for hetero, ignored for grit)", +) +parser.add_argument("--num_edges", type=int, required=True) +parser.add_argument("--output_csv", type=str, required=True) +parser.add_argument("--iterations", type=int, default=20) +parser.add_argument("--num_workers", type=int, default=0, help="DataLoader num_workers") +parser.add_argument( + "--pin_memory", + action="store_true", + help="Enable pin_memory in DataLoader when CUDA is available", +) +args = parser.parse_args() + +# --- Custom logging (ensure directory exists) + +os.makedirs("logs", exist_ok=True) +logger = logging.getLogger("ibm_benchmark_logger") +logger.setLevel(logging.DEBUG) +logger.propagate = False +file_handler = logging.FileHandler( + "logs/ibm_bench_logs.log", + mode="a", +) # 'a' for append, 'w' to overwrite +file_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +file_handler.setFormatter(formatter) +if not logger.handlers: + logger.addHandler(file_handler) + +# ---------------------------- +# Load Model +# ---------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +with open(args.config, "r") as f: + base_config = yaml.safe_load(f) + +config_args = NestedNamespace(**base_config) +model = load_model(config_args).to(device).eval() +tot_params = sum(p.numel() for p in model.parameters() if p.requires_grad) +print("**Total model trainable params: {}".format(tot_params)) + +# ---------------------------- +# Parameters +# ---------------------------- +MODEL_TYPE = args.model +N_BUS = args.num_nodes +E = args.num_edges + +# Default num_gens when not provided (shell script omits --num_gens) +N_GEN = args.num_gens if args.num_gens > 0 else max(1, N_BUS // 5) + +EDGE_FEATS = getattr(config_args.model, "edge_dim", 10) + +# Both model types use HeteroData with bus + gen nodes. +# Fall back to input_dim / defaults for configs that lack the hetero keys. +BUS_FEATS = getattr( + config_args.model, + "input_bus_dim", + getattr(config_args.model, "input_dim", 15), +) +GEN_FEATS = getattr(config_args.model, "input_gen_dim", 6) + +if MODEL_TYPE == "grit": + # Positional encoding config (only GRIT uses these) + # Read enablement and dimensions from data config (canonical source). + RRWP_ENABLED = ( + getattr(config_args.data.posenc_RRWP, "enable", False) + if hasattr(config_args.data, "posenc_RRWP") + else False + ) + RRWP_KSTEPS = ( + getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 + ) + RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr( + config_args.data.posenc_RWSE, + "enable", + False, + ) + RWSE_TIMES = ( + getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + ) +else: + RRWP_ENABLED = False + RRWP_KSTEPS = 0 + RWSE_ENABLED = False + RWSE_TIMES = 0 + +# Keep original batch sizes list +batch_sizes = [1, 2, 4, 8, 16, 32] +iterations = args.iterations + + +# ---------------------------- +# Helpers +# ---------------------------- +def now_ms() -> float: + return time.perf_counter() * 1000.0 + + +def maybe_cuda_sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def get_env_info(): + # CPU name detection + cpu_name = None + try: + cpu_name = platform.processor() or None + if not cpu_name and os.path.exists("/proc/cpuinfo"): + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_name = line.strip().split(":", 1)[1].strip() + break + if not cpu_name: + cpu_name = platform.uname().machine + except Exception: + cpu_name = "unknown" + + # GPU names and device info + if torch.cuda.is_available(): + try: + gpu_names_list = [ + torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) + ] + gpu_names = "; ".join(gpu_names_list) + except Exception: + gpu_names = "cuda_available_but_name_unreadable" + device_type = "cuda" + device_name = ( + torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + ) + cuda_version_in_torch = torch.version.cuda + cudnn_version = ( + torch.backends.cudnn.version() + if torch.backends.cudnn.is_available() + else None + ) + else: + # Apple Metal backend? + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + device_type = "mps" + device_name = "Apple MPS" + gpu_names = "mps" + cuda_version_in_torch = None + cudnn_version = None + else: + device_type = "cpu" + device_name = "cpu" + gpu_names = "none" + cuda_version_in_torch = None + cudnn_version = None + + info = { + "device_type": device_type, + "device_name": device_name, + "gpu_names": gpu_names, + "cpu_name": cpu_name, + "torch_version": torch.__version__, + "cuda_version_in_torch": cuda_version_in_torch, + "cudnn_version": cudnn_version, + "python_version": platform.python_version(), + } + return info + + +# ---------------------------- +# Generate Synthetic Hetero Graph +# ---------------------------- +def generate_hetero_graph(): + """ + Generates a dummy heterogeneous power network graph for benchmarking. + + Returns: + data (HeteroData): single self-contained heterogeneous graph with: + - data["bus"].x, data["gen"].x + - edge_index & edge_attr for all relations + - mask_dict inside data.mask_dict + """ + data = HeteroData() + + # Node features + data["bus"].x = torch.randn(N_BUS, BUS_FEATS) + data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + + # Dummy targets + data["bus"].y = torch.randn(N_BUS, BUS_FEATS) + data["gen"].y = torch.randn(N_GEN, GEN_FEATS) + + # Edges: Bus–Bus + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data["bus", "connects", "bus"].edge_index = torch.stack([src, dst], dim=0) + data["bus", "connects", "bus"].edge_attr = torch.randn(E, EDGE_FEATS) + + # Edges: Gen–Bus & Bus–Gen + gen_to_bus = torch.randint(0, N_BUS, (N_GEN,)) + + # Gen → Bus + data["gen", "connected_to", "bus"].edge_index = torch.stack( + [torch.arange(N_GEN), gen_to_bus], + dim=0, + ) + + # Bus → Gen + data["bus", "connected_to", "gen"].edge_index = torch.stack( + [gen_to_bus, torch.arange(N_GEN)], + dim=0, + ) + + # No edge features for these + data["gen", "connected_to", "bus"].edge_attr = None + data["bus", "connected_to", "gen"].edge_attr = None + + # Dummy masks (all True) + mask_bus = torch.ones_like(data["bus"].x, dtype=torch.bool) + mask_gen = torch.ones_like(data["gen"].x, dtype=torch.bool) + bus_types = torch.randint(0, 3, (N_BUS,)) + mask_branch = torch.ones_like( + data["bus", "connects", "bus"].edge_attr, + dtype=torch.bool, + ) + + mask_PQ = bus_types == 0 + mask_PV = bus_types == 1 + mask_REF = bus_types == 2 + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + "branch": mask_branch, + } + return data + + +# ---------------------------- +# Generate Synthetic Homogeneous Graph (GRIT) +# ---------------------------- +def generate_grit_graph(): + """ + Generates a dummy heterogeneous graph for GRIT benchmarking. + + GritHeteroAdapter expects a HeteroData batch, so we generate the same + structure as generate_hetero_graph() but also attach PE attributes on + the bus node store when RWSE / RRWP are enabled. + + Returns: + data (HeteroData): heterogeneous graph with bus & gen nodes, + plus optional PE attributes on data["bus"]. + """ + data = generate_hetero_graph() + + # RWSE positional encoding on bus nodes + if RWSE_ENABLED: + data["bus"].pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + + # RRWP positional / structural encoding on bus nodes + if RRWP_ENABLED: + data["bus"].rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing bus-bus edges + self-loops + bb_ei = data["bus", "connects", "bus"].edge_index + self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) + rrwp_idx = torch.cat([bb_ei, self_loops], dim=1) + rrwp_nnz = rrwp_idx.size(1) + data["bus"].rrwp_index = rrwp_idx + data["bus"].rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + + return data + + +# ---------------------------- +# Benchmark Function +# ---------------------------- +def benchmark(): + # Environment/context info (constant per run) + env = get_env_info() + timestamp = datetime.now().isoformat(timespec="seconds") + + # Measure synthetic graph creation + t0 = now_ms() + if MODEL_TYPE == "hetero": + data = generate_hetero_graph() + else: + data = generate_grit_graph() + t1 = now_ms() + data_gen_time_ms = t1 - t0 + + # Move the base graph to device (preserve original behavior) + maybe_cuda_sync() + t2 = now_ms() + data = data.to(device) + maybe_cuda_sync() + t3 = now_ms() + graph_to_device_time_ms = t3 - t2 + + batch_sizes_used = [] + times = [] + + header = [ + # Keep original first two columns + "batch_size", + "avg_time_per_sample_ms", + # Execution config + "num_iters", + "total_samples", + # Data/IO timing + "data_gen_time_ms", + "graph_to_device_time_ms", + "clone_list_time_ms", + "dataloader_create_time_ms", + "dataloader_first_iter_time_ms", + "batch_to_device_time_ms", + # Model timing + "warmup_time_ms", + "iter_total_wall_time_ms", + "iter_gpu_time_ms", + "gpu_idle_time_ms", + "gpu_busy_ratio", + "samples_per_sec_wall", + "samples_per_sec_gpu", + "timing_source", # "cuda_event" or "wall_clock" + # Memory + "max_cuda_mem_alloc_bytes", + "max_cuda_mem_reserved_bytes", + # Graph & model context + "n_bus", + "n_gen", + "n_edges", + "bus_feats", + "gen_feats", + "edge_feats", + # Runtime context + "device_type", + "device_name", + "torch_version", + "cuda_version_in_torch", + "cudnn_version", + "python_version", + "cpu_name", # NEW + "gpu_names", # NEW + "timestamp_iso", + "num_workers", + "pin_memory", + ] + + with open(args.output_csv, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(header) + + for batch_size in batch_sizes: + # Build list of graphs (on device, preserving original flow) + maybe_cuda_sync() + t_clone_start = now_ms() + data_list = [data.clone() for _ in range(batch_size)] + maybe_cuda_sync() + t_clone_end = now_ms() + clone_list_time_ms = t_clone_end - t_clone_start + + # Create DataLoader + pin_mem = args.pin_memory and torch.cuda.is_available() + persistent = args.num_workers > 0 + t_dl_create_start = now_ms() + loader = DataLoader( + data_list, + batch_size=batch_size, + num_workers=args.num_workers, + pin_memory=pin_mem, + persistent_workers=persistent, + ) + t_dl_create_end = now_ms() + dataloader_create_time_ms = t_dl_create_end - t_dl_create_start + + # Fetch first batch (collate) + t_iter_start = now_ms() + batch = next(iter(loader)) + t_iter_end = now_ms() + dataloader_first_iter_time_ms = t_iter_end - t_iter_start + + # Ensure batch on device (likely ~0 if items already on device) + maybe_cuda_sync() + t_b2d_start = now_ms() + batch = ( + batch.to(device, non_blocking=True) + if torch.cuda.is_available() + else batch.to(device) + ) + maybe_cuda_sync() + t_b2d_end = now_ms() + batch_to_device_time_ms = t_b2d_end - t_b2d_start + + test_model = model + + # Warmup (excluded from main timing) + maybe_cuda_sync() + t_warmup_start = now_ms() + with torch.no_grad(): + for _ in range(5): + _ = test_model(batch.clone()) + maybe_cuda_sync() + t_warmup_end = now_ms() + warmup_time_ms = t_warmup_end - t_warmup_start + + num_iters = iterations + total_samples = batch_size * num_iters + + # Reset CUDA memory stats and set up timing + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(device) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Iteration timing + maybe_cuda_sync() + wall_start = now_ms() + with torch.no_grad(): + if torch.cuda.is_available(): + start_event.record() + for _ in range(num_iters): + _ = test_model(batch.clone()) + if torch.cuda.is_available(): + end_event.record() + maybe_cuda_sync() + wall_end = now_ms() + + iter_total_wall_time_ms = wall_end - wall_start + + if torch.cuda.is_available(): + iter_gpu_time_ms = float(start_event.elapsed_time(end_event)) # ms + timing_source = "cuda_event" + avg_time_per_sample_ms = iter_gpu_time_ms / total_samples + gpu_idle_time_ms = max(iter_total_wall_time_ms - iter_gpu_time_ms, 0.0) + gpu_busy_ratio = ( + (iter_gpu_time_ms / iter_total_wall_time_ms) + if iter_total_wall_time_ms > 0 + else None + ) + max_cuda_mem_alloc_bytes = int(torch.cuda.max_memory_allocated(device)) + max_cuda_mem_reserved_bytes = int( + torch.cuda.max_memory_reserved(device), + ) + samples_per_sec_gpu = ( + (total_samples / (iter_gpu_time_ms / 1000.0)) + if iter_gpu_time_ms > 0 + else None + ) + else: + iter_gpu_time_ms = None + timing_source = "wall_clock" + avg_time_per_sample_ms = iter_total_wall_time_ms / total_samples + gpu_idle_time_ms = None + gpu_busy_ratio = None + max_cuda_mem_alloc_bytes = None + max_cuda_mem_reserved_bytes = None + samples_per_sec_gpu = None + + samples_per_sec_wall = ( + (total_samples / (iter_total_wall_time_ms / 1000.0)) + if iter_total_wall_time_ms > 0 + else None + ) + + # Prepare row + row = [ + batch_size, + avg_time_per_sample_ms, + num_iters, + total_samples, + data_gen_time_ms, + graph_to_device_time_ms, + clone_list_time_ms, + dataloader_create_time_ms, + dataloader_first_iter_time_ms, + batch_to_device_time_ms, + warmup_time_ms, + iter_total_wall_time_ms, + iter_gpu_time_ms, + gpu_idle_time_ms, + gpu_busy_ratio, + samples_per_sec_wall, + samples_per_sec_gpu, + timing_source, + max_cuda_mem_alloc_bytes, + max_cuda_mem_reserved_bytes, + N_BUS, + N_GEN, + E, + BUS_FEATS, + GEN_FEATS, + EDGE_FEATS, + env["device_type"], + env["device_name"], + env["torch_version"], + env["cuda_version_in_torch"], + env["cudnn_version"], + env["python_version"], + env["cpu_name"], + env["gpu_names"], + timestamp, + args.num_workers, + bool(pin_mem), + ] + + writer.writerow(row) + csvfile.flush() + batch_sizes_used.append(batch_size) + times.append(avg_time_per_sample_ms) + + return batch_sizes_used, times + + +if __name__ == "__main__": + print(f"Starting benchmark for {os.path.basename(args.output_csv)} ..") + benchmark() + print(f"Finished benchmarking for {os.path.basename(args.output_csv)}\n ...") diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 00000000..2054f30b --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set +e # Do NOT exit on error + +CONFIGS=( + "grit01" +) + +CONFIG_PATHS=( + "../examples/config/GRIT_PF_datakit_case14.yaml" +) + +GRAPH_SIZES=( + "30 110" + "300 1120" + "2000 9276" + "3022 11390" + "9241 41337" + "30000 100784" +) + +OUTPUT_DIR="benchmark_results" +mkdir -p $OUTPUT_DIR +for i in "${!CONFIGS[@]}"; do + config_name="${CONFIGS[$i]}" + config_path="${CONFIG_PATHS[$i]}" + for size in "${GRAPH_SIZES[@]}"; do + read -r nodes edges <<< "$size" + output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" + echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." + python benchmark_model_inference.py \ + --model "grit" \ + --config "$config_path" \ + --output_csv "$output_file" \ + --num_nodes "$nodes" \ + --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" + done +done