Skip to content

Commit

Permalink
update test qps demo.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 20, 2022
1 parent fadc7fe commit 6ebd0f4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
60 changes: 56 additions & 4 deletions tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

sys.path.append('..')
from text2vec import Word2Vec, SentenceModel
from sentence_transformers import SentenceTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -53,26 +54,48 @@ def mean_pooling(model_output, attention_mask):
return sentence_embeddings


class SentenceTransformersEncoder:
def __init__(self, model_name="shibing624/text2vec-base-chinese"):
self.model = SentenceTransformer(model_name)

def encode(self, sentences, convert_to_numpy=True):
sentence_embeddings = self.model.encode(sentences, convert_to_numpy)
return sentence_embeddings


class QPSEncoderTestCase(unittest.TestCase):
def test_cosent_speed(self):
"""测试cosent_speed"""
logger.info("\n---- cosent:")
model = SentenceModel('shibing624/text2vec-base-chinese')
logger.info(' convert_to_numpy=True:')
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp)
r = model.encode(tmp, convert_to_numpy=True)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
logger.info(' convert_to_numpy=False:')
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp, convert_to_numpy=False)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_origin_cosent_speed(self):
"""测试origin_cosent_speed"""
logger.info("\n---- origin cosent:")
def test_origin_transformers_speed(self):
"""测试origin_transformers_speed"""
logger.info("\n---- origin transformers:")
model = TransformersEncoder('shibing624/text2vec-base-chinese')
for j in range(repeat):
tmp = data * (2 ** j)
Expand All @@ -86,6 +109,35 @@ def test_origin_cosent_speed(self):
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_origin_sentence_transformers_speed(self):
"""测试origin_sentence_transformers_speed"""
logger.info("\n---- origin sentence_transformers:")
model = SentenceTransformersEncoder('shibing624/text2vec-base-chinese')
logger.info(' convert_to_numpy=True:')
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp, convert_to_numpy=True)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
logger.info(' convert_to_numpy=False:')
for j in range(repeat):
tmp = data * (2 ** j)
c_num_tokens = num_tokens * (2 ** j)
start_t = time.time()
r = model.encode(tmp, convert_to_numpy=False)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))

def test_sbert_speed(self):
"""测试sbert_speed"""
logger.info("\n---- sbert:")
Expand Down
15 changes: 11 additions & 4 deletions text2vec/sentence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,18 @@ def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids):
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
batch_size: int = 64,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
device: str = None,
):
):
"""
Returns the embeddings for a batch of sentences.
:param sentences: str/list, Input sentences
:param batch_size: int, Batch size
:param show_progress_bar: bool, Whether to show a progress bar for the sentences
:param convert_to_numpy: bool, Whether to convert the output to numpy, instead of a pytorch tensor
:param device: Which torch.device to use for the computation
"""
self.bert.eval()
Expand All @@ -158,10 +160,15 @@ def encode(
**self.tokenizer(sentences_batch, max_length=self.max_seq_length,
padding=True, truncation=True, return_tensors='pt').to(device)
)
embeddings = embeddings.detach().cpu()
embeddings = embeddings.detach()
if convert_to_numpy:
embeddings = embeddings.cpu()
all_embeddings.extend(embeddings)
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
if convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
else:
all_embeddings = torch.stack(all_embeddings)
if input_is_string:
all_embeddings = all_embeddings[0]

Expand Down

0 comments on commit 6ebd0f4

Please sign in to comment.