Skip to content

Commit

Permalink
update test qps.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 20, 2022
1 parent 03b2e3b commit 5050b3b
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,66 @@
import os
import sys
import unittest
from time import time
from loguru import logger
import time

sys.path.append('..')
from text2vec import Word2Vec, SentenceModel

pwd_path = os.path.abspath(os.path.dirname(__file__))
sts_test_path = os.path.join(pwd_path, '../examples/data/STS-B/STS-B.test.data')
logger.add('test.log')


def load_test_data(path):
sents1, sents2, labels = [], [], []
with open(path, 'r', encoding='utf8') as f:
for line in f:
line = line.strip().split('\t')
if len(line) != 3:
continue
sents1.append(line[0])
sents2.append(line[1])
labels.append(int(line[2]))
if len(sents1) > 10:
break
return sents1, sents2, labels
data = ['如何更换花呗绑定银行卡',
'花呗更改绑定银行卡']
print("data:", data)
num_tokens = sum([len(i) for i in data])


class QPSEncoderTestCase(unittest.TestCase):
def test_cosent_speed(self):
"""测试cosent_speed"""
sents1, sents2, labels = load_test_data(sts_test_path)
m = SentenceModel('shibing624/text2vec-base-chinese')
sents = sents1 + sents2
print('sente size:', len(sents))
t1 = time()
m.encode(sents)
spend_time = time() - t1
print('spend time:', spend_time, ' seconds')
print('cosent_sbert qps:', len(sents) / spend_time)
model = SentenceModel('shibing624/text2vec-base-chinese')
for j in range(10):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info("----\ncosent:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_sbert_speed(self):
"""测试sbert_speed"""
sents1, sents2, labels = load_test_data(sts_test_path)
m = SentenceModel()
sents = sents1 + sents2
print('sente size:', len(sents))
t1 = time()
m.encode(sents)
spend_time = time() - t1
print('spend time:', spend_time, ' seconds')
print('sbert qps:', len(sents) / spend_time)
model = SentenceModel('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
for j in range(10):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info("----\nsbert:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_w2v_speed(self):
"""测试w2v_speed"""
sents1, sents2, labels = load_test_data(sts_test_path)
m = Word2Vec()
sents = sents1 + sents2
print('sente size:', len(sents))
t1 = time()
m.encode(sents)
spend_time = time() - t1
print('spend time:', spend_time, ' seconds')
print('w2v qps:', len(sents) / spend_time)
model = Word2Vec()
for j in range(10):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
time_t = time.time() - start_t
logger.info("----\nword2vec:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))


if __name__ == '__main__':
Expand Down

0 comments on commit 5050b3b

Please sign in to comment.