forked from zeromicro/zero-contrib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcasbin.go
104 lines (90 loc) · 2.49 KB
/
casbin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package casbin
import (
"log"
"net/http"
"github.com/casbin/casbin/v2"
"github.com/zeromicro/go-zero/core/logx"
)
type (
// Authorizer stores the casbin handler
Authorizer struct {
enforcer *casbin.Enforcer
uidField string
domain string
}
// AuthorizerOption represents an option.
AuthorizerOption func(opt *Authorizer)
)
// WithUidField returns a custom user unique identity option.
func WithUidField(uidField string) AuthorizerOption {
return func(opt *Authorizer) {
opt.uidField = uidField
}
}
// WithDomain returns a custom domain option.
func WithDomain(domain string) AuthorizerOption {
return func(opt *Authorizer) {
opt.domain = domain
}
}
// NewAuthorizer returns the authorizer, uses a Casbin enforcer as input
func NewAuthorizer(e *casbin.Enforcer, opts ...AuthorizerOption) func(http.Handler) http.Handler {
a := &Authorizer{enforcer: e}
// init an Authorizer
a.init(opts...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if !a.CheckPermission(request) {
a.RequirePermission(writer)
return
}
next.ServeHTTP(writer, request)
})
}
}
func (a *Authorizer) init(opts ...AuthorizerOption) {
a.uidField = "username"
a.domain = "domain"
for _, opt := range opts {
opt(a)
}
}
// GetUid gets the uid from the JWT Claims.
func (a *Authorizer) GetUid(r *http.Request) (string, bool) {
uid, ok := r.Context().Value(a.uidField).(string)
return uid, ok
}
// GetDomain returns the domain from the request.
func (a *Authorizer) GetDomain(r *http.Request) (string, bool) {
domain, ok := r.Context().Value(a.domain).(string)
return domain, ok
}
// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *Authorizer) CheckPermission(r *http.Request) bool {
uid, ok := a.GetUid(r)
if !ok {
return false
}
method := r.Method
path := r.URL.Path
var (
allowed = false
err error
)
domain, withDomain := a.GetDomain(r)
log.Println("domain:", domain)
if withDomain {
allowed, err = a.enforcer.Enforce(uid, domain, path, method)
} else {
allowed, err = a.enforcer.Enforce(uid, path, method)
}
if err != nil {
logx.WithContext(r.Context()).Errorf("[CASBIN] enforce err %s", err.Error())
}
return allowed
}
// RequirePermission returns the 403 Forbidden to the client.
func (a *Authorizer) RequirePermission(writer http.ResponseWriter) {
writer.WriteHeader(http.StatusForbidden)
}