forked from Tencent/NeuralNLP-NeuralClassifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
coderbyr
committed
Jul 4, 2019
1 parent
aa5bc1b
commit 52a7ede
Showing
40 changed files
with
36,363 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
![NeuralClassifier Logo](readme/logo.png) | ||
|
||
|
||
# NeuralClassifier: An Open-source Neural Hierarchical Multi-label Text Classification Toolkit | ||
|
||
## Introduction | ||
|
||
NeuralClassifier is designed for quick implementation of neural models for hierarchical multi-label classification task, which is more challenging and common in real-world scenarios. A salient feature is that NeuralClassifier currently provides a variety of text encoders, such as FastText, TextCNN, TextRNN, RCNN, VDCNN, DPCNN, DRNN, AttentiveConvNet and Transformer encoder, etc. It also supports other text classification scenarios, including binary-class and multi-class classification. It is built on [PyTorch](https://pytorch.org/). Experiments show that models built in our toolkit achieve comparable performance with reported results in the literature. | ||
|
||
## Support tasks | ||
|
||
* Binary-class text classifcation | ||
* Multi-class text classification | ||
* Multi-label text classification | ||
* Hiearchical (multi-label) text classification (HMC) | ||
|
||
## Support text encoders | ||
|
||
* TextCNN ([Kim, 2014](https://arxiv.org/pdf/1408.5882.pdf)) | ||
* RCNN ([Lai et al., 2015](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552)) | ||
* TextRNN ([Liu et al., 2016](https://arxiv.org/pdf/1605.05101.pdf)) | ||
* FastText ([Joulin et al., 2016](https://arxiv.org/pdf/1607.01759.pdf)) | ||
* VDCNN ([Conneau et al., 2016](https://arxiv.org/pdf/1606.01781.pdf)) | ||
* DPCNN ([Johnson and Zhang, 2017](https://www.aclweb.org/anthology/P17-1052)) | ||
* AttentiveConvNet ([Yin and Schutze, 2017](https://arxiv.org/pdf/1710.00519.pdf)) | ||
* DRNN ([Wang, 2018](https://www.aclweb.org/anthology/P18-1215)) | ||
* Region embedding ([Qiao et al., 2018](http://research.baidu.com/Public/uploads/5acc1e230d179.pdf)) | ||
* Transformer encoder ([Vaswani et al., 2017](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)) | ||
* Star-Transformer encoder ([Guo et al., 2019](https://arxiv.org/pdf/1902.09113.pdf)) | ||
|
||
## Requirement | ||
|
||
* Python 3 | ||
* PyTorch 0.4+ | ||
* Numpy 1.14.3+ | ||
|
||
## System Architecture | ||
|
||
![NeuralClassifier Architecture](readme/deeptext_arc.png) | ||
|
||
|
||
## Usage | ||
|
||
### Training | ||
|
||
python train.py conf/train.json | ||
|
||
***Detail configurations and explanations see [Configuration](readme/Configuration.md).*** | ||
|
||
The training info will be outputted in standard output and log.logger\_file. | ||
|
||
### Evaluation | ||
python eval.py conf/train.json | ||
|
||
* if eval.is\_flat = false, hierarchical evaluation will be outputted. | ||
* eval.model\_dir is the model to evaluate. | ||
* data.test\_json\_files is the input text file to evaluate. | ||
|
||
The evaluation info will be outputed in eval.dir. | ||
|
||
## Input Data Format | ||
|
||
JSON example: | ||
|
||
{ | ||
"doc_label": ["Computer--MachineLearning--DeepLearning", "Neuro--ComputationalNeuro"], | ||
"doc_token": ["I", "love", "deep", "learning"], | ||
"doc_keyword": ["deep learning"], | ||
"doc_topic": ["AI", "Machine learning"] | ||
} | ||
|
||
"doc_keyword" and "doc_topic" are optional. | ||
|
||
## Performance | ||
|
||
### 0. Dataset | ||
|
||
<table> | ||
<tr><th>Dataset<th>Taxonomy<th>#Label<th>#Training<th>#Test | ||
<tr><td>RCV1<td>Tree<td>103<td>23,149<td>781,265 | ||
<tr><td>Yelp<td>DAG<td>539<td>87,375<td>37,265 | ||
</table> | ||
|
||
* RCV1: [Lewis et al., 2004](http://www.jmlr.org/papers/volume5/lewis04a/lewis04a.pdf) | ||
* Yelp: [Yelp](https://www.yelp.com/dataset/challenge) | ||
|
||
### 1. Compare with state-of-the-art | ||
<table> | ||
<tr><th>Text Encoders<th>Micro-F1 on RCV1<th>Micro-F1 on Yelp | ||
<tr><td>HR-DGCNN (Peng et al., 2018)<td>0.7610<td>- | ||
<tr><td>HMCN (Wehrmann et al., 2018)<td>0.8080<td>0.6640 | ||
<tr><td>Ours<td><strong>0.8313</strong><td><strong>0.6704</strong> | ||
</table> | ||
|
||
* HR-DGCNN: [Peng et al., 2018](http://www.cse.ust.hk/~yqsong/papers/2018-WWW-Text-GraphCNN.pdf) | ||
* HMCN: [Wehrmann et al., 2018](http://proceedings.mlr.press/v80/wehrmann18a/wehrmann18a.pdf) | ||
|
||
### 2. Different text encoders | ||
|
||
<table> | ||
<tr><th row_span='2'>Text Encoders<th colspan='2'>RCV1<th colspan='2'>Yelp | ||
<tr><td><th>Micro-F1<th>Macro-F1<th>Micro-F1<th>Macro-F1 | ||
<tr><td>TextCNN<td>0.7717<td>0.5246<td>0.6281<td>0.3657 | ||
<tr><td>TextRNN<td>0.8152<td>0.5458<td><strong>0.6704</strong><td>0.4059 | ||
<tr><td>RCNN<td><strong>0.8313</strong><td><strong>0.6047</strong><td>0.6569<td>0.3951 | ||
<tr><td>FastText<td>0.6887<td>0.2701 <td>0.6031<td>0.2323 | ||
<tr><td>DRNN<td>0.7846 <td>0.5147<td>0.6579<td>0.4401 | ||
<tr><td>DPCNN<td>0.8220 <td>0.5609 <td>0.5671 <td>0.2393 | ||
<tr><td>VDCNN<td>0.7263 <td>0.3860<td>0.6395<td>0.4035 | ||
<tr><td>AttentiveConvNet<td>0.7533<td>0.4373<td>0.6367<td>0.4040 | ||
<tr><td>RegionEmbedding<td>0.7780 <td>0.4888 <td>0.6601<td><strong>0.4514</strong> | ||
<tr><td>Transformer<td>0.7603 <td>0.4274<td>0.6533<td>0.4121 | ||
<tr><td>Star-Transformer<td>0.7668 <td>0.4840<td>0.6482<td>0.3895 | ||
|
||
</table> | ||
|
||
### 3. Hierarchical vs Flat | ||
|
||
<table> | ||
<tr><th row_span='2'>Text Encoders<th colspan='2'>Hierarchical<th colspan='2'>Flat | ||
<tr><td><th>Micro-F1<th>Macro-F1<th>Micro-F1<th>Macro-F1 | ||
<tr><td>TextCNN<td>0.7717<td>0.5246<td>0.7367<td>0.4224 | ||
<tr><td>TextRNN<td>0.8152<td>0.5458<td>0.7546 <td>0.4505 | ||
<tr><td>RCNN<td><strong>0.8313</strong><td><strong>0.6047</strong><td><strong>0.7955</strong><td><strong>0.5123</strong> | ||
<tr><td>FastText<td>0.6887<td>0.2701 <td>0.6865<td>0.2816 | ||
<tr><td>DRNN<td>0.7846 <td>0.5147<td>0.7506<td>0.4450 | ||
<tr><td>DPCNN<td>0.8220 <td>0.5609 <td>0.7423 <td>0.4261 | ||
<tr><td>VDCNN<td>0.7263 <td>0.3860<td>0.7110<td>0.3593 | ||
<tr><td>AttentiveConvNet<td>0.7533<td>0.4373<td>0.7511<td>0.4286 | ||
<tr><td>RegionEmbedding<td>0.7780 <td>0.4888 <td>0.7640<td>0.4617 | ||
<tr><td>Transformer<td>0.7603 <td>0.4274<td>0.7602<td>0.4339 | ||
<tr><td>Star-Transformer<td>0.7668 <td>0.4840<td>0.7618<td>0.4745 | ||
</table> | ||
|
||
## Acknowledgement | ||
|
||
Some public codes are referenced by our toolkit: | ||
|
||
* https://pytorch.org/docs/stable/ | ||
* https://github.com/jadore801120/attention-is-all-you-need-pytorch/ | ||
* https://github.com/Hsuxu/FocalLoss-PyTorch | ||
* https://github.com/Shawn1993/cnn-text-classification-pytorch | ||
* https://github.com/ailias/Focal-Loss-implement-on-Tensorflow/ | ||
* https://github.com/brightmart/text_classification | ||
* https://github.com/NLPLearn/QANet | ||
* https://github.com/huggingface/pytorch-pretrained-BERT | ||
|
||
## Update | ||
|
||
* 2019-04-29, init version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
{ | ||
"task_info":{ | ||
"label_type": "multi_label", | ||
"hierarchical": true, | ||
"hierar_taxonomy": "data/rcv1.taxonomy", | ||
"hierar_penalty": 0.000001 | ||
}, | ||
"device": "cuda", | ||
"model_name": "TextCNN", | ||
"checkpoint_dir": "checkpoint_dir_rcv1", | ||
"model_dir": "trained_model_rcv1", | ||
"data": { | ||
"train_json_files": [ | ||
"data/rcv1_train.json" | ||
], | ||
"validate_json_files": [ | ||
"data/rcv1_dev.json" | ||
], | ||
"test_json_files": [ | ||
"data/rcv1_test.json" | ||
], | ||
"generate_dict_using_json_files": true, | ||
"generate_dict_using_all_json_files": true, | ||
"generate_dict_using_pretrained_embedding": false, | ||
"dict_dir": "dict_rcv1", | ||
"num_worker": 4 | ||
}, | ||
"feature": { | ||
"feature_names": [ | ||
"token" | ||
], | ||
"min_token_count": 2, | ||
"min_char_count": 2, | ||
"token_ngram": 0, | ||
"min_token_ngram_count": 0, | ||
"min_keyword_count": 0, | ||
"min_topic_count": 2, | ||
"max_token_dict_size": 1000000, | ||
"max_char_dict_size": 150000, | ||
"max_token_ngram_dict_size": 10000000, | ||
"max_keyword_dict_size": 100, | ||
"max_topic_dict_size": 100, | ||
"max_token_len": 256, | ||
"max_char_len": 1024, | ||
"max_char_len_per_token": 4, | ||
"token_pretrained_file": "", | ||
"keyword_pretrained_file": "" | ||
}, | ||
"train": { | ||
"batch_size": 64, | ||
"start_epoch": 1, | ||
"num_epochs": 5, | ||
"num_epochs_static_embedding": 0, | ||
"decay_steps": 1000, | ||
"decay_rate": 1.0, | ||
"clip_gradients": 100.0, | ||
"l2_lambda": 0.0, | ||
"loss_type": "BCEWithLogitsLoss", | ||
"sampler": "fixed", | ||
"num_sampled": 5, | ||
"visible_device_list": "0", | ||
"hidden_layer_dropout": 0.5 | ||
}, | ||
"embedding": { | ||
"type": "embedding", | ||
"dimension": 64, | ||
"region_embedding_type": "context_word", | ||
"region_size": 5, | ||
"initializer": "uniform", | ||
"fan_mode": "FAN_IN", | ||
"uniform_bound": 0.25, | ||
"random_stddev": 0.01, | ||
"dropout": 0.0 | ||
}, | ||
"optimizer": { | ||
"optimizer_type": "Adam", | ||
"learning_rate": 0.008, | ||
"adadelta_decay_rate": 0.95, | ||
"adadelta_epsilon": 1e-08 | ||
}, | ||
"TextCNN": { | ||
"kernel_sizes": [ | ||
2, | ||
3, | ||
4 | ||
], | ||
"num_kernels": 100, | ||
"top_k_max_pooling": 1 | ||
}, | ||
"TextRNN": { | ||
"hidden_dimension": 64, | ||
"rnn_type": "GRU", | ||
"num_layers": 1, | ||
"doc_embedding_type": "Attention", | ||
"attention_dimension": 16, | ||
"bidirectional": true | ||
}, | ||
"DRNN": { | ||
"hidden_dimension": 5, | ||
"window_size": 3, | ||
"rnn_type": "GRU", | ||
"bidirectional": true, | ||
"cell_hidden_dropout": 0.1 | ||
}, | ||
"eval": { | ||
"text_file": "data/rcv1_test.json", | ||
"threshold": 0.5, | ||
"dir": "eval_dir", | ||
"batch_size": 1024, | ||
"is_flat": true, | ||
"top_k": 100, | ||
"model_dir": "checkpoint_dir_rcv1/TextCNN" | ||
}, | ||
"TextVDCNN": { | ||
"vdcnn_depth": 9, | ||
"top_k_max_pooling": 8 | ||
}, | ||
"DPCNN": { | ||
"kernel_size": 3, | ||
"pooling_stride": 2, | ||
"num_kernels": 16, | ||
"blocks": 2 | ||
}, | ||
"TextRCNN": { | ||
"kernel_sizes": [ | ||
2, | ||
3, | ||
4 | ||
], | ||
"num_kernels": 100, | ||
"top_k_max_pooling": 1, | ||
"hidden_dimension":64, | ||
"rnn_type": "GRU", | ||
"num_layers": 1, | ||
"bidirectional": true | ||
}, | ||
"Transformer": { | ||
"d_inner": 128, | ||
"d_k": 32, | ||
"d_v": 32, | ||
"n_head": 4, | ||
"n_layers": 1, | ||
"dropout": 0.1, | ||
"use_star": true | ||
}, | ||
"AttentiveConvNet": { | ||
"attention_type": "bilinear", | ||
"margin_size": 3, | ||
"type": "advanced", | ||
"hidden_size": 64 | ||
}, | ||
"log": { | ||
"logger_file": "log_test_rcv1_hierar", | ||
"log_level": "warn" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!usr/bin/env python | ||
# coding:utf-8 | ||
""" | ||
Tencent is pleased to support the open source community by making NeuralClassifier available. | ||
Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. | ||
Licensed under the MIT License (the "License"); you may not use this file except in compliance | ||
with the License. You may obtain a copy of the License at | ||
http://opensource.org/licenses/MIT | ||
Unless required by applicable law or agreed to in writing, software distributed under the License | ||
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
or implied. See the License for thespecific language governing permissions and limitations under | ||
the License. | ||
""" | ||
|
||
import json | ||
|
||
|
||
class Config(object): | ||
"""Config load from json file | ||
""" | ||
|
||
def __init__(self, config=None, config_file=None): | ||
if config_file: | ||
with open(config_file, 'r') as fin: | ||
config = json.load(fin) | ||
|
||
self.dict = config | ||
if config: | ||
self._update(config) | ||
|
||
def __getitem__(self, key): | ||
return self.dict[key] | ||
|
||
def __contains__(self, item): | ||
return item in self.dict | ||
|
||
def items(self): | ||
return self.dict.items() | ||
|
||
def add(self, key, value): | ||
"""Add key value pair | ||
""" | ||
self.__dict__[key] = value | ||
|
||
def _update(self, config): | ||
if not isinstance(config, dict): | ||
return | ||
|
||
for key in config: | ||
if isinstance(config[key], dict): | ||
config[key] = Config(config[key]) | ||
|
||
if isinstance(config[key], list): | ||
config[key] = [Config(x) if isinstance(x, dict) else x for x in | ||
config[key]] | ||
|
||
self.__dict__.update(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
Root CCAT ECAT GCAT MCAT | ||
CCAT C11 C12 C13 C14 C15 C16 C17 C18 C21 C22 C23 C24 C31 C32 C33 C34 C41 C42 | ||
C15 C151 C152 | ||
C151 C1511 | ||
C17 C171 C172 C173 C174 | ||
C18 C181 C182 C183 | ||
C31 C311 C312 C313 | ||
C33 C331 | ||
C41 C411 | ||
ECAT E11 E12 E13 E14 E21 E31 E41 E51 E61 E71 | ||
E12 E121 | ||
E13 E131 E132 | ||
E14 E141 E142 E143 | ||
E21 E211 E212 | ||
E31 E311 E312 E313 | ||
E41 E411 | ||
E51 E511 E512 E513 | ||
GCAT G15 GCRIM GDEF GDIP GDIS GENT GENV GFAS GHEA GJOB GMIL GOBIT GODD GPOL GPRO GREL GSCI GSPO GTOUR GVIO GVOTE GWEA GWELF | ||
G15 G151 G152 G153 G154 G155 G156 G157 G158 G159 | ||
MCAT M11 M12 M13 M14 | ||
M13 M131 M132 | ||
M14 M141 M142 M143 |
Oops, something went wrong.