forked from carefree0910/MachineLearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9b86af1
commit a244fd7
Showing
7 changed files
with
257 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import pickle | ||
import numpy as np | ||
|
||
|
||
def gen_dataset(dat_path): | ||
if not os.path.isfile(dat_path): | ||
print("\nGenerating Dataset...") | ||
folders = os.listdir("_Data") | ||
label_dic = [folder for folder in folders if os.path.isdir(os.path.join("_Data", folder))] | ||
folders_path = [os.path.join("_Data", folder) for folder in label_dic] | ||
x, y = [], [] | ||
for i, folder in enumerate(folders_path): | ||
for txt in os.listdir(folder): | ||
with open(os.path.join(folder, txt), "r", encoding="utf-8") as file: | ||
try: | ||
x.append(file.read().strip().split()) | ||
y.append(i) | ||
except Exception as err: | ||
print(err) | ||
np.save(os.path.join("_Data", "LABEL_DIC.npy"), label_dic) | ||
with open(dat_path, "wb") as file: | ||
pickle.dump((x, y), file) | ||
print("Done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
import pickle | ||
|
||
from SkRun import run | ||
|
||
if not os.path.isfile("dataset.dat"): | ||
print("Processing data...") | ||
rs, labels = [], [] | ||
data_folder = "_Data" | ||
for i, folder in enumerate(os.listdir(data_folder)): | ||
for txt_file in os.listdir(os.path.join(data_folder, folder)): | ||
with open(os.path.join(data_folder, folder, txt_file), "r", encoding="utf-8") as file: | ||
try: | ||
rs.append(file.readline().split()) | ||
labels.append(i) | ||
except UnicodeDecodeError as err: | ||
print(err) | ||
with open("dataset.dat", "wb") as file: | ||
pickle.dump((rs, labels), file) | ||
print("Done") | ||
|
||
print("Running Naive Bayes written by myself...") | ||
os.system("python _NB.py") | ||
|
||
print("Running Naive Bayes in sklearn...") | ||
run("Naive Bayes") | ||
|
||
print("Running LinearSVM in sklearn") | ||
run("SVM") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[1]:https://github.com/carefree0910/TextClassification/ | ||
|
||
# Text Classification | ||
|
||
Dependency:numpy、sklearn、matplotlib | ||
|
||
A Stand-alone version for this project can be found [here][1] | ||
|
||
+ Put your training set **FOLDERS** into '_Data' folder | ||
+ Each folder name should be treated as the 'label' of the texts contained in the folder | ||
+ Each folder should contain a number of txt files | ||
+ One sentence per txt file | ||
+ For some languages (e.g. Chinese), sentences should be segmented | ||
+ Run 'Main.py' or 'SkRun.py' (**Recommended**) ! | ||
+ Running them in PyCharm will be perfectly correct, but if you want to run them by double-clicking, some import statements may need to be modified | ||
+ You may want a [Stand-alone version][1] for this project, where you can run it by double-clicking 'Main.py' or 'SkRun.py' (**Recommended**) without modifying any import statement! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import math | ||
import pickle | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn.feature_extraction.text import CountVectorizer | ||
from sklearn.feature_extraction.text import TfidfTransformer | ||
from sklearn import metrics | ||
|
||
from _SKlearn.NaiveBayes import SKMultinomialNB | ||
from _SKlearn.SVM import SKSVM, SKLinearSVM | ||
from _Dist.TextClassification.GenDataset import gen_dataset | ||
from Util.ProgressBar import ProgressBar | ||
|
||
|
||
def main(clf): | ||
dat_path = os.path.join("_Data", "dataset.dat") | ||
gen_dataset(dat_path) | ||
with open(dat_path, "rb") as _file: | ||
x, y = pickle.load(_file) | ||
x = [" ".join(sentence) for sentence in x] | ||
_indices = np.random.permutation(len(x)) | ||
x = list(np.array(x)[_indices]) | ||
y = list(np.array(y)[_indices]) | ||
data_len = len(x) | ||
batch_size = math.ceil(data_len * 0.1) | ||
_acc_lst, y_results = [], [] | ||
bar = ProgressBar(max_value=10, name=str(clf)) | ||
bar.start() | ||
for i in range(10): | ||
_next = (i + 1) * batch_size if i != 9 else data_len | ||
x_train = x[:i * batch_size] + x[(i + 1) * batch_size:] | ||
y_train = y[:i * batch_size] + y[(i + 1) * batch_size:] | ||
x_test, y_test = x[i * batch_size:_next], y[i * batch_size:_next] | ||
count_vec = CountVectorizer() | ||
counts_train = count_vec.fit_transform(x_train) | ||
x_test = count_vec.transform(x_test) | ||
tfidf_transformer = TfidfTransformer() | ||
x_train = tfidf_transformer.fit_transform(counts_train) | ||
clf.fit(x_train, y_train) | ||
y_pred = clf.predict(x_test) | ||
_acc_lst.append(clf.acc(y_test, y_pred)) | ||
y_results.append([y_test, y_pred]) | ||
del x_train, y_train, x_test, y_test, y_pred | ||
bar.update() | ||
return _acc_lst, y_results | ||
|
||
|
||
def run(clf): | ||
acc_records, y_records = [], [] | ||
bar = ProgressBar(max_value=10, name="Main") | ||
bar.start() | ||
for _ in range(10): | ||
if clf == "Naive Bayes": | ||
_clf = SKMultinomialNB(alpha=0.1) | ||
elif clf == "Non-linear SVM": | ||
_clf = SKSVM() | ||
else: | ||
_clf = SKLinearSVM() | ||
rs = main(_clf) | ||
acc_records.append(rs[0]) | ||
y_records += rs[1] | ||
bar.update() | ||
acc_records = np.array(acc_records) * 100 | ||
|
||
plt.figure() | ||
plt.boxplot(acc_records, vert=False, showmeans=True) | ||
plt.show() | ||
|
||
from Util.DataToolkit import DataToolkit | ||
idx = np.argmax(acc_records) # type: int | ||
print(metrics.classification_report(y_records[idx][0], y_records[idx][1], target_names=np.load(os.path.join( | ||
"_Data", "LABEL_DIC.npy" | ||
)))) | ||
toolkit = DataToolkit(acc_records[np.argmax(np.average(acc_records, axis=1))]) | ||
print("Acc Mean : {:8.6}".format(toolkit.mean)) | ||
print("Acc Variance : {:8.6}".format(toolkit.variance)) | ||
print("Done") | ||
|
||
if __name__ == '__main__': | ||
run("SVM") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import math | ||
import pickle | ||
import numpy as np | ||
from collections import Counter | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn import metrics | ||
|
||
from _Dist.TextClassification.GenDataset import gen_dataset | ||
from Util.ProgressBar import ProgressBar | ||
|
||
|
||
def pick_best(sentence, prob_lst): | ||
rs = [prob["prior"] for prob in prob_lst] | ||
for j, _prob_dic in enumerate(prob_lst): | ||
for word in sentence: | ||
if word in _prob_dic: | ||
rs[j] *= _prob_dic[word] | ||
else: | ||
rs[j] /= _prob_dic["null"] | ||
return np.argmax(rs) | ||
|
||
|
||
def train(power=6.46): | ||
dat_path = os.path.join("_Data", "dataset.dat") | ||
gen_dataset(dat_path) | ||
with open(dat_path, "rb") as _file: | ||
x, y = pickle.load(_file) | ||
_indices = np.random.permutation(len(x)) | ||
x = [x[i] for i in _indices] | ||
y = [y[i] for i in _indices] | ||
data_len = len(x) | ||
batch_size = math.ceil(data_len*0.1) | ||
_test_sets, _prob_lists = [], [] | ||
_total = sum([len(sentence) for sentence in x]) | ||
for i in range(10): | ||
rs = [[] for _ in range(9)] | ||
_next = (i+1)*batch_size if i != 9 else data_len | ||
x_train = x[:i * batch_size] + x[(i + 1) * batch_size:] | ||
y_train = y[:i * batch_size] + y[(i + 1) * batch_size:] | ||
x_test, y_test = x[i*batch_size:_next], y[i*batch_size:_next] | ||
for xx, yy in zip(x_train, y_train): | ||
rs[yy] += xx | ||
_counters = [Counter(group) for group in rs] | ||
_test_sets.append((x_test, y_test)) | ||
_prob_lst = [] | ||
for counter in _counters: | ||
_sum = sum(counter.values()) | ||
_prob_lst.append({ | ||
key: value / _sum for key, value in counter.items() | ||
}) | ||
_prob_lst[-1]["null"] = _sum * 2 ** power | ||
_prob_lst[-1]["prior"] = _sum / _total | ||
_prob_lists.append(_prob_lst) | ||
return _test_sets, _prob_lists | ||
|
||
|
||
def test(test_sets, prob_lists): | ||
acc_lst = [] | ||
for i in range(10): | ||
_prob_lst = prob_lists[i] | ||
x_test, y_test = test_sets[i] | ||
y_pred = np.array([pick_best(sentence, _prob_lst) for sentence in x_test]) | ||
y_test = np.array(y_test) | ||
acc_lst.append(100 * np.sum(y_pred == y_test) / len(y_pred)) | ||
return acc_lst | ||
|
||
if __name__ == '__main__': | ||
_rs, epoch = [], 10 | ||
bar = ProgressBar(max_value=epoch, name="_NB") | ||
bar.start() | ||
for _ in range(epoch): | ||
_rs.append(test(*train())) | ||
bar.update() | ||
_rs = np.array(_rs).T | ||
# x_base = np.arange(len(_rs[0])) + 1 | ||
# plt.figure() | ||
# for _acc_lst in _rs: | ||
# plt.plot(x_base, _acc_lst) | ||
# plt.plot(x_base, np.average(_rs, axis=0), linewidth=4, label="Average") | ||
# plt.xlim(1, epoch) | ||
# plt.ylim(np.min(_rs), np.max(_rs)+2) | ||
# plt.legend(loc="lower right") | ||
# plt.show() | ||
plt.figure() | ||
plt.boxplot(_rs.T, vert=False, showmeans=True) | ||
plt.show() | ||
_rs = np.array(_rs).ravel() | ||
print("Acc Mean : {:8.6}".format(np.average(_rs))) | ||
print("Acc Variance : {:8.6}".format(np.average((_rs - np.average(_rs)) ** 2))) | ||
|
||
sets, lists = train() | ||
acc_list = test(sets, lists) | ||
idx = np.argmax(acc_list) # type: int | ||
lst_, (x_, y_) = lists[idx], sets[idx] | ||
print(metrics.classification_report(y_, [ | ||
pick_best(sentence, lst_) for sentence in x_ | ||
], target_names=np.load(os.path.join("_Data", "LABEL_DIC.npy")))) | ||
|
||
print("Done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
from Util.Bases import ClassifierBase | ||
from Util.Metas import SKCompatibleMeta | ||
|
||
from sklearn.svm import SVC | ||
from sklearn.svm import SVC, LinearSVC | ||
|
||
|
||
class SKSVM(SVC, ClassifierBase, metaclass=SKCompatibleMeta): | ||
pass | ||
|
||
|
||
class SKLinearSVM(LinearSVC, ClassifierBase, metaclass=SKCompatibleMeta): | ||
pass |