-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_lstm.py
More file actions
128 lines (104 loc) · 4.73 KB
/
train_lstm.py
File metadata and controls
128 lines (104 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# Andreas Goulas <goulasand@gmail.com>
# Nikolaos Gkalelis <gkalelis@iti.gr> | 23/4/2021 | minor changes (main, print, path processing, etc.)
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import average_precision_score, accuracy_score
import os
import sys
from fcvid import FCVID
from ylimed import YLIMED
from model import Classifier
def train(model, loader, crit, opt, sched, device):
model.train()
epoch_loss = 0
for i, batch in enumerate(loader):
feats, label = batch
feats = feats.to(device)
label = label.to(device)
opt.zero_grad()
out_data = model(feats)
loss = crit(out_data, label)
loss.backward()
opt.step()
epoch_loss += loss.item()
sched.step()
return epoch_loss / len(loader)
def test(model, loader, scores, device):
gidx = 0
model.eval()
with torch.no_grad():
for i, batch in enumerate(loader):
feats, _ = batch
feats = feats.to(device)
out_data = model(feats)
N = out_data.shape[0]
scores[gidx:gidx+N, :] = out_data.cpu()
gidx += N
parser = argparse.ArgumentParser(description='GCN Video Classification')
parser.add_argument('--dataset', default='fcvid', choices=['fcvid', 'ylimed'])
parser.add_argument('--feats_folder', default='feats', help='directory to load features')
parser.add_argument('--lr', type=float, default=1e-5, help='initial learning rate')
parser.add_argument('--gamma', type=float, default=1, help='learning rate decay rate')
parser.add_argument('--num_epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for data loader; set always to zero!')
parser.add_argument('--eval_interval', type=int, default=50, help='interval for evaluating models (epochs)')
parser.add_argument('-v', '--verbose', action='store_true', help='show details')
args = parser.parse_args()
def main():
train_feats = torch.load(os.path.join(args.feats_folder, 'feats-train.pt'))
train_truth = torch.load(os.path.join(args.feats_folder, 'truth-train.pt'))
test_feats = torch.load(os.path.join(args.feats_folder, 'feats-test.pt'))
test_truth = torch.load(os.path.join(args.feats_folder, 'truth-test.pt'))
if args.dataset == 'ylimed':
train_truth = train_truth.long()
test_truth = test_truth.long()
train_dataset = TensorDataset(train_feats, train_truth)
test_dataset = TensorDataset(test_feats, test_truth)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_truth = test_truth.numpy()
device = torch.device('cuda:0')
if args.verbose:
print("running on {}".format(device))
print("num train samples {}".format(len(train_dataset)))
print("num test samples {}".format(len(test_dataset)))
if args.dataset == 'fcvid':
crit = nn.BCEWithLogitsLoss()
num_feats, num_class = FCVID.NUM_FEATS, FCVID.NUM_CLASS
elif args.dataset == 'ylimed':
crit = nn.CrossEntropyLoss()
num_feats, num_class = YLIMED.NUM_FEATS, YLIMED.NUM_CLASS
else:
sys.exit("Unknown dataset!")
start_epoch = 0
model = Classifier(2 * num_feats, num_feats, num_class).to(device)
opt = optim.Adam(model.parameters(), lr=args.lr)
sched = optim.lr_scheduler.ExponentialLR(opt, args.gamma)
for epoch in range(start_epoch, args.num_epochs):
t0 = time.perf_counter()
loss = train(model, train_loader, crit, opt, sched, device)
t1 = time.perf_counter()
sched.step()
if args.verbose:
print("[epoch {}] loss={} dt={:.2f}sec".format(epoch + 1, loss, t1 - t0))
if (epoch + 1) % args.eval_interval == 0:
num_test = len(test_dataset)
scores = torch.zeros((num_test, num_class), dtype=torch.float32)
test(model, test_loader, scores, device)
scores = scores.numpy()
if args.dataset == 'fcvid':
ap = average_precision_score(test_truth, scores)
print("mAP={:.2f}%".format(100 * ap))
elif args.dataset == 'ylimed':
pred = scores.argmax(axis=1)
acc = accuracy_score(test_truth, pred)
print("accuracy={:.2f}%".format(100 * acc))
torch.save(pred, os.path.join(args.feats_folder, 'pred-test.pt'))
if __name__ == '__main__':
main()