forked from MartialBE/one-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoidc.go
165 lines (158 loc) · 3.89 KB
/
oidc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package controller
import (
"context"
"errors"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
"one-api/common/config"
"one-api/common/logger"
"one-api/common/oidc"
"one-api/common/utils"
"one-api/model"
)
func OIDCEndpoint(c *gin.Context) {
if !config.OIDCAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过OIDC登录",
"success": false,
})
return
}
oidcConfig, err := oidc.GetOIDCConfigInstance()
if err != nil {
logger.SysError("获取 OIDC 配置失败, err: " + err.Error())
c.JSON(http.StatusOK, gin.H{
"message": "获取 OIDC 配置失败",
"success": false,
})
return
}
session := sessions.Default(c)
state := utils.GetRandomString(12)
session.Set("oauth_state", state)
loginURL := oidcConfig.LoginURL(state)
err = session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": loginURL,
})
}
func OIDCAuth(c *gin.Context) {
if !config.OIDCAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过OIDC登录",
"success": false,
})
return
}
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
oidcConfig, err := oidc.GetOIDCConfigInstance()
if err != nil {
logger.SysError("获取 OIDC 配置失败, err: " + err.Error())
c.JSON(http.StatusOK, gin.H{
"message": "获取 OIDC 配置失败",
"success": false,
})
return
}
// 从请求中获取授权码
code := c.Query("code")
// 使用授权码换取ID Token
ctx := context.Background()
token, err := oidcConfig.OAuth2Config.Exchange(ctx, code)
if err != nil {
c.String(http.StatusBadRequest, "Failed to exchange token: %v", err)
return
}
// 验证ID Token
idToken, err := oidcConfig.Verifier.Verify(ctx, token.Extra("id_token").(string))
if err != nil {
c.String(http.StatusBadRequest, "Failed to verify ID token: %v", err)
return
}
// 获取用户信息
claims := make(map[string]interface{})
if err := idToken.Claims(&claims); err != nil {
c.String(http.StatusBadRequest, "Failed to parse claims: %v", err)
return
}
// 从claims中获取用户名称
userName := claims[config.OIDCUsernameClaims]
if userName == nil {
c.JSON(http.StatusOK, gin.H{
"message": "用户没有OIDC登录权限",
"success": false,
})
return
}
user := model.User{
Username: userName.(string),
}
err = user.FillUserByUsername()
if err != nil {
logger.SysError("查询用户错误:" + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
// 用户不存在
logger.SysError("用户不存在:" + err.Error())
if config.RegisterEnabled {
user.Username = userName.(string)
email := claims["email"]
if email != nil {
user.Email = email.(string)
}
display_name := claims["displayName"]
if display_name != nil {
user.DisplayName = display_name.(string)
}
user.Role = config.RoleCommonUser
user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
} else if err != nil {
logger.SysError("其他错误:" + err.Error())
c.JSON(http.StatusOK, gin.H{
"message": err.Error(),
"success": false,
})
return
}
}
if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁或不存在",
"success": false,
})
return
}
setupLogin(&user, c)
}