1
+ # coding=utf-8
2
+ import ast
3
+ from collections import deque
4
+ from zss import Node , simple_distance
5
+
6
+
7
+ def _jaccard_distance (self ):
8
+ # jaccard相似度
9
+ node_list1 , node_list2 = self ._get_node_lists ()
10
+ node_set1 = set (node_list1 )
11
+ node_set2 = set (node_list2 )
12
+ return 1.0 * len (node_set1 & node_set2 ) / len (node_set1 | node_set2 )
13
+
14
+
15
+ def _fake_anti_uni_distance (node1 , node2 ):
16
+ stack1 = deque ([node1 ])
17
+ stack2 = deque ([node2 ])
18
+ same = 0
19
+ diff = 0
20
+ while stack1 or stack2 :
21
+ if stack1 :
22
+ _node1 = stack1 .popleft ()
23
+ if type (_node1 ).__name__ == 'Load' :
24
+ try :
25
+ _node1 = stack1 .popleft ()
26
+ except IndexError :
27
+ _node1 = None
28
+ else :
29
+ _node1 = None
30
+ if stack2 :
31
+ _node2 = stack2 .popleft ()
32
+ if type (_node2 ).__name__ == 'Load' :
33
+ try :
34
+ _node2 = stack2 .popleft ()
35
+ except IndexError :
36
+ _node2 = None
37
+ else :
38
+ _node2 = None
39
+ if type (_node1 ).__name__ == type (_node2 ).__name__ :
40
+ same += 1
41
+ else :
42
+ diff += 1
43
+ if _node1 and _node2 :
44
+ stack1 .extend (ast .iter_child_nodes (_node1 ))
45
+ stack2 .extend (ast .iter_child_nodes (_node2 ))
46
+ return 1.0 * same / (same + diff )
47
+
48
+
49
+ def _tree_edit_distance (node1 , node2 ):
50
+
51
+ def get_dtc_tree (node ):
52
+ distance_node = Node (type (node ).__name__ )
53
+ tree_size = _dfs (node , distance_node )
54
+ return distance_node , tree_size
55
+
56
+ distance_node1 , tree_size1 = get_dtc_tree (node1 )
57
+ distance_node2 , tree_size2 = get_dtc_tree (node2 )
58
+ distance = simple_distance (distance_node1 , distance_node2 )
59
+ return 1 - 1.0 * distance / max (tree_size1 , tree_size2 )
60
+
61
+
62
+ def _dfs (root , dtc_node = None ):
63
+ _tree_size = 0
64
+ nodes = ast .iter_child_nodes (root )
65
+ for _node in nodes :
66
+ if type (root ).__name__ == 'Load' :
67
+ continue
68
+ _tree_size += 1
69
+ if dtc_node is not None :
70
+ _dtc_node = Node (type (_node ).__name__ )
71
+ dtc_node .addkid (_dtc_node )
72
+ else :
73
+ _dtc_node = None
74
+ _tree_size += _dfs (_node , _dtc_node )
75
+ return _tree_size
76
+
77
+
78
+ def _bfs (root , mass ):
79
+ stack = deque ([root ])
80
+ big_nodes = []
81
+ while stack :
82
+ node = stack .popleft ()
83
+ node_name = type (node ).__name__
84
+ if node_name == 'Load' :
85
+ continue
86
+ distance_node = Node (node_name )
87
+ tree_size = _dfs (node , distance_node )
88
+ if tree_size >= mass :
89
+ big_nodes .append (node )
90
+ stack .extend (ast .iter_child_nodes (node ))
91
+ return big_nodes
92
+
93
+
94
+ class _NodeList (ast .NodeVisitor , list ):
95
+
96
+ # 深度优先遍历抽象语法树保存到列表
97
+
98
+ def visit_Load (self , node ):
99
+ pass
100
+
101
+ def visit_Name (self , node ):
102
+ self .append ('Name' )
103
+
104
+ def generic_visit (self , node ):
105
+ self .append (type (node ).__name__ )
106
+ ast .NodeVisitor .generic_visit (self , node )
107
+
108
+
109
+ class _CodeSim :
110
+
111
+ def __init__ (self , file_name1 , file_name2 ):
112
+ with open (file_name1 ) as f :
113
+ self ._code1 = f .read ()
114
+ with open (file_name2 ) as f :
115
+ self ._code2 = f .read ()
116
+
117
+ def _get_node_lists (self ):
118
+ node1 , node2 = self ._get_nodes ()
119
+ node_list1 , node_list2 = _NodeList (), _NodeList ()
120
+ node_list1 .generic_visit (node1 )
121
+ node_list2 .generic_visit (node2 )
122
+ return node_list1 , node_list2
123
+
124
+ @property
125
+ def fake_anti_uni_distance (self ):
126
+ node1 , node2 = self ._get_nodes ()
127
+ root1_size = _dfs (node1 )
128
+ root2_size = _dfs (node2 )
129
+ mass = min (root1_size , root2_size ) / 3
130
+ sub_node1_list = _bfs (node1 , mass )
131
+ sub_node2_list = _bfs (node2 , mass )
132
+ sims = []
133
+ for sub_node1 in sub_node1_list :
134
+ for sub_node2 in sub_node2_list :
135
+ sims .append ( _fake_anti_uni_distance (sub_node1 , sub_node2 ))
136
+ return max (sims )
137
+
138
+ def _get_nodes (self ):
139
+ node1 , node2 = ast .parse (self ._code1 ), ast .parse (self ._code2 )
140
+ return node1 , node2
141
+
142
+ @property
143
+ def jaccard_distance (self ):
144
+ # jaccard相似度
145
+ node_list1 , node_list2 = self ._get_node_lists ()
146
+ node_set1 = set (node_list1 )
147
+ node_set2 = set (node_list2 )
148
+ return 1.0 * len (node_set1 & node_set2 ) / len (node_set1 | node_set2 )
149
+
150
+ @property
151
+ def tree_edit_distance (self ):
152
+ node1 , node2 = self ._get_nodes ()
153
+ return _tree_edit_distance (node1 , node2 )
154
+
155
+
156
+ def code_sim (file_name1 , file_name2 , method = 'tree_edit' ):
157
+ if method not in ('jaccard' , 'tree_edit' , 'fake_anti_uni' ):
158
+ raise ValueError ('method must be jaccard or tree_edit or fake_anti_uni' )
159
+ return getattr (_CodeSim (file_name1 , file_name2 ), method + '_distance' )
0 commit comments