-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph.py
More file actions
229 lines (186 loc) · 8.14 KB
/
Copy pathgraph.py
File metadata and controls
229 lines (186 loc) · 8.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
from typing import Any, Dict, NamedTuple, Optional, Tuple
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index
from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType
from relbench.modeling.utils import remove_pkey_fkey, to_unix_time
def make_pkey_fkey_graph(
db: Database,
col_to_stype_dict: Dict[str, Dict[str, stype]],
text_embedder_cfg: Optional[TextEmbedderConfig] = None,
cache_dir: Optional[str] = None,
) -> Tuple[HeteroData, Dict[str, Dict[str, Dict[StatType, Any]]]]:
r"""Given a :class:`Database` object, construct a heterogeneous graph with primary-
foreign key relationships, together with the column stats of each table.
Args:
db: A database object containing a set of tables.
col_to_stype_dict: Column to stype for
each table.
text_embedder_cfg: Text embedder config.
cache_dir: A directory for storing materialized tensor
frames. If specified, we will either cache the file or use the
cached file. If not specified, we will not use cached file and
re-process everything from scratch without saving the cache.
Returns:
HeteroData: The heterogeneous :class:`PyG` object with
:class:`TensorFrame` feature.
"""
data = HeteroData()
col_stats_dict = dict()
if cache_dir is not None:
os.makedirs(cache_dir, exist_ok=True)
for table_name, table in db.table_dict.items():
print('processing table name:', table_name)
# Materialize the tables into tensor frames:
df = table.df
# Ensure that pkey is consecutive.
if table.pkey_col is not None:
assert (df[table.pkey_col].values == np.arange(len(df))).all()
col_to_stype = col_to_stype_dict[table_name]
# Remove pkey, fkey columns since they will not be used as input
# feature.
remove_pkey_fkey(col_to_stype, table)
if len(col_to_stype) == 0: # Add constant feature in case df is empty:
col_to_stype = {"__const__": stype.numerical}
# We need to add edges later, so we need to also keep the fkeys
fkey_dict = {key: df[key] for key in table.fkey_col_to_pkey_table}
df = pd.DataFrame({"__const__": np.ones(len(table.df)), **fkey_dict})
path = (
None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
)
dataset = Dataset(
df=df,
col_to_stype=col_to_stype,
col_to_text_embedder_cfg=text_embedder_cfg,
).materialize(path=path)
data[table_name].tf = dataset.tensor_frame
col_stats_dict[table_name] = dataset.col_stats
# Add time attribute:
if table.time_col is not None:
data[table_name].time = torch.from_numpy(
to_unix_time(table.df[table.time_col])
)
# Add edges:
for fkey_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
pkey_index = df[fkey_name]
# Filter out dangling foreign keys
mask = ~pkey_index.isna()
fkey_index = torch.arange(len(pkey_index))
# Filter dangling foreign keys:
pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
fkey_index = fkey_index[torch.from_numpy(mask.values)]
# Ensure no dangling fkeys
assert (pkey_index < len(db.table_dict[pkey_table_name])).all()
# fkey -> pkey edges
edge_index = torch.stack([fkey_index, pkey_index], dim=0)
edge_type = (table_name, f"f2p_{fkey_name}", pkey_table_name)
data[edge_type].edge_index = sort_edge_index(edge_index)
# pkey -> fkey edges.
# "rev_" is added so that PyG loader recognizes the reverse edges
edge_index = torch.stack([pkey_index, fkey_index], dim=0)
edge_type = (pkey_table_name, f"rev_f2p_{fkey_name}", table_name)
data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()
return data, col_stats_dict
class AttachTargetTransform:
r"""Attach the target label to the heterogeneous mini-batch.
The batch consists of disjoint subgraphs loaded via temporal sampling. The same
input node can occur multiple times with different timestamps, and thus different
subgraphs and labels. Hence labels cannot be stored in the graph object directly,
and must be attached to the batch after the batch is created.
"""
def __init__(self, entity: str, target: Tensor):
self.entity = entity
self.target = target
def __call__(self, batch: HeteroData) -> HeteroData:
batch[self.entity].y = self.target[batch[self.entity].input_id]
return batch
class NodeTrainTableInput(NamedTuple):
r"""Training table input for node prediction.
- nodes is a Tensor of node indices.
- time is a Tensor of node timestamps.
- target is a Tensor of node labels.
- transform attaches the target to the batch.
"""
nodes: Tuple[NodeType, Tensor]
time: Optional[Tensor]
target: Optional[Tensor]
transform: Optional[AttachTargetTransform]
def get_node_train_table_input(
table: Table,
task: EntityTask,
) -> NodeTrainTableInput:
r"""Get the training table input for node prediction."""
nodes = torch.from_numpy(table.df[task.entity_col].astype(int).values)
time: Optional[Tensor] = None
if table.time_col is not None:
time = torch.from_numpy(to_unix_time(table.df[table.time_col]))
target: Optional[Tensor] = None
transform: Optional[AttachTargetTransform] = None
if task.target_col in table.df:
target_type = float
if task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
target_type = int
if task.task_type == TaskType.MULTILABEL_CLASSIFICATION:
target = torch.from_numpy(np.stack(table.df[task.target_col].values))
else:
target = torch.from_numpy(
table.df[task.target_col].values.astype(target_type)
)
transform = AttachTargetTransform(task.entity_table, target)
return NodeTrainTableInput(
nodes=(task.entity_table, nodes),
time=time,
target=target,
transform=transform,
)
class LinkTrainTableInput(NamedTuple):
r"""Training table input for link prediction.
- src_nodes is a Tensor of source node indices.
- dst_nodes is PyTorch sparse tensor in csr format.
dst_nodes[src_node_idx] gives a tensor of destination node
indices for src_node_idx.
- num_dst_nodes is the total number of destination nodes.
(used to perform negative sampling).
- src_time is a Tensor of time for src_nodes
"""
src_nodes: Tuple[NodeType, Tensor]
dst_nodes: Tuple[NodeType, Tensor]
num_dst_nodes: int
src_time: Optional[Tensor]
def get_link_train_table_input(
table: Table,
task: RecommendationTask,
) -> LinkTrainTableInput:
r"""Get the training table input for link prediction."""
src_node_idx: Tensor = torch.from_numpy(
table.df[task.src_entity_col].astype(int).values
)
exploded = table.df[task.dst_entity_col].explode()
coo_indices = torch.from_numpy(
np.stack([exploded.index.values, exploded.values.astype(int)])
)
sparse_coo = torch.sparse_coo_tensor(
coo_indices,
torch.ones(coo_indices.size(1), dtype=bool),
(len(src_node_idx), task.num_dst_nodes),
)
dst_node_indices = sparse_coo.to_sparse_csr()
time: Optional[Tensor] = None
if table.time_col is not None:
time = torch.from_numpy(to_unix_time(table.df[table.time_col]))
return LinkTrainTableInput(
src_nodes=(task.src_entity_table, src_node_idx),
dst_nodes=(task.dst_entity_table, dst_node_indices),
num_dst_nodes=task.num_dst_nodes,
src_time=time,
)