-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathlp_train.py
More file actions
106 lines (90 loc) · 3.58 KB
/
Copy pathlp_train.py
File metadata and controls
106 lines (90 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import division
from __future__ import print_function
import json
import os
import pickle
import time
import numpy as np
import optimizers
import torch
from config import parser
from models.base_models import LPModel
from utils.data_utils import load_data
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import scipy.sparse as sp
from tqdm import tqdm
def train(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if int(args.cuda) >= 0:
torch.cuda.manual_seed(args.seed)
args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Load data
data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
args.n_nodes, args.feat_dim = data['features'].shape
args.nb_false_edges = len(data['train_edges_false'])
args.nb_edges = len(data['train_edges'])
Model = LPModel
# No validation for reconstruction task
args.eval_freq = args.epochs + 1
if not args.lr_reduce_freq:
args.lr_reduce_freq = args.epochs
model = Model(args)
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=int(args.lr_reduce_freq),
gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
if args.cuda is not None and int(args.cuda) >= 0 :
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
model = model.to(args.device)
for x, val in data.items():
if torch.is_tensor(data[x]):
data[x] = data[x].to(args.device)
# Train model
t_total = time.time()
counter = 0
best_val_metrics = model.init_metric_dict()
best_test_metrics = None
best_emb = None
for epoch in tqdm(range(args.epochs)):
tt = time.time()
model.train()
optimizer.zero_grad()
embeddings,t,h0,adj= model.encode(data['features'], data['adj_train_norm'])
# print(h0)
train_metrics = model.compute_metrics(embeddings, data, t,h0,adj,'train')
# print(train_metrics['loss'])
train_metrics['loss'].backward()
if args.grad_clip is not None:
max_norm = float(args.grad_clip)
all_params = list(model.parameters())
for param in all_params:
torch.nn.utils.clip_grad_norm_(param, max_norm)
optimizer.step()
lr_scheduler.step()
if (epoch + 1) % args.eval_freq == 0:
model.eval()
embeddings,t,h0,adj = model.encode(data['features'], data['adj_train_norm'])
val_metrics = model.compute_metrics(embeddings, data, t,h0,adj,'val')
model.eval()
best_test_metrics = model.compute_metrics(embeddings, data,t,h0, adj,'test')
# print(best_test_metrics)
print('End encoding!')
np.save(os.path.join(save_dir, 'embeddings.npy'), h0.cpu().detach().numpy())
if hasattr(model.encoder, 'att_adj'):
filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb'))
print('Dumped attention adj: ' + filename)
json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
if __name__ == '__main__':
args = parser.parse_args()
train(args)