Skip to content

Commit

Permalink
feat: add sso config
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 committed Jun 14, 2024
1 parent 1f62d56 commit c13b455
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 78 deletions.
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def getn_env():
env['office_url'] = settings.settings.get_from_db('office_url')
# add tips from settings
env['dialog_tips'] = settings.settings.get_from_db('dialog_tips')
# 判断是否SSO
env['sso'] = settings.settings.system_login_method.get('SSO_OAuth', False)
# add env dict from settings
env.update(settings.settings.get_from_db('env') or {})
return resp_200(env)
Expand Down
84 changes: 32 additions & 52 deletions src/backend/bisheng/api/v1/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
from base64 import b64decode, b64encode
from datetime import datetime
from io import BytesIO
from typing import List, Optional, Dict, Annotated
from typing import Annotated, Dict, List, Optional
from uuid import UUID

import rsa

from bisheng.api.errcode.user import (UserNotPasswordError, UserPasswordExpireError,
UserValidateError)
from bisheng.api.JWT import get_login_user
from bisheng.api.errcode.user import UserNotPasswordError, UserValidateError, UserPasswordExpireError
from bisheng.api.services.captcha import verify_captcha
from bisheng.api.services.user_service import gen_user_jwt, get_assistant_list_by_access, UserPayload, gen_user_role
from bisheng.api.services.user_service import (UserPayload, gen_user_jwt, gen_user_role,
get_assistant_list_by_access)
from bisheng.api.v1.schemas import UnifiedResponseModel, resp_200
from bisheng.cache.redis import redis_client
from bisheng.database.base import session_getter
from bisheng.database.models.flow import Flow
from bisheng.database.models.group import GroupDao
from bisheng.database.models.knowledge import Knowledge
from bisheng.database.models.role import Role, RoleCreate, RoleUpdate, RoleDao
from bisheng.database.models.role import Role, RoleCreate, RoleDao, RoleUpdate
from bisheng.database.models.role_access import AccessType, RoleAccess, RoleRefresh
from bisheng.database.models.user import User, UserCreate, UserDao, UserLogin, UserRead, UserUpdate
from bisheng.database.models.user_group import UserGroupDao
Expand Down Expand Up @@ -97,12 +98,15 @@ async def regist(*, user: UserCreate):
@router.post('/user/sso', response_model=UnifiedResponseModel[UserRead], status_code=201)
async def sso(*, user: UserCreate, Authorize: AuthJWT = Depends()):
'''给sso提供的接口'''
if True: # 判断sso 是否打开
if settings.system_login_method.get('SSO_OAuth', False): # 判断sso 是否打开
account_name = user.user_name
user_exist = UserDao.get_unique_user_by_name(account_name)
if not user_exist:
logger.info('act=create_user account={}', account_name)
user_exist = UserDao.create_user(user)
default_admin = settings.system_login_method.get('admin_username')
if default_admin and default_admin == account_name:
UserRoleDao.set_admin_user(user_exist.user_id)
UserGroupDao.add_default_user_group(user_exist.user_id)

access_token, refresh_token, _ = gen_user_jwt(user_exist)
Expand Down Expand Up @@ -148,7 +152,8 @@ async def login(*, user: UserLogin, Authorize: AuthJWT = Depends()):

# 判断下密码是否长期未修改
if password_conf.password_valid_period and password_conf.password_valid_period > 0:
if (datetime.now() - db_user.password_update_time).days >= password_conf.password_valid_period:
if (datetime.now() -
db_user.password_update_time).days >= password_conf.password_valid_period:
return UserPasswordExpireError.return_resp()

access_token, refresh_token, role = gen_user_jwt(db_user)
Expand Down Expand Up @@ -257,25 +262,11 @@ async def list_user(*,
group_dict = {}
for one in users:
one_data = one.model_dump()
user_roles = get_user_roles(one, role_dict)
user_groups = get_user_groups(one, group_dict)
# 如果不是超级管理,则需要将数据过滤, 不能看到非他操作用户管理的用户组内的角色和用户组列表
if user_admin_groups:
for i in range(len(user_roles) - 1, -1, -1):
if user_roles[i]["group_id"] not in user_admin_groups:
del user_roles[i]
for i in range(len(user_groups) - 1, -1, -1):
if user_groups[i]["id"] not in user_admin_groups:
del user_groups[i]

one_data["roles"] = user_roles
one_data["groups"] = user_groups
one_data['roles'] = get_user_roles(one, role_dict)
one_data['groups'] = get_user_groups(one, group_dict)
res.append(one_data)

return resp_200({
'data': res,
'total': total_count
})
return resp_200({'data': res, 'total': total_count})


def get_user_roles(user: User, role_cache: Dict) -> List[Dict]:
Expand All @@ -291,11 +282,7 @@ def get_user_roles(user: User, role_cache: Dict) -> List[Dict]:
if user_role_ids:
role_list = RoleDao.get_role_by_ids(user_role_ids)
for role_info in role_list:
role_cache[role_info.id] = {
"id": role_info.id,
"group_id": role_info.group_id,
"name": role_info.role_name
}
role_cache[role_info.id] = {'id': role_info.id, 'name': role_info.role_name}
res.append(role_cache.get(role_info.id))
return res

Expand All @@ -313,10 +300,7 @@ def get_user_groups(user: User, group_cache: Dict) -> List[Dict]:
if user_group_ids:
group_list = GroupDao.get_group_by_ids(user_group_ids)
for group_info in group_list:
group_cache[group_info.id] = {
"id": group_info.id,
"name": group_info.group_name
}
group_cache[group_info.id] = {'id': group_info.id, 'name': group_info.group_name}
res.append(group_cache.get(group_info.id))
return res

Expand Down Expand Up @@ -414,10 +398,7 @@ async def get_role(*,
# 查询所有的角色列表
res = RoleDao.get_role_by_groups(group_ids, role_name, page, limit)
total = RoleDao.count_role_by_groups(group_ids, role_name)
return resp_200(data={
"data": res,
"total": total
})
return resp_200(data={"data": res, "total": total})


@router.delete('/role/{role_id}', status_code=200)
Expand Down Expand Up @@ -618,7 +599,7 @@ async def knowledge_list(*,
'id': access[0].id
} for access in db_role_access],
'total':
total_count
total_count
})


Expand Down Expand Up @@ -667,7 +648,7 @@ async def flow_list(*,
'id': access[0]
} for access in db_role_access],
'total':
total_count
total_count
})


Expand Down Expand Up @@ -714,11 +695,13 @@ async def get_rsa_publish_key():
return resp_200({'public_key': pubkey_str})


@router.post("/user/reset_password", status_code=200)
async def reset_password(*,
user_id: int = Body(embed=True),
password: str = Body(embed=True),
login_user: UserPayload = Depends(get_login_user)):
@router.post('/user/reset_password', status_code=200)
async def reset_password(
*,
user_id: int,
password: str,
login_user: UserPayload = Depends(get_login_user),
):
"""
管理员重置用户密码
"""
Expand All @@ -743,10 +726,10 @@ async def reset_password(*,
return resp_200()


@router.post("/user/change_password", status_code=200)
@router.post('/user/change_password', status_code=200)
async def change_password(*,
password: str = Body(embed=True),
new_password: str = Body(embed=True),
password: str,
new_password: str,
login_user: UserPayload = Depends(get_login_user)):
"""
登录用户 修改自己的密码
Expand All @@ -766,11 +749,8 @@ async def change_password(*,
return resp_200()


@router.post("/user/change_password_public", status_code=200)
async def change_password_public(*,
username: str = Body(embed=True),
password: str = Body(embed=True),
new_password: str = Body(embed=True)):
@router.post('/user/change_password_public', status_code=200)
async def change_password_public(*, username: str, password: str, new_password: str):
"""
未登录用户 修改自己的密码
"""
Expand Down
21 changes: 11 additions & 10 deletions src/backend/bisheng/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from contextlib import contextmanager
from typing import List

from sqlalchemy import text

from bisheng.database.init_config import init_config
from bisheng.database.models.user_role import UserRoleDao
from bisheng.database.service import DatabaseService
from bisheng.settings import settings
from bisheng.utils.logger import logger
from sqlalchemy import text
from sqlmodel import Session, select, update

db_service: 'DatabaseService' = DatabaseService(settings.database_url)
Expand All @@ -21,7 +21,6 @@ def init_default_data():
from bisheng.database.models.component import Component
from bisheng.database.models.role import Role
from bisheng.database.models.user import User
from bisheng.database.models.user_role import UserRole
from bisheng.database.models.gpts_tools import GptsTools
from bisheng.database.models.gpts_tools import GptsToolsType
from bisheng.database.models.sft_model import SftModel
Expand Down Expand Up @@ -52,9 +51,7 @@ def init_default_data():
session.add(user)
session.commit()
session.refresh(user)
db_userrole = UserRole(user_id=user.user_id, role_id=db_role.id)
session.add(db_userrole)
session.commit()
UserRoleDao.set_admin_user(user.user_id)

component_db = session.exec(select(Component).limit(1)).all()
if not component_db:
Expand Down Expand Up @@ -93,10 +90,12 @@ def init_default_data():
session.exec(update(GptsTools).where(GptsTools.id == i).values(type=i))
# 属于天眼查类别下的工具
tyc_types: List[int] = list(range(7, 18))
session.exec(update(GptsTools).where(GptsTools.id.in_(tyc_types)).values(type=7))
session.exec(
update(GptsTools).where(GptsTools.id.in_(tyc_types)).values(type=7))
# 属于金融类别下的工具
jr_types: List[int] = list(range(18, 28))
session.exec(update(GptsTools).where(GptsTools.id.in_(jr_types)).values(type=8))
session.exec(
update(GptsTools).where(GptsTools.id.in_(jr_types)).values(type=8))
session.commit()
# 初始化配置可用于微调的基准模型
preset_models = session.exec(select(SftModel).limit(1)).all()
Expand All @@ -112,13 +111,15 @@ def init_default_data():
# 初始化补充默认的技能版本表
flow_version = session.exec(select(FlowVersion).limit(1)).all()
if not flow_version:
sql_query = text("INSERT INTO `flowversion` (`name`, `flow_id`, `data`, `user_id`, `is_current`, `is_delete`) \
sql_query = text(
"INSERT INTO `flowversion` (`name`, `flow_id`, `data`, `user_id`, `is_current`, `is_delete`) \
select 'v0', `id` as flow_id, `data`, `user_id`, 1, 0 from `flow`;")
session.execute(sql_query)
session.commit()
# 修改表单数据表
sql_query = text(
"UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` WHERE flow_id=a.flow_id and is_current=1)")
'UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` WHERE flow_id=a.flow_id and is_current=1)' # noqa
)
session.execute(sql_query)
session.commit()
# 初始化数据库config
Expand Down
12 changes: 12 additions & 0 deletions src/backend/bisheng/database/models/user_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,15 @@ def get_admins_user(cls) -> List[UserRole]:
with session_getter() as session:
statement = select(UserRole).where(UserRole.role_id == 1)
return session.exec(statement).all()

@classmethod
def set_admin_user(cls, user_id: int) -> UserRole:
"""
设置用户为超级管理员
"""
with session_getter() as session:
user_role = UserRole(user_id=user_id, role_id=1)
session.add(user_role)
session.commit()
session.refresh(user_role)
return user_role
6 changes: 3 additions & 3 deletions src/backend/bisheng/initdb_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ gpts:
system_login_method:
# SSO 登录
SSO_OAuth: true
# LDAP 登录
LDAP: true
LDAP服务器地址配置: XX
# # LDAP 登录
# LDAP: true
# LDAP服务器地址配置: XX

# 切换 SSO/LDAP 登录后管理员用户名
admin_username: admin
26 changes: 13 additions & 13 deletions src/backend/bisheng/settings.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import os
from typing import Dict, Optional, Union, List
import re
import yaml
from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Union

from loguru import logger
import yaml
from bisheng.database.models.config import Config
from cryptography.fernet import Fernet

from langchain.pydantic_v1 import BaseSettings, root_validator, validator
from loguru import logger
from pydantic import BaseModel, Field
from sqlmodel import select

from bisheng.database.models.config import Config


class LoggerConf(BaseModel):
level: str = 'DEBUG'
format: str = '<level>[{level.name} process-{process.id}-{thread.id} {name}:{line}]</level> - <level>trace={extra[trace_id]} {message}</level>'
format: str = '<level>[{level.name} process-{process.id}-{thread.id} {name}:{line}]</level> - <level>trace={extra[trace_id]} {message}</level>' # noqa
handlers: List[Dict] = []

@classmethod
Expand All @@ -25,7 +23,7 @@ def parse_logger_sink(cls, sink: str) -> str:
return sink
env_keys = {}
for one in match.groups():
env_keys[one] = os.getenv(one, "")
env_keys[one] = os.getenv(one, '')
return sink.format(**env_keys)

@validator('handlers', pre=True)
Expand All @@ -41,9 +39,9 @@ def set_handlers(cls, value):


class PasswordConf(BaseModel):
password_valid_period: Optional[int] = Field(description="密码超过X天必须进行修改, 登录提示重新修改密码")
login_error_time_window: Optional[int] = Field(description="登录错误时间窗口,单位分钟")
max_error_times: Optional[int] = Field(description="最大错误次数,超过后会封禁用户")
password_valid_period: Optional[int] = Field(description='密码超过X天必须进行修改, 登录提示重新修改密码')
login_error_time_window: Optional[int] = Field(description='登录错误时间窗口,单位分钟')
max_error_times: Optional[int] = Field(description='最大错误次数,超过后会封禁用户')


class Settings(BaseSettings):
Expand Down Expand Up @@ -81,6 +79,7 @@ class Settings(BaseSettings):
minio_conf = {}
logger_conf: LoggerConf = LoggerConf()
password_conf: PasswordConf = PasswordConf()
system_login_method: dict = {}

@validator('database_url', pre=True)
def set_database_url(cls, value):
Expand Down Expand Up @@ -163,7 +162,8 @@ def get_all_config(self):
return yaml.safe_load(cache)
else:
with session_getter() as session:
initdb_config = session.exec(select(Config).where(Config.key == 'initdb_config')).first()
initdb_config = session.exec(
select(Config).where(Config.key == 'initdb_config')).first()
if initdb_config:
redis_client.set(redis_key, initdb_config.value, 100)
return yaml.safe_load(initdb_config.value)
Expand Down

0 comments on commit c13b455

Please sign in to comment.