Skip to content

Commit

Permalink
update test qps.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 20, 2022
1 parent 5050b3b commit 616fbad
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,21 @@ Cross-Encoder适用于向量检索精排。

| Arch | Backbone | Model Name | ATEC | BQ | LCQMC | PAWSX | STS-B | Avg | QPS |
| :-- | :--- | :--- | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| CoSENT | hfl/chinese-macbert-base | CoSENT-macbert-base | 50.39 | **72.93** | **79.17** | **60.86** | **80.51** | **68.77** | 2572 |
| CoSENT | hfl/chinese-macbert-base | CoSENT-macbert-base | 50.39 | **72.93** | **79.17** | **60.86** | **80.51** | **68.77** | 3008 |
| CoSENT | Langboat/mengzi-bert-base | CoSENT-mengzi-base | **50.52** | 72.27 | 78.69 | 12.89 | 80.15 | 58.90 | 2502 |
| CoSENT | bert-base-chinese | CoSENT-bert-base | 49.74 | 72.38 | 78.69 | 60.00 | 80.14 | 68.19 | 2653 |
| SBERT | bert-base-chinese | SBERT-bert-base | 46.36 | 70.36 | 78.72 | 46.86 | 66.41 | 61.74 | 1365 |
| SBERT | hfl/chinese-macbert-base | SBERT-macbert-base | 47.28 | 68.63 | **79.42** | 55.59 | 64.82 | 63.15 | 1948 |
| SBERT | bert-base-chinese | SBERT-bert-base | 46.36 | 70.36 | 78.72 | 46.86 | 66.41 | 61.74 | 3365 |
| SBERT | hfl/chinese-macbert-base | SBERT-macbert-base | 47.28 | 68.63 | **79.42** | 55.59 | 64.82 | 63.15 | 2948 |
| CoSENT | hfl/chinese-roberta-wwm-ext | CoSENT-roberta-ext | **50.81** | **71.45** | **79.31** | **61.56** | **81.13** | **68.85** | - |
| SBERT | hfl/chinese-roberta-wwm-ext | SBERT-roberta-ext | 48.29 | 69.99 | 79.22 | 44.10 | 72.42 | 62.80 | - |

- 本项目release模型的中文匹配评测结果:

| Arch | Backbone | Model Name | ATEC | BQ | LCQMC | PAWSX | STS-B | Avg | QPS |
| :-- | :--- | :---- | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Word2Vec | word2vec | w2v-light-tencent-chinese | 20.00 | 31.49 | 59.46 | 2.57 | 55.78 | 33.86 | 10283 |
| SBERT | xlm-roberta-base | paraphrase-multilingual-MiniLM-L12-v2 | 18.42 | 38.52 | 63.96 | 10.14 | 78.90 | 41.99 | 2371 |
| CoSENT | hfl/chinese-macbert-base | text2vec-base-chinese | 31.93 | 42.67 | 70.16 | 17.21 | 79.30 | **48.25** | 2572 |
| Word2Vec | word2vec | w2v-light-tencent-chinese | 20.00 | 31.49 | 59.46 | 2.57 | 55.78 | 33.86 | 23769 |
| SBERT | xlm-roberta-base | paraphrase-multilingual-MiniLM-L12-v2 | 18.42 | 38.52 | 63.96 | 10.14 | 78.90 | 41.99 | 3138 |
| CoSENT | hfl/chinese-macbert-base | text2vec-base-chinese | 31.93 | 42.67 | 70.16 | 17.21 | 79.30 | **48.25** | 3008 |

说明:
- 结果值均使用spearman系数
Expand Down
50 changes: 48 additions & 2 deletions tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
import unittest
from loguru import logger
import time
import os
import torch
from transformers import AutoTokenizer, AutoModel

sys.path.append('..')
from text2vec import Word2Vec, SentenceModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
pwd_path = os.path.abspath(os.path.dirname(__file__))
logger.add('test.log')

Expand All @@ -21,10 +27,50 @@
num_tokens = sum([len(i) for i in data])


class TransformersEncoder:
def __init__(self, model_name='shibing624/text2vec-base-chinese'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(device)

def encode(self, sentences):
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1),
min=1e-9)

# Tokenize sentences
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)

# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return sentence_embeddings


class QPSEncoderTestCase(unittest.TestCase):
def test_cosent_speed(self):
"""测试cosent_speed"""
logger.info("\n---- cosent:")
model = SentenceModel('shibing624/text2vec-base-chinese')
for j in range(10):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_origin_cosent_speed(self):
"""测试origin_cosent_speed"""
logger.info("\n---- origin cosent:")
model = TransformersEncoder('shibing624/text2vec-base-chinese')
for j in range(10):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
Expand All @@ -39,6 +85,7 @@ def test_cosent_speed(self):

def test_sbert_speed(self):
"""测试sbert_speed"""
logger.info("\n---- sbert:")
model = SentenceModel('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
for j in range(10):
tmp = data * (2 ** j)
Expand All @@ -48,12 +95,12 @@ def test_sbert_speed(self):
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info("----\nsbert:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_w2v_speed(self):
"""测试w2v_speed"""
logger.info("\n---- w2v:")
model = Word2Vec()
for j in range(10):
tmp = data * (2 ** j)
Expand All @@ -63,7 +110,6 @@ def test_w2v_speed(self):
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info("----\nword2vec:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

Expand Down

0 comments on commit 616fbad

Please sign in to comment.