Skip to content

Commit

Permalink
fix: sql cell config validator
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshinesmilelk committed Oct 16, 2024
1 parent cc0a0b4 commit 5f38891
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
2 changes: 1 addition & 1 deletion libro-sql/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "libro-sql"
version = "0.1.8"
version = "0.1.9"
description = "libro sql"
authors = [
{ name = "brokun", email = "[email protected]" },
Expand Down
2 changes: 1 addition & 1 deletion libro-sql/src/libro_sql/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = "0.1.8"
__version__ = "0.1.9"
38 changes: 20 additions & 18 deletions libro-sql/src/libro_sql/database.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@

from typing import Dict
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, model_validator
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
import pandas as pd
from libro_core.config import libro_config

class DatabaseConfig(BaseModel):
db_type: str
username: str | None = None # 可选字段
password: str | None = None # 可选字段
host: str | None = None # 可选字段
port: int | None = None # 可选字段
username: str = None
password: str = None
host: str = None
port: int = None
database: str

@field_validator('database', mode='before')
def validate_database(cls, v, values):
db_type = values.get('db_type')
if db_type in ['postgresql', 'mysql','sqlite'] and not v:
raise ValueError('database must be provided.')
return v

@field_validator('username', 'password', 'host', 'port', mode='before')
def validate_fields(cls, v, values, field):
@model_validator(mode="before")
def validate_fields(cls, values):
db_type = values.get('db_type')

# 如果 db_type 是 'postgresql' 或 'mysql',则这些字段为必填
if db_type in ['postgresql', 'mysql']:
if v is None:
raise ValueError(f'{field.name} must be provided when db_type is {db_type}')
# 对于sqlite,不需要验证其他字段
return v
required_fields = ['username', 'password', 'host', 'port']
for field in required_fields:
if values.get(field) is None:
raise ValueError(f'{field} must be provided when db_type is {db_type}')

# 如果 db_type 是 'sqlite',则只验证 database 字段
elif db_type == 'sqlite':
if not values.get('database'):
raise ValueError('database must be provided when db_type is sqlite')

return values

class Database:
config: DatabaseConfig
Expand Down
2 changes: 1 addition & 1 deletion libro/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
{ name = "sunshinesmilelk", email = "[email protected]" },
{ name = "zhanba", email = "[email protected]" },
]
dependencies = ["libro-server>=0.1.5", "libro-sql>=0.1.8", "libro-ai>=0.1.8"]
dependencies = ["libro-server>=0.1.5", "libro-sql>=0.1.9", "libro-ai>=0.1.8"]
dev-dependencies = []
readme = "README.md"
requires-python = ">= 3.9"
Expand Down

0 comments on commit 5f38891

Please sign in to comment.