-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_model.py
More file actions
109 lines (97 loc) · 3.94 KB
/
eval_model.py
File metadata and controls
109 lines (97 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
""" Use the trained models on test data, testing both CCLM and LMCC.
Produce results tabels that will be read by results.py to create plots, tables, etc.
"""
import util
import data
import model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext as tt
import numpy as np
import random
import json
import csv
import os
import click
from pathlib import Path
def batch_nll(lm, batch, pad_idx, comm=None):
"""
Compute the negative log likelihood of the batch for the given community
If comm is None, then the actual communities for the batch will be used.
"""
if comm is None:
comm = batch.community
text, lengths = batch.text
x_text = text[:-1]
y = text[1:]
y_hat = lm(x_text, comm)
vocab_size = y_hat.shape[-1]
nll_seq = F.nll_loss(y_hat.view(-1,vocab_size), y.view(-1),
reduction='none', ignore_index=pad_idx).view(y.shape).sum(axis=0)
return nll_seq
@click.command()
@click.argument('model_family_dir', type=click.Path(exists=False))
@click.argument('model_name', type=str)
@click.argument('data_dir', type=click.Path(exists=True))
@click.option('--batch-size', default=512)
@click.option('--max-seq-len', default=64)
@click.option('--file-limit', type=int, default=None,
help="Number of examples per file (community).")
@click.option('--gpu-id', type=int, default=None,
help="ID of the GPU, if traning with CUDA")
def cli(model_family_dir, model_name, data_dir, batch_size, max_seq_len, file_limit, gpu_id):
model_family_dir = Path(model_family_dir)
model_dir = model_family_dir/model_name
device = torch.device(f'cuda:{gpu_id}' if gpu_id is not None else 'cpu')
log = util.create_logger('test', os.path.join(model_dir, 'testing.log'), True)
log.info(f"Loading data from {data_dir}.")
fields = data.load_fields(model_family_dir)
fields['text'].include_lengths = True
test_data = data.load_data(data_dir, fields, 'test', max_seq_len, file_limit)
vocab_size = len(fields['text'].vocab.itos)
comm_vocab_size = len(fields['community'].vocab.itos)
comms = fields['community'].vocab.itos
pad_idx = fields['text'].vocab.stoi['<pad>']
log.info(f"Loaded {len(test_data)} test examples.")
model_args = json.load(open(model_dir/'model_args.json'))
lm = model.CommunityConditionedLM.build_model(**model_args).to(device)
lm.load_state_dict(torch.load(model_dir/'model.bin'))
lm.to(device)
lm.eval()
log.debug(str(lm))
test_iterator = tt.data.BucketIterator(
test_data,
device=device,
batch_size=batch_size,
sort_key=lambda x: -len(x.text),
shuffle=True,
train=False)
def batchify_comm(comm, batch_size):
comm_idx = fields['community'].vocab.stoi[comm]
return torch.tensor(comm_idx).repeat(batch_size).to(device)
with torch.no_grad(), open(model_dir/'nll.csv', 'w') as f:
meta_fields = ['community', 'example_id', 'length']
data_fields = comms if lm.use_community else ['nll']
writer = csv.DictWriter(f, fieldnames=meta_fields+data_fields)
writer.writeheader()
for i, batch in enumerate(test_iterator):
nlls_batch = [
dict(zip(meta_fields, meta_values)) for meta_values in zip(
[comms[i] for i in batch.community.tolist()],
batch.example_id.tolist(),
batch.text[1].tolist()
)
]
for comm in comms:
if lm.use_community:
batch_comm = batchify_comm(comm, batch.batch_size)
else:
batch_comm = None
nlls_comm = batch_nll(lm, batch, pad_idx, comm=batch_comm)
for j, nll in enumerate(nlls_comm):
nlls_batch[j][comm] = nll.item()
writer.writerows(nlls_batch)
log.info(f"Completed {i+1}/{len(test_iterator)}")
if __name__ == '__main__':
cli(obj={})