Skip to content

Commit

Permalink
refactored rabbitmq to be more robust to failure and solved connectio…
Browse files Browse the repository at this point in the history
…n issue
  • Loading branch information
David Garnitz authored and David Garnitz committed Aug 16, 2023
1 parent 38e89b7 commit 0453fa6
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 34 deletions.
10 changes: 10 additions & 0 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ def dequeue():
vectorflow_key = request.headers.get('VectorFlowKey')
if not vectorflow_key or not auth.validate_credentials(vectorflow_key):
return jsonify({'error': 'Invalid credentials'}), 401

pipeline.connect()
if pipeline.get_queue_size() == 0:
pipeline.disconnect()
return jsonify({'error': 'No jobs in queue'}), 404
else:
body = pipeline.get_from_queue()
pipeline.disconnect()

data = json.loads(body)
batch_id, source_data = data

return jsonify({'batch_id': batch_id, 'source_data': source_data}), 200

def create_batches(file_content, job_id, embeddings_metadata, vector_db_metadata, lines_per_chunk):
Expand All @@ -116,7 +122,11 @@ def create_batches(file_content, job_id, embeddings_metadata, vector_db_metadata
for batch, chunk in zip(batches, chunks):
data = (batch.id, chunk)
json_data = json.dumps(data)

pipeline.connect()
pipeline.add_to_queue(json_data)
pipeline.disconnect()

return job.total_batches if job else None

def split_file(file_content, lines_per_chunk=1000):
Expand Down
38 changes: 25 additions & 13 deletions src/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,36 @@

class Pipeline:
def __init__(self):
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
)
self.credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))
self.queue_name = os.getenv('RABBITMQ_QUEUE')
self.channel = None

def connect(self):
connection_params = self._get_connection_params()
connection = pika.BlockingConnection(connection_params)
self.channel = connection.channel()
self.queue_name = os.getenv('RABBITMQ_QUEUE')
self.channel.queue_declare(queue=self.queue_name)

def disconnect(self):
if self.channel and self.channel.is_open:
self.channel.close()

def _get_connection_params(self):
if os.getenv('RABBITMQ_PORT') == "5671":
return pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=self.credentials,
port=os.getenv('RABBITMQ_PORT'),
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
)
else:
return pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=self.credentials,
)


def add_to_queue(self, data):
self.channel.basic_publish(exchange='', routing_key=self.queue_name, body=str(data))

Expand Down
3 changes: 2 additions & 1 deletion src/worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
MAX_THREADS_OPENAI = 20
MAX_OPENAI_EMBEDDING_BATCH_SIZE = 2048
PINECONE_BATCH_SIZE = 128
SLEEP_SECONDS = 5
SLEEP_SECONDS = 5
PIKA_RETRY_INTERVAL = 10
52 changes: 32 additions & 20 deletions src/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,27 +258,39 @@ def callback(ch, method, properties, body):

ch.basic_ack(delivery_tag=method.delivery_tag)

if __name__ == "__main__":
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
)
def start_connection():
while True:
try:
credentials = pika.PlainCredentials(os.getenv('RABBITMQ_USERNAME'), os.getenv('RABBITMQ_PASSWORD'))

connection_params = pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
port=os.getenv('RABBITMQ_PORT'),
heartbeat=600,
ssl_options=pika.SSLOptions(ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)),
virtual_host="/"
) if os.getenv('RABBITMQ_PORT') == "5671" else pika.ConnectionParameters(
host=os.getenv('RABBITMQ_HOST'),
credentials=credentials,
heartbeat=600,
)

connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()

queue_name = os.getenv('RABBITMQ_QUEUE')
channel.queue_declare(queue=queue_name)
queue_name = os.getenv('RABBITMQ_QUEUE')
channel.queue_declare(queue=queue_name)

channel.basic_consume(queue=queue_name, on_message_callback=callback)
channel.basic_consume(queue=queue_name, on_message_callback=callback)

logging.info('Waiting for messages.')
channel.start_consuming()

except Exception as e:
logging.error('ERROR connecting to RabbitMQ, retrying now. See exception:', e)
time.sleep(config.PIKA_RETRY_INTERVAL) # Wait before retrying

if __name__ == "__main__":
start_connection()

logging.info('Waiting for messages.')
channel.start_consuming()

0 comments on commit 0453fa6

Please sign in to comment.