Skip to content

Commit

Permalink
feat(inf-sd): Add local s3 support for sd api
Browse files Browse the repository at this point in the history
  • Loading branch information
hiro-v committed Aug 30, 2023
1 parent 0ac19f1 commit 19a0fe4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
3 changes: 2 additions & 1 deletion jan-inference/sd/inference.requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Inference
fastapi
uvicorn
python-multipart
python-multipart
boto3
40 changes: 22 additions & 18 deletions jan-inference/sd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,32 @@
import os
from uuid import uuid4
from pydantic import BaseModel
import boto3
from botocore.client import Config

app = FastAPI()

OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output")
SD_PATH = os.environ.get("SD_PATH", "./sd")
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
MODEL_NAME = os.environ.get(
"MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin")
BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000")
"MODEL_NAME", "v1-5-pruned-emaonly.safetensors.q4_0.bin")

S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")
S3_PUBLIC_ENDPOINT_URL = os.environ.get(
"S3_PUBLIC_ENDPOINT_URL", "http://localhost:9000")
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", "minio")
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", "minio123")
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "jan")

s3 = boto3.resource('s3',
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version='s3v4'),
region_name='us-east-1')

s3_bucket = s3.Bucket(S3_BUCKET_NAME)


class Payload(BaseModel):
Expand All @@ -33,9 +50,6 @@ class Payload(BaseModel):
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)

# Serve files from the "files" directory
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output")


def run_command(payload: Payload, filename: str):
# Construct the command based on your provided example
Expand Down Expand Up @@ -66,21 +80,11 @@ async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
# We will use background task to run the command so it won't block
# background_tasks.add_task(run_command, payload, filename)
run_command(payload, filename)

s3_bucket.upload_file(f'{os.path.join(OUTPUT_DIR, filename)}', filename)
# Return the expected path of the output file
return {"url": f'{BASE_URL}/serve/{filename}'}


@app.get("/serve/{filename}")
async def serve_file(filename: str):
file_path = os.path.join(OUTPUT_DIR, filename)

if os.path.exists(file_path):
return FileResponse(file_path)
else:
raise HTTPException(status_code=404, detail="File not found")
return {"url": f'{S3_PUBLIC_ENDPOINT_URL}/{S3_BUCKET_NAME}/{filename}'}


if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8002)
7 changes: 7 additions & 0 deletions sample.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ LLM_MODEL_FILE=llama-2-7b-chat.ggmlv3.q4_1.bin
## SD
SD_MODEL_URL=https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
SD_MODEL_FILE=v1-5-pruned-emaonly.safetensors

# Minio
S3_ACCESS_KEY_ID=minio
S3_SECRET_ACCESS_KEY=minio123
S3_BUCKET_NAME=jan
S3_ENDPOINT_URL=http://minio:9000
S3_PUBLIC_ENDPOINT_URL=http://127.0.0.1:9000

0 comments on commit 19a0fe4

Please sign in to comment.