-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoutput.txt
More file actions
886 lines (729 loc) · 37.2 KB
/
output.txt
File metadata and controls
886 lines (729 loc) · 37.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
--- Start of config.py ---
# 配置文件
PROJECT_PATH = "/home/sxj/Desktop/Workspace/CodeQl/gptgraph/DevEval/Source_Code/Internet/Authlib" # 需要分析的项目路径
NEO4J_URL = "bolt://localhost:7687" # Neo4j数据库的URL
NEO4J_USER = "neo4j" # Neo4j数据库的用户名
NEO4J_PASSWORD = "12341234" # Neo4j数据库的密码
--- End of config.py ---
--- Start of neo4j_utils.py ---
import os
import logging
from py2neo import Graph, Node, Relationship
from tqdm import tqdm
class Neo4jHandler:
def __init__(self, url, user, password):
self.graph = Graph(url, auth=(user, password))
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.DEBUG) # Set the desired log level here
def clean_database(self):
self.graph.run("MATCH (n) DETACH DELETE n")
self.logger.debug("数据库已清空")
def import_graph(self, code_graph):
nx_graph = code_graph.get_graph()
# 使用 tqdm 为节点导入添加进度条
nodes = list(nx_graph.nodes(data=True))
with tqdm(total=len(nodes), desc="导入节点", unit="节点") as pbar:
for node, attrs in nodes:
full_name = node
node_type = attrs.get('type', 'UNKNOWN').upper()
if node_type == 'UNKNOWN':
tqdm.write(f"发现未知类型的节点: {full_name}") # 使用 tqdm.write 替代 logger,避免干扰进度条
pbar.update(1)
continue
if node_type == 'FILE':
short_name = os.path.basename(full_name)
else:
short_name = full_name.split('.')[-1]
existing_node = self.graph.nodes.match(node_type, full_name=full_name).first()
if not existing_node:
n = Node(node_type, name=short_name, full_name=full_name, code=attrs.get('code', ''), signature=attrs.get('signature', ''), description=attrs.get('description', ''))
self.graph.create(n)
tqdm.write(f"导入节点: {full_name} (类型: {node_type})") # 使用 tqdm.write 替代 logger
pbar.update(1) # 更新进度条
# 使用 tqdm 为关系导入添加进度条
edges = list(nx_graph.edges(data=True))
with tqdm(total=len(edges), desc="导入关系", unit="关系") as pbar:
for start, end, edge_attrs in edges:
start_node = self.graph.nodes.match(full_name=start).first()
end_node = self.graph.nodes.match(full_name=end).first()
if start_node and end_node:
existing_rel = self.graph.match_one(nodes=(start_node, end_node), r_type=edge_attrs['relationship'])
if not existing_rel:
rel = Relationship(start_node, edge_attrs['relationship'], end_node)
self.graph.create(rel)
tqdm.write(f"导入关系: {start} -> {end} (类型: {edge_attrs['relationship']})") # 使用 tqdm.write 替代 logger
else:
tqdm.write(f"警告: 关系的起始节点或终止节点缺失,跳过创建关系: {start} -> {end}")
pbar.update(1) # 更新进度条
--- End of neo4j_utils.py ---
--- Start of main.py ---
import os
from code_graph import CodeGraph
from neo4j_utils import Neo4jHandler
from parsers.contains_parser import ContainsParser # 引入包含关系的解析器
from parsers.import_parser import ImportParser # 引入 import 关系的解析器
from parsers.call_parser import CallParser # 引入调用关系的解析器
from lsp_client import LspClientWrapper # LSP 客户端包装器
import config
import logging
# 全局日志配置
logging.basicConfig(level=logging.INFO, format=' %(name)s - %(levelname)s - %(message)s')
def main():
# 连接到 Neo4j 数据库
neo4j_handler = Neo4jHandler(config.NEO4J_URL, config.NEO4J_USER, config.NEO4J_PASSWORD)
# 清空 Neo4j 数据库
neo4j_handler.clean_database()
# 获取项目名称
repo_name = os.path.basename(os.path.normpath(config.PROJECT_PATH))
# 第一步:解析 CONTAINS 关系
contains_parser = ContainsParser(config.PROJECT_PATH, repo_name)
contains_parser.parse() # 解析目录结构和类、函数定义
# 构建代码图
code_graph = CodeGraph()
# 遍历树形结构并构建图,从根节点开始
code_graph.build_graph_from_tree(contains_parser.root)
# 第二步:解析 import 关系
import_parser = ImportParser(config.PROJECT_PATH, repo_name)
import_parser.parse()
print(f"import_parser.imports: {import_parser.imports}")
# 处理 import 关系
for import_data in import_parser.imports:
importer = import_data[0]
imported_module = import_data[1]
print(f"importer: {importer}, imported_module: {imported_module}")
code_graph.add_import(importer, imported_module)
# 第三步:解析调用关系并启动 LSP 服务器
lsp_client = LspClientWrapper(config.PROJECT_PATH) # 初始化 LSP 客户端包装器
lsp_client.start_server() # 手动启动 LSP 服务器
try:
call_parser = CallParser(config.PROJECT_PATH, repo_name, code_graph, contains_parser.defined_symbols, lsp_client)
call_parser.parse() # 使用已解析的符号来处理调用关系
# 输出已定义的符号(调试用)
print('-'*50+'\n',f"已定义的符号: {call_parser.defined_symbols}\n",'-'*50+'\n')
# 处理调用关系
for caller, callee in call_parser.calls:
code_graph.add_call(caller, callee)
finally:
lsp_client.stop_server() # 手动停止 LSP 服务器
# 最后,将图导入到 Neo4j 数据库
neo4j_handler.import_graph(code_graph)
if __name__ == "__main__":
main()
--- End of main.py ---
--- Start of code_graph.py ---
import networkx as nx
import sys
import importlib.util
import logging
class CodeGraph:
def __init__(self):
self.graph = nx.DiGraph()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.DEBUG)
def build_graph_from_tree(self, tree_root):
# 从树的根节点开始构建图
self._add_node(tree_root)
self._build_edges(tree_root)
def _add_node(self, node):
self.graph.add_node(node.fullname, type=node.node_type, code=node.code, signature=node.signature)
self.logger.debug(f"添加节点: {node.fullname} (类型: {node.node_type})")
def _build_edges(self, node):
for child in node.children:
self.graph.add_edge(node.fullname, child.fullname, relationship="CONTAINS")
self._add_node(child)
self._build_edges(child)
def add_call(self, caller_fullname, callee_fullname):
if caller_fullname in self.graph and callee_fullname in self.graph:
self.graph.add_edge(caller_fullname, callee_fullname, relationship="CALLS")
self.logger.debug(f"添加调用关系: {caller_fullname} -> {callee_fullname}")
else:
self.logger.debug(f"调用关系中的节点不存在: {caller_fullname} -> {callee_fullname}")
def add_import(self, importer_fullname, imported_fullname):
# 检测模块类型
module_type, module_path = self._detect_module_type(imported_fullname)
# 如果模块不存在图中,为第三方库、标准库或本地库创建虚拟节点
if imported_fullname not in self.graph:
self.logger.debug(f"创建节点: {imported_fullname} (类型: {module_type}) 路径: {module_path}")
self.graph.add_node(imported_fullname, type=module_type, code=None, signature=None)
if importer_fullname in self.graph:
self.graph.add_edge(importer_fullname, imported_fullname, relationship="IMPORTS")
self.logger.debug(f"添加import关系: {importer_fullname} -> {imported_fullname}")
else:
self.logger.debug(f"import关系中的节点不存在: {importer_fullname} -> {imported_fullname}")
def _detect_module_type(self, module_name):
"""
检测模块类型:标准库、第三方库或本地库
"""
try:
# 检测标准库模块
if module_name in sys.builtin_module_names:
return "standard_library", None
# 使用 importlib.util.find_spec 查找模块的元信息
module_spec = importlib.util.find_spec(module_name)
if module_spec is None:
self.logger.debug(f"无法找到模块 {module_name}")
return "unknown", None
# 判断模块类型:第三方库或标准库
module_path = module_spec.origin
if not module_path:
return "unknown", None
if "site-packages" in module_path or "dist-packages" in module_path:
return "third_party_library", module_path
else:
return "local_module", module_path
except ModuleNotFoundError:
self.logger.debug(f"模块 {module_name} 未安装")
return "unknown", None
def get_graph(self):
return self.graph
--- End of code_graph.py ---
--- Start of lsp_client.py ---
import os
import logging
from multilspy import SyncLanguageServer
from multilspy.multilspy_config import MultilspyConfig
from multilspy.multilspy_logger import MultilspyLogger
class LspClientWrapper:
_instance = None # 单例模式实现
def __new__(cls, project_root):
if cls._instance is None:
cls._instance = super(LspClientWrapper, cls).__new__(cls)
cls._instance.initialize_server(project_root)
return cls._instance
def initialize_server(self, project_root):
"""初始化并启动 LSP 服务器,只在首次调用时运行"""
self.project_root = os.path.abspath(project_root)
self.config = MultilspyConfig.from_dict({"code_language": "python"}) # 配置语言
self.logger = MultilspyLogger()
self.slsp = SyncLanguageServer.create(self.config, self.logger, self.project_root)
self.active = False # 标记 LSP 服务器是否活跃
def start_server(self):
"""启动同步的 LSP 服务器,只启动一次"""
if not self.active:
self.server_context = self.slsp.start_server() # 获取上下文管理器
self.server_context.__enter__() # 手动进入上下文管理器
self.active = True
logging.info("LSP server started")
def stop_server(self):
"""停止同步的 LSP 服务器"""
if self.active:
self.server_context.__exit__(None, None, None) # 退出上下文管理器
self.active = False
logging.info("LSP server stopped")
def find_definition(self, file_path, line, character):
"""同步接口,查找定义"""
if not self.active:
try:
self.start_server()
except Exception as e:
logging.error(f"Failed to start LSP server: {e}")
raise RuntimeError("LSP server not started or stopped")
abs_file_path = os.path.abspath(file_path)
logging.debug(f"Finding definition in file: {abs_file_path} at line: {line}, character: {character}")
try:
# 请求 LSP 查找定义
result = self.slsp.request_definition(abs_file_path, line, character)
if not result:
logging.warning(f"No definition found for {abs_file_path} at line {line}, character {character}")
return None # 返回 None 以表示未找到定义
return result
except AssertionError as ae:
logging.error(f"LSP request failed with assertion error: {ae}")
return None # 返回 None 表示解析失败
except Exception as e:
logging.error(f"Error finding definition for {abs_file_path} at line {line}, character {character}: {e}")
return None # 捕获其他异常并返回 None
def __enter__(self):
"""支持上下文管理器,进入时启动服务器"""
self.start_server()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出上下文管理器时,停止服务器"""
self.stop_server()
def __del__(self):
"""对象销毁时确保资源被释放"""
self.stop_server()
--- End of lsp_client.py ---
--- Start of parsers/contains_parser.py ---
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
import os
class Node:
def __init__(self, name, node_type, code=None, signature=None, parent_fullname=None):
self.name = name
self.node_type = node_type # 'directory', 'module', 'class', 'function'
self.children = []
self.code = code
self.signature = signature
# 生成全名:从根节点到当前节点的路径名
if parent_fullname:
self.fullname = f"{parent_fullname}.{name}"
else:
self.fullname = name
def add_child(self, child_node):
self.children.append(child_node)
class ContainsParser:
def __init__(self, project_path, repo_name):
self.project_path = project_path
self.repo_name = repo_name
self.parser = self._init_parser()
self.root = Node(repo_name, 'directory') # 项目的根节点
self.nodes = {repo_name: self.root} # 存储所有创建的节点
self.defined_symbols = {} # 用于存储函数和类的定义,key为name,value为定义路径列表
def _init_parser(self):
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
return parser
def parse(self):
self._build_tree(self.project_path, self.root)
def _build_tree(self, current_path, parent_node):
for item in os.listdir(current_path):
item_path = os.path.join(current_path, item)
if os.path.isdir(item_path):
# 创建目录节点
dir_node = self._create_node(item, 'directory', parent_node)
# 递归遍历子目录
self._build_tree(item_path, dir_node)
elif item.endswith(".py"):
# 创建文件模块节点并解析
module_node = self._create_node(item, 'module', parent_node)
self._parse_file(item_path, module_node)
def _create_node(self, name, node_type, parent_node):
# 去掉文件扩展名(仅对模块节点)
if node_type == 'module' and name.endswith('.py'):
name = name[:-3] # 去除 .py 后缀
# 生成全名
if parent_node.fullname:
full_name = f"{parent_node.fullname}.{name}"
else:
full_name = name
# 创建节点
node = Node(name, node_type, parent_fullname=parent_node.fullname)
parent_node.add_child(node)
self.nodes[full_name] = node
return node
def _parse_file(self, file_path, module_node):
with open(file_path, "r") as file:
file_content = file.read()
tree = self.parser.parse(bytes(file_content, "utf8"))
# 递归构建文件内的树形结构
self._extract_items(tree.root_node, file_path, module_node)
def _extract_items(self, node, file_path, parent_node):
for child in node.children:
if child.type == 'class_definition':
class_name = self._get_node_text(child.child_by_field_name('name'), file_path)
class_signature = self._get_node_text(child, file_path)
class_node = Node(class_name, 'class', self._get_code_segment(child, file_path), class_signature, parent_node.fullname)
parent_node.add_child(class_node)
# 注册类到 defined_symbols
self._register_symbol(class_name, class_node.fullname)
# 递归处理子节点
self._extract_items(child, file_path, class_node)
elif child.type == 'function_definition':
func_name = self._get_node_text(child.child_by_field_name('name'), file_path)
func_signature = self._get_node_text(child, file_path)
func_node = Node(func_name, 'function', self._get_code_segment(child, file_path), func_signature, parent_node.fullname)
parent_node.add_child(func_node)
# 注册函数到 defined_symbols
self._register_symbol(func_name, func_node.fullname)
# 递归处理子节点
self._extract_items(child, file_path, func_node)
else:
# 递归处理其他子节点
self._extract_items(child, file_path, parent_node)
def _register_symbol(self, name, fullname):
"""
注册函数或类的定义到 defined_symbols 字典中。
如果符号已经存在,添加到其定义路径列表中。
"""
if name in self.defined_symbols:
self.defined_symbols[name].append(fullname)
else:
self.defined_symbols[name] = [fullname]
def _get_node_text(self, node, file_path):
"""
提取 AST 节点对应的源代码文本,使用行列号而非字节
"""
if node is None:
return ""
start_line, start_column = node.start_point
end_line, end_column = node.end_point
with open(file_path, "r") as file:
file_lines = file.readlines()
if start_line == end_line:
# 同一行的情况,直接从 start_column 到 end_column
return file_lines[start_line][start_column:end_column].strip()
else:
# 跨多行的情况,处理起始行、中间行和结束行
extracted_text = []
extracted_text.append(file_lines[start_line][start_column:].strip())
for line in range(start_line + 1, end_line):
extracted_text.append(file_lines[line].strip())
extracted_text.append(file_lines[end_line][:end_column].strip())
return " ".join(extracted_text)
def _get_code_segment(self, node, file_path):
return self._get_node_text(node, file_path)
--- End of parsers/contains_parser.py ---
--- Start of parsers/call_parser.py ---
import os
from tree_sitter import Parser, Language
import tree_sitter_python as tspython
from lsp_client import LspClientWrapper
import logging
class CallParser:
def __init__(self, project_path, repo_name, code_graph, defined_symbols, lsp_client=None):
self.project_path = project_path
self.repo_name = repo_name
self.code_graph = code_graph
self.defined_symbols = defined_symbols # 从 ContainsParser 获取的符号定义
self.calls = [] # 存储调用关系 (caller, callee)
# 配置日志记录
self.logger = logging.getLogger('call_parser')
self.logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
# 初始化LSP客户端
# self.lsp_client = LspClientWrapper(self.project_path)
self.lsp_client = lsp_client
self.parser = self._init_parser()
def _init_parser(self):
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
return parser
def parse(self):
"""
解析项目中的所有 Python 文件,构建函数调用关系。
"""
py_files = self._get_py_files()
for file in py_files:
self._parse_file(file)
def _get_py_files(self):
"""
获取项目中所有的 Python 文件
"""
py_files = []
for root, _, files in os.walk(self.project_path):
for file in files:
if file.endswith(".py"):
py_files.append(os.path.join(root, file))
return py_files
def _parse_file(self, file_path):
with open(file_path, "r") as file:
file_content = file.read()
# 使用 tree-sitter 解析文件
tree = self.parser.parse(bytes(file_content, "utf8"))
# 构建模块名称
module_name = self._get_module_name(file_path)
# 递归分析调用关系
self._extract_calls(tree.root_node, file_path, module_name)
def _get_module_name(self, file_path):
"""根据文件路径生成模块全名"""
relative_path = os.path.relpath(file_path, self.project_path)
module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, '.')
return f"{self.repo_name}.{module_name}"
def _extract_calls(self, node, file_path, current_fullname):
"""
递归分析文件中的函数调用关系,保持与 ContainsParser 一致的自顶向下路径
"""
for child in node.children:
if child.type == 'class_definition' or child.type == 'function_definition':
# 处理类和函数,构建它们的完整路径
name = self._get_node_text(child.child_by_field_name('name'), file_path)
fullname = f"{current_fullname}.{name}"
# 递归处理子节点
self._extract_calls(child, file_path, fullname)
elif child.type == 'call':
# 处理函数调用
func_name_node = child.child_by_field_name('function')
if func_name_node:
callee_name = self._get_node_text(func_name_node, file_path)
self.logger.debug("-" * 50)
self._handle_call(func_name_node, callee_name, current_fullname, file_path)
# 递归处理链式调用
self._extract_calls(child, file_path, current_fullname)
else:
# 递归处理其他节点
self._extract_calls(child, file_path, current_fullname)
def _handle_call(self, func_name_node, callee_name, caller_fullname, file_path):
"""
处理全局函数和方法调用的统一逻辑
"""
self.logger.debug(f"Found call to {callee_name} in {caller_fullname}")
if func_name_node.type == 'attribute':
# 处理类方法或实例方法调用
object_node = func_name_node.child_by_field_name('object')
method_node = func_name_node.child_by_field_name('attribute')
object_name = self._get_node_text(object_node, file_path) if object_node else None
method_name = self._get_node_text(method_node, file_path) if method_node else None
self.logger.debug(f"Found method call: {object_name}.{method_name} in {caller_fullname}")
self._process_method_call(object_name, method_name, caller_fullname, file_path, object_node,method_node)
else:
# 处理全局函数调用
self._process_function_call(callee_name, caller_fullname, file_path, func_name_node)
def _process_function_call(self, callee_name, caller_fullname, file_path, func_name_node):
"""
处理全局函数调用
"""
if callee_name in self.defined_symbols:
definition_paths = self.defined_symbols[callee_name]
if len(definition_paths) == 1:
# 唯一定义,直接使用
callee_fullname = definition_paths[0]
self.calls.append((caller_fullname, callee_fullname))
self.logger.debug(f"Recorded function call: {caller_fullname} -> {callee_fullname}")
else:
# 多个定义路径,使用 LSP 确定具体定义
definition = self.lsp_client.find_definition(file_path, func_name_node.start_point[0], func_name_node.start_point[1])
self._resolve_call_with_lsp(caller_fullname, definition, definition_paths, callee_name)
else:
self.logger.debug(f"Call to external function {callee_name} in {caller_fullname}, skipping.")
def _resolve_call_with_lsp(self, caller_fullname, definition, definition_paths, callee_name):
"""
使用 LSP 来确定调用函数的位置
"""
self.logger.debug(f"Definition: {definition}")
if definition:
callee_fullname = self._get_fullname_from_definition(definition)
self.logger.debug(f"Resolved full function name: {callee_fullname}")
if callee_fullname and any(callee_fullname.startswith(path) for path in definition_paths):
self.calls.append((caller_fullname, callee_fullname))
self.logger.debug(f"Recorded call: {caller_fullname} -> {callee_fullname}")
else:
self.logger.warning(f"Could not determine the correct definition for {callee_name} called in {caller_fullname}")
def _process_method_call(self, object_name, method_name, caller_fullname, file_path, object_node, method_node):
"""
处理类方法或实例方法调用
"""
if object_name in self.defined_symbols:
class_definitions = self.defined_symbols[object_name]
if len(class_definitions) == 1:
# 静态方法调用,记录类方法调用
callee_fullname = f"{class_definitions[0]}.{method_name}"
self.calls.append((caller_fullname, callee_fullname))
self.logger.debug(f"Recorded static method call: {caller_fullname} -> {callee_fullname}")
else:
# 多个类定义,使用 LSP 确定具体定义
definition = self.lsp_client.find_definition(file_path, object_node.start_point[0], object_node.start_point[1])
self._resolve_call_with_lsp(caller_fullname, definition, class_definitions , object_name )
else:
if method_name in self.defined_symbols:
method_definitions = self.defined_symbols[method_name]
if len(method_definitions) == 1:
callee_fullname = method_definitions[0]
self.calls.append((caller_fullname, callee_fullname))
self.logger.debug(f"Recorded instance method call: {caller_fullname} -> {callee_fullname}")
else:
# 修改:这里传入 method_node 以便更精确地使用 LSP 确定定义
definition = self.lsp_client.find_definition(file_path, method_node.start_point[0], method_node.start_point[1])
self._resolve_call_with_lsp(caller_fullname, definition, method_definitions, method_name)
else:
self.logger.debug(f"Method {method_name} not found for object {object_name}, skipping.")
def _get_fullname_from_definition(self, definition):
"""
从 LSP 的定义响应中,通过 Tree-sitter 解析出完整的 namespace 路径
"""
definition = definition[0]
if not isinstance(definition, dict):
self.logger.error(f"Unexpected definition format: {definition}")
return None
# 获取 LSP 返回的位置信息
def_file_path = os.path.abspath(definition['uri'].replace('file://', ''))
start_line = definition['range']['start']['line']
start_column = definition['range']['start']['character']
end_line = definition['range']['end']['line']
end_column = definition['range']['end']['character']
self.logger.debug(f"Definition file: {def_file_path}, start: ({start_line}, {start_column}), end: ({end_line}, {end_column})")
# 根据文件路径解析对应的源代码
with open(def_file_path, "r") as file:
file_content = file.read()
# 使用 tree-sitter 解析文件,生成语法树
tree = self.parser.parse(bytes(file_content, "utf8"))
# 根据 LSP 的位置信息找到语法树中的精确节点
target_node = tree.root_node.descendant_for_point_range((start_line, start_column), (end_line, end_column))
if not target_node:
self.logger.error(f"Could not locate node at ({start_line}, {start_column}) in {def_file_path}")
return None
# 构建命名空间路径(包括文件相对项目根路径的模块路径)
namespace = self._build_namespace_from_node(target_node, def_file_path)
self.logger.debug(f"Resolved full function name: {namespace}")
return namespace
def _build_namespace_from_node(self, node, def_file_path):
"""
从指定的 AST 节点开始,向上遍历,构建 namespace 路径
"""
components = []
current_node = node
# 自底向上遍历,找到 function, class, module 等定义,构建完整路径
while current_node:
self.logger.debug(f"Current node: {current_node.type}")
if current_node.type in ['function_definition', 'class_definition', 'module']:
# 获取函数或类的名字
name_node = current_node.child_by_field_name('name')
if name_node:
self.logger.debug(f"Adding component: {self._get_node_text(name_node, def_file_path)}")
components.insert(0, self._get_node_text(name_node, def_file_path))
current_node = current_node.parent
# 获取文件相对于项目根路径的模块路径
module_path = self._get_module_name(def_file_path)
# 返回完整的命名空间:模块路径 + 代码内部命名空间
return f"{module_path}{'.' if components else ''}{'.'.join(components)}"
def _get_node_text(self, node, file_path):
"""
提取 AST 节点对应的源代码文本,使用行列号而非字节
"""
if node is None:
return ""
start_line, start_column = node.start_point
end_line, end_column = node.end_point
with open(file_path, "r") as file:
file_lines = file.readlines()
if start_line == end_line:
# 同一行的情况,直接从 start_column 到 end_column
return file_lines[start_line][start_column:end_column].strip()
else:
# 跨多行的情况,处理起始行、中间行和结束行
extracted_text = []
extracted_text.append(file_lines[start_line][start_column:].strip())
for line in range(start_line + 1, end_line):
extracted_text.append(file_lines[line].strip())
extracted_text.append(file_lines[end_line][:end_column].strip())
return " ".join(extracted_text)
--- End of parsers/call_parser.py ---
--- Start of parsers/import_parser.py ---
import os
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
import logging
class ImportParser:
def __init__(self, project_path, repo_name):
self.project_path = project_path
self.repo_name = repo_name
self.parser = self._init_parser()
self.imports = [] # 存储import关系 (importer, imported_module)
# 配置日志记录
self.logger = logging.getLogger('import_parser')
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def _init_parser(self):
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
return parser
def parse(self):
"""
解析项目中的所有 Python 文件并提取导入关系
"""
py_files = self._get_py_files()
self.logger.debug(f"Found {len(py_files)} Python files to parse.")
for file in py_files:
self.logger.debug(f"Parsing file: {file}")
self._parse_file(file)
def _get_py_files(self):
"""
获取项目中所有的 Python 文件
"""
py_files = []
for root, _, files in os.walk(self.project_path):
for file in files:
if file.endswith(".py"):
py_files.append(os.path.join(root, file))
return py_files
def _parse_file(self, file_path):
with open(file_path, "r") as file:
file_content = file.read()
tree = self.parser.parse(bytes(file_content, "utf8"))
# 构建模块名称
module_name = self._get_module_name(file_path)
self.logger.debug(f"Module name for file {file_path}: {module_name}")
# 递归分析import关系
self._extract_imports(tree.root_node, file_path, module_name)
def _get_module_name(self, file_path):
"""
生成基于文件路径的模块名称
"""
relative_path = os.path.relpath(file_path, self.project_path)
module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, '.')
return f"{self.repo_name}.{module_name}"
def _extract_imports(self, node, file_path, current_fullname):
"""
递归解析文件中的 import 语句
"""
for child in node.children:
if child.type == 'import_statement':
self.logger.debug(f"Found import statement in {current_fullname}")
# 处理import语句
self._handle_import_statement(child, current_fullname, file_path)
elif child.type == 'import_from_statement':
self.logger.debug(f"Found from-import statement in {current_fullname}")
# 处理 from ... import ... 语句
self._handle_from_import_statement(child, current_fullname, file_path)
else:
# 递归处理其他子节点
self._extract_imports(child, file_path, current_fullname)
def _handle_import_statement(self, node, current_fullname, file_path):
"""
处理普通的 import 语句
"""
for name_node in node.named_children:
if name_node.type == 'dotted_name' or name_node.type == 'identifier':
import_name = self._get_node_text(name_node, file_path)
self.imports.append((current_fullname, import_name)) # 确保是 (importer, imported_module)
self.logger.debug(f"Recorded import: {current_fullname} imports {import_name}")
# 处理别名(as导入)
alias_node = node.child_by_field_name('alias')
if alias_node:
alias_name = self._get_node_text(alias_node, file_path)
self.imports.append((current_fullname, alias_name)) # 确保是 (importer, alias_name)
self.logger.debug(f"Recorded alias import: {current_fullname} imports {alias_name}")
def _handle_from_import_statement(self, node, current_fullname, file_path):
"""
处理 from ... import ... 语句
"""
module_name_node = node.child_by_field_name('module')
module_name = self._get_node_text(module_name_node, file_path) if module_name_node else None
if module_name:
# 记录从模块导入的关系
self.imports.append((current_fullname, module_name)) # 确保是 (importer, module_name)
self.logger.debug(f"Recorded from-import: {current_fullname} imports from {module_name}")
# 处理具体导入的元素
for import_child in node.named_children:
if import_child.type == 'dotted_name' or import_child.type == 'identifier':
import_element = self._get_node_text(import_child, file_path)
full_import_path = f"{module_name}.{import_element}"
self.imports.append((current_fullname, full_import_path)) # 确保是 (importer, full_import_path)
self.logger.debug(f"Recorded from-import element: {current_fullname} imports {full_import_path}")
# 处理别名导入
alias_node = node.child_by_field_name('alias')
if alias_node:
alias_name = self._get_node_text(alias_node, file_path)
full_alias_path = f"{module_name}.{alias_name}"
self.imports.append((current_fullname, full_alias_path)) # 确保是 (importer, full_alias_path)
self.logger.debug(f"Recorded alias from-import: {current_fullname} imports {full_alias_path}")
def _get_node_text(self, node, file_path):
"""
提取 AST 节点对应的源代码文本,使用行列号而非字节
"""
if node is None:
return ""
start_line, start_column = node.start_point
end_line, end_column = node.end_point
with open(file_path, "r") as file:
file_lines = file.readlines()
if start_line == end_line:
# 同一行的情况,直接从 start_column 到 end_column
return file_lines[start_line][start_column:end_column].strip()
else:
# 跨多行的情况,处理起始行、中间行和结束行
extracted_text = []
extracted_text.append(file_lines[start_line][start_column:].strip())
for line in range(start_line + 1, end_line):
extracted_text.append(file_lines[line].strip())
extracted_text.append(file_lines[end_line][:end_column].strip())
return " ".join(extracted_text)
--- End of parsers/import_parser.py ---
--- Start of parsers/__init__.py ---
--- End of parsers/__init__.py ---