Skip to content

Commit

Permalink
Merge pull request flairNLP#1335 from flairNLP/optimizations
Browse files Browse the repository at this point in the history
relax version requirements for Colab installation
  • Loading branch information
Alan Akbik authored Jan 7, 2020
2 parents f4001c7 + c2aefa2 commit 058a458
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 85 deletions.
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ before_script: cd tests
script:
- pip freeze
- 'if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then pytest --runintegration; fi'
- 'if [ "$TRAVIS_PULL_REQUEST" = "false" ]; then pytest; fi'
- 'if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then pip install black; black --check .; fi'
- 'if [ "$TRAVIS_PULL_REQUEST" = "false" ]; then pytest; fi'
10 changes: 0 additions & 10 deletions flair/visual/training_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,8 @@
import numpy as np
import csv

import matplotlib
import math


# change from Agg to TkAgg for interactive mode
try:
# change from Agg to TkAgg for interactive mode
matplotlib.use("TkAgg")
except:
pass


import matplotlib.pyplot as plt


Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
python-dateutil==2.8.0
python-dateutil>=2.8.1
torch>=1.1.0
gensim>=3.4.0
pytest>=3.6.4
pytest>=5.3.2
tqdm>=4.26.0
segtok>=1.5.7
matplotlib>=2.2.3
mpld3==0.3
scikit-learn==0.21.3
scikit-learn>=0.21.3
sqlitedict>=1.6.0
deprecated>=1.2.4
hyperopt>=0.1.1
Expand All @@ -15,4 +15,4 @@ bpemb>=0.2.9
regex
tabulate
urllib3<1.25,>=1.20
langdetect
langdetect
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="Alan Akbik",
author_email="alan.akbik@zalando.de",
url="https://github.com/zalandoresearch/flair",
author_email="alan.akbik@gmail.com",
url="https://github.com/flairNLP/flair",
packages=find_packages(exclude="tests"), # same as name
license="MIT",
install_requires=required,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_keep_batch_order():

assert torch.norm(sentences_1[0].embedding - sentences_2[1].embedding) == 0.0
assert torch.norm(sentences_1[0].embedding - sentences_2[1].embedding) == 0.0
del embeddings


@pytest.mark.integration
Expand All @@ -58,6 +59,7 @@ def test_stacked_embeddings():
token.clear_embeddings()

assert len(token.get_embedding()) == 0
del embeddings


@pytest.mark.integration
Expand Down Expand Up @@ -97,6 +99,7 @@ def test_fine_tunable_flair_embedding():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings


@pytest.mark.integration
Expand All @@ -115,6 +118,7 @@ def test_document_lstm_embeddings():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings


@pytest.mark.integration
Expand All @@ -133,6 +137,7 @@ def test_document_bidirectional_lstm_embeddings():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings


@pytest.mark.integration
Expand All @@ -151,6 +156,7 @@ def test_document_pool_embeddings():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings


@pytest.mark.integration
Expand All @@ -169,6 +175,7 @@ def test_document_pool_embeddings_nonlinear():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings


def init_document_embeddings():
Expand All @@ -193,6 +200,7 @@ def load_and_apply_word_embeddings(emb_type: str):
token.clear_embeddings()

assert len(token.get_embedding()) == 0
del embeddings


def load_and_apply_char_lm_embeddings(emb_type: str):
Expand All @@ -207,3 +215,4 @@ def load_and_apply_char_lm_embeddings(emb_type: str):
token.clear_embeddings()

assert len(token.get_embedding()) == 0
del embeddings
17 changes: 5 additions & 12 deletions tests/test_hyperparameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
)
import flair.datasets

glove_embedding: WordEmbeddings = WordEmbeddings("glove")


@pytest.mark.integration
def test_sequence_tagger_param_selector(results_base_path, tasks_base_path):
Expand All @@ -27,16 +29,7 @@ def test_sequence_tagger_param_selector(results_base_path, tasks_base_path):
search_space.add(
Parameter.EMBEDDINGS,
hp.choice,
options=[
StackedEmbeddings([WordEmbeddings("glove")]),
StackedEmbeddings(
[
WordEmbeddings("glove"),
FlairEmbeddings("news-forward-fast"),
FlairEmbeddings("news-backward-fast"),
]
),
],
options=[StackedEmbeddings([glove_embedding])],
)
search_space.add(Parameter.USE_CRF, hp.choice, options=[True, False])
search_space.add(Parameter.DROPOUT, hp.uniform, low=0.25, high=0.75)
Expand All @@ -63,14 +56,13 @@ def test_sequence_tagger_param_selector(results_base_path, tasks_base_path):

# clean up results directory
shutil.rmtree(results_base_path)
del optimizer, search_space


@pytest.mark.integration
def test_text_classifier_param_selector(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")

glove_embedding: WordEmbeddings = WordEmbeddings("glove")

search_space = SearchSpace()

# document embeddings parameter
Expand All @@ -97,3 +89,4 @@ def test_text_classifier_param_selector(results_base_path, tasks_base_path):

# clean up results directory
shutil.rmtree(results_base_path)
del param_selector, search_space
2 changes: 2 additions & 0 deletions tests/test_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_generate_text_with_small_temperatures():
)
assert text is not None
assert len(text) >= 100
del language_model


def test_compute_perplexity():
Expand Down Expand Up @@ -70,3 +71,4 @@ def test_compute_perplexity():
print(f'"{ungrammatical}" - perplexity is {perplexity_ungramamtical_sentence}')

assert perplexity_gramamtical_sentence < perplexity_ungramamtical_sentence
del language_model
Loading

0 comments on commit 058a458

Please sign in to comment.