forked from baidu/Familia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdocument.cpp
110 lines (98 loc) · 3.3 KB
/
document.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "familia/document.h"
#include "familia/util.h"
using std::vector;
using std::string;
namespace familia {
// -------------LDA Begin---------------
void LDADoc::init(int num_topics) {
_num_topics = num_topics;
_num_accum = 0; // 清空采样累积次数
_tokens.clear();
_topic_sum.resize(_num_topics, 0);
_accum_topic_sum.resize(_num_topics, 0);
}
void LDADoc::add_token(const Token& token) {
CHECK_GE(token.topic, 0) << "Topic " << token.topic << " out of range!";
CHECK_LT(token.topic, _num_topics) << "Topic " << token.topic << " out of range!";
_tokens.push_back(token);
_topic_sum[token.topic]++;
}
void LDADoc::set_topic(int index, int new_topic) {
CHECK_GE(new_topic, 0) << "Topic " << new_topic << " out of range!";
CHECK_LT(new_topic, _num_topics) << "Topic " << new_topic << " out of range!";
int old_topic = _tokens[index].topic;
if (new_topic == old_topic) {
return;
}
_tokens[index].topic = new_topic;
_topic_sum[old_topic]--;
_topic_sum[new_topic]++;
}
void LDADoc::sparse_topic_dist(vector<Topic>& topic_dist, bool sort) const {
topic_dist.clear();
size_t sum = 0;
for (int i = 0; i < _num_topics; ++i) {
sum += _accum_topic_sum[i];
}
if (sum == 0) {
return; // 返回空结果
}
for (int i = 0; i < _num_topics; ++i) {
// 跳过0的的项,得到稀疏主题分布
if (_accum_topic_sum[i] == 0) {
continue;
}
topic_dist.push_back({i, _accum_topic_sum[i] * 1.0 / sum});
}
if (sort) {
std::sort(topic_dist.begin(), topic_dist.end());
}
}
void LDADoc::dense_topic_dist(vector<float>& dense_dist) const {
dense_dist.clear();
dense_dist.resize(_num_topics, 0.0);
// 若文档长度为0,则范围0向量
if (size() == 0) {
return;
}
for (int i = 0; i < _num_topics; ++i) {
dense_dist[i] = (_accum_topic_sum[i] * 1.0/ _num_accum + _alpha)
/ (size() + _alpha * _num_topics);
}
}
void LDADoc::accumulate_topic_sum() {
for (int i = 0; i < _num_topics; ++i) {
_accum_topic_sum[i] += _topic_sum[i];
}
_num_accum += 1;
}
// -------------LDA End---------------
// --------Sentence-LDA Begin---------
void SLDADoc::init(int num_topics) {
_num_topics = num_topics;
_sentences.clear();
_topic_sum.resize(_num_topics, 0);
_accum_topic_sum.resize(_num_topics, 0);
}
void SLDADoc::add_sentence(const Sentence& sent) {
CHECK_GE(sent.topic, 0) << "Topic " << sent.topic << " out of range!";
CHECK_LT(sent.topic, _num_topics) << "Topic " << sent.topic << " out of range!";
_sentences.push_back(sent);
_topic_sum[sent.topic]++;
}
void SLDADoc::set_topic(int index, int new_topic) {
CHECK_GE(new_topic, 0) << "Topic " << new_topic << " out of range!";
CHECK_LT(new_topic, _num_topics) << "Topic " << new_topic << " out of range!";
int old_topic = _sentences[index].topic;
if (new_topic == old_topic) {
return;
}
_sentences[index].topic = new_topic;
_topic_sum[old_topic]--;
_topic_sum[new_topic]++;
}
// --------Sentence-LDA End---------
} // namespace familia