Skip to content

Commit

Permalink
Merge pull request wangzheng0822#192 from KPatr1ck/trie
Browse files Browse the repository at this point in the history
Trie implementation in python
  • Loading branch information
wangzheng0822 authored Dec 12, 2018
2 parents 9066ab5 + e1e0673 commit 9a85b1e
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions python/35_trie/trie_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-

from queue import Queue
import pygraphviz as pgv

OUTPUT_PATH = 'E:/'


class Node:
def __init__(self, c):
self.data = c
self.is_ending_char = False
# 使用有序数组,降低空间消耗,支持更多字符
self.children = []

def insert_child(self, c):
"""
插入一个子节点
:param c:
:return:
"""
v = ord(c)
idx = self._find_insert_idx(v)
length = len(self.children)

node = Node(c)
if idx == length:
self.children.append(node)
else:
self.children.append(None)
for i in range(length, idx, -1):
self.children[i] = self.children[i-1]
self.children[idx] = node

def get_child(self, c):
"""
搜索子节点并返回
:param c:
:return:
"""
start = 0
end = len(self.children) - 1
v = ord(c)

while start <= end:
mid = (start + end)//2
if v == ord(self.children[mid].data):
return self.children[mid]
elif v < ord(self.children[mid].data):
end = mid - 1
else:
start = mid + 1
# 找不到返回None
return None

def _find_insert_idx(self, v):
"""
二分查找,找到有序数组的插入位置
:param v:
:return:
"""
start = 0
end = len(self.children) - 1

while start <= end:
mid = (start + end)//2
if v < ord(self.children[mid].data):
end = mid - 1
else:
if mid + 1 == len(self.children) or v < ord(self.children[mid+1].data):
return mid + 1
else:
start = mid + 1
# v < self.children[0]
return 0

def __repr__(self):
return 'node value: {}'.format(self.data) + '\n' \
+ 'children:{}'.format([n.data for n in self.children])


class Trie:
def __init__(self):
self.root = Node(None)

def gen_tree(self, string_list):
"""
创建trie树
1. 遍历每个字符串的字符,从根节点开始,如果没有对应子节点,则创建
2. 每一个串的末尾节点标注为红色(is_ending_char)
:param string_list:
:return:
"""
for string in string_list:
n = self.root
for c in string:
if n.get_child(c) is None:
n.insert_child(c)
n = n.get_child(c)
n.is_ending_char = True

def search(self, pattern):
"""
搜索
1. 遍历模式串的字符,从根节点开始搜索,如果途中子节点不存在,返回False
2. 遍历完模式串,则说明模式串存在,再检查树中最后一个节点是否为红色,是
则返回True,否则False
:param pattern:
:return:
"""
assert type(pattern) is str and len(pattern) > 0

n = self.root
for c in pattern:
if n.get_child(c) is None:
return False
n = n.get_child(c)

return True if n.is_ending_char is True else False

def draw_img(self, img_name='Trie.png'):
"""
画出trie树
:param img_name:
:return:
"""
if self.root is None:
return

tree = pgv.AGraph('graph foo {}', strict=False, directed=False)

# root
nid = 0
color = 'black'
tree.add_node(nid, color=color, label='None')

q = Queue()
q.put((self.root, nid))
while not q.empty():
n, pid = q.get()
for c in n.children:
nid += 1
q.put((c, nid))
color = 'red' if c.is_ending_char is True else 'black'
tree.add_node(nid, color=color, label=c.data)
tree.add_edge(pid, nid)

tree.graph_attr['epsilon'] = '0.01'
tree.layout('dot')
tree.draw(OUTPUT_PATH + img_name)
return True


if __name__ == '__main__':
string_list = ['abc', 'abd', 'abcc', 'accd', 'acml', 'P@trick', 'data', 'structure', 'algorithm']

print('--- gen trie ---')
print(string_list)
trie = Trie()
trie.gen_tree(string_list)
# trie.draw_img()

print('\n')
print('--- search result ---')
search_string = ['a', 'ab', 'abc', 'abcc', 'abe', 'P@trick', 'P@tric', 'Patrick']
for ss in search_string:
print('[pattern]: {}'.format(ss), '[result]: {}'.format(trie.search(ss)))

0 comments on commit 9a85b1e

Please sign in to comment.