Skip to content

Commit

Permalink
feat: add suport JSON Web Key Set (dunglas#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas authored May 30, 2024
1 parent abd4c7d commit c8df627
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 89 deletions.
3 changes: 2 additions & 1 deletion .github/linters/.jscpd.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"**/examples",
"**/node_modules",
"**/dist",
"**/.goreleaser.yml"
"**/.goreleaser.yml",
"**/caddy/caddy.go"
],
"absolute": true
}
125 changes: 90 additions & 35 deletions caddy/caddy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"time"

"github.com/MicahParks/keyfunc/v3"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
Expand All @@ -24,8 +25,9 @@ import (
const defaultHubURL = "/.well-known/mercure"

var (
transports = caddy.NewUsagePool() //nolint:gochecknoglobals
metrics = mercure.NewPrometheusMetrics(prometheus.DefaultRegisterer) //nolint:gochecknoglobals
ErrCompatibility = errors.New("compatibility mode only supports protocol version 7")
transports = caddy.NewUsagePool() //nolint:gochecknoglobals
metrics = mercure.NewPrometheusMetrics(prometheus.DefaultRegisterer) //nolint:gochecknoglobals
)

func init() { //nolint:gochecknoinits
Expand All @@ -43,7 +45,7 @@ type transportDestructor struct {
}

func (d *transportDestructor) Destruct() error {
return d.transport.Close()
return d.transport.Close() //nolint:wrapcheck
}

// Mercure implements a Mercure hub as a Caddy module. Mercure is a protocol allowing to push data updates to web browsers and other HTTP clients in a convenient, fast, reliable and battery-efficient way.
Expand Down Expand Up @@ -72,9 +74,15 @@ type Mercure struct {
// JWT key and signing algorithm to use for publishers.
PublisherJWT JWTConfig `json:"publisher_jwt,omitempty"`

// JWK Set URL to use for publishers.
PublisherJWKSURL string `json:"publisher_jwks_url,omitempty"`

// JWT key and signing algorithm to use for subscribers.
SubscriberJWT JWTConfig `json:"subscriber_jwt,omitempty"`

// JWK Set URL to use for subscribers.
SubscriberJWKSURL string `json:"subscriber_jwks_url,omitempty"`

// Origins allowed to publish updates
PublishOrigins []string `json:"publish_origins,omitempty"`

Expand Down Expand Up @@ -105,20 +113,45 @@ func (Mercure) CaddyModule() caddy.ModuleInfo {
}
}

func (m *Mercure) Provision(ctx caddy.Context) error { //nolint:funlen
func (m *Mercure) populateJWTConfig() error {
repl := caddy.NewReplacer()

m.PublisherJWT.Key = repl.ReplaceKnown(m.PublisherJWT.Key, "")
m.PublisherJWT.Alg = repl.ReplaceKnown(m.PublisherJWT.Alg, "HS256")
m.SubscriberJWT.Key = repl.ReplaceKnown(m.SubscriberJWT.Key, "")
m.SubscriberJWT.Alg = repl.ReplaceKnown(m.SubscriberJWT.Alg, "HS256")
if m.PublisherJWKSURL == "" {
m.PublisherJWT.Key = repl.ReplaceKnown(m.PublisherJWT.Key, "")
m.PublisherJWT.Alg = repl.ReplaceKnown(m.PublisherJWT.Alg, "HS256")

if m.PublisherJWT.Key == "" {
return errors.New("a JWT key for publishers must be provided") //nolint:goerr113
if m.PublisherJWT.Key == "" {
return errors.New("a JWT key or the URL of a JWK Set for publishers must be provided") //nolint:goerr113
}

if m.PublisherJWT.Alg == "" {
m.PublisherJWT.Alg = "HS256"
}
}
if m.PublisherJWT.Alg == "" {
m.PublisherJWT.Alg = "HS256"

if m.SubscriberJWKSURL == "" {
m.SubscriberJWT.Key = repl.ReplaceKnown(m.SubscriberJWT.Key, "")
m.SubscriberJWT.Alg = repl.ReplaceKnown(m.SubscriberJWT.Alg, "HS256")

if m.SubscriberJWT.Key == "" {
if !m.Anonymous {
return errors.New("a JWT key or the URL of a JWK Set for subscribers must be provided") //nolint:goerr113
}
}

if m.SubscriberJWT.Alg == "" {
m.SubscriberJWT.Alg = "HS256"
}
}

return nil
}

func (m *Mercure) Provision(ctx caddy.Context) error { //nolint:funlen,gocognit
if err := m.populateJWTConfig(); err != nil {
return err
}

if m.TransportURL == "" {
m.TransportURL = "bolt://mercure.db"
}
Expand Down Expand Up @@ -167,22 +200,30 @@ func (m *Mercure) Provision(ctx caddy.Context) error { //nolint:funlen
mercure.WithTopicSelectorStore(tss),
mercure.WithTransport(destructor.(*transportDestructor).transport),
mercure.WithMetrics(metrics),
mercure.WithPublisherJWT([]byte(m.PublisherJWT.Key), m.PublisherJWT.Alg),
mercure.WithCookieName(m.CookieName),
}
if m.logger.Core().Enabled(zapcore.DebugLevel) {
opts = append(opts, mercure.WithDebug())
}

if m.SubscriberJWT.Key == "" {
if !m.Anonymous {
return errors.New("a JWT key for subscribers must be provided") //nolint:goerr113
}
if m.PublisherJWKSURL == "" {
opts = append(opts, mercure.WithPublisherJWT([]byte(m.PublisherJWT.Key), m.PublisherJWT.Alg))
} else {
if m.SubscriberJWT.Alg == "" {
m.SubscriberJWT.Alg = "HS256"
k, err := keyfunc.NewDefaultCtx(ctx, []string{m.PublisherJWKSURL})
if err != nil {
return fmt.Errorf("failed to retrieve publisher JWK Set: %w", err)
}

opts = append(opts, mercure.WithPublisherJWTKeyFunc(k.Keyfunc))
}

if m.SubscriberJWKSURL != "" {
k, err := keyfunc.NewDefaultCtx(ctx, []string{m.SubscriberJWKSURL})
if err != nil {
return fmt.Errorf("failed to retrieve subscriber JWK Set: %w", err)
}

opts = append(opts, mercure.WithSubscriberJWTKeyFunc(k.Keyfunc))
} else if m.SubscriberJWT.Key != "" {
opts = append(opts, mercure.WithSubscriberJWT([]byte(m.SubscriberJWT.Key), m.SubscriberJWT.Alg))
}

Expand Down Expand Up @@ -235,7 +276,7 @@ func (m *Mercure) Cleanup() error {

func (m Mercure) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
if !strings.HasPrefix(r.URL.Path, defaultHubURL) {
return next.ServeHTTP(w, r)
return next.ServeHTTP(w, r) //nolint:wrapcheck
}

m.hub.ServeHTTP(w, r)
Expand All @@ -244,7 +285,7 @@ func (m Mercure) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt
}

// UnmarshalCaddyfile sets up the handler from Caddyfile tokens.
func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:funlen,gocognit
func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:funlen,gocognit,gocyclo
for d.Next() {
for d.NextBlock(0) {
switch d.Val() {
Expand All @@ -262,7 +303,7 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu

case "write_timeout":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

d, err := caddy.ParseDuration(d.Val())
Expand All @@ -275,7 +316,7 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu

case "dispatch_timeout":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

d, err := caddy.ParseDuration(d.Val())
Expand All @@ -288,7 +329,7 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu

case "heartbeat":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

d, err := caddy.ParseDuration(d.Val())
Expand All @@ -299,19 +340,33 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu
cd := caddy.Duration(d)
m.Heartbeat = &cd

case "publisher_jwks_url":
if !d.NextArg() {
return d.ArgErr() //nolint:wrapcheck
}

m.PublisherJWKSURL = d.Val()

case "publisher_jwt":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.PublisherJWT.Key = d.Val()
if d.NextArg() {
m.PublisherJWT.Alg = d.Val()
}

case "subscriber_jwks_url":
if !d.NextArg() {
return d.ArgErr() //nolint:wrapcheck
}

m.SubscriberJWKSURL = d.Val()

case "subscriber_jwt":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.SubscriberJWT.Key = d.Val()
Expand All @@ -322,29 +377,29 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu
case "publish_origins":
ra := d.RemainingArgs()
if len(ra) == 0 {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.PublishOrigins = ra

case "cors_origins":
ra := d.RemainingArgs()
if len(ra) == 0 {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.CORSOrigins = ra

case "transport_url":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.TransportURL = d.Val()

case "lru_cache":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

v, err := strconv.ParseInt(d.Val(), 10, 64)
Expand All @@ -356,14 +411,14 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu

case "cookie_name":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

m.CookieName = d.Val()

case "protocol_version_compatibility":
if !d.NextArg() {
return d.ArgErr()
return d.ArgErr() //nolint:wrapcheck
}

v, err := strconv.Atoi(d.Val())
Expand All @@ -372,7 +427,7 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu
}

if v != 7 {
return errors.New("compatibility mode only supports protocol version 7")
return ErrCompatibility
}

m.ProtocolVersionCompatibility = v
Expand All @@ -384,7 +439,7 @@ func (m *Mercure) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { //nolint:fu
}

// parseCaddyfile unmarshals tokens from h into a new Middleware.
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { //nolint:ireturn
var m Mercure
err := m.UnmarshalCaddyfile(h.Dispenser)

Expand Down
6 changes: 3 additions & 3 deletions caddy/caddy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestMercure(t *testing.T) {

body := url.Values{"topic": {"http://example.com/foo/1"}, "data": {"bar"}, "id": {"bar"}}
req, err := http.NewRequest(http.MethodPost, "http://localhost:9080/.well-known/mercure", strings.NewReader(body.Encode()))
require.Nil(t, err)
require.NoError(t, err)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Authorization", bearerPrefix+publisherJWT)

Expand Down Expand Up @@ -154,7 +154,7 @@ func TestJWTPlaceholders(t *testing.T) {

body := url.Values{"topic": {"http://example.com/foo/1"}, "data": {"bar"}, "id": {"bar"}}
req, err := http.NewRequest(http.MethodPost, "http://localhost:9080/.well-known/mercure", strings.NewReader(body.Encode()))
require.Nil(t, err)
require.NoError(t, err)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Authorization", bearerPrefix+publisherJWTRSA)

Expand Down Expand Up @@ -254,7 +254,7 @@ func TestCookieName(t *testing.T) {

body := url.Values{"topic": {"http://example.com/foo/1"}, "data": {"bar"}, "id": {"bar"}, "private": {"1"}}
req, err := http.NewRequest(http.MethodPost, "http://localhost:9080/.well-known/mercure", strings.NewReader(body.Encode()))
require.Nil(t, err)
require.NoError(t, err)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Origin", "http://localhost:9080")
req.AddCookie(&http.Cookie{Name: "foo", Value: publisherJWT})
Expand Down
Loading

0 comments on commit c8df627

Please sign in to comment.