Skip to content

Commit

Permalink
refactor(model): minor code style updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fardeon authored Mar 6, 2023
1 parent 1734206 commit ddd7fb5
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def __call__(
**samples,
)

@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
def _infer(self, model_fn, **kwargs):
"""Call a model function in inference mode with auto retrying."""
# This is a temporary workaround for bitsandbytes #162:
# https://github.com/TimDettmers/bitsandbytes/issues/162
with torch.inference_mode():
return model_fn(**kwargs)

def _sample(self, token, token_logprob, top_tokens, top_logprobs):
"""Sample log probabilities of the most likely tokens."""
token = self.tokenizer.decode(token)
Expand Down Expand Up @@ -168,14 +176,6 @@ def _logits_processor(self, config, input_length):

return processor

@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
def _infer(self, model_fn, **kwargs):
"""Call a model function in inference mode with auto retrying."""
# This is a temporary workaround for bitsandbytes #162:
# https://github.com/TimDettmers/bitsandbytes/issues/162
with torch.inference_mode():
return model_fn(**kwargs)

def tokenize(self, text):
"""Tokenize a string into a tensor of token IDs."""
batch = self.tokenizer.encode(text, return_tensors="pt")
Expand Down Expand Up @@ -280,7 +280,7 @@ def generate(self, input_ids, logprobs=0, **kwargs):
# Append selected tokens to the inputs.
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)

# Extract past key values from model output.
# Extract past key values from model outputs.
if "past_key_values" in outputs:
kwargs["past_key_values"] = outputs.past_key_values
elif "mems" in outputs:
Expand Down

0 comments on commit ddd7fb5

Please sign in to comment.