forked from lestrrat-go/jwx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
refresh.go
648 lines (584 loc) · 18.4 KB
/
refresh.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
package jwk
import (
"context"
"net/http"
"reflect"
"sync"
"time"
"github.com/lestrrat-go/backoff/v2"
"github.com/lestrrat-go/httpcc"
"github.com/pkg/errors"
)
// AutoRefresh is a container that keeps track of jwk.Set object by their source URLs.
// The jwk.Set objects are refreshed automatically behind the scenes.
//
// Before retrieving the jwk.Set objects, the user must pre-register the
// URLs they intend to use by calling `Configure()`
//
// ar := jwk.NewAutoRefresh(ctx)
// ar.Configure(url, options...)
//
// Once registered, you can call `Fetch()` to retrieve the jwk.Set object.
//
// All JWKS objects that are retrieved via the auto-fetch mechanism should be
// treated read-only, as they are shared among the consumers and this object.
type AutoRefresh struct {
errSink chan AutoRefreshError
cache map[string]Set
configureCh chan struct{}
removeCh chan removeReq
fetching map[string]chan struct{}
muErrSink sync.Mutex
muCache sync.RWMutex
muFetching sync.Mutex
muRegistry sync.RWMutex
registry map[string]*target
resetTimerCh chan *resetTimerReq
}
type target struct {
// The backoff policy to use when fetching the JWKS fails
backoff backoff.Policy
// The HTTP client to use. The user may opt to use a client which is
// aware of HTTP caching, or one that goes through a proxy
httpcl HTTPClient
// Interval between refreshes are calculated two ways.
// 1) You can set an explicit refresh interval by using WithRefreshInterval().
// In this mode, it doesn't matter what the HTTP response says in its
// Cache-Control or Expires headers
// 2) You can let us calculate the time-to-refresh based on the key's
// Cache-Control or Expires headers.
// First, the user provides us the absolute minimum interval before
// refreshes. We will never check for refreshes before this specified
// amount of time.
//
// Next, max-age directive in the Cache-Control header is consulted.
// If `max-age` is not present, we skip the following section, and
// proceed to the next option.
// If `max-age > user-supplied minimum interval`, then we use the max-age,
// otherwise the user-supplied minimum interval is used.
//
// Next, the value specified in Expires header is consulted.
// If the header is not present, we skip the following seciont and
// proceed to the next option.
// We take the time until expiration `expires - time.Now()`, and
// if `time-until-expiration > user-supplied minimum interval`, then
// we use the expires value, otherwise the user-supplied minimum interval is used.
//
// If all of the above fails, we used the user-supplied minimum interval
refreshInterval *time.Duration
minRefreshInterval time.Duration
url string
// The timer for refreshing the keyset. should not be set by anyone
// other than the refreshing goroutine
timer *time.Timer
// Semaphore to limit the number of concurrent refreshes in the background
sem chan struct{}
// for debugging, snapshoting
lastRefresh time.Time
nextRefresh time.Time
wl Whitelist
}
type resetTimerReq struct {
t *target
d time.Duration
}
// NewAutoRefresh creates a container that keeps track of JWKS objects which
// are automatically refreshed.
//
// The context object in the argument controls the life-span of the
// auto-refresh worker. If you are using this in a long running process, this
// should mostly be set to a context that ends when the main loop/part of your
// program exits:
//
// func MainLoop() {
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// ar := jwk.AutoRefresh(ctx)
// for ... {
// ...
// }
// }
func NewAutoRefresh(ctx context.Context) *AutoRefresh {
af := &AutoRefresh{
cache: make(map[string]Set),
configureCh: make(chan struct{}),
removeCh: make(chan removeReq),
fetching: make(map[string]chan struct{}),
registry: make(map[string]*target),
resetTimerCh: make(chan *resetTimerReq),
}
go af.refreshLoop(ctx)
return af
}
func (af *AutoRefresh) getCached(url string) (Set, bool) {
af.muCache.RLock()
ks, ok := af.cache[url]
af.muCache.RUnlock()
if ok {
return ks, true
}
return nil, false
}
type removeReq struct {
replyCh chan error
url string
}
// Remove removes `url` from the list of urls being watched by jwk.AutoRefresh.
// If the url is not already registered, returns an error.
func (af *AutoRefresh) Remove(url string) error {
ch := make(chan error)
af.removeCh <- removeReq{replyCh: ch, url: url}
return <-ch
}
// Configure registers the url to be controlled by AutoRefresh, and also
// sets any options associated to it.
//
// Note that options are treated as a whole -- you can't just update
// one value. For example, if you did:
//
// ar.Configure(url, jwk.WithHTTPClient(...))
// ar.Configure(url, jwk.WithRefreshInterval(...))
// The the end result is that `url` is ONLY associated with the options
// given in the second call to `Configure()`, i.e. `jwk.WithRefreshInterval`.
// The other unspecified options, including the HTTP client, is set to
// their default values.
//
// Configuration must propagate between goroutines, and therefore are
// not atomic (But changes should be felt "soon enough" for practical
// purposes)
func (af *AutoRefresh) Configure(url string, options ...AutoRefreshOption) {
var httpcl HTTPClient = http.DefaultClient
var hasRefreshInterval bool
var refreshInterval time.Duration
var wl Whitelist
minRefreshInterval := time.Hour
bo := backoff.Null()
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
case identFetchBackoff{}:
bo = option.Value().(backoff.Policy)
case identRefreshInterval{}:
refreshInterval = option.Value().(time.Duration)
hasRefreshInterval = true
case identMinRefreshInterval{}:
minRefreshInterval = option.Value().(time.Duration)
case identHTTPClient{}:
httpcl = option.Value().(HTTPClient)
case identFetchWhitelist{}:
wl = option.Value().(Whitelist)
}
}
var doReconfigure bool
af.muRegistry.Lock()
t, ok := af.registry[url]
if ok {
if t.httpcl != httpcl {
t.httpcl = httpcl
doReconfigure = true
}
if t.minRefreshInterval != minRefreshInterval {
t.minRefreshInterval = minRefreshInterval
doReconfigure = true
}
if t.refreshInterval != nil {
if !hasRefreshInterval {
t.refreshInterval = nil
doReconfigure = true
} else if *t.refreshInterval != refreshInterval {
*t.refreshInterval = refreshInterval
doReconfigure = true
}
} else {
if hasRefreshInterval {
t.refreshInterval = &refreshInterval
doReconfigure = true
}
}
if t.wl != wl {
t.wl = wl
doReconfigure = true
}
} else {
t = &target{
backoff: bo,
httpcl: httpcl,
minRefreshInterval: minRefreshInterval,
url: url,
sem: make(chan struct{}, 1),
// This is a placeholder timer so we can call Reset() on it later
// Make it sufficiently in the future so that we don't have bogus
// events firing
timer: time.NewTimer(24 * time.Hour),
wl: wl,
}
if hasRefreshInterval {
t.refreshInterval = &refreshInterval
}
// Record this in the registry
af.registry[url] = t
doReconfigure = true
}
af.muRegistry.Unlock()
if doReconfigure {
// Tell the backend to reconfigure itself
af.configureCh <- struct{}{}
}
}
func (af *AutoRefresh) releaseFetching(url string) {
// first delete the entry from the map, then close the channel or
// otherwise we may end up getting multiple groutines doing the fetch
af.muFetching.Lock()
fetchingCh, ok := af.fetching[url]
if !ok {
// Juuuuuuust in case. But shouldn't happen
af.muFetching.Unlock()
return
}
delete(af.fetching, url)
close(fetchingCh)
af.muFetching.Unlock()
}
// IsRegistered checks if `url` is registered already.
func (af *AutoRefresh) IsRegistered(url string) bool {
_, ok := af.getRegistered(url)
return ok
}
// Fetch returns a jwk.Set from the given url.
func (af *AutoRefresh) getRegistered(url string) (*target, bool) {
af.muRegistry.RLock()
t, ok := af.registry[url]
af.muRegistry.RUnlock()
return t, ok
}
// Fetch returns a jwk.Set from the given url.
//
// If it has previously been fetched, then a cached value is returned.
//
// If this the first time `url` was requested, an HTTP request will be
// sent, synchronously.
//
// When accessed via multiple goroutines concurrently, and the cache
// has not been populated yet, only the first goroutine is
// allowed to perform the initialization (HTTP fetch and cache population).
// All other goroutines will be blocked until the operation is completed.
//
// DO NOT modify the jwk.Set object returned by this method, as the
// objects are shared among all consumers and the backend goroutine
func (af *AutoRefresh) Fetch(ctx context.Context, url string) (Set, error) {
if _, ok := af.getRegistered(url); !ok {
return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
}
ks, found := af.getCached(url)
if found {
return ks, nil
}
return af.refresh(ctx, url)
}
// Refresh is the same as Fetch(), except that HTTP fetching is done synchronously.
//
// This is useful when you want to force an HTTP fetch instead of waiting
// for the background goroutine to do it, for example when you want to
// make sure the AutoRefresh cache is warmed up before starting your main loop
func (af *AutoRefresh) Refresh(ctx context.Context, url string) (Set, error) {
if _, ok := af.getRegistered(url); !ok {
return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
}
return af.refresh(ctx, url)
}
func (af *AutoRefresh) refresh(ctx context.Context, url string) (Set, error) {
// To avoid a thundering herd, only one goroutine per url may enter into this
// initial fetch phase.
af.muFetching.Lock()
fetchingCh, fetching := af.fetching[url]
// unlock happens in each of the if/else clauses because we need to perform
// the channel initialization when there is no channel present
if fetching {
af.muFetching.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-fetchingCh:
}
} else {
fetchingCh = make(chan struct{})
af.fetching[url] = fetchingCh
af.muFetching.Unlock()
// Register a cleanup handler, to make sure we always
defer af.releaseFetching(url)
// The first time around, we need to fetch the keyset
if err := af.doRefreshRequest(ctx, url, false); err != nil {
return nil, errors.Wrapf(err, `failed to fetch resource pointed by %s`, url)
}
}
// the cache should now be populated
ks, ok := af.getCached(url)
if !ok {
return nil, errors.New("cache was not populated after explicit refresh")
}
return ks, nil
}
// Keeps looping, while refreshing the KeySet.
func (af *AutoRefresh) refreshLoop(ctx context.Context) {
// reflect.Select() is slow IF we are executing it over and over
// in a very fast iteration, but we assume here that refreshes happen
// seldom enough that being able to call one `select{}` with multiple
// targets / channels outweighs the speed penalty of using reflect.
//
const (
ctxDoneIdx = iota
configureIdx
resetTimerIdx
removeIdx
baseSelcasesLen
)
baseSelcases := make([]reflect.SelectCase, baseSelcasesLen)
baseSelcases[ctxDoneIdx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
}
baseSelcases[configureIdx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(af.configureCh),
}
baseSelcases[resetTimerIdx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(af.resetTimerCh),
}
baseSelcases[removeIdx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(af.removeCh),
}
var targets []*target
var selcases []reflect.SelectCase
for {
// It seems silly, but it's much easier to keep track of things
// if we re-build the select cases every iteration
af.muRegistry.RLock()
if cap(targets) < len(af.registry) {
targets = make([]*target, 0, len(af.registry))
} else {
targets = targets[:0]
}
if cap(selcases) < len(af.registry) {
selcases = make([]reflect.SelectCase, 0, len(af.registry)+baseSelcasesLen)
} else {
selcases = selcases[:0]
}
selcases = append(selcases, baseSelcases...)
for _, data := range af.registry {
targets = append(targets, data)
selcases = append(selcases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(data.timer.C),
})
}
af.muRegistry.RUnlock()
chosen, recv, recvOK := reflect.Select(selcases)
switch chosen {
case ctxDoneIdx:
// <-ctx.Done(). Just bail out of this loop
return
case configureIdx:
// <-configureCh. rebuild the select list from the registry.
// since we're rebuilding everything for each iteration,
// we just need to start the loop all over again
continue
case resetTimerIdx:
// <-resetTimerCh. interrupt polling, and reset the timer on
// a single target. this needs to be handled inside this select
if !recvOK {
continue
}
req := recv.Interface().(*resetTimerReq) //nolint:forcetypeassert
t := req.t
d := req.d
if !t.timer.Stop() {
select {
case <-t.timer.C:
default:
}
}
t.timer.Reset(d)
case removeIdx:
// <-removeCh. remove the URL from future fetching
//nolint:forcetypeassert
req := recv.Interface().(removeReq)
replyCh := req.replyCh
url := req.url
af.muRegistry.Lock()
if _, ok := af.registry[url]; !ok {
replyCh <- errors.Errorf(`invalid url %q (not registered)`, url)
} else {
delete(af.registry, url)
replyCh <- nil
}
af.muRegistry.Unlock()
default:
// Do not fire a refresh in case the channel was closed.
if !recvOK {
continue
}
// Time to refresh a target
t := targets[chosen-baseSelcasesLen]
// Check if there are other goroutines still doing the refresh asynchronously.
// This could happen if the refreshing goroutine is stuck on a backoff
// waiting for the HTTP request to complete.
select {
case t.sem <- struct{}{}:
// There can only be one refreshing goroutine
default:
continue
}
go func() {
//nolint:errcheck
af.doRefreshRequest(ctx, t.url, true)
<-t.sem
}()
}
}
}
func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error {
af.muRegistry.RLock()
t, ok := af.registry[url]
af.muRegistry.RUnlock()
if !ok {
return errors.Errorf(`url "%s" is not registered`, url)
}
// In case the refresh fails due to errors in fetching/parsing the JWKS,
// we want to retry. Create a backoff object,
options := []FetchOption{WithHTTPClient(t.httpcl)}
if enableBackoff {
options = append(options, WithFetchBackoff(t.backoff))
}
if t.wl != nil {
options = append(options, WithFetchWhitelist(t.wl))
}
res, err := fetch(ctx, url, options...)
if err == nil {
if res.StatusCode != http.StatusOK {
// now, can there be a remote resource that responds with a status code
// other than 200 and still be valid...? naaaaaaahhhhhh....
err = errors.Errorf(`bad response status code (%d)`, res.StatusCode)
} else {
defer res.Body.Close()
keyset, parseErr := ParseReader(res.Body)
if parseErr == nil {
// Got a new key set. replace the keyset in the target
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
rtr := &resetTimerReq{
t: t,
d: nextInterval,
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
}
now := time.Now()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
return nil
}
err = parseErr
}
}
// At this point if err != nil, we know that there was something wrong
// in either the fetching or the parsing. Send this error to be processed,
// but take the extra mileage to not block regular processing by
// discarding the error if we fail to send it through the channel
if err != nil {
select {
case af.errSink <- AutoRefreshError{Error: err, URL: url}:
default:
}
}
// We either failed to perform the HTTP GET, or we failed to parse the
// JWK set. Even in case of errors, we don't delete the old key.
// We persist the old key set, even if it may be stale so the user has something to work with
// TODO: maybe this behavior should be customizable?
// If we failed to get a single time, then queue another fetch in the future.
rtr := &resetTimerReq{
t: t,
d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval),
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
}
return err
}
// ErrorSink sets a channel to receive JWK fetch errors, if any.
// Only the errors that occurred *after* the channel was set will be sent.
//
// The user is responsible for properly draining the channel. If the channel
// is not drained properly, errors will be discarded.
//
// To disable, set a nil channel.
func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) {
af.muErrSink.Lock()
af.errSink = ch
af.muErrSink.Unlock()
}
func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration {
// This always has precedence
if refreshInterval != nil {
return *refreshInterval
}
if res != nil {
if v := res.Header.Get(`Cache-Control`); v != "" {
dir, err := httpcc.ParseResponse(v)
if err == nil {
maxAge, ok := dir.MaxAge()
if ok {
resDuration := time.Duration(maxAge) * time.Second
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
// fallthrough
}
// fallthrough
}
if v := res.Header.Get(`Expires`); v != "" {
expires, err := http.ParseTime(v)
if err == nil {
resDuration := time.Until(expires)
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
// fallthrough
}
}
// Previous fallthroughs are a little redandunt, but hey, it's all good.
return minRefreshInterval
}
// TargetSnapshot is the structure returned by the Snapshot method.
// It contains information about a url that has been configured
// in AutoRefresh.
type TargetSnapshot struct {
URL string
NextRefresh time.Time
LastRefresh time.Time
}
func (af *AutoRefresh) Snapshot() <-chan TargetSnapshot {
af.muRegistry.Lock()
ch := make(chan TargetSnapshot, len(af.registry))
for url, t := range af.registry {
ch <- TargetSnapshot{
URL: url,
NextRefresh: t.nextRefresh,
LastRefresh: t.lastRefresh,
}
}
af.muRegistry.Unlock()
close(ch)
return ch
}