Skip to content

Commit

Permalink
'Refactored by Sourcery'
Browse files Browse the repository at this point in the history
  • Loading branch information
Sourcery AI committed Apr 4, 2023
1 parent ca949b6 commit dcce17e
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 132 deletions.
20 changes: 9 additions & 11 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def warning_on_one_line(message, category, filename, lineno, file=None, line=Non
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>')
def serve(path):
if path == "" or not os.path.exists(app.static_folder + '/' + path):
if path == "" or not os.path.exists(f'{app.static_folder}/{path}'):
path = 'index.html'

return send_from_directory(app.static_folder, path)
Expand Down Expand Up @@ -72,23 +72,21 @@ def run(self):
for line in lines:
if line == "":
continue

if line.startswith("Downloading shards:"):
progress = re.search(r"\| (\d+)/(\d+) \[", line)
if progress:
current_shard, total_shards = int(progress.group(1)), int(progress.group(2))
if progress := re.search(r"\| (\d+)/(\d+) \[", line):
current_shard, total_shards = int(progress[1]), int(progress[2])
elif line.startswith("Downloading"):
percentage = re.search(r":\s+(\d+)%", line)
percentage = percentage.group(0)[2:] if percentage else ""
percentage = percentage[0][2:] if percentage else ""

progress = re.search(r"\[(.*?)\]", line)
if progress and "?" not in progress.group(0):
current_duration, rest = progress.group(0)[1:-1].split("<")
if progress and "?" not in progress[0]:
current_duration, rest = progress[0][1:-1].split("<")
total_duration, speed = rest.split(",")

download_size = re.search(r"\| (.*?)\[", line)
if download_size:
current_size, total_size = download_size.group(0)[2:-1].strip().split("/")
if download_size := re.search(r"\| (.*?)\[", line):
current_size, total_size = download_size[0][2:-1].strip().split("/")

self.event_emitter.emit(EVENTS.MODEL_DOWNLOAD_UPDATE, self.model, {
'current_shard': current_shard,
Expand Down
37 changes: 21 additions & 16 deletions server/lib/api/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def stream_inference():

if not isinstance(data['prompt'], str) or not isinstance(data['models'], list):
return create_response_message("Invalid request", 400)

request_uuid = "1"

prompt = data['prompt']
Expand All @@ -51,10 +51,10 @@ def stream_inference():

if not isinstance(name, str) or not isinstance(tag, str) or not isinstance(parameters, dict):
continue

if provider not in providers:
continue

models_name_provider.append({"name": model['name'], "provider": model['provider']})

required_parameters = []
Expand Down Expand Up @@ -91,23 +91,26 @@ def stream_inference():
if param == "stopSequences":
if parameters[param] is None:
parameters[param] = []
if (not isinstance(parameters[param], list) and not parameters[param] == None):
return create_response_message(f"Invalid stopSequences parameter", 400)
if (
not isinstance(parameters[param], list)
and parameters[param] is not None
):
return create_response_message("Invalid stopSequences parameter", 400)
elif not isinstance(parameters[param], (int, float)) and not (isinstance(parameters[param], str) and parameters[param].replace('.', '').isdigit()):
return create_response_message(f"Invalid parameter: {param} - {name}", 400)

sanitized_params[param] = parameters[param]

all_tasks.append(InferenceRequest(
uuid=request_uuid, model_name=name, model_tag=tag, model_provider=provider,
model_parameters=sanitized_params, prompt=prompt)
)

uuid = "1"

if len(all_tasks) == 0:
if not all_tasks:
return create_response_message("Invalid Request", 400)

thread = threading.Thread(target=bulk_completions, args=(global_state, all_tasks,))
thread.start()

Expand Down Expand Up @@ -140,18 +143,20 @@ def bulk_completions(global_state, tasks: List[InferenceRequest]):
local_tasks.append(task)
else:
remote_tasks.append(task)
if len(remote_tasks) > 0:
if remote_tasks:
with ThreadPoolExecutor(max_workers=len(remote_tasks)) as executor:
futures = []
for inference_request in remote_tasks:
futures.append(executor.submit(global_state.text_generation, inference_request))

futures = [
executor.submit(
global_state.text_generation, inference_request
)
for inference_request in remote_tasks
]
[future.result() for future in futures]

#Not safe to assume that localhost can run multiple models at once
for inference_request in local_tasks:
global_state.text_generation(inference_request)

global_state.get_announcer().announce(InferenceResult(
uuid=tasks[0].uuid,
model_name=None,
Expand Down
6 changes: 3 additions & 3 deletions server/lib/api/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def provider_update_api_key(provider_name):

api_key = data['apiKey']
if api_key is None:
return create_response_message(f"Invalid API key", 400)
return create_response_message("Invalid API key", 400)

storage.update_provider_api_key(provider_name, api_key)

response = jsonify({'status': 'success'})
response.headers.add('Access-Control-Allow-Origin', '*')
return response
Expand Down
5 changes: 1 addition & 4 deletions server/lib/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ def __init__(
self.search_url = search_url

def has_model(self, model_name: str) -> bool:
for model in self.models:
if model.name == model_name:
return True
return False
return any(model.name == model_name for model in self.models)

def get_model(self, model_name: str) -> Model:
for model in self.models:
Expand Down
22 changes: 10 additions & 12 deletions server/lib/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@ def on(self, event: EVENTS, listener):

def off(self, event, listener):
with self._lock:
if event in self.listeners:
if listener in self.listeners[event]:
self.listeners[event].remove(listener)
if event in self.listeners and listener in self.listeners[event]:
self.listeners[event].remove(listener)

def emit(self, event: EVENTS, *args, **kwargs):
if event in EVENTS.__members__.values():
if event.value not in self.listeners:
return
if event not in EVENTS.__members__.values():
raise ValueError(f"Invalid event type: {event}")
if event.value not in self.listeners:
return

with self._lock:
listeners_to_notify = self.listeners[event.value].copy()
with self._lock:
listeners_to_notify = self.listeners[event.value].copy()

for listener in listeners_to_notify:
listener(event, *args, **kwargs)
else:
raise ValueError("Invalid event type: %s" % event)
for listener in listeners_to_notify:
listener(event, *args, **kwargs)
85 changes: 44 additions & 41 deletions server/lib/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __error_handler__(self, inference_fn: InferenceFunction, provider_details: P
probability=None,
top_n_distribution=None
)

if not self.announcer.announce(InferenceResult(
uuid=inference_request.uuid,
model_name=inference_request.model_name,
Expand Down Expand Up @@ -180,7 +180,7 @@ def __error_handler__(self, inference_fn: InferenceFunction, provider_details: P
infer_result.token = f"[ERROR] OpenAI API request exceeded rate limit: {e}"
logger.error(f"OpenAI API request exceeded rate limit: {e}")
except requests.exceptions.RequestException as e:
logging.error("RequestException: {}".format(e))
logging.error(f"RequestException: {e}")
infer_result.token = f"[ERROR] No response from {infer_result.model_provider } after sixty seconds"
except ValueError as e:
if infer_result.model_provider == "huggingface-local":
Expand All @@ -204,7 +204,7 @@ def __openai_chat_generation__(self, provider_details: ProviderDetails, inferenc
current_date = datetime.now().strftime("%Y-%m-%d")

if inference_request.model_name == "gpt-4":
system_content = f"You are GPT-4, a large language model trained by OpenAI. Answer as concisely as possible"
system_content = "You are GPT-4, a large language model trained by OpenAI. Answer as concisely as possible"
else:
system_content = f"You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: 2021-09-01 Current date: {current_date}"

Expand Down Expand Up @@ -233,7 +233,7 @@ def __openai_chat_generation__(self, provider_details: ProviderDetails, inferenc

delta = response['delta']

if not "content" in delta:
if "content" not in delta:
continue

generated_token = delta["content"]
Expand Down Expand Up @@ -326,7 +326,7 @@ def __openai_text_generation__(self, provider_details: ProviderDetails, inferenc

def openai_text_generation(self, provider_details: ProviderDetails, inference_request: InferenceRequest):
# TODO: Add a meta field to the inference so we know when a model is chat vs text
if inference_request.model_name == "gpt-3.5-turbo" or inference_request.model_name == "gpt-4":
if inference_request.model_name in ["gpt-3.5-turbo", "gpt-4"]:
self.__error_handler__(self.__openai_chat_generation__, provider_details, inference_request)
else:
self.__error_handler__(self.__openai_text_generation__, provider_details, inference_request)
Expand Down Expand Up @@ -402,7 +402,6 @@ def __huggingface_text_generation__(self, provider_details: ProviderDetails, inf

content_type = response.headers["content-type"]

total_tokens = 0
cancelled = False

if response.status_code != 200:
Expand All @@ -423,6 +422,7 @@ def __huggingface_text_generation__(self, provider_details: ProviderDetails, inf
top_n_distribution=None
), event="infer")
else:
total_tokens = 0
for response in response.iter_lines():
response = response.decode('utf-8')
if response == "":
Expand All @@ -434,23 +434,26 @@ def __huggingface_text_generation__(self, provider_details: ProviderDetails, inf
raise Exception(f"{error}")

token = response_json['token']

total_tokens += 1

if token["special"]:
continue

if cancelled: continue

if not self.announcer.announce(InferenceResult(
uuid=inference_request.uuid,
model_name=inference_request.model_name,
model_tag=inference_request.model_tag,
model_provider=inference_request.model_provider,
token= " " if token['id'] == 3 else response_json['token']['text'],
probability=response_json['token']['logprob'],
top_n_distribution=None
), event="infer"):
if not self.announcer.announce(
InferenceResult(
uuid=inference_request.uuid,
model_name=inference_request.model_name,
model_tag=inference_request.model_tag,
model_provider=inference_request.model_provider,
token=" " if token['id'] == 3 else token['text'],
probability=token['logprob'],
top_n_distribution=None,
),
event="infer",
):
cancelled = True
logger.info(f"Cancelled inference for {inference_request.uuid} - {inference_request.model_name}")

Expand All @@ -459,24 +462,24 @@ def huggingface_text_generation(self, provider_details: ProviderDetails, inferen

def __forefront_text_generation__(self, provider_details: ProviderDetails, inference_request: InferenceRequest):
with requests.post(
f"https://shared-api.forefront.link/organization/gPn2ZLSO3mTh/{inference_request.model_name}/completions/{provider_details.version_key}",
headers={
"Authorization": f"Bearer {provider_details.api_key}",
"Content-Type": "application/json",
},
data=json.dumps({
"text": inference_request.prompt,
"top_p": float(inference_request.model_parameters['topP']),
"top_k": int(inference_request.model_parameters['topK']),
"temperature": float(inference_request.model_parameters['temperature']),
"repetition_penalty": float(inference_request.model_parameters['repetitionPenalty']),
"length": int(inference_request.model_parameters['maximumLength']),
"stop": inference_request.model_parameters['stopSequences'],
"logprobs": 5,
"stream": True,
}),
stream=True
) as response:
f"https://shared-api.forefront.link/organization/gPn2ZLSO3mTh/{inference_request.model_name}/completions/{provider_details.version_key}",
headers={
"Authorization": f"Bearer {provider_details.api_key}",
"Content-Type": "application/json",
},
data=json.dumps({
"text": inference_request.prompt,
"top_p": float(inference_request.model_parameters['topP']),
"top_k": int(inference_request.model_parameters['topK']),
"temperature": float(inference_request.model_parameters['temperature']),
"repetition_penalty": float(inference_request.model_parameters['repetitionPenalty']),
"length": int(inference_request.model_parameters['maximumLength']),
"stop": inference_request.model_parameters['stopSequences'],
"logprobs": 5,
"stream": True,
}),
stream=True
) as response:
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code} {response.reason}")
cancelled = False
Expand Down Expand Up @@ -515,10 +518,10 @@ def __forefront_text_generation__(self, provider_details: ProviderDetails, infer

for index, new_token in enumerate(new_tokens):
generated_token = new_token

probability = token_logprobs[total_tokens + index]
top_logprobs = logprobs["top_logprobs"][total_tokens + index]

chosen_log_prob = 0
prob_dist = ProablityDistribution(
log_prob_sum=0, simple_prob_sum=0, tokens={},
Expand All @@ -529,11 +532,11 @@ def __forefront_text_generation__(self, provider_details: ProviderDetails, infer
simple_prob = round(math.exp(log_prob) * 100, 2)
prob_dist.tokens[token] = [log_prob, simple_prob]

if token == new_token:
if token == generated_token:
chosen_log_prob = round(log_prob, 2)

prob_dist.simple_prob_sum += simple_prob

prob_dist.tokens = dict(
sorted(prob_dist.tokens.items(), key=lambda item: item[1][0], reverse=True)
)
Expand Down
Loading

0 comments on commit dcce17e

Please sign in to comment.