Skip to content

Commit

Permalink
Non-resizable textcats to legacy (#10)
Browse files Browse the repository at this point in the history
* non-resizable textcat architectures

* TextCatEnsemble_v2 is in core spacy

* consult registry instead of using imports

* bump to 3.0.6

* add the two new architectures to the test suite
  • Loading branch information
svlandeg authored Jun 14, 2021
1 parent 93b3166 commit 6f2b82e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 4 deletions.
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 3.0.5
version = 3.0.6
description = Legacy registered functions for spaCy backwards compatibility
url = https://spacy.io
author = Explosion
Expand Down Expand Up @@ -37,6 +37,8 @@ spacy_architectures =
spacy-legacy.Tok2Vec.v1 = spacy_legacy.architectures.tok2vec:Tok2Vec_v1
spacy-legacy.MaxoutWindowEncoder.v1 = spacy_legacy.architectures.tok2vec:MaxoutWindowEncoder_v1
spacy-legacy.MishWindowEncoder.v1 = spacy_legacy.architectures.tok2vec:MishWindowEncoder_v1
spacy-legacy.TextCatCNN.v1 = spacy_legacy.architectures.textcat:TextCatCNN_v1
spacy-legacy.TextCatBOW.v1 = spacy_legacy.architectures.textcat:TextCatBOW_v1
spacy-legacy.TextCatEnsemble.v1 = spacy_legacy.architectures.textcat:TextCatEnsemble_v1
spacy-legacy.HashEmbedCNN.v1 = spacy_legacy.architectures.tok2vec:HashEmbedCNN_v1
spacy-legacy.MultiHashEmbed.v1 = spacy_legacy.architectures.tok2vec:MultiHashEmbed_v1
Expand Down
66 changes: 64 additions & 2 deletions spacy_legacy/architectures/textcat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,69 @@
from typing import Optional
from thinc.api import Model
from typing import Optional, List
from thinc.types import Floats2d
from thinc.api import Model, with_cpu
from spacy.attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from spacy.util import registry
from spacy.tokens import Doc

# TODO: replace with registered layer after spacy v3.0.7
from spacy.ml import extract_ngrams


def TextCatCNN_v1(
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
chain = registry.get("layers", "chain.v1")
reduce_mean = registry.get("layers", "reduce_mean.v1")
Logistic = registry.get("layers", "Logistic.v1")
Softmax = registry.get("layers", "Softmax.v1")
Linear = registry.get("layers", "Linear.v1")
list2ragged = registry.get("layers", "list2ragged.v1")

# extract_ngrams = registry.get("layers", "spacy.extract_ngrams.v1")

with Model.define_operators({">>": chain}):
cnn = tok2vec >> list2ragged() >> reduce_mean()
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = cnn >> output_layer
model.set_ref("output_layer", output_layer)
else:
linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = cnn >> linear_layer >> Logistic()
model.set_ref("output_layer", linear_layer)
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nO)
model.attrs["multi_label"] = not exclusive_classes
return model


def TextCatBOW_v1(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
chain = registry.get("layers", "chain.v1")
Logistic = registry.get("layers", "Logistic.v1")
SparseLinear = registry.get("layers", "SparseLinear.v1")
softmax_activation = registry.get("layers", "softmax_activation.v1")

with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO)
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
model = with_cpu(model, model.ops)
if not no_output_layer:
output_layer = softmax_activation() if exclusive_classes else Logistic()
model = model >> with_cpu(output_layer, output_layer.ops)
model.set_ref("output_layer", sparse_linear)
model.attrs["multi_label"] = not exclusive_classes
return model


def TextCatEnsemble_v1(
Expand Down
5 changes: 4 additions & 1 deletion spacy_legacy/tests/test_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
("architectures", "Tok2Vec.v1"),
("architectures", "MaxoutWindowEncoder.v1"),
("architectures", "MishWindowEncoder.v1"),
("architectures", "TextCatBOW.v1"),
("architectures", "TextCatCNN.v1"),
("architectures", "TextCatEnsemble.v1"),
("architectures", "HashEmbedCNN.v1"),
("architectures", "MultiHashEmbed.v1"),
("architectures", "CharacterEmbed.v1"),
("loggers", "WandbLogger.v1"),
("layers", "StaticVectors.v1")
("layers", "StaticVectors.v1"),
]


@pytest.mark.parametrize("package", PACKAGES)
@pytest.mark.parametrize("reg_name,name", FUNCTIONS)
def test_registry(package, reg_name, name):
Expand Down

0 comments on commit 6f2b82e

Please sign in to comment.