Skip to content

Commit

Permalink
update test qps.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 20, 2022
1 parent 616fbad commit fadc7fe
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
'花呗更改绑定银行卡']
print("data:", data)
num_tokens = sum([len(i) for i in data])
use_cuda = torch.cuda.is_available()
repeat = 10 if use_cuda else 1


class TransformersEncoder:
Expand Down Expand Up @@ -56,13 +58,14 @@ def test_cosent_speed(self):
"""测试cosent_speed"""
logger.info("\n---- cosent:")
model = SentenceModel('shibing624/text2vec-base-chinese')
for j in range(10):
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
Expand All @@ -71,29 +74,30 @@ def test_origin_cosent_speed(self):
"""测试origin_cosent_speed"""
logger.info("\n---- origin cosent:")
model = TransformersEncoder('shibing624/text2vec-base-chinese')
for j in range(10):
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info("----\ncosent:")
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_sbert_speed(self):
"""测试sbert_speed"""
logger.info("\n---- sbert:")
model = SentenceModel('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
for j in range(10):
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
Expand All @@ -102,13 +106,14 @@ def test_w2v_speed(self):
"""测试w2v_speed"""
logger.info("\n---- w2v:")
model = Word2Vec()
for j in range(10):
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
assert r is not None
print('result shape', r.shape)
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
Expand Down

0 comments on commit fadc7fe

Please sign in to comment.