Skip to content

Commit

Permalink
update encode.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 20, 2022
1 parent ad2ec05 commit b65ac50
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
run: |
pip install -r requirements.txt
pip install .
pip install pytest
pip install sentence-transformers
- name: PKG-TEST
run: |
python -m unittest discover tests/
python -m pytest
4 changes: 2 additions & 2 deletions tests/test_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_cosent_speed(self):
r = model.encode(tmp, convert_to_numpy=False)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
logger.info(f"result shape: {len(r)}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_origin_sentence_transformers_speed(self):
r = model.encode(tmp, convert_to_numpy=False)
assert r is not None
if j == 0:
logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}")
logger.info(f"result shape: {len(r)}, emb: {r[0][:10]}")
time_t = time.time() - start_t
logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' %
(len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t)))
Expand Down
13 changes: 9 additions & 4 deletions text2vec/sentence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def encode(
batch_size: int = 64,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
):
"""
Expand All @@ -138,12 +139,15 @@ def encode(
:param sentences: str/list, Input sentences
:param batch_size: int, Batch size
:param show_progress_bar: bool, Whether to show a progress bar for the sentences
:param convert_to_numpy: bool, Whether to convert the output to numpy, instead of a pytorch tensor
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
"""
self.bert.eval()
if device is None:
device = self.device
if convert_to_tensor:
convert_to_numpy = False
input_is_string = False
if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
sentences = [sentences]
Expand All @@ -165,10 +169,11 @@ def encode(
embeddings = embeddings.cpu()
all_embeddings.extend(embeddings)
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
if convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
else:
if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

if input_is_string:
all_embeddings = all_embeddings[0]

Expand Down

0 comments on commit b65ac50

Please sign in to comment.