Skip to content

Commit

Permalink
model encode speed up.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 17, 2022
1 parent 17e37e1 commit cac7eca
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 43 deletions.
4 changes: 3 additions & 1 deletion text2vec/bertmatching_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

from text2vec.bertmatching_dataset import BertMatchingTestDataset, BertMatchingTrainDataset, \
load_test_data, load_train_data, HFBertMatchingTrainDataset, HFBertMatchingTestDataset
from text2vec.sentence_model import device
from text2vec.utils.stats_util import compute_spearmanr, compute_pearsonr
from text2vec.utils.stats_util import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BertMatchModule(nn.Module):
def __init__(
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
max_seq_length: int = 128,
num_classes: int = 2,
encoder_type=None,

):
"""
Initializes the base sentence model.
Expand Down
20 changes: 11 additions & 9 deletions text2vec/cosent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

from text2vec.cosent_dataset import CosentTrainDataset, load_cosent_train_data, HFCosentTrainDataset
from text2vec.sentence_model import SentenceModel, device
from text2vec.sentence_model import SentenceModel
from text2vec.text_matching_dataset import TextMatchingTestDataset, load_test_data, HFTextMatchingTestDataset
from text2vec.utils.stats_util import set_seed

Expand All @@ -26,6 +26,7 @@ def __init__(
model_name_or_path: str = "hfl/chinese-macbert-base",
encoder_type: str = "FIRST_LAST_AVG",
max_seq_length: int = 128,
device: str = None,
):
"""
Initializes a CoSENT Model.
Expand All @@ -34,8 +35,9 @@ def __init__(
model_name_or_path: Default Transformer model name or path to a directory containing Transformer model file (pytorch_nodel.bin).
encoder_type: Enum of type EncoderType.
max_seq_length: The maximum total input sequence length after tokenization.
device: The device on which the model is allocated.
"""
super().__init__(model_name_or_path, encoder_type, max_seq_length)
super().__init__(model_name_or_path, encoder_type, max_seq_length, device)

def __str__(self):
return f"<CoSENTModel: {self.model_name_or_path}, encoder_type: {self.encoder_type}, " \
Expand Down Expand Up @@ -136,7 +138,7 @@ def calc_loss(self, y_true, y_pred):
y_pred = y_pred - (1 - y_true) * 1e12
y_pred = y_pred.view(-1)
# 这里加0是因为e^0 = 1相当于在log中加了1
y_pred = torch.cat((torch.tensor([0]).float().to(device), y_pred), dim=0)
y_pred = torch.cat((torch.tensor([0]).float().to(self.device), y_pred), dim=0)
return torch.logsumexp(y_pred, dim=0)

def train(
Expand All @@ -162,8 +164,8 @@ def train(
Utility function to be used by the train_model() method. Not intended to be used directly.
"""
os.makedirs(output_dir, exist_ok=True)
logger.debug("Use pytorch device: {}".format(device))
self.bert.to(device)
logger.debug("Use device: {}".format(self.device))
self.bert.to(self.device)
set_seed(seed)

train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size)
Expand Down Expand Up @@ -233,11 +235,11 @@ def train(
steps_trained_in_current_epoch -= 1
continue
inputs, labels = batch
labels = labels.to(device)
labels = labels.to(self.device)
# inputs [batch, 1, seq_len] -> [batch, seq_len]
input_ids = inputs.get('input_ids').squeeze(1).to(device)
attention_mask = inputs.get('attention_mask').squeeze(1).to(device)
token_type_ids = inputs.get('token_type_ids').squeeze(1).to(device)
input_ids = inputs.get('input_ids').squeeze(1).to(self.device)
attention_mask = inputs.get('attention_mask').squeeze(1).to(self.device)
token_type_ids = inputs.get('token_type_ids').squeeze(1).to(self.device)
output_embeddings = self.get_sentence_embeddings(input_ids, attention_mask, token_type_ids)
loss = self.calc_loss(labels, output_embeddings)
current_loss = loss.item()
Expand Down
54 changes: 34 additions & 20 deletions text2vec/sentence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from tqdm.auto import tqdm, trange
from text2vec.utils.stats_util import compute_spearmanr, compute_pearsonr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"

Expand Down Expand Up @@ -44,7 +43,8 @@ def __init__(
self,
model_name_or_path: str = "shibing624/text2vec-base-chinese",
encoder_type: Union[str, EncoderType] = "MEAN",
max_seq_length: int = 128
max_seq_length: int = 128,
device: Optional[str] = None,
):
"""
Initializes the base sentence model.
Expand All @@ -53,6 +53,7 @@ def __init__(
:param encoder_type: The type of encoder to use, See the EncoderType enum for options:
FIRST_LAST_AVG, LAST_AVG, CLS, POOLER(cls + dense), MEAN(mean of last_hidden_state)
:param max_seq_length: The maximum sequence length.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if GPU.
bert model: https://huggingface.co/transformers/model_doc/bert.html?highlight=bert#transformers.BertModel.forward
BERT return: <last_hidden_state>, <pooler_output> [hidden_states, attentions]
Expand All @@ -67,7 +68,11 @@ def __init__(
self.max_seq_length = max_seq_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.bert = AutoModel.from_pretrained(model_name_or_path)
self.bert.to(device)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
logger.debug("Use device: {}".format(self.device))
self.bert.to(self.device)
self.results = {} # Save training process evaluation result

def __str__(self):
Expand Down Expand Up @@ -119,11 +124,24 @@ def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids):
input_mask_expanded.sum(1), min=1e-9)
return final_encoding # [batch, hid_size]

def encode(self, sentences: Union[str, List[str]], batch_size: int = 32, show_progress_bar: bool = False):
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = False,
device: str = None,
):
"""
Returns the embeddings for a batch of sentences.
:param sentences: str/list, Input sentences
:param batch_size: int, Batch size
:param show_progress_bar: bool, Whether to show a progress bar for the sentences
:param device: Which torch.device to use for the computation
"""
self.bert.eval()
if device is None:
device = self.device
input_is_string = False
if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
sentences = [sentences]
Expand All @@ -134,16 +152,12 @@ def encode(self, sentences: Union[str, List[str]], batch_size: int = 32, show_pr
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index: start_index + batch_size]
# Tokenize sentences
inputs = self.tokenizer(sentences_batch, max_length=self.max_seq_length, truncation=True,
padding='max_length', return_tensors='pt')
input_ids = inputs.get('input_ids').squeeze(1).to(device)
attention_mask = inputs.get('attention_mask').squeeze(1).to(device)
token_type_ids = inputs.get('token_type_ids').squeeze(1).to(device)

# Compute sentences embeddings
with torch.no_grad():
embeddings = self.get_sentence_embeddings(input_ids, attention_mask, token_type_ids)
embeddings = self.get_sentence_embeddings(
**self.tokenizer(sentences_batch, max_length=self.max_seq_length,
padding=True, truncation=True, return_tensors='pt').to(device)
)
embeddings = embeddings.detach().cpu()
all_embeddings.extend(embeddings)
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
Expand Down Expand Up @@ -175,24 +189,24 @@ def evaluate(self, eval_dataset, output_dir: str = None, batch_size: int = 16):
results = {}

eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.bert.to(device)
self.bert.to(self.device)
self.bert.eval()

batch_labels = []
batch_preds = []
for batch in tqdm(eval_dataloader, disable=False, desc="Running Evaluation"):
source, target, labels = batch
labels = labels.to(device)
labels = labels.to(self.device)
batch_labels.extend(labels.cpu().numpy())
# source [batch, 1, seq_len] -> [batch, seq_len]
source_input_ids = source.get('input_ids').squeeze(1).to(device)
source_attention_mask = source.get('attention_mask').squeeze(1).to(device)
source_token_type_ids = source.get('token_type_ids').squeeze(1).to(device)
source_input_ids = source.get('input_ids').squeeze(1).to(self.device)
source_attention_mask = source.get('attention_mask').squeeze(1).to(self.device)
source_token_type_ids = source.get('token_type_ids').squeeze(1).to(self.device)

# target [batch, 1, seq_len] -> [batch, seq_len]
target_input_ids = target.get('input_ids').squeeze(1).to(device)
target_attention_mask = target.get('attention_mask').squeeze(1).to(device)
target_token_type_ids = target.get('token_type_ids').squeeze(1).to(device)
target_input_ids = target.get('input_ids').squeeze(1).to(self.device)
target_attention_mask = target.get('attention_mask').squeeze(1).to(self.device)
target_token_type_ids = target.get('token_type_ids').squeeze(1).to(self.device)

with torch.no_grad():
source_embeddings = self.get_sentence_embeddings(source_input_ids, source_attention_mask,
Expand Down
26 changes: 14 additions & 12 deletions text2vec/sentencebert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tqdm.auto import tqdm, trange
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

from text2vec.sentence_model import SentenceModel, device
from text2vec.sentence_model import SentenceModel
from text2vec.text_matching_dataset import (
TextMatchingTrainDataset,
TextMatchingTestDataset,
Expand All @@ -34,6 +34,7 @@ def __init__(
encoder_type: str = "MEAN",
max_seq_length: int = 128,
num_classes: int = 2,
device: str = None,
):
"""
Initializes a SentenceBert Model.
Expand All @@ -43,9 +44,10 @@ def __init__(
encoder_type: encoder type, set by model name
max_seq_length: The maximum total input sequence length after tokenization.
num_classes: Number of classes for classification.
device: CPU or GPU
"""
super().__init__(model_name_or_path, encoder_type, max_seq_length)
self.classifier = nn.Linear(self.bert.config.hidden_size * 3, num_classes).to(device)
super().__init__(model_name_or_path, encoder_type, max_seq_length, device)
self.classifier = nn.Linear(self.bert.config.hidden_size * 3, num_classes).to(self.device)

def __str__(self):
return f"<SentenceBertModel: {self.model_name_or_path}, encoder_type: {self.encoder_type}, " \
Expand Down Expand Up @@ -172,8 +174,8 @@ def train(
Utility function to be used by the train_model() method. Not intended to be used directly.
"""
os.makedirs(output_dir, exist_ok=True)
logger.debug("Use pytorch device: {}".format(device))
self.bert.to(device)
logger.debug("Use pytorch device: {}".format(self.device))
self.bert.to(self.device)
set_seed(seed)

train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size)
Expand Down Expand Up @@ -244,14 +246,14 @@ def train(
continue
source, target, labels = batch
# source [batch, 1, seq_len] -> [batch, seq_len]
source_input_ids = source.get('input_ids').squeeze(1).to(device)
source_attention_mask = source.get('attention_mask').squeeze(1).to(device)
source_token_type_ids = source.get('token_type_ids').squeeze(1).to(device)
source_input_ids = source.get('input_ids').squeeze(1).to(self.device)
source_attention_mask = source.get('attention_mask').squeeze(1).to(self.device)
source_token_type_ids = source.get('token_type_ids').squeeze(1).to(self.device)
# target [batch, 1, seq_len] -> [batch, seq_len]
target_input_ids = target.get('input_ids').squeeze(1).to(device)
target_attention_mask = target.get('attention_mask').squeeze(1).to(device)
target_token_type_ids = target.get('token_type_ids').squeeze(1).to(device)
labels = labels.to(device)
target_input_ids = target.get('input_ids').squeeze(1).to(self.device)
target_attention_mask = target.get('attention_mask').squeeze(1).to(self.device)
target_token_type_ids = target.get('token_type_ids').squeeze(1).to(self.device)
labels = labels.to(self.device)

# get sentence embeddings of BERT encoder
source_embeddings = self.get_sentence_embeddings(source_input_ids, source_attention_mask,
Expand Down
5 changes: 4 additions & 1 deletion text2vec/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from numpy import ndarray
from torch import Tensor

from text2vec.sentence_model import SentenceModel, device, EncoderType
from text2vec.sentence_model import SentenceModel, EncoderType
from text2vec.utils.distance import cosine_distance
from text2vec.utils.tokenizer import JiebaTokenizer
from text2vec.word2vec import Word2Vec
from text2vec.utils.get_file import deprecated

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class SimilarityType(Enum):
WMD = 0
COSINE = 1
Expand Down

0 comments on commit cac7eca

Please sign in to comment.