Skip to content

Commit

Permalink
✨Introduced audit_text
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 8, 2022
1 parent e52b5be commit efed1de
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 21 deletions.
13 changes: 8 additions & 5 deletions apis/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,23 @@ async def hello(data: HelloModel) -> HelloResponse:
return await run_algorithm(loaded_algorithms["demo.hello"], data)


# translate
# get prompt


class TranslateModel(BaseModel):
class GetPromptModel(BaseModel):
text: str


class TranslateResponse(BaseModel):
class GetPromptResponse(BaseModel):
text: str
success: bool
reason: str


@app.post("/translate")
def translate(data: TranslateModel) -> TranslateResponse:
return TranslateResponse(text=data.text)
@app.post("/get_prompt")
def get_prompt(data: GetPromptModel) -> GetPromptResponse:
return GetPromptResponse(text=data.text, success=True, reason="")


# txt2img
Expand Down
27 changes: 22 additions & 5 deletions apis/kafka/producer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import json
import time
import yaml
import redis
import datetime
import requests
import logging.config

from enum import Enum
Expand All @@ -13,14 +15,19 @@
from typing import Optional
from typing import NamedTuple
from fastapi import FastAPI
from pydantic import Field
from pydantic import BaseModel
from qcloud_cos import CosConfig
from qcloud_cos import CosS3Client
from pkg_resources import get_distribution
from cftool.misc import random_hash
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware

from cfclient.utils import get_responses

from cfcreator import *


app = FastAPI()
root = os.path.dirname(__file__)
Expand Down Expand Up @@ -61,13 +68,16 @@ def filter(self, record: logging.LogRecord) -> bool:
return True

logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
logging.getLogger("dicttoxml").disabled = True
logging.getLogger("kafka.conn").disabled = True
logging.getLogger("kafka.cluster").disabled = True
logging.getLogger("kafka.coordinator").disabled = True
logging.getLogger("kafka.consumer.subscription_state").disabled = True


# clients
config = CosConfig(Region=REGION, SecretId=SECRET_ID, SecretKey=SECRET_KEY)
cos_client = CosS3Client(config)
redis_client = redis.Redis(host="localhost", port=6379, db=0)
kafka_admin = KafkaAdminClient(bootstrap_servers="172.17.16.8:9092")
kafka_producer = KafkaProducer(bootstrap_servers="172.17.16.8:9092")
Expand Down Expand Up @@ -112,20 +122,27 @@ async def health_check() -> HealthCheckResponse:
return {"status": "alive"}


# translate
# get prompt


class TranslateModel(BaseModel):
class GetPromptModel(BaseModel):
text: str


class TranslateResponse(BaseModel):
class GetPromptResponse(BaseModel):
text: str
success: bool
reason: str


@app.post("/translate")
def translate(data: TranslateModel) -> TranslateResponse:
return TranslateResponse(text=data.text)
@app.post("/get_prompt")
def get_prompt(data: GetPromptModel) -> GetPromptResponse:
text = data.text
audit = audit_text(cos_client, text)
if not audit.safe:
return GetPromptResponse(text="", success=False, reason=audit.reason)
return GetPromptResponse(text=text, success=True, reason="")


# kafka & redis
Expand Down
99 changes: 88 additions & 11 deletions cfcreator/cos.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import io
import os
import time
import uuid

import numpy as np

from PIL import Image
from typing import Union
from typing import BinaryIO
from typing import Optional
from pydantic import Field
from pydantic import BaseModel
from qcloud_cos import CosConfig
Expand All @@ -17,26 +19,97 @@
BUCKET = "ailab-1310750649"
CDN_HOST = "https://ailabcdn.nolibox.com"
COS_HOST = "https://ailab-1310750649.cos.ap-shanghai.myqcloud.com/"
TEXT_BIZ_TYPE = "56daee337ae2d847e55838c0ddb6d547"
SECRET_ID = os.getenv("SECRETID")
SECRET_KEY = os.getenv("SECRETKEY")

TEMP_FOLDER = "tmp"
TEMP_TEXT_FOLDER = "tmp_txt"
TEMP_IMAGE_FOLDER = "tmp"


class UploadResponse(BaseModel):
class UploadTextResponse(BaseModel):
path: str = Field(..., description="The path on the cloud.")
cdn: str = Field(..., description="The `cdn` url of the input text.")
cos: str = Field(..., description="The `cos` url of the input text, which should be used internally.")


class AuditResponse(BaseModel):
safe: bool = Field(..., description="Whether the input content is safe.")
reason: str = Field(..., description="If not safe, what's the reason?")


class UploadImageResponse(BaseModel):
path: str = Field(..., description="The path on the cloud.")
cdn: str = Field(..., description="The `cdn` url of the input image.")
cos: str = Field(..., description="The `cos` url of the input image, which should be used internally.")


def upload_text(
client: CosS3Client,
text: str,
*,
folder: str,
part_size: int = 10,
max_thread: int = 10,
) -> UploadTextResponse:
path = f"{folder}/{uuid.uuid4().hex}.txt"
text_io = io.StringIO(text)
client.upload_file_from_buffer(BUCKET, path, text_io, PartSize=part_size, MAXThread=max_thread)
return UploadTextResponse(
path=path,
cdn=f"{CDN_HOST}/{path}",
cos=f"{COS_HOST}/{path}",
)

def upload_temp_text(
client: CosS3Client,
text: str,
*,
part_size: int = 10,
max_thread: int = 10,
) -> UploadTextResponse:
return upload_text(
client,
text,
folder=TEMP_TEXT_FOLDER,
part_size=part_size,
max_thread=max_thread,
)


def parse_audit_text(res: dict) -> Optional[AuditResponse]:
detail = res["JobsDetail"]
if detail["State"] != "Success":
return
label = detail["Label"]
return AuditResponse(safe=label == "Normal", reason=label)

def audit_text(client: CosS3Client, text: str) -> AuditResponse:
res = client.ci_auditing_text_submit(BUCKET, "", Content=text.encode("utf-8"), BizType=TEXT_BIZ_TYPE)
job_id = res["JobsDetail"]["JobId"]
parsed = parse_audit_text(res)
patience = 20
interval = 100
for i in range(patience):
if parsed is not None:
break
time.sleep(interval)
res = client.ci_auditing_text_query(BUCKET, job_id)
parsed = parse_audit_text(res)
if parsed is None:
return AuditResponse(safe=False, reason=f"Timeout ({patience * interval})")
return parsed


def upload_image(
client: CosS3Client,
inp: Union[bytes, np.ndarray, BinaryIO],
*,
folder: str,
part_size: int = 10,
max_thread: int = 10,
) -> UploadResponse:
temp_path = f"{folder}/{uuid.uuid4().hex}.png"
) -> UploadImageResponse:
path = f"{folder}/{uuid.uuid4().hex}.png"
if isinstance(inp, bytes):
img_bytes = io.BytesIO(inp)
elif isinstance(inp, np.ndarray):
Expand All @@ -45,10 +118,11 @@ def upload_image(
img_bytes.seek(0)
else:
img_bytes = inp
client.upload_file_from_buffer(BUCKET, temp_path, img_bytes, PartSize=part_size, MAXThread=max_thread)
return UploadResponse(
cdn=f"{CDN_HOST}/{temp_path}",
cos=f"{COS_HOST}/{temp_path}",
client.upload_file_from_buffer(BUCKET, path, img_bytes, PartSize=part_size, MAXThread=max_thread)
return UploadImageResponse(
path=path,
cdn=f"{CDN_HOST}/{path}",
cos=f"{COS_HOST}/{path}",
)

def upload_temp_image(
Expand All @@ -57,11 +131,11 @@ def upload_temp_image(
*,
part_size: int = 10,
max_thread: int = 10,
) -> UploadResponse:
) -> UploadImageResponse:
return upload_image(
client,
inp,
folder=TEMP_FOLDER,
folder=TEMP_IMAGE_FOLDER,
part_size=part_size,
max_thread=max_thread,
)
Expand All @@ -71,9 +145,12 @@ def upload_temp_image(
"REGION",
"SECRET_ID",
"SECRET_KEY",
"upload_text",
"upload_temp_text",
"audit_text",
"upload_image",
"upload_temp_image",
"UploadResponse",
"UploadImageResponse",
]


Expand Down

0 comments on commit efed1de

Please sign in to comment.