-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathcors.go
228 lines (204 loc) · 6.59 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
// Copyright 2015 Husobee Associates, LLC. All rights reserved.
// Use of this source code is governed by The MIT License, which
// can be found in the LICENSE file included.
package vestigo
import (
"errors"
"fmt"
"net/http"
"strings"
"time"
)
// CorsAccessControl - Default implementation of Cors
type CorsAccessControl struct {
AllowOrigin []string
AllowCredentials bool
ExposeHeaders []string
MaxAge time.Duration
AllowMethods []string
AllowHeaders []string
}
// GetAllowOrigin - returns the allow-origin string representation
func (c *CorsAccessControl) GetAllowOrigin() []string {
return c.AllowOrigin
}
// GetAllowCredentials - returns the allow-credentials string representation
func (c *CorsAccessControl) GetAllowCredentials() bool {
return c.AllowCredentials
}
// GetExposeHeaders - returns the expose-headers string representation
func (c *CorsAccessControl) GetExposeHeaders() []string {
return c.ExposeHeaders
}
// GetMaxAge - returns the max-age string representation
func (c *CorsAccessControl) GetMaxAge() time.Duration {
return c.MaxAge
}
// GetAllowMethods - returns the allow-methods string representation
func (c *CorsAccessControl) GetAllowMethods() []string {
return c.AllowMethods
}
// GetAllowHeaders - returns the allow-headers string representation
func (c *CorsAccessControl) GetAllowHeaders() []string {
return c.AllowHeaders
}
// Merge - Merge the values of one CORS policy into 'this' one
func (c *CorsAccessControl) Merge(c2 *CorsAccessControl) *CorsAccessControl {
result := new(CorsAccessControl)
if c != nil {
if c2 == nil {
result.AllowOrigin = c.GetAllowOrigin()
result.AllowCredentials = c.GetAllowCredentials()
result.ExposeHeaders = c.GetExposeHeaders()
result.MaxAge = c.GetMaxAge()
result.AllowMethods = c.GetAllowMethods()
result.AllowHeaders = c.GetAllowHeaders()
return result
}
if allowOrigin := c2.GetAllowOrigin(); len(allowOrigin) != 0 {
result.AllowOrigin = append(c.GetAllowOrigin(), c2.GetAllowOrigin()...)
} else {
result.AllowOrigin = c.GetAllowOrigin()
}
if allowCredentials := c2.GetAllowCredentials(); allowCredentials == true {
result.AllowCredentials = c2.GetAllowCredentials()
} else {
result.AllowCredentials = c.GetAllowCredentials()
}
if exposeHeaders := c2.GetExposeHeaders(); len(exposeHeaders) != 0 {
h := append(c.GetExposeHeaders(), c2.GetExposeHeaders()...)
seen := map[string]bool{}
for i, x := range h {
if seen[strings.ToLower(x)] {
continue
}
seen[strings.ToLower(x)] = true
result.ExposeHeaders = append(result.ExposeHeaders, h[i])
}
} else {
result.ExposeHeaders = c.GetExposeHeaders()
}
if maxAge := c2.GetMaxAge(); maxAge.Seconds() != 0 {
result.MaxAge = c2.GetMaxAge()
} else {
result.MaxAge = c.GetMaxAge()
}
if allowMethods := c2.GetAllowMethods(); len(allowMethods) != 0 {
h := append(c.GetAllowMethods(), allowMethods...)
seen := map[string]bool{}
for i, x := range h {
if seen[x] {
continue
}
seen[x] = true
result.AllowMethods = append(result.AllowMethods, h[i])
}
} else {
result.AllowMethods = c.GetAllowMethods()
}
if allowHeaders := c2.GetAllowHeaders(); len(allowHeaders) != 0 {
h := append(c.GetAllowHeaders(), c2.GetAllowHeaders()...)
seen := map[string]bool{}
for i, x := range h {
if seen[strings.ToLower(x)] {
continue
}
seen[strings.ToLower(x)] = true
result.AllowHeaders = append(result.AllowHeaders, h[i])
}
} else {
result.AllowHeaders = c.GetAllowHeaders()
}
}
return result
}
// corsPreflight - perform CORS preflight against the CORS policy for a given resource
func corsPreflight(gcors *CorsAccessControl, lcors *CorsAccessControl, allowedMethods string, w http.ResponseWriter, r *http.Request) error {
cors := gcors.Merge(lcors)
if origin := r.Header.Get("Origin"); cors != nil && origin != "" {
// validate origin is in list of acceptable allow-origins
allowedOrigin := false
allowedOriginExact := false
for _, v := range cors.GetAllowOrigin() {
if v == origin {
w.Header().Add("Access-Control-Allow-Origin", origin)
allowedOriginExact = true
allowedOrigin = true
break
}
}
if !allowedOrigin {
for _, v := range cors.GetAllowOrigin() {
if v == "*" {
w.Header().Add("Access-Control-Allow-Origin", v)
allowedOrigin = true
break
}
}
}
if !allowedOrigin {
// other option headers needed
w.WriteHeader(http.StatusOK)
w.Write([]byte(""))
return errors.New("quick cors end")
}
// if the request includes access-control-request-method
if method := r.Header.Get("Access-Control-Request-Method"); method != "" {
// if there are no cors settings for this resource, use the allowedMethods,
// if there are settings for cors, use those
responseMethods := []string{}
if methods := cors.GetAllowMethods(); len(methods) != 0 {
for _, x := range methods {
if x == method {
responseMethods = append(responseMethods, x)
}
}
} else {
for _, x := range strings.Split(allowedMethods, ", ") {
if x == method {
responseMethods = append(responseMethods, x)
}
}
}
if len(responseMethods) > 0 {
w.Header().Add("Access-Control-Allow-Methods", strings.Join(responseMethods, ", "))
} else {
// other option headers needed
w.WriteHeader(http.StatusOK)
w.Write([]byte(""))
return errors.New("quick cors end")
}
}
// if allow credentials is allowed on this resource respond with true
if allowCredentials := cors.GetAllowCredentials(); allowedOriginExact && allowCredentials {
w.Header().Add("Access-Control-Allow-Credentials", "true")
}
if exposeHeaders := cors.GetExposeHeaders(); len(exposeHeaders) != 0 {
// if we have expose headers, send them
w.Header().Add("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", "))
}
if maxAge := cors.GetMaxAge(); maxAge.Seconds() != 0 {
// optional, if we have a max age, send it
sec := fmt.Sprint(int64(maxAge.Seconds()))
w.Header().Add("Access-Control-Max-Age", sec)
}
if header := r.Header.Get("Access-Control-Request-Headers"); header != "" {
header = strings.Replace(header, " ", "", -1)
requestHeaders := strings.Split(header, ",")
allowHeaders := cors.GetAllowHeaders()
goodHeaders := []string{}
for _, x := range requestHeaders {
for _, y := range allowHeaders {
if strings.ToLower(x) == strings.ToLower(y) {
goodHeaders = append(goodHeaders, x)
break
}
}
}
if len(goodHeaders) > 0 {
w.Header().Add("Access-Control-Allow-Headers", strings.Join(goodHeaders, ", "))
}
}
}
return nil
}