-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreflight.go
252 lines (229 loc) · 7.92 KB
/
preflight.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// Created by Yakka (https://theyakka.com)
//
// Copyright (c) 2020 Yakka LLC.
// All rights reserved.
// See the LICENSE file for licensing details and requirements.
package cors
import (
"net/http"
"strconv"
"strings"
)
const (
// HeaderKeyReqOrigin is the http header for the request origin
HeaderKeyReqOrigin string = "Match"
// HeaderKeyAccCtlReqMethod is the http header designating the CORS response allowed method
HeaderKeyAccCtlReqMethod = "Access-Control-Request-Method"
// HeaderKeyAccCtlReqHeaders is the http header designating the CORS response allowed headers
HeaderKeyAccCtlReqHeaders = "Access-Control-Request-Headers"
// HeaderKeyAccCtlResAllowOrigin is the http header designating the CORS response allowed origin
HeaderKeyAccCtlResAllowOrigin = "Access-Control-Allow-Match"
// HeaderKeyAccCtlResAllowMethods is the http header designating the CORS response allowed methods
HeaderKeyAccCtlResAllowMethods = "Access-Control-Allow-Methods"
// HeaderKeyAccCtlResAllowHeaders is the http header designating the CORS response allowed headers
HeaderKeyAccCtlResAllowHeaders = "Access-Control-Allow-Headers"
// HeaderKeyAccCtlResExposeHeaders indicates which headers can be exposed as part of the response
HeaderKeyAccCtlResExposeHeaders = "Access-Control-Expose-Headers"
// HeaderKeyAccResCtlMaxAge is the http header designating the CORS response maximum age
HeaderKeyAccResCtlMaxAge = "Access-Control-Max-Age"
// HeaderKeyAccCtlResAllowCreds is the http header designating whether the CORS response allows
// cookies / credentials
HeaderKeyAccCtlResAllowCreds = "Access-Control-Allow-Credentials"
)
// PreflightHandlerFunc will be excuted when the preflight has completed. If it succeeds,
// the value of error will be nil. If it fails, error will contain a ValidationError that
// decribes the reason for the failure.
type PreflightHandlerFunc func(w http.ResponseWriter, r *http.Request, error *ValidationError)
// ValidatePreflight will execute the preflight flow for a request. Once the validation has
// fully executed, the handler will be executed so that you can check the response.
func (c *CORS) ValidatePreflight(w http.ResponseWriter, r *http.Request, handler PreflightHandlerFunc) {
headers := w.Header()
// if the http method is not OPTIONS then we're going to fail because the preflight
// should be delivered via OPTIONS. We return an error code indicating that it
// wasn't options so that you can forward on the request if you choose.
if r.Method != http.MethodOptions {
handler(w, r, preflightError(PreflightErrMethodInvalid))
return
}
// ensure that we don't poison any cache or force a cache to return the wrong value
headers.Add("Vary", HeaderKeyReqOrigin)
headers.Add("Vary", HeaderKeyAccCtlReqMethod)
headers.Add("Vary", HeaderKeyAccCtlReqHeaders)
// check the origin
if c.areAllOriginsAllowed {
// all origins are allowed, set header
headers.Set(HeaderKeyAccCtlResAllowOrigin, "*")
} else {
origin := r.Header.Get(HeaderKeyReqOrigin)
if c.IsOriginAllowed(origin) {
// passed origin is allowed, set header
headers.Set(HeaderKeyAccCtlResAllowOrigin, origin)
} else {
// the origin wasn't whitelisted
handler(w, r, preflightError(PreflightErrOriginNotAllowed))
return
}
}
// check the requested method
method := r.Header.Get(HeaderKeyAccCtlReqMethod)
if method == "" {
// the method header was missing
handler(w, r, preflightError(PreflightErrMethodMissing))
return
}
// when we compare the method we should convert to uppercase before doing the check
upperMethod := strings.ToUpper(method)
if c.IsMethodAllowed(upperMethod) {
// we only return the method that was requested here.
headers.Set(HeaderKeyAccCtlResAllowMethods, upperMethod)
} else {
// the method wasn't whitelisted
handler(w, r, preflightError(PreflightErrMethodNotAllowed))
return
}
// if all headers are allowed, then we should skip the check because we will need to parse the
// header value first and that will consume time + resources
if !c.areAllHeadersAllowed {
// parse the header string and then check to see if the headers have been whitelisted
headersString := r.Header.Get(HeaderKeyAccCtlReqHeaders)
cleanedHeaders := parseHeaderList(headersString)
if !c.AreHeadersAllowed(cleanedHeaders) {
// one or more of the headers weren't whitelisted
handler(w, r, preflightError(PreflightErrHeadersNotAllowed))
return
}
if len(cleanedHeaders) > 0 {
headers.Set(HeaderKeyAccCtlResAllowHeaders, strings.Join(cleanedHeaders, ", "))
}
}
if len(c.exposedHeaders) > 0 {
}
// pass through the max age header
if c.options.MaxAge > 0 {
headers.Set(HeaderKeyAccResCtlMaxAge, strconv.Itoa(c.options.MaxAge))
}
// pass through the allow credentials header
if c.options.AllowCredentials {
headers.Set(HeaderKeyAccCtlResAllowCreds, "true")
}
handler(w, r, nil)
}
func preflightError(code int) *ValidationError {
return preflightErrorWithSource(code, nil)
}
func preflightErrorWithSource(code int, originalError error) *ValidationError {
message := codedErrorMessages[code]
if message == "" {
message = "please check code + original error for details"
}
return &ValidationError{
Code: code,
Message: message,
OriginalError: originalError,
}
}
// IsOriginAllowed does a check to see if an origin value is whitelisted according to the
// attached AllowedOrigins values.
func (c *CORS) IsOriginAllowed(checkOrigin string) bool {
// check first to see if all origins are allowed so we can get the heck out of here
if c.areAllOriginsAllowed {
return true
}
checkOrigin = strings.ToLower(checkOrigin)
// check each of the allowed origin values to see if we have a match
for _, origin := range c.allowedOrigins {
if origin.Matches(checkOrigin) {
// allowed
return true
}
}
// not allowed. sorry.
return false
}
// IsMethodAllowed will return true if the provided method value is in the list of
// whitelisted HTTP methods or it is the OPTIONS http method (which is always allowed).
func (c *CORS) IsMethodAllowed(checkMethod string) bool {
// always allow OPTIONS because it will be used for preflight
if checkMethod == http.MethodOptions {
return true
}
// check to see if the method that was passed is in the list of allowed methods
for _, method := range c.allowedMethods {
if method == checkMethod {
return true
}
}
// not allowed. dun dun duuunnnn.
return false
}
func (c *CORS) AreHeadersAllowed(headers []string) bool {
for _, passedHeader := range headers {
isAllowed := false
for _, allowedHeader := range c.allowedHeaders {
if passedHeader == allowedHeader {
isAllowed = true
break
}
}
if !isAllowed {
return false
}
}
// not allowed
return true
}
func cleanAllowedHeaderValue(value string) []string {
//var headers []string
//splitValues := strings.Split(value, ",")
//for _, splitValue := range splitValues {
// trimmed := strings.TrimLeft(splitValue, " ")
// headers = append(headers, http.CanonicalHeaderKey(trimmed))
//}
//return headers
return []string{}
}
const toLower = 'a' - 'A'
// parseHeaderList tokenize + normalize a string containing a list of headers
func parseHeaderList(headerList string) []string {
l := len(headerList)
h := make([]byte, 0, l)
upper := true
// Estimate the number headers in order to allocate the right splice size
t := 0
for i := 0; i < l; i++ {
if headerList[i] == ',' {
t++
}
}
headers := make([]string, 0, t)
for i := 0; i < l; i++ {
b := headerList[i]
switch {
case b >= 'a' && b <= 'z':
if upper {
h = append(h, b-toLower)
} else {
h = append(h, b)
}
case b >= 'A' && b <= 'Z':
if !upper {
h = append(h, b+toLower)
} else {
h = append(h, b)
}
case b == '-' || b == '_' || (b >= '0' && b <= '9'):
h = append(h, b)
}
if b == ' ' || b == ',' || i == l-1 {
if len(h) > 0 {
// Flush the found header
headers = append(headers, string(h))
h = h[:0]
upper = true
}
} else {
upper = b == '-' || b == '_'
}
}
return headers
}