-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdocument_classifier.py
More file actions
145 lines (120 loc) · 4.61 KB
/
document_classifier.py
File metadata and controls
145 lines (120 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
class DocumentClassifier:
"""
文档分类器,用于智能分类上传的文档
"""
def __init__(self, llm_model="gpt-3.5-turbo"):
"""
初始化文档分类器
Args:
llm_model: LLM模型名称
"""
self.llm_model = llm_model
self.llm = None
self.classification_prompt = self._build_classification_prompt()
# 尝试初始化LLM,失败则使用简单分类
self._init_llm()
def _init_llm(self):
"""
初始化LLM模型
"""
try:
from langchain_openai import ChatOpenAI
self.llm = ChatOpenAI(model=self.llm_model, temperature=0.1)
except Exception as e:
print(f"初始化LLM失败,将使用简单分类: {e}")
self.llm = None
def _build_classification_prompt(self):
"""
构建分类提示模板
Returns:
分类提示模板
"""
prompt = ChatPromptTemplate.from_template("""
请根据文档内容,将其分类到以下类别之一:
- 技术文档
- 产品手册
- 法律文件
- 财务报告
- 市场资料
- 教育培训
- 其他
请仅返回类别名称,不要添加任何解释或其他内容。
文档内容:
{document_content}
分类结果:
""")
return prompt
def _simple_classify(self, document_content):
"""
简单分类方法(当LLM不可用时使用)
Args:
document_content: 文档内容
Returns:
分类结果
"""
content_lower = document_content.lower()
if any(word in content_lower for word in ["技术", "编程", "开发", "代码", "软件", "API"]):
return "技术文档"
elif any(word in content_lower for word in ["产品", "使用", "手册", "指南", "功能"]):
return "产品手册"
elif any(word in content_lower for word in ["法律", "法规", "合同", "条款", "协议"]):
return "法律文件"
elif any(word in content_lower for word in ["财务", "报告", "预算", "收入", "支出"]):
return "财务报告"
elif any(word in content_lower for word in ["市场", "营销", "推广", "销售", "客户"]):
return "市场资料"
elif any(word in content_lower for word in ["教育", "培训", "学习", "课程", "教程"]):
return "教育培训"
else:
return "其他"
def classify_document(self, document_content):
"""
分类单个文档
Args:
document_content: 文档内容
Returns:
分类结果
"""
if self.llm:
# 构建分类链
classification_chain = (
self.classification_prompt
| self.llm
| StrOutputParser()
)
try:
# 限制文档内容长度
limited_content = document_content[:1000] # 仅使用前1000个字符进行分类
return classification_chain.invoke({"document_content": limited_content})
except Exception as e:
print(f"使用LLM分类失败,将使用简单分类: {e}")
return self._simple_classify(document_content)
else:
# 使用简单分类
return self._simple_classify(document_content)
def classify_documents(self, documents):
"""
分类多个文档
Args:
documents: 文档对象列表(LangChain Document对象)
Returns:
分类结果列表,每个元素为 (文档对象, 分类结果) 的元组
"""
classified_docs = []
for doc in documents:
# 获取文档内容
content = doc.page_content if hasattr(doc, "page_content") else str(doc)
# 分类
category = self.classify_document(content)
classified_docs.append((doc, category))
return classified_docs
def update_llm_model(self, llm_model):
"""
更新LLM模型
Args:
llm_model: 新的LLM模型名称
"""
self.llm_model = llm_model
self.llm = ChatOpenAI(model=self.llm_model, temperature=0.1)