Skip to content

Commit

Permalink
Add hierarchical toy data and fix bugs. (Tencent#82)
Browse files Browse the repository at this point in the history
* add hierar rcv1 data and fix bug

* Update README.md

* Update README.md

* fix config example

* Update Configuration.md

* Update Configuration.md

* remove unused annotation

* remove unused config parameter

Co-authored-by: perrypyli <[email protected]>
  • Loading branch information
coderbyr and coderbyr authored Mar 29, 2021
1 parent 04b937a commit 8fbf95f
Show file tree
Hide file tree
Showing 10 changed files with 31,461 additions and 8 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,18 @@ NeuralClassifier is designed for quick implementation of neural models for hiera

### Training

python train.py conf/train.json
#### How to train a non-hierarchical classifier
python train.py conf/train.json
* set task_info.hierarchical = false.
* model_name can be `FastText、TextCNN、TextRNN、TextRCNN、DRNN、VDCNN、DPCNN、AttentiveConvNet、Transformer`.
#### How to train a hierarchical classifier using hierarchial penalty
python train.py conf/train.hierar.json
* set task_info.hierarchical = true.
* model_name can be `FastText、TextCNN、TextRNN、TextRCNN、DRNN、VDCNN、DPCNN、AttentiveConvNet、Transformer`
#### How to train a hierarchical classifier with HMCN
python train.py conf/train.hmcn.json
* set task_info.hierarchical = false.
* set model_name = `HMCN`

***Detail configurations and explanations see [Configuration](readme/Configuration.md).***

Expand Down
161 changes: 161 additions & 0 deletions conf/train.hierar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"task_info": {
"label_type": "multi_label",
"hierarchical": true,
"hierar_taxonomy": "data/rcv1.taxonomy",
"hierar_penalty": 0.000001
},
"device": "cuda",
"model_name": "TextRNN",
"checkpoint_dir": "checkpoint_dir_rcv1",
"model_dir": "trained_model_rcv1",
"data": {
"train_json_files": [
"data/rcv1_merged.hierar.json"
],
"validate_json_files": [
"data/rcv1_test.hierar.json"
],
"test_json_files": [
"data/rcv1_test.hierar.json"
],
"generate_dict_using_json_files": true,
"generate_dict_using_all_json_files": true,
"generate_dict_using_pretrained_embedding": false,
"generate_hierarchy_label": true,
"dict_dir": "dict_rcv1",
"num_worker": 4
},
"feature": {
"feature_names": [
"token"
],
"min_token_count": 2,
"min_char_count": 2,
"token_ngram": 0,
"min_token_ngram_count": 0,
"min_keyword_count": 0,
"min_topic_count": 2,
"max_token_dict_size": 1000000,
"max_char_dict_size": 150000,
"max_token_ngram_dict_size": 10000000,
"max_keyword_dict_size": 100,
"max_topic_dict_size": 100,
"max_token_len": 256,
"max_char_len": 1024,
"max_char_len_per_token": 4,
"token_pretrained_file": "",
"keyword_pretrained_file": ""
},
"train": {
"batch_size": 64,
"start_epoch": 1,
"num_epochs": 50,
"num_epochs_static_embedding": 0,
"decay_steps": 1000,
"decay_rate": 1.0,
"clip_gradients": 100.0,
"l2_lambda": 0.0,
"loss_type": "BCEWithLogitsLoss",
"sampler": "fixed",
"num_sampled": 5,
"visible_device_list": "0",
"hidden_layer_dropout": 0.5
},
"embedding": {
"type": "embedding",
"dimension": 64,
"region_embedding_type": "context_word",
"region_size": 5,
"initializer": "uniform",
"fan_mode": "FAN_IN",
"uniform_bound": 0.25,
"random_stddev": 0.01,
"dropout": 0.0
},
"optimizer": {
"optimizer_type": "Adam",
"learning_rate": 0.008,
"adadelta_decay_rate": 0.95,
"adadelta_epsilon": 1e-08
},
"TextCNN": {
"kernel_sizes": [
2,
3,
4
],
"num_kernels": 100,
"top_k_max_pooling": 1
},
"TextRNN": {
"hidden_dimension": 64,
"rnn_type": "GRU",
"num_layers": 1,
"doc_embedding_type": "Attention",
"attention_dimension": 16,
"bidirectional": true
},
"DRNN": {
"hidden_dimension": 5,
"window_size": 3,
"rnn_type": "GRU",
"bidirectional": true,
"cell_hidden_dropout": 0.1
},
"eval": {
"text_file": "data/rcv1_test.json",
"threshold": 0.5,
"dir": "eval_dir",
"batch_size": 1024,
"is_flat": true,
"top_k": 100,
"model_dir": "checkpoint_dir_rcv1/TextCNN_best"
},
"TextVDCNN": {
"vdcnn_depth": 9,
"top_k_max_pooling": 8
},
"DPCNN": {
"kernel_size": 3,
"pooling_stride": 2,
"num_kernels": 16,
"blocks": 2
},
"TextRCNN": {
"kernel_sizes": [
2,
3,
4
],
"num_kernels": 100,
"top_k_max_pooling": 1,
"hidden_dimension":64,
"rnn_type": "GRU",
"num_layers": 1,
"bidirectional": true
},
"Transformer": {
"d_inner": 128,
"d_k": 32,
"d_v": 32,
"n_head": 4,
"n_layers": 1,
"dropout": 0.1,
"use_star": true
},
"AttentiveConvNet": {
"attention_type": "bilinear",
"margin_size": 3,
"type": "advanced",
"hidden_size": 64
},
"HMCN": {
"hierarchical_depth": [0, 384, 384, 384, 384],
"global2local": [0, 4, 55, 43, 1]
},
"log": {
"logger_file": "log_test_rcv1_hierar",
"log_level": "warn"
}
}
161 changes: 161 additions & 0 deletions conf/train.hmcn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"task_info": {
"label_type": "multi_label",
"hierarchical": false,
"hierar_taxonomy": "data/rcv1.taxonomy",
"hierar_penalty": 1e-5
},
"device": "cuda",
"model_name": "HMCN",
"checkpoint_dir": "checkpoint_dir_rcv1",
"model_dir": "trained_model_rcv1",
"data": {
"train_json_files": [
"data/rcv1_merged.hierar.json"
],
"validate_json_files": [
"data/rcv1_test.hierar.json"
],
"test_json_files": [
"data/rcv1_test.hierar.json"
],
"generate_dict_using_json_files": true,
"generate_dict_using_all_json_files": true,
"generate_dict_using_pretrained_embedding": false,
"generate_hierarchy_label": true,
"dict_dir": "dict_rcv1",
"num_worker": 4
},
"feature": {
"feature_names": [
"token"
],
"min_token_count": 2,
"min_char_count": 2,
"token_ngram": 0,
"min_token_ngram_count": 0,
"min_keyword_count": 0,
"min_topic_count": 2,
"max_token_dict_size": 1000000,
"max_char_dict_size": 150000,
"max_token_ngram_dict_size": 10000000,
"max_keyword_dict_size": 100,
"max_topic_dict_size": 100,
"max_token_len": 256,
"max_char_len": 1024,
"max_char_len_per_token": 4,
"token_pretrained_file": "",
"keyword_pretrained_file": ""
},
"train": {
"batch_size": 64,
"start_epoch": 1,
"num_epochs": 50,
"num_epochs_static_embedding": 0,
"decay_steps": 1000,
"decay_rate": 1.0,
"clip_gradients": 100.0,
"l2_lambda": 0.0,
"loss_type": "BCEWithLogitsLoss",
"sampler": "fixed",
"num_sampled": 5,
"visible_device_list": "0",
"hidden_layer_dropout": 0.5
},
"embedding": {
"type": "embedding",
"dimension": 64,
"region_embedding_type": "context_word",
"region_size": 5,
"initializer": "uniform",
"fan_mode": "FAN_IN",
"uniform_bound": 0.25,
"random_stddev": 0.01,
"dropout": 0.0
},
"optimizer": {
"optimizer_type": "Adam",
"learning_rate": 0.008,
"adadelta_decay_rate": 0.95,
"adadelta_epsilon": 1e-08
},
"TextCNN": {
"kernel_sizes": [
2,
3,
4
],
"num_kernels": 100,
"top_k_max_pooling": 1
},
"TextRNN": {
"hidden_dimension": 64,
"rnn_type": "GRU",
"num_layers": 1,
"doc_embedding_type": "Attention",
"attention_dimension": 16,
"bidirectional": true
},
"DRNN": {
"hidden_dimension": 5,
"window_size": 3,
"rnn_type": "GRU",
"bidirectional": true,
"cell_hidden_dropout": 0.1
},
"eval": {
"text_file": "data/rcv1_test.json",
"threshold": 0.5,
"dir": "eval_dir",
"batch_size": 1024,
"is_flat": true,
"top_k": 100,
"model_dir": "checkpoint_dir_rcv1/TextCNN_best"
},
"TextVDCNN": {
"vdcnn_depth": 9,
"top_k_max_pooling": 8
},
"DPCNN": {
"kernel_size": 3,
"pooling_stride": 2,
"num_kernels": 16,
"blocks": 2
},
"TextRCNN": {
"kernel_sizes": [
2,
3,
4
],
"num_kernels": 100,
"top_k_max_pooling": 1,
"hidden_dimension":64,
"rnn_type": "GRU",
"num_layers": 1,
"bidirectional": true
},
"Transformer": {
"d_inner": 128,
"d_k": 32,
"d_v": 32,
"n_head": 4,
"n_layers": 1,
"dropout": 0.1,
"use_star": true
},
"AttentiveConvNet": {
"attention_type": "bilinear",
"margin_size": 3,
"type": "advanced",
"hidden_size": 64
},
"HMCN": {
"hierarchical_depth": [0, 384, 384, 384, 384],
"global2local": [0, 4, 55, 43, 1]
},
"log": {
"logger_file": "log_test_rcv1_hierar",
"log_level": "warn"
}
}
3 changes: 1 addition & 2 deletions conf/train.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
{
"task_info":{
"label_type": "multi_label",
"hierarchical": true,
"hierarchical_model": false,
"hierarchical": false,
"hierar_taxonomy": "data/rcv1.taxonomy",
"hierar_penalty": 0.000001
},
Expand Down
Loading

0 comments on commit 8fbf95f

Please sign in to comment.