-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_classifier.py
More file actions
106 lines (91 loc) · 3.57 KB
/
Copy pathtrain_classifier.py
File metadata and controls
106 lines (91 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import random
from pathlib import Path
import torch
from document_segmentation.model.document_classifier import DocumentClassifier
from document_segmentation.model.page_learner import AbstractPageLearner
from document_segmentation.pagexml.annotations.renate_analysis import (
RenateAnalysis,
RenateAnalysisInv,
)
from document_segmentation.pagexml.datamodel.document import Document
from document_segmentation.settings import (
RENATE_ANALYSIS_SHEETS,
RENATE_TANAP_CATEGORISATION_SHEET,
)
if __name__ == "__main__":
########################################################################################
# PARSE ARGUMENTS
########################################################################################
arg_parser = argparse.ArgumentParser(
description="Train a document classification model"
)
arg_parser.add_argument(
"--renate-categorisation-sheet",
type=str,
default=RENATE_TANAP_CATEGORISATION_SHEET,
help="The sheet with input annotations (Appendix F Renate Analysis).",
)
arg_parser.add_argument(
"--renate-analysis-sheet",
nargs="*",
type=str,
default=RENATE_ANALYSIS_SHEETS,
help="The sheet with input annotations (Entire inventories from Renate's Analyses).",
)
arg_parser.add_argument("--split", type=float, default=0.8, help="Train/val split.")
arg_parser.add_argument(
"--max-documents",
"--max",
type=int,
required=False,
help="The maximum number of documents to read.",
)
arg_parser.add_argument(
"--model-file",
type=Path,
default=Path("classifier_model.pt"),
help="Output file for the model. Defaults to 'model.pt'.",
)
arg_parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
arg_parser.add_argument(
"--device",
"-d",
type=str,
choices=["cuda", "mps", "cpu"],
required=False,
help="The device to use for training. Auto-detects if not given.",
)
arg_parser.add_argument("--seed", type=int, required=False, help="Random Seed")
args = arg_parser.parse_args()
random.seed(args.seed)
########################################################################################
# LOAD ANNOTATION SHEETS AND DATA
########################################################################################
training_data: list[Document] = []
validation_data: dict[str, list[Document]] = dict()
sheet = RenateAnalysis(sheet_file=args.renate_categorisation_sheet)
docs = list(sheet.documents(n=args.max_documents))
train, validation = AbstractPageLearner.split(docs, split=args.split)
training_data.extend(train)
validation_data[args.renate_categorisation_sheet.name] = validation
if args.renate_categorisation_sheet:
docs = []
for inv_sheet in args.renate_analysis_sheet:
docs.extend(
sheet.documents_from_sheet(
RenateAnalysisInv(sheet_file=inv_sheet), n=args.max_documents
)
)
train, validation = AbstractPageLearner.split(docs, split=args.split)
training_data.extend(train)
validation_data["renate_analysis_inv"] = validation
# TODO: add Generale Missiven?
model = DocumentClassifier()
best_model = model.train_(
training_data,
validation_data,
epochs=args.epochs,
weights=model.total_class_weights(validation_data["renate_analysis_inv"]),
)
torch.save(best_model, args.model_file)