-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathflairs_spam.py
16 lines (16 loc) · 955 Bytes
/
flairs_spam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path
import flair, torch
flair.device = torch.device('cpu')
# load training data
corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./data/'), test_file='train.csv', dev_file='dev.csv', train_file='test.csv')
# word embeddings (fasttext format)
word_embeddings = [WordEmbeddings('glove'), FlairEmbeddings('news-forward-fast'), FlairEmbeddings('news-backward-fast')]
document_embeddings = DocumentLSTMEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)
# training
classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)
trainer = ModelTrainer(classifier, corpus)
trainer.train('./model/', max_epochs=20)