-
Notifications
You must be signed in to change notification settings - Fork 48
/
multilabel_pipeline.py
50 lines (45 loc) · 1.47 KB
/
multilabel_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Union, Optional
import numpy as np
from transformers.pipelines import ArgumentHandler
from transformers import (
Pipeline,
PreTrainedTokenizer,
ModelCard
)
class MultiLabelPipeline(Pipeline):
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
threshold: float = 0.3
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=binary_output,
task=task
)
self.threshold = threshold
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
scores = 1 / (1 + np.exp(-outputs)) # Sigmoid
results = []
for item in scores:
labels = []
scores = []
for idx, s in enumerate(item):
if s > self.threshold:
labels.append(self.model.config.id2label[idx])
scores.append(s)
results.append({"labels": labels, "scores": scores})
return results