Skip to content

Commit

Permalink
Fix rabbit mq bug & add basic retry to worker (dgarnitz#92)
Browse files Browse the repository at this point in the history
* upgraded rabbitmq connection logic

* upgraded message broker logic

* basic retry is working

* fixed frontend telemetry bug. added basic retry to worker.py as a test

* fixed tests

---------

Co-authored-by: David Garnitz <[email protected]>
  • Loading branch information
dgarnitz and David Garnitz authored Nov 8, 2023
1 parent 7baa00f commit b87d28c
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 97 deletions.
5 changes: 4 additions & 1 deletion src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import fitz
import base64
import logging
import copy
from io import BytesIO
import services.database.batch_service as batch_service
import services.database.job_service as job_service
Expand Down Expand Up @@ -74,7 +75,9 @@ def embed():
if file and is_valid_file_type(file):
job = safe_db_operation(job_service.create_job, vectorflow_request, file.filename)
batch_count = process_file(file, vectorflow_request, job.id)
send_telemetry("SINGLE_FILE_UPLOAD_SUCCESS", vectorflow_request)

vectorflow_request_copy = copy.deepcopy(vectorflow_request)
send_telemetry("SINGLE_FILE_UPLOAD_SUCCESS", vectorflow_request_copy)

return jsonify({'message': f"Successfully added {batch_count} batches to the queue", 'JobID': job.id}), 200
else:
Expand Down
52 changes: 28 additions & 24 deletions src/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import json
import pika
import ssl
import time
import fitz
from pathlib import Path
Expand All @@ -21,6 +20,8 @@
from io import BytesIO
from llama_index import download_loader
from shared.vectorflow_request import VectorflowRequest
from services.rabbitmq.rabbit_service import create_connection_params
from pika.exceptions import AMQPConnectionError

logging.basicConfig(filename='./extract-log.txt', level=logging.INFO)
logging.basicConfig(filename='./extract-error-log.txt', level=logging.ERROR)
Expand Down Expand Up @@ -172,30 +173,16 @@ def callback(ch, method, properties, body):

ch.basic_ack(delivery_tag=method.delivery_tag)

def start_connection():
def start_connection(max_retries=5, retry_delay=5):
global publish_channel
global connection

while True:
try:
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
heartbeat=600,
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
heartbeat=600,
)

for attempt in range(max_retries):
try:
connection_params = create_connection_params()
connection = pika.BlockingConnection(connection_params)
consume_channel = connection.channel()
publish_channel = connection.channel()
publish_channel = connection.channel()

consume_queue_name = os.getenv('EXTRACTION_QUEUE')
publish_queue_name = os.getenv('EMBEDDING_QUEUE')
Expand All @@ -207,10 +194,27 @@ def start_connection():

logging.info('Waiting for messages.')
consume_channel.start_consuming()

return # If successful, exit the function

except AMQPConnectionError as e:
logging.error('AMQP Connection Error: %s', e)
except Exception as e:
logging.error('ERROR connecting to RabbitMQ, retrying now. See exception: %s', e)
time.sleep(10) # Wait before retrying
logging.error('Unexpected error: %s', e)
finally:
if connection and not connection.is_closed:
connection.close()

logging.info('Retrying to connect in %s seconds (Attempt %s/%s)', retry_delay, attempt + 1, max_retries)
time.sleep(retry_delay)

raise Exception('Failed to connect after {} attempts'.format(max_retries))

if __name__ == "__main__":
start_connection()
while True:
try:
start_connection()

except Exception as e:
logging.error('Error in start_connection: %s', e)
logging.info('Restarting start_connection after encountering an error.')
time.sleep(10)
47 changes: 47 additions & 0 deletions src/services/rabbitmq/rabbit_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import pika
import ssl
import logging
import time
from pika.exceptions import ConnectionClosed, ChannelClosed

def create_connection_params():
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
heartbeat=600,
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
heartbeat=600,
)
return connection_params

def publish_message_to_retry_queue(channel, retry_queue, message, publish_attempts=5):
for attempt in range(publish_attempts):
try:
channel.basic_publish(exchange='',
routing_key=retry_queue,
body=message)
logging.info("Message retried successfully")
break # Exit the loop if publish is successful
except ConnectionClosed:
logging.error("Connection was closed, retrying...")
connection_params = create_connection_params()
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue=retry_queue)
except ChannelClosed:
logging.error("Channel was closed, retrying...")
connection_params = create_connection_params()
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue=retry_queue)
except Exception as e:
logging.error("Failed to publish message: %s", e)
time.sleep(2 ** attempt) # Exponential backoff
3 changes: 2 additions & 1 deletion src/worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
SLEEP_SECONDS = 5
PIKA_RETRY_INTERVAL = 10
HUGGING_FACE_BATCH_SIZE = 32
VALIDATION_TIMEOUT = 30
VALIDATION_TIMEOUT = 30
MAX_BATCH_RETRIES = 3
18 changes: 9 additions & 9 deletions src/worker/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_process_batch_success(
mock_safe_db_operation.return_value = "test_db"

# act
worker.process_batch(batch.id, source_data)
worker.process_batch(batch.id, source_data, "fake_vdb_key", "fake_embedding_key")

# assert
mock_embed_openai_batch.assert_called_once_with(batch, chunked_data)
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_process_batch_hugging_face_embeddings_success(
mock_safe_db_operation.return_value = "test_db"

# act
worker.process_batch(batch.id, source_data)
worker.process_batch(batch.id, source_data, "fake_vdb_key", "fake_embedding_key")

# assert
mock_embed_hugging_face_batch.assert_called_once_with(batch, chunked_data)
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_process_batch_failure_no_vectors(
# arrange
source_data = "source_data"
job = Job(id=1, webhook_url="test_webhook_url", job_status=JobStatus.NOT_STARTED)
batch = Batch(id=1, job_id=1, batch_status=BatchStatus.NOT_STARTED, embeddings_metadata=EmbeddingsMetadata(embeddings_type=EmbeddingsType.OPEN_AI))
batch = Batch(id=1, job_id=1, retries=config.MAX_BATCH_RETRIES, batch_status=BatchStatus.NOT_STARTED, embeddings_metadata=EmbeddingsMetadata(embeddings_type=EmbeddingsType.OPEN_AI))
chunked_data = ["test"]
mock_chunk_data.return_value = chunked_data
mock_embed_openai_batch.return_value = None
Expand All @@ -122,11 +122,11 @@ def test_process_batch_failure_no_vectors(
mock_safe_db_operation.return_value = "test_db"

# act
worker.process_batch(batch.id, source_data)
worker.process_batch(batch.id, source_data, "fake_vdb_key", "fake_embedding_key")

# assert
mock_embed_openai_batch.assert_called_once_with(batch, chunked_data)
mock_update_batch_and_job_status.assert_called_with(batch.job_id, BatchStatus.FAILED, batch.id)
mock_update_batch_and_job_status.assert_called_with(batch.job_id, BatchStatus.FAILED, batch.id, 3)
mock_upload_to_vector_db.assert_not_called()

@patch('worker.worker.chunk_data')
Expand All @@ -152,7 +152,7 @@ def test_process_batch_failure_openai(
# arrange
source_data = "source_data"
job = Job(id=1, webhook_url="test_webhook_url", job_status=JobStatus.NOT_STARTED)
batch = Batch(id=1, job_id=1, batch_status=BatchStatus.NOT_STARTED, embeddings_metadata=EmbeddingsMetadata(embeddings_type=EmbeddingsType.OPEN_AI))
batch = Batch(id=1, job_id=1, retries=config.MAX_BATCH_RETRIES, batch_status=BatchStatus.NOT_STARTED, embeddings_metadata=EmbeddingsMetadata(embeddings_type=EmbeddingsType.OPEN_AI))
chunked_data = ["test"]
mock_chunk_data.return_value = chunked_data

Expand All @@ -163,7 +163,7 @@ def test_process_batch_failure_openai(
mock_safe_db_operation.return_value = "test_db"

# act
worker.process_batch(batch.id, source_data)
worker.process_batch(batch.id, source_data, "fake_vdb_key", "fake_embedding_key")

# assert
mock_embed_openai_batch.assert_called_once_with(batch, chunked_data)
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_process_batch_failure_validate_chunks(
# arrange
source_data = "source_data" * 100
job = Job(id=1, chunk_validation_url="test_validation_url", job_status=JobStatus.NOT_STARTED)
batch = Batch(id=1, job_id=1, batch_status=BatchStatus.NOT_STARTED,
batch = Batch(id=1, job_id=1, batch_status=BatchStatus.NOT_STARTED, retries=config.MAX_BATCH_RETRIES,
embeddings_metadata=EmbeddingsMetadata(embeddings_type=EmbeddingsType.OPEN_AI,
chunk_size=256,
chunk_overlap=128,
Expand All @@ -207,7 +207,7 @@ def test_process_batch_failure_validate_chunks(

# act & assert
with self.assertRaises(Exception) as context:
worker.process_batch(batch.id, source_data)
worker.process_batch(batch.id, source_data, "fake_vdb_key", "fake_embedding_key")

# you can add an additional check for the exception message if necessary
self.assertEqual(str(context.exception), "Failed to chunk data")
Expand Down
47 changes: 22 additions & 25 deletions src/worker/vdb_upload_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
# this is needed to import classes from the API. it will be removed when the worker is refactored
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))

import ssl
import time
import pika
import json
import openai
import pinecone
import requests
import logging
import uuid
import weaviate
import redis
import lancedb
Expand All @@ -32,6 +28,8 @@
from shared.vector_db_type import VectorDBType
from shared.utils import generate_uuid_from_tuple
from urllib.parse import quote_plus
from services.rabbitmq.rabbit_service import create_connection_params
from pika.exceptions import AMQPConnectionError

logging.basicConfig(filename='./vdb-upload-log.txt', level=logging.INFO)
logging.basicConfig(filename='./vdb-upload-errors.txt', level=logging.ERROR)
Expand Down Expand Up @@ -412,24 +410,10 @@ def callback(ch, method, properties, body):

ch.basic_ack(delivery_tag=method.delivery_tag)

def start_connection():
while True:
def start_connection(max_retries=5, retry_delay=5):
for attempt in range(max_retries):
try:
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
heartbeat=600,
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
heartbeat=600,
)

connection_params = create_connection_params()
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()

Expand All @@ -440,11 +424,24 @@ def start_connection():

logging.info('Waiting for messages.')
channel.start_consuming()
return # If successful, exit the function

except AMQPConnectionError as e:
logging.error('AMQP Connection Error: %s', e)
except Exception as e:
logging.error('ERROR connecting to RabbitMQ, retrying now. See exception: %s', e)
time.sleep(config.PIKA_RETRY_INTERVAL) # Wait before retrying
logging.error('Unexpected error: %s', e)
finally:
if connection and not connection.is_closed:
connection.close()

logging.info('Retrying to connect in %s seconds (Attempt %s/%s)', retry_delay, attempt + 1, max_retries)
time.sleep(retry_delay)

if __name__ == "__main__":
start_connection()

while True:
try:
start_connection()
except Exception as e:
logging.error('Error in start_connection: %s', e)
logging.info('Restarting start_connection after encountering an error.')
time.sleep(config.PIKA_RETRY_INTERVAL)
Loading

0 comments on commit b87d28c

Please sign in to comment.