Skip to content

Commit

Permalink
TextRankKeyword新增了一些接口,优化堆排序以实现TopN
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Nov 22, 2015
1 parent 8f34f10 commit 915d3ee
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 20 deletions.
98 changes: 98 additions & 0 deletions src/main/java/com/hankcs/hanlp/algoritm/MaxHeap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* <summary></summary>
* <author>He Han</author>
* <email>[email protected]</email>
* <create-date>2015/11/22 13:23</create-date>
*
* <copyright file="MaxHeap.java" company="码农场">
* Copyright (c) 2008-2015, 码农场. All Right Reserved, http://www.hankcs.com/
* This source is subject to Hankcs. Please contact Hankcs to get more information.
* </copyright>
*/
package com.hankcs.hanlp.algoritm;

import java.util.*;

/**
* 用固定容量的优先队列模拟的最大堆,用于解决求topN大的问题
*
* @author hankcs
*/
public class MaxHeap<E>
{
/**
* 优先队列
*/
private PriorityQueue<E> queue;
/**
* 堆的最大容量
*/
private int maxSize;

/**
* 构造最大堆
* @param maxSize 保留多少个元素
* @param comparator 比较器,生成最大堆使用o1-o2,生成最小堆使用o2-o1,并修改 e.compareTo(peek) 比较规则
*/
public MaxHeap(int maxSize, Comparator<E> comparator)
{
if (maxSize <= 0)
throw new IllegalArgumentException();
this.maxSize = maxSize;
this.queue = new PriorityQueue<E>(maxSize, comparator);
}

/**
* 添加一个元素
* @param e 元素
* @return 是否添加成功
*/
public boolean add(E e)
{
if (queue.size() < maxSize)
{ // 未达到最大容量,直接添加
queue.add(e);
return true;
}
else
{ // 队列已满
E peek = queue.peek();
if (queue.comparator().compare(e, peek) > 0)
{ // 将新元素与当前堆顶元素比较,保留较小的元素
queue.poll();
queue.add(e);
return true;
}
}
return false;
}

/**
* 添加许多元素
* @param collection
*/
public MaxHeap<E> addAll(Collection<E> collection)
{
for (E e : collection)
{
add(e);
}

return this;
}

/**
* 转为有序列表,自毁性操作
* @return
*/
public List<E> toList()
{
ArrayList<E> list = new ArrayList<E>(queue.size());
while (!queue.isEmpty())
{
list.add(0, queue.poll());
}

return list;
}
}
18 changes: 18 additions & 0 deletions src/main/java/com/hankcs/hanlp/summary/KeywordExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
package com.hankcs.hanlp.summary;

import com.hankcs.hanlp.dictionary.stopword.CoreStopWordDictionary;
import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.tokenizer.StandardTokenizer;

/**
* 提取关键词的基类
* @author hankcs
*/
public class KeywordExtractor
{
/**
* 默认分词器
*/
Segment defaultSegment = StandardTokenizer.SEGMENT;

/**
* 是否应当将这个term纳入计算,词性属于名词、动词、副词、形容词
*
Expand Down Expand Up @@ -61,4 +68,15 @@ public boolean shouldInclude(Term term)

return false;
}

/**
* 设置关键词提取器使用的分词器
* @param segment 任何开启了词性标注的分词器
* @return 自己
*/
public KeywordExtractor setSegment(Segment segment)
{
defaultSegment = segment;
return this;
}
}
82 changes: 62 additions & 20 deletions src/main/java/com/hankcs/hanlp/summary/TextRankKeyword.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.hankcs.hanlp.summary;


import com.hankcs.hanlp.algoritm.MaxHeap;
import com.hankcs.hanlp.seg.common.Term;
import com.hankcs.hanlp.tokenizer.StandardTokenizer;

Expand Down Expand Up @@ -40,10 +41,67 @@ public static List<String> getKeywordList(String document, int size)
return textRankKeyword.getKeyword(document);
}

/**
* 提取关键词
* @param content
* @return
*/
public List<String> getKeyword(String content)
{
List<Term> termList = StandardTokenizer.segment(content);
List<String> wordList = new ArrayList<String>();
Set<Map.Entry<String, Float>> entrySet = getTermAndRank(content, nKeyword).entrySet();
List<String> result = new ArrayList<String>(entrySet.size());
for (Map.Entry<String, Float> entry : entrySet)
{
result.add(entry.getKey());
}
return result;
}

/**
* 返回全部分词结果和对应的rank
* @param content
* @return
*/
public Map<String,Float> getTermAndRank(String content)
{
assert content != null;
List<Term> termList = defaultSegment.seg(content);
return getRank(termList);
}

/**
* 返回分数最高的前size个分词结果和对应的rank
* @param content
* @param size
* @return
*/
public Map<String,Float> getTermAndRank(String content, Integer size)
{
Map<String, Float> map = getTermAndRank(content);
Map<String, Float> result = new LinkedHashMap<String, Float>();
for (Map.Entry<String, Float> entry : new MaxHeap<Map.Entry<String, Float>>(size, new Comparator<Map.Entry<String, Float>>()
{
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2)
{
return o1.getValue().compareTo(o2.getValue());
}
}).addAll(map.entrySet()).toList())
{
result.put(entry.getKey(), entry.getValue());
}

return result;
}

/**
* 使用已经分好的词来计算rank
* @param termList
* @return
*/
public Map<String,Float> getRank(List<Term> termList)
{
List<String> wordList = new ArrayList<String>(termList.size());
for (Term t : termList)
{
if (shouldInclude(t))
Expand Down Expand Up @@ -102,23 +160,7 @@ public List<String> getKeyword(String content)
score = m;
if (max_diff <= min_diff) break;
}
List<Map.Entry<String, Float>> entryList = new ArrayList<Map.Entry<String, Float>>(score.entrySet());
Collections.sort(entryList, new Comparator<Map.Entry<String, Float>>()
{
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2)
{
return o2.getValue().compareTo(o1.getValue());
}
});
// System.out.println(entryList);
int limit = Math.min(nKeyword, entryList.size());
List<String> result = new ArrayList<String>(limit);
for (int i = 0; i < limit; ++i)
{
result.add(entryList.get(i).getKey()) ;
}
return result;
}

return score;
}
}
44 changes: 44 additions & 0 deletions src/test/java/com/hankcs/test/algorithm/MaxHeapTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.hankcs.test.algorithm;

import com.hankcs.hanlp.algoritm.MaxHeap;
import junit.framework.TestCase;

import java.util.Collections;
import java.util.Comparator;

public class MaxHeapTest extends TestCase
{
final MaxHeap<Integer> heap = new MaxHeap<Integer>(5, new Comparator<Integer>()
{
@Override
public int compare(Integer o1, Integer o2)
{
return o1.compareTo(o2);
}
});

public void testAdd() throws Exception
{
heap.add(1);
heap.add(3);
heap.add(5);
heap.add(7);
heap.add(9);
heap.add(8);
heap.add(6);
heap.add(4);
heap.add(2);
heap.add(0);
}

public void testAddAll() throws Exception
{

}

public void testToList() throws Exception
{
testAdd();
System.out.println(heap.toList());
}
}

0 comments on commit 915d3ee

Please sign in to comment.