-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
150 lines (131 loc) · 4.64 KB
/
Copy pathmodel.py
File metadata and controls
150 lines (131 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Any, Dict, List
import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from relbench.modeling.nn import HeteroEncoder, HeteroTemporalEncoder
from exp_model import HeteroGraphSAGE
class FROG(torch.nn.Module):
def __init__(
self,
data: HeteroData,
two_hop_relations: set,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_layers: int,
channels: int,
out_channels: int,
aggr: str,
norm: str,
# List of node types to add shallow embeddings to input
shallow_list: List[NodeType] = [],
# ID awareness
id_awareness: bool = False,
MAE: float = 0.99,
dropout: float = 0.0,
optim: str = 'both',
fix: bool = False,
fix_attn: float = 0.5
):
super().__init__()
self.encoder = HeteroEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
two_hop_relations=two_hop_relations,
channels=channels,
aggr=aggr,
num_layers=num_layers,
MAE = MAE,
dropout=dropout,
optim = optim,
fix = fix,
fix_attn = fix_attn
)
self.head = MLP(
channels,
out_channels=out_channels,
norm=norm,
num_layers=1,
)
self.embedding_dict = ModuleDict(
{
node: Embedding(data.num_nodes_dict[node], channels)
for node in shallow_list
}
)
self.id_awareness_emb = None
if id_awareness:
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.reset_parameters()
def reset_parameters(self):
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
for embedding in self.embedding_dict.values():
torch.nn.init.normal_(embedding.weight, std=0.1)
if self.id_awareness_emb is not None:
self.id_awareness_emb.reset_parameters()
def forward(
self,
batch: HeteroData,
entity_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
rel_time_dict = self.temporal_encoder(
seed_time, batch.time_dict, batch.batch_dict
)
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
for node_type, embedding in self.embedding_dict.items():
x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)
x_dict = self.gnn(
x_dict,
batch.edge_index_dict
)
# if self.training:
# torch.save(x_dict, "x_dict_f1.pt")
return self.head(x_dict[entity_table][: seed_time.size(0)]), x_dict
def forward_dst_readout(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
if self.id_awareness_emb is None:
raise RuntimeError(
"id_awareness must be set True to use forward_dst_readout"
)
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(
seed_time, batch.time_dict, batch.batch_dict
)
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
for node_type, embedding in self.embedding_dict.items():
x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)
x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)
return self.head(x_dict[dst_table]), x_dict